1use crate::fact::{Fact, FactFilter, FactId, FactPatch, MemoryTier};
8use crate::scope::Scope;
9use crate::store::{FactStore, MemoryError, StoreStats};
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::SqlitePool;
13use uuid::Uuid;
14
15pub const FACT_STORE_DDL: &str = r#"
20CREATE TABLE IF NOT EXISTS facts (
21 id TEXT PRIMARY KEY,
22 text TEXT NOT NULL,
23 org_id TEXT NOT NULL DEFAULT 'default',
24 agent_id TEXT,
25 user_id TEXT,
26 session_id TEXT,
27 tier TEXT NOT NULL DEFAULT 'conversation',
28 category TEXT,
29 source TEXT,
30 confidence REAL,
31 valid_from TEXT NOT NULL,
32 invalid_at TEXT,
33 created_at TEXT NOT NULL,
34 entity_refs TEXT NOT NULL DEFAULT '[]',
35 supersedes TEXT,
36 superseded_by TEXT,
37 access_count INTEGER NOT NULL DEFAULT 0,
38 last_accessed TEXT,
39 metadata TEXT NOT NULL DEFAULT 'null'
40);
41CREATE INDEX IF NOT EXISTS idx_facts_org_id ON facts (org_id);
42CREATE INDEX IF NOT EXISTS idx_facts_user_id ON facts (user_id);
43CREATE INDEX IF NOT EXISTS idx_facts_agent_id ON facts (agent_id);
44CREATE INDEX IF NOT EXISTS idx_facts_session_id ON facts (session_id);
45CREATE INDEX IF NOT EXISTS idx_facts_tier ON facts (tier);
46CREATE INDEX IF NOT EXISTS idx_facts_category ON facts (category);
47CREATE INDEX IF NOT EXISTS idx_facts_valid_from ON facts (valid_from);
48CREATE INDEX IF NOT EXISTS idx_facts_invalid_at ON facts (invalid_at);
49"#;
50
51const FTS5_DDL: &[&str] = &[
56 "CREATE VIRTUAL TABLE IF NOT EXISTS facts_fts USING fts5(fact_id UNINDEXED, text)",
57 "CREATE TRIGGER IF NOT EXISTS facts_ai AFTER INSERT ON facts BEGIN INSERT INTO facts_fts(fact_id, text) VALUES (new.id, new.text); END",
58 "CREATE TRIGGER IF NOT EXISTS facts_ad AFTER DELETE ON facts BEGIN DELETE FROM facts_fts WHERE fact_id = old.id; END",
59];
60
61pub struct SqliteFactStore {
66 pool: SqlitePool,
67}
68
69impl SqliteFactStore {
70 pub fn new(pool: SqlitePool) -> Self {
71 Self { pool }
72 }
73
74 pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
76 let pool = SqlitePool::connect(database_url).await?;
77 Ok(Self { pool })
78 }
79
80 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
82 for stmt in FACT_STORE_DDL.split(';') {
83 let stmt = stmt.trim();
84 if stmt.is_empty() {
85 continue;
86 }
87 sqlx::query(stmt).execute(&self.pool).await?;
88 }
89 for stmt in FTS5_DDL {
91 sqlx::query(stmt).execute(&self.pool).await?;
92 }
93 Ok(())
94 }
95}
96
97#[derive(sqlx::FromRow)]
102struct FactRow {
103 id: String,
104 text: String,
105 org_id: String,
106 agent_id: Option<String>,
107 user_id: Option<String>,
108 session_id: Option<String>,
109 tier: String,
110 category: Option<String>,
111 source: Option<String>,
112 confidence: Option<f64>,
113 valid_from: String,
114 invalid_at: Option<String>,
115 created_at: String,
116 entity_refs: String,
117 supersedes: Option<String>,
118 superseded_by: Option<String>,
119 access_count: i64,
120 last_accessed: Option<String>,
121 metadata: String,
122}
123
124fn parse_dt(s: &str) -> Result<DateTime<Utc>, MemoryError> {
129 DateTime::parse_from_rfc3339(s)
130 .map(|dt| dt.with_timezone(&Utc))
131 .map_err(|e| MemoryError::Serialization(e.to_string()))
132}
133
134fn parse_opt_dt(s: &Option<String>) -> Result<Option<DateTime<Utc>>, MemoryError> {
135 match s {
136 None => Ok(None),
137 Some(s) => parse_dt(s).map(Some),
138 }
139}
140
141fn tier_from_str(s: &str) -> MemoryTier {
142 match s {
143 "working" => MemoryTier::Working,
144 "knowledge" => MemoryTier::Knowledge,
145 _ => MemoryTier::Conversation,
146 }
147}
148
149fn tier_to_str(t: &MemoryTier) -> &'static str {
150 match t {
151 MemoryTier::Working => "working",
152 MemoryTier::Conversation => "conversation",
153 MemoryTier::Knowledge => "knowledge",
154 }
155}
156
157fn row_to_fact(row: FactRow) -> Result<Fact, MemoryError> {
158 let id = Uuid::parse_str(&row.id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
159
160 let entity_refs: Vec<Uuid> = {
161 let strings: Vec<String> = serde_json::from_str(&row.entity_refs)
162 .map_err(|e| MemoryError::Serialization(e.to_string()))?;
163 strings
164 .iter()
165 .map(|s| Uuid::parse_str(s).map_err(|e| MemoryError::Serialization(e.to_string())))
166 .collect::<Result<Vec<_>, _>>()?
167 };
168
169 let metadata: serde_json::Map<String, serde_json::Value> =
170 if row.metadata == "null" || row.metadata.is_empty() {
171 serde_json::Map::new()
172 } else {
173 serde_json::from_str(&row.metadata)
174 .map_err(|e| MemoryError::Serialization(e.to_string()))?
175 };
176
177 let supersedes = row
178 .supersedes
179 .as_deref()
180 .map(|s| Uuid::parse_str(s).map_err(|e| MemoryError::Serialization(e.to_string())))
181 .transpose()?;
182
183 let superseded_by = row
184 .superseded_by
185 .as_deref()
186 .map(|s| Uuid::parse_str(s).map_err(|e| MemoryError::Serialization(e.to_string())))
187 .transpose()?;
188
189 Ok(Fact {
190 id,
191 text: row.text,
192 scope: Scope {
193 org_id: row.org_id,
194 agent_id: row.agent_id,
195 user_id: row.user_id,
196 session_id: row.session_id,
197 },
198 tier: tier_from_str(&row.tier),
199 category: row.category,
200 source: row.source,
201 confidence: row.confidence.map(|c| c as f32),
202 valid_from: parse_dt(&row.valid_from)?,
203 invalid_at: parse_opt_dt(&row.invalid_at)?,
204 created_at: parse_dt(&row.created_at)?,
205 embedding: Vec::new(),
207 entity_refs,
208 supersedes,
209 superseded_by,
210 access_count: row.access_count as u64,
211 last_accessed: parse_opt_dt(&row.last_accessed)?,
212 metadata,
213 })
214}
215
216#[async_trait]
221impl FactStore for SqliteFactStore {
222 async fn insert_fact(&self, fact: Fact) -> Result<FactId, MemoryError> {
223 let entity_refs_json = {
224 let strs: Vec<String> = fact.entity_refs.iter().map(|u| u.to_string()).collect();
225 serde_json::to_string(&strs).map_err(|e| MemoryError::Serialization(e.to_string()))?
226 };
227
228 let metadata_json = if fact.metadata.is_empty() {
229 "null".to_string()
230 } else {
231 serde_json::to_string(&fact.metadata)
232 .map_err(|e| MemoryError::Serialization(e.to_string()))?
233 };
234
235 sqlx::query(
236 r#"
237 INSERT OR IGNORE INTO facts
238 (id, text, org_id, agent_id, user_id, session_id,
239 tier, category, source, confidence,
240 valid_from, invalid_at, created_at,
241 entity_refs, supersedes, superseded_by,
242 access_count, last_accessed, metadata)
243 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
244 "#,
245 )
246 .bind(fact.id.to_string())
247 .bind(&fact.text)
248 .bind(&fact.scope.org_id)
249 .bind(fact.scope.agent_id.as_deref())
250 .bind(fact.scope.user_id.as_deref())
251 .bind(fact.scope.session_id.as_deref())
252 .bind(tier_to_str(&fact.tier))
253 .bind(fact.category.as_deref())
254 .bind(fact.source.as_deref())
255 .bind(fact.confidence.map(|c| c as f64))
256 .bind(fact.valid_from.to_rfc3339())
257 .bind(fact.invalid_at.map(|dt| dt.to_rfc3339()))
258 .bind(fact.created_at.to_rfc3339())
259 .bind(entity_refs_json)
260 .bind(fact.supersedes.map(|u| u.to_string()))
261 .bind(fact.superseded_by.map(|u| u.to_string()))
262 .bind(fact.access_count as i64)
263 .bind(fact.last_accessed.map(|dt| dt.to_rfc3339()))
264 .bind(metadata_json)
265 .execute(&self.pool)
266 .await
267 .map_err(|e| MemoryError::Database(e.to_string()))?;
268
269 Ok(fact.id)
270 }
271
272 async fn get_fact(&self, id: FactId) -> Result<Fact, MemoryError> {
273 let row = sqlx::query_as::<_, FactRow>("SELECT * FROM facts WHERE id = ?")
274 .bind(id.to_string())
275 .fetch_optional(&self.pool)
276 .await
277 .map_err(|e| MemoryError::Database(e.to_string()))?
278 .ok_or_else(|| MemoryError::NotFound(id.to_string()))?;
279
280 row_to_fact(row)
281 }
282
283 async fn update_fact(&self, id: FactId, patch: FactPatch) -> Result<Fact, MemoryError> {
284 let mut cols: Vec<&'static str> = Vec::new();
287 let mut vals: Vec<String> = Vec::new();
288
289 if let Some(ref text) = patch.text {
290 cols.push("text = ?");
291 vals.push(text.clone());
292 }
293 if let Some(ref tier) = patch.tier {
294 cols.push("tier = ?");
295 vals.push(tier_to_str(tier).to_string());
296 }
297 if let Some(ref category) = patch.category {
298 cols.push("category = ?");
299 vals.push(category.clone());
300 }
301 if let Some(ref source) = patch.source {
302 cols.push("source = ?");
303 vals.push(source.clone());
304 }
305 if let Some(confidence) = patch.confidence {
306 cols.push("confidence = ?");
307 vals.push((confidence as f64).to_string());
308 }
309 if let Some(invalid_at) = patch.invalid_at {
310 cols.push("invalid_at = ?");
311 vals.push(invalid_at.to_rfc3339());
312 }
313 if let Some(superseded_by) = patch.superseded_by {
314 cols.push("superseded_by = ?");
315 vals.push(superseded_by.to_string());
316 }
317 if !patch.metadata.is_empty() {
318 let json = serde_json::to_string(&patch.metadata)
319 .map_err(|e| MemoryError::Serialization(e.to_string()))?;
320 cols.push("metadata = ?");
321 vals.push(json);
322 }
323
324 if !cols.is_empty() {
325 let sql = format!("UPDATE facts SET {} WHERE id = ?", cols.join(", "));
326 let mut q = sqlx::query(&sql);
327 for v in &vals {
328 q = q.bind(v.as_str());
329 }
330 q = q.bind(id.to_string());
331 q.execute(&self.pool)
332 .await
333 .map_err(|e| MemoryError::Database(e.to_string()))?;
334 }
335
336 self.get_fact(id).await
337 }
338
339 async fn list_facts(&self, filter: &FactFilter) -> Result<Vec<Fact>, MemoryError> {
340 let mut wheres: Vec<String> = vec!["1=1".to_string()];
341
342 if let Some(ref scope) = filter.scope {
343 wheres.push(format!("org_id = '{}'", scope.org_id.replace('\'', "''")));
344 if let Some(ref user_id) = scope.user_id {
345 wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
346 }
347 if let Some(ref agent_id) = scope.agent_id {
348 wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
349 }
350 if let Some(ref session_id) = scope.session_id {
351 wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
352 }
353 }
354
355 if let Some(ref tier) = filter.tier {
356 wheres.push(format!("tier = '{}'", tier_to_str(tier)));
357 }
358
359 if let Some(ref category) = filter.category {
360 wheres.push(format!("category = '{}'", category.replace('\'', "''")));
361 }
362
363 if let Some(as_of) = filter.as_of {
364 let s = as_of.to_rfc3339();
365 wheres.push(format!("valid_from <= '{s}'"));
366 wheres.push(format!("(invalid_at IS NULL OR invalid_at > '{s}')"));
367 } else if filter.valid_only {
368 wheres.push("invalid_at IS NULL".to_string());
369 }
370
371 if let Some(ref text_contains) = filter.text_contains {
372 let escaped = text_contains.replace('\'', "''");
373 wheres.push(format!("text LIKE '%{escaped}%'"));
374 }
375
376 let where_clause = wheres.join(" AND ");
377 let sql = format!(
378 "SELECT * FROM facts WHERE {where_clause} ORDER BY created_at DESC LIMIT {} OFFSET {}",
379 filter.limit, filter.offset
380 );
381
382 let rows = sqlx::query_as::<_, FactRow>(&sql)
383 .fetch_all(&self.pool)
384 .await
385 .map_err(|e| MemoryError::Database(e.to_string()))?;
386
387 rows.into_iter().map(row_to_fact).collect()
388 }
389
390 async fn invalidate_fact(&self, id: FactId) -> Result<(), MemoryError> {
391 let now = Utc::now().to_rfc3339();
392 sqlx::query("UPDATE facts SET invalid_at = ? WHERE id = ?")
393 .bind(&now)
394 .bind(id.to_string())
395 .execute(&self.pool)
396 .await
397 .map_err(|e| MemoryError::Database(e.to_string()))?;
398 Ok(())
399 }
400
401 async fn delete_scope_data(&self, scope: &Scope) -> Result<u64, MemoryError> {
402 let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
403
404 if let Some(ref user_id) = scope.user_id {
405 wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
406 }
407 if let Some(ref agent_id) = scope.agent_id {
408 wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
409 }
410 if let Some(ref session_id) = scope.session_id {
411 wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
412 }
413
414 let where_clause = wheres.join(" AND ");
415 let sql = format!("DELETE FROM facts WHERE {where_clause}");
416
417 let result = sqlx::query(&sql)
418 .execute(&self.pool)
419 .await
420 .map_err(|e| MemoryError::Database(e.to_string()))?;
421
422 Ok(result.rows_affected())
423 }
424
425 async fn export(&self, filter: &FactFilter) -> Result<Vec<Fact>, MemoryError> {
426 self.list_facts(filter).await
427 }
428
429 async fn import(&self, facts: Vec<Fact>) -> Result<u64, MemoryError> {
430 let mut imported: u64 = 0;
431 for fact in facts {
432 let result = sqlx::query(
433 r#"
434 INSERT OR IGNORE INTO facts
435 (id, text, org_id, agent_id, user_id, session_id,
436 tier, category, source, confidence,
437 valid_from, invalid_at, created_at,
438 entity_refs, supersedes, superseded_by,
439 access_count, last_accessed, metadata)
440 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
441 "#,
442 )
443 .bind(fact.id.to_string())
444 .bind(&fact.text)
445 .bind(&fact.scope.org_id)
446 .bind(fact.scope.agent_id.as_deref())
447 .bind(fact.scope.user_id.as_deref())
448 .bind(fact.scope.session_id.as_deref())
449 .bind(tier_to_str(&fact.tier))
450 .bind(fact.category.as_deref())
451 .bind(fact.source.as_deref())
452 .bind(fact.confidence.map(|c| c as f64))
453 .bind(fact.valid_from.to_rfc3339())
454 .bind(fact.invalid_at.map(|dt| dt.to_rfc3339()))
455 .bind(fact.created_at.to_rfc3339())
456 .bind({
457 let strs: Vec<String> = fact.entity_refs.iter().map(|u| u.to_string()).collect();
458 serde_json::to_string(&strs)
459 .map_err(|e| MemoryError::Serialization(e.to_string()))?
460 })
461 .bind(fact.supersedes.map(|u| u.to_string()))
462 .bind(fact.superseded_by.map(|u| u.to_string()))
463 .bind(fact.access_count as i64)
464 .bind(fact.last_accessed.map(|dt| dt.to_rfc3339()))
465 .bind(if fact.metadata.is_empty() {
466 "null".to_string()
467 } else {
468 serde_json::to_string(&fact.metadata)
469 .map_err(|e| MemoryError::Serialization(e.to_string()))?
470 })
471 .execute(&self.pool)
472 .await
473 .map_err(|e| MemoryError::Database(e.to_string()))?;
474
475 imported += result.rows_affected();
476 }
477 Ok(imported)
478 }
479
480 async fn stats(&self) -> Result<StoreStats, MemoryError> {
481 let (total,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM facts")
482 .fetch_one(&self.pool)
483 .await
484 .map_err(|e| MemoryError::Database(e.to_string()))?;
485
486 let (valid,): (i64,) =
487 sqlx::query_as("SELECT COUNT(*) FROM facts WHERE invalid_at IS NULL")
488 .fetch_one(&self.pool)
489 .await
490 .map_err(|e| MemoryError::Database(e.to_string()))?;
491
492 let (invalidated,): (i64,) =
493 sqlx::query_as("SELECT COUNT(*) FROM facts WHERE invalid_at IS NOT NULL")
494 .fetch_one(&self.pool)
495 .await
496 .map_err(|e| MemoryError::Database(e.to_string()))?;
497
498 Ok(StoreStats {
499 total_facts: total as u64,
500 valid_facts: valid as u64,
501 invalidated_facts: invalidated as u64,
502 total_entities: 0,
503 total_relationships: 0,
504 })
505 }
506
507 async fn record_access(&self, id: FactId) -> Result<(), MemoryError> {
508 let now = Utc::now().to_rfc3339();
509 sqlx::query(
510 "UPDATE facts SET access_count = access_count + 1, last_accessed = ? WHERE id = ?",
511 )
512 .bind(&now)
513 .bind(id.to_string())
514 .execute(&self.pool)
515 .await
516 .map_err(|e| MemoryError::Database(e.to_string()))?;
517 Ok(())
518 }
519
520 async fn keyword_search(
521 &self,
522 query: &str,
523 scope: &Scope,
524 top_k: usize,
525 ) -> Result<Vec<Fact>, MemoryError> {
526 let sql = r#"
527 SELECT f.*
528 FROM facts_fts fts
529 INNER JOIN facts f ON f.id = fts.fact_id
530 WHERE facts_fts MATCH ?
531 AND f.org_id = ?
532 AND (? IS NULL OR f.user_id = ?)
533 AND f.invalid_at IS NULL
534 ORDER BY fts.rank
535 LIMIT ?
536 "#;
537
538 let normalized = normalize_fts_query(query);
539 if normalized.is_empty() {
540 return Ok(Vec::new());
541 }
542
543 let rows = sqlx::query_as::<_, FactRow>(sql)
544 .bind(&normalized)
545 .bind(&scope.org_id)
546 .bind(scope.user_id.as_deref())
547 .bind(scope.user_id.as_deref())
548 .bind(top_k as i64)
549 .fetch_all(&self.pool)
550 .await
551 .map_err(|e| MemoryError::Database(e.to_string()))?;
552
553 rows.into_iter().map(row_to_fact).collect()
554 }
555}
556
557fn normalize_fts_query(query: &str) -> String {
566 let trimmed = query.trim();
567 if trimmed.is_empty() {
568 return String::new();
569 }
570 if trimmed.contains('"') || trimmed.contains(':') || trimmed.contains('(') {
572 return trimmed.to_string();
573 }
574
575 trimmed
576 .split_whitespace()
577 .filter_map(|token| {
578 let cleaned: String = token
579 .chars()
580 .filter(|c| c.is_alphanumeric() || *c == '_')
581 .collect();
582 if cleaned.is_empty() {
583 None
584 } else {
585 Some(format!("{cleaned}*"))
586 }
587 })
588 .collect::<Vec<_>>()
589 .join(" ")
590}
591
592#[cfg(test)]
593mod tests {
594 use super::normalize_fts_query;
595
596 #[test]
597 fn single_token_gets_prefix_star() {
598 assert_eq!(normalize_fts_query("peanut"), "peanut*");
599 }
600
601 #[test]
602 fn multi_token_each_gets_prefix_star() {
603 assert_eq!(normalize_fts_query("food allergies"), "food* allergies*");
604 }
605
606 #[test]
607 fn punctuation_stripped() {
608 assert_eq!(normalize_fts_query("what's up?"), "whats* up*");
609 }
610
611 #[test]
612 fn empty_query_returns_empty() {
613 assert_eq!(normalize_fts_query(" "), "");
614 }
615
616 #[test]
617 fn quoted_phrase_passes_through() {
618 let q = "\"exact phrase\"";
619 assert_eq!(normalize_fts_query(q), q);
620 }
621}