Skip to main content

khive_db/stores/
graph.rs

1//! SQL-backed `GraphStore` implementation.
2//!
3//! `SqlGraphStore` stores graph edges in a regular SQLite table.
4//! Traversal uses recursive CTEs for multi-hop queries.
5//!
6//! # Connection strategy
7//!
8//! - **File-backed**: Opens standalone connections per operation.
9//! - **In-memory**: Acquires pool connections per operation via `spawn_blocking`.
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use chrono::{DateTime, TimeZone, Utc};
15use uuid::Uuid;
16
17use khive_storage::error::StorageError;
18use khive_storage::types::{
19    BatchWriteSummary, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit, NeighborQuery,
20    Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
21};
22use khive_storage::GraphStore;
23use khive_storage::LinkId;
24use khive_storage::StorageCapability;
25use khive_types::EdgeRelation;
26
27use crate::error::SqliteError;
28use crate::pool::ConnectionPool;
29
30/// Map a rusqlite error to `StorageError` with `Graph` capability.
31fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
32    StorageError::driver(StorageCapability::Graph, op, e)
33}
34
35fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
36    StorageError::driver(StorageCapability::Graph, op, e)
37}
38
39/// A GraphStore backed by SQLite tables.
40pub struct SqlGraphStore {
41    pool: Arc<ConnectionPool>,
42    is_file_backed: bool,
43    namespace: String,
44}
45
46impl SqlGraphStore {
47    /// Create a new store scoped to one namespace.
48    pub fn new_scoped(
49        pool: Arc<ConnectionPool>,
50        is_file_backed: bool,
51        namespace: impl Into<String>,
52    ) -> Self {
53        Self {
54            pool,
55            is_file_backed,
56            namespace: namespace.into(),
57        }
58    }
59
60    fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
61        let config = self.pool.config();
62        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
63            operation: "graph_writer".into(),
64            message: "in-memory databases do not support standalone connections".into(),
65        })?;
66
67        let conn = rusqlite::Connection::open_with_flags(
68            path,
69            rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
70                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
71                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
72        )
73        .map_err(|e| map_err(e, "open_graph_writer"))?;
74
75        conn.busy_timeout(config.busy_timeout)
76            .map_err(|e| map_err(e, "open_graph_writer"))?;
77        conn.pragma_update(None, "foreign_keys", "ON")
78            .map_err(|e| map_err(e, "open_graph_writer"))?;
79        conn.pragma_update(None, "synchronous", "NORMAL")
80            .map_err(|e| map_err(e, "open_graph_writer"))?;
81
82        Ok(conn)
83    }
84
85    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
86        let config = self.pool.config();
87        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
88            operation: "graph_reader".into(),
89            message: "in-memory databases do not support standalone connections".into(),
90        })?;
91
92        let conn = rusqlite::Connection::open_with_flags(
93            path,
94            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
95                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
96                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
97        )
98        .map_err(|e| map_err(e, "open_graph_reader"))?;
99
100        conn.busy_timeout(config.busy_timeout)
101            .map_err(|e| map_err(e, "open_graph_reader"))?;
102        conn.pragma_update(None, "foreign_keys", "ON")
103            .map_err(|e| map_err(e, "open_graph_reader"))?;
104        conn.pragma_update(None, "synchronous", "NORMAL")
105            .map_err(|e| map_err(e, "open_graph_reader"))?;
106
107        Ok(conn)
108    }
109
110    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
111    where
112        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
113        R: Send + 'static,
114    {
115        if self.is_file_backed {
116            let conn = self.open_standalone_writer()?;
117            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
118                .await
119                .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
120        } else {
121            let pool = Arc::clone(&self.pool);
122            tokio::task::spawn_blocking(move || {
123                let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
124                f(guard.conn()).map_err(|e| map_err(e, op))
125            })
126            .await
127            .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
128        }
129    }
130
131    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
132    where
133        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
134        R: Send + 'static,
135    {
136        if self.is_file_backed {
137            let conn = self.open_standalone_reader()?;
138            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
139                .await
140                .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
141        } else {
142            let pool = Arc::clone(&self.pool);
143            tokio::task::spawn_blocking(move || {
144                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
145                f(guard.conn()).map_err(|e| map_err(e, op))
146            })
147            .await
148            .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
149        }
150    }
151}
152
153// =============================================================================
154// Helpers
155// =============================================================================
156
157fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
158    let id_str: String = row.get(0)?;
159    let source_str: String = row.get(1)?;
160    let target_str: String = row.get(2)?;
161    let relation_str: String = row.get(3)?;
162    let weight: f64 = row.get(4)?;
163    let created_micros: i64 = row.get(5)?;
164    let metadata_str: Option<String> = row.get(6)?;
165
166    let id = parse_uuid(&id_str)?;
167    let source_id = parse_uuid(&source_str)?;
168    let target_id = parse_uuid(&target_str)?;
169    let created_at = micros_to_datetime(created_micros);
170    let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
171        rusqlite::Error::FromSqlConversionFailure(3, rusqlite::types::Type::Text, Box::new(e))
172    })?;
173    let metadata = metadata_str.and_then(|s| serde_json::from_str(&s).ok());
174
175    Ok(Edge {
176        id: id.into(),
177        source_id,
178        target_id,
179        relation,
180        weight,
181        created_at,
182        metadata,
183    })
184}
185
186fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
187    Uuid::parse_str(s).map_err(|e| {
188        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
189    })
190}
191
192fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
193    Utc.timestamp_micros(micros)
194        .single()
195        .unwrap_or_else(Utc::now)
196}
197
198fn build_edge_filter_sql(
199    namespace: &str,
200    filter: &EdgeFilter,
201) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
202    let mut conditions: Vec<String> = vec!["namespace = ?1".to_string()];
203    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
204
205    if !filter.ids.is_empty() {
206        let placeholders: Vec<String> = filter
207            .ids
208            .iter()
209            .map(|id| {
210                params.push(Box::new(id.to_string()));
211                format!("?{}", params.len())
212            })
213            .collect();
214        conditions.push(format!("id IN ({})", placeholders.join(",")));
215    }
216
217    if !filter.source_ids.is_empty() {
218        let placeholders: Vec<String> = filter
219            .source_ids
220            .iter()
221            .map(|id| {
222                params.push(Box::new(id.to_string()));
223                format!("?{}", params.len())
224            })
225            .collect();
226        conditions.push(format!("source_id IN ({})", placeholders.join(",")));
227    }
228
229    if !filter.target_ids.is_empty() {
230        let placeholders: Vec<String> = filter
231            .target_ids
232            .iter()
233            .map(|id| {
234                params.push(Box::new(id.to_string()));
235                format!("?{}", params.len())
236            })
237            .collect();
238        conditions.push(format!("target_id IN ({})", placeholders.join(",")));
239    }
240
241    if !filter.relations.is_empty() {
242        let placeholders: Vec<String> = filter
243            .relations
244            .iter()
245            .map(|r| {
246                params.push(Box::new(r.to_string()));
247                format!("?{}", params.len())
248            })
249            .collect();
250        conditions.push(format!("relation IN ({})", placeholders.join(",")));
251    }
252
253    if let Some(min_w) = filter.min_weight {
254        params.push(Box::new(min_w));
255        conditions.push(format!("weight >= ?{}", params.len()));
256    }
257
258    if let Some(max_w) = filter.max_weight {
259        params.push(Box::new(max_w));
260        conditions.push(format!("weight <= ?{}", params.len()));
261    }
262
263    if let Some(ref time_range) = filter.created_at {
264        if let Some(start) = time_range.start {
265            params.push(Box::new(start.timestamp_micros()));
266            conditions.push(format!("created_at >= ?{}", params.len()));
267        }
268        if let Some(end) = time_range.end {
269            params.push(Box::new(end.timestamp_micros()));
270            conditions.push(format!("created_at < ?{}", params.len()));
271        }
272    }
273
274    let clause = format!(" WHERE {}", conditions.join(" AND "));
275    (clause, params)
276}
277
278fn edge_sort_col(field: &EdgeSortField) -> &'static str {
279    match field {
280        EdgeSortField::CreatedAt => "created_at",
281        EdgeSortField::Weight => "weight",
282        EdgeSortField::Relation => "relation",
283    }
284}
285
286// =============================================================================
287// GraphStore implementation
288// =============================================================================
289
290#[async_trait]
291impl GraphStore for SqlGraphStore {
292    async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
293        let namespace = self.namespace.clone();
294        let id_str = Uuid::from(edge.id).to_string();
295        let src_str = edge.source_id.to_string();
296        let tgt_str = edge.target_id.to_string();
297        let relation_str = edge.relation.to_string();
298        let metadata_str = edge
299            .metadata
300            .as_ref()
301            .map(|v| serde_json::to_string(v).unwrap_or_default());
302        self.with_writer("upsert_edge", move |conn| {
303            conn.execute(
304                "INSERT OR REPLACE INTO graph_edges \
305                 (namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
306                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
307                rusqlite::params![
308                    namespace,
309                    id_str,
310                    src_str,
311                    tgt_str,
312                    relation_str,
313                    edge.weight,
314                    edge.created_at.timestamp_micros(),
315                    metadata_str,
316                ],
317            )?;
318            Ok(())
319        })
320        .await
321    }
322
323    async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
324        let attempted = edges.len() as u64;
325        let namespace = self.namespace.clone();
326
327        self.with_writer("upsert_edges", move |conn| {
328            conn.execute_batch("BEGIN IMMEDIATE")?;
329            let mut affected = 0u64;
330            let mut failed = 0u64;
331            let mut first_error = String::new();
332
333            for edge in &edges {
334                let id_str = Uuid::from(edge.id).to_string();
335                let src_str = edge.source_id.to_string();
336                let tgt_str = edge.target_id.to_string();
337                let relation_str = edge.relation.to_string();
338                let metadata_str = edge
339                    .metadata
340                    .as_ref()
341                    .map(|v| serde_json::to_string(v).unwrap_or_default());
342                match conn.execute(
343                    "INSERT OR REPLACE INTO graph_edges \
344                     (namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
345                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
346                    rusqlite::params![
347                        &namespace,
348                        id_str,
349                        src_str,
350                        tgt_str,
351                        relation_str,
352                        edge.weight,
353                        edge.created_at.timestamp_micros(),
354                        metadata_str,
355                    ],
356                ) {
357                    Ok(_) => affected += 1,
358                    Err(e) => {
359                        if first_error.is_empty() {
360                            first_error = e.to_string();
361                        }
362                        failed += 1;
363                    }
364                }
365            }
366
367            if let Err(e) = conn.execute_batch("COMMIT") {
368                let _ = conn.execute_batch("ROLLBACK");
369                return Err(e);
370            }
371            Ok(BatchWriteSummary {
372                attempted,
373                affected,
374                failed,
375                first_error,
376            })
377        })
378        .await
379    }
380
381    async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
382        let namespace = self.namespace.clone();
383        let id_str = Uuid::from(id).to_string();
384
385        self.with_reader("get_edge", move |conn| {
386            let mut stmt = conn.prepare(
387                "SELECT id, source_id, target_id, relation, weight, created_at, metadata \
388                 FROM graph_edges WHERE namespace = ?1 AND id = ?2",
389            )?;
390            let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
391            match rows.next()? {
392                Some(row) => Ok(Some(read_edge(row)?)),
393                None => Ok(None),
394            }
395        })
396        .await
397    }
398
399    async fn delete_edge(&self, id: LinkId) -> Result<bool, StorageError> {
400        let namespace = self.namespace.clone();
401        let id_str = Uuid::from(id).to_string();
402
403        self.with_writer("delete_edge", move |conn| {
404            let deleted = conn.execute(
405                "DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
406                rusqlite::params![namespace, id_str],
407            )?;
408            Ok(deleted > 0)
409        })
410        .await
411    }
412
413    async fn query_edges(
414        &self,
415        filter: EdgeFilter,
416        sort: Vec<SortOrder<EdgeSortField>>,
417        page: PageRequest,
418    ) -> Result<Page<Edge>, StorageError> {
419        let namespace = self.namespace.clone();
420        self.with_reader("query_edges", move |conn| {
421            let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
422
423            let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
424            let total: i64 = {
425                let mut stmt = conn.prepare(&count_sql)?;
426                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
427                    filter_params.iter().map(|p| p.as_ref()).collect();
428                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
429            };
430
431            let order_clause = if sort.is_empty() {
432                " ORDER BY created_at DESC".to_string()
433            } else {
434                let parts: Vec<String> = sort
435                    .iter()
436                    .map(|s| {
437                        let dir = match s.direction {
438                            SortDirection::Asc => "ASC",
439                            SortDirection::Desc => "DESC",
440                        };
441                        format!("{} {}", edge_sort_col(&s.field), dir)
442                    })
443                    .collect();
444                format!(" ORDER BY {}", parts.join(", "))
445            };
446
447            let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
448            let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
449            all_params.push(Box::new(page.limit as i64));
450            all_params.push(Box::new(page.offset as i64));
451
452            let limit_idx = all_params.len() - 1;
453            let offset_idx = all_params.len();
454
455            let data_sql = format!(
456                "SELECT id, source_id, target_id, relation, weight, created_at, metadata \
457                 FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
458                where_clause, order_clause, limit_idx, offset_idx,
459            );
460
461            let mut stmt = conn.prepare(&data_sql)?;
462            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
463                all_params.iter().map(|p| p.as_ref()).collect();
464            let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
465
466            let mut items = Vec::new();
467            for row in rows {
468                items.push(row?);
469            }
470
471            Ok(Page {
472                items,
473                total: Some(total as u64),
474            })
475        })
476        .await
477    }
478
479    async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
480        let namespace = self.namespace.clone();
481        self.with_reader("count_edges", move |conn| {
482            let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
483            let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
484            let mut stmt = conn.prepare(&sql)?;
485            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
486                params.iter().map(|p| p.as_ref()).collect();
487            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
488            Ok(count as u64)
489        })
490        .await
491    }
492
493    async fn neighbors(
494        &self,
495        node_id: Uuid,
496        query: NeighborQuery,
497    ) -> Result<Vec<NeighborHit>, StorageError> {
498        use khive_storage::types::Direction;
499
500        let namespace = self.namespace.clone();
501        let node_str = node_id.to_string();
502
503        self.with_reader("neighbors", move |conn| {
504            let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
505                            FROM graph_edges WHERE namespace = ?1 AND source_id = ?2";
506            let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
507                           FROM graph_edges WHERE namespace = ?1 AND target_id = ?2";
508
509            let sql = match query.direction {
510                Direction::Out => base_out.to_string(),
511                Direction::In => base_in.to_string(),
512                Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
513            };
514
515            let mut conditions: Vec<String> = Vec::new();
516            let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
517            let mut param_idx = 3;
518
519            if let Some(ref rels) = query.relations {
520                if !rels.is_empty() {
521                    let placeholders: Vec<String> = rels
522                        .iter()
523                        .map(|r| {
524                            extra_params.push(Box::new(r.to_string()));
525                            let p = format!("?{}", param_idx);
526                            param_idx += 1;
527                            p
528                        })
529                        .collect();
530                    conditions.push(format!("relation IN ({})", placeholders.join(",")));
531                }
532            }
533
534            if let Some(min_w) = query.min_weight {
535                extra_params.push(Box::new(min_w));
536                conditions.push(format!("weight >= ?{}", param_idx));
537                param_idx += 1;
538            }
539
540            let where_extra = if conditions.is_empty() {
541                String::new()
542            } else {
543                format!(" WHERE {}", conditions.join(" AND "))
544            };
545
546            let limit_clause = if let Some(lim) = query.limit {
547                extra_params.push(Box::new(lim as i64));
548                format!(" LIMIT ?{}", param_idx)
549            } else {
550                String::new()
551            };
552
553            let full_sql = format!(
554                "SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
555                sql, where_extra, limit_clause
556            );
557
558            let mut stmt = conn.prepare(&full_sql)?;
559
560            let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
561            all_params.push(Box::new(namespace.clone()));
562            all_params.push(Box::new(node_str.clone()));
563            all_params.extend(extra_params);
564
565            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
566                all_params.iter().map(|p| p.as_ref()).collect();
567
568            let rows = stmt.query_map(param_refs.as_slice(), |row| {
569                let nid_str: String = row.get(0)?;
570                let eid_str: String = row.get(1)?;
571                let relation_str: String = row.get(2)?;
572                let weight: f64 = row.get(3)?;
573                Ok((nid_str, eid_str, relation_str, weight))
574            })?;
575
576            let mut hits = Vec::new();
577            for row in rows {
578                let (nid_str, eid_str, relation_str, weight) = row?;
579                let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
580                    rusqlite::Error::FromSqlConversionFailure(
581                        2,
582                        rusqlite::types::Type::Text,
583                        Box::new(e),
584                    )
585                })?;
586                hits.push(NeighborHit {
587                    node_id: parse_uuid(&nid_str)?,
588                    edge_id: parse_uuid(&eid_str)?,
589                    relation,
590                    weight,
591                });
592            }
593
594            Ok(hits)
595        })
596        .await
597    }
598
599    async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
600        use khive_storage::types::Direction;
601
602        if request.roots.is_empty() {
603            return Ok(Vec::new());
604        }
605
606        let roots = request.roots.clone();
607        let opts = request.options.clone();
608        let include_roots = request.include_roots;
609        let namespace = self.namespace.clone();
610
611        self.with_reader("traverse", move |conn| {
612            let mut all_paths: Vec<GraphPath> = Vec::new();
613
614            for root_id in &roots {
615                let root_str = root_id.to_string();
616
617                let (join_condition, next_node) = match opts.direction {
618                    Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
619                    Direction::In => ("e.target_id = t.node_id", "e.source_id"),
620                    Direction::Both => (
621                        "(e.source_id = t.node_id OR e.target_id = t.node_id)",
622                        "CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
623                    ),
624                };
625
626                let mut relation_cond = String::new();
627                let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
628                let mut param_idx = 4;
629
630                if let Some(ref rels) = opts.relations {
631                    if !rels.is_empty() {
632                        let placeholders: Vec<String> = rels
633                            .iter()
634                            .map(|r| {
635                                relation_params.push(Box::new(r.to_string()));
636                                let p = format!("?{}", param_idx);
637                                param_idx += 1;
638                                p
639                            })
640                            .collect();
641                        relation_cond =
642                            format!(" AND e.relation IN ({})", placeholders.join(","));
643                    }
644                }
645
646                let mut weight_cond = String::new();
647                if let Some(min_w) = opts.min_weight {
648                    relation_params.push(Box::new(min_w));
649                    weight_cond = format!(" AND e.weight >= ?{}", param_idx);
650                    param_idx += 1;
651                }
652
653                let limit_clause = if let Some(lim) = opts.limit {
654                    relation_params.push(Box::new(lim as i64));
655                    format!(" LIMIT ?{}", param_idx)
656                } else {
657                    String::new()
658                };
659
660                let cte_sql = format!(
661                    "WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
662                         SELECT ?2, NULL, 0, ?2, 0.0 \
663                         UNION ALL \
664                         SELECT {next_node}, e.id, t.depth + 1, \
665                                t.path || ',' || {next_node}, \
666                                t.total_weight + e.weight \
667                         FROM graph_edges e \
668                         JOIN traversal t ON {join_condition} \
669                         WHERE e.namespace = ?1 \
670                           AND t.depth < ?3 \
671                           AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
672                     ) \
673                     SELECT node_id, edge_id, depth, path, total_weight \
674                     FROM traversal WHERE depth > 0 \
675                     ORDER BY depth{limit}",
676                    next_node = next_node,
677                    join_condition = join_condition,
678                    rel_cond = relation_cond,
679                    wt_cond = weight_cond,
680                    limit = limit_clause,
681                );
682
683                let mut stmt = conn.prepare(&cte_sql)?;
684
685                let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
686                all_params.push(Box::new(namespace.clone()));
687                all_params.push(Box::new(root_str.clone()));
688                all_params.push(Box::new(opts.max_depth as i64));
689                all_params.extend(relation_params);
690
691                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
692                    all_params.iter().map(|p| p.as_ref()).collect();
693
694                let rows = stmt.query_map(param_refs.as_slice(), |row| {
695                    let node_str: String = row.get(0)?;
696                    let edge_str: Option<String> = row.get(1)?;
697                    let depth: i64 = row.get(2)?;
698                    let _path: String = row.get(3)?;
699                    let total_weight: f64 = row.get(4)?;
700                    Ok((node_str, edge_str, depth, total_weight))
701                })?;
702
703                let mut nodes = Vec::new();
704                let mut max_weight = 0.0f64;
705
706                if include_roots {
707                    nodes.push(PathNode {
708                        node_id: *root_id,
709                        via_edge: None,
710                        depth: 0,
711                    });
712                }
713
714                for row in rows {
715                    let (node_str, edge_str, depth, total_weight) = row?;
716                    let node_id = parse_uuid(&node_str)?;
717                    let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
718                    nodes.push(PathNode {
719                        node_id,
720                        via_edge,
721                        depth: depth as usize,
722                    });
723                    if total_weight > max_weight {
724                        max_weight = total_weight;
725                    }
726                }
727
728                if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
729                    all_paths.push(GraphPath {
730                        root_id: *root_id,
731                        nodes,
732                        total_weight: max_weight,
733                    });
734                }
735            }
736
737            Ok(all_paths)
738        })
739        .await
740    }
741}
742
743// =============================================================================
744// DDL
745// =============================================================================
746
747const GRAPH_DDL: &str = "\
748    CREATE TABLE IF NOT EXISTS graph_edges (\
749        namespace TEXT NOT NULL,\
750        id TEXT NOT NULL,\
751        source_id TEXT NOT NULL,\
752        target_id TEXT NOT NULL,\
753        relation TEXT NOT NULL,\
754        weight REAL NOT NULL DEFAULT 1.0,\
755        created_at INTEGER NOT NULL,\
756        metadata TEXT,\
757        PRIMARY KEY (namespace, id)\
758    );\
759    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_source ON graph_edges(namespace, source_id);\
760    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_target ON graph_edges(namespace, target_id);\
761    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_relation ON graph_edges(namespace, relation);\
762    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_src_rel ON graph_edges(namespace, source_id, relation);\
763    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_tgt_rel ON graph_edges(namespace, target_id, relation);\
764";
765
766pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
767    conn.execute_batch(GRAPH_DDL)
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773    use crate::pool::PoolConfig;
774    use khive_storage::types::{Direction, TraversalOptions};
775
776    fn setup_memory_store() -> SqlGraphStore {
777        let config = PoolConfig {
778            path: None,
779            ..PoolConfig::default()
780        };
781        let pool = Arc::new(ConnectionPool::new(config).unwrap());
782
783        {
784            let writer = pool.writer().unwrap();
785            writer.conn().execute_batch(GRAPH_DDL).unwrap();
786        }
787
788        SqlGraphStore::new_scoped(pool, false, "default")
789    }
790
791    fn make_edge(source: Uuid, target: Uuid, relation: EdgeRelation, weight: f64) -> Edge {
792        Edge {
793            id: Uuid::new_v4().into(),
794            source_id: source,
795            target_id: target,
796            relation,
797            weight,
798            created_at: Utc::now(),
799            metadata: None,
800        }
801    }
802
803    #[tokio::test]
804    async fn test_upsert_and_get_edge() {
805        let store = setup_memory_store();
806
807        let src = Uuid::new_v4();
808        let tgt = Uuid::new_v4();
809        let edge = Edge {
810            id: Uuid::new_v4().into(),
811            source_id: src,
812            target_id: tgt,
813            relation: EdgeRelation::Extends,
814            weight: 0.8,
815            created_at: Utc::now(),
816            metadata: None,
817        };
818        let edge_id = edge.id;
819
820        store.upsert_edge(edge).await.unwrap();
821
822        let fetched = store.get_edge(edge_id).await.unwrap();
823        assert!(fetched.is_some());
824        let fetched = fetched.unwrap();
825        assert_eq!(fetched.id, edge_id);
826        assert_eq!(fetched.source_id, src);
827        assert_eq!(fetched.target_id, tgt);
828        assert_eq!(fetched.relation, EdgeRelation::Extends);
829        assert!((fetched.weight - 0.8).abs() < 1e-9);
830    }
831
832    #[tokio::test]
833    async fn test_delete_edge() {
834        let store = setup_memory_store();
835
836        let edge = make_edge(Uuid::new_v4(), Uuid::new_v4(), EdgeRelation::Contains, 1.0);
837        let edge_id = edge.id;
838
839        store.upsert_edge(edge).await.unwrap();
840        assert!(store.get_edge(edge_id).await.unwrap().is_some());
841
842        let deleted = store.delete_edge(edge_id).await.unwrap();
843        assert!(deleted);
844
845        assert!(store.get_edge(edge_id).await.unwrap().is_none());
846
847        let deleted_again = store.delete_edge(edge_id).await.unwrap();
848        assert!(!deleted_again);
849    }
850
851    #[tokio::test]
852    async fn test_count_edges() {
853        let store = setup_memory_store();
854
855        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 0);
856
857        for _ in 0..5 {
858            store
859                .upsert_edge(make_edge(
860                    Uuid::new_v4(),
861                    Uuid::new_v4(),
862                    EdgeRelation::DependsOn,
863                    1.0,
864                ))
865                .await
866                .unwrap();
867        }
868
869        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 5);
870    }
871
872    #[tokio::test]
873    async fn test_neighbors_outbound() {
874        let store = setup_memory_store();
875
876        let a = Uuid::new_v4();
877        let b = Uuid::new_v4();
878        let c = Uuid::new_v4();
879        let d = Uuid::new_v4();
880
881        store
882            .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
883            .await
884            .unwrap();
885        store
886            .upsert_edge(make_edge(a, c, EdgeRelation::DependsOn, 0.7))
887            .await
888            .unwrap();
889        store
890            .upsert_edge(make_edge(d, a, EdgeRelation::Extends, 0.5))
891            .await
892            .unwrap();
893
894        let query = NeighborQuery {
895            direction: Direction::Out,
896            relations: None,
897            limit: None,
898            min_weight: None,
899        };
900
901        let hits = store.neighbors(a, query).await.unwrap();
902        assert_eq!(hits.len(), 2);
903
904        let neighbor_ids: Vec<Uuid> = hits.iter().map(|h| h.node_id).collect();
905        assert!(neighbor_ids.contains(&b));
906        assert!(neighbor_ids.contains(&c));
907        assert!(!neighbor_ids.contains(&d));
908    }
909
910    #[tokio::test]
911    async fn test_traverse_depth_2() {
912        let store = setup_memory_store();
913
914        let a = Uuid::new_v4();
915        let b = Uuid::new_v4();
916        let c = Uuid::new_v4();
917        let d = Uuid::new_v4();
918
919        store
920            .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
921            .await
922            .unwrap();
923        store
924            .upsert_edge(make_edge(b, c, EdgeRelation::Extends, 2.0))
925            .await
926            .unwrap();
927        store
928            .upsert_edge(make_edge(c, d, EdgeRelation::Extends, 3.0))
929            .await
930            .unwrap();
931
932        let request = TraversalRequest {
933            roots: vec![a],
934            options: TraversalOptions::new(2).with_direction(Direction::Out),
935            include_roots: true,
936        };
937
938        let paths = store.traverse(request).await.unwrap();
939        assert_eq!(paths.len(), 1);
940
941        let path = &paths[0];
942        let node_ids: Vec<Uuid> = path.nodes.iter().map(|n| n.node_id).collect();
943        assert!(node_ids.contains(&a));
944        assert!(node_ids.contains(&b));
945        assert!(node_ids.contains(&c));
946        assert!(!node_ids.contains(&d));
947    }
948
949    #[tokio::test]
950    async fn test_metadata_roundtrip() {
951        let store = setup_memory_store();
952
953        let src = Uuid::new_v4();
954        let tgt = Uuid::new_v4();
955        let meta = serde_json::json!({"note": "important link", "confidence": 0.95});
956        let edge = Edge {
957            id: Uuid::new_v4().into(),
958            source_id: src,
959            target_id: tgt,
960            relation: EdgeRelation::Implements,
961            weight: 0.9,
962            created_at: Utc::now(),
963            metadata: Some(meta.clone()),
964        };
965        let edge_id = edge.id;
966
967        store.upsert_edge(edge).await.unwrap();
968
969        let fetched = store.get_edge(edge_id).await.unwrap().unwrap();
970        assert_eq!(
971            fetched.metadata.as_ref(),
972            Some(&meta),
973            "metadata must survive a write/read roundtrip via get_edge"
974        );
975
976        // Also verify via query_edges.
977        let page = store
978            .query_edges(EdgeFilter::default(), vec![], PageRequest::default())
979            .await
980            .unwrap();
981        let from_query = page
982            .items
983            .iter()
984            .find(|e| e.id == edge_id)
985            .expect("edge must appear in query_edges result");
986        assert_eq!(
987            from_query.metadata.as_ref(),
988            Some(&meta),
989            "metadata must survive a write/read roundtrip via query_edges"
990        );
991    }
992
993    #[tokio::test]
994    async fn test_upsert_edges_batch() {
995        let store = setup_memory_store();
996
997        let edges: Vec<Edge> = (0..10)
998            .map(|i| {
999                make_edge(
1000                    Uuid::new_v4(),
1001                    Uuid::new_v4(),
1002                    EdgeRelation::Implements,
1003                    i as f64,
1004                )
1005            })
1006            .collect();
1007
1008        let summary = store.upsert_edges(edges).await.unwrap();
1009        assert_eq!(summary.attempted, 10);
1010        assert_eq!(summary.affected, 10);
1011        assert_eq!(summary.failed, 0);
1012
1013        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 10);
1014    }
1015}