1use mnemo_core::error::{Error, Result};
2use mnemo_core::model::acl::{Acl, Permission};
3use mnemo_core::model::agent_profile::AgentProfile;
4use mnemo_core::model::checkpoint::Checkpoint;
5use mnemo_core::model::delegation::{Delegation, DelegationScope};
6use mnemo_core::model::embedding_baseline::EmbeddingBaseline;
7use mnemo_core::model::event::AgentEvent;
8use mnemo_core::model::memory::MemoryRecord;
9use mnemo_core::model::relation::Relation;
10use mnemo_core::storage::{MemoryFilter, StorageBackend};
11use pgvector::Vector;
12use sqlx::Row;
13use uuid::Uuid;
14
15pub struct PgStorage {
22 pool: sqlx::PgPool,
23 #[allow(dead_code)]
24 dimensions: usize,
25}
26
27impl PgStorage {
28 pub async fn connect(url: &str, dimensions: usize) -> Result<Self> {
33 let pool = sqlx::PgPool::connect(url)
34 .await
35 .map_err(|e| Error::Storage(e.to_string()))?;
36 let storage = Self { pool, dimensions };
37 crate::migrations::run_migrations(&storage.pool, dimensions).await?;
38 Ok(storage)
39 }
40
41 pub async fn from_pool(pool: sqlx::PgPool, dimensions: usize) -> Result<Self> {
43 crate::migrations::run_migrations(&pool, dimensions).await?;
44 Ok(Self { pool, dimensions })
45 }
46}
47
48fn map_sqlx(e: sqlx::Error) -> Error {
53 Error::Storage(e.to_string())
54}
55
56fn serialize_embedding(embedding: &Option<Vec<f32>>) -> Option<Vec<u8>> {
57 embedding
58 .as_ref()
59 .map(|v| v.iter().flat_map(|f| f.to_le_bytes()).collect())
60}
61
62fn deserialize_embedding(blob: Option<Vec<u8>>) -> Option<Vec<f32>> {
63 blob.map(|bytes| {
64 bytes
65 .chunks_exact(4)
66 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
67 .collect()
68 })
69}
70
71fn row_to_memory(row: &sqlx::postgres::PgRow) -> std::result::Result<MemoryRecord, sqlx::Error> {
72 let tags: Vec<String> = row.try_get::<Vec<String>, _>("tags").unwrap_or_default();
73 let metadata: serde_json::Value = row
74 .try_get("metadata")
75 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
76
77 let embedding: Option<Vec<f32>> = {
80 let raw: Option<String> = row.try_get("embedding_text").ok().flatten();
81 raw.and_then(|s| {
82 let trimmed = s.trim_start_matches('[').trim_end_matches(']');
84 if trimmed.is_empty() {
85 None
86 } else {
87 Some(
88 trimmed
89 .split(',')
90 .filter_map(|v| v.trim().parse::<f32>().ok())
91 .collect(),
92 )
93 }
94 })
95 };
96
97 Ok(MemoryRecord {
98 id: row.get("id"),
99 agent_id: row.get("agent_id"),
100 content: row.get("content"),
101 memory_type: row
102 .get::<String, _>("memory_type")
103 .parse()
104 .unwrap_or(mnemo_core::model::memory::MemoryType::Semantic),
105 scope: row
106 .get::<String, _>("scope")
107 .parse()
108 .unwrap_or(mnemo_core::model::memory::Scope::Private),
109 importance: row.get("importance"),
110 tags,
111 metadata,
112 embedding,
113 content_hash: row.get("content_hash"),
114 prev_hash: row.get("prev_hash"),
115 source_type: row
116 .get::<String, _>("source_type")
117 .parse()
118 .unwrap_or(mnemo_core::model::memory::SourceType::Agent),
119 source_id: row.get("source_id"),
120 consolidation_state: row
121 .get::<String, _>("consolidation_state")
122 .parse()
123 .unwrap_or(mnemo_core::model::memory::ConsolidationState::Raw),
124 access_count: row.get::<i64, _>("access_count") as u64,
125 org_id: row.get("org_id"),
126 thread_id: row.get("thread_id"),
127 created_at: row.get("created_at"),
128 updated_at: row.get("updated_at"),
129 last_accessed_at: row.get("last_accessed_at"),
130 expires_at: row.get("expires_at"),
131 deleted_at: row.get("deleted_at"),
132 decay_rate: row.get("decay_rate"),
133 created_by: row.get("created_by"),
134 version: row.get::<i32, _>("version") as u32,
135 prev_version_id: row.get("prev_version_id"),
136 quarantined: row.get("quarantined"),
137 quarantine_reason: row.get("quarantine_reason"),
138 decay_function: row.get("decay_function"),
139 })
140}
141
142const MEMORY_COLUMNS: &str = r#"
146 id, agent_id, content, memory_type, scope, importance,
147 tags, metadata, embedding::text AS embedding_text,
148 content_hash, prev_hash, source_type, source_id,
149 consolidation_state, access_count, org_id, thread_id,
150 created_at, updated_at, last_accessed_at, expires_at,
151 deleted_at, decay_rate, created_by, version, prev_version_id,
152 quarantined, quarantine_reason, decay_function
153"#;
154
155fn row_to_event(row: &sqlx::postgres::PgRow) -> std::result::Result<AgentEvent, sqlx::Error> {
156 let payload: serde_json::Value = row.try_get("payload").unwrap_or(serde_json::Value::Null);
157 let embedding_blob: Option<Vec<u8>> = row.try_get("embedding").unwrap_or(None);
158
159 Ok(AgentEvent {
160 id: row.get("id"),
161 agent_id: row.get("agent_id"),
162 thread_id: row.get("thread_id"),
163 run_id: row.get("run_id"),
164 parent_event_id: row.get("parent_event_id"),
165 event_type: row
166 .get::<String, _>("event_type")
167 .parse()
168 .unwrap_or(mnemo_core::model::event::EventType::Error),
169 payload,
170 trace_id: row.get("trace_id"),
171 span_id: row.get("span_id"),
172 model: row.get("model"),
173 tokens_input: row.get("tokens_input"),
174 tokens_output: row.get("tokens_output"),
175 latency_ms: row.get("latency_ms"),
176 cost_usd: row.get("cost_usd"),
177 timestamp: row.get("timestamp"),
178 logical_clock: row.get("logical_clock"),
179 content_hash: row.get("content_hash"),
180 prev_hash: row.get("prev_hash"),
181 embedding: deserialize_embedding(embedding_blob),
182 })
183}
184
185fn row_to_relation(row: &sqlx::postgres::PgRow) -> std::result::Result<Relation, sqlx::Error> {
186 let metadata: serde_json::Value = row
187 .try_get("metadata")
188 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
189
190 Ok(Relation {
191 id: row.get("id"),
192 source_id: row.get("source_id"),
193 target_id: row.get("target_id"),
194 relation_type: row.get("relation_type"),
195 weight: row.get("weight"),
196 metadata,
197 created_at: row.get("created_at"),
198 })
199}
200
201fn row_to_checkpoint(row: &sqlx::postgres::PgRow) -> std::result::Result<Checkpoint, sqlx::Error> {
202 let state_snapshot: serde_json::Value = row
203 .try_get("state_snapshot")
204 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
205 let state_diff: Option<serde_json::Value> = row.try_get("state_diff").unwrap_or(None);
206
207 let memory_refs_raw: Vec<String> = row.try_get("memory_refs").unwrap_or_default();
209 let memory_refs: Vec<Uuid> = memory_refs_raw
210 .iter()
211 .filter_map(|s| Uuid::parse_str(s).ok())
212 .collect();
213
214 let metadata: serde_json::Value = row
215 .try_get("metadata")
216 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
217
218 Ok(Checkpoint {
219 id: row.get("id"),
220 thread_id: row.get("thread_id"),
221 agent_id: row.get("agent_id"),
222 parent_id: row.get("parent_id"),
223 branch_name: row.get("branch_name"),
224 state_snapshot,
225 state_diff,
226 memory_refs,
227 event_cursor: row.get("event_cursor"),
228 label: row.get("label"),
229 created_at: row.get("created_at"),
230 metadata,
231 })
232}
233
234fn row_to_delegation(row: &sqlx::postgres::PgRow) -> std::result::Result<Delegation, sqlx::Error> {
235 let scope_type: String = row.get("scope_type");
236 let scope_value: Option<serde_json::Value> = row.try_get("scope_value").unwrap_or(None);
237
238 let scope = match scope_type.as_str() {
239 "by_tag" => {
240 let tags: Vec<String> = scope_value
241 .and_then(|v| serde_json::from_value(v).ok())
242 .unwrap_or_default();
243 DelegationScope::ByTag(tags)
244 }
245 "by_memory_id" => {
246 let id_strs: Vec<String> = scope_value
247 .and_then(|v| serde_json::from_value(v).ok())
248 .unwrap_or_default();
249 let uuids = id_strs
250 .into_iter()
251 .filter_map(|s| Uuid::parse_str(&s).ok())
252 .collect();
253 DelegationScope::ByMemoryId(uuids)
254 }
255 _ => DelegationScope::AllMemories,
256 };
257
258 Ok(Delegation {
259 id: row.get("id"),
260 delegator_id: row.get("delegator_id"),
261 delegate_id: row.get("delegate_id"),
262 permission: row
263 .get::<String, _>("permission")
264 .parse()
265 .unwrap_or(Permission::Read),
266 scope,
267 max_depth: row.get::<i32, _>("max_depth") as u32,
268 current_depth: row.get::<i32, _>("current_depth") as u32,
269 parent_delegation_id: row.get("parent_delegation_id"),
270 created_at: row.get("created_at"),
271 expires_at: row.get("expires_at"),
272 revoked_at: row.get("revoked_at"),
273 })
274}
275
276#[async_trait::async_trait]
281impl StorageBackend for PgStorage {
282 async fn insert_memory(&self, record: &MemoryRecord) -> Result<()> {
287 let embedding_param: Option<Vector> =
288 record.embedding.as_ref().map(|v| Vector::from(v.clone()));
289
290 let tags_slice: &[String] = &record.tags;
291
292 sqlx::query(
293 r#"
294INSERT INTO memories (
295 id, agent_id, content, memory_type, scope, importance,
296 tags, metadata, embedding,
297 content_hash, prev_hash, source_type, source_id,
298 consolidation_state, access_count, org_id, thread_id,
299 created_at, updated_at, last_accessed_at, expires_at,
300 deleted_at, decay_rate, created_by, version, prev_version_id,
301 quarantined, quarantine_reason, decay_function
302) VALUES (
303 $1, $2, $3, $4, $5, $6,
304 $7, $8, $9,
305 $10, $11, $12, $13,
306 $14, $15, $16, $17,
307 $18, $19, $20, $21,
308 $22, $23, $24, $25, $26,
309 $27, $28, $29
310)
311"#,
312 )
313 .bind(record.id)
314 .bind(&record.agent_id)
315 .bind(&record.content)
316 .bind(record.memory_type.to_string())
317 .bind(record.scope.to_string())
318 .bind(record.importance)
319 .bind(tags_slice)
320 .bind(&record.metadata)
321 .bind(&embedding_param)
322 .bind(&record.content_hash)
323 .bind(&record.prev_hash)
324 .bind(record.source_type.to_string())
325 .bind(&record.source_id)
326 .bind(record.consolidation_state.to_string())
327 .bind(record.access_count as i64)
328 .bind(&record.org_id)
329 .bind(&record.thread_id)
330 .bind(&record.created_at)
331 .bind(&record.updated_at)
332 .bind(&record.last_accessed_at)
333 .bind(&record.expires_at)
334 .bind(&record.deleted_at)
335 .bind(record.decay_rate)
336 .bind(&record.created_by)
337 .bind(record.version as i32)
338 .bind(record.prev_version_id)
339 .bind(record.quarantined)
340 .bind(&record.quarantine_reason)
341 .bind(&record.decay_function)
342 .execute(&self.pool)
343 .await
344 .map_err(map_sqlx)?;
345
346 Ok(())
347 }
348
349 async fn get_memory(&self, id: Uuid) -> Result<Option<MemoryRecord>> {
350 let sql = format!("SELECT {MEMORY_COLUMNS} FROM memories WHERE id = $1");
351 let row = sqlx::query(&sql)
352 .bind(id)
353 .fetch_optional(&self.pool)
354 .await
355 .map_err(map_sqlx)?;
356
357 match row {
358 Some(r) => Ok(Some(row_to_memory(&r).map_err(map_sqlx)?)),
359 None => Ok(None),
360 }
361 }
362
363 async fn update_memory(&self, record: &MemoryRecord) -> Result<()> {
364 let embedding_param: Option<Vector> =
365 record.embedding.as_ref().map(|v| Vector::from(v.clone()));
366
367 let tags_slice: &[String] = &record.tags;
368
369 let result = sqlx::query(
370 r#"
371UPDATE memories SET
372 agent_id = $1, content = $2, memory_type = $3, scope = $4,
373 importance = $5, tags = $6, metadata = $7,
374 embedding = $8,
375 content_hash = $9, prev_hash = $10, source_type = $11,
376 source_id = $12, consolidation_state = $13, access_count = $14,
377 org_id = $15, thread_id = $16, updated_at = $17,
378 last_accessed_at = $18, expires_at = $19, deleted_at = $20,
379 decay_rate = $21, created_by = $22, version = $23,
380 prev_version_id = $24, quarantined = $25, quarantine_reason = $26,
381 decay_function = $27
382WHERE id = $28
383"#,
384 )
385 .bind(&record.agent_id)
386 .bind(&record.content)
387 .bind(record.memory_type.to_string())
388 .bind(record.scope.to_string())
389 .bind(record.importance)
390 .bind(tags_slice)
391 .bind(&record.metadata)
392 .bind(&embedding_param)
393 .bind(&record.content_hash)
394 .bind(&record.prev_hash)
395 .bind(record.source_type.to_string())
396 .bind(&record.source_id)
397 .bind(record.consolidation_state.to_string())
398 .bind(record.access_count as i64)
399 .bind(&record.org_id)
400 .bind(&record.thread_id)
401 .bind(&record.updated_at)
402 .bind(&record.last_accessed_at)
403 .bind(&record.expires_at)
404 .bind(&record.deleted_at)
405 .bind(record.decay_rate)
406 .bind(&record.created_by)
407 .bind(record.version as i32)
408 .bind(record.prev_version_id)
409 .bind(record.quarantined)
410 .bind(&record.quarantine_reason)
411 .bind(&record.decay_function)
412 .bind(record.id)
413 .execute(&self.pool)
414 .await
415 .map_err(map_sqlx)?;
416
417 if result.rows_affected() == 0 {
418 return Err(Error::NotFound(format!("memory {} not found", record.id)));
419 }
420 Ok(())
421 }
422
423 async fn soft_delete_memory(&self, id: Uuid) -> Result<()> {
424 let now = chrono::Utc::now().to_rfc3339();
425 let result = sqlx::query(
426 "UPDATE memories SET deleted_at = $1, updated_at = $2 WHERE id = $3 AND deleted_at IS NULL",
427 )
428 .bind(&now)
429 .bind(&now)
430 .bind(id)
431 .execute(&self.pool)
432 .await
433 .map_err(map_sqlx)?;
434
435 if result.rows_affected() == 0 {
436 return Err(Error::NotFound(format!(
437 "memory {id} not found or already deleted"
438 )));
439 }
440 Ok(())
441 }
442
443 async fn hard_delete_memory(&self, id: Uuid) -> Result<()> {
444 let result = sqlx::query("DELETE FROM memories WHERE id = $1")
445 .bind(id)
446 .execute(&self.pool)
447 .await
448 .map_err(map_sqlx)?;
449
450 if result.rows_affected() == 0 {
451 return Err(Error::NotFound(format!("memory {id} not found")));
452 }
453
454 sqlx::query("DELETE FROM acls WHERE memory_id = $1")
456 .bind(id)
457 .execute(&self.pool)
458 .await
459 .map_err(map_sqlx)?;
460
461 Ok(())
462 }
463
464 async fn list_memories(
465 &self,
466 filter: &MemoryFilter,
467 limit: usize,
468 offset: usize,
469 ) -> Result<Vec<MemoryRecord>> {
470 let mut conditions: Vec<String> = Vec::new();
471 let mut param_idx: usize = 0;
474
475 if !filter.include_deleted {
485 conditions.push("deleted_at IS NULL".to_string());
486 }
487
488 #[derive(Debug)]
490 enum Param {
491 Str(String),
492 F32(f32),
493 }
494 let mut params: Vec<Param> = Vec::new();
495
496 if let Some(ref agent_id) = filter.agent_id {
497 param_idx += 1;
498 conditions.push(format!("agent_id = ${param_idx}"));
499 params.push(Param::Str(agent_id.clone()));
500 }
501 if let Some(memory_type) = filter.memory_type {
502 param_idx += 1;
503 conditions.push(format!("memory_type = ${param_idx}"));
504 params.push(Param::Str(memory_type.to_string()));
505 }
506 if let Some(scope) = filter.scope {
507 param_idx += 1;
508 conditions.push(format!("scope = ${param_idx}"));
509 params.push(Param::Str(scope.to_string()));
510 }
511 if let Some(min_importance) = filter.min_importance {
512 param_idx += 1;
513 conditions.push(format!("importance >= ${param_idx}"));
514 params.push(Param::F32(min_importance));
515 }
516 if let Some(ref org_id) = filter.org_id {
517 param_idx += 1;
518 conditions.push(format!("org_id = ${param_idx}"));
519 params.push(Param::Str(org_id.clone()));
520 }
521 if let Some(ref thread_id) = filter.thread_id {
522 param_idx += 1;
523 conditions.push(format!("thread_id = ${param_idx}"));
524 params.push(Param::Str(thread_id.clone()));
525 }
526
527 let where_clause = if conditions.is_empty() {
528 String::new()
529 } else {
530 format!("WHERE {}", conditions.join(" AND "))
531 };
532
533 let sql = format!(
534 "SELECT {MEMORY_COLUMNS} FROM memories {where_clause} ORDER BY created_at DESC LIMIT {limit} OFFSET {offset}"
535 );
536
537 let mut query = sqlx::query(&sql);
538 for p in ¶ms {
539 match p {
540 Param::Str(s) => query = query.bind(s),
541 Param::F32(f) => query = query.bind(*f),
542 }
543 }
544
545 let rows = query.fetch_all(&self.pool).await.map_err(map_sqlx)?;
546 let mut results = Vec::with_capacity(rows.len());
547 for r in &rows {
548 results.push(row_to_memory(r).map_err(map_sqlx)?);
549 }
550 Ok(results)
551 }
552
553 async fn touch_memory(&self, id: Uuid) -> Result<()> {
554 let now = chrono::Utc::now().to_rfc3339();
555 sqlx::query(
556 "UPDATE memories SET access_count = access_count + 1, last_accessed_at = $1 WHERE id = $2",
557 )
558 .bind(&now)
559 .bind(id)
560 .execute(&self.pool)
561 .await
562 .map_err(map_sqlx)?;
563 Ok(())
564 }
565
566 async fn insert_acl(&self, acl: &Acl) -> Result<()> {
571 sqlx::query(
572 r#"
573INSERT INTO acls (id, memory_id, principal_type, principal_id, permission, granted_by, created_at, expires_at)
574VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
575"#,
576 )
577 .bind(acl.id)
578 .bind(acl.memory_id)
579 .bind(acl.principal_type.to_string())
580 .bind(&acl.principal_id)
581 .bind(acl.permission.to_string())
582 .bind(&acl.granted_by)
583 .bind(&acl.created_at)
584 .bind(&acl.expires_at)
585 .execute(&self.pool)
586 .await
587 .map_err(map_sqlx)?;
588 Ok(())
589 }
590
591 async fn check_permission(
592 &self,
593 memory_id: Uuid,
594 principal_id: &str,
595 required: Permission,
596 ) -> Result<bool> {
597 let owner_row = sqlx::query("SELECT agent_id FROM memories WHERE id = $1")
599 .bind(memory_id)
600 .fetch_optional(&self.pool)
601 .await
602 .map_err(map_sqlx)?;
603
604 match owner_row {
605 None => return Err(Error::NotFound(format!("memory {memory_id} not found"))),
606 Some(row) => {
607 let owner: String = row.get("agent_id");
608 if owner == principal_id {
609 return Ok(true);
610 }
611 }
612 }
613
614 let now = chrono::Utc::now().to_rfc3339();
616 let acl_rows = sqlx::query(
617 "SELECT permission FROM acls WHERE memory_id = $1 AND principal_id = $2 AND (expires_at IS NULL OR expires_at > $3)",
618 )
619 .bind(memory_id)
620 .bind(principal_id)
621 .bind(&now)
622 .fetch_all(&self.pool)
623 .await
624 .map_err(map_sqlx)?;
625
626 for row in &acl_rows {
627 let perm_str: String = row.get("permission");
628 if let Ok(perm) = perm_str.parse::<Permission>()
629 && perm.satisfies(required)
630 {
631 return Ok(true);
632 }
633 }
634
635 let public_rows = sqlx::query(
637 "SELECT permission FROM acls WHERE memory_id = $1 AND principal_type = 'public' AND (expires_at IS NULL OR expires_at > $2)",
638 )
639 .bind(memory_id)
640 .bind(&now)
641 .fetch_all(&self.pool)
642 .await
643 .map_err(map_sqlx)?;
644
645 for row in &public_rows {
646 let perm_str: String = row.get("permission");
647 if let Ok(perm) = perm_str.parse::<Permission>()
648 && perm.satisfies(required)
649 {
650 return Ok(true);
651 }
652 }
653
654 if self
656 .check_delegation(principal_id, memory_id, required)
657 .await?
658 {
659 return Ok(true);
660 }
661
662 Ok(false)
663 }
664
665 async fn insert_relation(&self, relation: &Relation) -> Result<()> {
670 sqlx::query(
671 r#"
672INSERT INTO relations (id, source_id, target_id, relation_type, weight, metadata, created_at)
673VALUES ($1, $2, $3, $4, $5, $6, $7)
674"#,
675 )
676 .bind(relation.id)
677 .bind(relation.source_id)
678 .bind(relation.target_id)
679 .bind(&relation.relation_type)
680 .bind(relation.weight)
681 .bind(&relation.metadata)
682 .bind(&relation.created_at)
683 .execute(&self.pool)
684 .await
685 .map_err(map_sqlx)?;
686 Ok(())
687 }
688
689 async fn get_relations_from(&self, source_id: Uuid) -> Result<Vec<Relation>> {
690 let rows = sqlx::query(
691 "SELECT id, source_id, target_id, relation_type, weight, metadata, created_at FROM relations WHERE source_id = $1",
692 )
693 .bind(source_id)
694 .fetch_all(&self.pool)
695 .await
696 .map_err(map_sqlx)?;
697
698 let mut results = Vec::with_capacity(rows.len());
699 for r in &rows {
700 results.push(row_to_relation(r).map_err(map_sqlx)?);
701 }
702 Ok(results)
703 }
704
705 async fn get_relations_to(&self, target_id: Uuid) -> Result<Vec<Relation>> {
706 let rows = sqlx::query(
707 "SELECT id, source_id, target_id, relation_type, weight, metadata, created_at FROM relations WHERE target_id = $1",
708 )
709 .bind(target_id)
710 .fetch_all(&self.pool)
711 .await
712 .map_err(map_sqlx)?;
713
714 let mut results = Vec::with_capacity(rows.len());
715 for r in &rows {
716 results.push(row_to_relation(r).map_err(map_sqlx)?);
717 }
718 Ok(results)
719 }
720
721 async fn delete_relation(&self, id: Uuid) -> Result<()> {
722 let result = sqlx::query("DELETE FROM relations WHERE id = $1")
723 .bind(id)
724 .execute(&self.pool)
725 .await
726 .map_err(map_sqlx)?;
727
728 if result.rows_affected() == 0 {
729 return Err(Error::NotFound(format!("relation {id} not found")));
730 }
731 Ok(())
732 }
733
734 async fn get_latest_memory_hash(
739 &self,
740 agent_id: &str,
741 thread_id: Option<&str>,
742 ) -> Result<Option<Vec<u8>>> {
743 let row = if let Some(tid) = thread_id {
744 sqlx::query(
745 "SELECT content_hash FROM memories WHERE agent_id = $1 AND thread_id = $2 AND deleted_at IS NULL ORDER BY created_at DESC LIMIT 1",
746 )
747 .bind(agent_id)
748 .bind(tid)
749 .fetch_optional(&self.pool)
750 .await
751 .map_err(map_sqlx)?
752 } else {
753 sqlx::query(
754 "SELECT content_hash FROM memories WHERE agent_id = $1 AND thread_id IS NULL AND deleted_at IS NULL ORDER BY created_at DESC LIMIT 1",
755 )
756 .bind(agent_id)
757 .fetch_optional(&self.pool)
758 .await
759 .map_err(map_sqlx)?
760 };
761
762 Ok(row.map(|r| r.get::<Vec<u8>, _>("content_hash")))
763 }
764
765 async fn get_latest_event_hash(
766 &self,
767 agent_id: &str,
768 thread_id: Option<&str>,
769 ) -> Result<Option<Vec<u8>>> {
770 let row = if let Some(tid) = thread_id {
771 sqlx::query(
772 "SELECT content_hash FROM agent_events WHERE agent_id = $1 AND thread_id = $2 ORDER BY timestamp DESC LIMIT 1",
773 )
774 .bind(agent_id)
775 .bind(tid)
776 .fetch_optional(&self.pool)
777 .await
778 .map_err(map_sqlx)?
779 } else {
780 sqlx::query(
781 "SELECT content_hash FROM agent_events WHERE agent_id = $1 ORDER BY timestamp DESC LIMIT 1",
782 )
783 .bind(agent_id)
784 .fetch_optional(&self.pool)
785 .await
786 .map_err(map_sqlx)?
787 };
788 Ok(row.map(|r| r.get::<Vec<u8>, _>("content_hash")))
789 }
790
791 async fn get_sync_watermark(&self, key: &str) -> Result<Option<String>> {
792 let row = sqlx::query("SELECT value FROM sync_metadata WHERE key = $1")
793 .bind(key)
794 .fetch_optional(&self.pool)
795 .await
796 .map_err(map_sqlx)?;
797 Ok(row.map(|r| r.get::<String, _>("value")))
798 }
799
800 async fn set_sync_watermark(&self, key: &str, value: &str) -> Result<()> {
801 let now = chrono::Utc::now().to_rfc3339();
802 sqlx::query(
803 "INSERT INTO sync_metadata (key, value, updated_at) VALUES ($1, $2, $3) ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = $3",
804 )
805 .bind(key)
806 .bind(value)
807 .bind(now)
808 .execute(&self.pool)
809 .await
810 .map_err(map_sqlx)?;
811 Ok(())
812 }
813
814 async fn list_accessible_memory_ids(&self, agent_id: &str, limit: usize) -> Result<Vec<Uuid>> {
819 let now = chrono::Utc::now().to_rfc3339();
820 let rows = sqlx::query(
821 r#"
822SELECT id FROM memories
823WHERE (
824 agent_id = $1
825 OR scope = 'public'
826 OR id IN (
827 SELECT memory_id FROM acls
828 WHERE principal_id = $2 AND (expires_at IS NULL OR expires_at > $3)
829 )
830)
831AND deleted_at IS NULL
832LIMIT $4
833"#,
834 )
835 .bind(agent_id)
836 .bind(agent_id)
837 .bind(&now)
838 .bind(limit as i64)
839 .fetch_all(&self.pool)
840 .await
841 .map_err(map_sqlx)?;
842
843 let ids: Vec<Uuid> = rows.iter().map(|r| r.get("id")).collect();
844 Ok(ids)
845 }
846
847 async fn insert_event(&self, event: &AgentEvent) -> Result<()> {
852 let payload_json = &event.payload;
853 let embedding_blob = serialize_embedding(&event.embedding);
854
855 sqlx::query(
856 r#"
857INSERT INTO agent_events (
858 id, agent_id, thread_id, run_id, parent_event_id, event_type,
859 payload, trace_id, span_id, model, tokens_input, tokens_output,
860 latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
861 prev_hash, embedding
862) VALUES (
863 $1, $2, $3, $4, $5, $6,
864 $7, $8, $9, $10, $11, $12,
865 $13, $14, $15, $16, $17,
866 $18, $19
867)
868"#,
869 )
870 .bind(event.id)
871 .bind(&event.agent_id)
872 .bind(&event.thread_id)
873 .bind(&event.run_id)
874 .bind(event.parent_event_id)
875 .bind(event.event_type.to_string())
876 .bind(payload_json)
877 .bind(&event.trace_id)
878 .bind(&event.span_id)
879 .bind(&event.model)
880 .bind(event.tokens_input)
881 .bind(event.tokens_output)
882 .bind(event.latency_ms)
883 .bind(event.cost_usd)
884 .bind(&event.timestamp)
885 .bind(event.logical_clock)
886 .bind(&event.content_hash)
887 .bind(&event.prev_hash)
888 .bind(&embedding_blob)
889 .execute(&self.pool)
890 .await
891 .map_err(map_sqlx)?;
892 Ok(())
893 }
894
895 async fn list_events(
896 &self,
897 agent_id: &str,
898 limit: usize,
899 offset: usize,
900 ) -> Result<Vec<AgentEvent>> {
901 let rows = sqlx::query(
902 r#"
903SELECT id, agent_id, thread_id, run_id, parent_event_id, event_type,
904 payload, trace_id, span_id, model, tokens_input, tokens_output,
905 latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
906 prev_hash, embedding
907FROM agent_events
908WHERE agent_id = $1
909ORDER BY "timestamp" DESC
910LIMIT $2 OFFSET $3
911"#,
912 )
913 .bind(agent_id)
914 .bind(limit as i64)
915 .bind(offset as i64)
916 .fetch_all(&self.pool)
917 .await
918 .map_err(map_sqlx)?;
919
920 let mut results = Vec::with_capacity(rows.len());
921 for r in &rows {
922 results.push(row_to_event(r).map_err(map_sqlx)?);
923 }
924 Ok(results)
925 }
926
927 async fn get_events_by_thread(&self, thread_id: &str, limit: usize) -> Result<Vec<AgentEvent>> {
928 let rows = sqlx::query(
929 r#"
930SELECT id, agent_id, thread_id, run_id, parent_event_id, event_type,
931 payload, trace_id, span_id, model, tokens_input, tokens_output,
932 latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
933 prev_hash, embedding
934FROM agent_events
935WHERE thread_id = $1
936ORDER BY "timestamp" ASC
937LIMIT $2
938"#,
939 )
940 .bind(thread_id)
941 .bind(limit as i64)
942 .fetch_all(&self.pool)
943 .await
944 .map_err(map_sqlx)?;
945
946 let mut results = Vec::with_capacity(rows.len());
947 for r in &rows {
948 results.push(row_to_event(r).map_err(map_sqlx)?);
949 }
950 Ok(results)
951 }
952
953 async fn get_event(&self, id: Uuid) -> Result<Option<AgentEvent>> {
954 let row = sqlx::query(
955 r#"
956SELECT id, agent_id, thread_id, run_id, parent_event_id, event_type,
957 payload, trace_id, span_id, model, tokens_input, tokens_output,
958 latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
959 prev_hash, embedding
960FROM agent_events
961WHERE id = $1
962"#,
963 )
964 .bind(id)
965 .fetch_optional(&self.pool)
966 .await
967 .map_err(map_sqlx)?;
968
969 match row {
970 Some(r) => Ok(Some(row_to_event(&r).map_err(map_sqlx)?)),
971 None => Ok(None),
972 }
973 }
974
975 async fn list_child_events(
976 &self,
977 parent_event_id: Uuid,
978 limit: usize,
979 ) -> Result<Vec<AgentEvent>> {
980 let rows = sqlx::query(
981 r#"
982SELECT id, agent_id, thread_id, run_id, parent_event_id, event_type,
983 payload, trace_id, span_id, model, tokens_input, tokens_output,
984 latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
985 prev_hash, embedding
986FROM agent_events
987WHERE parent_event_id = $1
988ORDER BY "timestamp" ASC
989LIMIT $2
990"#,
991 )
992 .bind(parent_event_id)
993 .bind(limit as i64)
994 .fetch_all(&self.pool)
995 .await
996 .map_err(map_sqlx)?;
997
998 let mut results = Vec::with_capacity(rows.len());
999 for r in &rows {
1000 results.push(row_to_event(r).map_err(map_sqlx)?);
1001 }
1002 Ok(results)
1003 }
1004
1005 async fn list_memories_by_agent_ordered(
1010 &self,
1011 agent_id: &str,
1012 thread_id: Option<&str>,
1013 limit: usize,
1014 ) -> Result<Vec<MemoryRecord>> {
1015 let rows = if let Some(tid) = thread_id {
1016 let sql = format!(
1017 "SELECT {MEMORY_COLUMNS} FROM memories WHERE agent_id = $1 AND thread_id = $2 AND deleted_at IS NULL ORDER BY created_at ASC LIMIT $3"
1018 );
1019 sqlx::query(&sql)
1020 .bind(agent_id)
1021 .bind(tid)
1022 .bind(limit as i64)
1023 .fetch_all(&self.pool)
1024 .await
1025 .map_err(map_sqlx)?
1026 } else {
1027 let sql = format!(
1028 "SELECT {MEMORY_COLUMNS} FROM memories WHERE agent_id = $1 AND deleted_at IS NULL ORDER BY created_at ASC LIMIT $2"
1029 );
1030 sqlx::query(&sql)
1031 .bind(agent_id)
1032 .bind(limit as i64)
1033 .fetch_all(&self.pool)
1034 .await
1035 .map_err(map_sqlx)?
1036 };
1037
1038 let mut results = Vec::with_capacity(rows.len());
1039 for r in &rows {
1040 results.push(row_to_memory(r).map_err(map_sqlx)?);
1041 }
1042 Ok(results)
1043 }
1044
1045 async fn list_memories_since(
1050 &self,
1051 updated_after: &str,
1052 limit: usize,
1053 ) -> Result<Vec<MemoryRecord>> {
1054 let sql = format!(
1055 "SELECT {MEMORY_COLUMNS} FROM memories WHERE updated_at > $1 ORDER BY updated_at ASC LIMIT $2"
1056 );
1057 let rows = sqlx::query(&sql)
1058 .bind(updated_after)
1059 .bind(limit as i64)
1060 .fetch_all(&self.pool)
1061 .await
1062 .map_err(map_sqlx)?;
1063
1064 let mut results = Vec::with_capacity(rows.len());
1065 for r in &rows {
1066 results.push(row_to_memory(r).map_err(map_sqlx)?);
1067 }
1068 Ok(results)
1069 }
1070
1071 async fn upsert_memory(&self, record: &MemoryRecord) -> Result<()> {
1072 match self.update_memory(record).await {
1073 Ok(()) => Ok(()),
1074 Err(Error::NotFound(_)) => self.insert_memory(record).await,
1075 Err(e) => Err(e),
1076 }
1077 }
1078
1079 async fn cleanup_expired(&self) -> Result<usize> {
1084 let now = chrono::Utc::now().to_rfc3339();
1085 let result = sqlx::query(
1086 "UPDATE memories SET deleted_at = $1 WHERE expires_at IS NOT NULL AND expires_at < $2 AND deleted_at IS NULL",
1087 )
1088 .bind(&now)
1089 .bind(&now)
1090 .execute(&self.pool)
1091 .await
1092 .map_err(map_sqlx)?;
1093
1094 Ok(result.rows_affected() as usize)
1095 }
1096
1097 async fn insert_delegation(&self, d: &Delegation) -> Result<()> {
1102 let scope_type = d.scope.to_string();
1103 let scope_value: serde_json::Value = match &d.scope {
1104 DelegationScope::AllMemories => serde_json::Value::Null,
1105 DelegationScope::ByTag(tags) => serde_json::json!(tags),
1106 DelegationScope::ByMemoryId(ids) => {
1107 serde_json::json!(ids.iter().map(|id| id.to_string()).collect::<Vec<_>>())
1108 }
1109 };
1110
1111 sqlx::query(
1112 r#"
1113INSERT INTO delegations (
1114 id, delegator_id, delegate_id, permission, scope_type, scope_value,
1115 max_depth, current_depth, parent_delegation_id,
1116 created_at, expires_at, revoked_at
1117) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
1118"#,
1119 )
1120 .bind(d.id)
1121 .bind(&d.delegator_id)
1122 .bind(&d.delegate_id)
1123 .bind(d.permission.to_string())
1124 .bind(&scope_type)
1125 .bind(&scope_value)
1126 .bind(d.max_depth as i32)
1127 .bind(d.current_depth as i32)
1128 .bind(d.parent_delegation_id)
1129 .bind(&d.created_at)
1130 .bind(&d.expires_at)
1131 .bind(&d.revoked_at)
1132 .execute(&self.pool)
1133 .await
1134 .map_err(map_sqlx)?;
1135 Ok(())
1136 }
1137
1138 async fn list_delegations_for(&self, delegate_id: &str) -> Result<Vec<Delegation>> {
1139 let now = chrono::Utc::now().to_rfc3339();
1140 let rows = sqlx::query(
1141 r#"
1142SELECT id, delegator_id, delegate_id, permission, scope_type, scope_value,
1143 max_depth, current_depth, parent_delegation_id,
1144 created_at, expires_at, revoked_at
1145FROM delegations
1146WHERE delegate_id = $1 AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > $2)
1147"#,
1148 )
1149 .bind(delegate_id)
1150 .bind(&now)
1151 .fetch_all(&self.pool)
1152 .await
1153 .map_err(map_sqlx)?;
1154
1155 let mut results = Vec::with_capacity(rows.len());
1156 for r in &rows {
1157 results.push(row_to_delegation(r).map_err(map_sqlx)?);
1158 }
1159 Ok(results)
1160 }
1161
1162 async fn revoke_delegation(&self, id: Uuid) -> Result<()> {
1163 let now = chrono::Utc::now().to_rfc3339();
1164 let result = sqlx::query(
1165 "UPDATE delegations SET revoked_at = $1 WHERE id = $2 AND revoked_at IS NULL",
1166 )
1167 .bind(&now)
1168 .bind(id)
1169 .execute(&self.pool)
1170 .await
1171 .map_err(map_sqlx)?;
1172
1173 if result.rows_affected() == 0 {
1174 return Err(Error::NotFound(format!(
1175 "delegation {id} not found or already revoked"
1176 )));
1177 }
1178 Ok(())
1179 }
1180
1181 async fn check_delegation(
1182 &self,
1183 delegate_id: &str,
1184 memory_id: Uuid,
1185 required: Permission,
1186 ) -> Result<bool> {
1187 let delegations = self.list_delegations_for(delegate_id).await?;
1188
1189 let memory = match self.get_memory(memory_id).await? {
1191 Some(m) => m,
1192 None => return Ok(false),
1193 };
1194
1195 for d in &delegations {
1196 if !d.permission.satisfies(required) {
1197 continue;
1198 }
1199 match &d.scope {
1200 DelegationScope::AllMemories => return Ok(true),
1201 DelegationScope::ByMemoryId(ids) => {
1202 if ids.contains(&memory_id) {
1203 return Ok(true);
1204 }
1205 }
1206 DelegationScope::ByTag(tags) => {
1207 if tags.iter().any(|t| memory.tags.contains(t)) {
1208 return Ok(true);
1209 }
1210 }
1211 }
1212 }
1213 Ok(false)
1214 }
1215
1216 async fn insert_or_update_agent_profile(&self, profile: &AgentProfile) -> Result<()> {
1221 sqlx::query(
1222 r#"
1223INSERT INTO agent_profiles (agent_id, avg_importance, avg_content_length, total_memories, last_updated)
1224VALUES ($1, $2, $3, $4, $5)
1225ON CONFLICT (agent_id) DO UPDATE SET
1226 avg_importance = EXCLUDED.avg_importance,
1227 avg_content_length = EXCLUDED.avg_content_length,
1228 total_memories = EXCLUDED.total_memories,
1229 last_updated = EXCLUDED.last_updated
1230"#,
1231 )
1232 .bind(&profile.agent_id)
1233 .bind(profile.avg_importance)
1234 .bind(profile.avg_content_length)
1235 .bind(profile.total_memories as i64)
1236 .bind(&profile.last_updated)
1237 .execute(&self.pool)
1238 .await
1239 .map_err(map_sqlx)?;
1240 Ok(())
1241 }
1242
1243 async fn get_agent_profile(&self, agent_id: &str) -> Result<Option<AgentProfile>> {
1244 let row = sqlx::query(
1245 "SELECT agent_id, avg_importance, avg_content_length, total_memories, last_updated FROM agent_profiles WHERE agent_id = $1",
1246 )
1247 .bind(agent_id)
1248 .fetch_optional(&self.pool)
1249 .await
1250 .map_err(map_sqlx)?;
1251
1252 Ok(row.map(|r| AgentProfile {
1253 agent_id: r.get("agent_id"),
1254 avg_importance: r.get("avg_importance"),
1255 avg_content_length: r.get("avg_content_length"),
1256 total_memories: r.get::<i64, _>("total_memories") as u64,
1257 last_updated: r.get("last_updated"),
1258 }))
1259 }
1260
1261 async fn insert_or_update_embedding_baseline(
1266 &self,
1267 baseline: &EmbeddingBaseline,
1268 ) -> Result<()> {
1269 let mu_json =
1270 serde_json::to_value(&baseline.mu).map_err(|e| Error::Storage(e.to_string()))?;
1271 let cov_json =
1272 serde_json::to_value(&baseline.cov_diag).map_err(|e| Error::Storage(e.to_string()))?;
1273 sqlx::query(
1274 r#"
1275INSERT INTO embedding_baseline (agent_id, mu, cov_diag, n, updated_at)
1276VALUES ($1, $2, $3, $4, $5)
1277ON CONFLICT (agent_id) DO UPDATE SET
1278 mu = EXCLUDED.mu,
1279 cov_diag = EXCLUDED.cov_diag,
1280 n = EXCLUDED.n,
1281 updated_at = EXCLUDED.updated_at
1282"#,
1283 )
1284 .bind(&baseline.agent_id)
1285 .bind(&mu_json)
1286 .bind(&cov_json)
1287 .bind(baseline.n as i64)
1288 .bind(&baseline.updated_at)
1289 .execute(&self.pool)
1290 .await
1291 .map_err(map_sqlx)?;
1292 Ok(())
1293 }
1294
1295 async fn get_embedding_baseline(&self, agent_id: &str) -> Result<Option<EmbeddingBaseline>> {
1296 let row = sqlx::query(
1297 "SELECT agent_id, mu, cov_diag, n, updated_at FROM embedding_baseline WHERE agent_id = $1",
1298 )
1299 .bind(agent_id)
1300 .fetch_optional(&self.pool)
1301 .await
1302 .map_err(map_sqlx)?;
1303
1304 match row {
1305 None => Ok(None),
1306 Some(r) => {
1307 let mu_val: serde_json::Value = r.get("mu");
1308 let cov_val: serde_json::Value = r.get("cov_diag");
1309 let mu: Vec<f32> =
1310 serde_json::from_value(mu_val).map_err(|e| Error::Storage(e.to_string()))?;
1311 let cov_diag: Vec<f32> =
1312 serde_json::from_value(cov_val).map_err(|e| Error::Storage(e.to_string()))?;
1313 Ok(Some(EmbeddingBaseline {
1314 agent_id: r.get("agent_id"),
1315 mu,
1316 cov_diag,
1317 n: r.get::<i64, _>("n") as u64,
1318 updated_at: r.get("updated_at"),
1319 }))
1320 }
1321 }
1322 }
1323
1324 async fn insert_checkpoint(&self, cp: &Checkpoint) -> Result<()> {
1329 let memory_refs_strs: Vec<String> =
1330 cp.memory_refs.iter().map(|id| id.to_string()).collect();
1331 let refs_slice: &[String] = &memory_refs_strs;
1332
1333 sqlx::query(
1334 r#"
1335INSERT INTO checkpoints (
1336 id, thread_id, agent_id, parent_id, branch_name,
1337 state_snapshot, state_diff, memory_refs, event_cursor,
1338 label, created_at, metadata
1339) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
1340"#,
1341 )
1342 .bind(cp.id)
1343 .bind(&cp.thread_id)
1344 .bind(&cp.agent_id)
1345 .bind(cp.parent_id)
1346 .bind(&cp.branch_name)
1347 .bind(&cp.state_snapshot)
1348 .bind(&cp.state_diff)
1349 .bind(refs_slice)
1350 .bind(cp.event_cursor)
1351 .bind(&cp.label)
1352 .bind(&cp.created_at)
1353 .bind(&cp.metadata)
1354 .execute(&self.pool)
1355 .await
1356 .map_err(map_sqlx)?;
1357 Ok(())
1358 }
1359
1360 async fn get_checkpoint(&self, id: Uuid) -> Result<Option<Checkpoint>> {
1361 let row = sqlx::query(
1362 r#"
1363SELECT id, thread_id, agent_id, parent_id, branch_name,
1364 state_snapshot, state_diff, memory_refs, event_cursor,
1365 label, created_at, metadata
1366FROM checkpoints WHERE id = $1
1367"#,
1368 )
1369 .bind(id)
1370 .fetch_optional(&self.pool)
1371 .await
1372 .map_err(map_sqlx)?;
1373
1374 match row {
1375 Some(r) => Ok(Some(row_to_checkpoint(&r).map_err(map_sqlx)?)),
1376 None => Ok(None),
1377 }
1378 }
1379
1380 async fn list_checkpoints(
1381 &self,
1382 thread_id: &str,
1383 branch: Option<&str>,
1384 limit: usize,
1385 ) -> Result<Vec<Checkpoint>> {
1386 let rows = if let Some(branch_name) = branch {
1387 sqlx::query(
1388 r#"
1389SELECT id, thread_id, agent_id, parent_id, branch_name,
1390 state_snapshot, state_diff, memory_refs, event_cursor,
1391 label, created_at, metadata
1392FROM checkpoints
1393WHERE thread_id = $1 AND branch_name = $2
1394ORDER BY created_at DESC
1395LIMIT $3
1396"#,
1397 )
1398 .bind(thread_id)
1399 .bind(branch_name)
1400 .bind(limit as i64)
1401 .fetch_all(&self.pool)
1402 .await
1403 .map_err(map_sqlx)?
1404 } else {
1405 sqlx::query(
1406 r#"
1407SELECT id, thread_id, agent_id, parent_id, branch_name,
1408 state_snapshot, state_diff, memory_refs, event_cursor,
1409 label, created_at, metadata
1410FROM checkpoints
1411WHERE thread_id = $1
1412ORDER BY created_at DESC
1413LIMIT $2
1414"#,
1415 )
1416 .bind(thread_id)
1417 .bind(limit as i64)
1418 .fetch_all(&self.pool)
1419 .await
1420 .map_err(map_sqlx)?
1421 };
1422
1423 let mut results = Vec::with_capacity(rows.len());
1424 for r in &rows {
1425 results.push(row_to_checkpoint(r).map_err(map_sqlx)?);
1426 }
1427 Ok(results)
1428 }
1429
1430 async fn get_latest_checkpoint(
1431 &self,
1432 thread_id: &str,
1433 branch: &str,
1434 ) -> Result<Option<Checkpoint>> {
1435 let row = sqlx::query(
1436 r#"
1437SELECT id, thread_id, agent_id, parent_id, branch_name,
1438 state_snapshot, state_diff, memory_refs, event_cursor,
1439 label, created_at, metadata
1440FROM checkpoints
1441WHERE thread_id = $1 AND branch_name = $2
1442ORDER BY created_at DESC
1443LIMIT 1
1444"#,
1445 )
1446 .bind(thread_id)
1447 .bind(branch)
1448 .fetch_optional(&self.pool)
1449 .await
1450 .map_err(map_sqlx)?;
1451
1452 match row {
1453 Some(r) => Ok(Some(row_to_checkpoint(&r).map_err(map_sqlx)?)),
1454 None => Ok(None),
1455 }
1456 }
1457}