Skip to main content

engram/
store_sqlite.rs

1//! SQLite-backed `FactStore` implementation.
2//!
3//! All `DateTime<Utc>` values are stored as RFC 3339 strings.
4//! UUIDs are stored as TEXT. `entity_refs` is a JSON array of UUID strings.
5//! `metadata` is a JSON object string (or `"null"` when empty).
6
7use crate::fact::{Fact, FactFilter, FactId, FactPatch, MemoryTier};
8use crate::scope::Scope;
9use crate::store::{FactStore, MemoryError, StoreStats};
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::SqlitePool;
13use uuid::Uuid;
14
15// ---------------------------------------------------------------------------
16// DDL
17// ---------------------------------------------------------------------------
18
19pub const FACT_STORE_DDL: &str = r#"
20CREATE TABLE IF NOT EXISTS facts (
21    id              TEXT PRIMARY KEY,
22    text            TEXT NOT NULL,
23    org_id          TEXT NOT NULL DEFAULT 'default',
24    agent_id        TEXT,
25    user_id         TEXT,
26    session_id      TEXT,
27    tier            TEXT NOT NULL DEFAULT 'conversation',
28    category        TEXT,
29    source          TEXT,
30    confidence      REAL,
31    valid_from      TEXT NOT NULL,
32    invalid_at      TEXT,
33    created_at      TEXT NOT NULL,
34    entity_refs     TEXT NOT NULL DEFAULT '[]',
35    supersedes      TEXT,
36    superseded_by   TEXT,
37    access_count    INTEGER NOT NULL DEFAULT 0,
38    last_accessed   TEXT,
39    metadata        TEXT NOT NULL DEFAULT 'null'
40);
41CREATE INDEX IF NOT EXISTS idx_facts_org_id     ON facts (org_id);
42CREATE INDEX IF NOT EXISTS idx_facts_user_id    ON facts (user_id);
43CREATE INDEX IF NOT EXISTS idx_facts_agent_id   ON facts (agent_id);
44CREATE INDEX IF NOT EXISTS idx_facts_session_id ON facts (session_id);
45CREATE INDEX IF NOT EXISTS idx_facts_tier       ON facts (tier);
46CREATE INDEX IF NOT EXISTS idx_facts_category   ON facts (category);
47CREATE INDEX IF NOT EXISTS idx_facts_valid_from ON facts (valid_from);
48CREATE INDEX IF NOT EXISTS idx_facts_invalid_at ON facts (invalid_at);
49"#;
50
51/// FTS5 virtual table for keyword search. Separate from FACT_STORE_DDL because
52/// triggers contain `;` inside their bodies, which breaks the `split(';')` DDL
53/// execution approach. Uses a standalone FTS5 table (no content= directive) to
54/// avoid column-name mismatches.
55const FTS5_DDL: &[&str] = &[
56    "CREATE VIRTUAL TABLE IF NOT EXISTS facts_fts USING fts5(fact_id UNINDEXED, text)",
57    "CREATE TRIGGER IF NOT EXISTS facts_ai AFTER INSERT ON facts BEGIN INSERT INTO facts_fts(fact_id, text) VALUES (new.id, new.text); END",
58    "CREATE TRIGGER IF NOT EXISTS facts_ad AFTER DELETE ON facts BEGIN DELETE FROM facts_fts WHERE fact_id = old.id; END",
59];
60
61// ---------------------------------------------------------------------------
62// SqliteFactStore
63// ---------------------------------------------------------------------------
64
65pub struct SqliteFactStore {
66    pool: SqlitePool,
67}
68
69impl SqliteFactStore {
70    pub fn new(pool: SqlitePool) -> Self {
71        Self { pool }
72    }
73
74    /// Open a connection pool from a database URL and return a store.
75    pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
76        let pool = SqlitePool::connect(database_url).await?;
77        Ok(Self { pool })
78    }
79
80    /// Apply the DDL. Safe to call multiple times (uses `IF NOT EXISTS`).
81    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
82        for stmt in FACT_STORE_DDL.split(';') {
83            let stmt = stmt.trim();
84            if stmt.is_empty() {
85                continue;
86            }
87            sqlx::query(stmt).execute(&self.pool).await?;
88        }
89        // FTS5 DDL is handled separately because triggers contain `;` inside
90        for stmt in FTS5_DDL {
91            sqlx::query(stmt).execute(&self.pool).await?;
92        }
93        Ok(())
94    }
95}
96
97// ---------------------------------------------------------------------------
98// Internal row type
99// ---------------------------------------------------------------------------
100
101#[derive(sqlx::FromRow)]
102struct FactRow {
103    id: String,
104    text: String,
105    org_id: String,
106    agent_id: Option<String>,
107    user_id: Option<String>,
108    session_id: Option<String>,
109    tier: String,
110    category: Option<String>,
111    source: Option<String>,
112    confidence: Option<f64>,
113    valid_from: String,
114    invalid_at: Option<String>,
115    created_at: String,
116    entity_refs: String,
117    supersedes: Option<String>,
118    superseded_by: Option<String>,
119    access_count: i64,
120    last_accessed: Option<String>,
121    metadata: String,
122}
123
124// ---------------------------------------------------------------------------
125// Conversion helpers
126// ---------------------------------------------------------------------------
127
128fn parse_dt(s: &str) -> Result<DateTime<Utc>, MemoryError> {
129    DateTime::parse_from_rfc3339(s)
130        .map(|dt| dt.with_timezone(&Utc))
131        .map_err(|e| MemoryError::Serialization(e.to_string()))
132}
133
134fn parse_opt_dt(s: &Option<String>) -> Result<Option<DateTime<Utc>>, MemoryError> {
135    match s {
136        None => Ok(None),
137        Some(s) => parse_dt(s).map(Some),
138    }
139}
140
141fn tier_from_str(s: &str) -> MemoryTier {
142    match s {
143        "working" => MemoryTier::Working,
144        "knowledge" => MemoryTier::Knowledge,
145        _ => MemoryTier::Conversation,
146    }
147}
148
149fn tier_to_str(t: &MemoryTier) -> &'static str {
150    match t {
151        MemoryTier::Working => "working",
152        MemoryTier::Conversation => "conversation",
153        MemoryTier::Knowledge => "knowledge",
154    }
155}
156
157fn row_to_fact(row: FactRow) -> Result<Fact, MemoryError> {
158    let id = Uuid::parse_str(&row.id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
159
160    let entity_refs: Vec<Uuid> = {
161        let strings: Vec<String> = serde_json::from_str(&row.entity_refs)
162            .map_err(|e| MemoryError::Serialization(e.to_string()))?;
163        strings
164            .iter()
165            .map(|s| Uuid::parse_str(s).map_err(|e| MemoryError::Serialization(e.to_string())))
166            .collect::<Result<Vec<_>, _>>()?
167    };
168
169    let metadata: serde_json::Map<String, serde_json::Value> =
170        if row.metadata == "null" || row.metadata.is_empty() {
171            serde_json::Map::new()
172        } else {
173            serde_json::from_str(&row.metadata)
174                .map_err(|e| MemoryError::Serialization(e.to_string()))?
175        };
176
177    let supersedes = row
178        .supersedes
179        .as_deref()
180        .map(|s| Uuid::parse_str(s).map_err(|e| MemoryError::Serialization(e.to_string())))
181        .transpose()?;
182
183    let superseded_by = row
184        .superseded_by
185        .as_deref()
186        .map(|s| Uuid::parse_str(s).map_err(|e| MemoryError::Serialization(e.to_string())))
187        .transpose()?;
188
189    Ok(Fact {
190        id,
191        text: row.text,
192        scope: Scope {
193            org_id: row.org_id,
194            agent_id: row.agent_id,
195            user_id: row.user_id,
196            session_id: row.session_id,
197        },
198        tier: tier_from_str(&row.tier),
199        category: row.category,
200        source: row.source,
201        confidence: row.confidence.map(|c| c as f32),
202        valid_from: parse_dt(&row.valid_from)?,
203        invalid_at: parse_opt_dt(&row.invalid_at)?,
204        created_at: parse_dt(&row.created_at)?,
205        // embeddings are not persisted in the facts table
206        embedding: Vec::new(),
207        entity_refs,
208        supersedes,
209        superseded_by,
210        access_count: row.access_count as u64,
211        last_accessed: parse_opt_dt(&row.last_accessed)?,
212        metadata,
213    })
214}
215
216// ---------------------------------------------------------------------------
217// FactStore implementation
218// ---------------------------------------------------------------------------
219
220#[async_trait]
221impl FactStore for SqliteFactStore {
222    async fn insert_fact(&self, fact: Fact) -> Result<FactId, MemoryError> {
223        let entity_refs_json = {
224            let strs: Vec<String> = fact.entity_refs.iter().map(|u| u.to_string()).collect();
225            serde_json::to_string(&strs).map_err(|e| MemoryError::Serialization(e.to_string()))?
226        };
227
228        let metadata_json = if fact.metadata.is_empty() {
229            "null".to_string()
230        } else {
231            serde_json::to_string(&fact.metadata)
232                .map_err(|e| MemoryError::Serialization(e.to_string()))?
233        };
234
235        sqlx::query(
236            r#"
237            INSERT OR IGNORE INTO facts
238                (id, text, org_id, agent_id, user_id, session_id,
239                 tier, category, source, confidence,
240                 valid_from, invalid_at, created_at,
241                 entity_refs, supersedes, superseded_by,
242                 access_count, last_accessed, metadata)
243            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
244            "#,
245        )
246        .bind(fact.id.to_string())
247        .bind(&fact.text)
248        .bind(&fact.scope.org_id)
249        .bind(fact.scope.agent_id.as_deref())
250        .bind(fact.scope.user_id.as_deref())
251        .bind(fact.scope.session_id.as_deref())
252        .bind(tier_to_str(&fact.tier))
253        .bind(fact.category.as_deref())
254        .bind(fact.source.as_deref())
255        .bind(fact.confidence.map(|c| c as f64))
256        .bind(fact.valid_from.to_rfc3339())
257        .bind(fact.invalid_at.map(|dt| dt.to_rfc3339()))
258        .bind(fact.created_at.to_rfc3339())
259        .bind(entity_refs_json)
260        .bind(fact.supersedes.map(|u| u.to_string()))
261        .bind(fact.superseded_by.map(|u| u.to_string()))
262        .bind(fact.access_count as i64)
263        .bind(fact.last_accessed.map(|dt| dt.to_rfc3339()))
264        .bind(metadata_json)
265        .execute(&self.pool)
266        .await
267        .map_err(|e| MemoryError::Database(e.to_string()))?;
268
269        Ok(fact.id)
270    }
271
272    async fn get_fact(&self, id: FactId) -> Result<Fact, MemoryError> {
273        let row = sqlx::query_as::<_, FactRow>("SELECT * FROM facts WHERE id = ?")
274            .bind(id.to_string())
275            .fetch_optional(&self.pool)
276            .await
277            .map_err(|e| MemoryError::Database(e.to_string()))?
278            .ok_or_else(|| MemoryError::NotFound(id.to_string()))?;
279
280        row_to_fact(row)
281    }
282
283    async fn update_fact(&self, id: FactId, patch: FactPatch) -> Result<Fact, MemoryError> {
284        // Build a (column = ?, value_string) list for each non-None patch field.
285        // We bind all values as strings; SQLite will coerce them.
286        let mut cols: Vec<&'static str> = Vec::new();
287        let mut vals: Vec<String> = Vec::new();
288
289        if let Some(ref text) = patch.text {
290            cols.push("text = ?");
291            vals.push(text.clone());
292        }
293        if let Some(ref tier) = patch.tier {
294            cols.push("tier = ?");
295            vals.push(tier_to_str(tier).to_string());
296        }
297        if let Some(ref category) = patch.category {
298            cols.push("category = ?");
299            vals.push(category.clone());
300        }
301        if let Some(ref source) = patch.source {
302            cols.push("source = ?");
303            vals.push(source.clone());
304        }
305        if let Some(confidence) = patch.confidence {
306            cols.push("confidence = ?");
307            vals.push((confidence as f64).to_string());
308        }
309        if let Some(invalid_at) = patch.invalid_at {
310            cols.push("invalid_at = ?");
311            vals.push(invalid_at.to_rfc3339());
312        }
313        if let Some(superseded_by) = patch.superseded_by {
314            cols.push("superseded_by = ?");
315            vals.push(superseded_by.to_string());
316        }
317        if !patch.metadata.is_empty() {
318            let json = serde_json::to_string(&patch.metadata)
319                .map_err(|e| MemoryError::Serialization(e.to_string()))?;
320            cols.push("metadata = ?");
321            vals.push(json);
322        }
323
324        if !cols.is_empty() {
325            let sql = format!("UPDATE facts SET {} WHERE id = ?", cols.join(", "));
326            let mut q = sqlx::query(&sql);
327            for v in &vals {
328                q = q.bind(v.as_str());
329            }
330            q = q.bind(id.to_string());
331            q.execute(&self.pool)
332                .await
333                .map_err(|e| MemoryError::Database(e.to_string()))?;
334        }
335
336        self.get_fact(id).await
337    }
338
339    async fn list_facts(&self, filter: &FactFilter) -> Result<Vec<Fact>, MemoryError> {
340        let mut wheres: Vec<String> = vec!["1=1".to_string()];
341
342        if let Some(ref scope) = filter.scope {
343            wheres.push(format!("org_id = '{}'", scope.org_id.replace('\'', "''")));
344            if let Some(ref user_id) = scope.user_id {
345                wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
346            }
347            if let Some(ref agent_id) = scope.agent_id {
348                wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
349            }
350            if let Some(ref session_id) = scope.session_id {
351                wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
352            }
353        }
354
355        if let Some(ref tier) = filter.tier {
356            wheres.push(format!("tier = '{}'", tier_to_str(tier)));
357        }
358
359        if let Some(ref category) = filter.category {
360            wheres.push(format!("category = '{}'", category.replace('\'', "''")));
361        }
362
363        if let Some(as_of) = filter.as_of {
364            let s = as_of.to_rfc3339();
365            wheres.push(format!("valid_from <= '{s}'"));
366            wheres.push(format!("(invalid_at IS NULL OR invalid_at > '{s}')"));
367        } else if filter.valid_only {
368            wheres.push("invalid_at IS NULL".to_string());
369        }
370
371        if let Some(ref text_contains) = filter.text_contains {
372            let escaped = text_contains.replace('\'', "''");
373            wheres.push(format!("text LIKE '%{escaped}%'"));
374        }
375
376        let where_clause = wheres.join(" AND ");
377        let sql = format!(
378            "SELECT * FROM facts WHERE {where_clause} ORDER BY created_at DESC LIMIT {} OFFSET {}",
379            filter.limit, filter.offset
380        );
381
382        let rows = sqlx::query_as::<_, FactRow>(&sql)
383            .fetch_all(&self.pool)
384            .await
385            .map_err(|e| MemoryError::Database(e.to_string()))?;
386
387        rows.into_iter().map(row_to_fact).collect()
388    }
389
390    async fn invalidate_fact(&self, id: FactId) -> Result<(), MemoryError> {
391        let now = Utc::now().to_rfc3339();
392        sqlx::query("UPDATE facts SET invalid_at = ? WHERE id = ?")
393            .bind(&now)
394            .bind(id.to_string())
395            .execute(&self.pool)
396            .await
397            .map_err(|e| MemoryError::Database(e.to_string()))?;
398        Ok(())
399    }
400
401    async fn delete_scope_data(&self, scope: &Scope) -> Result<u64, MemoryError> {
402        let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
403
404        if let Some(ref user_id) = scope.user_id {
405            wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
406        }
407        if let Some(ref agent_id) = scope.agent_id {
408            wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
409        }
410        if let Some(ref session_id) = scope.session_id {
411            wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
412        }
413
414        let where_clause = wheres.join(" AND ");
415        let sql = format!("DELETE FROM facts WHERE {where_clause}");
416
417        let result = sqlx::query(&sql)
418            .execute(&self.pool)
419            .await
420            .map_err(|e| MemoryError::Database(e.to_string()))?;
421
422        Ok(result.rows_affected())
423    }
424
425    async fn export(&self, filter: &FactFilter) -> Result<Vec<Fact>, MemoryError> {
426        self.list_facts(filter).await
427    }
428
429    async fn import(&self, facts: Vec<Fact>) -> Result<u64, MemoryError> {
430        let mut imported: u64 = 0;
431        for fact in facts {
432            let result = sqlx::query(
433                r#"
434                INSERT OR IGNORE INTO facts
435                    (id, text, org_id, agent_id, user_id, session_id,
436                     tier, category, source, confidence,
437                     valid_from, invalid_at, created_at,
438                     entity_refs, supersedes, superseded_by,
439                     access_count, last_accessed, metadata)
440                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
441                "#,
442            )
443            .bind(fact.id.to_string())
444            .bind(&fact.text)
445            .bind(&fact.scope.org_id)
446            .bind(fact.scope.agent_id.as_deref())
447            .bind(fact.scope.user_id.as_deref())
448            .bind(fact.scope.session_id.as_deref())
449            .bind(tier_to_str(&fact.tier))
450            .bind(fact.category.as_deref())
451            .bind(fact.source.as_deref())
452            .bind(fact.confidence.map(|c| c as f64))
453            .bind(fact.valid_from.to_rfc3339())
454            .bind(fact.invalid_at.map(|dt| dt.to_rfc3339()))
455            .bind(fact.created_at.to_rfc3339())
456            .bind({
457                let strs: Vec<String> = fact.entity_refs.iter().map(|u| u.to_string()).collect();
458                serde_json::to_string(&strs)
459                    .map_err(|e| MemoryError::Serialization(e.to_string()))?
460            })
461            .bind(fact.supersedes.map(|u| u.to_string()))
462            .bind(fact.superseded_by.map(|u| u.to_string()))
463            .bind(fact.access_count as i64)
464            .bind(fact.last_accessed.map(|dt| dt.to_rfc3339()))
465            .bind(if fact.metadata.is_empty() {
466                "null".to_string()
467            } else {
468                serde_json::to_string(&fact.metadata)
469                    .map_err(|e| MemoryError::Serialization(e.to_string()))?
470            })
471            .execute(&self.pool)
472            .await
473            .map_err(|e| MemoryError::Database(e.to_string()))?;
474
475            imported += result.rows_affected();
476        }
477        Ok(imported)
478    }
479
480    async fn stats(&self) -> Result<StoreStats, MemoryError> {
481        let (total,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM facts")
482            .fetch_one(&self.pool)
483            .await
484            .map_err(|e| MemoryError::Database(e.to_string()))?;
485
486        let (valid,): (i64,) =
487            sqlx::query_as("SELECT COUNT(*) FROM facts WHERE invalid_at IS NULL")
488                .fetch_one(&self.pool)
489                .await
490                .map_err(|e| MemoryError::Database(e.to_string()))?;
491
492        let (invalidated,): (i64,) =
493            sqlx::query_as("SELECT COUNT(*) FROM facts WHERE invalid_at IS NOT NULL")
494                .fetch_one(&self.pool)
495                .await
496                .map_err(|e| MemoryError::Database(e.to_string()))?;
497
498        Ok(StoreStats {
499            total_facts: total as u64,
500            valid_facts: valid as u64,
501            invalidated_facts: invalidated as u64,
502            total_entities: 0,
503            total_relationships: 0,
504        })
505    }
506
507    async fn record_access(&self, id: FactId) -> Result<(), MemoryError> {
508        let now = Utc::now().to_rfc3339();
509        sqlx::query(
510            "UPDATE facts SET access_count = access_count + 1, last_accessed = ? WHERE id = ?",
511        )
512        .bind(&now)
513        .bind(id.to_string())
514        .execute(&self.pool)
515        .await
516        .map_err(|e| MemoryError::Database(e.to_string()))?;
517        Ok(())
518    }
519
520    async fn keyword_search(
521        &self,
522        query: &str,
523        scope: &Scope,
524        top_k: usize,
525    ) -> Result<Vec<Fact>, MemoryError> {
526        let sql = r#"
527            SELECT f.*
528            FROM facts_fts fts
529            INNER JOIN facts f ON f.id = fts.fact_id
530            WHERE facts_fts MATCH ?
531              AND f.org_id = ?
532              AND (? IS NULL OR f.user_id = ?)
533              AND f.invalid_at IS NULL
534            ORDER BY fts.rank
535            LIMIT ?
536        "#;
537
538        let normalized = normalize_fts_query(query);
539        if normalized.is_empty() {
540            return Ok(Vec::new());
541        }
542
543        let rows = sqlx::query_as::<_, FactRow>(sql)
544            .bind(&normalized)
545            .bind(&scope.org_id)
546            .bind(scope.user_id.as_deref())
547            .bind(scope.user_id.as_deref())
548            .bind(top_k as i64)
549            .fetch_all(&self.pool)
550            .await
551            .map_err(|e| MemoryError::Database(e.to_string()))?;
552
553        rows.into_iter().map(row_to_fact).collect()
554    }
555}
556
557/// Normalize a raw keyword query into an FTS5 MATCH expression.
558///
559/// FTS5 requires `peanut*` for prefix matching — plain `peanut` only matches
560/// the exact token `peanut` and would miss `peanuts`. This helper strips
561/// punctuation, lowercases, splits on whitespace, and appends `*` to each
562/// token so that reasonable user input like `"peanut"` or `"food allergies"`
563/// matches what they expect. Queries that already contain FTS5 operators
564/// (quoted phrases, column filters) are passed through unchanged.
565fn normalize_fts_query(query: &str) -> String {
566    let trimmed = query.trim();
567    if trimmed.is_empty() {
568        return String::new();
569    }
570    // Pass through advanced FTS5 syntax so power users aren't blocked.
571    if trimmed.contains('"') || trimmed.contains(':') || trimmed.contains('(') {
572        return trimmed.to_string();
573    }
574
575    trimmed
576        .split_whitespace()
577        .filter_map(|token| {
578            let cleaned: String = token
579                .chars()
580                .filter(|c| c.is_alphanumeric() || *c == '_')
581                .collect();
582            if cleaned.is_empty() {
583                None
584            } else {
585                Some(format!("{cleaned}*"))
586            }
587        })
588        .collect::<Vec<_>>()
589        .join(" ")
590}
591
592#[cfg(test)]
593mod tests {
594    use super::normalize_fts_query;
595
596    #[test]
597    fn single_token_gets_prefix_star() {
598        assert_eq!(normalize_fts_query("peanut"), "peanut*");
599    }
600
601    #[test]
602    fn multi_token_each_gets_prefix_star() {
603        assert_eq!(normalize_fts_query("food allergies"), "food* allergies*");
604    }
605
606    #[test]
607    fn punctuation_stripped() {
608        assert_eq!(normalize_fts_query("what's up?"), "whats* up*");
609    }
610
611    #[test]
612    fn empty_query_returns_empty() {
613        assert_eq!(normalize_fts_query("   "), "");
614    }
615
616    #[test]
617    fn quoted_phrase_passes_through() {
618        let q = "\"exact phrase\"";
619        assert_eq!(normalize_fts_query(q), q);
620    }
621}