Skip to main content

khive_db/stores/
graph.rs

1//! SQL-backed `GraphStore` implementation.
2//!
3//! `SqlGraphStore` stores graph edges in a regular SQLite table.
4//! Traversal uses recursive CTEs for multi-hop queries.
5//!
6//! # Connection strategy
7//!
8//! - **File-backed**: Opens standalone connections per operation.
9//! - **In-memory**: Acquires pool connections per operation via `spawn_blocking`.
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use chrono::{DateTime, TimeZone, Utc};
15use uuid::Uuid;
16
17use khive_storage::error::StorageError;
18use khive_storage::types::{
19    BatchWriteSummary, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit, NeighborQuery,
20    Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
21};
22use khive_storage::GraphStore;
23use khive_storage::LinkId;
24use khive_storage::StorageCapability;
25use khive_types::EdgeRelation;
26
27use crate::error::SqliteError;
28use crate::pool::ConnectionPool;
29
30/// Map a rusqlite error to `StorageError` with `Graph` capability.
31fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
32    StorageError::driver(StorageCapability::Graph, op, e)
33}
34
35fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
36    StorageError::driver(StorageCapability::Graph, op, e)
37}
38
39/// A GraphStore backed by SQLite tables.
40pub struct SqlGraphStore {
41    pool: Arc<ConnectionPool>,
42    is_file_backed: bool,
43    namespace: String,
44}
45
46impl SqlGraphStore {
47    /// Create a new store scoped to one namespace.
48    pub fn new_scoped(
49        pool: Arc<ConnectionPool>,
50        is_file_backed: bool,
51        namespace: impl Into<String>,
52    ) -> Self {
53        Self {
54            pool,
55            is_file_backed,
56            namespace: namespace.into(),
57        }
58    }
59
60    fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
61        let config = self.pool.config();
62        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
63            operation: "graph_writer".into(),
64            message: "in-memory databases do not support standalone connections".into(),
65        })?;
66
67        let conn = rusqlite::Connection::open_with_flags(
68            path,
69            rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
70                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
71                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
72        )
73        .map_err(|e| map_err(e, "open_graph_writer"))?;
74
75        conn.busy_timeout(config.busy_timeout)
76            .map_err(|e| map_err(e, "open_graph_writer"))?;
77        conn.pragma_update(None, "foreign_keys", "ON")
78            .map_err(|e| map_err(e, "open_graph_writer"))?;
79        conn.pragma_update(None, "synchronous", "NORMAL")
80            .map_err(|e| map_err(e, "open_graph_writer"))?;
81
82        Ok(conn)
83    }
84
85    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
86        let config = self.pool.config();
87        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
88            operation: "graph_reader".into(),
89            message: "in-memory databases do not support standalone connections".into(),
90        })?;
91
92        let conn = rusqlite::Connection::open_with_flags(
93            path,
94            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
95                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
96                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
97        )
98        .map_err(|e| map_err(e, "open_graph_reader"))?;
99
100        conn.busy_timeout(config.busy_timeout)
101            .map_err(|e| map_err(e, "open_graph_reader"))?;
102        conn.pragma_update(None, "foreign_keys", "ON")
103            .map_err(|e| map_err(e, "open_graph_reader"))?;
104        conn.pragma_update(None, "synchronous", "NORMAL")
105            .map_err(|e| map_err(e, "open_graph_reader"))?;
106
107        Ok(conn)
108    }
109
110    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
111    where
112        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
113        R: Send + 'static,
114    {
115        if self.is_file_backed {
116            let conn = self.open_standalone_writer()?;
117            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
118                .await
119                .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
120        } else {
121            let pool = Arc::clone(&self.pool);
122            tokio::task::spawn_blocking(move || {
123                let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
124                f(guard.conn()).map_err(|e| map_err(e, op))
125            })
126            .await
127            .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
128        }
129    }
130
131    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
132    where
133        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
134        R: Send + 'static,
135    {
136        if self.is_file_backed {
137            let conn = self.open_standalone_reader()?;
138            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
139                .await
140                .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
141        } else {
142            let pool = Arc::clone(&self.pool);
143            tokio::task::spawn_blocking(move || {
144                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
145                f(guard.conn()).map_err(|e| map_err(e, op))
146            })
147            .await
148            .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
149        }
150    }
151}
152
153// =============================================================================
154// Helpers
155// =============================================================================
156
157fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
158    let id_str: String = row.get(0)?;
159    let source_str: String = row.get(1)?;
160    let target_str: String = row.get(2)?;
161    let relation_str: String = row.get(3)?;
162    let weight: f64 = row.get(4)?;
163    let created_micros: i64 = row.get(5)?;
164    let metadata_str: Option<String> = row.get(6)?;
165
166    let id = parse_uuid(&id_str)?;
167    let source_id = parse_uuid(&source_str)?;
168    let target_id = parse_uuid(&target_str)?;
169    let created_at = micros_to_datetime(created_micros);
170    let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
171        rusqlite::Error::FromSqlConversionFailure(3, rusqlite::types::Type::Text, Box::new(e))
172    })?;
173    let metadata = metadata_str.and_then(|s| serde_json::from_str(&s).ok());
174
175    Ok(Edge {
176        id: id.into(),
177        source_id,
178        target_id,
179        relation,
180        weight,
181        created_at,
182        metadata,
183    })
184}
185
186fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
187    Uuid::parse_str(s).map_err(|e| {
188        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
189    })
190}
191
192fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
193    Utc.timestamp_micros(micros)
194        .single()
195        .unwrap_or_else(Utc::now)
196}
197
198fn build_edge_filter_sql(
199    namespace: &str,
200    filter: &EdgeFilter,
201) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
202    let mut conditions: Vec<String> = vec!["namespace = ?1".to_string()];
203    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
204
205    if !filter.ids.is_empty() {
206        let placeholders: Vec<String> = filter
207            .ids
208            .iter()
209            .map(|id| {
210                params.push(Box::new(id.to_string()));
211                format!("?{}", params.len())
212            })
213            .collect();
214        conditions.push(format!("id IN ({})", placeholders.join(",")));
215    }
216
217    if !filter.source_ids.is_empty() {
218        let placeholders: Vec<String> = filter
219            .source_ids
220            .iter()
221            .map(|id| {
222                params.push(Box::new(id.to_string()));
223                format!("?{}", params.len())
224            })
225            .collect();
226        conditions.push(format!("source_id IN ({})", placeholders.join(",")));
227    }
228
229    if !filter.target_ids.is_empty() {
230        let placeholders: Vec<String> = filter
231            .target_ids
232            .iter()
233            .map(|id| {
234                params.push(Box::new(id.to_string()));
235                format!("?{}", params.len())
236            })
237            .collect();
238        conditions.push(format!("target_id IN ({})", placeholders.join(",")));
239    }
240
241    if !filter.relations.is_empty() {
242        let placeholders: Vec<String> = filter
243            .relations
244            .iter()
245            .map(|r| {
246                params.push(Box::new(r.to_string()));
247                format!("?{}", params.len())
248            })
249            .collect();
250        conditions.push(format!("relation IN ({})", placeholders.join(",")));
251    }
252
253    if let Some(min_w) = filter.min_weight {
254        params.push(Box::new(min_w));
255        conditions.push(format!("weight >= ?{}", params.len()));
256    }
257
258    if let Some(max_w) = filter.max_weight {
259        params.push(Box::new(max_w));
260        conditions.push(format!("weight <= ?{}", params.len()));
261    }
262
263    if let Some(ref time_range) = filter.created_at {
264        if let Some(start) = time_range.start {
265            params.push(Box::new(start.timestamp_micros()));
266            conditions.push(format!("created_at >= ?{}", params.len()));
267        }
268        if let Some(end) = time_range.end {
269            params.push(Box::new(end.timestamp_micros()));
270            conditions.push(format!("created_at < ?{}", params.len()));
271        }
272    }
273
274    let clause = format!(" WHERE {}", conditions.join(" AND "));
275    (clause, params)
276}
277
278fn edge_sort_col(field: &EdgeSortField) -> &'static str {
279    match field {
280        EdgeSortField::CreatedAt => "created_at",
281        EdgeSortField::Weight => "weight",
282        EdgeSortField::Relation => "relation",
283    }
284}
285
286// =============================================================================
287// GraphStore implementation
288// =============================================================================
289
290#[async_trait]
291impl GraphStore for SqlGraphStore {
292    async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
293        let namespace = self.namespace.clone();
294        let id_str = Uuid::from(edge.id).to_string();
295        let src_str = edge.source_id.to_string();
296        let tgt_str = edge.target_id.to_string();
297        let relation_str = edge.relation.to_string();
298        let metadata_str = edge
299            .metadata
300            .as_ref()
301            .map(|v| serde_json::to_string(v).unwrap_or_default());
302        self.with_writer("upsert_edge", move |conn| {
303            conn.execute(
304                "INSERT INTO graph_edges \
305                 (namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
306                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) \
307                 ON CONFLICT(namespace, id) DO UPDATE SET \
308                     source_id = excluded.source_id, \
309                     target_id = excluded.target_id, \
310                     relation = excluded.relation, \
311                     weight = excluded.weight, \
312                     created_at = excluded.created_at, \
313                     metadata = excluded.metadata \
314                 ON CONFLICT(namespace, source_id, target_id, relation) DO NOTHING",
315                rusqlite::params![
316                    namespace,
317                    id_str,
318                    src_str,
319                    tgt_str,
320                    relation_str,
321                    edge.weight,
322                    edge.created_at.timestamp_micros(),
323                    metadata_str,
324                ],
325            )?;
326            Ok(())
327        })
328        .await
329    }
330
331    async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
332        let attempted = edges.len() as u64;
333        let namespace = self.namespace.clone();
334
335        self.with_writer("upsert_edges", move |conn| {
336            conn.execute_batch("BEGIN IMMEDIATE")?;
337            let mut affected = 0u64;
338            let mut failed = 0u64;
339            let mut first_error = String::new();
340
341            for edge in &edges {
342                let id_str = Uuid::from(edge.id).to_string();
343                let src_str = edge.source_id.to_string();
344                let tgt_str = edge.target_id.to_string();
345                let relation_str = edge.relation.to_string();
346                let metadata_str = edge
347                    .metadata
348                    .as_ref()
349                    .map(|v| serde_json::to_string(v).unwrap_or_default());
350                match conn.execute(
351                    "INSERT INTO graph_edges \
352                     (namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
353                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) \
354                     ON CONFLICT(namespace, id) DO UPDATE SET \
355                         source_id = excluded.source_id, \
356                         target_id = excluded.target_id, \
357                         relation = excluded.relation, \
358                         weight = excluded.weight, \
359                         created_at = excluded.created_at, \
360                         metadata = excluded.metadata \
361                     ON CONFLICT(namespace, source_id, target_id, relation) DO NOTHING",
362                    rusqlite::params![
363                        &namespace,
364                        id_str,
365                        src_str,
366                        tgt_str,
367                        relation_str,
368                        edge.weight,
369                        edge.created_at.timestamp_micros(),
370                        metadata_str,
371                    ],
372                ) {
373                    Ok(_) => affected += 1,
374                    Err(e) => {
375                        if first_error.is_empty() {
376                            first_error = e.to_string();
377                        }
378                        failed += 1;
379                    }
380                }
381            }
382
383            if let Err(e) = conn.execute_batch("COMMIT") {
384                let _ = conn.execute_batch("ROLLBACK");
385                return Err(e);
386            }
387            Ok(BatchWriteSummary {
388                attempted,
389                affected,
390                failed,
391                first_error,
392            })
393        })
394        .await
395    }
396
397    async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
398        let namespace = self.namespace.clone();
399        let id_str = Uuid::from(id).to_string();
400
401        self.with_reader("get_edge", move |conn| {
402            let mut stmt = conn.prepare(
403                "SELECT id, source_id, target_id, relation, weight, created_at, metadata \
404                 FROM graph_edges WHERE namespace = ?1 AND id = ?2",
405            )?;
406            let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
407            match rows.next()? {
408                Some(row) => Ok(Some(read_edge(row)?)),
409                None => Ok(None),
410            }
411        })
412        .await
413    }
414
415    async fn delete_edge(&self, id: LinkId) -> Result<bool, StorageError> {
416        let namespace = self.namespace.clone();
417        let id_str = Uuid::from(id).to_string();
418
419        self.with_writer("delete_edge", move |conn| {
420            let deleted = conn.execute(
421                "DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
422                rusqlite::params![namespace, id_str],
423            )?;
424            Ok(deleted > 0)
425        })
426        .await
427    }
428
429    async fn query_edges(
430        &self,
431        filter: EdgeFilter,
432        sort: Vec<SortOrder<EdgeSortField>>,
433        page: PageRequest,
434    ) -> Result<Page<Edge>, StorageError> {
435        let namespace = self.namespace.clone();
436        self.with_reader("query_edges", move |conn| {
437            let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
438
439            let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
440            let total: i64 = {
441                let mut stmt = conn.prepare(&count_sql)?;
442                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
443                    filter_params.iter().map(|p| p.as_ref()).collect();
444                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
445            };
446
447            let order_clause = if sort.is_empty() {
448                " ORDER BY created_at DESC".to_string()
449            } else {
450                let parts: Vec<String> = sort
451                    .iter()
452                    .map(|s| {
453                        let dir = match s.direction {
454                            SortDirection::Asc => "ASC",
455                            SortDirection::Desc => "DESC",
456                        };
457                        format!("{} {}", edge_sort_col(&s.field), dir)
458                    })
459                    .collect();
460                format!(" ORDER BY {}", parts.join(", "))
461            };
462
463            let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
464            let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
465            all_params.push(Box::new(page.limit as i64));
466            all_params.push(Box::new(page.offset as i64));
467
468            let limit_idx = all_params.len() - 1;
469            let offset_idx = all_params.len();
470
471            let data_sql = format!(
472                "SELECT id, source_id, target_id, relation, weight, created_at, metadata \
473                 FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
474                where_clause, order_clause, limit_idx, offset_idx,
475            );
476
477            let mut stmt = conn.prepare(&data_sql)?;
478            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
479                all_params.iter().map(|p| p.as_ref()).collect();
480            let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
481
482            let mut items = Vec::new();
483            for row in rows {
484                items.push(row?);
485            }
486
487            Ok(Page {
488                items,
489                total: Some(total as u64),
490            })
491        })
492        .await
493    }
494
495    async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
496        let namespace = self.namespace.clone();
497        self.with_reader("count_edges", move |conn| {
498            let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
499            let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
500            let mut stmt = conn.prepare(&sql)?;
501            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
502                params.iter().map(|p| p.as_ref()).collect();
503            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
504            Ok(count as u64)
505        })
506        .await
507    }
508
509    async fn neighbors(
510        &self,
511        node_id: Uuid,
512        query: NeighborQuery,
513    ) -> Result<Vec<NeighborHit>, StorageError> {
514        use khive_storage::types::Direction;
515
516        let namespace = self.namespace.clone();
517        let node_str = node_id.to_string();
518
519        self.with_reader("neighbors", move |conn| {
520            let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
521                            FROM graph_edges WHERE namespace = ?1 AND source_id = ?2";
522            let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
523                           FROM graph_edges WHERE namespace = ?1 AND target_id = ?2";
524
525            let sql = match query.direction {
526                Direction::Out => base_out.to_string(),
527                Direction::In => base_in.to_string(),
528                Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
529            };
530
531            let mut conditions: Vec<String> = Vec::new();
532            let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
533            let mut param_idx = 3;
534
535            if let Some(ref rels) = query.relations {
536                if !rels.is_empty() {
537                    let placeholders: Vec<String> = rels
538                        .iter()
539                        .map(|r| {
540                            extra_params.push(Box::new(r.to_string()));
541                            let p = format!("?{}", param_idx);
542                            param_idx += 1;
543                            p
544                        })
545                        .collect();
546                    conditions.push(format!("relation IN ({})", placeholders.join(",")));
547                }
548            }
549
550            if let Some(min_w) = query.min_weight {
551                extra_params.push(Box::new(min_w));
552                conditions.push(format!("weight >= ?{}", param_idx));
553                param_idx += 1;
554            }
555
556            let where_extra = if conditions.is_empty() {
557                String::new()
558            } else {
559                format!(" WHERE {}", conditions.join(" AND "))
560            };
561
562            let limit_clause = if let Some(lim) = query.limit {
563                extra_params.push(Box::new(lim as i64));
564                format!(" LIMIT ?{}", param_idx)
565            } else {
566                String::new()
567            };
568
569            let full_sql = format!(
570                "SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
571                sql, where_extra, limit_clause
572            );
573
574            let mut stmt = conn.prepare(&full_sql)?;
575
576            let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
577            all_params.push(Box::new(namespace.clone()));
578            all_params.push(Box::new(node_str.clone()));
579            all_params.extend(extra_params);
580
581            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
582                all_params.iter().map(|p| p.as_ref()).collect();
583
584            let rows = stmt.query_map(param_refs.as_slice(), |row| {
585                let nid_str: String = row.get(0)?;
586                let eid_str: String = row.get(1)?;
587                let relation_str: String = row.get(2)?;
588                let weight: f64 = row.get(3)?;
589                Ok((nid_str, eid_str, relation_str, weight))
590            })?;
591
592            let mut hits = Vec::new();
593            for row in rows {
594                let (nid_str, eid_str, relation_str, weight) = row?;
595                let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
596                    rusqlite::Error::FromSqlConversionFailure(
597                        2,
598                        rusqlite::types::Type::Text,
599                        Box::new(e),
600                    )
601                })?;
602                hits.push(NeighborHit {
603                    node_id: parse_uuid(&nid_str)?,
604                    edge_id: parse_uuid(&eid_str)?,
605                    relation,
606                    weight,
607                    name: None,
608                    kind: None,
609                });
610            }
611
612            Ok(hits)
613        })
614        .await
615    }
616
617    async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
618        use khive_storage::types::Direction;
619
620        if request.roots.is_empty() {
621            return Ok(Vec::new());
622        }
623
624        let roots = request.roots.clone();
625        let opts = request.options.clone();
626        let include_roots = request.include_roots;
627        let namespace = self.namespace.clone();
628
629        self.with_reader("traverse", move |conn| {
630            let mut all_paths: Vec<GraphPath> = Vec::new();
631
632            for root_id in &roots {
633                let root_str = root_id.to_string();
634
635                let (join_condition, next_node) = match opts.direction {
636                    Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
637                    Direction::In => ("e.target_id = t.node_id", "e.source_id"),
638                    Direction::Both => (
639                        "(e.source_id = t.node_id OR e.target_id = t.node_id)",
640                        "CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
641                    ),
642                };
643
644                let mut relation_cond = String::new();
645                let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
646                let mut param_idx = 4;
647
648                if let Some(ref rels) = opts.relations {
649                    if !rels.is_empty() {
650                        let placeholders: Vec<String> = rels
651                            .iter()
652                            .map(|r| {
653                                relation_params.push(Box::new(r.to_string()));
654                                let p = format!("?{}", param_idx);
655                                param_idx += 1;
656                                p
657                            })
658                            .collect();
659                        relation_cond =
660                            format!(" AND e.relation IN ({})", placeholders.join(","));
661                    }
662                }
663
664                let mut weight_cond = String::new();
665                if let Some(min_w) = opts.min_weight {
666                    relation_params.push(Box::new(min_w));
667                    weight_cond = format!(" AND e.weight >= ?{}", param_idx);
668                    param_idx += 1;
669                }
670
671                let limit_clause = if let Some(lim) = opts.limit {
672                    relation_params.push(Box::new(lim as i64));
673                    format!(" LIMIT ?{}", param_idx)
674                } else {
675                    String::new()
676                };
677
678                let cte_sql = format!(
679                    "WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
680                         SELECT ?2, NULL, 0, ?2, 0.0 \
681                         UNION ALL \
682                         SELECT {next_node}, e.id, t.depth + 1, \
683                                t.path || ',' || {next_node}, \
684                                t.total_weight + e.weight \
685                         FROM graph_edges e \
686                         JOIN traversal t ON {join_condition} \
687                         WHERE e.namespace = ?1 \
688                           AND t.depth < ?3 \
689                           AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
690                     ) \
691                     SELECT node_id, edge_id, depth, path, total_weight \
692                     FROM traversal WHERE depth > 0 \
693                     ORDER BY depth{limit}",
694                    next_node = next_node,
695                    join_condition = join_condition,
696                    rel_cond = relation_cond,
697                    wt_cond = weight_cond,
698                    limit = limit_clause,
699                );
700
701                let mut stmt = conn.prepare(&cte_sql)?;
702
703                let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
704                all_params.push(Box::new(namespace.clone()));
705                all_params.push(Box::new(root_str.clone()));
706                all_params.push(Box::new(opts.max_depth as i64));
707                all_params.extend(relation_params);
708
709                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
710                    all_params.iter().map(|p| p.as_ref()).collect();
711
712                let rows = stmt.query_map(param_refs.as_slice(), |row| {
713                    let node_str: String = row.get(0)?;
714                    let edge_str: Option<String> = row.get(1)?;
715                    let depth: i64 = row.get(2)?;
716                    let _path: String = row.get(3)?;
717                    let total_weight: f64 = row.get(4)?;
718                    Ok((node_str, edge_str, depth, total_weight))
719                })?;
720
721                let mut nodes = Vec::new();
722                let mut max_weight = 0.0f64;
723
724                if include_roots {
725                    nodes.push(PathNode {
726                        node_id: *root_id,
727                        via_edge: None,
728                        depth: 0,
729                        name: None,
730                        kind: None,
731                    });
732                }
733
734                for row in rows {
735                    let (node_str, edge_str, depth, total_weight) = row?;
736                    let node_id = parse_uuid(&node_str)?;
737                    let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
738                    nodes.push(PathNode {
739                        node_id,
740                        via_edge,
741                        depth: depth as usize,
742                        name: None,
743                        kind: None,
744                    });
745                    if total_weight > max_weight {
746                        max_weight = total_weight;
747                    }
748                }
749
750                if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
751                    all_paths.push(GraphPath {
752                        root_id: *root_id,
753                        nodes,
754                        total_weight: max_weight,
755                    });
756                }
757            }
758
759            Ok(all_paths)
760        })
761        .await
762    }
763}
764
765// =============================================================================
766// DDL
767// =============================================================================
768
769const GRAPH_DDL: &str = "\
770    CREATE TABLE IF NOT EXISTS graph_edges (\
771        namespace TEXT NOT NULL,\
772        id TEXT NOT NULL,\
773        source_id TEXT NOT NULL,\
774        target_id TEXT NOT NULL,\
775        relation TEXT NOT NULL,\
776        weight REAL NOT NULL DEFAULT 1.0,\
777        created_at INTEGER NOT NULL,\
778        metadata TEXT,\
779        PRIMARY KEY (namespace, id)\
780    );\
781    CREATE UNIQUE INDEX IF NOT EXISTS idx_graph_edges_unique_triple ON graph_edges(namespace, source_id, target_id, relation);\
782    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_source ON graph_edges(namespace, source_id);\
783    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_target ON graph_edges(namespace, target_id);\
784    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_relation ON graph_edges(namespace, relation);\
785    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_src_rel ON graph_edges(namespace, source_id, relation);\
786    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_tgt_rel ON graph_edges(namespace, target_id, relation);\
787";
788
789pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
790    conn.execute_batch(GRAPH_DDL)
791}
792
793#[cfg(test)]
794mod tests {
795    use super::*;
796    use crate::pool::PoolConfig;
797    use khive_storage::types::{Direction, TraversalOptions};
798
799    fn setup_memory_store() -> SqlGraphStore {
800        let config = PoolConfig {
801            path: None,
802            ..PoolConfig::default()
803        };
804        let pool = Arc::new(ConnectionPool::new(config).unwrap());
805
806        {
807            let writer = pool.writer().unwrap();
808            writer.conn().execute_batch(GRAPH_DDL).unwrap();
809        }
810
811        SqlGraphStore::new_scoped(pool, false, "default")
812    }
813
814    fn make_edge(source: Uuid, target: Uuid, relation: EdgeRelation, weight: f64) -> Edge {
815        Edge {
816            id: Uuid::new_v4().into(),
817            source_id: source,
818            target_id: target,
819            relation,
820            weight,
821            created_at: Utc::now(),
822            metadata: None,
823        }
824    }
825
826    #[tokio::test]
827    async fn test_upsert_and_get_edge() {
828        let store = setup_memory_store();
829
830        let src = Uuid::new_v4();
831        let tgt = Uuid::new_v4();
832        let edge = Edge {
833            id: Uuid::new_v4().into(),
834            source_id: src,
835            target_id: tgt,
836            relation: EdgeRelation::Extends,
837            weight: 0.8,
838            created_at: Utc::now(),
839            metadata: None,
840        };
841        let edge_id = edge.id;
842
843        store.upsert_edge(edge).await.unwrap();
844
845        let fetched = store.get_edge(edge_id).await.unwrap();
846        assert!(fetched.is_some());
847        let fetched = fetched.unwrap();
848        assert_eq!(fetched.id, edge_id);
849        assert_eq!(fetched.source_id, src);
850        assert_eq!(fetched.target_id, tgt);
851        assert_eq!(fetched.relation, EdgeRelation::Extends);
852        assert!((fetched.weight - 0.8).abs() < 1e-9);
853    }
854
855    #[tokio::test]
856    async fn test_delete_edge() {
857        let store = setup_memory_store();
858
859        let edge = make_edge(Uuid::new_v4(), Uuid::new_v4(), EdgeRelation::Contains, 1.0);
860        let edge_id = edge.id;
861
862        store.upsert_edge(edge).await.unwrap();
863        assert!(store.get_edge(edge_id).await.unwrap().is_some());
864
865        let deleted = store.delete_edge(edge_id).await.unwrap();
866        assert!(deleted);
867
868        assert!(store.get_edge(edge_id).await.unwrap().is_none());
869
870        let deleted_again = store.delete_edge(edge_id).await.unwrap();
871        assert!(!deleted_again);
872    }
873
874    #[tokio::test]
875    async fn test_count_edges() {
876        let store = setup_memory_store();
877
878        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 0);
879
880        for _ in 0..5 {
881            store
882                .upsert_edge(make_edge(
883                    Uuid::new_v4(),
884                    Uuid::new_v4(),
885                    EdgeRelation::DependsOn,
886                    1.0,
887                ))
888                .await
889                .unwrap();
890        }
891
892        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 5);
893    }
894
895    #[tokio::test]
896    async fn test_neighbors_outbound() {
897        let store = setup_memory_store();
898
899        let a = Uuid::new_v4();
900        let b = Uuid::new_v4();
901        let c = Uuid::new_v4();
902        let d = Uuid::new_v4();
903
904        store
905            .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
906            .await
907            .unwrap();
908        store
909            .upsert_edge(make_edge(a, c, EdgeRelation::DependsOn, 0.7))
910            .await
911            .unwrap();
912        store
913            .upsert_edge(make_edge(d, a, EdgeRelation::Extends, 0.5))
914            .await
915            .unwrap();
916
917        let query = NeighborQuery {
918            direction: Direction::Out,
919            relations: None,
920            limit: None,
921            min_weight: None,
922        };
923
924        let hits = store.neighbors(a, query).await.unwrap();
925        assert_eq!(hits.len(), 2);
926
927        let neighbor_ids: Vec<Uuid> = hits.iter().map(|h| h.node_id).collect();
928        assert!(neighbor_ids.contains(&b));
929        assert!(neighbor_ids.contains(&c));
930        assert!(!neighbor_ids.contains(&d));
931    }
932
933    #[tokio::test]
934    async fn test_traverse_depth_2() {
935        let store = setup_memory_store();
936
937        let a = Uuid::new_v4();
938        let b = Uuid::new_v4();
939        let c = Uuid::new_v4();
940        let d = Uuid::new_v4();
941
942        store
943            .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
944            .await
945            .unwrap();
946        store
947            .upsert_edge(make_edge(b, c, EdgeRelation::Extends, 2.0))
948            .await
949            .unwrap();
950        store
951            .upsert_edge(make_edge(c, d, EdgeRelation::Extends, 3.0))
952            .await
953            .unwrap();
954
955        let request = TraversalRequest {
956            roots: vec![a],
957            options: TraversalOptions::new(2).with_direction(Direction::Out),
958            include_roots: true,
959        };
960
961        let paths = store.traverse(request).await.unwrap();
962        assert_eq!(paths.len(), 1);
963
964        let path = &paths[0];
965        let node_ids: Vec<Uuid> = path.nodes.iter().map(|n| n.node_id).collect();
966        assert!(node_ids.contains(&a));
967        assert!(node_ids.contains(&b));
968        assert!(node_ids.contains(&c));
969        assert!(!node_ids.contains(&d));
970    }
971
972    #[tokio::test]
973    async fn test_metadata_roundtrip() {
974        let store = setup_memory_store();
975
976        let src = Uuid::new_v4();
977        let tgt = Uuid::new_v4();
978        let meta = serde_json::json!({"note": "important link", "confidence": 0.95});
979        let edge = Edge {
980            id: Uuid::new_v4().into(),
981            source_id: src,
982            target_id: tgt,
983            relation: EdgeRelation::Implements,
984            weight: 0.9,
985            created_at: Utc::now(),
986            metadata: Some(meta.clone()),
987        };
988        let edge_id = edge.id;
989
990        store.upsert_edge(edge).await.unwrap();
991
992        let fetched = store.get_edge(edge_id).await.unwrap().unwrap();
993        assert_eq!(
994            fetched.metadata.as_ref(),
995            Some(&meta),
996            "metadata must survive a write/read roundtrip via get_edge"
997        );
998
999        // Also verify via query_edges.
1000        let page = store
1001            .query_edges(EdgeFilter::default(), vec![], PageRequest::default())
1002            .await
1003            .unwrap();
1004        let from_query = page
1005            .items
1006            .iter()
1007            .find(|e| e.id == edge_id)
1008            .expect("edge must appear in query_edges result");
1009        assert_eq!(
1010            from_query.metadata.as_ref(),
1011            Some(&meta),
1012            "metadata must survive a write/read roundtrip via query_edges"
1013        );
1014    }
1015
1016    #[tokio::test]
1017    async fn test_upsert_edges_batch() {
1018        let store = setup_memory_store();
1019
1020        let edges: Vec<Edge> = (0..10)
1021            .map(|i| {
1022                make_edge(
1023                    Uuid::new_v4(),
1024                    Uuid::new_v4(),
1025                    EdgeRelation::Implements,
1026                    i as f64,
1027                )
1028            })
1029            .collect();
1030
1031        let summary = store.upsert_edges(edges).await.unwrap();
1032        assert_eq!(summary.attempted, 10);
1033        assert_eq!(summary.affected, 10);
1034        assert_eq!(summary.failed, 0);
1035
1036        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 10);
1037    }
1038
1039    // ---- #229 deduplication test ----
1040
1041    #[tokio::test]
1042    async fn graph_duplicate_edges_ignored() {
1043        let store = setup_memory_store();
1044
1045        let src = Uuid::new_v4();
1046        let tgt = Uuid::new_v4();
1047
1048        // Two edges with the same (source_id, target_id, relation) triple but different IDs.
1049        let edge1 = Edge {
1050            id: Uuid::new_v4().into(),
1051            source_id: src,
1052            target_id: tgt,
1053            relation: EdgeRelation::Extends,
1054            weight: 1.0,
1055            created_at: Utc::now(),
1056            metadata: None,
1057        };
1058        let edge2 = Edge {
1059            id: Uuid::new_v4().into(),
1060            source_id: src,
1061            target_id: tgt,
1062            relation: EdgeRelation::Extends,
1063            weight: 0.5,
1064            created_at: Utc::now(),
1065            metadata: None,
1066        };
1067
1068        store.upsert_edge(edge1).await.unwrap();
1069        store.upsert_edge(edge2).await.unwrap();
1070
1071        assert_eq!(
1072            store.count_edges(EdgeFilter::default()).await.unwrap(),
1073            1,
1074            "duplicate (source, target, relation) triple must be ignored; only one edge must exist"
1075        );
1076    }
1077}