use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, TimeZone, Utc};
use uuid::Uuid;
use khive_storage::error::StorageError;
use khive_storage::types::{
BatchWriteSummary, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit, NeighborQuery,
Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
};
use khive_storage::GraphStore;
use khive_storage::LinkId;
use khive_storage::StorageCapability;
use khive_types::EdgeRelation;
use crate::error::SqliteError;
use crate::pool::ConnectionPool;
fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
StorageError::driver(StorageCapability::Graph, op, e)
}
fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
StorageError::driver(StorageCapability::Graph, op, e)
}
pub struct SqlGraphStore {
pool: Arc<ConnectionPool>,
is_file_backed: bool,
namespace: String,
}
impl SqlGraphStore {
pub fn new_scoped(
pool: Arc<ConnectionPool>,
is_file_backed: bool,
namespace: impl Into<String>,
) -> Self {
Self {
pool,
is_file_backed,
namespace: namespace.into(),
}
}
fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
let config = self.pool.config();
let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
operation: "graph_writer".into(),
message: "in-memory databases do not support standalone connections".into(),
})?;
let conn = rusqlite::Connection::open_with_flags(
path,
rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
| rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
| rusqlite::OpenFlags::SQLITE_OPEN_URI,
)
.map_err(|e| map_err(e, "open_graph_writer"))?;
conn.busy_timeout(config.busy_timeout)
.map_err(|e| map_err(e, "open_graph_writer"))?;
conn.pragma_update(None, "foreign_keys", "ON")
.map_err(|e| map_err(e, "open_graph_writer"))?;
conn.pragma_update(None, "synchronous", "NORMAL")
.map_err(|e| map_err(e, "open_graph_writer"))?;
Ok(conn)
}
fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
let config = self.pool.config();
let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
operation: "graph_reader".into(),
message: "in-memory databases do not support standalone connections".into(),
})?;
let conn = rusqlite::Connection::open_with_flags(
path,
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
| rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
| rusqlite::OpenFlags::SQLITE_OPEN_URI,
)
.map_err(|e| map_err(e, "open_graph_reader"))?;
conn.busy_timeout(config.busy_timeout)
.map_err(|e| map_err(e, "open_graph_reader"))?;
conn.pragma_update(None, "foreign_keys", "ON")
.map_err(|e| map_err(e, "open_graph_reader"))?;
conn.pragma_update(None, "synchronous", "NORMAL")
.map_err(|e| map_err(e, "open_graph_reader"))?;
Ok(conn)
}
async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
where
F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
R: Send + 'static,
{
if self.is_file_backed {
let conn = self.open_standalone_writer()?;
tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
.await
.map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
} else {
let pool = Arc::clone(&self.pool);
tokio::task::spawn_blocking(move || {
let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
f(guard.conn()).map_err(|e| map_err(e, op))
})
.await
.map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
}
}
async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
where
F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
R: Send + 'static,
{
if self.is_file_backed {
let conn = self.open_standalone_reader()?;
tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
.await
.map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
} else {
let pool = Arc::clone(&self.pool);
tokio::task::spawn_blocking(move || {
let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
f(guard.conn()).map_err(|e| map_err(e, op))
})
.await
.map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
}
}
}
fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
let id_str: String = row.get(0)?;
let source_str: String = row.get(1)?;
let target_str: String = row.get(2)?;
let relation_str: String = row.get(3)?;
let weight: f64 = row.get(4)?;
let created_micros: i64 = row.get(5)?;
let metadata_str: Option<String> = row.get(6)?;
let id = parse_uuid(&id_str)?;
let source_id = parse_uuid(&source_str)?;
let target_id = parse_uuid(&target_str)?;
let created_at = micros_to_datetime(created_micros);
let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(3, rusqlite::types::Type::Text, Box::new(e))
})?;
let metadata = metadata_str.and_then(|s| serde_json::from_str(&s).ok());
Ok(Edge {
id: id.into(),
source_id,
target_id,
relation,
weight,
created_at,
metadata,
})
}
fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
Uuid::parse_str(s).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
})
}
fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
Utc.timestamp_micros(micros)
.single()
.unwrap_or_else(Utc::now)
}
fn build_edge_filter_sql(
namespace: &str,
filter: &EdgeFilter,
) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
let mut conditions: Vec<String> = vec!["namespace = ?1".to_string()];
let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
if !filter.ids.is_empty() {
let placeholders: Vec<String> = filter
.ids
.iter()
.map(|id| {
params.push(Box::new(id.to_string()));
format!("?{}", params.len())
})
.collect();
conditions.push(format!("id IN ({})", placeholders.join(",")));
}
if !filter.source_ids.is_empty() {
let placeholders: Vec<String> = filter
.source_ids
.iter()
.map(|id| {
params.push(Box::new(id.to_string()));
format!("?{}", params.len())
})
.collect();
conditions.push(format!("source_id IN ({})", placeholders.join(",")));
}
if !filter.target_ids.is_empty() {
let placeholders: Vec<String> = filter
.target_ids
.iter()
.map(|id| {
params.push(Box::new(id.to_string()));
format!("?{}", params.len())
})
.collect();
conditions.push(format!("target_id IN ({})", placeholders.join(",")));
}
if !filter.relations.is_empty() {
let placeholders: Vec<String> = filter
.relations
.iter()
.map(|r| {
params.push(Box::new(r.to_string()));
format!("?{}", params.len())
})
.collect();
conditions.push(format!("relation IN ({})", placeholders.join(",")));
}
if let Some(min_w) = filter.min_weight {
params.push(Box::new(min_w));
conditions.push(format!("weight >= ?{}", params.len()));
}
if let Some(max_w) = filter.max_weight {
params.push(Box::new(max_w));
conditions.push(format!("weight <= ?{}", params.len()));
}
if let Some(ref time_range) = filter.created_at {
if let Some(start) = time_range.start {
params.push(Box::new(start.timestamp_micros()));
conditions.push(format!("created_at >= ?{}", params.len()));
}
if let Some(end) = time_range.end {
params.push(Box::new(end.timestamp_micros()));
conditions.push(format!("created_at < ?{}", params.len()));
}
}
let clause = format!(" WHERE {}", conditions.join(" AND "));
(clause, params)
}
fn edge_sort_col(field: &EdgeSortField) -> &'static str {
match field {
EdgeSortField::CreatedAt => "created_at",
EdgeSortField::Weight => "weight",
EdgeSortField::Relation => "relation",
}
}
#[async_trait]
impl GraphStore for SqlGraphStore {
async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
let namespace = self.namespace.clone();
let id_str = Uuid::from(edge.id).to_string();
let src_str = edge.source_id.to_string();
let tgt_str = edge.target_id.to_string();
let relation_str = edge.relation.to_string();
let metadata_str = edge
.metadata
.as_ref()
.map(|v| serde_json::to_string(v).unwrap_or_default());
self.with_writer("upsert_edge", move |conn| {
conn.execute(
"INSERT INTO graph_edges \
(namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) \
ON CONFLICT(namespace, id) DO UPDATE SET \
source_id = excluded.source_id, \
target_id = excluded.target_id, \
relation = excluded.relation, \
weight = excluded.weight, \
created_at = excluded.created_at, \
metadata = excluded.metadata \
ON CONFLICT(namespace, source_id, target_id, relation) DO NOTHING",
rusqlite::params![
namespace,
id_str,
src_str,
tgt_str,
relation_str,
edge.weight,
edge.created_at.timestamp_micros(),
metadata_str,
],
)?;
Ok(())
})
.await
}
async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
let attempted = edges.len() as u64;
let namespace = self.namespace.clone();
self.with_writer("upsert_edges", move |conn| {
conn.execute_batch("BEGIN IMMEDIATE")?;
let mut affected = 0u64;
let mut failed = 0u64;
let mut first_error = String::new();
for edge in &edges {
let id_str = Uuid::from(edge.id).to_string();
let src_str = edge.source_id.to_string();
let tgt_str = edge.target_id.to_string();
let relation_str = edge.relation.to_string();
let metadata_str = edge
.metadata
.as_ref()
.map(|v| serde_json::to_string(v).unwrap_or_default());
match conn.execute(
"INSERT INTO graph_edges \
(namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) \
ON CONFLICT(namespace, id) DO UPDATE SET \
source_id = excluded.source_id, \
target_id = excluded.target_id, \
relation = excluded.relation, \
weight = excluded.weight, \
created_at = excluded.created_at, \
metadata = excluded.metadata \
ON CONFLICT(namespace, source_id, target_id, relation) DO NOTHING",
rusqlite::params![
&namespace,
id_str,
src_str,
tgt_str,
relation_str,
edge.weight,
edge.created_at.timestamp_micros(),
metadata_str,
],
) {
Ok(_) => affected += 1,
Err(e) => {
if first_error.is_empty() {
first_error = e.to_string();
}
failed += 1;
}
}
}
if let Err(e) = conn.execute_batch("COMMIT") {
let _ = conn.execute_batch("ROLLBACK");
return Err(e);
}
Ok(BatchWriteSummary {
attempted,
affected,
failed,
first_error,
})
})
.await
}
async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
let namespace = self.namespace.clone();
let id_str = Uuid::from(id).to_string();
self.with_reader("get_edge", move |conn| {
let mut stmt = conn.prepare(
"SELECT id, source_id, target_id, relation, weight, created_at, metadata \
FROM graph_edges WHERE namespace = ?1 AND id = ?2",
)?;
let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
match rows.next()? {
Some(row) => Ok(Some(read_edge(row)?)),
None => Ok(None),
}
})
.await
}
async fn delete_edge(&self, id: LinkId) -> Result<bool, StorageError> {
let namespace = self.namespace.clone();
let id_str = Uuid::from(id).to_string();
self.with_writer("delete_edge", move |conn| {
let deleted = conn.execute(
"DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
rusqlite::params![namespace, id_str],
)?;
Ok(deleted > 0)
})
.await
}
async fn query_edges(
&self,
filter: EdgeFilter,
sort: Vec<SortOrder<EdgeSortField>>,
page: PageRequest,
) -> Result<Page<Edge>, StorageError> {
let namespace = self.namespace.clone();
self.with_reader("query_edges", move |conn| {
let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
let total: i64 = {
let mut stmt = conn.prepare(&count_sql)?;
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
filter_params.iter().map(|p| p.as_ref()).collect();
stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
};
let order_clause = if sort.is_empty() {
" ORDER BY created_at DESC".to_string()
} else {
let parts: Vec<String> = sort
.iter()
.map(|s| {
let dir = match s.direction {
SortDirection::Asc => "ASC",
SortDirection::Desc => "DESC",
};
format!("{} {}", edge_sort_col(&s.field), dir)
})
.collect();
format!(" ORDER BY {}", parts.join(", "))
};
let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
all_params.push(Box::new(page.limit as i64));
all_params.push(Box::new(page.offset as i64));
let limit_idx = all_params.len() - 1;
let offset_idx = all_params.len();
let data_sql = format!(
"SELECT id, source_id, target_id, relation, weight, created_at, metadata \
FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
where_clause, order_clause, limit_idx, offset_idx,
);
let mut stmt = conn.prepare(&data_sql)?;
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
all_params.iter().map(|p| p.as_ref()).collect();
let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
let mut items = Vec::new();
for row in rows {
items.push(row?);
}
Ok(Page {
items,
total: Some(total as u64),
})
})
.await
}
async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
let namespace = self.namespace.clone();
self.with_reader("count_edges", move |conn| {
let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
let mut stmt = conn.prepare(&sql)?;
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
params.iter().map(|p| p.as_ref()).collect();
let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
Ok(count as u64)
})
.await
}
async fn neighbors(
&self,
node_id: Uuid,
query: NeighborQuery,
) -> Result<Vec<NeighborHit>, StorageError> {
use khive_storage::types::Direction;
let namespace = self.namespace.clone();
let node_str = node_id.to_string();
self.with_reader("neighbors", move |conn| {
let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
FROM graph_edges WHERE namespace = ?1 AND source_id = ?2";
let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
FROM graph_edges WHERE namespace = ?1 AND target_id = ?2";
let sql = match query.direction {
Direction::Out => base_out.to_string(),
Direction::In => base_in.to_string(),
Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
};
let mut conditions: Vec<String> = Vec::new();
let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut param_idx = 3;
if let Some(ref rels) = query.relations {
if !rels.is_empty() {
let placeholders: Vec<String> = rels
.iter()
.map(|r| {
extra_params.push(Box::new(r.to_string()));
let p = format!("?{}", param_idx);
param_idx += 1;
p
})
.collect();
conditions.push(format!("relation IN ({})", placeholders.join(",")));
}
}
if let Some(min_w) = query.min_weight {
extra_params.push(Box::new(min_w));
conditions.push(format!("weight >= ?{}", param_idx));
param_idx += 1;
}
let where_extra = if conditions.is_empty() {
String::new()
} else {
format!(" WHERE {}", conditions.join(" AND "))
};
let limit_clause = if let Some(lim) = query.limit {
extra_params.push(Box::new(lim as i64));
format!(" LIMIT ?{}", param_idx)
} else {
String::new()
};
let full_sql = format!(
"SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
sql, where_extra, limit_clause
);
let mut stmt = conn.prepare(&full_sql)?;
let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
all_params.push(Box::new(namespace.clone()));
all_params.push(Box::new(node_str.clone()));
all_params.extend(extra_params);
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
all_params.iter().map(|p| p.as_ref()).collect();
let rows = stmt.query_map(param_refs.as_slice(), |row| {
let nid_str: String = row.get(0)?;
let eid_str: String = row.get(1)?;
let relation_str: String = row.get(2)?;
let weight: f64 = row.get(3)?;
Ok((nid_str, eid_str, relation_str, weight))
})?;
let mut hits = Vec::new();
for row in rows {
let (nid_str, eid_str, relation_str, weight) = row?;
let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
2,
rusqlite::types::Type::Text,
Box::new(e),
)
})?;
hits.push(NeighborHit {
node_id: parse_uuid(&nid_str)?,
edge_id: parse_uuid(&eid_str)?,
relation,
weight,
name: None,
kind: None,
});
}
Ok(hits)
})
.await
}
async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
use khive_storage::types::Direction;
if request.roots.is_empty() {
return Ok(Vec::new());
}
let roots = request.roots.clone();
let opts = request.options.clone();
let include_roots = request.include_roots;
let namespace = self.namespace.clone();
self.with_reader("traverse", move |conn| {
let mut all_paths: Vec<GraphPath> = Vec::new();
for root_id in &roots {
let root_str = root_id.to_string();
let (join_condition, next_node) = match opts.direction {
Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
Direction::In => ("e.target_id = t.node_id", "e.source_id"),
Direction::Both => (
"(e.source_id = t.node_id OR e.target_id = t.node_id)",
"CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
),
};
let mut relation_cond = String::new();
let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut param_idx = 4;
if let Some(ref rels) = opts.relations {
if !rels.is_empty() {
let placeholders: Vec<String> = rels
.iter()
.map(|r| {
relation_params.push(Box::new(r.to_string()));
let p = format!("?{}", param_idx);
param_idx += 1;
p
})
.collect();
relation_cond =
format!(" AND e.relation IN ({})", placeholders.join(","));
}
}
let mut weight_cond = String::new();
if let Some(min_w) = opts.min_weight {
relation_params.push(Box::new(min_w));
weight_cond = format!(" AND e.weight >= ?{}", param_idx);
param_idx += 1;
}
let limit_clause = if let Some(lim) = opts.limit {
relation_params.push(Box::new(lim as i64));
format!(" LIMIT ?{}", param_idx)
} else {
String::new()
};
let cte_sql = format!(
"WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
SELECT ?2, NULL, 0, ?2, 0.0 \
UNION ALL \
SELECT {next_node}, e.id, t.depth + 1, \
t.path || ',' || {next_node}, \
t.total_weight + e.weight \
FROM graph_edges e \
JOIN traversal t ON {join_condition} \
WHERE e.namespace = ?1 \
AND t.depth < ?3 \
AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
) \
SELECT node_id, edge_id, depth, path, total_weight \
FROM traversal WHERE depth > 0 \
ORDER BY depth{limit}",
next_node = next_node,
join_condition = join_condition,
rel_cond = relation_cond,
wt_cond = weight_cond,
limit = limit_clause,
);
let mut stmt = conn.prepare(&cte_sql)?;
let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
all_params.push(Box::new(namespace.clone()));
all_params.push(Box::new(root_str.clone()));
all_params.push(Box::new(opts.max_depth as i64));
all_params.extend(relation_params);
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
all_params.iter().map(|p| p.as_ref()).collect();
let rows = stmt.query_map(param_refs.as_slice(), |row| {
let node_str: String = row.get(0)?;
let edge_str: Option<String> = row.get(1)?;
let depth: i64 = row.get(2)?;
let _path: String = row.get(3)?;
let total_weight: f64 = row.get(4)?;
Ok((node_str, edge_str, depth, total_weight))
})?;
let mut nodes = Vec::new();
let mut max_weight = 0.0f64;
if include_roots {
nodes.push(PathNode {
node_id: *root_id,
via_edge: None,
depth: 0,
name: None,
kind: None,
});
}
for row in rows {
let (node_str, edge_str, depth, total_weight) = row?;
let node_id = parse_uuid(&node_str)?;
let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
nodes.push(PathNode {
node_id,
via_edge,
depth: depth as usize,
name: None,
kind: None,
});
if total_weight > max_weight {
max_weight = total_weight;
}
}
if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
all_paths.push(GraphPath {
root_id: *root_id,
nodes,
total_weight: max_weight,
});
}
}
Ok(all_paths)
})
.await
}
}
const GRAPH_DDL: &str = "\
CREATE TABLE IF NOT EXISTS graph_edges (\
namespace TEXT NOT NULL,\
id TEXT NOT NULL,\
source_id TEXT NOT NULL,\
target_id TEXT NOT NULL,\
relation TEXT NOT NULL,\
weight REAL NOT NULL DEFAULT 1.0,\
created_at INTEGER NOT NULL,\
metadata TEXT,\
PRIMARY KEY (namespace, id)\
);\
CREATE UNIQUE INDEX IF NOT EXISTS idx_graph_edges_unique_triple ON graph_edges(namespace, source_id, target_id, relation);\
CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_source ON graph_edges(namespace, source_id);\
CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_target ON graph_edges(namespace, target_id);\
CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_relation ON graph_edges(namespace, relation);\
CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_src_rel ON graph_edges(namespace, source_id, relation);\
CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_tgt_rel ON graph_edges(namespace, target_id, relation);\
";
pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
conn.execute_batch(GRAPH_DDL)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pool::PoolConfig;
use khive_storage::types::{Direction, TraversalOptions};
fn setup_memory_store() -> SqlGraphStore {
let config = PoolConfig {
path: None,
..PoolConfig::default()
};
let pool = Arc::new(ConnectionPool::new(config).unwrap());
{
let writer = pool.writer().unwrap();
writer.conn().execute_batch(GRAPH_DDL).unwrap();
}
SqlGraphStore::new_scoped(pool, false, "default")
}
fn make_edge(source: Uuid, target: Uuid, relation: EdgeRelation, weight: f64) -> Edge {
Edge {
id: Uuid::new_v4().into(),
source_id: source,
target_id: target,
relation,
weight,
created_at: Utc::now(),
metadata: None,
}
}
#[tokio::test]
async fn test_upsert_and_get_edge() {
let store = setup_memory_store();
let src = Uuid::new_v4();
let tgt = Uuid::new_v4();
let edge = Edge {
id: Uuid::new_v4().into(),
source_id: src,
target_id: tgt,
relation: EdgeRelation::Extends,
weight: 0.8,
created_at: Utc::now(),
metadata: None,
};
let edge_id = edge.id;
store.upsert_edge(edge).await.unwrap();
let fetched = store.get_edge(edge_id).await.unwrap();
assert!(fetched.is_some());
let fetched = fetched.unwrap();
assert_eq!(fetched.id, edge_id);
assert_eq!(fetched.source_id, src);
assert_eq!(fetched.target_id, tgt);
assert_eq!(fetched.relation, EdgeRelation::Extends);
assert!((fetched.weight - 0.8).abs() < 1e-9);
}
#[tokio::test]
async fn test_delete_edge() {
let store = setup_memory_store();
let edge = make_edge(Uuid::new_v4(), Uuid::new_v4(), EdgeRelation::Contains, 1.0);
let edge_id = edge.id;
store.upsert_edge(edge).await.unwrap();
assert!(store.get_edge(edge_id).await.unwrap().is_some());
let deleted = store.delete_edge(edge_id).await.unwrap();
assert!(deleted);
assert!(store.get_edge(edge_id).await.unwrap().is_none());
let deleted_again = store.delete_edge(edge_id).await.unwrap();
assert!(!deleted_again);
}
#[tokio::test]
async fn test_count_edges() {
let store = setup_memory_store();
assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 0);
for _ in 0..5 {
store
.upsert_edge(make_edge(
Uuid::new_v4(),
Uuid::new_v4(),
EdgeRelation::DependsOn,
1.0,
))
.await
.unwrap();
}
assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 5);
}
#[tokio::test]
async fn test_neighbors_outbound() {
let store = setup_memory_store();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
let d = Uuid::new_v4();
store
.upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
.await
.unwrap();
store
.upsert_edge(make_edge(a, c, EdgeRelation::DependsOn, 0.7))
.await
.unwrap();
store
.upsert_edge(make_edge(d, a, EdgeRelation::Extends, 0.5))
.await
.unwrap();
let query = NeighborQuery {
direction: Direction::Out,
relations: None,
limit: None,
min_weight: None,
};
let hits = store.neighbors(a, query).await.unwrap();
assert_eq!(hits.len(), 2);
let neighbor_ids: Vec<Uuid> = hits.iter().map(|h| h.node_id).collect();
assert!(neighbor_ids.contains(&b));
assert!(neighbor_ids.contains(&c));
assert!(!neighbor_ids.contains(&d));
}
#[tokio::test]
async fn test_traverse_depth_2() {
let store = setup_memory_store();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
let d = Uuid::new_v4();
store
.upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
.await
.unwrap();
store
.upsert_edge(make_edge(b, c, EdgeRelation::Extends, 2.0))
.await
.unwrap();
store
.upsert_edge(make_edge(c, d, EdgeRelation::Extends, 3.0))
.await
.unwrap();
let request = TraversalRequest {
roots: vec![a],
options: TraversalOptions::new(2).with_direction(Direction::Out),
include_roots: true,
};
let paths = store.traverse(request).await.unwrap();
assert_eq!(paths.len(), 1);
let path = &paths[0];
let node_ids: Vec<Uuid> = path.nodes.iter().map(|n| n.node_id).collect();
assert!(node_ids.contains(&a));
assert!(node_ids.contains(&b));
assert!(node_ids.contains(&c));
assert!(!node_ids.contains(&d));
}
#[tokio::test]
async fn test_metadata_roundtrip() {
let store = setup_memory_store();
let src = Uuid::new_v4();
let tgt = Uuid::new_v4();
let meta = serde_json::json!({"note": "important link", "confidence": 0.95});
let edge = Edge {
id: Uuid::new_v4().into(),
source_id: src,
target_id: tgt,
relation: EdgeRelation::Implements,
weight: 0.9,
created_at: Utc::now(),
metadata: Some(meta.clone()),
};
let edge_id = edge.id;
store.upsert_edge(edge).await.unwrap();
let fetched = store.get_edge(edge_id).await.unwrap().unwrap();
assert_eq!(
fetched.metadata.as_ref(),
Some(&meta),
"metadata must survive a write/read roundtrip via get_edge"
);
let page = store
.query_edges(EdgeFilter::default(), vec![], PageRequest::default())
.await
.unwrap();
let from_query = page
.items
.iter()
.find(|e| e.id == edge_id)
.expect("edge must appear in query_edges result");
assert_eq!(
from_query.metadata.as_ref(),
Some(&meta),
"metadata must survive a write/read roundtrip via query_edges"
);
}
#[tokio::test]
async fn test_upsert_edges_batch() {
let store = setup_memory_store();
let edges: Vec<Edge> = (0..10)
.map(|i| {
make_edge(
Uuid::new_v4(),
Uuid::new_v4(),
EdgeRelation::Implements,
i as f64,
)
})
.collect();
let summary = store.upsert_edges(edges).await.unwrap();
assert_eq!(summary.attempted, 10);
assert_eq!(summary.affected, 10);
assert_eq!(summary.failed, 0);
assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 10);
}
#[tokio::test]
async fn graph_duplicate_edges_ignored() {
let store = setup_memory_store();
let src = Uuid::new_v4();
let tgt = Uuid::new_v4();
let edge1 = Edge {
id: Uuid::new_v4().into(),
source_id: src,
target_id: tgt,
relation: EdgeRelation::Extends,
weight: 1.0,
created_at: Utc::now(),
metadata: None,
};
let edge2 = Edge {
id: Uuid::new_v4().into(),
source_id: src,
target_id: tgt,
relation: EdgeRelation::Extends,
weight: 0.5,
created_at: Utc::now(),
metadata: None,
};
store.upsert_edge(edge1).await.unwrap();
store.upsert_edge(edge2).await.unwrap();
assert_eq!(
store.count_edges(EdgeFilter::default()).await.unwrap(),
1,
"duplicate (source, target, relation) triple must be ignored; only one edge must exist"
);
}
}