1use std::sync::Arc;
12
13use async_trait::async_trait;
14use chrono::{DateTime, TimeZone, Utc};
15use uuid::Uuid;
16
17use khive_storage::error::StorageError;
18use khive_storage::types::{
19 BatchWriteSummary, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit, NeighborQuery,
20 Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
21};
22use khive_storage::GraphStore;
23use khive_storage::LinkId;
24use khive_storage::StorageCapability;
25use khive_types::EdgeRelation;
26
27use crate::error::SqliteError;
28use crate::pool::ConnectionPool;
29
30fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
32 StorageError::driver(StorageCapability::Graph, op, e)
33}
34
35fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
36 StorageError::driver(StorageCapability::Graph, op, e)
37}
38
39pub struct SqlGraphStore {
41 pool: Arc<ConnectionPool>,
42 is_file_backed: bool,
43 namespace: String,
44}
45
46impl SqlGraphStore {
47 pub fn new_scoped(
49 pool: Arc<ConnectionPool>,
50 is_file_backed: bool,
51 namespace: impl Into<String>,
52 ) -> Self {
53 Self {
54 pool,
55 is_file_backed,
56 namespace: namespace.into(),
57 }
58 }
59
60 fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
61 let config = self.pool.config();
62 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
63 operation: "graph_writer".into(),
64 message: "in-memory databases do not support standalone connections".into(),
65 })?;
66
67 let conn = rusqlite::Connection::open_with_flags(
68 path,
69 rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
70 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
71 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
72 )
73 .map_err(|e| map_err(e, "open_graph_writer"))?;
74
75 conn.busy_timeout(config.busy_timeout)
76 .map_err(|e| map_err(e, "open_graph_writer"))?;
77 conn.pragma_update(None, "foreign_keys", "ON")
78 .map_err(|e| map_err(e, "open_graph_writer"))?;
79 conn.pragma_update(None, "synchronous", "NORMAL")
80 .map_err(|e| map_err(e, "open_graph_writer"))?;
81
82 Ok(conn)
83 }
84
85 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
86 let config = self.pool.config();
87 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
88 operation: "graph_reader".into(),
89 message: "in-memory databases do not support standalone connections".into(),
90 })?;
91
92 let conn = rusqlite::Connection::open_with_flags(
93 path,
94 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
95 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
96 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
97 )
98 .map_err(|e| map_err(e, "open_graph_reader"))?;
99
100 conn.busy_timeout(config.busy_timeout)
101 .map_err(|e| map_err(e, "open_graph_reader"))?;
102 conn.pragma_update(None, "foreign_keys", "ON")
103 .map_err(|e| map_err(e, "open_graph_reader"))?;
104 conn.pragma_update(None, "synchronous", "NORMAL")
105 .map_err(|e| map_err(e, "open_graph_reader"))?;
106
107 Ok(conn)
108 }
109
110 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
111 where
112 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
113 R: Send + 'static,
114 {
115 if self.is_file_backed {
116 let conn = self.open_standalone_writer()?;
117 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
118 .await
119 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
120 } else {
121 let pool = Arc::clone(&self.pool);
122 tokio::task::spawn_blocking(move || {
123 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
124 f(guard.conn()).map_err(|e| map_err(e, op))
125 })
126 .await
127 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
128 }
129 }
130
131 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
132 where
133 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
134 R: Send + 'static,
135 {
136 if self.is_file_backed {
137 let conn = self.open_standalone_reader()?;
138 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
139 .await
140 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
141 } else {
142 let pool = Arc::clone(&self.pool);
143 tokio::task::spawn_blocking(move || {
144 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
145 f(guard.conn()).map_err(|e| map_err(e, op))
146 })
147 .await
148 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
149 }
150 }
151}
152
153fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
158 let id_str: String = row.get(0)?;
159 let source_str: String = row.get(1)?;
160 let target_str: String = row.get(2)?;
161 let relation_str: String = row.get(3)?;
162 let weight: f64 = row.get(4)?;
163 let created_micros: i64 = row.get(5)?;
164 let metadata_str: Option<String> = row.get(6)?;
165
166 let id = parse_uuid(&id_str)?;
167 let source_id = parse_uuid(&source_str)?;
168 let target_id = parse_uuid(&target_str)?;
169 let created_at = micros_to_datetime(created_micros);
170 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
171 rusqlite::Error::FromSqlConversionFailure(3, rusqlite::types::Type::Text, Box::new(e))
172 })?;
173 let metadata = metadata_str.and_then(|s| serde_json::from_str(&s).ok());
174
175 Ok(Edge {
176 id: id.into(),
177 source_id,
178 target_id,
179 relation,
180 weight,
181 created_at,
182 metadata,
183 })
184}
185
186fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
187 Uuid::parse_str(s).map_err(|e| {
188 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
189 })
190}
191
192fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
193 Utc.timestamp_micros(micros)
194 .single()
195 .unwrap_or_else(Utc::now)
196}
197
198fn build_edge_filter_sql(
199 namespace: &str,
200 filter: &EdgeFilter,
201) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
202 let mut conditions: Vec<String> = vec!["namespace = ?1".to_string()];
203 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
204
205 if !filter.ids.is_empty() {
206 let placeholders: Vec<String> = filter
207 .ids
208 .iter()
209 .map(|id| {
210 params.push(Box::new(id.to_string()));
211 format!("?{}", params.len())
212 })
213 .collect();
214 conditions.push(format!("id IN ({})", placeholders.join(",")));
215 }
216
217 if !filter.source_ids.is_empty() {
218 let placeholders: Vec<String> = filter
219 .source_ids
220 .iter()
221 .map(|id| {
222 params.push(Box::new(id.to_string()));
223 format!("?{}", params.len())
224 })
225 .collect();
226 conditions.push(format!("source_id IN ({})", placeholders.join(",")));
227 }
228
229 if !filter.target_ids.is_empty() {
230 let placeholders: Vec<String> = filter
231 .target_ids
232 .iter()
233 .map(|id| {
234 params.push(Box::new(id.to_string()));
235 format!("?{}", params.len())
236 })
237 .collect();
238 conditions.push(format!("target_id IN ({})", placeholders.join(",")));
239 }
240
241 if !filter.relations.is_empty() {
242 let placeholders: Vec<String> = filter
243 .relations
244 .iter()
245 .map(|r| {
246 params.push(Box::new(r.to_string()));
247 format!("?{}", params.len())
248 })
249 .collect();
250 conditions.push(format!("relation IN ({})", placeholders.join(",")));
251 }
252
253 if let Some(min_w) = filter.min_weight {
254 params.push(Box::new(min_w));
255 conditions.push(format!("weight >= ?{}", params.len()));
256 }
257
258 if let Some(max_w) = filter.max_weight {
259 params.push(Box::new(max_w));
260 conditions.push(format!("weight <= ?{}", params.len()));
261 }
262
263 if let Some(ref time_range) = filter.created_at {
264 if let Some(start) = time_range.start {
265 params.push(Box::new(start.timestamp_micros()));
266 conditions.push(format!("created_at >= ?{}", params.len()));
267 }
268 if let Some(end) = time_range.end {
269 params.push(Box::new(end.timestamp_micros()));
270 conditions.push(format!("created_at < ?{}", params.len()));
271 }
272 }
273
274 let clause = format!(" WHERE {}", conditions.join(" AND "));
275 (clause, params)
276}
277
278fn edge_sort_col(field: &EdgeSortField) -> &'static str {
279 match field {
280 EdgeSortField::CreatedAt => "created_at",
281 EdgeSortField::Weight => "weight",
282 EdgeSortField::Relation => "relation",
283 }
284}
285
286#[async_trait]
291impl GraphStore for SqlGraphStore {
292 async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
293 let namespace = self.namespace.clone();
294 let id_str = Uuid::from(edge.id).to_string();
295 let src_str = edge.source_id.to_string();
296 let tgt_str = edge.target_id.to_string();
297 let relation_str = edge.relation.to_string();
298 let metadata_str = edge
299 .metadata
300 .as_ref()
301 .map(|v| serde_json::to_string(v).unwrap_or_default());
302 self.with_writer("upsert_edge", move |conn| {
303 conn.execute(
304 "INSERT INTO graph_edges \
305 (namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
306 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) \
307 ON CONFLICT(namespace, id) DO UPDATE SET \
308 source_id = excluded.source_id, \
309 target_id = excluded.target_id, \
310 relation = excluded.relation, \
311 weight = excluded.weight, \
312 created_at = excluded.created_at, \
313 metadata = excluded.metadata \
314 ON CONFLICT(namespace, source_id, target_id, relation) DO NOTHING",
315 rusqlite::params![
316 namespace,
317 id_str,
318 src_str,
319 tgt_str,
320 relation_str,
321 edge.weight,
322 edge.created_at.timestamp_micros(),
323 metadata_str,
324 ],
325 )?;
326 Ok(())
327 })
328 .await
329 }
330
331 async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
332 let attempted = edges.len() as u64;
333 let namespace = self.namespace.clone();
334
335 self.with_writer("upsert_edges", move |conn| {
336 conn.execute_batch("BEGIN IMMEDIATE")?;
337 let mut affected = 0u64;
338 let mut failed = 0u64;
339 let mut first_error = String::new();
340
341 for edge in &edges {
342 let id_str = Uuid::from(edge.id).to_string();
343 let src_str = edge.source_id.to_string();
344 let tgt_str = edge.target_id.to_string();
345 let relation_str = edge.relation.to_string();
346 let metadata_str = edge
347 .metadata
348 .as_ref()
349 .map(|v| serde_json::to_string(v).unwrap_or_default());
350 match conn.execute(
351 "INSERT INTO graph_edges \
352 (namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
353 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) \
354 ON CONFLICT(namespace, id) DO UPDATE SET \
355 source_id = excluded.source_id, \
356 target_id = excluded.target_id, \
357 relation = excluded.relation, \
358 weight = excluded.weight, \
359 created_at = excluded.created_at, \
360 metadata = excluded.metadata \
361 ON CONFLICT(namespace, source_id, target_id, relation) DO NOTHING",
362 rusqlite::params![
363 &namespace,
364 id_str,
365 src_str,
366 tgt_str,
367 relation_str,
368 edge.weight,
369 edge.created_at.timestamp_micros(),
370 metadata_str,
371 ],
372 ) {
373 Ok(_) => affected += 1,
374 Err(e) => {
375 if first_error.is_empty() {
376 first_error = e.to_string();
377 }
378 failed += 1;
379 }
380 }
381 }
382
383 if let Err(e) = conn.execute_batch("COMMIT") {
384 let _ = conn.execute_batch("ROLLBACK");
385 return Err(e);
386 }
387 Ok(BatchWriteSummary {
388 attempted,
389 affected,
390 failed,
391 first_error,
392 })
393 })
394 .await
395 }
396
397 async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
398 let namespace = self.namespace.clone();
399 let id_str = Uuid::from(id).to_string();
400
401 self.with_reader("get_edge", move |conn| {
402 let mut stmt = conn.prepare(
403 "SELECT id, source_id, target_id, relation, weight, created_at, metadata \
404 FROM graph_edges WHERE namespace = ?1 AND id = ?2",
405 )?;
406 let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
407 match rows.next()? {
408 Some(row) => Ok(Some(read_edge(row)?)),
409 None => Ok(None),
410 }
411 })
412 .await
413 }
414
415 async fn delete_edge(&self, id: LinkId) -> Result<bool, StorageError> {
416 let namespace = self.namespace.clone();
417 let id_str = Uuid::from(id).to_string();
418
419 self.with_writer("delete_edge", move |conn| {
420 let deleted = conn.execute(
421 "DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
422 rusqlite::params![namespace, id_str],
423 )?;
424 Ok(deleted > 0)
425 })
426 .await
427 }
428
429 async fn query_edges(
430 &self,
431 filter: EdgeFilter,
432 sort: Vec<SortOrder<EdgeSortField>>,
433 page: PageRequest,
434 ) -> Result<Page<Edge>, StorageError> {
435 let namespace = self.namespace.clone();
436 self.with_reader("query_edges", move |conn| {
437 let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
438
439 let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
440 let total: i64 = {
441 let mut stmt = conn.prepare(&count_sql)?;
442 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
443 filter_params.iter().map(|p| p.as_ref()).collect();
444 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
445 };
446
447 let order_clause = if sort.is_empty() {
448 " ORDER BY created_at DESC".to_string()
449 } else {
450 let parts: Vec<String> = sort
451 .iter()
452 .map(|s| {
453 let dir = match s.direction {
454 SortDirection::Asc => "ASC",
455 SortDirection::Desc => "DESC",
456 };
457 format!("{} {}", edge_sort_col(&s.field), dir)
458 })
459 .collect();
460 format!(" ORDER BY {}", parts.join(", "))
461 };
462
463 let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
464 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
465 all_params.push(Box::new(page.limit as i64));
466 all_params.push(Box::new(page.offset as i64));
467
468 let limit_idx = all_params.len() - 1;
469 let offset_idx = all_params.len();
470
471 let data_sql = format!(
472 "SELECT id, source_id, target_id, relation, weight, created_at, metadata \
473 FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
474 where_clause, order_clause, limit_idx, offset_idx,
475 );
476
477 let mut stmt = conn.prepare(&data_sql)?;
478 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
479 all_params.iter().map(|p| p.as_ref()).collect();
480 let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
481
482 let mut items = Vec::new();
483 for row in rows {
484 items.push(row?);
485 }
486
487 Ok(Page {
488 items,
489 total: Some(total as u64),
490 })
491 })
492 .await
493 }
494
495 async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
496 let namespace = self.namespace.clone();
497 self.with_reader("count_edges", move |conn| {
498 let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
499 let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
500 let mut stmt = conn.prepare(&sql)?;
501 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
502 params.iter().map(|p| p.as_ref()).collect();
503 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
504 Ok(count as u64)
505 })
506 .await
507 }
508
509 async fn neighbors(
510 &self,
511 node_id: Uuid,
512 query: NeighborQuery,
513 ) -> Result<Vec<NeighborHit>, StorageError> {
514 use khive_storage::types::Direction;
515
516 let namespace = self.namespace.clone();
517 let node_str = node_id.to_string();
518
519 self.with_reader("neighbors", move |conn| {
520 let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
521 FROM graph_edges WHERE namespace = ?1 AND source_id = ?2";
522 let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
523 FROM graph_edges WHERE namespace = ?1 AND target_id = ?2";
524
525 let sql = match query.direction {
526 Direction::Out => base_out.to_string(),
527 Direction::In => base_in.to_string(),
528 Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
529 };
530
531 let mut conditions: Vec<String> = Vec::new();
532 let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
533 let mut param_idx = 3;
534
535 if let Some(ref rels) = query.relations {
536 if !rels.is_empty() {
537 let placeholders: Vec<String> = rels
538 .iter()
539 .map(|r| {
540 extra_params.push(Box::new(r.to_string()));
541 let p = format!("?{}", param_idx);
542 param_idx += 1;
543 p
544 })
545 .collect();
546 conditions.push(format!("relation IN ({})", placeholders.join(",")));
547 }
548 }
549
550 if let Some(min_w) = query.min_weight {
551 extra_params.push(Box::new(min_w));
552 conditions.push(format!("weight >= ?{}", param_idx));
553 param_idx += 1;
554 }
555
556 let where_extra = if conditions.is_empty() {
557 String::new()
558 } else {
559 format!(" WHERE {}", conditions.join(" AND "))
560 };
561
562 let limit_clause = if let Some(lim) = query.limit {
563 extra_params.push(Box::new(lim as i64));
564 format!(" LIMIT ?{}", param_idx)
565 } else {
566 String::new()
567 };
568
569 let full_sql = format!(
570 "SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
571 sql, where_extra, limit_clause
572 );
573
574 let mut stmt = conn.prepare(&full_sql)?;
575
576 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
577 all_params.push(Box::new(namespace.clone()));
578 all_params.push(Box::new(node_str.clone()));
579 all_params.extend(extra_params);
580
581 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
582 all_params.iter().map(|p| p.as_ref()).collect();
583
584 let rows = stmt.query_map(param_refs.as_slice(), |row| {
585 let nid_str: String = row.get(0)?;
586 let eid_str: String = row.get(1)?;
587 let relation_str: String = row.get(2)?;
588 let weight: f64 = row.get(3)?;
589 Ok((nid_str, eid_str, relation_str, weight))
590 })?;
591
592 let mut hits = Vec::new();
593 for row in rows {
594 let (nid_str, eid_str, relation_str, weight) = row?;
595 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
596 rusqlite::Error::FromSqlConversionFailure(
597 2,
598 rusqlite::types::Type::Text,
599 Box::new(e),
600 )
601 })?;
602 hits.push(NeighborHit {
603 node_id: parse_uuid(&nid_str)?,
604 edge_id: parse_uuid(&eid_str)?,
605 relation,
606 weight,
607 name: None,
608 kind: None,
609 });
610 }
611
612 Ok(hits)
613 })
614 .await
615 }
616
617 async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
618 use khive_storage::types::Direction;
619
620 if request.roots.is_empty() {
621 return Ok(Vec::new());
622 }
623
624 let roots = request.roots.clone();
625 let opts = request.options.clone();
626 let include_roots = request.include_roots;
627 let namespace = self.namespace.clone();
628
629 self.with_reader("traverse", move |conn| {
630 let mut all_paths: Vec<GraphPath> = Vec::new();
631
632 for root_id in &roots {
633 let root_str = root_id.to_string();
634
635 let (join_condition, next_node) = match opts.direction {
636 Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
637 Direction::In => ("e.target_id = t.node_id", "e.source_id"),
638 Direction::Both => (
639 "(e.source_id = t.node_id OR e.target_id = t.node_id)",
640 "CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
641 ),
642 };
643
644 let mut relation_cond = String::new();
645 let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
646 let mut param_idx = 4;
647
648 if let Some(ref rels) = opts.relations {
649 if !rels.is_empty() {
650 let placeholders: Vec<String> = rels
651 .iter()
652 .map(|r| {
653 relation_params.push(Box::new(r.to_string()));
654 let p = format!("?{}", param_idx);
655 param_idx += 1;
656 p
657 })
658 .collect();
659 relation_cond =
660 format!(" AND e.relation IN ({})", placeholders.join(","));
661 }
662 }
663
664 let mut weight_cond = String::new();
665 if let Some(min_w) = opts.min_weight {
666 relation_params.push(Box::new(min_w));
667 weight_cond = format!(" AND e.weight >= ?{}", param_idx);
668 param_idx += 1;
669 }
670
671 let limit_clause = if let Some(lim) = opts.limit {
672 relation_params.push(Box::new(lim as i64));
673 format!(" LIMIT ?{}", param_idx)
674 } else {
675 String::new()
676 };
677
678 let cte_sql = format!(
679 "WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
680 SELECT ?2, NULL, 0, ?2, 0.0 \
681 UNION ALL \
682 SELECT {next_node}, e.id, t.depth + 1, \
683 t.path || ',' || {next_node}, \
684 t.total_weight + e.weight \
685 FROM graph_edges e \
686 JOIN traversal t ON {join_condition} \
687 WHERE e.namespace = ?1 \
688 AND t.depth < ?3 \
689 AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
690 ) \
691 SELECT node_id, edge_id, depth, path, total_weight \
692 FROM traversal WHERE depth > 0 \
693 ORDER BY depth{limit}",
694 next_node = next_node,
695 join_condition = join_condition,
696 rel_cond = relation_cond,
697 wt_cond = weight_cond,
698 limit = limit_clause,
699 );
700
701 let mut stmt = conn.prepare(&cte_sql)?;
702
703 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
704 all_params.push(Box::new(namespace.clone()));
705 all_params.push(Box::new(root_str.clone()));
706 all_params.push(Box::new(opts.max_depth as i64));
707 all_params.extend(relation_params);
708
709 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
710 all_params.iter().map(|p| p.as_ref()).collect();
711
712 let rows = stmt.query_map(param_refs.as_slice(), |row| {
713 let node_str: String = row.get(0)?;
714 let edge_str: Option<String> = row.get(1)?;
715 let depth: i64 = row.get(2)?;
716 let _path: String = row.get(3)?;
717 let total_weight: f64 = row.get(4)?;
718 Ok((node_str, edge_str, depth, total_weight))
719 })?;
720
721 let mut nodes = Vec::new();
722 let mut max_weight = 0.0f64;
723
724 if include_roots {
725 nodes.push(PathNode {
726 node_id: *root_id,
727 via_edge: None,
728 depth: 0,
729 name: None,
730 kind: None,
731 });
732 }
733
734 for row in rows {
735 let (node_str, edge_str, depth, total_weight) = row?;
736 let node_id = parse_uuid(&node_str)?;
737 let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
738 nodes.push(PathNode {
739 node_id,
740 via_edge,
741 depth: depth as usize,
742 name: None,
743 kind: None,
744 });
745 if total_weight > max_weight {
746 max_weight = total_weight;
747 }
748 }
749
750 if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
751 all_paths.push(GraphPath {
752 root_id: *root_id,
753 nodes,
754 total_weight: max_weight,
755 });
756 }
757 }
758
759 Ok(all_paths)
760 })
761 .await
762 }
763}
764
765const GRAPH_DDL: &str = "\
770 CREATE TABLE IF NOT EXISTS graph_edges (\
771 namespace TEXT NOT NULL,\
772 id TEXT NOT NULL,\
773 source_id TEXT NOT NULL,\
774 target_id TEXT NOT NULL,\
775 relation TEXT NOT NULL,\
776 weight REAL NOT NULL DEFAULT 1.0,\
777 created_at INTEGER NOT NULL,\
778 metadata TEXT,\
779 PRIMARY KEY (namespace, id)\
780 );\
781 CREATE UNIQUE INDEX IF NOT EXISTS idx_graph_edges_unique_triple ON graph_edges(namespace, source_id, target_id, relation);\
782 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_source ON graph_edges(namespace, source_id);\
783 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_target ON graph_edges(namespace, target_id);\
784 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_relation ON graph_edges(namespace, relation);\
785 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_src_rel ON graph_edges(namespace, source_id, relation);\
786 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_tgt_rel ON graph_edges(namespace, target_id, relation);\
787";
788
789pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
790 conn.execute_batch(GRAPH_DDL)
791}
792
793#[cfg(test)]
794mod tests {
795 use super::*;
796 use crate::pool::PoolConfig;
797 use khive_storage::types::{Direction, TraversalOptions};
798
799 fn setup_memory_store() -> SqlGraphStore {
800 let config = PoolConfig {
801 path: None,
802 ..PoolConfig::default()
803 };
804 let pool = Arc::new(ConnectionPool::new(config).unwrap());
805
806 {
807 let writer = pool.writer().unwrap();
808 writer.conn().execute_batch(GRAPH_DDL).unwrap();
809 }
810
811 SqlGraphStore::new_scoped(pool, false, "default")
812 }
813
814 fn make_edge(source: Uuid, target: Uuid, relation: EdgeRelation, weight: f64) -> Edge {
815 Edge {
816 id: Uuid::new_v4().into(),
817 source_id: source,
818 target_id: target,
819 relation,
820 weight,
821 created_at: Utc::now(),
822 metadata: None,
823 }
824 }
825
826 #[tokio::test]
827 async fn test_upsert_and_get_edge() {
828 let store = setup_memory_store();
829
830 let src = Uuid::new_v4();
831 let tgt = Uuid::new_v4();
832 let edge = Edge {
833 id: Uuid::new_v4().into(),
834 source_id: src,
835 target_id: tgt,
836 relation: EdgeRelation::Extends,
837 weight: 0.8,
838 created_at: Utc::now(),
839 metadata: None,
840 };
841 let edge_id = edge.id;
842
843 store.upsert_edge(edge).await.unwrap();
844
845 let fetched = store.get_edge(edge_id).await.unwrap();
846 assert!(fetched.is_some());
847 let fetched = fetched.unwrap();
848 assert_eq!(fetched.id, edge_id);
849 assert_eq!(fetched.source_id, src);
850 assert_eq!(fetched.target_id, tgt);
851 assert_eq!(fetched.relation, EdgeRelation::Extends);
852 assert!((fetched.weight - 0.8).abs() < 1e-9);
853 }
854
855 #[tokio::test]
856 async fn test_delete_edge() {
857 let store = setup_memory_store();
858
859 let edge = make_edge(Uuid::new_v4(), Uuid::new_v4(), EdgeRelation::Contains, 1.0);
860 let edge_id = edge.id;
861
862 store.upsert_edge(edge).await.unwrap();
863 assert!(store.get_edge(edge_id).await.unwrap().is_some());
864
865 let deleted = store.delete_edge(edge_id).await.unwrap();
866 assert!(deleted);
867
868 assert!(store.get_edge(edge_id).await.unwrap().is_none());
869
870 let deleted_again = store.delete_edge(edge_id).await.unwrap();
871 assert!(!deleted_again);
872 }
873
874 #[tokio::test]
875 async fn test_count_edges() {
876 let store = setup_memory_store();
877
878 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 0);
879
880 for _ in 0..5 {
881 store
882 .upsert_edge(make_edge(
883 Uuid::new_v4(),
884 Uuid::new_v4(),
885 EdgeRelation::DependsOn,
886 1.0,
887 ))
888 .await
889 .unwrap();
890 }
891
892 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 5);
893 }
894
895 #[tokio::test]
896 async fn test_neighbors_outbound() {
897 let store = setup_memory_store();
898
899 let a = Uuid::new_v4();
900 let b = Uuid::new_v4();
901 let c = Uuid::new_v4();
902 let d = Uuid::new_v4();
903
904 store
905 .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
906 .await
907 .unwrap();
908 store
909 .upsert_edge(make_edge(a, c, EdgeRelation::DependsOn, 0.7))
910 .await
911 .unwrap();
912 store
913 .upsert_edge(make_edge(d, a, EdgeRelation::Extends, 0.5))
914 .await
915 .unwrap();
916
917 let query = NeighborQuery {
918 direction: Direction::Out,
919 relations: None,
920 limit: None,
921 min_weight: None,
922 };
923
924 let hits = store.neighbors(a, query).await.unwrap();
925 assert_eq!(hits.len(), 2);
926
927 let neighbor_ids: Vec<Uuid> = hits.iter().map(|h| h.node_id).collect();
928 assert!(neighbor_ids.contains(&b));
929 assert!(neighbor_ids.contains(&c));
930 assert!(!neighbor_ids.contains(&d));
931 }
932
933 #[tokio::test]
934 async fn test_traverse_depth_2() {
935 let store = setup_memory_store();
936
937 let a = Uuid::new_v4();
938 let b = Uuid::new_v4();
939 let c = Uuid::new_v4();
940 let d = Uuid::new_v4();
941
942 store
943 .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
944 .await
945 .unwrap();
946 store
947 .upsert_edge(make_edge(b, c, EdgeRelation::Extends, 2.0))
948 .await
949 .unwrap();
950 store
951 .upsert_edge(make_edge(c, d, EdgeRelation::Extends, 3.0))
952 .await
953 .unwrap();
954
955 let request = TraversalRequest {
956 roots: vec![a],
957 options: TraversalOptions::new(2).with_direction(Direction::Out),
958 include_roots: true,
959 };
960
961 let paths = store.traverse(request).await.unwrap();
962 assert_eq!(paths.len(), 1);
963
964 let path = &paths[0];
965 let node_ids: Vec<Uuid> = path.nodes.iter().map(|n| n.node_id).collect();
966 assert!(node_ids.contains(&a));
967 assert!(node_ids.contains(&b));
968 assert!(node_ids.contains(&c));
969 assert!(!node_ids.contains(&d));
970 }
971
972 #[tokio::test]
973 async fn test_metadata_roundtrip() {
974 let store = setup_memory_store();
975
976 let src = Uuid::new_v4();
977 let tgt = Uuid::new_v4();
978 let meta = serde_json::json!({"note": "important link", "confidence": 0.95});
979 let edge = Edge {
980 id: Uuid::new_v4().into(),
981 source_id: src,
982 target_id: tgt,
983 relation: EdgeRelation::Implements,
984 weight: 0.9,
985 created_at: Utc::now(),
986 metadata: Some(meta.clone()),
987 };
988 let edge_id = edge.id;
989
990 store.upsert_edge(edge).await.unwrap();
991
992 let fetched = store.get_edge(edge_id).await.unwrap().unwrap();
993 assert_eq!(
994 fetched.metadata.as_ref(),
995 Some(&meta),
996 "metadata must survive a write/read roundtrip via get_edge"
997 );
998
999 let page = store
1001 .query_edges(EdgeFilter::default(), vec![], PageRequest::default())
1002 .await
1003 .unwrap();
1004 let from_query = page
1005 .items
1006 .iter()
1007 .find(|e| e.id == edge_id)
1008 .expect("edge must appear in query_edges result");
1009 assert_eq!(
1010 from_query.metadata.as_ref(),
1011 Some(&meta),
1012 "metadata must survive a write/read roundtrip via query_edges"
1013 );
1014 }
1015
1016 #[tokio::test]
1017 async fn test_upsert_edges_batch() {
1018 let store = setup_memory_store();
1019
1020 let edges: Vec<Edge> = (0..10)
1021 .map(|i| {
1022 make_edge(
1023 Uuid::new_v4(),
1024 Uuid::new_v4(),
1025 EdgeRelation::Implements,
1026 i as f64,
1027 )
1028 })
1029 .collect();
1030
1031 let summary = store.upsert_edges(edges).await.unwrap();
1032 assert_eq!(summary.attempted, 10);
1033 assert_eq!(summary.affected, 10);
1034 assert_eq!(summary.failed, 0);
1035
1036 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 10);
1037 }
1038
1039 #[tokio::test]
1042 async fn graph_duplicate_edges_ignored() {
1043 let store = setup_memory_store();
1044
1045 let src = Uuid::new_v4();
1046 let tgt = Uuid::new_v4();
1047
1048 let edge1 = Edge {
1050 id: Uuid::new_v4().into(),
1051 source_id: src,
1052 target_id: tgt,
1053 relation: EdgeRelation::Extends,
1054 weight: 1.0,
1055 created_at: Utc::now(),
1056 metadata: None,
1057 };
1058 let edge2 = Edge {
1059 id: Uuid::new_v4().into(),
1060 source_id: src,
1061 target_id: tgt,
1062 relation: EdgeRelation::Extends,
1063 weight: 0.5,
1064 created_at: Utc::now(),
1065 metadata: None,
1066 };
1067
1068 store.upsert_edge(edge1).await.unwrap();
1069 store.upsert_edge(edge2).await.unwrap();
1070
1071 assert_eq!(
1072 store.count_edges(EdgeFilter::default()).await.unwrap(),
1073 1,
1074 "duplicate (source, target, relation) triple must be ignored; only one edge must exist"
1075 );
1076 }
1077}