1use crate::fact::{Entity, EntityId, Relationship, RelationshipId, SubGraph};
7use crate::graph::GraphStore;
8use crate::scope::Scope;
9use crate::store::MemoryError;
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::PgPool;
13use std::collections::{HashMap, HashSet, VecDeque};
14use uuid::Uuid;
15
16const PG_GRAPH_STORE_DDL: &[&str] = &[
23 r#"
24 CREATE TABLE IF NOT EXISTS entities (
25 id UUID PRIMARY KEY,
26 name TEXT NOT NULL,
27 entity_type TEXT,
28 org_id TEXT NOT NULL DEFAULT 'default',
29 agent_id TEXT,
30 user_id TEXT,
31 session_id TEXT,
32 attributes JSONB NOT NULL DEFAULT 'null',
33 created_at TIMESTAMPTZ NOT NULL,
34 updated_at TIMESTAMPTZ NOT NULL
35 )
36 "#,
37 "CREATE INDEX IF NOT EXISTS idx_pg_entities_name ON entities (name)",
38 "CREATE INDEX IF NOT EXISTS idx_pg_entities_user_id ON entities (user_id)",
39 "CREATE INDEX IF NOT EXISTS idx_pg_entities_org_id ON entities (org_id)",
40 r#"
41 CREATE TABLE IF NOT EXISTS relationships (
42 id UUID PRIMARY KEY,
43 source_id UUID NOT NULL,
44 relation TEXT NOT NULL,
45 target_id UUID NOT NULL,
46 org_id TEXT NOT NULL DEFAULT 'default',
47 agent_id TEXT,
48 user_id TEXT,
49 session_id TEXT,
50 valid_from TIMESTAMPTZ NOT NULL,
51 invalid_at TIMESTAMPTZ,
52 created_at TIMESTAMPTZ NOT NULL
53 )
54 "#,
55 "CREATE INDEX IF NOT EXISTS idx_pg_rel_source_id ON relationships (source_id)",
56 "CREATE INDEX IF NOT EXISTS idx_pg_rel_target_id ON relationships (target_id)",
57 "CREATE INDEX IF NOT EXISTS idx_pg_rel_relation ON relationships (relation)",
58 "CREATE INDEX IF NOT EXISTS idx_pg_rel_invalid_at ON relationships (invalid_at)",
59];
60
61pub struct PostgresGraphStore {
66 pool: PgPool,
67}
68
69impl PostgresGraphStore {
70 pub fn new(pool: PgPool) -> Self {
71 Self { pool }
72 }
73
74 pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
76 let pool = PgPool::connect(database_url).await?;
77 Ok(Self { pool })
78 }
79
80 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
82 for stmt in PG_GRAPH_STORE_DDL {
83 sqlx::query(stmt).execute(&self.pool).await?;
84 }
85 Ok(())
86 }
87}
88
89#[derive(sqlx::FromRow)]
94struct EntityRow {
95 id: Uuid,
96 name: String,
97 entity_type: Option<String>,
98 org_id: String,
99 agent_id: Option<String>,
100 user_id: Option<String>,
101 session_id: Option<String>,
102 attributes: serde_json::Value,
103 created_at: DateTime<Utc>,
104 updated_at: DateTime<Utc>,
105}
106
107#[derive(sqlx::FromRow)]
108struct RelationshipRow {
109 id: Uuid,
110 source_id: Uuid,
111 relation: String,
112 target_id: Uuid,
113 org_id: String,
114 agent_id: Option<String>,
115 user_id: Option<String>,
116 session_id: Option<String>,
117 valid_from: DateTime<Utc>,
118 invalid_at: Option<DateTime<Utc>>,
119 created_at: DateTime<Utc>,
120}
121
122fn row_to_entity(row: EntityRow) -> Result<Entity, MemoryError> {
127 let attributes: serde_json::Map<String, serde_json::Value> = match &row.attributes {
128 serde_json::Value::Null => serde_json::Map::new(),
129 serde_json::Value::Object(map) => map.clone(),
130 other => serde_json::from_value(other.clone())
131 .map_err(|e| MemoryError::Serialization(e.to_string()))?,
132 };
133
134 Ok(Entity {
135 id: row.id,
136 name: row.name,
137 entity_type: row.entity_type.unwrap_or_else(|| "unknown".to_string()),
138 scope: Scope {
139 org_id: row.org_id,
140 agent_id: row.agent_id,
141 user_id: row.user_id,
142 session_id: row.session_id,
143 },
144 attributes,
145 created_at: row.created_at,
146 updated_at: row.updated_at,
147 })
148}
149
150fn row_to_relationship(row: RelationshipRow) -> Result<Relationship, MemoryError> {
151 Ok(Relationship {
152 id: row.id,
153 source_id: row.source_id,
154 relation: row.relation,
155 target_id: row.target_id,
156 scope: Scope {
157 org_id: row.org_id,
158 agent_id: row.agent_id,
159 user_id: row.user_id,
160 session_id: row.session_id,
161 },
162 valid_from: row.valid_from,
163 invalid_at: row.invalid_at,
164 created_at: row.created_at,
165 })
166}
167
168#[async_trait]
173impl GraphStore for PostgresGraphStore {
174 async fn upsert_entity(&self, entity: &Entity) -> Result<(), MemoryError> {
175 let attributes_json = if entity.attributes.is_empty() {
176 serde_json::Value::Null
177 } else {
178 serde_json::to_value(&entity.attributes)
179 .map_err(|e| MemoryError::Serialization(e.to_string()))?
180 };
181
182 sqlx::query(
183 r#"
184 INSERT INTO entities
185 (id, name, entity_type, org_id, agent_id, user_id, session_id,
186 attributes, created_at, updated_at)
187 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
188 ON CONFLICT(id) DO UPDATE SET
189 name = EXCLUDED.name,
190 entity_type = EXCLUDED.entity_type,
191 attributes = EXCLUDED.attributes,
192 updated_at = EXCLUDED.updated_at
193 "#,
194 )
195 .bind(entity.id)
196 .bind(&entity.name)
197 .bind(&entity.entity_type)
198 .bind(&entity.scope.org_id)
199 .bind(entity.scope.agent_id.as_deref())
200 .bind(entity.scope.user_id.as_deref())
201 .bind(entity.scope.session_id.as_deref())
202 .bind(&attributes_json)
203 .bind(entity.created_at)
204 .bind(entity.updated_at)
205 .execute(&self.pool)
206 .await
207 .map_err(|e| MemoryError::Database(e.to_string()))?;
208
209 Ok(())
210 }
211
212 async fn upsert_relationship(&self, rel: &Relationship) -> Result<(), MemoryError> {
213 sqlx::query(
214 r#"
215 INSERT INTO relationships
216 (id, source_id, relation, target_id, org_id, agent_id, user_id, session_id,
217 valid_from, invalid_at, created_at)
218 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
219 ON CONFLICT(id) DO UPDATE SET
220 relation = EXCLUDED.relation,
221 invalid_at = EXCLUDED.invalid_at
222 "#,
223 )
224 .bind(rel.id)
225 .bind(rel.source_id)
226 .bind(&rel.relation)
227 .bind(rel.target_id)
228 .bind(&rel.scope.org_id)
229 .bind(rel.scope.agent_id.as_deref())
230 .bind(rel.scope.user_id.as_deref())
231 .bind(rel.scope.session_id.as_deref())
232 .bind(rel.valid_from)
233 .bind(rel.invalid_at)
234 .bind(rel.created_at)
235 .execute(&self.pool)
236 .await
237 .map_err(|e| MemoryError::Database(e.to_string()))?;
238
239 Ok(())
240 }
241
242 async fn invalidate_relationship(
243 &self,
244 id: RelationshipId,
245 invalid_at: DateTime<Utc>,
246 ) -> Result<(), MemoryError> {
247 sqlx::query("UPDATE relationships SET invalid_at = $1 WHERE id = $2")
248 .bind(invalid_at)
249 .bind(id)
250 .execute(&self.pool)
251 .await
252 .map_err(|e| MemoryError::Database(e.to_string()))?;
253 Ok(())
254 }
255
256 async fn get_entity(&self, id: EntityId) -> Result<Option<Entity>, MemoryError> {
257 let row = sqlx::query_as::<_, EntityRow>("SELECT * FROM entities WHERE id = $1")
258 .bind(id)
259 .fetch_optional(&self.pool)
260 .await
261 .map_err(|e| MemoryError::Database(e.to_string()))?;
262
263 row.map(row_to_entity).transpose()
264 }
265
266 async fn neighbors(
267 &self,
268 id: EntityId,
269 depth: u8,
270 as_of: Option<DateTime<Utc>>,
271 ) -> Result<SubGraph, MemoryError> {
272 let validity_clause = match as_of {
274 Some(t) => {
275 let s = t.to_rfc3339();
276 format!("valid_from <= '{s}' AND (invalid_at IS NULL OR invalid_at > '{s}')")
277 }
278 None => "invalid_at IS NULL".to_string(),
279 };
280
281 let mut visited_entities: HashSet<EntityId> = HashSet::new();
282 visited_entities.insert(id);
283
284 let mut discovered_entities: HashMap<EntityId, Entity> = HashMap::new();
285 let mut discovered_relationships: HashMap<RelationshipId, Relationship> = HashMap::new();
286
287 let mut queue: VecDeque<(EntityId, u8)> = VecDeque::new();
289 queue.push_back((id, depth));
290
291 while let Some((current_id, remaining)) = queue.pop_front() {
292 if remaining == 0 {
293 continue;
294 }
295
296 let sql = format!(
298 "SELECT * FROM relationships WHERE (source_id = $1 OR target_id = $2) AND {validity_clause}"
299 );
300
301 let rel_rows = sqlx::query_as::<_, RelationshipRow>(&sql)
302 .bind(current_id)
303 .bind(current_id)
304 .fetch_all(&self.pool)
305 .await
306 .map_err(|e| MemoryError::Database(e.to_string()))?;
307
308 for row in rel_rows {
309 let rel = row_to_relationship(row)?;
310 let neighbor_id = if rel.source_id == current_id {
311 rel.target_id
312 } else {
313 rel.source_id
314 };
315
316 discovered_relationships.entry(rel.id).or_insert(rel);
318
319 if !visited_entities.contains(&neighbor_id) {
321 visited_entities.insert(neighbor_id);
322
323 if let Some(entity) = self.get_entity(neighbor_id).await? {
325 discovered_entities.entry(neighbor_id).or_insert(entity);
326 }
327
328 queue.push_back((neighbor_id, remaining - 1));
329 }
330 }
331 }
332
333 Ok(SubGraph {
334 entities: discovered_entities.into_values().collect(),
335 relationships: discovered_relationships.into_values().collect(),
336 })
337 }
338
339 async fn search_entities(&self, query: &str, top_k: usize) -> Result<Vec<Entity>, MemoryError> {
340 let sql = "SELECT * FROM entities WHERE name ILIKE $1 LIMIT $2";
341
342 let pattern = format!("%{}%", query.replace('%', "\\%").replace('_', "\\_"));
343
344 let rows = sqlx::query_as::<_, EntityRow>(sql)
345 .bind(&pattern)
346 .bind(top_k as i64)
347 .fetch_all(&self.pool)
348 .await
349 .map_err(|e| MemoryError::Database(e.to_string()))?;
350
351 rows.into_iter().map(row_to_entity).collect()
352 }
353
354 async fn delete_by_scope(&self, scope: &Scope) -> Result<u64, MemoryError> {
355 let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
357
358 if let Some(ref user_id) = scope.user_id {
359 wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
360 }
361 if let Some(ref agent_id) = scope.agent_id {
362 wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
363 }
364 if let Some(ref session_id) = scope.session_id {
365 wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
366 }
367
368 let where_clause = wheres.join(" AND ");
369
370 let rel_sql = format!("DELETE FROM relationships WHERE {where_clause}");
372 let rel_result = sqlx::query(&rel_sql)
373 .execute(&self.pool)
374 .await
375 .map_err(|e| MemoryError::Database(e.to_string()))?;
376
377 let ent_sql = format!("DELETE FROM entities WHERE {where_clause}");
379 let ent_result = sqlx::query(&ent_sql)
380 .execute(&self.pool)
381 .await
382 .map_err(|e| MemoryError::Database(e.to_string()))?;
383
384 Ok(rel_result.rows_affected() + ent_result.rows_affected())
385 }
386}