1use std::sync::Arc;
12
13use async_trait::async_trait;
14use chrono::{DateTime, TimeZone, Utc};
15use uuid::Uuid;
16
17use khive_storage::error::StorageError;
18use khive_storage::types::{
19 BatchWriteSummary, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit, NeighborQuery,
20 Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
21};
22use khive_storage::GraphStore;
23use khive_storage::LinkId;
24use khive_storage::StorageCapability;
25use khive_types::EdgeRelation;
26
27use crate::error::SqliteError;
28use crate::pool::ConnectionPool;
29
30fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
32 StorageError::driver(StorageCapability::Graph, op, e)
33}
34
35fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
36 StorageError::driver(StorageCapability::Graph, op, e)
37}
38
39pub struct SqlGraphStore {
41 pool: Arc<ConnectionPool>,
42 is_file_backed: bool,
43 namespace: String,
44}
45
46impl SqlGraphStore {
47 pub fn new_scoped(
49 pool: Arc<ConnectionPool>,
50 is_file_backed: bool,
51 namespace: impl Into<String>,
52 ) -> Self {
53 Self {
54 pool,
55 is_file_backed,
56 namespace: namespace.into(),
57 }
58 }
59
60 fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
61 let config = self.pool.config();
62 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
63 operation: "graph_writer".into(),
64 message: "in-memory databases do not support standalone connections".into(),
65 })?;
66
67 let conn = rusqlite::Connection::open_with_flags(
68 path,
69 rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
70 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
71 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
72 )
73 .map_err(|e| map_err(e, "open_graph_writer"))?;
74
75 conn.busy_timeout(config.busy_timeout)
76 .map_err(|e| map_err(e, "open_graph_writer"))?;
77 conn.pragma_update(None, "foreign_keys", "ON")
78 .map_err(|e| map_err(e, "open_graph_writer"))?;
79 conn.pragma_update(None, "synchronous", "NORMAL")
80 .map_err(|e| map_err(e, "open_graph_writer"))?;
81
82 Ok(conn)
83 }
84
85 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
86 let config = self.pool.config();
87 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
88 operation: "graph_reader".into(),
89 message: "in-memory databases do not support standalone connections".into(),
90 })?;
91
92 let conn = rusqlite::Connection::open_with_flags(
93 path,
94 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
95 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
96 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
97 )
98 .map_err(|e| map_err(e, "open_graph_reader"))?;
99
100 conn.busy_timeout(config.busy_timeout)
101 .map_err(|e| map_err(e, "open_graph_reader"))?;
102 conn.pragma_update(None, "foreign_keys", "ON")
103 .map_err(|e| map_err(e, "open_graph_reader"))?;
104 conn.pragma_update(None, "synchronous", "NORMAL")
105 .map_err(|e| map_err(e, "open_graph_reader"))?;
106
107 Ok(conn)
108 }
109
110 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
111 where
112 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
113 R: Send + 'static,
114 {
115 if self.is_file_backed {
116 let conn = self.open_standalone_writer()?;
117 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
118 .await
119 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
120 } else {
121 let pool = Arc::clone(&self.pool);
122 tokio::task::spawn_blocking(move || {
123 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
124 f(guard.conn()).map_err(|e| map_err(e, op))
125 })
126 .await
127 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
128 }
129 }
130
131 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
132 where
133 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
134 R: Send + 'static,
135 {
136 if self.is_file_backed {
137 let conn = self.open_standalone_reader()?;
138 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
139 .await
140 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
141 } else {
142 let pool = Arc::clone(&self.pool);
143 tokio::task::spawn_blocking(move || {
144 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
145 f(guard.conn()).map_err(|e| map_err(e, op))
146 })
147 .await
148 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
149 }
150 }
151}
152
153fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
158 let id_str: String = row.get(0)?;
159 let source_str: String = row.get(1)?;
160 let target_str: String = row.get(2)?;
161 let relation_str: String = row.get(3)?;
162 let weight: f64 = row.get(4)?;
163 let created_micros: i64 = row.get(5)?;
164 let metadata_str: Option<String> = row.get(6)?;
165
166 let id = parse_uuid(&id_str)?;
167 let source_id = parse_uuid(&source_str)?;
168 let target_id = parse_uuid(&target_str)?;
169 let created_at = micros_to_datetime(created_micros);
170 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
171 rusqlite::Error::FromSqlConversionFailure(3, rusqlite::types::Type::Text, Box::new(e))
172 })?;
173 let metadata = metadata_str.and_then(|s| serde_json::from_str(&s).ok());
174
175 Ok(Edge {
176 id: id.into(),
177 source_id,
178 target_id,
179 relation,
180 weight,
181 created_at,
182 metadata,
183 })
184}
185
186fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
187 Uuid::parse_str(s).map_err(|e| {
188 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
189 })
190}
191
192fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
193 Utc.timestamp_micros(micros)
194 .single()
195 .unwrap_or_else(Utc::now)
196}
197
198fn build_edge_filter_sql(
199 namespace: &str,
200 filter: &EdgeFilter,
201) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
202 let mut conditions: Vec<String> = vec!["namespace = ?1".to_string()];
203 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
204
205 if !filter.ids.is_empty() {
206 let placeholders: Vec<String> = filter
207 .ids
208 .iter()
209 .map(|id| {
210 params.push(Box::new(id.to_string()));
211 format!("?{}", params.len())
212 })
213 .collect();
214 conditions.push(format!("id IN ({})", placeholders.join(",")));
215 }
216
217 if !filter.source_ids.is_empty() {
218 let placeholders: Vec<String> = filter
219 .source_ids
220 .iter()
221 .map(|id| {
222 params.push(Box::new(id.to_string()));
223 format!("?{}", params.len())
224 })
225 .collect();
226 conditions.push(format!("source_id IN ({})", placeholders.join(",")));
227 }
228
229 if !filter.target_ids.is_empty() {
230 let placeholders: Vec<String> = filter
231 .target_ids
232 .iter()
233 .map(|id| {
234 params.push(Box::new(id.to_string()));
235 format!("?{}", params.len())
236 })
237 .collect();
238 conditions.push(format!("target_id IN ({})", placeholders.join(",")));
239 }
240
241 if !filter.relations.is_empty() {
242 let placeholders: Vec<String> = filter
243 .relations
244 .iter()
245 .map(|r| {
246 params.push(Box::new(r.to_string()));
247 format!("?{}", params.len())
248 })
249 .collect();
250 conditions.push(format!("relation IN ({})", placeholders.join(",")));
251 }
252
253 if let Some(min_w) = filter.min_weight {
254 params.push(Box::new(min_w));
255 conditions.push(format!("weight >= ?{}", params.len()));
256 }
257
258 if let Some(max_w) = filter.max_weight {
259 params.push(Box::new(max_w));
260 conditions.push(format!("weight <= ?{}", params.len()));
261 }
262
263 if let Some(ref time_range) = filter.created_at {
264 if let Some(start) = time_range.start {
265 params.push(Box::new(start.timestamp_micros()));
266 conditions.push(format!("created_at >= ?{}", params.len()));
267 }
268 if let Some(end) = time_range.end {
269 params.push(Box::new(end.timestamp_micros()));
270 conditions.push(format!("created_at < ?{}", params.len()));
271 }
272 }
273
274 let clause = format!(" WHERE {}", conditions.join(" AND "));
275 (clause, params)
276}
277
278fn edge_sort_col(field: &EdgeSortField) -> &'static str {
279 match field {
280 EdgeSortField::CreatedAt => "created_at",
281 EdgeSortField::Weight => "weight",
282 EdgeSortField::Relation => "relation",
283 }
284}
285
286#[async_trait]
291impl GraphStore for SqlGraphStore {
292 async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
293 let namespace = self.namespace.clone();
294 let id_str = Uuid::from(edge.id).to_string();
295 let src_str = edge.source_id.to_string();
296 let tgt_str = edge.target_id.to_string();
297 let relation_str = edge.relation.to_string();
298 let metadata_str = edge
299 .metadata
300 .as_ref()
301 .map(|v| serde_json::to_string(v).unwrap_or_default());
302 self.with_writer("upsert_edge", move |conn| {
303 conn.execute(
304 "INSERT OR REPLACE INTO graph_edges \
305 (namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
306 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
307 rusqlite::params![
308 namespace,
309 id_str,
310 src_str,
311 tgt_str,
312 relation_str,
313 edge.weight,
314 edge.created_at.timestamp_micros(),
315 metadata_str,
316 ],
317 )?;
318 Ok(())
319 })
320 .await
321 }
322
323 async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
324 let attempted = edges.len() as u64;
325 let namespace = self.namespace.clone();
326
327 self.with_writer("upsert_edges", move |conn| {
328 conn.execute_batch("BEGIN IMMEDIATE")?;
329 let mut affected = 0u64;
330 let mut failed = 0u64;
331 let mut first_error = String::new();
332
333 for edge in &edges {
334 let id_str = Uuid::from(edge.id).to_string();
335 let src_str = edge.source_id.to_string();
336 let tgt_str = edge.target_id.to_string();
337 let relation_str = edge.relation.to_string();
338 let metadata_str = edge
339 .metadata
340 .as_ref()
341 .map(|v| serde_json::to_string(v).unwrap_or_default());
342 match conn.execute(
343 "INSERT OR REPLACE INTO graph_edges \
344 (namespace, id, source_id, target_id, relation, weight, created_at, metadata) \
345 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
346 rusqlite::params![
347 &namespace,
348 id_str,
349 src_str,
350 tgt_str,
351 relation_str,
352 edge.weight,
353 edge.created_at.timestamp_micros(),
354 metadata_str,
355 ],
356 ) {
357 Ok(_) => affected += 1,
358 Err(e) => {
359 if first_error.is_empty() {
360 first_error = e.to_string();
361 }
362 failed += 1;
363 }
364 }
365 }
366
367 if let Err(e) = conn.execute_batch("COMMIT") {
368 let _ = conn.execute_batch("ROLLBACK");
369 return Err(e);
370 }
371 Ok(BatchWriteSummary {
372 attempted,
373 affected,
374 failed,
375 first_error,
376 })
377 })
378 .await
379 }
380
381 async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
382 let namespace = self.namespace.clone();
383 let id_str = Uuid::from(id).to_string();
384
385 self.with_reader("get_edge", move |conn| {
386 let mut stmt = conn.prepare(
387 "SELECT id, source_id, target_id, relation, weight, created_at, metadata \
388 FROM graph_edges WHERE namespace = ?1 AND id = ?2",
389 )?;
390 let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
391 match rows.next()? {
392 Some(row) => Ok(Some(read_edge(row)?)),
393 None => Ok(None),
394 }
395 })
396 .await
397 }
398
399 async fn delete_edge(&self, id: LinkId) -> Result<bool, StorageError> {
400 let namespace = self.namespace.clone();
401 let id_str = Uuid::from(id).to_string();
402
403 self.with_writer("delete_edge", move |conn| {
404 let deleted = conn.execute(
405 "DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
406 rusqlite::params![namespace, id_str],
407 )?;
408 Ok(deleted > 0)
409 })
410 .await
411 }
412
413 async fn query_edges(
414 &self,
415 filter: EdgeFilter,
416 sort: Vec<SortOrder<EdgeSortField>>,
417 page: PageRequest,
418 ) -> Result<Page<Edge>, StorageError> {
419 let namespace = self.namespace.clone();
420 self.with_reader("query_edges", move |conn| {
421 let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
422
423 let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
424 let total: i64 = {
425 let mut stmt = conn.prepare(&count_sql)?;
426 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
427 filter_params.iter().map(|p| p.as_ref()).collect();
428 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
429 };
430
431 let order_clause = if sort.is_empty() {
432 " ORDER BY created_at DESC".to_string()
433 } else {
434 let parts: Vec<String> = sort
435 .iter()
436 .map(|s| {
437 let dir = match s.direction {
438 SortDirection::Asc => "ASC",
439 SortDirection::Desc => "DESC",
440 };
441 format!("{} {}", edge_sort_col(&s.field), dir)
442 })
443 .collect();
444 format!(" ORDER BY {}", parts.join(", "))
445 };
446
447 let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
448 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
449 all_params.push(Box::new(page.limit as i64));
450 all_params.push(Box::new(page.offset as i64));
451
452 let limit_idx = all_params.len() - 1;
453 let offset_idx = all_params.len();
454
455 let data_sql = format!(
456 "SELECT id, source_id, target_id, relation, weight, created_at, metadata \
457 FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
458 where_clause, order_clause, limit_idx, offset_idx,
459 );
460
461 let mut stmt = conn.prepare(&data_sql)?;
462 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
463 all_params.iter().map(|p| p.as_ref()).collect();
464 let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
465
466 let mut items = Vec::new();
467 for row in rows {
468 items.push(row?);
469 }
470
471 Ok(Page {
472 items,
473 total: Some(total as u64),
474 })
475 })
476 .await
477 }
478
479 async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
480 let namespace = self.namespace.clone();
481 self.with_reader("count_edges", move |conn| {
482 let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
483 let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
484 let mut stmt = conn.prepare(&sql)?;
485 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
486 params.iter().map(|p| p.as_ref()).collect();
487 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
488 Ok(count as u64)
489 })
490 .await
491 }
492
493 async fn neighbors(
494 &self,
495 node_id: Uuid,
496 query: NeighborQuery,
497 ) -> Result<Vec<NeighborHit>, StorageError> {
498 use khive_storage::types::Direction;
499
500 let namespace = self.namespace.clone();
501 let node_str = node_id.to_string();
502
503 self.with_reader("neighbors", move |conn| {
504 let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
505 FROM graph_edges WHERE namespace = ?1 AND source_id = ?2";
506 let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
507 FROM graph_edges WHERE namespace = ?1 AND target_id = ?2";
508
509 let sql = match query.direction {
510 Direction::Out => base_out.to_string(),
511 Direction::In => base_in.to_string(),
512 Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
513 };
514
515 let mut conditions: Vec<String> = Vec::new();
516 let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
517 let mut param_idx = 3;
518
519 if let Some(ref rels) = query.relations {
520 if !rels.is_empty() {
521 let placeholders: Vec<String> = rels
522 .iter()
523 .map(|r| {
524 extra_params.push(Box::new(r.to_string()));
525 let p = format!("?{}", param_idx);
526 param_idx += 1;
527 p
528 })
529 .collect();
530 conditions.push(format!("relation IN ({})", placeholders.join(",")));
531 }
532 }
533
534 if let Some(min_w) = query.min_weight {
535 extra_params.push(Box::new(min_w));
536 conditions.push(format!("weight >= ?{}", param_idx));
537 param_idx += 1;
538 }
539
540 let where_extra = if conditions.is_empty() {
541 String::new()
542 } else {
543 format!(" WHERE {}", conditions.join(" AND "))
544 };
545
546 let limit_clause = if let Some(lim) = query.limit {
547 extra_params.push(Box::new(lim as i64));
548 format!(" LIMIT ?{}", param_idx)
549 } else {
550 String::new()
551 };
552
553 let full_sql = format!(
554 "SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
555 sql, where_extra, limit_clause
556 );
557
558 let mut stmt = conn.prepare(&full_sql)?;
559
560 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
561 all_params.push(Box::new(namespace.clone()));
562 all_params.push(Box::new(node_str.clone()));
563 all_params.extend(extra_params);
564
565 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
566 all_params.iter().map(|p| p.as_ref()).collect();
567
568 let rows = stmt.query_map(param_refs.as_slice(), |row| {
569 let nid_str: String = row.get(0)?;
570 let eid_str: String = row.get(1)?;
571 let relation_str: String = row.get(2)?;
572 let weight: f64 = row.get(3)?;
573 Ok((nid_str, eid_str, relation_str, weight))
574 })?;
575
576 let mut hits = Vec::new();
577 for row in rows {
578 let (nid_str, eid_str, relation_str, weight) = row?;
579 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
580 rusqlite::Error::FromSqlConversionFailure(
581 2,
582 rusqlite::types::Type::Text,
583 Box::new(e),
584 )
585 })?;
586 hits.push(NeighborHit {
587 node_id: parse_uuid(&nid_str)?,
588 edge_id: parse_uuid(&eid_str)?,
589 relation,
590 weight,
591 });
592 }
593
594 Ok(hits)
595 })
596 .await
597 }
598
599 async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
600 use khive_storage::types::Direction;
601
602 if request.roots.is_empty() {
603 return Ok(Vec::new());
604 }
605
606 let roots = request.roots.clone();
607 let opts = request.options.clone();
608 let include_roots = request.include_roots;
609 let namespace = self.namespace.clone();
610
611 self.with_reader("traverse", move |conn| {
612 let mut all_paths: Vec<GraphPath> = Vec::new();
613
614 for root_id in &roots {
615 let root_str = root_id.to_string();
616
617 let (join_condition, next_node) = match opts.direction {
618 Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
619 Direction::In => ("e.target_id = t.node_id", "e.source_id"),
620 Direction::Both => (
621 "(e.source_id = t.node_id OR e.target_id = t.node_id)",
622 "CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
623 ),
624 };
625
626 let mut relation_cond = String::new();
627 let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
628 let mut param_idx = 4;
629
630 if let Some(ref rels) = opts.relations {
631 if !rels.is_empty() {
632 let placeholders: Vec<String> = rels
633 .iter()
634 .map(|r| {
635 relation_params.push(Box::new(r.to_string()));
636 let p = format!("?{}", param_idx);
637 param_idx += 1;
638 p
639 })
640 .collect();
641 relation_cond =
642 format!(" AND e.relation IN ({})", placeholders.join(","));
643 }
644 }
645
646 let mut weight_cond = String::new();
647 if let Some(min_w) = opts.min_weight {
648 relation_params.push(Box::new(min_w));
649 weight_cond = format!(" AND e.weight >= ?{}", param_idx);
650 param_idx += 1;
651 }
652
653 let limit_clause = if let Some(lim) = opts.limit {
654 relation_params.push(Box::new(lim as i64));
655 format!(" LIMIT ?{}", param_idx)
656 } else {
657 String::new()
658 };
659
660 let cte_sql = format!(
661 "WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
662 SELECT ?2, NULL, 0, ?2, 0.0 \
663 UNION ALL \
664 SELECT {next_node}, e.id, t.depth + 1, \
665 t.path || ',' || {next_node}, \
666 t.total_weight + e.weight \
667 FROM graph_edges e \
668 JOIN traversal t ON {join_condition} \
669 WHERE e.namespace = ?1 \
670 AND t.depth < ?3 \
671 AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
672 ) \
673 SELECT node_id, edge_id, depth, path, total_weight \
674 FROM traversal WHERE depth > 0 \
675 ORDER BY depth{limit}",
676 next_node = next_node,
677 join_condition = join_condition,
678 rel_cond = relation_cond,
679 wt_cond = weight_cond,
680 limit = limit_clause,
681 );
682
683 let mut stmt = conn.prepare(&cte_sql)?;
684
685 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
686 all_params.push(Box::new(namespace.clone()));
687 all_params.push(Box::new(root_str.clone()));
688 all_params.push(Box::new(opts.max_depth as i64));
689 all_params.extend(relation_params);
690
691 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
692 all_params.iter().map(|p| p.as_ref()).collect();
693
694 let rows = stmt.query_map(param_refs.as_slice(), |row| {
695 let node_str: String = row.get(0)?;
696 let edge_str: Option<String> = row.get(1)?;
697 let depth: i64 = row.get(2)?;
698 let _path: String = row.get(3)?;
699 let total_weight: f64 = row.get(4)?;
700 Ok((node_str, edge_str, depth, total_weight))
701 })?;
702
703 let mut nodes = Vec::new();
704 let mut max_weight = 0.0f64;
705
706 if include_roots {
707 nodes.push(PathNode {
708 node_id: *root_id,
709 via_edge: None,
710 depth: 0,
711 });
712 }
713
714 for row in rows {
715 let (node_str, edge_str, depth, total_weight) = row?;
716 let node_id = parse_uuid(&node_str)?;
717 let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
718 nodes.push(PathNode {
719 node_id,
720 via_edge,
721 depth: depth as usize,
722 });
723 if total_weight > max_weight {
724 max_weight = total_weight;
725 }
726 }
727
728 if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
729 all_paths.push(GraphPath {
730 root_id: *root_id,
731 nodes,
732 total_weight: max_weight,
733 });
734 }
735 }
736
737 Ok(all_paths)
738 })
739 .await
740 }
741}
742
743const GRAPH_DDL: &str = "\
748 CREATE TABLE IF NOT EXISTS graph_edges (\
749 namespace TEXT NOT NULL,\
750 id TEXT NOT NULL,\
751 source_id TEXT NOT NULL,\
752 target_id TEXT NOT NULL,\
753 relation TEXT NOT NULL,\
754 weight REAL NOT NULL DEFAULT 1.0,\
755 created_at INTEGER NOT NULL,\
756 metadata TEXT,\
757 PRIMARY KEY (namespace, id)\
758 );\
759 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_source ON graph_edges(namespace, source_id);\
760 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_target ON graph_edges(namespace, target_id);\
761 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_relation ON graph_edges(namespace, relation);\
762 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_src_rel ON graph_edges(namespace, source_id, relation);\
763 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_tgt_rel ON graph_edges(namespace, target_id, relation);\
764";
765
766pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
767 conn.execute_batch(GRAPH_DDL)
768}
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773 use crate::pool::PoolConfig;
774 use khive_storage::types::{Direction, TraversalOptions};
775
776 fn setup_memory_store() -> SqlGraphStore {
777 let config = PoolConfig {
778 path: None,
779 ..PoolConfig::default()
780 };
781 let pool = Arc::new(ConnectionPool::new(config).unwrap());
782
783 {
784 let writer = pool.writer().unwrap();
785 writer.conn().execute_batch(GRAPH_DDL).unwrap();
786 }
787
788 SqlGraphStore::new_scoped(pool, false, "default")
789 }
790
791 fn make_edge(source: Uuid, target: Uuid, relation: EdgeRelation, weight: f64) -> Edge {
792 Edge {
793 id: Uuid::new_v4().into(),
794 source_id: source,
795 target_id: target,
796 relation,
797 weight,
798 created_at: Utc::now(),
799 metadata: None,
800 }
801 }
802
803 #[tokio::test]
804 async fn test_upsert_and_get_edge() {
805 let store = setup_memory_store();
806
807 let src = Uuid::new_v4();
808 let tgt = Uuid::new_v4();
809 let edge = Edge {
810 id: Uuid::new_v4().into(),
811 source_id: src,
812 target_id: tgt,
813 relation: EdgeRelation::Extends,
814 weight: 0.8,
815 created_at: Utc::now(),
816 metadata: None,
817 };
818 let edge_id = edge.id;
819
820 store.upsert_edge(edge).await.unwrap();
821
822 let fetched = store.get_edge(edge_id).await.unwrap();
823 assert!(fetched.is_some());
824 let fetched = fetched.unwrap();
825 assert_eq!(fetched.id, edge_id);
826 assert_eq!(fetched.source_id, src);
827 assert_eq!(fetched.target_id, tgt);
828 assert_eq!(fetched.relation, EdgeRelation::Extends);
829 assert!((fetched.weight - 0.8).abs() < 1e-9);
830 }
831
832 #[tokio::test]
833 async fn test_delete_edge() {
834 let store = setup_memory_store();
835
836 let edge = make_edge(Uuid::new_v4(), Uuid::new_v4(), EdgeRelation::Contains, 1.0);
837 let edge_id = edge.id;
838
839 store.upsert_edge(edge).await.unwrap();
840 assert!(store.get_edge(edge_id).await.unwrap().is_some());
841
842 let deleted = store.delete_edge(edge_id).await.unwrap();
843 assert!(deleted);
844
845 assert!(store.get_edge(edge_id).await.unwrap().is_none());
846
847 let deleted_again = store.delete_edge(edge_id).await.unwrap();
848 assert!(!deleted_again);
849 }
850
851 #[tokio::test]
852 async fn test_count_edges() {
853 let store = setup_memory_store();
854
855 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 0);
856
857 for _ in 0..5 {
858 store
859 .upsert_edge(make_edge(
860 Uuid::new_v4(),
861 Uuid::new_v4(),
862 EdgeRelation::DependsOn,
863 1.0,
864 ))
865 .await
866 .unwrap();
867 }
868
869 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 5);
870 }
871
872 #[tokio::test]
873 async fn test_neighbors_outbound() {
874 let store = setup_memory_store();
875
876 let a = Uuid::new_v4();
877 let b = Uuid::new_v4();
878 let c = Uuid::new_v4();
879 let d = Uuid::new_v4();
880
881 store
882 .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
883 .await
884 .unwrap();
885 store
886 .upsert_edge(make_edge(a, c, EdgeRelation::DependsOn, 0.7))
887 .await
888 .unwrap();
889 store
890 .upsert_edge(make_edge(d, a, EdgeRelation::Extends, 0.5))
891 .await
892 .unwrap();
893
894 let query = NeighborQuery {
895 direction: Direction::Out,
896 relations: None,
897 limit: None,
898 min_weight: None,
899 };
900
901 let hits = store.neighbors(a, query).await.unwrap();
902 assert_eq!(hits.len(), 2);
903
904 let neighbor_ids: Vec<Uuid> = hits.iter().map(|h| h.node_id).collect();
905 assert!(neighbor_ids.contains(&b));
906 assert!(neighbor_ids.contains(&c));
907 assert!(!neighbor_ids.contains(&d));
908 }
909
910 #[tokio::test]
911 async fn test_traverse_depth_2() {
912 let store = setup_memory_store();
913
914 let a = Uuid::new_v4();
915 let b = Uuid::new_v4();
916 let c = Uuid::new_v4();
917 let d = Uuid::new_v4();
918
919 store
920 .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
921 .await
922 .unwrap();
923 store
924 .upsert_edge(make_edge(b, c, EdgeRelation::Extends, 2.0))
925 .await
926 .unwrap();
927 store
928 .upsert_edge(make_edge(c, d, EdgeRelation::Extends, 3.0))
929 .await
930 .unwrap();
931
932 let request = TraversalRequest {
933 roots: vec![a],
934 options: TraversalOptions::new(2).with_direction(Direction::Out),
935 include_roots: true,
936 };
937
938 let paths = store.traverse(request).await.unwrap();
939 assert_eq!(paths.len(), 1);
940
941 let path = &paths[0];
942 let node_ids: Vec<Uuid> = path.nodes.iter().map(|n| n.node_id).collect();
943 assert!(node_ids.contains(&a));
944 assert!(node_ids.contains(&b));
945 assert!(node_ids.contains(&c));
946 assert!(!node_ids.contains(&d));
947 }
948
949 #[tokio::test]
950 async fn test_metadata_roundtrip() {
951 let store = setup_memory_store();
952
953 let src = Uuid::new_v4();
954 let tgt = Uuid::new_v4();
955 let meta = serde_json::json!({"note": "important link", "confidence": 0.95});
956 let edge = Edge {
957 id: Uuid::new_v4().into(),
958 source_id: src,
959 target_id: tgt,
960 relation: EdgeRelation::Implements,
961 weight: 0.9,
962 created_at: Utc::now(),
963 metadata: Some(meta.clone()),
964 };
965 let edge_id = edge.id;
966
967 store.upsert_edge(edge).await.unwrap();
968
969 let fetched = store.get_edge(edge_id).await.unwrap().unwrap();
970 assert_eq!(
971 fetched.metadata.as_ref(),
972 Some(&meta),
973 "metadata must survive a write/read roundtrip via get_edge"
974 );
975
976 let page = store
978 .query_edges(EdgeFilter::default(), vec![], PageRequest::default())
979 .await
980 .unwrap();
981 let from_query = page
982 .items
983 .iter()
984 .find(|e| e.id == edge_id)
985 .expect("edge must appear in query_edges result");
986 assert_eq!(
987 from_query.metadata.as_ref(),
988 Some(&meta),
989 "metadata must survive a write/read roundtrip via query_edges"
990 );
991 }
992
993 #[tokio::test]
994 async fn test_upsert_edges_batch() {
995 let store = setup_memory_store();
996
997 let edges: Vec<Edge> = (0..10)
998 .map(|i| {
999 make_edge(
1000 Uuid::new_v4(),
1001 Uuid::new_v4(),
1002 EdgeRelation::Implements,
1003 i as f64,
1004 )
1005 })
1006 .collect();
1007
1008 let summary = store.upsert_edges(edges).await.unwrap();
1009 assert_eq!(summary.attempted, 10);
1010 assert_eq!(summary.affected, 10);
1011 assert_eq!(summary.failed, 0);
1012
1013 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 10);
1014 }
1015}