Skip to main content

khive_db/stores/
graph.rs

1//! SQL-backed `GraphStore` implementation.
2//!
3//! `SqlGraphStore` stores graph edges in a regular SQLite table.
4//! Traversal uses recursive CTEs for multi-hop queries.
5//!
6//! # Connection strategy
7//!
8//! - **File-backed**: Opens standalone connections per operation.
9//! - **In-memory**: Acquires pool connections per operation via `spawn_blocking`.
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use chrono::{DateTime, TimeZone, Utc};
15use uuid::Uuid;
16
17use khive_storage::error::StorageError;
18use khive_storage::types::{
19    BatchWriteSummary, DeleteMode, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit,
20    NeighborQuery, Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
21};
22use khive_storage::GraphStore;
23use khive_storage::LinkId;
24use khive_storage::StorageCapability;
25use khive_types::EdgeRelation;
26
27use crate::error::SqliteError;
28use crate::pool::ConnectionPool;
29
30/// Map a rusqlite error to `StorageError` with `Graph` capability.
31fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
32    StorageError::driver(StorageCapability::Graph, op, e)
33}
34
35fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
36    StorageError::driver(StorageCapability::Graph, op, e)
37}
38
39/// A GraphStore backed by SQLite tables.
40pub struct SqlGraphStore {
41    pool: Arc<ConnectionPool>,
42    is_file_backed: bool,
43    namespace: String,
44}
45
46impl SqlGraphStore {
47    /// Create a new store scoped to one namespace.
48    pub fn new_scoped(
49        pool: Arc<ConnectionPool>,
50        is_file_backed: bool,
51        namespace: impl Into<String>,
52    ) -> Self {
53        Self {
54            pool,
55            is_file_backed,
56            namespace: namespace.into(),
57        }
58    }
59
60    fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
61        let config = self.pool.config();
62        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
63            operation: "graph_writer".into(),
64            message: "in-memory databases do not support standalone connections".into(),
65        })?;
66
67        let conn = rusqlite::Connection::open_with_flags(
68            path,
69            rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
70                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
71                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
72        )
73        .map_err(|e| map_err(e, "open_graph_writer"))?;
74
75        conn.busy_timeout(config.busy_timeout)
76            .map_err(|e| map_err(e, "open_graph_writer"))?;
77        conn.pragma_update(None, "foreign_keys", "ON")
78            .map_err(|e| map_err(e, "open_graph_writer"))?;
79        conn.pragma_update(None, "synchronous", "NORMAL")
80            .map_err(|e| map_err(e, "open_graph_writer"))?;
81
82        Ok(conn)
83    }
84
85    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
86        let config = self.pool.config();
87        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
88            operation: "graph_reader".into(),
89            message: "in-memory databases do not support standalone connections".into(),
90        })?;
91
92        let conn = rusqlite::Connection::open_with_flags(
93            path,
94            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
95                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
96                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
97        )
98        .map_err(|e| map_err(e, "open_graph_reader"))?;
99
100        conn.busy_timeout(config.busy_timeout)
101            .map_err(|e| map_err(e, "open_graph_reader"))?;
102        conn.pragma_update(None, "foreign_keys", "ON")
103            .map_err(|e| map_err(e, "open_graph_reader"))?;
104        conn.pragma_update(None, "synchronous", "NORMAL")
105            .map_err(|e| map_err(e, "open_graph_reader"))?;
106
107        Ok(conn)
108    }
109
110    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
111    where
112        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
113        R: Send + 'static,
114    {
115        if self.is_file_backed {
116            let conn = self.open_standalone_writer()?;
117            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
118                .await
119                .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
120        } else {
121            let pool = Arc::clone(&self.pool);
122            tokio::task::spawn_blocking(move || {
123                let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
124                f(guard.conn()).map_err(|e| map_err(e, op))
125            })
126            .await
127            .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
128        }
129    }
130
131    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
132    where
133        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
134        R: Send + 'static,
135    {
136        if self.is_file_backed {
137            let conn = self.open_standalone_reader()?;
138            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
139                .await
140                .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
141        } else {
142            let pool = Arc::clone(&self.pool);
143            tokio::task::spawn_blocking(move || {
144                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
145                f(guard.conn()).map_err(|e| map_err(e, op))
146            })
147            .await
148            .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
149        }
150    }
151}
152
153// =============================================================================
154// Helpers
155// =============================================================================
156
157fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
158    let namespace: String = row.get(0)?;
159    let id_str: String = row.get(1)?;
160    let source_str: String = row.get(2)?;
161    let target_str: String = row.get(3)?;
162    let relation_str: String = row.get(4)?;
163    let weight: f64 = row.get(5)?;
164    let created_micros: i64 = row.get(6)?;
165    let updated_micros: i64 = row.get(7)?;
166    let deleted_micros: Option<i64> = row.get(8)?;
167    let metadata_str: Option<String> = row.get(9)?;
168    let target_backend: Option<String> = row.get(10)?;
169
170    let id = parse_uuid(&id_str)?;
171    let source_id = parse_uuid(&source_str)?;
172    let target_id = parse_uuid(&target_str)?;
173    let created_at = micros_to_datetime(created_micros);
174    let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
175        rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e))
176    })?;
177    let metadata = match metadata_str {
178        Some(s) => {
179            let v = serde_json::from_str(&s).map_err(|e| {
180                rusqlite::Error::FromSqlConversionFailure(
181                    9,
182                    rusqlite::types::Type::Text,
183                    Box::new(e),
184                )
185            })?;
186            Some(v)
187        }
188        None => None,
189    };
190
191    Ok(Edge {
192        id: id.into(),
193        namespace,
194        source_id,
195        target_id,
196        relation,
197        weight,
198        created_at,
199        updated_at: micros_to_datetime(updated_micros),
200        deleted_at: deleted_micros.map(micros_to_datetime),
201        metadata,
202        target_backend,
203    })
204}
205
206fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
207    Uuid::parse_str(s).map_err(|e| {
208        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
209    })
210}
211
212fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
213    Utc.timestamp_micros(micros)
214        .single()
215        .unwrap_or_else(Utc::now)
216}
217
218fn build_edge_filter_sql(
219    namespace: &str,
220    filter: &EdgeFilter,
221) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
222    let mut conditions: Vec<String> = vec![
223        "namespace = ?1".to_string(),
224        "deleted_at IS NULL".to_string(),
225    ];
226    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
227
228    if !filter.ids.is_empty() {
229        let placeholders: Vec<String> = filter
230            .ids
231            .iter()
232            .map(|id| {
233                params.push(Box::new(id.to_string()));
234                format!("?{}", params.len())
235            })
236            .collect();
237        conditions.push(format!("id IN ({})", placeholders.join(",")));
238    }
239
240    if !filter.source_ids.is_empty() {
241        let placeholders: Vec<String> = filter
242            .source_ids
243            .iter()
244            .map(|id| {
245                params.push(Box::new(id.to_string()));
246                format!("?{}", params.len())
247            })
248            .collect();
249        conditions.push(format!("source_id IN ({})", placeholders.join(",")));
250    }
251
252    if !filter.target_ids.is_empty() {
253        let placeholders: Vec<String> = filter
254            .target_ids
255            .iter()
256            .map(|id| {
257                params.push(Box::new(id.to_string()));
258                format!("?{}", params.len())
259            })
260            .collect();
261        conditions.push(format!("target_id IN ({})", placeholders.join(",")));
262    }
263
264    if !filter.relations.is_empty() {
265        let placeholders: Vec<String> = filter
266            .relations
267            .iter()
268            .map(|r| {
269                params.push(Box::new(r.to_string()));
270                format!("?{}", params.len())
271            })
272            .collect();
273        conditions.push(format!("relation IN ({})", placeholders.join(",")));
274    }
275
276    if let Some(min_w) = filter.min_weight {
277        params.push(Box::new(min_w));
278        conditions.push(format!("weight >= ?{}", params.len()));
279    }
280
281    if let Some(max_w) = filter.max_weight {
282        params.push(Box::new(max_w));
283        conditions.push(format!("weight <= ?{}", params.len()));
284    }
285
286    if let Some(ref time_range) = filter.created_at {
287        if let Some(start) = time_range.start {
288            params.push(Box::new(start.timestamp_micros()));
289            conditions.push(format!("created_at >= ?{}", params.len()));
290        }
291        if let Some(end) = time_range.end {
292            params.push(Box::new(end.timestamp_micros()));
293            conditions.push(format!("created_at < ?{}", params.len()));
294        }
295    }
296
297    let clause = format!(" WHERE {}", conditions.join(" AND "));
298    (clause, params)
299}
300
301fn edge_sort_col(field: &EdgeSortField) -> &'static str {
302    match field {
303        EdgeSortField::CreatedAt => "created_at",
304        EdgeSortField::Weight => "weight",
305        EdgeSortField::Relation => "relation",
306    }
307}
308
309// =============================================================================
310// GraphStore implementation
311// =============================================================================
312
313#[async_trait]
314impl GraphStore for SqlGraphStore {
315    async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
316        let namespace = self.namespace.clone();
317        if edge.namespace != namespace {
318            return Err(StorageError::InvalidInput {
319                capability: StorageCapability::Graph,
320                operation: "upsert_edge".into(),
321                message: format!(
322                    "edge namespace {:?} does not match store namespace {:?}",
323                    edge.namespace, namespace
324                ),
325            });
326        }
327        let id_str = Uuid::from(edge.id).to_string();
328        let src_str = edge.source_id.to_string();
329        let tgt_str = edge.target_id.to_string();
330        let relation_str = edge.relation.to_string();
331        let metadata_str = edge
332            .metadata
333            .as_ref()
334            .map(serde_json::to_string)
335            .transpose()
336            .map_err(|e| StorageError::driver(StorageCapability::Graph, "upsert_edge", e))?;
337        self.with_writer("upsert_edge", move |conn| {
338            conn.execute(
339                "INSERT INTO graph_edges \
340                 (namespace, id, source_id, target_id, relation, weight, \
341                  created_at, updated_at, deleted_at, metadata, target_backend) \
342                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) \
343                 ON CONFLICT(namespace, id) DO UPDATE SET \
344                     source_id = excluded.source_id, \
345                     target_id = excluded.target_id, \
346                     relation = excluded.relation, \
347                     weight = excluded.weight, \
348                     updated_at = excluded.updated_at, \
349                     deleted_at = NULL, \
350                     metadata = excluded.metadata, \
351                     target_backend = excluded.target_backend \
352                 ON CONFLICT(namespace, source_id, target_id, relation) DO UPDATE SET \
353                     weight = excluded.weight, \
354                     updated_at = excluded.updated_at, \
355                     deleted_at = NULL, \
356                     metadata = excluded.metadata, \
357                     target_backend = excluded.target_backend",
358                rusqlite::params![
359                    namespace,
360                    id_str,
361                    src_str,
362                    tgt_str,
363                    relation_str,
364                    edge.weight,
365                    edge.created_at.timestamp_micros(),
366                    edge.updated_at.timestamp_micros(),
367                    edge.deleted_at.map(|t| t.timestamp_micros()),
368                    metadata_str,
369                    edge.target_backend,
370                ],
371            )?;
372            Ok(())
373        })
374        .await
375    }
376
377    async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
378        let attempted = edges.len() as u64;
379        let namespace = self.namespace.clone();
380
381        // Validate namespaces before acquiring writer.
382        for edge in &edges {
383            if edge.namespace != namespace {
384                return Err(StorageError::InvalidInput {
385                    capability: StorageCapability::Graph,
386                    operation: "upsert_edges".into(),
387                    message: format!(
388                        "edge namespace {:?} does not match store namespace {:?}",
389                        edge.namespace, namespace
390                    ),
391                });
392            }
393        }
394
395        self.with_writer("upsert_edges", move |conn| {
396            conn.execute_batch("BEGIN IMMEDIATE")?;
397            let mut affected = 0u64;
398
399            for edge in &edges {
400                let id_str = Uuid::from(edge.id).to_string();
401                let src_str = edge.source_id.to_string();
402                let tgt_str = edge.target_id.to_string();
403                let relation_str = edge.relation.to_string();
404                let metadata_str = edge
405                    .metadata
406                    .as_ref()
407                    .map(serde_json::to_string)
408                    .transpose()
409                    .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
410                if let Err(e) = conn.execute(
411                    "INSERT INTO graph_edges \
412                     (namespace, id, source_id, target_id, relation, weight, \
413                      created_at, updated_at, deleted_at, metadata, target_backend) \
414                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) \
415                     ON CONFLICT(namespace, id) DO UPDATE SET \
416                         source_id = excluded.source_id, \
417                         target_id = excluded.target_id, \
418                         relation = excluded.relation, \
419                         weight = excluded.weight, \
420                         updated_at = excluded.updated_at, \
421                         deleted_at = NULL, \
422                         metadata = excluded.metadata, \
423                         target_backend = excluded.target_backend \
424                     ON CONFLICT(namespace, source_id, target_id, relation) DO UPDATE SET \
425                         weight = excluded.weight, \
426                         updated_at = excluded.updated_at, \
427                         deleted_at = NULL, \
428                         metadata = excluded.metadata, \
429                         target_backend = excluded.target_backend",
430                    rusqlite::params![
431                        &namespace,
432                        id_str,
433                        src_str,
434                        tgt_str,
435                        relation_str,
436                        edge.weight,
437                        edge.created_at.timestamp_micros(),
438                        edge.updated_at.timestamp_micros(),
439                        edge.deleted_at.map(|t| t.timestamp_micros()),
440                        metadata_str,
441                        edge.target_backend.as_deref(),
442                    ],
443                ) {
444                    let _ = conn.execute_batch("ROLLBACK");
445                    return Err(e);
446                }
447                affected += 1;
448            }
449
450            if let Err(e) = conn.execute_batch("COMMIT") {
451                let _ = conn.execute_batch("ROLLBACK");
452                return Err(e);
453            }
454            Ok(BatchWriteSummary {
455                attempted,
456                affected,
457                failed: 0,
458                first_error: String::new(),
459            })
460        })
461        .await
462    }
463
464    async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
465        let namespace = self.namespace.clone();
466        let id_str = Uuid::from(id).to_string();
467
468        self.with_reader("get_edge", move |conn| {
469            let mut stmt = conn.prepare(
470                "SELECT namespace, id, source_id, target_id, relation, weight, \
471                        created_at, updated_at, deleted_at, metadata, target_backend \
472                 FROM graph_edges WHERE namespace = ?1 AND id = ?2 AND deleted_at IS NULL",
473            )?;
474            let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
475            match rows.next()? {
476                Some(row) => Ok(Some(read_edge(row)?)),
477                None => Ok(None),
478            }
479        })
480        .await
481    }
482
483    async fn delete_edge(&self, id: LinkId, mode: DeleteMode) -> Result<bool, StorageError> {
484        let namespace = self.namespace.clone();
485        let id_str = Uuid::from(id).to_string();
486
487        self.with_writer("delete_edge", move |conn| {
488            let affected = match mode {
489                DeleteMode::Soft => conn.execute(
490                    "UPDATE graph_edges SET deleted_at = ?3, updated_at = ?3 \
491                     WHERE namespace = ?1 AND id = ?2 AND deleted_at IS NULL",
492                    rusqlite::params![namespace, id_str, chrono::Utc::now().timestamp_micros(),],
493                )?,
494                DeleteMode::Hard => conn.execute(
495                    "DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
496                    rusqlite::params![namespace, id_str],
497                )?,
498            };
499            Ok(affected > 0)
500        })
501        .await
502    }
503
504    async fn query_edges(
505        &self,
506        filter: EdgeFilter,
507        sort: Vec<SortOrder<EdgeSortField>>,
508        page: PageRequest,
509    ) -> Result<Page<Edge>, StorageError> {
510        let namespace = self.namespace.clone();
511        self.with_reader("query_edges", move |conn| {
512            let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
513
514            let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
515            let total: i64 = {
516                let mut stmt = conn.prepare(&count_sql)?;
517                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
518                    filter_params.iter().map(|p| p.as_ref()).collect();
519                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
520            };
521
522            let order_clause = if sort.is_empty() {
523                " ORDER BY created_at DESC".to_string()
524            } else {
525                let parts: Vec<String> = sort
526                    .iter()
527                    .map(|s| {
528                        let dir = match s.direction {
529                            SortDirection::Asc => "ASC",
530                            SortDirection::Desc => "DESC",
531                        };
532                        format!("{} {}", edge_sort_col(&s.field), dir)
533                    })
534                    .collect();
535                format!(" ORDER BY {}", parts.join(", "))
536            };
537
538            let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
539            let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
540            all_params.push(Box::new(page.limit as i64));
541            all_params.push(Box::new(page.offset as i64));
542
543            let limit_idx = all_params.len() - 1;
544            let offset_idx = all_params.len();
545
546            let data_sql = format!(
547                "SELECT namespace, id, source_id, target_id, relation, weight, \
548                        created_at, updated_at, deleted_at, metadata, target_backend \
549                 FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
550                where_clause, order_clause, limit_idx, offset_idx,
551            );
552
553            let mut stmt = conn.prepare(&data_sql)?;
554            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
555                all_params.iter().map(|p| p.as_ref()).collect();
556            let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
557
558            let mut items = Vec::new();
559            for row in rows {
560                items.push(row?);
561            }
562
563            Ok(Page {
564                items,
565                total: Some(total as u64),
566            })
567        })
568        .await
569    }
570
571    async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
572        let namespace = self.namespace.clone();
573        self.with_reader("count_edges", move |conn| {
574            let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
575            let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
576            let mut stmt = conn.prepare(&sql)?;
577            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
578                params.iter().map(|p| p.as_ref()).collect();
579            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
580            Ok(count as u64)
581        })
582        .await
583    }
584
585    async fn neighbors(
586        &self,
587        node_id: Uuid,
588        query: NeighborQuery,
589    ) -> Result<Vec<NeighborHit>, StorageError> {
590        use khive_storage::types::Direction;
591
592        let namespace = self.namespace.clone();
593        let node_str = node_id.to_string();
594
595        self.with_reader("neighbors", move |conn| {
596            let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
597                            FROM graph_edges \
598                            WHERE namespace = ?1 AND source_id = ?2 AND deleted_at IS NULL";
599            let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
600                           FROM graph_edges \
601                           WHERE namespace = ?1 AND target_id = ?2 AND deleted_at IS NULL";
602
603            let sql = match query.direction {
604                Direction::Out => base_out.to_string(),
605                Direction::In => base_in.to_string(),
606                Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
607            };
608
609            let mut conditions: Vec<String> = Vec::new();
610            let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
611            let mut param_idx = 3;
612
613            if let Some(ref rels) = query.relations {
614                if !rels.is_empty() {
615                    let placeholders: Vec<String> = rels
616                        .iter()
617                        .map(|r| {
618                            extra_params.push(Box::new(r.to_string()));
619                            let p = format!("?{}", param_idx);
620                            param_idx += 1;
621                            p
622                        })
623                        .collect();
624                    conditions.push(format!("relation IN ({})", placeholders.join(",")));
625                }
626            }
627
628            if let Some(min_w) = query.min_weight {
629                extra_params.push(Box::new(min_w));
630                conditions.push(format!("weight >= ?{}", param_idx));
631                param_idx += 1;
632            }
633
634            let where_extra = if conditions.is_empty() {
635                String::new()
636            } else {
637                format!(" WHERE {}", conditions.join(" AND "))
638            };
639
640            let limit_clause = if let Some(lim) = query.limit {
641                extra_params.push(Box::new(lim as i64));
642                format!(" LIMIT ?{}", param_idx)
643            } else {
644                String::new()
645            };
646
647            let full_sql = format!(
648                "SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
649                sql, where_extra, limit_clause
650            );
651
652            let mut stmt = conn.prepare(&full_sql)?;
653
654            let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
655            all_params.push(Box::new(namespace.clone()));
656            all_params.push(Box::new(node_str.clone()));
657            all_params.extend(extra_params);
658
659            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
660                all_params.iter().map(|p| p.as_ref()).collect();
661
662            let rows = stmt.query_map(param_refs.as_slice(), |row| {
663                let nid_str: String = row.get(0)?;
664                let eid_str: String = row.get(1)?;
665                let relation_str: String = row.get(2)?;
666                let weight: f64 = row.get(3)?;
667                Ok((nid_str, eid_str, relation_str, weight))
668            })?;
669
670            let mut hits = Vec::new();
671            for row in rows {
672                let (nid_str, eid_str, relation_str, weight) = row?;
673                let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
674                    rusqlite::Error::FromSqlConversionFailure(
675                        2,
676                        rusqlite::types::Type::Text,
677                        Box::new(e),
678                    )
679                })?;
680                hits.push(NeighborHit {
681                    node_id: parse_uuid(&nid_str)?,
682                    edge_id: parse_uuid(&eid_str)?,
683                    relation,
684                    weight,
685                    name: None,
686                    kind: None,
687                });
688            }
689
690            Ok(hits)
691        })
692        .await
693    }
694
695    async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
696        use khive_storage::types::Direction;
697
698        if request.roots.is_empty() {
699            return Ok(Vec::new());
700        }
701
702        let roots = request.roots.clone();
703        let opts = request.options.clone();
704        let include_roots = request.include_roots;
705        let namespace = self.namespace.clone();
706
707        self.with_reader("traverse", move |conn| {
708            let mut all_paths: Vec<GraphPath> = Vec::new();
709
710            for root_id in &roots {
711                let root_str = root_id.to_string();
712
713                let (join_condition, next_node) = match opts.direction {
714                    Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
715                    Direction::In => ("e.target_id = t.node_id", "e.source_id"),
716                    Direction::Both => (
717                        "(e.source_id = t.node_id OR e.target_id = t.node_id)",
718                        "CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
719                    ),
720                };
721
722                let mut relation_cond = String::new();
723                let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
724                let mut param_idx = 4;
725
726                if let Some(ref rels) = opts.relations {
727                    if !rels.is_empty() {
728                        let placeholders: Vec<String> = rels
729                            .iter()
730                            .map(|r| {
731                                relation_params.push(Box::new(r.to_string()));
732                                let p = format!("?{}", param_idx);
733                                param_idx += 1;
734                                p
735                            })
736                            .collect();
737                        relation_cond =
738                            format!(" AND e.relation IN ({})", placeholders.join(","));
739                    }
740                }
741
742                let mut weight_cond = String::new();
743                if let Some(min_w) = opts.min_weight {
744                    relation_params.push(Box::new(min_w));
745                    weight_cond = format!(" AND e.weight >= ?{}", param_idx);
746                    param_idx += 1;
747                }
748
749                let limit_clause = if let Some(lim) = opts.limit {
750                    relation_params.push(Box::new(lim as i64));
751                    format!(" LIMIT ?{}", param_idx)
752                } else {
753                    String::new()
754                };
755
756                let cte_sql = format!(
757                    "WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
758                         SELECT ?2, NULL, 0, ?2, 0.0 \
759                         UNION ALL \
760                         SELECT {next_node}, e.id, t.depth + 1, \
761                                t.path || ',' || {next_node}, \
762                                t.total_weight + e.weight \
763                         FROM graph_edges e \
764                         JOIN traversal t ON {join_condition} \
765                         WHERE e.namespace = ?1 \
766                           AND e.deleted_at IS NULL \
767                           AND t.depth < ?3 \
768                           AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
769                     ) \
770                     SELECT node_id, edge_id, depth, path, total_weight \
771                     FROM traversal WHERE depth > 0 \
772                     ORDER BY depth{limit}",
773                    next_node = next_node,
774                    join_condition = join_condition,
775                    rel_cond = relation_cond,
776                    wt_cond = weight_cond,
777                    limit = limit_clause,
778                );
779
780                let mut stmt = conn.prepare(&cte_sql)?;
781
782                let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
783                all_params.push(Box::new(namespace.clone()));
784                all_params.push(Box::new(root_str.clone()));
785                all_params.push(Box::new(opts.max_depth as i64));
786                all_params.extend(relation_params);
787
788                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
789                    all_params.iter().map(|p| p.as_ref()).collect();
790
791                let rows = stmt.query_map(param_refs.as_slice(), |row| {
792                    let node_str: String = row.get(0)?;
793                    let edge_str: Option<String> = row.get(1)?;
794                    let depth: i64 = row.get(2)?;
795                    let _path: String = row.get(3)?;
796                    let total_weight: f64 = row.get(4)?;
797                    Ok((node_str, edge_str, depth, total_weight))
798                })?;
799
800                let mut nodes = Vec::new();
801                let mut max_weight = 0.0f64;
802
803                if include_roots {
804                    nodes.push(PathNode {
805                        node_id: *root_id,
806                        via_edge: None,
807                        depth: 0,
808                        name: None,
809                        kind: None,
810                    });
811                }
812
813                for row in rows {
814                    let (node_str, edge_str, depth, total_weight) = row?;
815                    let node_id = parse_uuid(&node_str)?;
816                    let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
817                    nodes.push(PathNode {
818                        node_id,
819                        via_edge,
820                        depth: depth as usize,
821                        name: None,
822                        kind: None,
823                    });
824                    if total_weight > max_weight {
825                        max_weight = total_weight;
826                    }
827                }
828
829                if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
830                    all_paths.push(GraphPath {
831                        root_id: *root_id,
832                        nodes,
833                        total_weight: max_weight,
834                    });
835                }
836            }
837
838            Ok(all_paths)
839        })
840        .await
841    }
842}
843
844// =============================================================================
845// DDL
846// =============================================================================
847
848const GRAPH_DDL: &str = "\
849    CREATE TABLE IF NOT EXISTS graph_edges (\
850        namespace TEXT NOT NULL,\
851        id TEXT NOT NULL,\
852        source_id TEXT NOT NULL,\
853        target_id TEXT NOT NULL,\
854        relation TEXT NOT NULL,\
855        weight REAL NOT NULL DEFAULT 1.0,\
856        created_at INTEGER NOT NULL,\
857        updated_at INTEGER NOT NULL,\
858        deleted_at INTEGER,\
859        metadata TEXT,\
860        target_backend TEXT,\
861        PRIMARY KEY (namespace, id)\
862    );\
863    CREATE UNIQUE INDEX IF NOT EXISTS idx_graph_edges_unique_triple ON graph_edges(namespace, source_id, target_id, relation);\
864    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_source ON graph_edges(namespace, source_id);\
865    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_target ON graph_edges(namespace, target_id);\
866    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_relation ON graph_edges(namespace, relation);\
867    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_src_rel ON graph_edges(namespace, source_id, relation);\
868    CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_tgt_rel ON graph_edges(namespace, target_id, relation);\
869    CREATE INDEX IF NOT EXISTS idx_graph_edges_target_backend ON graph_edges(target_backend) WHERE target_backend IS NOT NULL;\
870";
871
872pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
873    conn.execute_batch(GRAPH_DDL)
874}
875
876#[cfg(test)]
877mod tests {
878    use super::*;
879    use crate::pool::PoolConfig;
880    use khive_storage::types::{Direction, TraversalOptions};
881
882    fn setup_memory_store() -> SqlGraphStore {
883        let config = PoolConfig {
884            path: None,
885            ..PoolConfig::default()
886        };
887        let pool = Arc::new(ConnectionPool::new(config).unwrap());
888
889        {
890            let writer = pool.writer().unwrap();
891            writer.conn().execute_batch(GRAPH_DDL).unwrap();
892        }
893
894        SqlGraphStore::new_scoped(pool, false, "default")
895    }
896
897    fn make_edge(source: Uuid, target: Uuid, relation: EdgeRelation, weight: f64) -> Edge {
898        let now = Utc::now();
899        Edge {
900            id: Uuid::new_v4().into(),
901            namespace: "default".to_string(),
902            source_id: source,
903            target_id: target,
904            relation,
905            weight,
906            created_at: now,
907            updated_at: now,
908            deleted_at: None,
909            metadata: None,
910            target_backend: None,
911        }
912    }
913
914    #[tokio::test]
915    async fn test_upsert_and_get_edge() {
916        let store = setup_memory_store();
917
918        let src = Uuid::new_v4();
919        let tgt = Uuid::new_v4();
920        let now = Utc::now();
921        let edge = Edge {
922            id: Uuid::new_v4().into(),
923            namespace: "default".to_string(),
924            source_id: src,
925            target_id: tgt,
926            relation: EdgeRelation::Extends,
927            weight: 0.8,
928            created_at: now,
929            updated_at: now,
930            deleted_at: None,
931            metadata: None,
932            target_backend: None,
933        };
934        let edge_id = edge.id;
935
936        store.upsert_edge(edge).await.unwrap();
937
938        let fetched = store.get_edge(edge_id).await.unwrap();
939        assert!(fetched.is_some());
940        let fetched = fetched.unwrap();
941        assert_eq!(fetched.id, edge_id);
942        assert_eq!(fetched.namespace, "default");
943        assert_eq!(fetched.source_id, src);
944        assert_eq!(fetched.target_id, tgt);
945        assert_eq!(fetched.relation, EdgeRelation::Extends);
946        assert!((fetched.weight - 0.8).abs() < 1e-9);
947    }
948
949    #[tokio::test]
950    async fn test_delete_edge() {
951        let store = setup_memory_store();
952
953        let edge = make_edge(Uuid::new_v4(), Uuid::new_v4(), EdgeRelation::Contains, 1.0);
954        let edge_id = edge.id;
955
956        store.upsert_edge(edge).await.unwrap();
957        assert!(store.get_edge(edge_id).await.unwrap().is_some());
958
959        let deleted = store.delete_edge(edge_id, DeleteMode::Hard).await.unwrap();
960        assert!(deleted);
961
962        assert!(store.get_edge(edge_id).await.unwrap().is_none());
963
964        let deleted_again = store.delete_edge(edge_id, DeleteMode::Hard).await.unwrap();
965        assert!(!deleted_again);
966    }
967
968    #[tokio::test]
969    async fn test_count_edges() {
970        let store = setup_memory_store();
971
972        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 0);
973
974        for _ in 0..5 {
975            store
976                .upsert_edge(make_edge(
977                    Uuid::new_v4(),
978                    Uuid::new_v4(),
979                    EdgeRelation::DependsOn,
980                    1.0,
981                ))
982                .await
983                .unwrap();
984        }
985
986        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 5);
987    }
988
989    #[tokio::test]
990    async fn test_neighbors_outbound() {
991        let store = setup_memory_store();
992
993        let a = Uuid::new_v4();
994        let b = Uuid::new_v4();
995        let c = Uuid::new_v4();
996        let d = Uuid::new_v4();
997
998        store
999            .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
1000            .await
1001            .unwrap();
1002        store
1003            .upsert_edge(make_edge(a, c, EdgeRelation::DependsOn, 0.7))
1004            .await
1005            .unwrap();
1006        store
1007            .upsert_edge(make_edge(d, a, EdgeRelation::Extends, 0.5))
1008            .await
1009            .unwrap();
1010
1011        let query = NeighborQuery {
1012            direction: Direction::Out,
1013            relations: None,
1014            limit: None,
1015            min_weight: None,
1016        };
1017
1018        let hits = store.neighbors(a, query).await.unwrap();
1019        assert_eq!(hits.len(), 2);
1020
1021        let neighbor_ids: Vec<Uuid> = hits.iter().map(|h| h.node_id).collect();
1022        assert!(neighbor_ids.contains(&b));
1023        assert!(neighbor_ids.contains(&c));
1024        assert!(!neighbor_ids.contains(&d));
1025    }
1026
1027    #[tokio::test]
1028    async fn test_traverse_depth_2() {
1029        let store = setup_memory_store();
1030
1031        let a = Uuid::new_v4();
1032        let b = Uuid::new_v4();
1033        let c = Uuid::new_v4();
1034        let d = Uuid::new_v4();
1035
1036        store
1037            .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
1038            .await
1039            .unwrap();
1040        store
1041            .upsert_edge(make_edge(b, c, EdgeRelation::Extends, 2.0))
1042            .await
1043            .unwrap();
1044        store
1045            .upsert_edge(make_edge(c, d, EdgeRelation::Extends, 3.0))
1046            .await
1047            .unwrap();
1048
1049        let request = TraversalRequest {
1050            roots: vec![a],
1051            options: TraversalOptions::new(2).with_direction(Direction::Out),
1052            include_roots: true,
1053        };
1054
1055        let paths = store.traverse(request).await.unwrap();
1056        assert_eq!(paths.len(), 1);
1057
1058        let path = &paths[0];
1059        let node_ids: Vec<Uuid> = path.nodes.iter().map(|n| n.node_id).collect();
1060        assert!(node_ids.contains(&a));
1061        assert!(node_ids.contains(&b));
1062        assert!(node_ids.contains(&c));
1063        assert!(!node_ids.contains(&d));
1064    }
1065
1066    #[tokio::test]
1067    async fn test_metadata_roundtrip() {
1068        let store = setup_memory_store();
1069
1070        let src = Uuid::new_v4();
1071        let tgt = Uuid::new_v4();
1072        let meta = serde_json::json!({"note": "important link", "confidence": 0.95});
1073        let now = Utc::now();
1074        let edge = Edge {
1075            id: Uuid::new_v4().into(),
1076            namespace: "default".to_string(),
1077            source_id: src,
1078            target_id: tgt,
1079            relation: EdgeRelation::Implements,
1080            weight: 0.9,
1081            created_at: now,
1082            updated_at: now,
1083            deleted_at: None,
1084            metadata: Some(meta.clone()),
1085            target_backend: None,
1086        };
1087        let edge_id = edge.id;
1088
1089        store.upsert_edge(edge).await.unwrap();
1090
1091        let fetched = store.get_edge(edge_id).await.unwrap().unwrap();
1092        assert_eq!(
1093            fetched.metadata.as_ref(),
1094            Some(&meta),
1095            "metadata must survive a write/read roundtrip via get_edge"
1096        );
1097
1098        // Also verify via query_edges.
1099        let page = store
1100            .query_edges(EdgeFilter::default(), vec![], PageRequest::default())
1101            .await
1102            .unwrap();
1103        let from_query = page
1104            .items
1105            .iter()
1106            .find(|e| e.id == edge_id)
1107            .expect("edge must appear in query_edges result");
1108        assert_eq!(
1109            from_query.metadata.as_ref(),
1110            Some(&meta),
1111            "metadata must survive a write/read roundtrip via query_edges"
1112        );
1113    }
1114
1115    #[tokio::test]
1116    async fn test_upsert_edges_batch() {
1117        let store = setup_memory_store();
1118
1119        let edges: Vec<Edge> = (0..10)
1120            .map(|i| {
1121                make_edge(
1122                    Uuid::new_v4(),
1123                    Uuid::new_v4(),
1124                    EdgeRelation::Implements,
1125                    i as f64,
1126                )
1127            })
1128            .collect();
1129
1130        let summary = store.upsert_edges(edges).await.unwrap();
1131        assert_eq!(summary.attempted, 10);
1132        assert_eq!(summary.affected, 10);
1133        assert_eq!(summary.failed, 0);
1134
1135        assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 10);
1136    }
1137
1138    // ---- #229 deduplication test ----
1139
1140    #[tokio::test]
1141    async fn graph_duplicate_edges_ignored() {
1142        let store = setup_memory_store();
1143
1144        let src = Uuid::new_v4();
1145        let tgt = Uuid::new_v4();
1146
1147        // Two edges with the same (source_id, target_id, relation) triple but different IDs.
1148        let now = Utc::now();
1149        let edge1 = Edge {
1150            id: Uuid::new_v4().into(),
1151            namespace: "default".to_string(),
1152            source_id: src,
1153            target_id: tgt,
1154            relation: EdgeRelation::Extends,
1155            weight: 1.0,
1156            created_at: now,
1157            updated_at: now,
1158            deleted_at: None,
1159            metadata: None,
1160            target_backend: None,
1161        };
1162        let edge2 = Edge {
1163            id: Uuid::new_v4().into(),
1164            namespace: "default".to_string(),
1165            source_id: src,
1166            target_id: tgt,
1167            relation: EdgeRelation::Extends,
1168            weight: 0.5,
1169            created_at: now,
1170            updated_at: now,
1171            deleted_at: None,
1172            metadata: None,
1173            target_backend: None,
1174        };
1175
1176        store.upsert_edge(edge1).await.unwrap();
1177        store.upsert_edge(edge2).await.unwrap();
1178
1179        assert_eq!(
1180            store.count_edges(EdgeFilter::default()).await.unwrap(),
1181            1,
1182            "duplicate (source, target, relation) triple must be ignored; only one edge must exist"
1183        );
1184    }
1185
1186    // F053 (CRIT): natural-key conflict must DO UPDATE (refresh weight/metadata), not DO NOTHING.
1187    // ADR-009 requires the second upsert to overwrite weight=0.5; current code keeps weight=1.0.
1188    #[tokio::test]
1189    async fn graph_duplicate_edges_refresh_existing_row() {
1190        let store = setup_memory_store();
1191        let src = Uuid::new_v4();
1192        let tgt = Uuid::new_v4();
1193
1194        let now = Utc::now();
1195        let edge1 = Edge {
1196            id: Uuid::new_v4().into(),
1197            namespace: "default".to_string(),
1198            source_id: src,
1199            target_id: tgt,
1200            relation: EdgeRelation::Extends,
1201            weight: 1.0,
1202            created_at: now,
1203            updated_at: now,
1204            deleted_at: None,
1205            metadata: None,
1206            target_backend: None,
1207        };
1208        let edge2 = Edge {
1209            id: Uuid::new_v4().into(),
1210            namespace: "default".to_string(),
1211            source_id: src,
1212            target_id: tgt,
1213            relation: EdgeRelation::Extends,
1214            weight: 0.5,
1215            created_at: now,
1216            updated_at: now,
1217            deleted_at: None,
1218            metadata: None,
1219            target_backend: None,
1220        };
1221
1222        store.upsert_edge(edge1).await.unwrap();
1223        store.upsert_edge(edge2).await.unwrap();
1224
1225        let edges = store
1226            .query_edges(EdgeFilter::default(), vec![], PageRequest::default())
1227            .await
1228            .unwrap();
1229        assert_eq!(
1230            edges.items.len(),
1231            1,
1232            "duplicate natural key must collapse to one row"
1233        );
1234        assert!(
1235            (edges.items[0].weight - 0.5).abs() < 0.001,
1236            "F053: natural-key conflict must DO UPDATE (weight=0.5 from second upsert); \
1237             current DO NOTHING keeps stale weight={}",
1238            edges.items[0].weight
1239        );
1240    }
1241}