Skip to main content

engram/
graph_sqlite.rs

1//! SQLite-backed `GraphStore` implementation.
2//!
3//! All `DateTime<Utc>` values are stored as RFC 3339 strings.
4//! UUIDs are stored as TEXT. `attributes` is a JSON object string.
5
6use crate::fact::{Entity, EntityId, Relationship, RelationshipId, SubGraph};
7use crate::graph::GraphStore;
8use crate::scope::Scope;
9use crate::store::MemoryError;
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::SqlitePool;
13use std::collections::{HashMap, HashSet, VecDeque};
14use uuid::Uuid;
15
16// ---------------------------------------------------------------------------
17// DDL
18// ---------------------------------------------------------------------------
19
20pub const GRAPH_STORE_DDL: &str = r#"
21CREATE TABLE IF NOT EXISTS entities (
22    id          TEXT PRIMARY KEY,
23    name        TEXT NOT NULL,
24    entity_type TEXT,
25    org_id      TEXT NOT NULL DEFAULT 'default',
26    agent_id    TEXT,
27    user_id     TEXT,
28    session_id  TEXT,
29    attributes  TEXT NOT NULL DEFAULT 'null',
30    created_at  TEXT NOT NULL,
31    updated_at  TEXT NOT NULL
32);
33CREATE INDEX IF NOT EXISTS idx_entities_name    ON entities (name);
34CREATE INDEX IF NOT EXISTS idx_entities_user_id ON entities (user_id);
35CREATE INDEX IF NOT EXISTS idx_entities_org_id  ON entities (org_id);
36
37CREATE TABLE IF NOT EXISTS relationships (
38    id          TEXT PRIMARY KEY,
39    source_id   TEXT NOT NULL,
40    relation    TEXT NOT NULL,
41    target_id   TEXT NOT NULL,
42    org_id      TEXT NOT NULL DEFAULT 'default',
43    agent_id    TEXT,
44    user_id     TEXT,
45    session_id  TEXT,
46    valid_from  TEXT NOT NULL,
47    invalid_at  TEXT,
48    created_at  TEXT NOT NULL
49);
50CREATE INDEX IF NOT EXISTS idx_rel_source_id  ON relationships (source_id);
51CREATE INDEX IF NOT EXISTS idx_rel_target_id  ON relationships (target_id);
52CREATE INDEX IF NOT EXISTS idx_rel_relation   ON relationships (relation);
53CREATE INDEX IF NOT EXISTS idx_rel_invalid_at ON relationships (invalid_at);
54"#;
55
56// ---------------------------------------------------------------------------
57// SqliteGraphStore
58// ---------------------------------------------------------------------------
59
60pub struct SqliteGraphStore {
61    pool: SqlitePool,
62}
63
64impl SqliteGraphStore {
65    pub fn new(pool: SqlitePool) -> Self {
66        Self { pool }
67    }
68
69    /// Open a connection pool from a database URL and return a store.
70    pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
71        let pool = SqlitePool::connect(database_url).await?;
72        Ok(Self { pool })
73    }
74
75    /// Apply the DDL. Safe to call multiple times (uses `IF NOT EXISTS`).
76    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
77        for stmt in GRAPH_STORE_DDL.split(';') {
78            let stmt = stmt.trim();
79            if stmt.is_empty() {
80                continue;
81            }
82            sqlx::query(stmt).execute(&self.pool).await?;
83        }
84        Ok(())
85    }
86}
87
88// ---------------------------------------------------------------------------
89// Internal row types
90// ---------------------------------------------------------------------------
91
92#[derive(sqlx::FromRow)]
93struct EntityRow {
94    id: String,
95    name: String,
96    entity_type: Option<String>,
97    org_id: String,
98    agent_id: Option<String>,
99    user_id: Option<String>,
100    session_id: Option<String>,
101    attributes: String,
102    created_at: String,
103    updated_at: String,
104}
105
106#[derive(sqlx::FromRow)]
107struct RelationshipRow {
108    id: String,
109    source_id: String,
110    relation: String,
111    target_id: String,
112    org_id: String,
113    agent_id: Option<String>,
114    user_id: Option<String>,
115    session_id: Option<String>,
116    valid_from: String,
117    invalid_at: Option<String>,
118    created_at: String,
119}
120
121// ---------------------------------------------------------------------------
122// Conversion helpers
123// ---------------------------------------------------------------------------
124
125fn parse_dt(s: &str) -> Result<DateTime<Utc>, MemoryError> {
126    DateTime::parse_from_rfc3339(s)
127        .map(|dt| dt.with_timezone(&Utc))
128        .map_err(|e| MemoryError::Serialization(e.to_string()))
129}
130
131fn parse_opt_dt(s: &Option<String>) -> Result<Option<DateTime<Utc>>, MemoryError> {
132    match s {
133        None => Ok(None),
134        Some(s) => parse_dt(s).map(Some),
135    }
136}
137
138fn row_to_entity(row: EntityRow) -> Result<Entity, MemoryError> {
139    let id = Uuid::parse_str(&row.id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
140
141    let attributes: serde_json::Map<String, serde_json::Value> =
142        if row.attributes == "null" || row.attributes.is_empty() {
143            serde_json::Map::new()
144        } else {
145            serde_json::from_str(&row.attributes)
146                .map_err(|e| MemoryError::Serialization(e.to_string()))?
147        };
148
149    Ok(Entity {
150        id,
151        name: row.name,
152        entity_type: row.entity_type.unwrap_or_else(|| "unknown".to_string()),
153        scope: Scope {
154            org_id: row.org_id,
155            agent_id: row.agent_id,
156            user_id: row.user_id,
157            session_id: row.session_id,
158        },
159        attributes,
160        created_at: parse_dt(&row.created_at)?,
161        updated_at: parse_dt(&row.updated_at)?,
162    })
163}
164
165fn row_to_relationship(row: RelationshipRow) -> Result<Relationship, MemoryError> {
166    let id = Uuid::parse_str(&row.id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
167    let source_id =
168        Uuid::parse_str(&row.source_id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
169    let target_id =
170        Uuid::parse_str(&row.target_id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
171
172    Ok(Relationship {
173        id,
174        source_id,
175        relation: row.relation,
176        target_id,
177        scope: Scope {
178            org_id: row.org_id,
179            agent_id: row.agent_id,
180            user_id: row.user_id,
181            session_id: row.session_id,
182        },
183        valid_from: parse_dt(&row.valid_from)?,
184        invalid_at: parse_opt_dt(&row.invalid_at)?,
185        created_at: parse_dt(&row.created_at)?,
186    })
187}
188
189// ---------------------------------------------------------------------------
190// GraphStore implementation
191// ---------------------------------------------------------------------------
192
193#[async_trait]
194impl GraphStore for SqliteGraphStore {
195    async fn upsert_entity(&self, entity: &Entity) -> Result<(), MemoryError> {
196        let attributes_json = if entity.attributes.is_empty() {
197            "null".to_string()
198        } else {
199            serde_json::to_string(&entity.attributes)
200                .map_err(|e| MemoryError::Serialization(e.to_string()))?
201        };
202
203        sqlx::query(
204            r#"
205            INSERT INTO entities
206                (id, name, entity_type, org_id, agent_id, user_id, session_id,
207                 attributes, created_at, updated_at)
208            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
209            ON CONFLICT(id) DO UPDATE SET
210                name        = excluded.name,
211                entity_type = excluded.entity_type,
212                attributes  = excluded.attributes,
213                updated_at  = excluded.updated_at
214            "#,
215        )
216        .bind(entity.id.to_string())
217        .bind(&entity.name)
218        .bind(&entity.entity_type)
219        .bind(&entity.scope.org_id)
220        .bind(entity.scope.agent_id.as_deref())
221        .bind(entity.scope.user_id.as_deref())
222        .bind(entity.scope.session_id.as_deref())
223        .bind(&attributes_json)
224        .bind(entity.created_at.to_rfc3339())
225        .bind(entity.updated_at.to_rfc3339())
226        .execute(&self.pool)
227        .await
228        .map_err(|e| MemoryError::Database(e.to_string()))?;
229
230        Ok(())
231    }
232
233    async fn upsert_relationship(&self, rel: &Relationship) -> Result<(), MemoryError> {
234        sqlx::query(
235            r#"
236            INSERT INTO relationships
237                (id, source_id, relation, target_id, org_id, agent_id, user_id, session_id,
238                 valid_from, invalid_at, created_at)
239            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
240            ON CONFLICT(id) DO UPDATE SET
241                relation   = excluded.relation,
242                invalid_at = excluded.invalid_at
243            "#,
244        )
245        .bind(rel.id.to_string())
246        .bind(rel.source_id.to_string())
247        .bind(&rel.relation)
248        .bind(rel.target_id.to_string())
249        .bind(&rel.scope.org_id)
250        .bind(rel.scope.agent_id.as_deref())
251        .bind(rel.scope.user_id.as_deref())
252        .bind(rel.scope.session_id.as_deref())
253        .bind(rel.valid_from.to_rfc3339())
254        .bind(rel.invalid_at.map(|dt| dt.to_rfc3339()))
255        .bind(rel.created_at.to_rfc3339())
256        .execute(&self.pool)
257        .await
258        .map_err(|e| MemoryError::Database(e.to_string()))?;
259
260        Ok(())
261    }
262
263    async fn invalidate_relationship(
264        &self,
265        id: RelationshipId,
266        invalid_at: DateTime<Utc>,
267    ) -> Result<(), MemoryError> {
268        sqlx::query("UPDATE relationships SET invalid_at = ? WHERE id = ?")
269            .bind(invalid_at.to_rfc3339())
270            .bind(id.to_string())
271            .execute(&self.pool)
272            .await
273            .map_err(|e| MemoryError::Database(e.to_string()))?;
274        Ok(())
275    }
276
277    async fn get_entity(&self, id: EntityId) -> Result<Option<Entity>, MemoryError> {
278        let row = sqlx::query_as::<_, EntityRow>("SELECT * FROM entities WHERE id = ?")
279            .bind(id.to_string())
280            .fetch_optional(&self.pool)
281            .await
282            .map_err(|e| MemoryError::Database(e.to_string()))?;
283
284        row.map(row_to_entity).transpose()
285    }
286
287    async fn neighbors(
288        &self,
289        id: EntityId,
290        depth: u8,
291        as_of: Option<DateTime<Utc>>,
292    ) -> Result<SubGraph, MemoryError> {
293        // Build the temporal validity filter for relationships.
294        let validity_clause = match as_of {
295            Some(t) => {
296                let s = t.to_rfc3339();
297                format!("valid_from <= '{s}' AND (invalid_at IS NULL OR invalid_at > '{s}')")
298            }
299            None => "invalid_at IS NULL".to_string(),
300        };
301
302        let mut visited_entities: HashSet<EntityId> = HashSet::new();
303        visited_entities.insert(id);
304
305        let mut discovered_entities: HashMap<EntityId, Entity> = HashMap::new();
306        let mut discovered_relationships: HashMap<RelationshipId, Relationship> = HashMap::new();
307
308        // BFS queue: (entity_id, remaining_depth)
309        let mut queue: VecDeque<(EntityId, u8)> = VecDeque::new();
310        queue.push_back((id, depth));
311
312        while let Some((current_id, remaining)) = queue.pop_front() {
313            if remaining == 0 {
314                continue;
315            }
316
317            // Fetch all valid relationships where current_id is source or target.
318            let sql = format!(
319                "SELECT * FROM relationships WHERE (source_id = ? OR target_id = ?) AND {validity_clause}"
320            );
321
322            let rel_rows = sqlx::query_as::<_, RelationshipRow>(&sql)
323                .bind(current_id.to_string())
324                .bind(current_id.to_string())
325                .fetch_all(&self.pool)
326                .await
327                .map_err(|e| MemoryError::Database(e.to_string()))?;
328
329            for row in rel_rows {
330                let rel = row_to_relationship(row)?;
331                let neighbor_id = if rel.source_id == current_id {
332                    rel.target_id
333                } else {
334                    rel.source_id
335                };
336
337                // Store the relationship (deduplicated by id).
338                discovered_relationships.entry(rel.id).or_insert(rel);
339
340                // Enqueue unvisited neighbors.
341                if !visited_entities.contains(&neighbor_id) {
342                    visited_entities.insert(neighbor_id);
343
344                    // Fetch the neighbor entity.
345                    if let Some(entity) = self.get_entity(neighbor_id).await? {
346                        discovered_entities.entry(neighbor_id).or_insert(entity);
347                    }
348
349                    queue.push_back((neighbor_id, remaining - 1));
350                }
351            }
352        }
353
354        Ok(SubGraph {
355            entities: discovered_entities.into_values().collect(),
356            relationships: discovered_relationships.into_values().collect(),
357        })
358    }
359
360    async fn search_entities(&self, query: &str, top_k: usize) -> Result<Vec<Entity>, MemoryError> {
361        let escaped = query.replace('\'', "''");
362        let sql = format!("SELECT * FROM entities WHERE name LIKE '%{escaped}%' LIMIT {top_k}");
363
364        let rows = sqlx::query_as::<_, EntityRow>(&sql)
365            .fetch_all(&self.pool)
366            .await
367            .map_err(|e| MemoryError::Database(e.to_string()))?;
368
369        rows.into_iter().map(row_to_entity).collect()
370    }
371
372    async fn delete_by_scope(&self, scope: &Scope) -> Result<u64, MemoryError> {
373        // Build WHERE clause matching the scope fields.
374        let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
375
376        if let Some(ref user_id) = scope.user_id {
377            wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
378        }
379        if let Some(ref agent_id) = scope.agent_id {
380            wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
381        }
382        if let Some(ref session_id) = scope.session_id {
383            wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
384        }
385
386        let where_clause = wheres.join(" AND ");
387
388        // Delete relationships first.
389        let rel_sql = format!("DELETE FROM relationships WHERE {where_clause}");
390        let rel_result = sqlx::query(&rel_sql)
391            .execute(&self.pool)
392            .await
393            .map_err(|e| MemoryError::Database(e.to_string()))?;
394
395        // Then delete entities.
396        let ent_sql = format!("DELETE FROM entities WHERE {where_clause}");
397        let ent_result = sqlx::query(&ent_sql)
398            .execute(&self.pool)
399            .await
400            .map_err(|e| MemoryError::Database(e.to_string()))?;
401
402        Ok(rel_result.rows_affected() + ent_result.rows_affected())
403    }
404}