1use crate::fact::{Entity, EntityId, Relationship, RelationshipId, SubGraph};
7use crate::graph::GraphStore;
8use crate::scope::Scope;
9use crate::store::MemoryError;
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::SqlitePool;
13use std::collections::{HashMap, HashSet, VecDeque};
14use uuid::Uuid;
15
16pub const GRAPH_STORE_DDL: &str = r#"
21CREATE TABLE IF NOT EXISTS entities (
22 id TEXT PRIMARY KEY,
23 name TEXT NOT NULL,
24 entity_type TEXT,
25 org_id TEXT NOT NULL DEFAULT 'default',
26 agent_id TEXT,
27 user_id TEXT,
28 session_id TEXT,
29 attributes TEXT NOT NULL DEFAULT 'null',
30 created_at TEXT NOT NULL,
31 updated_at TEXT NOT NULL
32);
33CREATE INDEX IF NOT EXISTS idx_entities_name ON entities (name);
34CREATE INDEX IF NOT EXISTS idx_entities_user_id ON entities (user_id);
35CREATE INDEX IF NOT EXISTS idx_entities_org_id ON entities (org_id);
36
37CREATE TABLE IF NOT EXISTS relationships (
38 id TEXT PRIMARY KEY,
39 source_id TEXT NOT NULL,
40 relation TEXT NOT NULL,
41 target_id TEXT NOT NULL,
42 org_id TEXT NOT NULL DEFAULT 'default',
43 agent_id TEXT,
44 user_id TEXT,
45 session_id TEXT,
46 valid_from TEXT NOT NULL,
47 invalid_at TEXT,
48 created_at TEXT NOT NULL
49);
50CREATE INDEX IF NOT EXISTS idx_rel_source_id ON relationships (source_id);
51CREATE INDEX IF NOT EXISTS idx_rel_target_id ON relationships (target_id);
52CREATE INDEX IF NOT EXISTS idx_rel_relation ON relationships (relation);
53CREATE INDEX IF NOT EXISTS idx_rel_invalid_at ON relationships (invalid_at);
54"#;
55
56pub struct SqliteGraphStore {
61 pool: SqlitePool,
62}
63
64impl SqliteGraphStore {
65 pub fn new(pool: SqlitePool) -> Self {
66 Self { pool }
67 }
68
69 pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
71 let pool = SqlitePool::connect(database_url).await?;
72 Ok(Self { pool })
73 }
74
75 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
77 for stmt in GRAPH_STORE_DDL.split(';') {
78 let stmt = stmt.trim();
79 if stmt.is_empty() {
80 continue;
81 }
82 sqlx::query(stmt).execute(&self.pool).await?;
83 }
84 Ok(())
85 }
86}
87
88#[derive(sqlx::FromRow)]
93struct EntityRow {
94 id: String,
95 name: String,
96 entity_type: Option<String>,
97 org_id: String,
98 agent_id: Option<String>,
99 user_id: Option<String>,
100 session_id: Option<String>,
101 attributes: String,
102 created_at: String,
103 updated_at: String,
104}
105
106#[derive(sqlx::FromRow)]
107struct RelationshipRow {
108 id: String,
109 source_id: String,
110 relation: String,
111 target_id: String,
112 org_id: String,
113 agent_id: Option<String>,
114 user_id: Option<String>,
115 session_id: Option<String>,
116 valid_from: String,
117 invalid_at: Option<String>,
118 created_at: String,
119}
120
121fn parse_dt(s: &str) -> Result<DateTime<Utc>, MemoryError> {
126 DateTime::parse_from_rfc3339(s)
127 .map(|dt| dt.with_timezone(&Utc))
128 .map_err(|e| MemoryError::Serialization(e.to_string()))
129}
130
131fn parse_opt_dt(s: &Option<String>) -> Result<Option<DateTime<Utc>>, MemoryError> {
132 match s {
133 None => Ok(None),
134 Some(s) => parse_dt(s).map(Some),
135 }
136}
137
138fn row_to_entity(row: EntityRow) -> Result<Entity, MemoryError> {
139 let id = Uuid::parse_str(&row.id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
140
141 let attributes: serde_json::Map<String, serde_json::Value> =
142 if row.attributes == "null" || row.attributes.is_empty() {
143 serde_json::Map::new()
144 } else {
145 serde_json::from_str(&row.attributes)
146 .map_err(|e| MemoryError::Serialization(e.to_string()))?
147 };
148
149 Ok(Entity {
150 id,
151 name: row.name,
152 entity_type: row.entity_type.unwrap_or_else(|| "unknown".to_string()),
153 scope: Scope {
154 org_id: row.org_id,
155 agent_id: row.agent_id,
156 user_id: row.user_id,
157 session_id: row.session_id,
158 },
159 attributes,
160 created_at: parse_dt(&row.created_at)?,
161 updated_at: parse_dt(&row.updated_at)?,
162 })
163}
164
165fn row_to_relationship(row: RelationshipRow) -> Result<Relationship, MemoryError> {
166 let id = Uuid::parse_str(&row.id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
167 let source_id =
168 Uuid::parse_str(&row.source_id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
169 let target_id =
170 Uuid::parse_str(&row.target_id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
171
172 Ok(Relationship {
173 id,
174 source_id,
175 relation: row.relation,
176 target_id,
177 scope: Scope {
178 org_id: row.org_id,
179 agent_id: row.agent_id,
180 user_id: row.user_id,
181 session_id: row.session_id,
182 },
183 valid_from: parse_dt(&row.valid_from)?,
184 invalid_at: parse_opt_dt(&row.invalid_at)?,
185 created_at: parse_dt(&row.created_at)?,
186 })
187}
188
189#[async_trait]
194impl GraphStore for SqliteGraphStore {
195 async fn upsert_entity(&self, entity: &Entity) -> Result<(), MemoryError> {
196 let attributes_json = if entity.attributes.is_empty() {
197 "null".to_string()
198 } else {
199 serde_json::to_string(&entity.attributes)
200 .map_err(|e| MemoryError::Serialization(e.to_string()))?
201 };
202
203 sqlx::query(
204 r#"
205 INSERT INTO entities
206 (id, name, entity_type, org_id, agent_id, user_id, session_id,
207 attributes, created_at, updated_at)
208 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
209 ON CONFLICT(id) DO UPDATE SET
210 name = excluded.name,
211 entity_type = excluded.entity_type,
212 attributes = excluded.attributes,
213 updated_at = excluded.updated_at
214 "#,
215 )
216 .bind(entity.id.to_string())
217 .bind(&entity.name)
218 .bind(&entity.entity_type)
219 .bind(&entity.scope.org_id)
220 .bind(entity.scope.agent_id.as_deref())
221 .bind(entity.scope.user_id.as_deref())
222 .bind(entity.scope.session_id.as_deref())
223 .bind(&attributes_json)
224 .bind(entity.created_at.to_rfc3339())
225 .bind(entity.updated_at.to_rfc3339())
226 .execute(&self.pool)
227 .await
228 .map_err(|e| MemoryError::Database(e.to_string()))?;
229
230 Ok(())
231 }
232
233 async fn upsert_relationship(&self, rel: &Relationship) -> Result<(), MemoryError> {
234 sqlx::query(
235 r#"
236 INSERT INTO relationships
237 (id, source_id, relation, target_id, org_id, agent_id, user_id, session_id,
238 valid_from, invalid_at, created_at)
239 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
240 ON CONFLICT(id) DO UPDATE SET
241 relation = excluded.relation,
242 invalid_at = excluded.invalid_at
243 "#,
244 )
245 .bind(rel.id.to_string())
246 .bind(rel.source_id.to_string())
247 .bind(&rel.relation)
248 .bind(rel.target_id.to_string())
249 .bind(&rel.scope.org_id)
250 .bind(rel.scope.agent_id.as_deref())
251 .bind(rel.scope.user_id.as_deref())
252 .bind(rel.scope.session_id.as_deref())
253 .bind(rel.valid_from.to_rfc3339())
254 .bind(rel.invalid_at.map(|dt| dt.to_rfc3339()))
255 .bind(rel.created_at.to_rfc3339())
256 .execute(&self.pool)
257 .await
258 .map_err(|e| MemoryError::Database(e.to_string()))?;
259
260 Ok(())
261 }
262
263 async fn invalidate_relationship(
264 &self,
265 id: RelationshipId,
266 invalid_at: DateTime<Utc>,
267 ) -> Result<(), MemoryError> {
268 sqlx::query("UPDATE relationships SET invalid_at = ? WHERE id = ?")
269 .bind(invalid_at.to_rfc3339())
270 .bind(id.to_string())
271 .execute(&self.pool)
272 .await
273 .map_err(|e| MemoryError::Database(e.to_string()))?;
274 Ok(())
275 }
276
277 async fn get_entity(&self, id: EntityId) -> Result<Option<Entity>, MemoryError> {
278 let row = sqlx::query_as::<_, EntityRow>("SELECT * FROM entities WHERE id = ?")
279 .bind(id.to_string())
280 .fetch_optional(&self.pool)
281 .await
282 .map_err(|e| MemoryError::Database(e.to_string()))?;
283
284 row.map(row_to_entity).transpose()
285 }
286
287 async fn neighbors(
288 &self,
289 id: EntityId,
290 depth: u8,
291 as_of: Option<DateTime<Utc>>,
292 ) -> Result<SubGraph, MemoryError> {
293 let validity_clause = match as_of {
295 Some(t) => {
296 let s = t.to_rfc3339();
297 format!("valid_from <= '{s}' AND (invalid_at IS NULL OR invalid_at > '{s}')")
298 }
299 None => "invalid_at IS NULL".to_string(),
300 };
301
302 let mut visited_entities: HashSet<EntityId> = HashSet::new();
303 visited_entities.insert(id);
304
305 let mut discovered_entities: HashMap<EntityId, Entity> = HashMap::new();
306 let mut discovered_relationships: HashMap<RelationshipId, Relationship> = HashMap::new();
307
308 let mut queue: VecDeque<(EntityId, u8)> = VecDeque::new();
310 queue.push_back((id, depth));
311
312 while let Some((current_id, remaining)) = queue.pop_front() {
313 if remaining == 0 {
314 continue;
315 }
316
317 let sql = format!(
319 "SELECT * FROM relationships WHERE (source_id = ? OR target_id = ?) AND {validity_clause}"
320 );
321
322 let rel_rows = sqlx::query_as::<_, RelationshipRow>(&sql)
323 .bind(current_id.to_string())
324 .bind(current_id.to_string())
325 .fetch_all(&self.pool)
326 .await
327 .map_err(|e| MemoryError::Database(e.to_string()))?;
328
329 for row in rel_rows {
330 let rel = row_to_relationship(row)?;
331 let neighbor_id = if rel.source_id == current_id {
332 rel.target_id
333 } else {
334 rel.source_id
335 };
336
337 discovered_relationships.entry(rel.id).or_insert(rel);
339
340 if !visited_entities.contains(&neighbor_id) {
342 visited_entities.insert(neighbor_id);
343
344 if let Some(entity) = self.get_entity(neighbor_id).await? {
346 discovered_entities.entry(neighbor_id).or_insert(entity);
347 }
348
349 queue.push_back((neighbor_id, remaining - 1));
350 }
351 }
352 }
353
354 Ok(SubGraph {
355 entities: discovered_entities.into_values().collect(),
356 relationships: discovered_relationships.into_values().collect(),
357 })
358 }
359
360 async fn search_entities(&self, query: &str, top_k: usize) -> Result<Vec<Entity>, MemoryError> {
361 let escaped = query.replace('\'', "''");
362 let sql = format!("SELECT * FROM entities WHERE name LIKE '%{escaped}%' LIMIT {top_k}");
363
364 let rows = sqlx::query_as::<_, EntityRow>(&sql)
365 .fetch_all(&self.pool)
366 .await
367 .map_err(|e| MemoryError::Database(e.to_string()))?;
368
369 rows.into_iter().map(row_to_entity).collect()
370 }
371
372 async fn delete_by_scope(&self, scope: &Scope) -> Result<u64, MemoryError> {
373 let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
375
376 if let Some(ref user_id) = scope.user_id {
377 wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
378 }
379 if let Some(ref agent_id) = scope.agent_id {
380 wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
381 }
382 if let Some(ref session_id) = scope.session_id {
383 wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
384 }
385
386 let where_clause = wheres.join(" AND ");
387
388 let rel_sql = format!("DELETE FROM relationships WHERE {where_clause}");
390 let rel_result = sqlx::query(&rel_sql)
391 .execute(&self.pool)
392 .await
393 .map_err(|e| MemoryError::Database(e.to_string()))?;
394
395 let ent_sql = format!("DELETE FROM entities WHERE {where_clause}");
397 let ent_result = sqlx::query(&ent_sql)
398 .execute(&self.pool)
399 .await
400 .map_err(|e| MemoryError::Database(e.to_string()))?;
401
402 Ok(rel_result.rows_affected() + ent_result.rows_affected())
403 }
404}