Skip to main content

mnemo_postgres/
storage.rs

1use mnemo_core::error::{Error, Result};
2use mnemo_core::model::acl::{Acl, Permission};
3use mnemo_core::model::agent_profile::AgentProfile;
4use mnemo_core::model::checkpoint::Checkpoint;
5use mnemo_core::model::delegation::{Delegation, DelegationScope};
6use mnemo_core::model::embedding_baseline::EmbeddingBaseline;
7use mnemo_core::model::event::AgentEvent;
8use mnemo_core::model::memory::MemoryRecord;
9use mnemo_core::model::relation::Relation;
10use mnemo_core::storage::{MemoryFilter, StorageBackend};
11use pgvector::Vector;
12use sqlx::Row;
13use uuid::Uuid;
14
15/// PostgreSQL-backed storage for Mnemo.
16///
17/// Wraps a `sqlx::PgPool` and runs schema migrations on construction.
18/// Embeddings are stored using the pgvector `vector` column type, while
19/// event embeddings are stored as `BYTEA` (serialised `Vec<f32>` in
20/// little-endian byte order), matching the DuckDB backend convention.
21pub struct PgStorage {
22    pool: sqlx::PgPool,
23    #[allow(dead_code)]
24    dimensions: usize,
25}
26
27impl PgStorage {
28    /// Connect to a PostgreSQL database and run migrations.
29    ///
30    /// `url` is a standard `postgres://` connection string.
31    /// `dimensions` controls the width of the pgvector `vector` column.
32    pub async fn connect(url: &str, dimensions: usize) -> Result<Self> {
33        let pool = sqlx::PgPool::connect(url)
34            .await
35            .map_err(|e| Error::Storage(e.to_string()))?;
36        let storage = Self { pool, dimensions };
37        crate::migrations::run_migrations(&storage.pool, dimensions).await?;
38        Ok(storage)
39    }
40
41    /// Build a `PgStorage` from an existing pool (useful for tests).
42    pub async fn from_pool(pool: sqlx::PgPool, dimensions: usize) -> Result<Self> {
43        crate::migrations::run_migrations(&pool, dimensions).await?;
44        Ok(Self { pool, dimensions })
45    }
46}
47
48// ---------------------------------------------------------------------------
49// Helpers
50// ---------------------------------------------------------------------------
51
52fn map_sqlx(e: sqlx::Error) -> Error {
53    Error::Storage(e.to_string())
54}
55
56fn serialize_embedding(embedding: &Option<Vec<f32>>) -> Option<Vec<u8>> {
57    embedding
58        .as_ref()
59        .map(|v| v.iter().flat_map(|f| f.to_le_bytes()).collect())
60}
61
62fn deserialize_embedding(blob: Option<Vec<u8>>) -> Option<Vec<f32>> {
63    blob.map(|bytes| {
64        bytes
65            .chunks_exact(4)
66            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
67            .collect()
68    })
69}
70
71fn row_to_memory(row: &sqlx::postgres::PgRow) -> std::result::Result<MemoryRecord, sqlx::Error> {
72    let tags: Vec<String> = row.try_get::<Vec<String>, _>("tags").unwrap_or_default();
73    let metadata: serde_json::Value = row
74        .try_get("metadata")
75        .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
76
77    // pgvector stores the embedding as its own type; we retrieve the raw text
78    // representation and parse back to Vec<f32>. If the column is NULL we get None.
79    let embedding: Option<Vec<f32>> = {
80        let raw: Option<String> = row.try_get("embedding_text").ok().flatten();
81        raw.and_then(|s| {
82            // pgvector text output looks like "[0.1,0.2,0.3]"
83            let trimmed = s.trim_start_matches('[').trim_end_matches(']');
84            if trimmed.is_empty() {
85                None
86            } else {
87                Some(
88                    trimmed
89                        .split(',')
90                        .filter_map(|v| v.trim().parse::<f32>().ok())
91                        .collect(),
92                )
93            }
94        })
95    };
96
97    Ok(MemoryRecord {
98        id: row.get("id"),
99        agent_id: row.get("agent_id"),
100        content: row.get("content"),
101        memory_type: row
102            .get::<String, _>("memory_type")
103            .parse()
104            .unwrap_or(mnemo_core::model::memory::MemoryType::Semantic),
105        scope: row
106            .get::<String, _>("scope")
107            .parse()
108            .unwrap_or(mnemo_core::model::memory::Scope::Private),
109        importance: row.get("importance"),
110        tags,
111        metadata,
112        embedding,
113        content_hash: row.get("content_hash"),
114        prev_hash: row.get("prev_hash"),
115        source_type: row
116            .get::<String, _>("source_type")
117            .parse()
118            .unwrap_or(mnemo_core::model::memory::SourceType::Agent),
119        source_id: row.get("source_id"),
120        consolidation_state: row
121            .get::<String, _>("consolidation_state")
122            .parse()
123            .unwrap_or(mnemo_core::model::memory::ConsolidationState::Raw),
124        access_count: row.get::<i64, _>("access_count") as u64,
125        org_id: row.get("org_id"),
126        thread_id: row.get("thread_id"),
127        created_at: row.get("created_at"),
128        updated_at: row.get("updated_at"),
129        last_accessed_at: row.get("last_accessed_at"),
130        expires_at: row.get("expires_at"),
131        deleted_at: row.get("deleted_at"),
132        decay_rate: row.get("decay_rate"),
133        created_by: row.get("created_by"),
134        version: row.get::<i32, _>("version") as u32,
135        prev_version_id: row.get("prev_version_id"),
136        quarantined: row.get("quarantined"),
137        quarantine_reason: row.get("quarantine_reason"),
138        decay_function: row.get("decay_function"),
139    })
140}
141
142/// The standard SELECT column list for the memories table.
143/// We cast the pgvector `embedding` column to text so we can parse it
144/// back into `Vec<f32>` without depending on a pgvector Rust decode path.
145const MEMORY_COLUMNS: &str = r#"
146    id, agent_id, content, memory_type, scope, importance,
147    tags, metadata, embedding::text AS embedding_text,
148    content_hash, prev_hash, source_type, source_id,
149    consolidation_state, access_count, org_id, thread_id,
150    created_at, updated_at, last_accessed_at, expires_at,
151    deleted_at, decay_rate, created_by, version, prev_version_id,
152    quarantined, quarantine_reason, decay_function
153"#;
154
155fn row_to_event(row: &sqlx::postgres::PgRow) -> std::result::Result<AgentEvent, sqlx::Error> {
156    let payload: serde_json::Value = row.try_get("payload").unwrap_or(serde_json::Value::Null);
157    let embedding_blob: Option<Vec<u8>> = row.try_get("embedding").unwrap_or(None);
158
159    Ok(AgentEvent {
160        id: row.get("id"),
161        agent_id: row.get("agent_id"),
162        thread_id: row.get("thread_id"),
163        run_id: row.get("run_id"),
164        parent_event_id: row.get("parent_event_id"),
165        event_type: row
166            .get::<String, _>("event_type")
167            .parse()
168            .unwrap_or(mnemo_core::model::event::EventType::Error),
169        payload,
170        trace_id: row.get("trace_id"),
171        span_id: row.get("span_id"),
172        model: row.get("model"),
173        tokens_input: row.get("tokens_input"),
174        tokens_output: row.get("tokens_output"),
175        latency_ms: row.get("latency_ms"),
176        cost_usd: row.get("cost_usd"),
177        timestamp: row.get("timestamp"),
178        logical_clock: row.get("logical_clock"),
179        content_hash: row.get("content_hash"),
180        prev_hash: row.get("prev_hash"),
181        embedding: deserialize_embedding(embedding_blob),
182    })
183}
184
185fn row_to_relation(row: &sqlx::postgres::PgRow) -> std::result::Result<Relation, sqlx::Error> {
186    let metadata: serde_json::Value = row
187        .try_get("metadata")
188        .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
189
190    Ok(Relation {
191        id: row.get("id"),
192        source_id: row.get("source_id"),
193        target_id: row.get("target_id"),
194        relation_type: row.get("relation_type"),
195        weight: row.get("weight"),
196        metadata,
197        created_at: row.get("created_at"),
198    })
199}
200
201fn row_to_checkpoint(row: &sqlx::postgres::PgRow) -> std::result::Result<Checkpoint, sqlx::Error> {
202    let state_snapshot: serde_json::Value = row
203        .try_get("state_snapshot")
204        .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
205    let state_diff: Option<serde_json::Value> = row.try_get("state_diff").unwrap_or(None);
206
207    // memory_refs is stored as TEXT[] of UUID strings
208    let memory_refs_raw: Vec<String> = row.try_get("memory_refs").unwrap_or_default();
209    let memory_refs: Vec<Uuid> = memory_refs_raw
210        .iter()
211        .filter_map(|s| Uuid::parse_str(s).ok())
212        .collect();
213
214    let metadata: serde_json::Value = row
215        .try_get("metadata")
216        .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
217
218    Ok(Checkpoint {
219        id: row.get("id"),
220        thread_id: row.get("thread_id"),
221        agent_id: row.get("agent_id"),
222        parent_id: row.get("parent_id"),
223        branch_name: row.get("branch_name"),
224        state_snapshot,
225        state_diff,
226        memory_refs,
227        event_cursor: row.get("event_cursor"),
228        label: row.get("label"),
229        created_at: row.get("created_at"),
230        metadata,
231    })
232}
233
234fn row_to_delegation(row: &sqlx::postgres::PgRow) -> std::result::Result<Delegation, sqlx::Error> {
235    let scope_type: String = row.get("scope_type");
236    let scope_value: Option<serde_json::Value> = row.try_get("scope_value").unwrap_or(None);
237
238    let scope = match scope_type.as_str() {
239        "by_tag" => {
240            let tags: Vec<String> = scope_value
241                .and_then(|v| serde_json::from_value(v).ok())
242                .unwrap_or_default();
243            DelegationScope::ByTag(tags)
244        }
245        "by_memory_id" => {
246            let id_strs: Vec<String> = scope_value
247                .and_then(|v| serde_json::from_value(v).ok())
248                .unwrap_or_default();
249            let uuids = id_strs
250                .into_iter()
251                .filter_map(|s| Uuid::parse_str(&s).ok())
252                .collect();
253            DelegationScope::ByMemoryId(uuids)
254        }
255        _ => DelegationScope::AllMemories,
256    };
257
258    Ok(Delegation {
259        id: row.get("id"),
260        delegator_id: row.get("delegator_id"),
261        delegate_id: row.get("delegate_id"),
262        permission: row
263            .get::<String, _>("permission")
264            .parse()
265            .unwrap_or(Permission::Read),
266        scope,
267        max_depth: row.get::<i32, _>("max_depth") as u32,
268        current_depth: row.get::<i32, _>("current_depth") as u32,
269        parent_delegation_id: row.get("parent_delegation_id"),
270        created_at: row.get("created_at"),
271        expires_at: row.get("expires_at"),
272        revoked_at: row.get("revoked_at"),
273    })
274}
275
276// ---------------------------------------------------------------------------
277// StorageBackend implementation
278// ---------------------------------------------------------------------------
279
280#[async_trait::async_trait]
281impl StorageBackend for PgStorage {
282    // -----------------------------------------------------------------------
283    // Memory CRUD
284    // -----------------------------------------------------------------------
285
286    async fn insert_memory(&self, record: &MemoryRecord) -> Result<()> {
287        let embedding_param: Option<Vector> =
288            record.embedding.as_ref().map(|v| Vector::from(v.clone()));
289
290        let tags_slice: &[String] = &record.tags;
291
292        sqlx::query(
293            r#"
294INSERT INTO memories (
295    id, agent_id, content, memory_type, scope, importance,
296    tags, metadata, embedding,
297    content_hash, prev_hash, source_type, source_id,
298    consolidation_state, access_count, org_id, thread_id,
299    created_at, updated_at, last_accessed_at, expires_at,
300    deleted_at, decay_rate, created_by, version, prev_version_id,
301    quarantined, quarantine_reason, decay_function
302) VALUES (
303    $1, $2, $3, $4, $5, $6,
304    $7, $8, $9,
305    $10, $11, $12, $13,
306    $14, $15, $16, $17,
307    $18, $19, $20, $21,
308    $22, $23, $24, $25, $26,
309    $27, $28, $29
310)
311"#,
312        )
313        .bind(record.id)
314        .bind(&record.agent_id)
315        .bind(&record.content)
316        .bind(record.memory_type.to_string())
317        .bind(record.scope.to_string())
318        .bind(record.importance)
319        .bind(tags_slice)
320        .bind(&record.metadata)
321        .bind(&embedding_param)
322        .bind(&record.content_hash)
323        .bind(&record.prev_hash)
324        .bind(record.source_type.to_string())
325        .bind(&record.source_id)
326        .bind(record.consolidation_state.to_string())
327        .bind(record.access_count as i64)
328        .bind(&record.org_id)
329        .bind(&record.thread_id)
330        .bind(&record.created_at)
331        .bind(&record.updated_at)
332        .bind(&record.last_accessed_at)
333        .bind(&record.expires_at)
334        .bind(&record.deleted_at)
335        .bind(record.decay_rate)
336        .bind(&record.created_by)
337        .bind(record.version as i32)
338        .bind(record.prev_version_id)
339        .bind(record.quarantined)
340        .bind(&record.quarantine_reason)
341        .bind(&record.decay_function)
342        .execute(&self.pool)
343        .await
344        .map_err(map_sqlx)?;
345
346        Ok(())
347    }
348
349    async fn get_memory(&self, id: Uuid) -> Result<Option<MemoryRecord>> {
350        let sql = format!("SELECT {MEMORY_COLUMNS} FROM memories WHERE id = $1");
351        let row = sqlx::query(&sql)
352            .bind(id)
353            .fetch_optional(&self.pool)
354            .await
355            .map_err(map_sqlx)?;
356
357        match row {
358            Some(r) => Ok(Some(row_to_memory(&r).map_err(map_sqlx)?)),
359            None => Ok(None),
360        }
361    }
362
363    async fn update_memory(&self, record: &MemoryRecord) -> Result<()> {
364        let embedding_param: Option<Vector> =
365            record.embedding.as_ref().map(|v| Vector::from(v.clone()));
366
367        let tags_slice: &[String] = &record.tags;
368
369        let result = sqlx::query(
370            r#"
371UPDATE memories SET
372    agent_id = $1, content = $2, memory_type = $3, scope = $4,
373    importance = $5, tags = $6, metadata = $7,
374    embedding = $8,
375    content_hash = $9, prev_hash = $10, source_type = $11,
376    source_id = $12, consolidation_state = $13, access_count = $14,
377    org_id = $15, thread_id = $16, updated_at = $17,
378    last_accessed_at = $18, expires_at = $19, deleted_at = $20,
379    decay_rate = $21, created_by = $22, version = $23,
380    prev_version_id = $24, quarantined = $25, quarantine_reason = $26,
381    decay_function = $27
382WHERE id = $28
383"#,
384        )
385        .bind(&record.agent_id)
386        .bind(&record.content)
387        .bind(record.memory_type.to_string())
388        .bind(record.scope.to_string())
389        .bind(record.importance)
390        .bind(tags_slice)
391        .bind(&record.metadata)
392        .bind(&embedding_param)
393        .bind(&record.content_hash)
394        .bind(&record.prev_hash)
395        .bind(record.source_type.to_string())
396        .bind(&record.source_id)
397        .bind(record.consolidation_state.to_string())
398        .bind(record.access_count as i64)
399        .bind(&record.org_id)
400        .bind(&record.thread_id)
401        .bind(&record.updated_at)
402        .bind(&record.last_accessed_at)
403        .bind(&record.expires_at)
404        .bind(&record.deleted_at)
405        .bind(record.decay_rate)
406        .bind(&record.created_by)
407        .bind(record.version as i32)
408        .bind(record.prev_version_id)
409        .bind(record.quarantined)
410        .bind(&record.quarantine_reason)
411        .bind(&record.decay_function)
412        .bind(record.id)
413        .execute(&self.pool)
414        .await
415        .map_err(map_sqlx)?;
416
417        if result.rows_affected() == 0 {
418            return Err(Error::NotFound(format!("memory {} not found", record.id)));
419        }
420        Ok(())
421    }
422
423    async fn soft_delete_memory(&self, id: Uuid) -> Result<()> {
424        let now = chrono::Utc::now().to_rfc3339();
425        let result = sqlx::query(
426            "UPDATE memories SET deleted_at = $1, updated_at = $2 WHERE id = $3 AND deleted_at IS NULL",
427        )
428        .bind(&now)
429        .bind(&now)
430        .bind(id)
431        .execute(&self.pool)
432        .await
433        .map_err(map_sqlx)?;
434
435        if result.rows_affected() == 0 {
436            return Err(Error::NotFound(format!(
437                "memory {id} not found or already deleted"
438            )));
439        }
440        Ok(())
441    }
442
443    async fn hard_delete_memory(&self, id: Uuid) -> Result<()> {
444        let result = sqlx::query("DELETE FROM memories WHERE id = $1")
445            .bind(id)
446            .execute(&self.pool)
447            .await
448            .map_err(map_sqlx)?;
449
450        if result.rows_affected() == 0 {
451            return Err(Error::NotFound(format!("memory {id} not found")));
452        }
453
454        // Clean up ACLs for this memory
455        sqlx::query("DELETE FROM acls WHERE memory_id = $1")
456            .bind(id)
457            .execute(&self.pool)
458            .await
459            .map_err(map_sqlx)?;
460
461        Ok(())
462    }
463
464    async fn list_memories(
465        &self,
466        filter: &MemoryFilter,
467        limit: usize,
468        offset: usize,
469    ) -> Result<Vec<MemoryRecord>> {
470        let mut conditions: Vec<String> = Vec::new();
471        // We'll track bind-parameter index manually.
472        // The MEMORY_COLUMNS select doesn't use numbered params.
473        let mut param_idx: usize = 0;
474
475        // We accumulate bind values in a specific order and push them later
476        // via a dynamic query builder. Unfortunately sqlx's dynamic queries
477        // require us to build the SQL string with numbered placeholders and
478        // bind all values in order.
479
480        // We'll collect (sql_fragment, value_type) tuples, then bind them.
481        // Use a simpler approach: build the query string, then bind
482        // parameters positionally.
483
484        if !filter.include_deleted {
485            conditions.push("deleted_at IS NULL".to_string());
486        }
487
488        // We'll use an enum-based approach below to track what to bind.
489        #[derive(Debug)]
490        enum Param {
491            Str(String),
492            F32(f32),
493        }
494        let mut params: Vec<Param> = Vec::new();
495
496        if let Some(ref agent_id) = filter.agent_id {
497            param_idx += 1;
498            conditions.push(format!("agent_id = ${param_idx}"));
499            params.push(Param::Str(agent_id.clone()));
500        }
501        if let Some(memory_type) = filter.memory_type {
502            param_idx += 1;
503            conditions.push(format!("memory_type = ${param_idx}"));
504            params.push(Param::Str(memory_type.to_string()));
505        }
506        if let Some(scope) = filter.scope {
507            param_idx += 1;
508            conditions.push(format!("scope = ${param_idx}"));
509            params.push(Param::Str(scope.to_string()));
510        }
511        if let Some(min_importance) = filter.min_importance {
512            param_idx += 1;
513            conditions.push(format!("importance >= ${param_idx}"));
514            params.push(Param::F32(min_importance));
515        }
516        if let Some(ref org_id) = filter.org_id {
517            param_idx += 1;
518            conditions.push(format!("org_id = ${param_idx}"));
519            params.push(Param::Str(org_id.clone()));
520        }
521        if let Some(ref thread_id) = filter.thread_id {
522            param_idx += 1;
523            conditions.push(format!("thread_id = ${param_idx}"));
524            params.push(Param::Str(thread_id.clone()));
525        }
526
527        let where_clause = if conditions.is_empty() {
528            String::new()
529        } else {
530            format!("WHERE {}", conditions.join(" AND "))
531        };
532
533        let sql = format!(
534            "SELECT {MEMORY_COLUMNS} FROM memories {where_clause} ORDER BY created_at DESC LIMIT {limit} OFFSET {offset}"
535        );
536
537        let mut query = sqlx::query(&sql);
538        for p in &params {
539            match p {
540                Param::Str(s) => query = query.bind(s),
541                Param::F32(f) => query = query.bind(*f),
542            }
543        }
544
545        let rows = query.fetch_all(&self.pool).await.map_err(map_sqlx)?;
546        let mut results = Vec::with_capacity(rows.len());
547        for r in &rows {
548            results.push(row_to_memory(r).map_err(map_sqlx)?);
549        }
550        Ok(results)
551    }
552
553    async fn touch_memory(&self, id: Uuid) -> Result<()> {
554        let now = chrono::Utc::now().to_rfc3339();
555        sqlx::query(
556            "UPDATE memories SET access_count = access_count + 1, last_accessed_at = $1 WHERE id = $2",
557        )
558        .bind(&now)
559        .bind(id)
560        .execute(&self.pool)
561        .await
562        .map_err(map_sqlx)?;
563        Ok(())
564    }
565
566    // -----------------------------------------------------------------------
567    // ACL
568    // -----------------------------------------------------------------------
569
570    async fn insert_acl(&self, acl: &Acl) -> Result<()> {
571        sqlx::query(
572            r#"
573INSERT INTO acls (id, memory_id, principal_type, principal_id, permission, granted_by, created_at, expires_at)
574VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
575"#,
576        )
577        .bind(acl.id)
578        .bind(acl.memory_id)
579        .bind(acl.principal_type.to_string())
580        .bind(&acl.principal_id)
581        .bind(acl.permission.to_string())
582        .bind(&acl.granted_by)
583        .bind(&acl.created_at)
584        .bind(&acl.expires_at)
585        .execute(&self.pool)
586        .await
587        .map_err(map_sqlx)?;
588        Ok(())
589    }
590
591    async fn check_permission(
592        &self,
593        memory_id: Uuid,
594        principal_id: &str,
595        required: Permission,
596    ) -> Result<bool> {
597        // Check if the principal is the owner
598        let owner_row = sqlx::query("SELECT agent_id FROM memories WHERE id = $1")
599            .bind(memory_id)
600            .fetch_optional(&self.pool)
601            .await
602            .map_err(map_sqlx)?;
603
604        match owner_row {
605            None => return Err(Error::NotFound(format!("memory {memory_id} not found"))),
606            Some(row) => {
607                let owner: String = row.get("agent_id");
608                if owner == principal_id {
609                    return Ok(true);
610                }
611            }
612        }
613
614        // Check ACLs (direct grants)
615        let now = chrono::Utc::now().to_rfc3339();
616        let acl_rows = sqlx::query(
617            "SELECT permission FROM acls WHERE memory_id = $1 AND principal_id = $2 AND (expires_at IS NULL OR expires_at > $3)",
618        )
619        .bind(memory_id)
620        .bind(principal_id)
621        .bind(&now)
622        .fetch_all(&self.pool)
623        .await
624        .map_err(map_sqlx)?;
625
626        for row in &acl_rows {
627            let perm_str: String = row.get("permission");
628            if let Ok(perm) = perm_str.parse::<Permission>()
629                && perm.satisfies(required)
630            {
631                return Ok(true);
632            }
633        }
634
635        // Check public ACLs
636        let public_rows = sqlx::query(
637            "SELECT permission FROM acls WHERE memory_id = $1 AND principal_type = 'public' AND (expires_at IS NULL OR expires_at > $2)",
638        )
639        .bind(memory_id)
640        .bind(&now)
641        .fetch_all(&self.pool)
642        .await
643        .map_err(map_sqlx)?;
644
645        for row in &public_rows {
646            let perm_str: String = row.get("permission");
647            if let Ok(perm) = perm_str.parse::<Permission>()
648                && perm.satisfies(required)
649            {
650                return Ok(true);
651            }
652        }
653
654        // Check delegations
655        if self
656            .check_delegation(principal_id, memory_id, required)
657            .await?
658        {
659            return Ok(true);
660        }
661
662        Ok(false)
663    }
664
665    // -----------------------------------------------------------------------
666    // Relations
667    // -----------------------------------------------------------------------
668
669    async fn insert_relation(&self, relation: &Relation) -> Result<()> {
670        sqlx::query(
671            r#"
672INSERT INTO relations (id, source_id, target_id, relation_type, weight, metadata, created_at)
673VALUES ($1, $2, $3, $4, $5, $6, $7)
674"#,
675        )
676        .bind(relation.id)
677        .bind(relation.source_id)
678        .bind(relation.target_id)
679        .bind(&relation.relation_type)
680        .bind(relation.weight)
681        .bind(&relation.metadata)
682        .bind(&relation.created_at)
683        .execute(&self.pool)
684        .await
685        .map_err(map_sqlx)?;
686        Ok(())
687    }
688
689    async fn get_relations_from(&self, source_id: Uuid) -> Result<Vec<Relation>> {
690        let rows = sqlx::query(
691            "SELECT id, source_id, target_id, relation_type, weight, metadata, created_at FROM relations WHERE source_id = $1",
692        )
693        .bind(source_id)
694        .fetch_all(&self.pool)
695        .await
696        .map_err(map_sqlx)?;
697
698        let mut results = Vec::with_capacity(rows.len());
699        for r in &rows {
700            results.push(row_to_relation(r).map_err(map_sqlx)?);
701        }
702        Ok(results)
703    }
704
705    async fn get_relations_to(&self, target_id: Uuid) -> Result<Vec<Relation>> {
706        let rows = sqlx::query(
707            "SELECT id, source_id, target_id, relation_type, weight, metadata, created_at FROM relations WHERE target_id = $1",
708        )
709        .bind(target_id)
710        .fetch_all(&self.pool)
711        .await
712        .map_err(map_sqlx)?;
713
714        let mut results = Vec::with_capacity(rows.len());
715        for r in &rows {
716            results.push(row_to_relation(r).map_err(map_sqlx)?);
717        }
718        Ok(results)
719    }
720
721    async fn delete_relation(&self, id: Uuid) -> Result<()> {
722        let result = sqlx::query("DELETE FROM relations WHERE id = $1")
723            .bind(id)
724            .execute(&self.pool)
725            .await
726            .map_err(map_sqlx)?;
727
728        if result.rows_affected() == 0 {
729            return Err(Error::NotFound(format!("relation {id} not found")));
730        }
731        Ok(())
732    }
733
734    // -----------------------------------------------------------------------
735    // Chain linking
736    // -----------------------------------------------------------------------
737
738    async fn get_latest_memory_hash(
739        &self,
740        agent_id: &str,
741        thread_id: Option<&str>,
742    ) -> Result<Option<Vec<u8>>> {
743        let row = if let Some(tid) = thread_id {
744            sqlx::query(
745                "SELECT content_hash FROM memories WHERE agent_id = $1 AND thread_id = $2 AND deleted_at IS NULL ORDER BY created_at DESC LIMIT 1",
746            )
747            .bind(agent_id)
748            .bind(tid)
749            .fetch_optional(&self.pool)
750            .await
751            .map_err(map_sqlx)?
752        } else {
753            sqlx::query(
754                "SELECT content_hash FROM memories WHERE agent_id = $1 AND thread_id IS NULL AND deleted_at IS NULL ORDER BY created_at DESC LIMIT 1",
755            )
756            .bind(agent_id)
757            .fetch_optional(&self.pool)
758            .await
759            .map_err(map_sqlx)?
760        };
761
762        Ok(row.map(|r| r.get::<Vec<u8>, _>("content_hash")))
763    }
764
765    async fn get_latest_event_hash(
766        &self,
767        agent_id: &str,
768        thread_id: Option<&str>,
769    ) -> Result<Option<Vec<u8>>> {
770        let row = if let Some(tid) = thread_id {
771            sqlx::query(
772                "SELECT content_hash FROM agent_events WHERE agent_id = $1 AND thread_id = $2 ORDER BY timestamp DESC LIMIT 1",
773            )
774            .bind(agent_id)
775            .bind(tid)
776            .fetch_optional(&self.pool)
777            .await
778            .map_err(map_sqlx)?
779        } else {
780            sqlx::query(
781                "SELECT content_hash FROM agent_events WHERE agent_id = $1 ORDER BY timestamp DESC LIMIT 1",
782            )
783            .bind(agent_id)
784            .fetch_optional(&self.pool)
785            .await
786            .map_err(map_sqlx)?
787        };
788        Ok(row.map(|r| r.get::<Vec<u8>, _>("content_hash")))
789    }
790
791    async fn get_sync_watermark(&self, key: &str) -> Result<Option<String>> {
792        let row = sqlx::query("SELECT value FROM sync_metadata WHERE key = $1")
793            .bind(key)
794            .fetch_optional(&self.pool)
795            .await
796            .map_err(map_sqlx)?;
797        Ok(row.map(|r| r.get::<String, _>("value")))
798    }
799
800    async fn set_sync_watermark(&self, key: &str, value: &str) -> Result<()> {
801        let now = chrono::Utc::now().to_rfc3339();
802        sqlx::query(
803            "INSERT INTO sync_metadata (key, value, updated_at) VALUES ($1, $2, $3) ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = $3",
804        )
805        .bind(key)
806        .bind(value)
807        .bind(now)
808        .execute(&self.pool)
809        .await
810        .map_err(map_sqlx)?;
811        Ok(())
812    }
813
814    // -----------------------------------------------------------------------
815    // Permission-safe ANN
816    // -----------------------------------------------------------------------
817
818    async fn list_accessible_memory_ids(&self, agent_id: &str, limit: usize) -> Result<Vec<Uuid>> {
819        let now = chrono::Utc::now().to_rfc3339();
820        let rows = sqlx::query(
821            r#"
822SELECT id FROM memories
823WHERE (
824    agent_id = $1
825    OR scope = 'public'
826    OR id IN (
827        SELECT memory_id FROM acls
828        WHERE principal_id = $2 AND (expires_at IS NULL OR expires_at > $3)
829    )
830)
831AND deleted_at IS NULL
832LIMIT $4
833"#,
834        )
835        .bind(agent_id)
836        .bind(agent_id)
837        .bind(&now)
838        .bind(limit as i64)
839        .fetch_all(&self.pool)
840        .await
841        .map_err(map_sqlx)?;
842
843        let ids: Vec<Uuid> = rows.iter().map(|r| r.get("id")).collect();
844        Ok(ids)
845    }
846
847    // -----------------------------------------------------------------------
848    // Events
849    // -----------------------------------------------------------------------
850
851    async fn insert_event(&self, event: &AgentEvent) -> Result<()> {
852        let payload_json = &event.payload;
853        let embedding_blob = serialize_embedding(&event.embedding);
854
855        sqlx::query(
856            r#"
857INSERT INTO agent_events (
858    id, agent_id, thread_id, run_id, parent_event_id, event_type,
859    payload, trace_id, span_id, model, tokens_input, tokens_output,
860    latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
861    prev_hash, embedding
862) VALUES (
863    $1, $2, $3, $4, $5, $6,
864    $7, $8, $9, $10, $11, $12,
865    $13, $14, $15, $16, $17,
866    $18, $19
867)
868"#,
869        )
870        .bind(event.id)
871        .bind(&event.agent_id)
872        .bind(&event.thread_id)
873        .bind(&event.run_id)
874        .bind(event.parent_event_id)
875        .bind(event.event_type.to_string())
876        .bind(payload_json)
877        .bind(&event.trace_id)
878        .bind(&event.span_id)
879        .bind(&event.model)
880        .bind(event.tokens_input)
881        .bind(event.tokens_output)
882        .bind(event.latency_ms)
883        .bind(event.cost_usd)
884        .bind(&event.timestamp)
885        .bind(event.logical_clock)
886        .bind(&event.content_hash)
887        .bind(&event.prev_hash)
888        .bind(&embedding_blob)
889        .execute(&self.pool)
890        .await
891        .map_err(map_sqlx)?;
892        Ok(())
893    }
894
895    async fn list_events(
896        &self,
897        agent_id: &str,
898        limit: usize,
899        offset: usize,
900    ) -> Result<Vec<AgentEvent>> {
901        let rows = sqlx::query(
902            r#"
903SELECT id, agent_id, thread_id, run_id, parent_event_id, event_type,
904       payload, trace_id, span_id, model, tokens_input, tokens_output,
905       latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
906       prev_hash, embedding
907FROM agent_events
908WHERE agent_id = $1
909ORDER BY "timestamp" DESC
910LIMIT $2 OFFSET $3
911"#,
912        )
913        .bind(agent_id)
914        .bind(limit as i64)
915        .bind(offset as i64)
916        .fetch_all(&self.pool)
917        .await
918        .map_err(map_sqlx)?;
919
920        let mut results = Vec::with_capacity(rows.len());
921        for r in &rows {
922            results.push(row_to_event(r).map_err(map_sqlx)?);
923        }
924        Ok(results)
925    }
926
927    async fn get_events_by_thread(&self, thread_id: &str, limit: usize) -> Result<Vec<AgentEvent>> {
928        let rows = sqlx::query(
929            r#"
930SELECT id, agent_id, thread_id, run_id, parent_event_id, event_type,
931       payload, trace_id, span_id, model, tokens_input, tokens_output,
932       latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
933       prev_hash, embedding
934FROM agent_events
935WHERE thread_id = $1
936ORDER BY "timestamp" ASC
937LIMIT $2
938"#,
939        )
940        .bind(thread_id)
941        .bind(limit as i64)
942        .fetch_all(&self.pool)
943        .await
944        .map_err(map_sqlx)?;
945
946        let mut results = Vec::with_capacity(rows.len());
947        for r in &rows {
948            results.push(row_to_event(r).map_err(map_sqlx)?);
949        }
950        Ok(results)
951    }
952
953    async fn get_event(&self, id: Uuid) -> Result<Option<AgentEvent>> {
954        let row = sqlx::query(
955            r#"
956SELECT id, agent_id, thread_id, run_id, parent_event_id, event_type,
957       payload, trace_id, span_id, model, tokens_input, tokens_output,
958       latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
959       prev_hash, embedding
960FROM agent_events
961WHERE id = $1
962"#,
963        )
964        .bind(id)
965        .fetch_optional(&self.pool)
966        .await
967        .map_err(map_sqlx)?;
968
969        match row {
970            Some(r) => Ok(Some(row_to_event(&r).map_err(map_sqlx)?)),
971            None => Ok(None),
972        }
973    }
974
975    async fn list_child_events(
976        &self,
977        parent_event_id: Uuid,
978        limit: usize,
979    ) -> Result<Vec<AgentEvent>> {
980        let rows = sqlx::query(
981            r#"
982SELECT id, agent_id, thread_id, run_id, parent_event_id, event_type,
983       payload, trace_id, span_id, model, tokens_input, tokens_output,
984       latency_ms, cost_usd, "timestamp", logical_clock, content_hash,
985       prev_hash, embedding
986FROM agent_events
987WHERE parent_event_id = $1
988ORDER BY "timestamp" ASC
989LIMIT $2
990"#,
991        )
992        .bind(parent_event_id)
993        .bind(limit as i64)
994        .fetch_all(&self.pool)
995        .await
996        .map_err(map_sqlx)?;
997
998        let mut results = Vec::with_capacity(rows.len());
999        for r in &rows {
1000            results.push(row_to_event(r).map_err(map_sqlx)?);
1001        }
1002        Ok(results)
1003    }
1004
1005    // -----------------------------------------------------------------------
1006    // Ordered listing
1007    // -----------------------------------------------------------------------
1008
1009    async fn list_memories_by_agent_ordered(
1010        &self,
1011        agent_id: &str,
1012        thread_id: Option<&str>,
1013        limit: usize,
1014    ) -> Result<Vec<MemoryRecord>> {
1015        let rows = if let Some(tid) = thread_id {
1016            let sql = format!(
1017                "SELECT {MEMORY_COLUMNS} FROM memories WHERE agent_id = $1 AND thread_id = $2 AND deleted_at IS NULL ORDER BY created_at ASC LIMIT $3"
1018            );
1019            sqlx::query(&sql)
1020                .bind(agent_id)
1021                .bind(tid)
1022                .bind(limit as i64)
1023                .fetch_all(&self.pool)
1024                .await
1025                .map_err(map_sqlx)?
1026        } else {
1027            let sql = format!(
1028                "SELECT {MEMORY_COLUMNS} FROM memories WHERE agent_id = $1 AND deleted_at IS NULL ORDER BY created_at ASC LIMIT $2"
1029            );
1030            sqlx::query(&sql)
1031                .bind(agent_id)
1032                .bind(limit as i64)
1033                .fetch_all(&self.pool)
1034                .await
1035                .map_err(map_sqlx)?
1036        };
1037
1038        let mut results = Vec::with_capacity(rows.len());
1039        for r in &rows {
1040            results.push(row_to_memory(r).map_err(map_sqlx)?);
1041        }
1042        Ok(results)
1043    }
1044
1045    // -----------------------------------------------------------------------
1046    // Sync support
1047    // -----------------------------------------------------------------------
1048
1049    async fn list_memories_since(
1050        &self,
1051        updated_after: &str,
1052        limit: usize,
1053    ) -> Result<Vec<MemoryRecord>> {
1054        let sql = format!(
1055            "SELECT {MEMORY_COLUMNS} FROM memories WHERE updated_at > $1 ORDER BY updated_at ASC LIMIT $2"
1056        );
1057        let rows = sqlx::query(&sql)
1058            .bind(updated_after)
1059            .bind(limit as i64)
1060            .fetch_all(&self.pool)
1061            .await
1062            .map_err(map_sqlx)?;
1063
1064        let mut results = Vec::with_capacity(rows.len());
1065        for r in &rows {
1066            results.push(row_to_memory(r).map_err(map_sqlx)?);
1067        }
1068        Ok(results)
1069    }
1070
1071    async fn upsert_memory(&self, record: &MemoryRecord) -> Result<()> {
1072        match self.update_memory(record).await {
1073            Ok(()) => Ok(()),
1074            Err(Error::NotFound(_)) => self.insert_memory(record).await,
1075            Err(e) => Err(e),
1076        }
1077    }
1078
1079    // -----------------------------------------------------------------------
1080    // Expired memory cleanup
1081    // -----------------------------------------------------------------------
1082
1083    async fn cleanup_expired(&self) -> Result<usize> {
1084        let now = chrono::Utc::now().to_rfc3339();
1085        let result = sqlx::query(
1086            "UPDATE memories SET deleted_at = $1 WHERE expires_at IS NOT NULL AND expires_at < $2 AND deleted_at IS NULL",
1087        )
1088        .bind(&now)
1089        .bind(&now)
1090        .execute(&self.pool)
1091        .await
1092        .map_err(map_sqlx)?;
1093
1094        Ok(result.rows_affected() as usize)
1095    }
1096
1097    // -----------------------------------------------------------------------
1098    // Delegations
1099    // -----------------------------------------------------------------------
1100
1101    async fn insert_delegation(&self, d: &Delegation) -> Result<()> {
1102        let scope_type = d.scope.to_string();
1103        let scope_value: serde_json::Value = match &d.scope {
1104            DelegationScope::AllMemories => serde_json::Value::Null,
1105            DelegationScope::ByTag(tags) => serde_json::json!(tags),
1106            DelegationScope::ByMemoryId(ids) => {
1107                serde_json::json!(ids.iter().map(|id| id.to_string()).collect::<Vec<_>>())
1108            }
1109        };
1110
1111        sqlx::query(
1112            r#"
1113INSERT INTO delegations (
1114    id, delegator_id, delegate_id, permission, scope_type, scope_value,
1115    max_depth, current_depth, parent_delegation_id,
1116    created_at, expires_at, revoked_at
1117) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
1118"#,
1119        )
1120        .bind(d.id)
1121        .bind(&d.delegator_id)
1122        .bind(&d.delegate_id)
1123        .bind(d.permission.to_string())
1124        .bind(&scope_type)
1125        .bind(&scope_value)
1126        .bind(d.max_depth as i32)
1127        .bind(d.current_depth as i32)
1128        .bind(d.parent_delegation_id)
1129        .bind(&d.created_at)
1130        .bind(&d.expires_at)
1131        .bind(&d.revoked_at)
1132        .execute(&self.pool)
1133        .await
1134        .map_err(map_sqlx)?;
1135        Ok(())
1136    }
1137
1138    async fn list_delegations_for(&self, delegate_id: &str) -> Result<Vec<Delegation>> {
1139        let now = chrono::Utc::now().to_rfc3339();
1140        let rows = sqlx::query(
1141            r#"
1142SELECT id, delegator_id, delegate_id, permission, scope_type, scope_value,
1143       max_depth, current_depth, parent_delegation_id,
1144       created_at, expires_at, revoked_at
1145FROM delegations
1146WHERE delegate_id = $1 AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > $2)
1147"#,
1148        )
1149        .bind(delegate_id)
1150        .bind(&now)
1151        .fetch_all(&self.pool)
1152        .await
1153        .map_err(map_sqlx)?;
1154
1155        let mut results = Vec::with_capacity(rows.len());
1156        for r in &rows {
1157            results.push(row_to_delegation(r).map_err(map_sqlx)?);
1158        }
1159        Ok(results)
1160    }
1161
1162    async fn revoke_delegation(&self, id: Uuid) -> Result<()> {
1163        let now = chrono::Utc::now().to_rfc3339();
1164        let result = sqlx::query(
1165            "UPDATE delegations SET revoked_at = $1 WHERE id = $2 AND revoked_at IS NULL",
1166        )
1167        .bind(&now)
1168        .bind(id)
1169        .execute(&self.pool)
1170        .await
1171        .map_err(map_sqlx)?;
1172
1173        if result.rows_affected() == 0 {
1174            return Err(Error::NotFound(format!(
1175                "delegation {id} not found or already revoked"
1176            )));
1177        }
1178        Ok(())
1179    }
1180
1181    async fn check_delegation(
1182        &self,
1183        delegate_id: &str,
1184        memory_id: Uuid,
1185        required: Permission,
1186    ) -> Result<bool> {
1187        let delegations = self.list_delegations_for(delegate_id).await?;
1188
1189        // Get the memory to inspect its tags for scope matching
1190        let memory = match self.get_memory(memory_id).await? {
1191            Some(m) => m,
1192            None => return Ok(false),
1193        };
1194
1195        for d in &delegations {
1196            if !d.permission.satisfies(required) {
1197                continue;
1198            }
1199            match &d.scope {
1200                DelegationScope::AllMemories => return Ok(true),
1201                DelegationScope::ByMemoryId(ids) => {
1202                    if ids.contains(&memory_id) {
1203                        return Ok(true);
1204                    }
1205                }
1206                DelegationScope::ByTag(tags) => {
1207                    if tags.iter().any(|t| memory.tags.contains(t)) {
1208                        return Ok(true);
1209                    }
1210                }
1211            }
1212        }
1213        Ok(false)
1214    }
1215
1216    // -----------------------------------------------------------------------
1217    // Agent Profiles
1218    // -----------------------------------------------------------------------
1219
1220    async fn insert_or_update_agent_profile(&self, profile: &AgentProfile) -> Result<()> {
1221        sqlx::query(
1222            r#"
1223INSERT INTO agent_profiles (agent_id, avg_importance, avg_content_length, total_memories, last_updated)
1224VALUES ($1, $2, $3, $4, $5)
1225ON CONFLICT (agent_id) DO UPDATE SET
1226    avg_importance = EXCLUDED.avg_importance,
1227    avg_content_length = EXCLUDED.avg_content_length,
1228    total_memories = EXCLUDED.total_memories,
1229    last_updated = EXCLUDED.last_updated
1230"#,
1231        )
1232        .bind(&profile.agent_id)
1233        .bind(profile.avg_importance)
1234        .bind(profile.avg_content_length)
1235        .bind(profile.total_memories as i64)
1236        .bind(&profile.last_updated)
1237        .execute(&self.pool)
1238        .await
1239        .map_err(map_sqlx)?;
1240        Ok(())
1241    }
1242
1243    async fn get_agent_profile(&self, agent_id: &str) -> Result<Option<AgentProfile>> {
1244        let row = sqlx::query(
1245            "SELECT agent_id, avg_importance, avg_content_length, total_memories, last_updated FROM agent_profiles WHERE agent_id = $1",
1246        )
1247        .bind(agent_id)
1248        .fetch_optional(&self.pool)
1249        .await
1250        .map_err(map_sqlx)?;
1251
1252        Ok(row.map(|r| AgentProfile {
1253            agent_id: r.get("agent_id"),
1254            avg_importance: r.get("avg_importance"),
1255            avg_content_length: r.get("avg_content_length"),
1256            total_memories: r.get::<i64, _>("total_memories") as u64,
1257            last_updated: r.get("last_updated"),
1258        }))
1259    }
1260
1261    // -----------------------------------------------------------------------
1262    // Embedding baselines (v0.3.3)
1263    // -----------------------------------------------------------------------
1264
1265    async fn insert_or_update_embedding_baseline(
1266        &self,
1267        baseline: &EmbeddingBaseline,
1268    ) -> Result<()> {
1269        let mu_json =
1270            serde_json::to_value(&baseline.mu).map_err(|e| Error::Storage(e.to_string()))?;
1271        let cov_json =
1272            serde_json::to_value(&baseline.cov_diag).map_err(|e| Error::Storage(e.to_string()))?;
1273        sqlx::query(
1274            r#"
1275INSERT INTO embedding_baseline (agent_id, mu, cov_diag, n, updated_at)
1276VALUES ($1, $2, $3, $4, $5)
1277ON CONFLICT (agent_id) DO UPDATE SET
1278    mu = EXCLUDED.mu,
1279    cov_diag = EXCLUDED.cov_diag,
1280    n = EXCLUDED.n,
1281    updated_at = EXCLUDED.updated_at
1282"#,
1283        )
1284        .bind(&baseline.agent_id)
1285        .bind(&mu_json)
1286        .bind(&cov_json)
1287        .bind(baseline.n as i64)
1288        .bind(&baseline.updated_at)
1289        .execute(&self.pool)
1290        .await
1291        .map_err(map_sqlx)?;
1292        Ok(())
1293    }
1294
1295    async fn get_embedding_baseline(&self, agent_id: &str) -> Result<Option<EmbeddingBaseline>> {
1296        let row = sqlx::query(
1297            "SELECT agent_id, mu, cov_diag, n, updated_at FROM embedding_baseline WHERE agent_id = $1",
1298        )
1299        .bind(agent_id)
1300        .fetch_optional(&self.pool)
1301        .await
1302        .map_err(map_sqlx)?;
1303
1304        match row {
1305            None => Ok(None),
1306            Some(r) => {
1307                let mu_val: serde_json::Value = r.get("mu");
1308                let cov_val: serde_json::Value = r.get("cov_diag");
1309                let mu: Vec<f32> =
1310                    serde_json::from_value(mu_val).map_err(|e| Error::Storage(e.to_string()))?;
1311                let cov_diag: Vec<f32> =
1312                    serde_json::from_value(cov_val).map_err(|e| Error::Storage(e.to_string()))?;
1313                Ok(Some(EmbeddingBaseline {
1314                    agent_id: r.get("agent_id"),
1315                    mu,
1316                    cov_diag,
1317                    n: r.get::<i64, _>("n") as u64,
1318                    updated_at: r.get("updated_at"),
1319                }))
1320            }
1321        }
1322    }
1323
1324    // -----------------------------------------------------------------------
1325    // Checkpoints
1326    // -----------------------------------------------------------------------
1327
1328    async fn insert_checkpoint(&self, cp: &Checkpoint) -> Result<()> {
1329        let memory_refs_strs: Vec<String> =
1330            cp.memory_refs.iter().map(|id| id.to_string()).collect();
1331        let refs_slice: &[String] = &memory_refs_strs;
1332
1333        sqlx::query(
1334            r#"
1335INSERT INTO checkpoints (
1336    id, thread_id, agent_id, parent_id, branch_name,
1337    state_snapshot, state_diff, memory_refs, event_cursor,
1338    label, created_at, metadata
1339) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
1340"#,
1341        )
1342        .bind(cp.id)
1343        .bind(&cp.thread_id)
1344        .bind(&cp.agent_id)
1345        .bind(cp.parent_id)
1346        .bind(&cp.branch_name)
1347        .bind(&cp.state_snapshot)
1348        .bind(&cp.state_diff)
1349        .bind(refs_slice)
1350        .bind(cp.event_cursor)
1351        .bind(&cp.label)
1352        .bind(&cp.created_at)
1353        .bind(&cp.metadata)
1354        .execute(&self.pool)
1355        .await
1356        .map_err(map_sqlx)?;
1357        Ok(())
1358    }
1359
1360    async fn get_checkpoint(&self, id: Uuid) -> Result<Option<Checkpoint>> {
1361        let row = sqlx::query(
1362            r#"
1363SELECT id, thread_id, agent_id, parent_id, branch_name,
1364       state_snapshot, state_diff, memory_refs, event_cursor,
1365       label, created_at, metadata
1366FROM checkpoints WHERE id = $1
1367"#,
1368        )
1369        .bind(id)
1370        .fetch_optional(&self.pool)
1371        .await
1372        .map_err(map_sqlx)?;
1373
1374        match row {
1375            Some(r) => Ok(Some(row_to_checkpoint(&r).map_err(map_sqlx)?)),
1376            None => Ok(None),
1377        }
1378    }
1379
1380    async fn list_checkpoints(
1381        &self,
1382        thread_id: &str,
1383        branch: Option<&str>,
1384        limit: usize,
1385    ) -> Result<Vec<Checkpoint>> {
1386        let rows = if let Some(branch_name) = branch {
1387            sqlx::query(
1388                r#"
1389SELECT id, thread_id, agent_id, parent_id, branch_name,
1390       state_snapshot, state_diff, memory_refs, event_cursor,
1391       label, created_at, metadata
1392FROM checkpoints
1393WHERE thread_id = $1 AND branch_name = $2
1394ORDER BY created_at DESC
1395LIMIT $3
1396"#,
1397            )
1398            .bind(thread_id)
1399            .bind(branch_name)
1400            .bind(limit as i64)
1401            .fetch_all(&self.pool)
1402            .await
1403            .map_err(map_sqlx)?
1404        } else {
1405            sqlx::query(
1406                r#"
1407SELECT id, thread_id, agent_id, parent_id, branch_name,
1408       state_snapshot, state_diff, memory_refs, event_cursor,
1409       label, created_at, metadata
1410FROM checkpoints
1411WHERE thread_id = $1
1412ORDER BY created_at DESC
1413LIMIT $2
1414"#,
1415            )
1416            .bind(thread_id)
1417            .bind(limit as i64)
1418            .fetch_all(&self.pool)
1419            .await
1420            .map_err(map_sqlx)?
1421        };
1422
1423        let mut results = Vec::with_capacity(rows.len());
1424        for r in &rows {
1425            results.push(row_to_checkpoint(r).map_err(map_sqlx)?);
1426        }
1427        Ok(results)
1428    }
1429
1430    async fn get_latest_checkpoint(
1431        &self,
1432        thread_id: &str,
1433        branch: &str,
1434    ) -> Result<Option<Checkpoint>> {
1435        let row = sqlx::query(
1436            r#"
1437SELECT id, thread_id, agent_id, parent_id, branch_name,
1438       state_snapshot, state_diff, memory_refs, event_cursor,
1439       label, created_at, metadata
1440FROM checkpoints
1441WHERE thread_id = $1 AND branch_name = $2
1442ORDER BY created_at DESC
1443LIMIT 1
1444"#,
1445        )
1446        .bind(thread_id)
1447        .bind(branch)
1448        .fetch_optional(&self.pool)
1449        .await
1450        .map_err(map_sqlx)?;
1451
1452        match row {
1453            Some(r) => Ok(Some(row_to_checkpoint(&r).map_err(map_sqlx)?)),
1454            None => Ok(None),
1455        }
1456    }
1457}