1use crate::fact::{Fact, FactFilter, FactId, FactPatch, MemoryTier};
8use crate::scope::Scope;
9use crate::store::{FactStore, MemoryError, StoreStats};
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::PgPool;
13use uuid::Uuid;
14
15const PG_FACT_STORE_DDL: &[&str] = &[
22 r#"
23 CREATE TABLE IF NOT EXISTS facts (
24 id UUID PRIMARY KEY,
25 text TEXT NOT NULL,
26 org_id TEXT NOT NULL DEFAULT 'default',
27 agent_id TEXT,
28 user_id TEXT,
29 session_id TEXT,
30 tier TEXT NOT NULL DEFAULT 'conversation',
31 category TEXT,
32 source TEXT,
33 confidence DOUBLE PRECISION,
34 valid_from TIMESTAMPTZ NOT NULL,
35 invalid_at TIMESTAMPTZ,
36 created_at TIMESTAMPTZ NOT NULL,
37 entity_refs JSONB NOT NULL DEFAULT '[]',
38 supersedes UUID,
39 superseded_by UUID,
40 access_count BIGINT NOT NULL DEFAULT 0,
41 last_accessed TIMESTAMPTZ,
42 metadata JSONB NOT NULL DEFAULT 'null',
43 search_vector tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
44 )
45 "#,
46 "CREATE INDEX IF NOT EXISTS idx_pg_facts_org_id ON facts (org_id)",
47 "CREATE INDEX IF NOT EXISTS idx_pg_facts_user_id ON facts (user_id)",
48 "CREATE INDEX IF NOT EXISTS idx_pg_facts_agent_id ON facts (agent_id)",
49 "CREATE INDEX IF NOT EXISTS idx_pg_facts_session_id ON facts (session_id)",
50 "CREATE INDEX IF NOT EXISTS idx_pg_facts_tier ON facts (tier)",
51 "CREATE INDEX IF NOT EXISTS idx_pg_facts_category ON facts (category)",
52 "CREATE INDEX IF NOT EXISTS idx_pg_facts_valid_from ON facts (valid_from)",
53 "CREATE INDEX IF NOT EXISTS idx_pg_facts_invalid_at ON facts (invalid_at)",
54 "CREATE INDEX IF NOT EXISTS idx_pg_facts_fts ON facts USING GIN (search_vector)",
55];
56
57pub struct PostgresFactStore {
62 pool: PgPool,
63}
64
65impl PostgresFactStore {
66 pub fn new(pool: PgPool) -> Self {
67 Self { pool }
68 }
69
70 pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
72 let pool = PgPool::connect(database_url).await?;
73 Ok(Self { pool })
74 }
75
76 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
78 for stmt in PG_FACT_STORE_DDL {
79 sqlx::query(stmt).execute(&self.pool).await?;
80 }
81 Ok(())
82 }
83}
84
85#[derive(sqlx::FromRow)]
90struct FactRow {
91 id: Uuid,
92 text: String,
93 org_id: String,
94 agent_id: Option<String>,
95 user_id: Option<String>,
96 session_id: Option<String>,
97 tier: String,
98 category: Option<String>,
99 source: Option<String>,
100 confidence: Option<f64>,
101 valid_from: DateTime<Utc>,
102 invalid_at: Option<DateTime<Utc>>,
103 created_at: DateTime<Utc>,
104 entity_refs: serde_json::Value,
105 supersedes: Option<Uuid>,
106 superseded_by: Option<Uuid>,
107 access_count: i64,
108 last_accessed: Option<DateTime<Utc>>,
109 metadata: serde_json::Value,
110}
111
112fn tier_from_str(s: &str) -> MemoryTier {
117 match s {
118 "working" => MemoryTier::Working,
119 "knowledge" => MemoryTier::Knowledge,
120 _ => MemoryTier::Conversation,
121 }
122}
123
124fn tier_to_str(t: &MemoryTier) -> &'static str {
125 match t {
126 MemoryTier::Working => "working",
127 MemoryTier::Conversation => "conversation",
128 MemoryTier::Knowledge => "knowledge",
129 }
130}
131
132fn row_to_fact(row: FactRow) -> Result<Fact, MemoryError> {
133 let entity_refs: Vec<Uuid> = {
134 let strings: Vec<String> = serde_json::from_value(row.entity_refs.clone())
135 .map_err(|e| MemoryError::Serialization(e.to_string()))?;
136 strings
137 .iter()
138 .map(|s| Uuid::parse_str(s).map_err(|e| MemoryError::Serialization(e.to_string())))
139 .collect::<Result<Vec<_>, _>>()?
140 };
141
142 let metadata: serde_json::Map<String, serde_json::Value> = match &row.metadata {
143 serde_json::Value::Null => serde_json::Map::new(),
144 serde_json::Value::Object(map) => map.clone(),
145 other => serde_json::from_value(other.clone())
146 .map_err(|e| MemoryError::Serialization(e.to_string()))?,
147 };
148
149 Ok(Fact {
150 id: row.id,
151 text: row.text,
152 scope: Scope {
153 org_id: row.org_id,
154 agent_id: row.agent_id,
155 user_id: row.user_id,
156 session_id: row.session_id,
157 },
158 tier: tier_from_str(&row.tier),
159 category: row.category,
160 source: row.source,
161 confidence: row.confidence.map(|c| c as f32),
162 valid_from: row.valid_from,
163 invalid_at: row.invalid_at,
164 created_at: row.created_at,
165 embedding: Vec::new(),
166 entity_refs,
167 supersedes: row.supersedes,
168 superseded_by: row.superseded_by,
169 access_count: row.access_count as u64,
170 last_accessed: row.last_accessed,
171 metadata,
172 })
173}
174
175#[async_trait]
180impl FactStore for PostgresFactStore {
181 async fn insert_fact(&self, fact: Fact) -> Result<FactId, MemoryError> {
182 let entity_refs_json = {
183 let strs: Vec<String> = fact.entity_refs.iter().map(|u| u.to_string()).collect();
184 serde_json::to_value(&strs).map_err(|e| MemoryError::Serialization(e.to_string()))?
185 };
186
187 let metadata_json = if fact.metadata.is_empty() {
188 serde_json::Value::Null
189 } else {
190 serde_json::to_value(&fact.metadata)
191 .map_err(|e| MemoryError::Serialization(e.to_string()))?
192 };
193
194 sqlx::query(
195 r#"
196 INSERT INTO facts
197 (id, text, org_id, agent_id, user_id, session_id,
198 tier, category, source, confidence,
199 valid_from, invalid_at, created_at,
200 entity_refs, supersedes, superseded_by,
201 access_count, last_accessed, metadata)
202 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
203 ON CONFLICT (id) DO NOTHING
204 "#,
205 )
206 .bind(fact.id)
207 .bind(&fact.text)
208 .bind(&fact.scope.org_id)
209 .bind(fact.scope.agent_id.as_deref())
210 .bind(fact.scope.user_id.as_deref())
211 .bind(fact.scope.session_id.as_deref())
212 .bind(tier_to_str(&fact.tier))
213 .bind(fact.category.as_deref())
214 .bind(fact.source.as_deref())
215 .bind(fact.confidence.map(|c| c as f64))
216 .bind(fact.valid_from)
217 .bind(fact.invalid_at)
218 .bind(fact.created_at)
219 .bind(&entity_refs_json)
220 .bind(fact.supersedes)
221 .bind(fact.superseded_by)
222 .bind(fact.access_count as i64)
223 .bind(fact.last_accessed)
224 .bind(&metadata_json)
225 .execute(&self.pool)
226 .await
227 .map_err(|e| MemoryError::Database(e.to_string()))?;
228
229 Ok(fact.id)
230 }
231
232 async fn get_fact(&self, id: FactId) -> Result<Fact, MemoryError> {
233 let row = sqlx::query_as::<_, FactRow>(
234 "SELECT id, text, org_id, agent_id, user_id, session_id, tier, category, source, confidence, valid_from, invalid_at, created_at, entity_refs, supersedes, superseded_by, access_count, last_accessed, metadata FROM facts WHERE id = $1",
235 )
236 .bind(id)
237 .fetch_optional(&self.pool)
238 .await
239 .map_err(|e| MemoryError::Database(e.to_string()))?
240 .ok_or_else(|| MemoryError::NotFound(id.to_string()))?;
241
242 row_to_fact(row)
243 }
244
245 async fn update_fact(&self, id: FactId, patch: FactPatch) -> Result<Fact, MemoryError> {
246 let mut set_clauses: Vec<String> = Vec::new();
247 let mut vals: Vec<String> = Vec::new();
248 let mut param_idx: usize = 1;
249
250 if let Some(ref text) = patch.text {
251 set_clauses.push(format!("text = ${param_idx}"));
252 vals.push(text.clone());
253 param_idx += 1;
254 }
255 if let Some(ref tier) = patch.tier {
256 set_clauses.push(format!("tier = ${param_idx}"));
257 vals.push(tier_to_str(tier).to_string());
258 param_idx += 1;
259 }
260 if let Some(ref category) = patch.category {
261 set_clauses.push(format!("category = ${param_idx}"));
262 vals.push(category.clone());
263 param_idx += 1;
264 }
265 if let Some(ref source) = patch.source {
266 set_clauses.push(format!("source = ${param_idx}"));
267 vals.push(source.clone());
268 param_idx += 1;
269 }
270 if let Some(confidence) = patch.confidence {
271 set_clauses.push(format!("confidence = ${param_idx}"));
272 vals.push((confidence as f64).to_string());
273 param_idx += 1;
274 }
275 if let Some(invalid_at) = patch.invalid_at {
276 set_clauses.push(format!("invalid_at = ${param_idx}"));
277 vals.push(invalid_at.to_rfc3339());
278 param_idx += 1;
279 }
280 if let Some(superseded_by) = patch.superseded_by {
281 set_clauses.push(format!("superseded_by = ${param_idx}"));
282 vals.push(superseded_by.to_string());
283 param_idx += 1;
284 }
285 if !patch.metadata.is_empty() {
286 let json = serde_json::to_string(&patch.metadata)
287 .map_err(|e| MemoryError::Serialization(e.to_string()))?;
288 set_clauses.push(format!("metadata = ${param_idx}::jsonb"));
289 vals.push(json);
290 param_idx += 1;
291 }
292
293 if !set_clauses.is_empty() {
294 let sql = format!(
295 "UPDATE facts SET {} WHERE id = ${param_idx}",
296 set_clauses.join(", ")
297 );
298 let mut q = sqlx::query(&sql);
299 for v in &vals {
300 q = q.bind(v.as_str());
301 }
302 q = q.bind(id.to_string());
303 q.execute(&self.pool)
304 .await
305 .map_err(|e| MemoryError::Database(e.to_string()))?;
306 }
307
308 self.get_fact(id).await
309 }
310
311 async fn list_facts(&self, filter: &FactFilter) -> Result<Vec<Fact>, MemoryError> {
312 let mut wheres: Vec<String> = vec!["1=1".to_string()];
313
314 if let Some(ref scope) = filter.scope {
315 wheres.push(format!("org_id = '{}'", scope.org_id.replace('\'', "''")));
316 if let Some(ref user_id) = scope.user_id {
317 wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
318 }
319 if let Some(ref agent_id) = scope.agent_id {
320 wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
321 }
322 if let Some(ref session_id) = scope.session_id {
323 wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
324 }
325 }
326
327 if let Some(ref tier) = filter.tier {
328 wheres.push(format!("tier = '{}'", tier_to_str(tier)));
329 }
330
331 if let Some(ref category) = filter.category {
332 wheres.push(format!("category = '{}'", category.replace('\'', "''")));
333 }
334
335 if let Some(as_of) = filter.as_of {
336 let s = as_of.to_rfc3339();
337 wheres.push(format!("valid_from <= '{s}'"));
338 wheres.push(format!("(invalid_at IS NULL OR invalid_at > '{s}')"));
339 } else if filter.valid_only {
340 wheres.push("invalid_at IS NULL".to_string());
341 }
342
343 if let Some(ref text_contains) = filter.text_contains {
344 let escaped = text_contains.replace('\'', "''");
345 wheres.push(format!("text LIKE '%{escaped}%'"));
346 }
347
348 let where_clause = wheres.join(" AND ");
349 let sql = format!(
350 "SELECT id, text, org_id, agent_id, user_id, session_id, tier, category, source, confidence, valid_from, invalid_at, created_at, entity_refs, supersedes, superseded_by, access_count, last_accessed, metadata FROM facts WHERE {where_clause} ORDER BY created_at DESC LIMIT {} OFFSET {}",
351 filter.limit, filter.offset
352 );
353
354 let rows = sqlx::query_as::<_, FactRow>(&sql)
355 .fetch_all(&self.pool)
356 .await
357 .map_err(|e| MemoryError::Database(e.to_string()))?;
358
359 rows.into_iter().map(row_to_fact).collect()
360 }
361
362 async fn invalidate_fact(&self, id: FactId) -> Result<(), MemoryError> {
363 let now = Utc::now();
364 sqlx::query("UPDATE facts SET invalid_at = $1 WHERE id = $2")
365 .bind(now)
366 .bind(id)
367 .execute(&self.pool)
368 .await
369 .map_err(|e| MemoryError::Database(e.to_string()))?;
370 Ok(())
371 }
372
373 async fn delete_scope_data(&self, scope: &Scope) -> Result<u64, MemoryError> {
374 let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
375
376 if let Some(ref user_id) = scope.user_id {
377 wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
378 }
379 if let Some(ref agent_id) = scope.agent_id {
380 wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
381 }
382 if let Some(ref session_id) = scope.session_id {
383 wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
384 }
385
386 let where_clause = wheres.join(" AND ");
387 let sql = format!("DELETE FROM facts WHERE {where_clause}");
388
389 let result = sqlx::query(&sql)
390 .execute(&self.pool)
391 .await
392 .map_err(|e| MemoryError::Database(e.to_string()))?;
393
394 Ok(result.rows_affected())
395 }
396
397 async fn export(&self, filter: &FactFilter) -> Result<Vec<Fact>, MemoryError> {
398 self.list_facts(filter).await
399 }
400
401 async fn import(&self, facts: Vec<Fact>) -> Result<u64, MemoryError> {
402 let mut imported: u64 = 0;
403 for fact in facts {
404 let entity_refs_json = {
405 let strs: Vec<String> = fact.entity_refs.iter().map(|u| u.to_string()).collect();
406 serde_json::to_value(&strs)
407 .map_err(|e| MemoryError::Serialization(e.to_string()))?
408 };
409
410 let metadata_json = if fact.metadata.is_empty() {
411 serde_json::Value::Null
412 } else {
413 serde_json::to_value(&fact.metadata)
414 .map_err(|e| MemoryError::Serialization(e.to_string()))?
415 };
416
417 let result = sqlx::query(
418 r#"
419 INSERT INTO facts
420 (id, text, org_id, agent_id, user_id, session_id,
421 tier, category, source, confidence,
422 valid_from, invalid_at, created_at,
423 entity_refs, supersedes, superseded_by,
424 access_count, last_accessed, metadata)
425 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
426 ON CONFLICT (id) DO NOTHING
427 "#,
428 )
429 .bind(fact.id)
430 .bind(&fact.text)
431 .bind(&fact.scope.org_id)
432 .bind(fact.scope.agent_id.as_deref())
433 .bind(fact.scope.user_id.as_deref())
434 .bind(fact.scope.session_id.as_deref())
435 .bind(tier_to_str(&fact.tier))
436 .bind(fact.category.as_deref())
437 .bind(fact.source.as_deref())
438 .bind(fact.confidence.map(|c| c as f64))
439 .bind(fact.valid_from)
440 .bind(fact.invalid_at)
441 .bind(fact.created_at)
442 .bind(&entity_refs_json)
443 .bind(fact.supersedes)
444 .bind(fact.superseded_by)
445 .bind(fact.access_count as i64)
446 .bind(fact.last_accessed)
447 .bind(&metadata_json)
448 .execute(&self.pool)
449 .await
450 .map_err(|e| MemoryError::Database(e.to_string()))?;
451
452 imported += result.rows_affected();
453 }
454 Ok(imported)
455 }
456
457 async fn stats(&self) -> Result<StoreStats, MemoryError> {
458 let (total,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM facts")
459 .fetch_one(&self.pool)
460 .await
461 .map_err(|e| MemoryError::Database(e.to_string()))?;
462
463 let (valid,): (i64,) =
464 sqlx::query_as("SELECT COUNT(*) FROM facts WHERE invalid_at IS NULL")
465 .fetch_one(&self.pool)
466 .await
467 .map_err(|e| MemoryError::Database(e.to_string()))?;
468
469 let (invalidated,): (i64,) =
470 sqlx::query_as("SELECT COUNT(*) FROM facts WHERE invalid_at IS NOT NULL")
471 .fetch_one(&self.pool)
472 .await
473 .map_err(|e| MemoryError::Database(e.to_string()))?;
474
475 Ok(StoreStats {
476 total_facts: total as u64,
477 valid_facts: valid as u64,
478 invalidated_facts: invalidated as u64,
479 total_entities: 0,
480 total_relationships: 0,
481 })
482 }
483
484 async fn record_access(&self, id: FactId) -> Result<(), MemoryError> {
485 let now = Utc::now();
486 sqlx::query(
487 "UPDATE facts SET access_count = access_count + 1, last_accessed = $1 WHERE id = $2",
488 )
489 .bind(now)
490 .bind(id)
491 .execute(&self.pool)
492 .await
493 .map_err(|e| MemoryError::Database(e.to_string()))?;
494 Ok(())
495 }
496
497 async fn keyword_search(
498 &self,
499 query: &str,
500 scope: &Scope,
501 top_k: usize,
502 ) -> Result<Vec<Fact>, MemoryError> {
503 let trimmed = query.trim();
504 if trimmed.is_empty() {
505 return Ok(Vec::new());
506 }
507
508 let sql = r#"
509 SELECT id, text, org_id, agent_id, user_id, session_id, tier, category,
510 source, confidence, valid_from, invalid_at, created_at, entity_refs,
511 supersedes, superseded_by, access_count, last_accessed, metadata
512 FROM facts
513 WHERE search_vector @@ plainto_tsquery('english', $1)
514 AND org_id = $2
515 AND ($3::text IS NULL OR user_id = $3)
516 AND invalid_at IS NULL
517 ORDER BY ts_rank(search_vector, plainto_tsquery('english', $1)) DESC
518 LIMIT $4
519 "#;
520
521 let rows = sqlx::query_as::<_, FactRow>(sql)
522 .bind(trimmed)
523 .bind(&scope.org_id)
524 .bind(scope.user_id.as_deref())
525 .bind(top_k as i64)
526 .fetch_all(&self.pool)
527 .await
528 .map_err(|e| MemoryError::Database(e.to_string()))?;
529
530 rows.into_iter().map(row_to_fact).collect()
531 }
532}