1use std::sync::Arc;
4
5use async_trait::async_trait;
6use chrono::{DateTime, TimeZone, Utc};
7use uuid::Uuid;
8
9use khive_storage::error::StorageError;
10use khive_storage::types::{
11 BatchWriteSummary, DeleteMode, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit,
12 NeighborQuery, Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
13};
14use khive_storage::GraphStore;
15use khive_storage::LinkId;
16use khive_storage::StorageCapability;
17use khive_types::EdgeRelation;
18
19use crate::error::SqliteError;
20use crate::pool::ConnectionPool;
21
22fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
24 StorageError::driver(StorageCapability::Graph, op, e)
25}
26
27fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
28 StorageError::driver(StorageCapability::Graph, op, e)
29}
30
31pub struct SqlGraphStore {
33 pool: Arc<ConnectionPool>,
34 is_file_backed: bool,
35 namespace: String,
36}
37
38impl SqlGraphStore {
39 pub fn new_scoped(
41 pool: Arc<ConnectionPool>,
42 is_file_backed: bool,
43 namespace: impl Into<String>,
44 ) -> Self {
45 Self {
46 pool,
47 is_file_backed,
48 namespace: namespace.into(),
49 }
50 }
51
52 fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
53 let config = self.pool.config();
54 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
55 operation: "graph_writer".into(),
56 message: "in-memory databases do not support standalone connections".into(),
57 })?;
58
59 let conn = rusqlite::Connection::open_with_flags(
60 path,
61 rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
62 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
63 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
64 )
65 .map_err(|e| map_err(e, "open_graph_writer"))?;
66
67 conn.busy_timeout(config.busy_timeout)
68 .map_err(|e| map_err(e, "open_graph_writer"))?;
69 conn.pragma_update(None, "foreign_keys", "ON")
70 .map_err(|e| map_err(e, "open_graph_writer"))?;
71 conn.pragma_update(None, "synchronous", "NORMAL")
72 .map_err(|e| map_err(e, "open_graph_writer"))?;
73
74 Ok(conn)
75 }
76
77 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
78 let config = self.pool.config();
79 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
80 operation: "graph_reader".into(),
81 message: "in-memory databases do not support standalone connections".into(),
82 })?;
83
84 let conn = rusqlite::Connection::open_with_flags(
85 path,
86 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
87 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
88 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
89 )
90 .map_err(|e| map_err(e, "open_graph_reader"))?;
91
92 conn.busy_timeout(config.busy_timeout)
93 .map_err(|e| map_err(e, "open_graph_reader"))?;
94 conn.pragma_update(None, "foreign_keys", "ON")
95 .map_err(|e| map_err(e, "open_graph_reader"))?;
96 conn.pragma_update(None, "synchronous", "NORMAL")
97 .map_err(|e| map_err(e, "open_graph_reader"))?;
98
99 Ok(conn)
100 }
101
102 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
103 where
104 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
105 R: Send + 'static,
106 {
107 if self.is_file_backed {
108 let conn = self.open_standalone_writer()?;
109 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
110 .await
111 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
112 } else {
113 let pool = Arc::clone(&self.pool);
114 tokio::task::spawn_blocking(move || {
115 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
116 f(guard.conn()).map_err(|e| map_err(e, op))
117 })
118 .await
119 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
120 }
121 }
122
123 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
124 where
125 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
126 R: Send + 'static,
127 {
128 if self.is_file_backed {
129 let conn = self.open_standalone_reader()?;
130 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
131 .await
132 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
133 } else {
134 let pool = Arc::clone(&self.pool);
135 tokio::task::spawn_blocking(move || {
136 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
137 f(guard.conn()).map_err(|e| map_err(e, op))
138 })
139 .await
140 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
141 }
142 }
143}
144
145fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
150 let namespace: String = row.get(0)?;
151 let id_str: String = row.get(1)?;
152 let source_str: String = row.get(2)?;
153 let target_str: String = row.get(3)?;
154 let relation_str: String = row.get(4)?;
155 let weight: f64 = row.get(5)?;
156 let created_micros: i64 = row.get(6)?;
157 let updated_micros: i64 = row.get(7)?;
158 let deleted_micros: Option<i64> = row.get(8)?;
159 let metadata_str: Option<String> = row.get(9)?;
160 let target_backend: Option<String> = row.get(10)?;
161
162 let id = parse_uuid(&id_str)?;
163 let source_id = parse_uuid(&source_str)?;
164 let target_id = parse_uuid(&target_str)?;
165 let created_at = micros_to_datetime(created_micros);
166 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
167 rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e))
168 })?;
169 let metadata = match metadata_str {
170 Some(s) => {
171 let v = serde_json::from_str(&s).map_err(|e| {
172 rusqlite::Error::FromSqlConversionFailure(
173 9,
174 rusqlite::types::Type::Text,
175 Box::new(e),
176 )
177 })?;
178 Some(v)
179 }
180 None => None,
181 };
182
183 Ok(Edge {
184 id: id.into(),
185 namespace,
186 source_id,
187 target_id,
188 relation,
189 weight,
190 created_at,
191 updated_at: micros_to_datetime(updated_micros),
192 deleted_at: deleted_micros.map(micros_to_datetime),
193 metadata,
194 target_backend,
195 })
196}
197
198fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
199 Uuid::parse_str(s).map_err(|e| {
200 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
201 })
202}
203
204fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
205 Utc.timestamp_micros(micros)
206 .single()
207 .unwrap_or_else(Utc::now)
208}
209
210fn build_edge_filter_sql(
211 namespace: &str,
212 filter: &EdgeFilter,
213) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
214 let mut conditions: Vec<String> = vec![
215 "namespace = ?1".to_string(),
216 "deleted_at IS NULL".to_string(),
217 ];
218 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
219
220 if !filter.ids.is_empty() {
221 let placeholders: Vec<String> = filter
222 .ids
223 .iter()
224 .map(|id| {
225 params.push(Box::new(id.to_string()));
226 format!("?{}", params.len())
227 })
228 .collect();
229 conditions.push(format!("id IN ({})", placeholders.join(",")));
230 }
231
232 if !filter.source_ids.is_empty() {
233 let placeholders: Vec<String> = filter
234 .source_ids
235 .iter()
236 .map(|id| {
237 params.push(Box::new(id.to_string()));
238 format!("?{}", params.len())
239 })
240 .collect();
241 conditions.push(format!("source_id IN ({})", placeholders.join(",")));
242 }
243
244 if !filter.target_ids.is_empty() {
245 let placeholders: Vec<String> = filter
246 .target_ids
247 .iter()
248 .map(|id| {
249 params.push(Box::new(id.to_string()));
250 format!("?{}", params.len())
251 })
252 .collect();
253 conditions.push(format!("target_id IN ({})", placeholders.join(",")));
254 }
255
256 if !filter.relations.is_empty() {
257 let placeholders: Vec<String> = filter
258 .relations
259 .iter()
260 .map(|r| {
261 params.push(Box::new(r.to_string()));
262 format!("?{}", params.len())
263 })
264 .collect();
265 conditions.push(format!("relation IN ({})", placeholders.join(",")));
266 }
267
268 if let Some(min_w) = filter.min_weight {
269 params.push(Box::new(min_w));
270 conditions.push(format!("weight >= ?{}", params.len()));
271 }
272
273 if let Some(max_w) = filter.max_weight {
274 params.push(Box::new(max_w));
275 conditions.push(format!("weight <= ?{}", params.len()));
276 }
277
278 if let Some(ref time_range) = filter.created_at {
279 if let Some(start) = time_range.start {
280 params.push(Box::new(start.timestamp_micros()));
281 conditions.push(format!("created_at >= ?{}", params.len()));
282 }
283 if let Some(end) = time_range.end {
284 params.push(Box::new(end.timestamp_micros()));
285 conditions.push(format!("created_at < ?{}", params.len()));
286 }
287 }
288
289 let clause = format!(" WHERE {}", conditions.join(" AND "));
290 (clause, params)
291}
292
293fn edge_sort_col(field: &EdgeSortField) -> &'static str {
294 match field {
295 EdgeSortField::CreatedAt => "created_at",
296 EdgeSortField::Weight => "weight",
297 EdgeSortField::Relation => "relation",
298 }
299}
300
301fn canonical_edge_endpoints(
310 relation: EdgeRelation,
311 source_id: Uuid,
312 target_id: Uuid,
313) -> (Uuid, Uuid) {
314 if relation.is_symmetric() && target_id < source_id {
315 (target_id, source_id)
316 } else {
317 (source_id, target_id)
318 }
319}
320
321#[async_trait]
322impl GraphStore for SqlGraphStore {
323 async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
324 let namespace = self.namespace.clone();
325 if edge.namespace != namespace {
326 return Err(StorageError::InvalidInput {
327 capability: StorageCapability::Graph,
328 operation: "upsert_edge".into(),
329 message: format!(
330 "edge namespace {:?} does not match store namespace {:?}",
331 edge.namespace, namespace
332 ),
333 });
334 }
335 let id_str = Uuid::from(edge.id).to_string();
336 let (source_id, target_id) =
337 canonical_edge_endpoints(edge.relation, edge.source_id, edge.target_id);
338 let src_str = source_id.to_string();
339 let tgt_str = target_id.to_string();
340 let relation_str = edge.relation.to_string();
341 let metadata_str = edge
342 .metadata
343 .as_ref()
344 .map(serde_json::to_string)
345 .transpose()
346 .map_err(|e| StorageError::driver(StorageCapability::Graph, "upsert_edge", e))?;
347 self.with_writer("upsert_edge", move |conn| {
348 conn.execute(
349 "INSERT INTO graph_edges \
350 (namespace, id, source_id, target_id, relation, weight, \
351 created_at, updated_at, deleted_at, metadata, target_backend) \
352 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) \
353 ON CONFLICT(namespace, id) DO UPDATE SET \
354 source_id = excluded.source_id, \
355 target_id = excluded.target_id, \
356 relation = excluded.relation, \
357 weight = excluded.weight, \
358 updated_at = excluded.updated_at, \
359 deleted_at = NULL, \
360 metadata = excluded.metadata, \
361 target_backend = excluded.target_backend \
362 ON CONFLICT(namespace, source_id, target_id, relation) DO UPDATE SET \
363 weight = excluded.weight, \
364 updated_at = excluded.updated_at, \
365 deleted_at = NULL, \
366 metadata = excluded.metadata, \
367 target_backend = excluded.target_backend",
368 rusqlite::params![
369 namespace,
370 id_str,
371 src_str,
372 tgt_str,
373 relation_str,
374 edge.weight,
375 edge.created_at.timestamp_micros(),
376 edge.updated_at.timestamp_micros(),
377 edge.deleted_at.map(|t| t.timestamp_micros()),
378 metadata_str,
379 edge.target_backend,
380 ],
381 )?;
382 Ok(())
383 })
384 .await
385 }
386
387 async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
388 let attempted = edges.len() as u64;
389 let namespace = self.namespace.clone();
390
391 for edge in &edges {
393 if edge.namespace != namespace {
394 return Err(StorageError::InvalidInput {
395 capability: StorageCapability::Graph,
396 operation: "upsert_edges".into(),
397 message: format!(
398 "edge namespace {:?} does not match store namespace {:?}",
399 edge.namespace, namespace
400 ),
401 });
402 }
403 }
404
405 self.with_writer("upsert_edges", move |conn| {
406 conn.execute_batch("BEGIN IMMEDIATE")?;
407 let mut affected = 0u64;
408
409 for edge in &edges {
410 let id_str = Uuid::from(edge.id).to_string();
411 let (canon_src, canon_tgt) =
412 canonical_edge_endpoints(edge.relation, edge.source_id, edge.target_id);
413 let src_str = canon_src.to_string();
414 let tgt_str = canon_tgt.to_string();
415 let relation_str = edge.relation.to_string();
416 let metadata_str = edge
417 .metadata
418 .as_ref()
419 .map(serde_json::to_string)
420 .transpose()
421 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
422 if let Err(e) = conn.execute(
423 "INSERT INTO graph_edges \
424 (namespace, id, source_id, target_id, relation, weight, \
425 created_at, updated_at, deleted_at, metadata, target_backend) \
426 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) \
427 ON CONFLICT(namespace, id) DO UPDATE SET \
428 source_id = excluded.source_id, \
429 target_id = excluded.target_id, \
430 relation = excluded.relation, \
431 weight = excluded.weight, \
432 updated_at = excluded.updated_at, \
433 deleted_at = NULL, \
434 metadata = excluded.metadata, \
435 target_backend = excluded.target_backend \
436 ON CONFLICT(namespace, source_id, target_id, relation) DO UPDATE SET \
437 weight = excluded.weight, \
438 updated_at = excluded.updated_at, \
439 deleted_at = NULL, \
440 metadata = excluded.metadata, \
441 target_backend = excluded.target_backend",
442 rusqlite::params![
443 &namespace,
444 id_str,
445 src_str,
446 tgt_str,
447 relation_str,
448 edge.weight,
449 edge.created_at.timestamp_micros(),
450 edge.updated_at.timestamp_micros(),
451 edge.deleted_at.map(|t| t.timestamp_micros()),
452 metadata_str,
453 edge.target_backend.as_deref(),
454 ],
455 ) {
456 let _ = conn.execute_batch("ROLLBACK");
457 return Err(e);
458 }
459 affected += 1;
460 }
461
462 if let Err(e) = conn.execute_batch("COMMIT") {
463 let _ = conn.execute_batch("ROLLBACK");
464 return Err(e);
465 }
466 Ok(BatchWriteSummary {
467 attempted,
468 affected,
469 failed: 0,
470 first_error: String::new(),
471 })
472 })
473 .await
474 }
475
476 async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
477 let namespace = self.namespace.clone();
478 let id_str = Uuid::from(id).to_string();
479
480 self.with_reader("get_edge", move |conn| {
481 let mut stmt = conn.prepare(
482 "SELECT namespace, id, source_id, target_id, relation, weight, \
483 created_at, updated_at, deleted_at, metadata, target_backend \
484 FROM graph_edges WHERE namespace = ?1 AND id = ?2 AND deleted_at IS NULL",
485 )?;
486 let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
487 match rows.next()? {
488 Some(row) => Ok(Some(read_edge(row)?)),
489 None => Ok(None),
490 }
491 })
492 .await
493 }
494
495 async fn get_edge_including_deleted(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
496 let namespace = self.namespace.clone();
497 let id_str = Uuid::from(id).to_string();
498
499 self.with_reader("get_edge_including_deleted", move |conn| {
500 let mut stmt = conn.prepare(
501 "SELECT namespace, id, source_id, target_id, relation, weight, \
502 created_at, updated_at, deleted_at, metadata, target_backend \
503 FROM graph_edges WHERE namespace = ?1 AND id = ?2",
504 )?;
505 let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
506 match rows.next()? {
507 Some(row) => Ok(Some(read_edge(row)?)),
508 None => Ok(None),
509 }
510 })
511 .await
512 }
513
514 async fn delete_edge(&self, id: LinkId, mode: DeleteMode) -> Result<bool, StorageError> {
515 let namespace = self.namespace.clone();
516 let id_str = Uuid::from(id).to_string();
517
518 self.with_writer("delete_edge", move |conn| {
519 let affected = match mode {
520 DeleteMode::Soft => conn.execute(
521 "UPDATE graph_edges SET deleted_at = ?3, updated_at = ?3 \
522 WHERE namespace = ?1 AND id = ?2 AND deleted_at IS NULL",
523 rusqlite::params![namespace, id_str, chrono::Utc::now().timestamp_micros(),],
524 )?,
525 DeleteMode::Hard => conn.execute(
526 "DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
527 rusqlite::params![namespace, id_str],
528 )?,
529 };
530 Ok(affected > 0)
531 })
532 .await
533 }
534
535 async fn query_edges(
536 &self,
537 filter: EdgeFilter,
538 sort: Vec<SortOrder<EdgeSortField>>,
539 page: PageRequest,
540 ) -> Result<Page<Edge>, StorageError> {
541 let namespace = self.namespace.clone();
542 self.with_reader("query_edges", move |conn| {
543 let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
544
545 let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
546 let total: i64 = {
547 let mut stmt = conn.prepare(&count_sql)?;
548 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
549 filter_params.iter().map(|p| p.as_ref()).collect();
550 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
551 };
552
553 let order_clause = if sort.is_empty() {
554 " ORDER BY created_at DESC".to_string()
555 } else {
556 let parts: Vec<String> = sort
557 .iter()
558 .map(|s| {
559 let dir = match s.direction {
560 SortDirection::Asc => "ASC",
561 SortDirection::Desc => "DESC",
562 };
563 format!("{} {}", edge_sort_col(&s.field), dir)
564 })
565 .collect();
566 format!(" ORDER BY {}", parts.join(", "))
567 };
568
569 let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
570 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
571 all_params.push(Box::new(page.limit as i64));
572 all_params.push(Box::new(page.offset as i64));
573
574 let limit_idx = all_params.len() - 1;
575 let offset_idx = all_params.len();
576
577 let data_sql = format!(
578 "SELECT namespace, id, source_id, target_id, relation, weight, \
579 created_at, updated_at, deleted_at, metadata, target_backend \
580 FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
581 where_clause, order_clause, limit_idx, offset_idx,
582 );
583
584 let mut stmt = conn.prepare(&data_sql)?;
585 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
586 all_params.iter().map(|p| p.as_ref()).collect();
587 let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
588
589 let mut items = Vec::new();
590 for row in rows {
591 items.push(row?);
592 }
593
594 Ok(Page {
595 items,
596 total: Some(total as u64),
597 })
598 })
599 .await
600 }
601
602 async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
603 let namespace = self.namespace.clone();
604 self.with_reader("count_edges", move |conn| {
605 let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
606 let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
607 let mut stmt = conn.prepare(&sql)?;
608 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
609 params.iter().map(|p| p.as_ref()).collect();
610 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
611 Ok(count as u64)
612 })
613 .await
614 }
615
616 async fn neighbors(
617 &self,
618 node_id: Uuid,
619 query: NeighborQuery,
620 ) -> Result<Vec<NeighborHit>, StorageError> {
621 use khive_storage::types::Direction;
622
623 let namespace = self.namespace.clone();
624 let node_str = node_id.to_string();
625
626 self.with_reader("neighbors", move |conn| {
627 let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
628 FROM graph_edges \
629 WHERE namespace = ?1 AND source_id = ?2 AND deleted_at IS NULL";
630 let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
631 FROM graph_edges \
632 WHERE namespace = ?1 AND target_id = ?2 AND deleted_at IS NULL";
633
634 let sql = match query.direction {
635 Direction::Out => base_out.to_string(),
636 Direction::In => base_in.to_string(),
637 Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
638 };
639
640 let mut conditions: Vec<String> = Vec::new();
641 let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
642 let mut param_idx = 3;
643
644 if let Some(ref rels) = query.relations {
645 if !rels.is_empty() {
646 let placeholders: Vec<String> = rels
647 .iter()
648 .map(|r| {
649 extra_params.push(Box::new(r.to_string()));
650 let p = format!("?{}", param_idx);
651 param_idx += 1;
652 p
653 })
654 .collect();
655 conditions.push(format!("relation IN ({})", placeholders.join(",")));
656 }
657 }
658
659 if let Some(min_w) = query.min_weight {
660 extra_params.push(Box::new(min_w));
661 conditions.push(format!("weight >= ?{}", param_idx));
662 param_idx += 1;
663 }
664
665 let where_extra = if conditions.is_empty() {
666 String::new()
667 } else {
668 format!(" WHERE {}", conditions.join(" AND "))
669 };
670
671 let limit_clause = if let Some(lim) = query.limit {
672 extra_params.push(Box::new(lim as i64));
673 format!(" LIMIT ?{}", param_idx)
674 } else {
675 String::new()
676 };
677
678 let full_sql = format!(
679 "SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
680 sql, where_extra, limit_clause
681 );
682
683 let mut stmt = conn.prepare(&full_sql)?;
684
685 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
686 all_params.push(Box::new(namespace.clone()));
687 all_params.push(Box::new(node_str.clone()));
688 all_params.extend(extra_params);
689
690 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
691 all_params.iter().map(|p| p.as_ref()).collect();
692
693 let rows = stmt.query_map(param_refs.as_slice(), |row| {
694 let nid_str: String = row.get(0)?;
695 let eid_str: String = row.get(1)?;
696 let relation_str: String = row.get(2)?;
697 let weight: f64 = row.get(3)?;
698 Ok((nid_str, eid_str, relation_str, weight))
699 })?;
700
701 let mut hits = Vec::new();
702 for row in rows {
703 let (nid_str, eid_str, relation_str, weight) = row?;
704 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
705 rusqlite::Error::FromSqlConversionFailure(
706 2,
707 rusqlite::types::Type::Text,
708 Box::new(e),
709 )
710 })?;
711 hits.push(NeighborHit {
712 node_id: parse_uuid(&nid_str)?,
713 edge_id: parse_uuid(&eid_str)?,
714 relation,
715 weight,
716 name: None,
717 kind: None,
718 });
719 }
720
721 Ok(hits)
722 })
723 .await
724 }
725
726 async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
727 use khive_storage::types::Direction;
728
729 if request.roots.is_empty() {
730 return Ok(Vec::new());
731 }
732
733 let roots = request.roots.clone();
734 let opts = request.options.clone();
735 let include_roots = request.include_roots;
736 let namespace = self.namespace.clone();
737
738 self.with_reader("traverse", move |conn| {
739 let mut all_paths: Vec<GraphPath> = Vec::new();
740
741 for root_id in &roots {
742 let root_str = root_id.to_string();
743
744 let (join_condition, next_node) = match opts.direction {
745 Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
746 Direction::In => ("e.target_id = t.node_id", "e.source_id"),
747 Direction::Both => (
748 "(e.source_id = t.node_id OR e.target_id = t.node_id)",
749 "CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
750 ),
751 };
752
753 let mut relation_cond = String::new();
754 let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
755 let mut param_idx = 4;
756
757 if let Some(ref rels) = opts.relations {
758 if !rels.is_empty() {
759 let placeholders: Vec<String> = rels
760 .iter()
761 .map(|r| {
762 relation_params.push(Box::new(r.to_string()));
763 let p = format!("?{}", param_idx);
764 param_idx += 1;
765 p
766 })
767 .collect();
768 relation_cond =
769 format!(" AND e.relation IN ({})", placeholders.join(","));
770 }
771 }
772
773 let mut weight_cond = String::new();
774 if let Some(min_w) = opts.min_weight {
775 relation_params.push(Box::new(min_w));
776 weight_cond = format!(" AND e.weight >= ?{}", param_idx);
777 param_idx += 1;
778 }
779
780 let limit_clause = if let Some(lim) = opts.limit {
781 relation_params.push(Box::new(lim as i64));
782 format!(" LIMIT ?{}", param_idx)
783 } else {
784 String::new()
785 };
786
787 let cte_sql = format!(
788 "WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
789 SELECT ?2, NULL, 0, ?2, 0.0 \
790 UNION ALL \
791 SELECT {next_node}, e.id, t.depth + 1, \
792 t.path || ',' || {next_node}, \
793 t.total_weight + e.weight \
794 FROM graph_edges e \
795 JOIN traversal t ON {join_condition} \
796 WHERE e.namespace = ?1 \
797 AND e.deleted_at IS NULL \
798 AND t.depth < ?3 \
799 AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
800 ) \
801 SELECT node_id, edge_id, depth, path, total_weight \
802 FROM traversal WHERE depth > 0 \
803 ORDER BY depth{limit}",
804 next_node = next_node,
805 join_condition = join_condition,
806 rel_cond = relation_cond,
807 wt_cond = weight_cond,
808 limit = limit_clause,
809 );
810
811 let mut stmt = conn.prepare(&cte_sql)?;
812
813 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
814 all_params.push(Box::new(namespace.clone()));
815 all_params.push(Box::new(root_str.clone()));
816 all_params.push(Box::new(opts.max_depth as i64));
817 all_params.extend(relation_params);
818
819 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
820 all_params.iter().map(|p| p.as_ref()).collect();
821
822 let rows = stmt.query_map(param_refs.as_slice(), |row| {
823 let node_str: String = row.get(0)?;
824 let edge_str: Option<String> = row.get(1)?;
825 let depth: i64 = row.get(2)?;
826 let _path: String = row.get(3)?;
827 let total_weight: f64 = row.get(4)?;
828 Ok((node_str, edge_str, depth, total_weight))
829 })?;
830
831 let mut nodes = Vec::new();
832 let mut max_weight = 0.0f64;
833 let mut seen: std::collections::HashSet<Uuid> =
838 std::collections::HashSet::new();
839
840 if include_roots {
841 seen.insert(*root_id);
842 nodes.push(PathNode {
843 node_id: *root_id,
844 via_edge: None,
845 depth: 0,
846 name: None,
847 kind: None,
848 });
849 }
850
851 for row in rows {
852 let (node_str, edge_str, depth, total_weight) = row?;
853 let node_id = parse_uuid(&node_str)?;
854 if !seen.insert(node_id) {
856 continue;
857 }
858 let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
859 nodes.push(PathNode {
860 node_id,
861 via_edge,
862 depth: depth as usize,
863 name: None,
864 kind: None,
865 });
866 if total_weight > max_weight {
867 max_weight = total_weight;
868 }
869 }
870
871 if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
872 all_paths.push(GraphPath {
873 root_id: *root_id,
874 nodes,
875 total_weight: max_weight,
876 });
877 }
878 }
879
880 Ok(all_paths)
881 })
882 .await
883 }
884
885 async fn purge_incident_edges(&self, node_id: Uuid) -> Result<u64, StorageError> {
886 let namespace = self.namespace.clone();
887 let id_str = node_id.to_string();
888 self.with_writer("purge_incident_edges", move |conn| {
889 let affected = conn.execute(
890 "DELETE FROM graph_edges \
891 WHERE namespace = ?1 AND (source_id = ?2 OR target_id = ?2)",
892 rusqlite::params![namespace, id_str],
893 )?;
894 Ok(affected as u64)
895 })
896 .await
897 }
898}
899
900const GRAPH_DDL: &str = include_str!("../../sql/graph-ddl.sql");
905
906pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
907 conn.execute_batch(GRAPH_DDL)
908}
909
910#[cfg(test)]
911#[path = "graph_tests.rs"]
912mod tests;