1use std::sync::Arc;
4
5use async_trait::async_trait;
6use chrono::{DateTime, TimeZone, Utc};
7use uuid::Uuid;
8
9use khive_storage::error::StorageError;
10use khive_storage::types::{
11 BatchWriteSummary, DeleteMode, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit,
12 NeighborQuery, Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
13};
14use khive_storage::GraphStore;
15use khive_storage::LinkId;
16use khive_storage::StorageCapability;
17use khive_types::EdgeRelation;
18
19use crate::error::SqliteError;
20use crate::pool::ConnectionPool;
21
22fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
24 StorageError::driver(StorageCapability::Graph, op, e)
25}
26
27fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
28 StorageError::driver(StorageCapability::Graph, op, e)
29}
30
31pub struct SqlGraphStore {
33 pool: Arc<ConnectionPool>,
34 is_file_backed: bool,
35 namespace: String,
36}
37
38impl SqlGraphStore {
39 pub fn new_scoped(
41 pool: Arc<ConnectionPool>,
42 is_file_backed: bool,
43 namespace: impl Into<String>,
44 ) -> Self {
45 Self {
46 pool,
47 is_file_backed,
48 namespace: namespace.into(),
49 }
50 }
51
52 fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
53 let config = self.pool.config();
54 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
55 operation: "graph_writer".into(),
56 message: "in-memory databases do not support standalone connections".into(),
57 })?;
58
59 let conn = rusqlite::Connection::open_with_flags(
60 path,
61 rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
62 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
63 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
64 )
65 .map_err(|e| map_err(e, "open_graph_writer"))?;
66
67 conn.busy_timeout(config.busy_timeout)
68 .map_err(|e| map_err(e, "open_graph_writer"))?;
69 conn.pragma_update(None, "foreign_keys", "ON")
70 .map_err(|e| map_err(e, "open_graph_writer"))?;
71 conn.pragma_update(None, "synchronous", "NORMAL")
72 .map_err(|e| map_err(e, "open_graph_writer"))?;
73
74 Ok(conn)
75 }
76
77 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
78 let config = self.pool.config();
79 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
80 operation: "graph_reader".into(),
81 message: "in-memory databases do not support standalone connections".into(),
82 })?;
83
84 let conn = rusqlite::Connection::open_with_flags(
85 path,
86 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
87 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
88 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
89 )
90 .map_err(|e| map_err(e, "open_graph_reader"))?;
91
92 conn.busy_timeout(config.busy_timeout)
93 .map_err(|e| map_err(e, "open_graph_reader"))?;
94 conn.pragma_update(None, "foreign_keys", "ON")
95 .map_err(|e| map_err(e, "open_graph_reader"))?;
96 conn.pragma_update(None, "synchronous", "NORMAL")
97 .map_err(|e| map_err(e, "open_graph_reader"))?;
98
99 Ok(conn)
100 }
101
102 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
103 where
104 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
105 R: Send + 'static,
106 {
107 if self.is_file_backed {
108 let conn = self.open_standalone_writer()?;
109 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
110 .await
111 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
112 } else {
113 let pool = Arc::clone(&self.pool);
114 tokio::task::spawn_blocking(move || {
115 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
116 f(guard.conn()).map_err(|e| map_err(e, op))
117 })
118 .await
119 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
120 }
121 }
122
123 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
124 where
125 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
126 R: Send + 'static,
127 {
128 if self.is_file_backed {
129 let conn = self.open_standalone_reader()?;
130 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
131 .await
132 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
133 } else {
134 let pool = Arc::clone(&self.pool);
135 tokio::task::spawn_blocking(move || {
136 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
137 f(guard.conn()).map_err(|e| map_err(e, op))
138 })
139 .await
140 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
141 }
142 }
143}
144
145fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
150 let namespace: String = row.get(0)?;
151 let id_str: String = row.get(1)?;
152 let source_str: String = row.get(2)?;
153 let target_str: String = row.get(3)?;
154 let relation_str: String = row.get(4)?;
155 let weight: f64 = row.get(5)?;
156 let created_micros: i64 = row.get(6)?;
157 let updated_micros: i64 = row.get(7)?;
158 let deleted_micros: Option<i64> = row.get(8)?;
159 let metadata_str: Option<String> = row.get(9)?;
160 let target_backend: Option<String> = row.get(10)?;
161
162 let id = parse_uuid(&id_str)?;
163 let source_id = parse_uuid(&source_str)?;
164 let target_id = parse_uuid(&target_str)?;
165 let created_at = micros_to_datetime(created_micros);
166 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
167 rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e))
168 })?;
169 let metadata = match metadata_str {
170 Some(s) => {
171 let v = serde_json::from_str(&s).map_err(|e| {
172 rusqlite::Error::FromSqlConversionFailure(
173 9,
174 rusqlite::types::Type::Text,
175 Box::new(e),
176 )
177 })?;
178 Some(v)
179 }
180 None => None,
181 };
182
183 Ok(Edge {
184 id: id.into(),
185 namespace,
186 source_id,
187 target_id,
188 relation,
189 weight,
190 created_at,
191 updated_at: micros_to_datetime(updated_micros),
192 deleted_at: deleted_micros.map(micros_to_datetime),
193 metadata,
194 target_backend,
195 })
196}
197
198fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
199 Uuid::parse_str(s).map_err(|e| {
200 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
201 })
202}
203
204fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
205 Utc.timestamp_micros(micros)
206 .single()
207 .unwrap_or_else(Utc::now)
208}
209
210fn build_edge_filter_sql(
211 namespace: &str,
212 filter: &EdgeFilter,
213) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
214 let mut conditions: Vec<String> = vec![
215 "namespace = ?1".to_string(),
216 "deleted_at IS NULL".to_string(),
217 ];
218 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
219
220 if !filter.ids.is_empty() {
221 let placeholders: Vec<String> = filter
222 .ids
223 .iter()
224 .map(|id| {
225 params.push(Box::new(id.to_string()));
226 format!("?{}", params.len())
227 })
228 .collect();
229 conditions.push(format!("id IN ({})", placeholders.join(",")));
230 }
231
232 if !filter.source_ids.is_empty() {
233 let placeholders: Vec<String> = filter
234 .source_ids
235 .iter()
236 .map(|id| {
237 params.push(Box::new(id.to_string()));
238 format!("?{}", params.len())
239 })
240 .collect();
241 conditions.push(format!("source_id IN ({})", placeholders.join(",")));
242 }
243
244 if !filter.target_ids.is_empty() {
245 let placeholders: Vec<String> = filter
246 .target_ids
247 .iter()
248 .map(|id| {
249 params.push(Box::new(id.to_string()));
250 format!("?{}", params.len())
251 })
252 .collect();
253 conditions.push(format!("target_id IN ({})", placeholders.join(",")));
254 }
255
256 if !filter.relations.is_empty() {
257 let placeholders: Vec<String> = filter
258 .relations
259 .iter()
260 .map(|r| {
261 params.push(Box::new(r.to_string()));
262 format!("?{}", params.len())
263 })
264 .collect();
265 conditions.push(format!("relation IN ({})", placeholders.join(",")));
266 }
267
268 if let Some(min_w) = filter.min_weight {
269 params.push(Box::new(min_w));
270 conditions.push(format!("weight >= ?{}", params.len()));
271 }
272
273 if let Some(max_w) = filter.max_weight {
274 params.push(Box::new(max_w));
275 conditions.push(format!("weight <= ?{}", params.len()));
276 }
277
278 if let Some(ref time_range) = filter.created_at {
279 if let Some(start) = time_range.start {
280 params.push(Box::new(start.timestamp_micros()));
281 conditions.push(format!("created_at >= ?{}", params.len()));
282 }
283 if let Some(end) = time_range.end {
284 params.push(Box::new(end.timestamp_micros()));
285 conditions.push(format!("created_at < ?{}", params.len()));
286 }
287 }
288
289 let clause = format!(" WHERE {}", conditions.join(" AND "));
290 (clause, params)
291}
292
293fn edge_sort_col(field: &EdgeSortField) -> &'static str {
294 match field {
295 EdgeSortField::CreatedAt => "created_at",
296 EdgeSortField::Weight => "weight",
297 EdgeSortField::Relation => "relation",
298 }
299}
300
301fn canonical_edge_endpoints(
310 relation: EdgeRelation,
311 source_id: Uuid,
312 target_id: Uuid,
313) -> (Uuid, Uuid) {
314 if relation.is_symmetric() && target_id < source_id {
315 (target_id, source_id)
316 } else {
317 (source_id, target_id)
318 }
319}
320
321#[async_trait]
322impl GraphStore for SqlGraphStore {
323 async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
324 let namespace = self.namespace.clone();
325 if edge.namespace != namespace {
326 return Err(StorageError::InvalidInput {
327 capability: StorageCapability::Graph,
328 operation: "upsert_edge".into(),
329 message: format!(
330 "edge namespace {:?} does not match store namespace {:?}",
331 edge.namespace, namespace
332 ),
333 });
334 }
335 let id_str = Uuid::from(edge.id).to_string();
336 let (source_id, target_id) =
337 canonical_edge_endpoints(edge.relation, edge.source_id, edge.target_id);
338 let src_str = source_id.to_string();
339 let tgt_str = target_id.to_string();
340 let relation_str = edge.relation.to_string();
341 let metadata_str = edge
342 .metadata
343 .as_ref()
344 .map(serde_json::to_string)
345 .transpose()
346 .map_err(|e| StorageError::driver(StorageCapability::Graph, "upsert_edge", e))?;
347 self.with_writer("upsert_edge", move |conn| {
348 conn.execute(
349 "INSERT INTO graph_edges \
350 (namespace, id, source_id, target_id, relation, weight, \
351 created_at, updated_at, deleted_at, metadata, target_backend) \
352 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) \
353 ON CONFLICT(namespace, id) DO UPDATE SET \
354 source_id = excluded.source_id, \
355 target_id = excluded.target_id, \
356 relation = excluded.relation, \
357 weight = excluded.weight, \
358 updated_at = excluded.updated_at, \
359 deleted_at = NULL, \
360 metadata = excluded.metadata, \
361 target_backend = excluded.target_backend \
362 ON CONFLICT(namespace, source_id, target_id, relation) DO UPDATE SET \
363 weight = excluded.weight, \
364 updated_at = excluded.updated_at, \
365 deleted_at = NULL, \
366 metadata = excluded.metadata, \
367 target_backend = excluded.target_backend",
368 rusqlite::params![
369 namespace,
370 id_str,
371 src_str,
372 tgt_str,
373 relation_str,
374 edge.weight,
375 edge.created_at.timestamp_micros(),
376 edge.updated_at.timestamp_micros(),
377 edge.deleted_at.map(|t| t.timestamp_micros()),
378 metadata_str,
379 edge.target_backend,
380 ],
381 )?;
382 Ok(())
383 })
384 .await
385 }
386
387 async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
388 let attempted = edges.len() as u64;
389 let namespace = self.namespace.clone();
390
391 for edge in &edges {
393 if edge.namespace != namespace {
394 return Err(StorageError::InvalidInput {
395 capability: StorageCapability::Graph,
396 operation: "upsert_edges".into(),
397 message: format!(
398 "edge namespace {:?} does not match store namespace {:?}",
399 edge.namespace, namespace
400 ),
401 });
402 }
403 }
404
405 self.with_writer("upsert_edges", move |conn| {
406 conn.execute_batch("BEGIN IMMEDIATE")?;
407 let mut affected = 0u64;
408
409 for edge in &edges {
410 let id_str = Uuid::from(edge.id).to_string();
411 let (canon_src, canon_tgt) =
412 canonical_edge_endpoints(edge.relation, edge.source_id, edge.target_id);
413 let src_str = canon_src.to_string();
414 let tgt_str = canon_tgt.to_string();
415 let relation_str = edge.relation.to_string();
416 let metadata_str = edge
417 .metadata
418 .as_ref()
419 .map(serde_json::to_string)
420 .transpose()
421 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
422 if let Err(e) = conn.execute(
423 "INSERT INTO graph_edges \
424 (namespace, id, source_id, target_id, relation, weight, \
425 created_at, updated_at, deleted_at, metadata, target_backend) \
426 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) \
427 ON CONFLICT(namespace, id) DO UPDATE SET \
428 source_id = excluded.source_id, \
429 target_id = excluded.target_id, \
430 relation = excluded.relation, \
431 weight = excluded.weight, \
432 updated_at = excluded.updated_at, \
433 deleted_at = NULL, \
434 metadata = excluded.metadata, \
435 target_backend = excluded.target_backend \
436 ON CONFLICT(namespace, source_id, target_id, relation) DO UPDATE SET \
437 weight = excluded.weight, \
438 updated_at = excluded.updated_at, \
439 deleted_at = NULL, \
440 metadata = excluded.metadata, \
441 target_backend = excluded.target_backend",
442 rusqlite::params![
443 &namespace,
444 id_str,
445 src_str,
446 tgt_str,
447 relation_str,
448 edge.weight,
449 edge.created_at.timestamp_micros(),
450 edge.updated_at.timestamp_micros(),
451 edge.deleted_at.map(|t| t.timestamp_micros()),
452 metadata_str,
453 edge.target_backend.as_deref(),
454 ],
455 ) {
456 let _ = conn.execute_batch("ROLLBACK");
457 return Err(e);
458 }
459 affected += 1;
460 }
461
462 if let Err(e) = conn.execute_batch("COMMIT") {
463 let _ = conn.execute_batch("ROLLBACK");
464 return Err(e);
465 }
466 Ok(BatchWriteSummary {
467 attempted,
468 affected,
469 failed: 0,
470 first_error: String::new(),
471 })
472 })
473 .await
474 }
475
476 async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
477 let namespace = self.namespace.clone();
478 let id_str = Uuid::from(id).to_string();
479
480 self.with_reader("get_edge", move |conn| {
481 let mut stmt = conn.prepare(
482 "SELECT namespace, id, source_id, target_id, relation, weight, \
483 created_at, updated_at, deleted_at, metadata, target_backend \
484 FROM graph_edges WHERE namespace = ?1 AND id = ?2 AND deleted_at IS NULL",
485 )?;
486 let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
487 match rows.next()? {
488 Some(row) => Ok(Some(read_edge(row)?)),
489 None => Ok(None),
490 }
491 })
492 .await
493 }
494
495 async fn delete_edge(&self, id: LinkId, mode: DeleteMode) -> Result<bool, StorageError> {
496 let namespace = self.namespace.clone();
497 let id_str = Uuid::from(id).to_string();
498
499 self.with_writer("delete_edge", move |conn| {
500 let affected = match mode {
501 DeleteMode::Soft => conn.execute(
502 "UPDATE graph_edges SET deleted_at = ?3, updated_at = ?3 \
503 WHERE namespace = ?1 AND id = ?2 AND deleted_at IS NULL",
504 rusqlite::params![namespace, id_str, chrono::Utc::now().timestamp_micros(),],
505 )?,
506 DeleteMode::Hard => conn.execute(
507 "DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
508 rusqlite::params![namespace, id_str],
509 )?,
510 };
511 Ok(affected > 0)
512 })
513 .await
514 }
515
516 async fn query_edges(
517 &self,
518 filter: EdgeFilter,
519 sort: Vec<SortOrder<EdgeSortField>>,
520 page: PageRequest,
521 ) -> Result<Page<Edge>, StorageError> {
522 let namespace = self.namespace.clone();
523 self.with_reader("query_edges", move |conn| {
524 let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
525
526 let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
527 let total: i64 = {
528 let mut stmt = conn.prepare(&count_sql)?;
529 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
530 filter_params.iter().map(|p| p.as_ref()).collect();
531 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
532 };
533
534 let order_clause = if sort.is_empty() {
535 " ORDER BY created_at DESC".to_string()
536 } else {
537 let parts: Vec<String> = sort
538 .iter()
539 .map(|s| {
540 let dir = match s.direction {
541 SortDirection::Asc => "ASC",
542 SortDirection::Desc => "DESC",
543 };
544 format!("{} {}", edge_sort_col(&s.field), dir)
545 })
546 .collect();
547 format!(" ORDER BY {}", parts.join(", "))
548 };
549
550 let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
551 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
552 all_params.push(Box::new(page.limit as i64));
553 all_params.push(Box::new(page.offset as i64));
554
555 let limit_idx = all_params.len() - 1;
556 let offset_idx = all_params.len();
557
558 let data_sql = format!(
559 "SELECT namespace, id, source_id, target_id, relation, weight, \
560 created_at, updated_at, deleted_at, metadata, target_backend \
561 FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
562 where_clause, order_clause, limit_idx, offset_idx,
563 );
564
565 let mut stmt = conn.prepare(&data_sql)?;
566 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
567 all_params.iter().map(|p| p.as_ref()).collect();
568 let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
569
570 let mut items = Vec::new();
571 for row in rows {
572 items.push(row?);
573 }
574
575 Ok(Page {
576 items,
577 total: Some(total as u64),
578 })
579 })
580 .await
581 }
582
583 async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
584 let namespace = self.namespace.clone();
585 self.with_reader("count_edges", move |conn| {
586 let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
587 let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
588 let mut stmt = conn.prepare(&sql)?;
589 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
590 params.iter().map(|p| p.as_ref()).collect();
591 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
592 Ok(count as u64)
593 })
594 .await
595 }
596
597 async fn neighbors(
598 &self,
599 node_id: Uuid,
600 query: NeighborQuery,
601 ) -> Result<Vec<NeighborHit>, StorageError> {
602 use khive_storage::types::Direction;
603
604 let namespace = self.namespace.clone();
605 let node_str = node_id.to_string();
606
607 self.with_reader("neighbors", move |conn| {
608 let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
609 FROM graph_edges \
610 WHERE namespace = ?1 AND source_id = ?2 AND deleted_at IS NULL";
611 let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
612 FROM graph_edges \
613 WHERE namespace = ?1 AND target_id = ?2 AND deleted_at IS NULL";
614
615 let sql = match query.direction {
616 Direction::Out => base_out.to_string(),
617 Direction::In => base_in.to_string(),
618 Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
619 };
620
621 let mut conditions: Vec<String> = Vec::new();
622 let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
623 let mut param_idx = 3;
624
625 if let Some(ref rels) = query.relations {
626 if !rels.is_empty() {
627 let placeholders: Vec<String> = rels
628 .iter()
629 .map(|r| {
630 extra_params.push(Box::new(r.to_string()));
631 let p = format!("?{}", param_idx);
632 param_idx += 1;
633 p
634 })
635 .collect();
636 conditions.push(format!("relation IN ({})", placeholders.join(",")));
637 }
638 }
639
640 if let Some(min_w) = query.min_weight {
641 extra_params.push(Box::new(min_w));
642 conditions.push(format!("weight >= ?{}", param_idx));
643 param_idx += 1;
644 }
645
646 let where_extra = if conditions.is_empty() {
647 String::new()
648 } else {
649 format!(" WHERE {}", conditions.join(" AND "))
650 };
651
652 let limit_clause = if let Some(lim) = query.limit {
653 extra_params.push(Box::new(lim as i64));
654 format!(" LIMIT ?{}", param_idx)
655 } else {
656 String::new()
657 };
658
659 let full_sql = format!(
660 "SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
661 sql, where_extra, limit_clause
662 );
663
664 let mut stmt = conn.prepare(&full_sql)?;
665
666 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
667 all_params.push(Box::new(namespace.clone()));
668 all_params.push(Box::new(node_str.clone()));
669 all_params.extend(extra_params);
670
671 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
672 all_params.iter().map(|p| p.as_ref()).collect();
673
674 let rows = stmt.query_map(param_refs.as_slice(), |row| {
675 let nid_str: String = row.get(0)?;
676 let eid_str: String = row.get(1)?;
677 let relation_str: String = row.get(2)?;
678 let weight: f64 = row.get(3)?;
679 Ok((nid_str, eid_str, relation_str, weight))
680 })?;
681
682 let mut hits = Vec::new();
683 for row in rows {
684 let (nid_str, eid_str, relation_str, weight) = row?;
685 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
686 rusqlite::Error::FromSqlConversionFailure(
687 2,
688 rusqlite::types::Type::Text,
689 Box::new(e),
690 )
691 })?;
692 hits.push(NeighborHit {
693 node_id: parse_uuid(&nid_str)?,
694 edge_id: parse_uuid(&eid_str)?,
695 relation,
696 weight,
697 name: None,
698 kind: None,
699 });
700 }
701
702 Ok(hits)
703 })
704 .await
705 }
706
707 async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
708 use khive_storage::types::Direction;
709
710 if request.roots.is_empty() {
711 return Ok(Vec::new());
712 }
713
714 let roots = request.roots.clone();
715 let opts = request.options.clone();
716 let include_roots = request.include_roots;
717 let namespace = self.namespace.clone();
718
719 self.with_reader("traverse", move |conn| {
720 let mut all_paths: Vec<GraphPath> = Vec::new();
721
722 for root_id in &roots {
723 let root_str = root_id.to_string();
724
725 let (join_condition, next_node) = match opts.direction {
726 Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
727 Direction::In => ("e.target_id = t.node_id", "e.source_id"),
728 Direction::Both => (
729 "(e.source_id = t.node_id OR e.target_id = t.node_id)",
730 "CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
731 ),
732 };
733
734 let mut relation_cond = String::new();
735 let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
736 let mut param_idx = 4;
737
738 if let Some(ref rels) = opts.relations {
739 if !rels.is_empty() {
740 let placeholders: Vec<String> = rels
741 .iter()
742 .map(|r| {
743 relation_params.push(Box::new(r.to_string()));
744 let p = format!("?{}", param_idx);
745 param_idx += 1;
746 p
747 })
748 .collect();
749 relation_cond =
750 format!(" AND e.relation IN ({})", placeholders.join(","));
751 }
752 }
753
754 let mut weight_cond = String::new();
755 if let Some(min_w) = opts.min_weight {
756 relation_params.push(Box::new(min_w));
757 weight_cond = format!(" AND e.weight >= ?{}", param_idx);
758 param_idx += 1;
759 }
760
761 let limit_clause = if let Some(lim) = opts.limit {
762 relation_params.push(Box::new(lim as i64));
763 format!(" LIMIT ?{}", param_idx)
764 } else {
765 String::new()
766 };
767
768 let cte_sql = format!(
769 "WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
770 SELECT ?2, NULL, 0, ?2, 0.0 \
771 UNION ALL \
772 SELECT {next_node}, e.id, t.depth + 1, \
773 t.path || ',' || {next_node}, \
774 t.total_weight + e.weight \
775 FROM graph_edges e \
776 JOIN traversal t ON {join_condition} \
777 WHERE e.namespace = ?1 \
778 AND e.deleted_at IS NULL \
779 AND t.depth < ?3 \
780 AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
781 ) \
782 SELECT node_id, edge_id, depth, path, total_weight \
783 FROM traversal WHERE depth > 0 \
784 ORDER BY depth{limit}",
785 next_node = next_node,
786 join_condition = join_condition,
787 rel_cond = relation_cond,
788 wt_cond = weight_cond,
789 limit = limit_clause,
790 );
791
792 let mut stmt = conn.prepare(&cte_sql)?;
793
794 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
795 all_params.push(Box::new(namespace.clone()));
796 all_params.push(Box::new(root_str.clone()));
797 all_params.push(Box::new(opts.max_depth as i64));
798 all_params.extend(relation_params);
799
800 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
801 all_params.iter().map(|p| p.as_ref()).collect();
802
803 let rows = stmt.query_map(param_refs.as_slice(), |row| {
804 let node_str: String = row.get(0)?;
805 let edge_str: Option<String> = row.get(1)?;
806 let depth: i64 = row.get(2)?;
807 let _path: String = row.get(3)?;
808 let total_weight: f64 = row.get(4)?;
809 Ok((node_str, edge_str, depth, total_weight))
810 })?;
811
812 let mut nodes = Vec::new();
813 let mut max_weight = 0.0f64;
814 let mut seen: std::collections::HashSet<Uuid> =
819 std::collections::HashSet::new();
820
821 if include_roots {
822 seen.insert(*root_id);
823 nodes.push(PathNode {
824 node_id: *root_id,
825 via_edge: None,
826 depth: 0,
827 name: None,
828 kind: None,
829 });
830 }
831
832 for row in rows {
833 let (node_str, edge_str, depth, total_weight) = row?;
834 let node_id = parse_uuid(&node_str)?;
835 if !seen.insert(node_id) {
837 continue;
838 }
839 let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
840 nodes.push(PathNode {
841 node_id,
842 via_edge,
843 depth: depth as usize,
844 name: None,
845 kind: None,
846 });
847 if total_weight > max_weight {
848 max_weight = total_weight;
849 }
850 }
851
852 if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
853 all_paths.push(GraphPath {
854 root_id: *root_id,
855 nodes,
856 total_weight: max_weight,
857 });
858 }
859 }
860
861 Ok(all_paths)
862 })
863 .await
864 }
865}
866
867const GRAPH_DDL: &str = "\
872 CREATE TABLE IF NOT EXISTS graph_edges (\
873 namespace TEXT NOT NULL,\
874 id TEXT NOT NULL,\
875 source_id TEXT NOT NULL,\
876 target_id TEXT NOT NULL,\
877 relation TEXT NOT NULL,\
878 weight REAL NOT NULL DEFAULT 1.0,\
879 created_at INTEGER NOT NULL,\
880 updated_at INTEGER NOT NULL,\
881 deleted_at INTEGER,\
882 metadata TEXT,\
883 target_backend TEXT,\
884 PRIMARY KEY (namespace, id)\
885 );\
886 CREATE UNIQUE INDEX IF NOT EXISTS idx_graph_edges_unique_triple ON graph_edges(namespace, source_id, target_id, relation);\
887 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_source ON graph_edges(namespace, source_id);\
888 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_target ON graph_edges(namespace, target_id);\
889 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_relation ON graph_edges(namespace, relation);\
890 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_src_rel ON graph_edges(namespace, source_id, relation);\
891 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_tgt_rel ON graph_edges(namespace, target_id, relation);\
892 CREATE INDEX IF NOT EXISTS idx_graph_edges_target_backend ON graph_edges(target_backend) WHERE target_backend IS NOT NULL;\
893";
894
895pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
896 conn.execute_batch(GRAPH_DDL)
897}
898
899#[cfg(test)]
900#[path = "graph_tests.rs"]
901mod tests;