Skip to main content

engram/
graph_postgres.rs

1//! PostgreSQL-backed `GraphStore` implementation.
2//!
3//! Uses native Postgres types: `UUID` for ids, `TIMESTAMPTZ` for timestamps,
4//! `JSONB` for attributes. Placeholder syntax uses `$1, $2, …`.
5
6use crate::fact::{Entity, EntityId, Relationship, RelationshipId, SubGraph};
7use crate::graph::GraphStore;
8use crate::scope::Scope;
9use crate::store::MemoryError;
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::PgPool;
13use std::collections::{HashMap, HashSet, VecDeque};
14use uuid::Uuid;
15
16// ---------------------------------------------------------------------------
17// DDL
18// ---------------------------------------------------------------------------
19
20/// DDL statements for the Postgres graph tables. Each element is a single
21/// statement to be executed independently.
22const PG_GRAPH_STORE_DDL: &[&str] = &[
23    r#"
24    CREATE TABLE IF NOT EXISTS entities (
25        id          UUID PRIMARY KEY,
26        name        TEXT NOT NULL,
27        entity_type TEXT,
28        org_id      TEXT NOT NULL DEFAULT 'default',
29        agent_id    TEXT,
30        user_id     TEXT,
31        session_id  TEXT,
32        attributes  JSONB NOT NULL DEFAULT 'null',
33        created_at  TIMESTAMPTZ NOT NULL,
34        updated_at  TIMESTAMPTZ NOT NULL
35    )
36    "#,
37    "CREATE INDEX IF NOT EXISTS idx_pg_entities_name    ON entities (name)",
38    "CREATE INDEX IF NOT EXISTS idx_pg_entities_user_id ON entities (user_id)",
39    "CREATE INDEX IF NOT EXISTS idx_pg_entities_org_id  ON entities (org_id)",
40    r#"
41    CREATE TABLE IF NOT EXISTS relationships (
42        id          UUID PRIMARY KEY,
43        source_id   UUID NOT NULL,
44        relation    TEXT NOT NULL,
45        target_id   UUID NOT NULL,
46        org_id      TEXT NOT NULL DEFAULT 'default',
47        agent_id    TEXT,
48        user_id     TEXT,
49        session_id  TEXT,
50        valid_from  TIMESTAMPTZ NOT NULL,
51        invalid_at  TIMESTAMPTZ,
52        created_at  TIMESTAMPTZ NOT NULL
53    )
54    "#,
55    "CREATE INDEX IF NOT EXISTS idx_pg_rel_source_id  ON relationships (source_id)",
56    "CREATE INDEX IF NOT EXISTS idx_pg_rel_target_id  ON relationships (target_id)",
57    "CREATE INDEX IF NOT EXISTS idx_pg_rel_relation   ON relationships (relation)",
58    "CREATE INDEX IF NOT EXISTS idx_pg_rel_invalid_at ON relationships (invalid_at)",
59];
60
61// ---------------------------------------------------------------------------
62// PostgresGraphStore
63// ---------------------------------------------------------------------------
64
65pub struct PostgresGraphStore {
66    pool: PgPool,
67}
68
69impl PostgresGraphStore {
70    pub fn new(pool: PgPool) -> Self {
71        Self { pool }
72    }
73
74    /// Open a connection pool from a database URL and return a store.
75    pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
76        let pool = PgPool::connect(database_url).await?;
77        Ok(Self { pool })
78    }
79
80    /// Apply the DDL. Safe to call multiple times (uses `IF NOT EXISTS`).
81    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
82        for stmt in PG_GRAPH_STORE_DDL {
83            sqlx::query(stmt).execute(&self.pool).await?;
84        }
85        Ok(())
86    }
87}
88
89// ---------------------------------------------------------------------------
90// Internal row types
91// ---------------------------------------------------------------------------
92
93#[derive(sqlx::FromRow)]
94struct EntityRow {
95    id: Uuid,
96    name: String,
97    entity_type: Option<String>,
98    org_id: String,
99    agent_id: Option<String>,
100    user_id: Option<String>,
101    session_id: Option<String>,
102    attributes: serde_json::Value,
103    created_at: DateTime<Utc>,
104    updated_at: DateTime<Utc>,
105}
106
107#[derive(sqlx::FromRow)]
108struct RelationshipRow {
109    id: Uuid,
110    source_id: Uuid,
111    relation: String,
112    target_id: Uuid,
113    org_id: String,
114    agent_id: Option<String>,
115    user_id: Option<String>,
116    session_id: Option<String>,
117    valid_from: DateTime<Utc>,
118    invalid_at: Option<DateTime<Utc>>,
119    created_at: DateTime<Utc>,
120}
121
122// ---------------------------------------------------------------------------
123// Conversion helpers
124// ---------------------------------------------------------------------------
125
126fn row_to_entity(row: EntityRow) -> Result<Entity, MemoryError> {
127    let attributes: serde_json::Map<String, serde_json::Value> = match &row.attributes {
128        serde_json::Value::Null => serde_json::Map::new(),
129        serde_json::Value::Object(map) => map.clone(),
130        other => serde_json::from_value(other.clone())
131            .map_err(|e| MemoryError::Serialization(e.to_string()))?,
132    };
133
134    Ok(Entity {
135        id: row.id,
136        name: row.name,
137        entity_type: row.entity_type.unwrap_or_else(|| "unknown".to_string()),
138        scope: Scope {
139            org_id: row.org_id,
140            agent_id: row.agent_id,
141            user_id: row.user_id,
142            session_id: row.session_id,
143        },
144        attributes,
145        created_at: row.created_at,
146        updated_at: row.updated_at,
147    })
148}
149
150fn row_to_relationship(row: RelationshipRow) -> Result<Relationship, MemoryError> {
151    Ok(Relationship {
152        id: row.id,
153        source_id: row.source_id,
154        relation: row.relation,
155        target_id: row.target_id,
156        scope: Scope {
157            org_id: row.org_id,
158            agent_id: row.agent_id,
159            user_id: row.user_id,
160            session_id: row.session_id,
161        },
162        valid_from: row.valid_from,
163        invalid_at: row.invalid_at,
164        created_at: row.created_at,
165    })
166}
167
168// ---------------------------------------------------------------------------
169// GraphStore implementation
170// ---------------------------------------------------------------------------
171
172#[async_trait]
173impl GraphStore for PostgresGraphStore {
174    async fn upsert_entity(&self, entity: &Entity) -> Result<(), MemoryError> {
175        let attributes_json = if entity.attributes.is_empty() {
176            serde_json::Value::Null
177        } else {
178            serde_json::to_value(&entity.attributes)
179                .map_err(|e| MemoryError::Serialization(e.to_string()))?
180        };
181
182        sqlx::query(
183            r#"
184            INSERT INTO entities
185                (id, name, entity_type, org_id, agent_id, user_id, session_id,
186                 attributes, created_at, updated_at)
187            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
188            ON CONFLICT(id) DO UPDATE SET
189                name        = EXCLUDED.name,
190                entity_type = EXCLUDED.entity_type,
191                attributes  = EXCLUDED.attributes,
192                updated_at  = EXCLUDED.updated_at
193            "#,
194        )
195        .bind(entity.id)
196        .bind(&entity.name)
197        .bind(&entity.entity_type)
198        .bind(&entity.scope.org_id)
199        .bind(entity.scope.agent_id.as_deref())
200        .bind(entity.scope.user_id.as_deref())
201        .bind(entity.scope.session_id.as_deref())
202        .bind(&attributes_json)
203        .bind(entity.created_at)
204        .bind(entity.updated_at)
205        .execute(&self.pool)
206        .await
207        .map_err(|e| MemoryError::Database(e.to_string()))?;
208
209        Ok(())
210    }
211
212    async fn upsert_relationship(&self, rel: &Relationship) -> Result<(), MemoryError> {
213        sqlx::query(
214            r#"
215            INSERT INTO relationships
216                (id, source_id, relation, target_id, org_id, agent_id, user_id, session_id,
217                 valid_from, invalid_at, created_at)
218            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
219            ON CONFLICT(id) DO UPDATE SET
220                relation   = EXCLUDED.relation,
221                invalid_at = EXCLUDED.invalid_at
222            "#,
223        )
224        .bind(rel.id)
225        .bind(rel.source_id)
226        .bind(&rel.relation)
227        .bind(rel.target_id)
228        .bind(&rel.scope.org_id)
229        .bind(rel.scope.agent_id.as_deref())
230        .bind(rel.scope.user_id.as_deref())
231        .bind(rel.scope.session_id.as_deref())
232        .bind(rel.valid_from)
233        .bind(rel.invalid_at)
234        .bind(rel.created_at)
235        .execute(&self.pool)
236        .await
237        .map_err(|e| MemoryError::Database(e.to_string()))?;
238
239        Ok(())
240    }
241
242    async fn invalidate_relationship(
243        &self,
244        id: RelationshipId,
245        invalid_at: DateTime<Utc>,
246    ) -> Result<(), MemoryError> {
247        sqlx::query("UPDATE relationships SET invalid_at = $1 WHERE id = $2")
248            .bind(invalid_at)
249            .bind(id)
250            .execute(&self.pool)
251            .await
252            .map_err(|e| MemoryError::Database(e.to_string()))?;
253        Ok(())
254    }
255
256    async fn get_entity(&self, id: EntityId) -> Result<Option<Entity>, MemoryError> {
257        let row = sqlx::query_as::<_, EntityRow>("SELECT * FROM entities WHERE id = $1")
258            .bind(id)
259            .fetch_optional(&self.pool)
260            .await
261            .map_err(|e| MemoryError::Database(e.to_string()))?;
262
263        row.map(row_to_entity).transpose()
264    }
265
266    async fn neighbors(
267        &self,
268        id: EntityId,
269        depth: u8,
270        as_of: Option<DateTime<Utc>>,
271    ) -> Result<SubGraph, MemoryError> {
272        // Build the temporal validity filter for relationships.
273        let validity_clause = match as_of {
274            Some(t) => {
275                let s = t.to_rfc3339();
276                format!("valid_from <= '{s}' AND (invalid_at IS NULL OR invalid_at > '{s}')")
277            }
278            None => "invalid_at IS NULL".to_string(),
279        };
280
281        let mut visited_entities: HashSet<EntityId> = HashSet::new();
282        visited_entities.insert(id);
283
284        let mut discovered_entities: HashMap<EntityId, Entity> = HashMap::new();
285        let mut discovered_relationships: HashMap<RelationshipId, Relationship> = HashMap::new();
286
287        // BFS queue: (entity_id, remaining_depth)
288        let mut queue: VecDeque<(EntityId, u8)> = VecDeque::new();
289        queue.push_back((id, depth));
290
291        while let Some((current_id, remaining)) = queue.pop_front() {
292            if remaining == 0 {
293                continue;
294            }
295
296            // Fetch all valid relationships where current_id is source or target.
297            let sql = format!(
298                "SELECT * FROM relationships WHERE (source_id = $1 OR target_id = $2) AND {validity_clause}"
299            );
300
301            let rel_rows = sqlx::query_as::<_, RelationshipRow>(&sql)
302                .bind(current_id)
303                .bind(current_id)
304                .fetch_all(&self.pool)
305                .await
306                .map_err(|e| MemoryError::Database(e.to_string()))?;
307
308            for row in rel_rows {
309                let rel = row_to_relationship(row)?;
310                let neighbor_id = if rel.source_id == current_id {
311                    rel.target_id
312                } else {
313                    rel.source_id
314                };
315
316                // Store the relationship (deduplicated by id).
317                discovered_relationships.entry(rel.id).or_insert(rel);
318
319                // Enqueue unvisited neighbors.
320                if !visited_entities.contains(&neighbor_id) {
321                    visited_entities.insert(neighbor_id);
322
323                    // Fetch the neighbor entity.
324                    if let Some(entity) = self.get_entity(neighbor_id).await? {
325                        discovered_entities.entry(neighbor_id).or_insert(entity);
326                    }
327
328                    queue.push_back((neighbor_id, remaining - 1));
329                }
330            }
331        }
332
333        Ok(SubGraph {
334            entities: discovered_entities.into_values().collect(),
335            relationships: discovered_relationships.into_values().collect(),
336        })
337    }
338
339    async fn search_entities(&self, query: &str, top_k: usize) -> Result<Vec<Entity>, MemoryError> {
340        let sql = "SELECT * FROM entities WHERE name ILIKE $1 LIMIT $2";
341
342        let pattern = format!("%{}%", query.replace('%', "\\%").replace('_', "\\_"));
343
344        let rows = sqlx::query_as::<_, EntityRow>(sql)
345            .bind(&pattern)
346            .bind(top_k as i64)
347            .fetch_all(&self.pool)
348            .await
349            .map_err(|e| MemoryError::Database(e.to_string()))?;
350
351        rows.into_iter().map(row_to_entity).collect()
352    }
353
354    async fn delete_by_scope(&self, scope: &Scope) -> Result<u64, MemoryError> {
355        // Build WHERE clause matching the scope fields.
356        let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
357
358        if let Some(ref user_id) = scope.user_id {
359            wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
360        }
361        if let Some(ref agent_id) = scope.agent_id {
362            wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
363        }
364        if let Some(ref session_id) = scope.session_id {
365            wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
366        }
367
368        let where_clause = wheres.join(" AND ");
369
370        // Delete relationships first.
371        let rel_sql = format!("DELETE FROM relationships WHERE {where_clause}");
372        let rel_result = sqlx::query(&rel_sql)
373            .execute(&self.pool)
374            .await
375            .map_err(|e| MemoryError::Database(e.to_string()))?;
376
377        // Then delete entities.
378        let ent_sql = format!("DELETE FROM entities WHERE {where_clause}");
379        let ent_result = sqlx::query(&ent_sql)
380            .execute(&self.pool)
381            .await
382            .map_err(|e| MemoryError::Database(e.to_string()))?;
383
384        Ok(rel_result.rows_affected() + ent_result.rows_affected())
385    }
386}