use crate::fact::{Entity, EntityId, Relationship, RelationshipId, SubGraph};
use crate::graph::GraphStore;
use crate::scope::Scope;
use crate::store::MemoryError;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use std::collections::{HashMap, HashSet, VecDeque};
use uuid::Uuid;
const PG_GRAPH_STORE_DDL: &[&str] = &[
r#"
CREATE TABLE IF NOT EXISTS entities (
id UUID PRIMARY KEY,
name TEXT NOT NULL,
entity_type TEXT,
org_id TEXT NOT NULL DEFAULT 'default',
agent_id TEXT,
user_id TEXT,
session_id TEXT,
attributes JSONB NOT NULL DEFAULT 'null',
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL
)
"#,
"CREATE INDEX IF NOT EXISTS idx_pg_entities_name ON entities (name)",
"CREATE INDEX IF NOT EXISTS idx_pg_entities_user_id ON entities (user_id)",
"CREATE INDEX IF NOT EXISTS idx_pg_entities_org_id ON entities (org_id)",
r#"
CREATE TABLE IF NOT EXISTS relationships (
id UUID PRIMARY KEY,
source_id UUID NOT NULL,
relation TEXT NOT NULL,
target_id UUID NOT NULL,
org_id TEXT NOT NULL DEFAULT 'default',
agent_id TEXT,
user_id TEXT,
session_id TEXT,
valid_from TIMESTAMPTZ NOT NULL,
invalid_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL
)
"#,
"CREATE INDEX IF NOT EXISTS idx_pg_rel_source_id ON relationships (source_id)",
"CREATE INDEX IF NOT EXISTS idx_pg_rel_target_id ON relationships (target_id)",
"CREATE INDEX IF NOT EXISTS idx_pg_rel_relation ON relationships (relation)",
"CREATE INDEX IF NOT EXISTS idx_pg_rel_invalid_at ON relationships (invalid_at)",
];
pub struct PostgresGraphStore {
pool: PgPool,
}
impl PostgresGraphStore {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
let pool = PgPool::connect(database_url).await?;
Ok(Self { pool })
}
pub async fn migrate(&self) -> Result<(), sqlx::Error> {
for stmt in PG_GRAPH_STORE_DDL {
sqlx::query(stmt).execute(&self.pool).await?;
}
Ok(())
}
}
#[derive(sqlx::FromRow)]
struct EntityRow {
id: Uuid,
name: String,
entity_type: Option<String>,
org_id: String,
agent_id: Option<String>,
user_id: Option<String>,
session_id: Option<String>,
attributes: serde_json::Value,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
#[derive(sqlx::FromRow)]
struct RelationshipRow {
id: Uuid,
source_id: Uuid,
relation: String,
target_id: Uuid,
org_id: String,
agent_id: Option<String>,
user_id: Option<String>,
session_id: Option<String>,
valid_from: DateTime<Utc>,
invalid_at: Option<DateTime<Utc>>,
created_at: DateTime<Utc>,
}
fn row_to_entity(row: EntityRow) -> Result<Entity, MemoryError> {
let attributes: serde_json::Map<String, serde_json::Value> = match &row.attributes {
serde_json::Value::Null => serde_json::Map::new(),
serde_json::Value::Object(map) => map.clone(),
other => serde_json::from_value(other.clone())
.map_err(|e| MemoryError::Serialization(e.to_string()))?,
};
Ok(Entity {
id: row.id,
name: row.name,
entity_type: row.entity_type.unwrap_or_else(|| "unknown".to_string()),
scope: Scope {
org_id: row.org_id,
agent_id: row.agent_id,
user_id: row.user_id,
session_id: row.session_id,
},
attributes,
created_at: row.created_at,
updated_at: row.updated_at,
})
}
fn row_to_relationship(row: RelationshipRow) -> Result<Relationship, MemoryError> {
Ok(Relationship {
id: row.id,
source_id: row.source_id,
relation: row.relation,
target_id: row.target_id,
scope: Scope {
org_id: row.org_id,
agent_id: row.agent_id,
user_id: row.user_id,
session_id: row.session_id,
},
valid_from: row.valid_from,
invalid_at: row.invalid_at,
created_at: row.created_at,
})
}
#[async_trait]
impl GraphStore for PostgresGraphStore {
async fn upsert_entity(&self, entity: &Entity) -> Result<(), MemoryError> {
let attributes_json = if entity.attributes.is_empty() {
serde_json::Value::Null
} else {
serde_json::to_value(&entity.attributes)
.map_err(|e| MemoryError::Serialization(e.to_string()))?
};
sqlx::query(
r#"
INSERT INTO entities
(id, name, entity_type, org_id, agent_id, user_id, session_id,
attributes, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT(id) DO UPDATE SET
name = EXCLUDED.name,
entity_type = EXCLUDED.entity_type,
attributes = EXCLUDED.attributes,
updated_at = EXCLUDED.updated_at
"#,
)
.bind(entity.id)
.bind(&entity.name)
.bind(&entity.entity_type)
.bind(&entity.scope.org_id)
.bind(entity.scope.agent_id.as_deref())
.bind(entity.scope.user_id.as_deref())
.bind(entity.scope.session_id.as_deref())
.bind(&attributes_json)
.bind(entity.created_at)
.bind(entity.updated_at)
.execute(&self.pool)
.await
.map_err(|e| MemoryError::Database(e.to_string()))?;
Ok(())
}
async fn upsert_relationship(&self, rel: &Relationship) -> Result<(), MemoryError> {
sqlx::query(
r#"
INSERT INTO relationships
(id, source_id, relation, target_id, org_id, agent_id, user_id, session_id,
valid_from, invalid_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT(id) DO UPDATE SET
relation = EXCLUDED.relation,
invalid_at = EXCLUDED.invalid_at
"#,
)
.bind(rel.id)
.bind(rel.source_id)
.bind(&rel.relation)
.bind(rel.target_id)
.bind(&rel.scope.org_id)
.bind(rel.scope.agent_id.as_deref())
.bind(rel.scope.user_id.as_deref())
.bind(rel.scope.session_id.as_deref())
.bind(rel.valid_from)
.bind(rel.invalid_at)
.bind(rel.created_at)
.execute(&self.pool)
.await
.map_err(|e| MemoryError::Database(e.to_string()))?;
Ok(())
}
async fn invalidate_relationship(
&self,
id: RelationshipId,
invalid_at: DateTime<Utc>,
) -> Result<(), MemoryError> {
sqlx::query("UPDATE relationships SET invalid_at = $1 WHERE id = $2")
.bind(invalid_at)
.bind(id)
.execute(&self.pool)
.await
.map_err(|e| MemoryError::Database(e.to_string()))?;
Ok(())
}
async fn get_entity(&self, id: EntityId) -> Result<Option<Entity>, MemoryError> {
let row = sqlx::query_as::<_, EntityRow>("SELECT * FROM entities WHERE id = $1")
.bind(id)
.fetch_optional(&self.pool)
.await
.map_err(|e| MemoryError::Database(e.to_string()))?;
row.map(row_to_entity).transpose()
}
async fn neighbors(
&self,
id: EntityId,
depth: u8,
as_of: Option<DateTime<Utc>>,
) -> Result<SubGraph, MemoryError> {
let validity_clause = match as_of {
Some(t) => {
let s = t.to_rfc3339();
format!("valid_from <= '{s}' AND (invalid_at IS NULL OR invalid_at > '{s}')")
}
None => "invalid_at IS NULL".to_string(),
};
let mut visited_entities: HashSet<EntityId> = HashSet::new();
visited_entities.insert(id);
let mut discovered_entities: HashMap<EntityId, Entity> = HashMap::new();
let mut discovered_relationships: HashMap<RelationshipId, Relationship> = HashMap::new();
let mut queue: VecDeque<(EntityId, u8)> = VecDeque::new();
queue.push_back((id, depth));
while let Some((current_id, remaining)) = queue.pop_front() {
if remaining == 0 {
continue;
}
let sql = format!(
"SELECT * FROM relationships WHERE (source_id = $1 OR target_id = $2) AND {validity_clause}"
);
let rel_rows = sqlx::query_as::<_, RelationshipRow>(&sql)
.bind(current_id)
.bind(current_id)
.fetch_all(&self.pool)
.await
.map_err(|e| MemoryError::Database(e.to_string()))?;
for row in rel_rows {
let rel = row_to_relationship(row)?;
let neighbor_id = if rel.source_id == current_id {
rel.target_id
} else {
rel.source_id
};
discovered_relationships.entry(rel.id).or_insert(rel);
if !visited_entities.contains(&neighbor_id) {
visited_entities.insert(neighbor_id);
if let Some(entity) = self.get_entity(neighbor_id).await? {
discovered_entities.entry(neighbor_id).or_insert(entity);
}
queue.push_back((neighbor_id, remaining - 1));
}
}
}
Ok(SubGraph {
entities: discovered_entities.into_values().collect(),
relationships: discovered_relationships.into_values().collect(),
})
}
async fn search_entities(&self, query: &str, top_k: usize) -> Result<Vec<Entity>, MemoryError> {
let sql = "SELECT * FROM entities WHERE name ILIKE $1 LIMIT $2";
let pattern = format!("%{}%", query.replace('%', "\\%").replace('_', "\\_"));
let rows = sqlx::query_as::<_, EntityRow>(sql)
.bind(&pattern)
.bind(top_k as i64)
.fetch_all(&self.pool)
.await
.map_err(|e| MemoryError::Database(e.to_string()))?;
rows.into_iter().map(row_to_entity).collect()
}
async fn delete_by_scope(&self, scope: &Scope) -> Result<u64, MemoryError> {
let mut wheres = vec![format!("org_id = '{}'", scope.org_id.replace('\'', "''"))];
if let Some(ref user_id) = scope.user_id {
wheres.push(format!("user_id = '{}'", user_id.replace('\'', "''")));
}
if let Some(ref agent_id) = scope.agent_id {
wheres.push(format!("agent_id = '{}'", agent_id.replace('\'', "''")));
}
if let Some(ref session_id) = scope.session_id {
wheres.push(format!("session_id = '{}'", session_id.replace('\'', "''")));
}
let where_clause = wheres.join(" AND ");
let rel_sql = format!("DELETE FROM relationships WHERE {where_clause}");
let rel_result = sqlx::query(&rel_sql)
.execute(&self.pool)
.await
.map_err(|e| MemoryError::Database(e.to_string()))?;
let ent_sql = format!("DELETE FROM entities WHERE {where_clause}");
let ent_result = sqlx::query(&ent_sql)
.execute(&self.pool)
.await
.map_err(|e| MemoryError::Database(e.to_string()))?;
Ok(rel_result.rows_affected() + ent_result.rows_affected())
}
}