Skip to main content

engram/
store_postgres.rs

1//! PostgreSQL-backed `FactStore` implementation.
2//!
3//! Uses native Postgres types: `UUID` for ids, `TIMESTAMPTZ` for timestamps,
4//! `JSONB` for metadata and entity_refs, and a generated `tsvector` column
5//! for full-text search. Placeholder syntax uses `$1, $2, …`.
6
7use crate::fact::{Fact, FactFilter, FactId, FactPatch, MemoryTier};
8use crate::scope::Scope;
9use crate::store::{FactStore, MemoryError, StoreStats};
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::PgPool;
13use uuid::Uuid;
14
15// ---------------------------------------------------------------------------
16// DDL
17// ---------------------------------------------------------------------------
18
19/// DDL statements for the Postgres facts table. Each element is a single
20/// statement to be executed independently (no splitting on `;` needed).
21const PG_FACT_STORE_DDL: &[&str] = &[
22    r#"
23    CREATE TABLE IF NOT EXISTS facts (
24        id              UUID PRIMARY KEY,
25        text            TEXT NOT NULL,
26        org_id          TEXT NOT NULL DEFAULT 'default',
27        agent_id        TEXT,
28        user_id         TEXT,
29        session_id      TEXT,
30        tier            TEXT NOT NULL DEFAULT 'conversation',
31        category        TEXT,
32        source          TEXT,
33        confidence      DOUBLE PRECISION,
34        valid_from      TIMESTAMPTZ NOT NULL,
35        invalid_at      TIMESTAMPTZ,
36        created_at      TIMESTAMPTZ NOT NULL,
37        entity_refs     JSONB NOT NULL DEFAULT '[]',
38        supersedes      UUID,
39        superseded_by   UUID,
40        access_count    BIGINT NOT NULL DEFAULT 0,
41        last_accessed   TIMESTAMPTZ,
42        metadata        JSONB NOT NULL DEFAULT 'null',
43        search_vector   tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
44    )
45    "#,
46    "CREATE INDEX IF NOT EXISTS idx_pg_facts_org_id     ON facts (org_id)",
47    "CREATE INDEX IF NOT EXISTS idx_pg_facts_user_id    ON facts (user_id)",
48    "CREATE INDEX IF NOT EXISTS idx_pg_facts_agent_id   ON facts (agent_id)",
49    "CREATE INDEX IF NOT EXISTS idx_pg_facts_session_id ON facts (session_id)",
50    "CREATE INDEX IF NOT EXISTS idx_pg_facts_tier       ON facts (tier)",
51    "CREATE INDEX IF NOT EXISTS idx_pg_facts_category   ON facts (category)",
52    "CREATE INDEX IF NOT EXISTS idx_pg_facts_valid_from ON facts (valid_from)",
53    "CREATE INDEX IF NOT EXISTS idx_pg_facts_invalid_at ON facts (invalid_at)",
54    "CREATE INDEX IF NOT EXISTS idx_pg_facts_fts        ON facts USING GIN (search_vector)",
55];
56
57// ---------------------------------------------------------------------------
58// PostgresFactStore
59// ---------------------------------------------------------------------------
60
61pub struct PostgresFactStore {
62    pool: PgPool,
63}
64
65impl PostgresFactStore {
66    pub fn new(pool: PgPool) -> Self {
67        Self { pool }
68    }
69
70    /// Open a connection pool from a database URL and return a store.
71    pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
72        let pool = PgPool::connect(database_url).await?;
73        Ok(Self { pool })
74    }
75
76    /// Apply the DDL. Safe to call multiple times (uses `IF NOT EXISTS`).
77    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
78        for stmt in PG_FACT_STORE_DDL {
79            sqlx::query(stmt).execute(&self.pool).await?;
80        }
81        Ok(())
82    }
83}
84
85// ---------------------------------------------------------------------------
86// Internal row type
87// ---------------------------------------------------------------------------
88
89#[derive(sqlx::FromRow)]
90struct FactRow {
91    id: Uuid,
92    text: String,
93    org_id: String,
94    agent_id: Option<String>,
95    user_id: Option<String>,
96    session_id: Option<String>,
97    tier: String,
98    category: Option<String>,
99    source: Option<String>,
100    confidence: Option<f64>,
101    valid_from: DateTime<Utc>,
102    invalid_at: Option<DateTime<Utc>>,
103    created_at: DateTime<Utc>,
104    entity_refs: serde_json::Value,
105    supersedes: Option<Uuid>,
106    superseded_by: Option<Uuid>,
107    access_count: i64,
108    last_accessed: Option<DateTime<Utc>>,
109    metadata: serde_json::Value,
110}
111
112// ---------------------------------------------------------------------------
113// Conversion helpers
114// ---------------------------------------------------------------------------
115
116fn tier_from_str(s: &str) -> MemoryTier {
117    match s {
118        "working" => MemoryTier::Working,
119        "knowledge" => MemoryTier::Knowledge,
120        _ => MemoryTier::Conversation,
121    }
122}
123
124fn tier_to_str(t: &MemoryTier) -> &'static str {
125    match t {
126        MemoryTier::Working => "working",
127        MemoryTier::Conversation => "conversation",
128        MemoryTier::Knowledge => "knowledge",
129    }
130}
131
132fn row_to_fact(row: FactRow) -> Result<Fact, MemoryError> {
133    let entity_refs: Vec<Uuid> = {
134        let strings: Vec<String> = serde_json::from_value(row.entity_refs.clone())
135            .map_err(|e| MemoryError::Serialization(e.to_string()))?;
136        strings
137            .iter()
138            .map(|s| Uuid::parse_str(s).map_err(|e| MemoryError::Serialization(e.to_string())))
139            .collect::<Result<Vec<_>, _>>()?
140    };
141
142    let metadata: serde_json::Map<String, serde_json::Value> = match &row.metadata {
143        serde_json::Value::Null => serde_json::Map::new(),
144        serde_json::Value::Object(map) => map.clone(),
145        other => serde_json::from_value(other.clone())
146            .map_err(|e| MemoryError::Serialization(e.to_string()))?,
147    };
148
149    Ok(Fact {
150        id: row.id,
151        text: row.text,
152        scope: Scope {
153            org_id: row.org_id,
154            agent_id: row.agent_id,
155            user_id: row.user_id,
156            session_id: row.session_id,
157        },
158        tier: tier_from_str(&row.tier),
159        category: row.category,
160        source: row.source,
161        confidence: row.confidence.map(|c| c as f32),
162        valid_from: row.valid_from,
163        invalid_at: row.invalid_at,
164        created_at: row.created_at,
165        embedding: Vec::new(),
166        entity_refs,
167        supersedes: row.supersedes,
168        superseded_by: row.superseded_by,
169        access_count: row.access_count as u64,
170        last_accessed: row.last_accessed,
171        metadata,
172    })
173}
174
175// ---------------------------------------------------------------------------
176// FactStore implementation
177// ---------------------------------------------------------------------------
178
179#[async_trait]
180impl FactStore for PostgresFactStore {
181    async fn insert_fact(&self, fact: Fact) -> Result<FactId, MemoryError> {
182        let entity_refs_json = {
183            let strs: Vec<String> = fact.entity_refs.iter().map(|u| u.to_string()).collect();
184            serde_json::to_value(&strs).map_err(|e| MemoryError::Serialization(e.to_string()))?
185        };
186
187        let metadata_json = if fact.metadata.is_empty() {
188            serde_json::Value::Null
189        } else {
190            serde_json::to_value(&fact.metadata)
191                .map_err(|e| MemoryError::Serialization(e.to_string()))?
192        };
193
194        sqlx::query(
195            r#"
196            INSERT INTO facts
197                (id, text, org_id, agent_id, user_id, session_id,
198                 tier, category, source, confidence,
199                 valid_from, invalid_at, created_at,
200                 entity_refs, supersedes, superseded_by,
201                 access_count, last_accessed, metadata)
202            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
203            ON CONFLICT (id) DO NOTHING
204            "#,
205        )
206        .bind(fact.id)
207        .bind(&fact.text)
208        .bind(&fact.scope.org_id)
209        .bind(fact.scope.agent_id.as_deref())
210        .bind(fact.scope.user_id.as_deref())
211        .bind(fact.scope.session_id.as_deref())
212        .bind(tier_to_str(&fact.tier))
213        .bind(fact.category.as_deref())
214        .bind(fact.source.as_deref())
215        .bind(fact.confidence.map(|c| c as f64))
216        .bind(fact.valid_from)
217        .bind(fact.invalid_at)
218        .bind(fact.created_at)
219        .bind(&entity_refs_json)
220        .bind(fact.supersedes)
221        .bind(fact.superseded_by)
222        .bind(fact.access_count as i64)
223        .bind(fact.last_accessed)
224        .bind(&metadata_json)
225        .execute(&self.pool)
226        .await
227        .map_err(|e| MemoryError::Database(e.to_string()))?;
228
229        Ok(fact.id)
230    }
231
232    async fn get_fact(&self, id: FactId) -> Result<Fact, MemoryError> {
233        let row = sqlx::query_as::<_, FactRow>(
234            "SELECT id, text, org_id, agent_id, user_id, session_id, tier, category, source, confidence, valid_from, invalid_at, created_at, entity_refs, supersedes, superseded_by, access_count, last_accessed, metadata FROM facts WHERE id = $1",
235        )
236        .bind(id)
237        .fetch_optional(&self.pool)
238        .await
239        .map_err(|e| MemoryError::Database(e.to_string()))?
240        .ok_or_else(|| MemoryError::NotFound(id.to_string()))?;
241
242        row_to_fact(row)
243    }
244
245    async fn update_fact(&self, id: FactId, patch: FactPatch) -> Result<Fact, MemoryError> {
246        let mut set_clauses: Vec<String> = Vec::new();
247        let mut vals: Vec<String> = Vec::new();
248        let mut param_idx: usize = 1;
249
250        if let Some(ref text) = patch.text {
251            set_clauses.push(format!("text = ${param_idx}"));
252            vals.push(text.clone());
253            param_idx += 1;
254        }
255        if let Some(ref tier) = patch.tier {
256            set_clauses.push(format!("tier = ${param_idx}"));
257            vals.push(tier_to_str(tier).to_string());
258            param_idx += 1;
259        }
260        if let Some(ref category) = patch.category {
261            set_clauses.push(format!("category = ${param_idx}"));
262            vals.push(category.clone());
263            param_idx += 1;
264        }
265        if let Some(ref source) = patch.source {
266            set_clauses.push(format!("source = ${param_idx}"));
267            vals.push(source.clone());
268            param_idx += 1;
269        }
270        if let Some(confidence) = patch.confidence {
271            set_clauses.push(format!("confidence = ${param_idx}"));
272            vals.push((confidence as f64).to_string());
273            param_idx += 1;
274        }
275        if let Some(invalid_at) = patch.invalid_at {
276            set_clauses.push(format!("invalid_at = ${param_idx}"));
277            vals.push(invalid_at.to_rfc3339());
278            param_idx += 1;
279        }
280        if let Some(superseded_by) = patch.superseded_by {
281            set_clauses.push(format!("superseded_by = ${param_idx}"));
282            vals.push(superseded_by.to_string());
283            param_idx += 1;
284        }
285        if !patch.metadata.is_empty() {
286            let json = serde_json::to_string(&patch.metadata)
287                .map_err(|e| MemoryError::Serialization(e.to_string()))?;
288            set_clauses.push(format!("metadata = ${param_idx}::jsonb"));
289            vals.push(json);
290            param_idx += 1;
291        }
292
293        if !set_clauses.is_empty() {
294            let sql = format!(
295                "UPDATE facts SET {} WHERE id = ${param_idx}",
296                set_clauses.join(", ")
297            );
298            let mut q = sqlx::query(&sql);
299            for v in &vals {
300                q = q.bind(v.as_str());
301            }
302            q = q.bind(id.to_string());
303            q.execute(&self.pool)
304                .await
305                .map_err(|e| MemoryError::Database(e.to_string()))?;
306        }
307
308        self.get_fact(id).await
309    }
310
311    async fn list_facts(&self, filter: &FactFilter) -> Result<Vec<Fact>, MemoryError> {
312        let mut wheres: Vec<String> = vec!["1=1".to_string()];
313
314        if let Some(ref scope) = filter.scope {
315            wheres.push(format!("org_id = '{}'", scope.org_id.replace('\'', "''")));
316            if let Some(ref user_id) = scope.user_id {
317                wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
318            }
319            if let Some(ref agent_id) = scope.agent_id {
320                wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
321            }
322            if let Some(ref session_id) = scope.session_id {
323                wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
324            }
325        }
326
327        if let Some(ref tier) = filter.tier {
328            wheres.push(format!("tier = '{}'", tier_to_str(tier)));
329        }
330
331        if let Some(ref category) = filter.category {
332            wheres.push(format!("category = '{}'", category.replace('\'', "''")));
333        }
334
335        if let Some(as_of) = filter.as_of {
336            let s = as_of.to_rfc3339();
337            wheres.push(format!("valid_from <= '{s}'"));
338            wheres.push(format!("(invalid_at IS NULL OR invalid_at > '{s}')"));
339        } else if filter.valid_only {
340            wheres.push("invalid_at IS NULL".to_string());
341        }
342
343        if let Some(ref text_contains) = filter.text_contains {
344            let escaped = text_contains.replace('\'', "''");
345            wheres.push(format!("text LIKE '%{escaped}%'"));
346        }
347
348        let where_clause = wheres.join(" AND ");
349        let sql = format!(
350            "SELECT id, text, org_id, agent_id, user_id, session_id, tier, category, source, confidence, valid_from, invalid_at, created_at, entity_refs, supersedes, superseded_by, access_count, last_accessed, metadata FROM facts WHERE {where_clause} ORDER BY created_at DESC LIMIT {} OFFSET {}",
351            filter.limit, filter.offset
352        );
353
354        let rows = sqlx::query_as::<_, FactRow>(&sql)
355            .fetch_all(&self.pool)
356            .await
357            .map_err(|e| MemoryError::Database(e.to_string()))?;
358
359        rows.into_iter().map(row_to_fact).collect()
360    }
361
362    async fn invalidate_fact(&self, id: FactId) -> Result<(), MemoryError> {
363        let now = Utc::now();
364        sqlx::query("UPDATE facts SET invalid_at = $1 WHERE id = $2")
365            .bind(now)
366            .bind(id)
367            .execute(&self.pool)
368            .await
369            .map_err(|e| MemoryError::Database(e.to_string()))?;
370        Ok(())
371    }
372
373    async fn delete_scope_data(&self, scope: &Scope) -> Result<u64, MemoryError> {
374        let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
375
376        if let Some(ref user_id) = scope.user_id {
377            wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
378        }
379        if let Some(ref agent_id) = scope.agent_id {
380            wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
381        }
382        if let Some(ref session_id) = scope.session_id {
383            wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
384        }
385
386        let where_clause = wheres.join(" AND ");
387        let sql = format!("DELETE FROM facts WHERE {where_clause}");
388
389        let result = sqlx::query(&sql)
390            .execute(&self.pool)
391            .await
392            .map_err(|e| MemoryError::Database(e.to_string()))?;
393
394        Ok(result.rows_affected())
395    }
396
397    async fn export(&self, filter: &FactFilter) -> Result<Vec<Fact>, MemoryError> {
398        self.list_facts(filter).await
399    }
400
401    async fn import(&self, facts: Vec<Fact>) -> Result<u64, MemoryError> {
402        let mut imported: u64 = 0;
403        for fact in facts {
404            let entity_refs_json = {
405                let strs: Vec<String> = fact.entity_refs.iter().map(|u| u.to_string()).collect();
406                serde_json::to_value(&strs)
407                    .map_err(|e| MemoryError::Serialization(e.to_string()))?
408            };
409
410            let metadata_json = if fact.metadata.is_empty() {
411                serde_json::Value::Null
412            } else {
413                serde_json::to_value(&fact.metadata)
414                    .map_err(|e| MemoryError::Serialization(e.to_string()))?
415            };
416
417            let result = sqlx::query(
418                r#"
419                INSERT INTO facts
420                    (id, text, org_id, agent_id, user_id, session_id,
421                     tier, category, source, confidence,
422                     valid_from, invalid_at, created_at,
423                     entity_refs, supersedes, superseded_by,
424                     access_count, last_accessed, metadata)
425                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
426                ON CONFLICT (id) DO NOTHING
427                "#,
428            )
429            .bind(fact.id)
430            .bind(&fact.text)
431            .bind(&fact.scope.org_id)
432            .bind(fact.scope.agent_id.as_deref())
433            .bind(fact.scope.user_id.as_deref())
434            .bind(fact.scope.session_id.as_deref())
435            .bind(tier_to_str(&fact.tier))
436            .bind(fact.category.as_deref())
437            .bind(fact.source.as_deref())
438            .bind(fact.confidence.map(|c| c as f64))
439            .bind(fact.valid_from)
440            .bind(fact.invalid_at)
441            .bind(fact.created_at)
442            .bind(&entity_refs_json)
443            .bind(fact.supersedes)
444            .bind(fact.superseded_by)
445            .bind(fact.access_count as i64)
446            .bind(fact.last_accessed)
447            .bind(&metadata_json)
448            .execute(&self.pool)
449            .await
450            .map_err(|e| MemoryError::Database(e.to_string()))?;
451
452            imported += result.rows_affected();
453        }
454        Ok(imported)
455    }
456
457    async fn stats(&self) -> Result<StoreStats, MemoryError> {
458        let (total,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM facts")
459            .fetch_one(&self.pool)
460            .await
461            .map_err(|e| MemoryError::Database(e.to_string()))?;
462
463        let (valid,): (i64,) =
464            sqlx::query_as("SELECT COUNT(*) FROM facts WHERE invalid_at IS NULL")
465                .fetch_one(&self.pool)
466                .await
467                .map_err(|e| MemoryError::Database(e.to_string()))?;
468
469        let (invalidated,): (i64,) =
470            sqlx::query_as("SELECT COUNT(*) FROM facts WHERE invalid_at IS NOT NULL")
471                .fetch_one(&self.pool)
472                .await
473                .map_err(|e| MemoryError::Database(e.to_string()))?;
474
475        Ok(StoreStats {
476            total_facts: total as u64,
477            valid_facts: valid as u64,
478            invalidated_facts: invalidated as u64,
479            total_entities: 0,
480            total_relationships: 0,
481        })
482    }
483
484    async fn record_access(&self, id: FactId) -> Result<(), MemoryError> {
485        let now = Utc::now();
486        sqlx::query(
487            "UPDATE facts SET access_count = access_count + 1, last_accessed = $1 WHERE id = $2",
488        )
489        .bind(now)
490        .bind(id)
491        .execute(&self.pool)
492        .await
493        .map_err(|e| MemoryError::Database(e.to_string()))?;
494        Ok(())
495    }
496
497    async fn keyword_search(
498        &self,
499        query: &str,
500        scope: &Scope,
501        top_k: usize,
502    ) -> Result<Vec<Fact>, MemoryError> {
503        let trimmed = query.trim();
504        if trimmed.is_empty() {
505            return Ok(Vec::new());
506        }
507
508        let sql = r#"
509            SELECT id, text, org_id, agent_id, user_id, session_id, tier, category,
510                   source, confidence, valid_from, invalid_at, created_at, entity_refs,
511                   supersedes, superseded_by, access_count, last_accessed, metadata
512            FROM facts
513            WHERE search_vector @@ plainto_tsquery('english', $1)
514              AND org_id = $2
515              AND ($3::text IS NULL OR user_id = $3)
516              AND invalid_at IS NULL
517            ORDER BY ts_rank(search_vector, plainto_tsquery('english', $1)) DESC
518            LIMIT $4
519        "#;
520
521        let rows = sqlx::query_as::<_, FactRow>(sql)
522            .bind(trimmed)
523            .bind(&scope.org_id)
524            .bind(scope.user_id.as_deref())
525            .bind(top_k as i64)
526            .fetch_all(&self.pool)
527            .await
528            .map_err(|e| MemoryError::Database(e.to_string()))?;
529
530        rows.into_iter().map(row_to_fact).collect()
531    }
532}