1use std::sync::Arc;
12
13use async_trait::async_trait;
14use chrono::{DateTime, TimeZone, Utc};
15use uuid::Uuid;
16
17use khive_storage::error::StorageError;
18use khive_storage::types::{
19 BatchWriteSummary, DeleteMode, Edge, EdgeFilter, EdgeSortField, GraphPath, NeighborHit,
20 NeighborQuery, Page, PageRequest, PathNode, SortDirection, SortOrder, TraversalRequest,
21};
22use khive_storage::GraphStore;
23use khive_storage::LinkId;
24use khive_storage::StorageCapability;
25use khive_types::EdgeRelation;
26
27use crate::error::SqliteError;
28use crate::pool::ConnectionPool;
29
30fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
32 StorageError::driver(StorageCapability::Graph, op, e)
33}
34
35fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
36 StorageError::driver(StorageCapability::Graph, op, e)
37}
38
39pub struct SqlGraphStore {
41 pool: Arc<ConnectionPool>,
42 is_file_backed: bool,
43 namespace: String,
44}
45
46impl SqlGraphStore {
47 pub fn new_scoped(
49 pool: Arc<ConnectionPool>,
50 is_file_backed: bool,
51 namespace: impl Into<String>,
52 ) -> Self {
53 Self {
54 pool,
55 is_file_backed,
56 namespace: namespace.into(),
57 }
58 }
59
60 fn open_standalone_writer(&self) -> Result<rusqlite::Connection, StorageError> {
61 let config = self.pool.config();
62 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
63 operation: "graph_writer".into(),
64 message: "in-memory databases do not support standalone connections".into(),
65 })?;
66
67 let conn = rusqlite::Connection::open_with_flags(
68 path,
69 rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
70 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
71 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
72 )
73 .map_err(|e| map_err(e, "open_graph_writer"))?;
74
75 conn.busy_timeout(config.busy_timeout)
76 .map_err(|e| map_err(e, "open_graph_writer"))?;
77 conn.pragma_update(None, "foreign_keys", "ON")
78 .map_err(|e| map_err(e, "open_graph_writer"))?;
79 conn.pragma_update(None, "synchronous", "NORMAL")
80 .map_err(|e| map_err(e, "open_graph_writer"))?;
81
82 Ok(conn)
83 }
84
85 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
86 let config = self.pool.config();
87 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
88 operation: "graph_reader".into(),
89 message: "in-memory databases do not support standalone connections".into(),
90 })?;
91
92 let conn = rusqlite::Connection::open_with_flags(
93 path,
94 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
95 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
96 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
97 )
98 .map_err(|e| map_err(e, "open_graph_reader"))?;
99
100 conn.busy_timeout(config.busy_timeout)
101 .map_err(|e| map_err(e, "open_graph_reader"))?;
102 conn.pragma_update(None, "foreign_keys", "ON")
103 .map_err(|e| map_err(e, "open_graph_reader"))?;
104 conn.pragma_update(None, "synchronous", "NORMAL")
105 .map_err(|e| map_err(e, "open_graph_reader"))?;
106
107 Ok(conn)
108 }
109
110 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
111 where
112 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
113 R: Send + 'static,
114 {
115 if self.is_file_backed {
116 let conn = self.open_standalone_writer()?;
117 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
118 .await
119 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
120 } else {
121 let pool = Arc::clone(&self.pool);
122 tokio::task::spawn_blocking(move || {
123 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
124 f(guard.conn()).map_err(|e| map_err(e, op))
125 })
126 .await
127 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
128 }
129 }
130
131 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
132 where
133 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
134 R: Send + 'static,
135 {
136 if self.is_file_backed {
137 let conn = self.open_standalone_reader()?;
138 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
139 .await
140 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
141 } else {
142 let pool = Arc::clone(&self.pool);
143 tokio::task::spawn_blocking(move || {
144 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
145 f(guard.conn()).map_err(|e| map_err(e, op))
146 })
147 .await
148 .map_err(|e| StorageError::driver(StorageCapability::Graph, op, e))?
149 }
150 }
151}
152
153fn read_edge(row: &rusqlite::Row<'_>) -> Result<Edge, rusqlite::Error> {
158 let namespace: String = row.get(0)?;
159 let id_str: String = row.get(1)?;
160 let source_str: String = row.get(2)?;
161 let target_str: String = row.get(3)?;
162 let relation_str: String = row.get(4)?;
163 let weight: f64 = row.get(5)?;
164 let created_micros: i64 = row.get(6)?;
165 let updated_micros: i64 = row.get(7)?;
166 let deleted_micros: Option<i64> = row.get(8)?;
167 let metadata_str: Option<String> = row.get(9)?;
168 let target_backend: Option<String> = row.get(10)?;
169
170 let id = parse_uuid(&id_str)?;
171 let source_id = parse_uuid(&source_str)?;
172 let target_id = parse_uuid(&target_str)?;
173 let created_at = micros_to_datetime(created_micros);
174 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
175 rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e))
176 })?;
177 let metadata = match metadata_str {
178 Some(s) => {
179 let v = serde_json::from_str(&s).map_err(|e| {
180 rusqlite::Error::FromSqlConversionFailure(
181 9,
182 rusqlite::types::Type::Text,
183 Box::new(e),
184 )
185 })?;
186 Some(v)
187 }
188 None => None,
189 };
190
191 Ok(Edge {
192 id: id.into(),
193 namespace,
194 source_id,
195 target_id,
196 relation,
197 weight,
198 created_at,
199 updated_at: micros_to_datetime(updated_micros),
200 deleted_at: deleted_micros.map(micros_to_datetime),
201 metadata,
202 target_backend,
203 })
204}
205
206fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
207 Uuid::parse_str(s).map_err(|e| {
208 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
209 })
210}
211
212fn micros_to_datetime(micros: i64) -> DateTime<Utc> {
213 Utc.timestamp_micros(micros)
214 .single()
215 .unwrap_or_else(Utc::now)
216}
217
218fn build_edge_filter_sql(
219 namespace: &str,
220 filter: &EdgeFilter,
221) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
222 let mut conditions: Vec<String> = vec![
223 "namespace = ?1".to_string(),
224 "deleted_at IS NULL".to_string(),
225 ];
226 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
227
228 if !filter.ids.is_empty() {
229 let placeholders: Vec<String> = filter
230 .ids
231 .iter()
232 .map(|id| {
233 params.push(Box::new(id.to_string()));
234 format!("?{}", params.len())
235 })
236 .collect();
237 conditions.push(format!("id IN ({})", placeholders.join(",")));
238 }
239
240 if !filter.source_ids.is_empty() {
241 let placeholders: Vec<String> = filter
242 .source_ids
243 .iter()
244 .map(|id| {
245 params.push(Box::new(id.to_string()));
246 format!("?{}", params.len())
247 })
248 .collect();
249 conditions.push(format!("source_id IN ({})", placeholders.join(",")));
250 }
251
252 if !filter.target_ids.is_empty() {
253 let placeholders: Vec<String> = filter
254 .target_ids
255 .iter()
256 .map(|id| {
257 params.push(Box::new(id.to_string()));
258 format!("?{}", params.len())
259 })
260 .collect();
261 conditions.push(format!("target_id IN ({})", placeholders.join(",")));
262 }
263
264 if !filter.relations.is_empty() {
265 let placeholders: Vec<String> = filter
266 .relations
267 .iter()
268 .map(|r| {
269 params.push(Box::new(r.to_string()));
270 format!("?{}", params.len())
271 })
272 .collect();
273 conditions.push(format!("relation IN ({})", placeholders.join(",")));
274 }
275
276 if let Some(min_w) = filter.min_weight {
277 params.push(Box::new(min_w));
278 conditions.push(format!("weight >= ?{}", params.len()));
279 }
280
281 if let Some(max_w) = filter.max_weight {
282 params.push(Box::new(max_w));
283 conditions.push(format!("weight <= ?{}", params.len()));
284 }
285
286 if let Some(ref time_range) = filter.created_at {
287 if let Some(start) = time_range.start {
288 params.push(Box::new(start.timestamp_micros()));
289 conditions.push(format!("created_at >= ?{}", params.len()));
290 }
291 if let Some(end) = time_range.end {
292 params.push(Box::new(end.timestamp_micros()));
293 conditions.push(format!("created_at < ?{}", params.len()));
294 }
295 }
296
297 let clause = format!(" WHERE {}", conditions.join(" AND "));
298 (clause, params)
299}
300
301fn edge_sort_col(field: &EdgeSortField) -> &'static str {
302 match field {
303 EdgeSortField::CreatedAt => "created_at",
304 EdgeSortField::Weight => "weight",
305 EdgeSortField::Relation => "relation",
306 }
307}
308
309#[async_trait]
314impl GraphStore for SqlGraphStore {
315 async fn upsert_edge(&self, edge: Edge) -> Result<(), StorageError> {
316 let namespace = self.namespace.clone();
317 if edge.namespace != namespace {
318 return Err(StorageError::InvalidInput {
319 capability: StorageCapability::Graph,
320 operation: "upsert_edge".into(),
321 message: format!(
322 "edge namespace {:?} does not match store namespace {:?}",
323 edge.namespace, namespace
324 ),
325 });
326 }
327 let id_str = Uuid::from(edge.id).to_string();
328 let src_str = edge.source_id.to_string();
329 let tgt_str = edge.target_id.to_string();
330 let relation_str = edge.relation.to_string();
331 let metadata_str = edge
332 .metadata
333 .as_ref()
334 .map(serde_json::to_string)
335 .transpose()
336 .map_err(|e| StorageError::driver(StorageCapability::Graph, "upsert_edge", e))?;
337 self.with_writer("upsert_edge", move |conn| {
338 conn.execute(
339 "INSERT INTO graph_edges \
340 (namespace, id, source_id, target_id, relation, weight, \
341 created_at, updated_at, deleted_at, metadata, target_backend) \
342 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) \
343 ON CONFLICT(namespace, id) DO UPDATE SET \
344 source_id = excluded.source_id, \
345 target_id = excluded.target_id, \
346 relation = excluded.relation, \
347 weight = excluded.weight, \
348 updated_at = excluded.updated_at, \
349 deleted_at = NULL, \
350 metadata = excluded.metadata, \
351 target_backend = excluded.target_backend \
352 ON CONFLICT(namespace, source_id, target_id, relation) DO UPDATE SET \
353 weight = excluded.weight, \
354 updated_at = excluded.updated_at, \
355 deleted_at = NULL, \
356 metadata = excluded.metadata, \
357 target_backend = excluded.target_backend",
358 rusqlite::params![
359 namespace,
360 id_str,
361 src_str,
362 tgt_str,
363 relation_str,
364 edge.weight,
365 edge.created_at.timestamp_micros(),
366 edge.updated_at.timestamp_micros(),
367 edge.deleted_at.map(|t| t.timestamp_micros()),
368 metadata_str,
369 edge.target_backend,
370 ],
371 )?;
372 Ok(())
373 })
374 .await
375 }
376
377 async fn upsert_edges(&self, edges: Vec<Edge>) -> Result<BatchWriteSummary, StorageError> {
378 let attempted = edges.len() as u64;
379 let namespace = self.namespace.clone();
380
381 for edge in &edges {
383 if edge.namespace != namespace {
384 return Err(StorageError::InvalidInput {
385 capability: StorageCapability::Graph,
386 operation: "upsert_edges".into(),
387 message: format!(
388 "edge namespace {:?} does not match store namespace {:?}",
389 edge.namespace, namespace
390 ),
391 });
392 }
393 }
394
395 self.with_writer("upsert_edges", move |conn| {
396 conn.execute_batch("BEGIN IMMEDIATE")?;
397 let mut affected = 0u64;
398
399 for edge in &edges {
400 let id_str = Uuid::from(edge.id).to_string();
401 let src_str = edge.source_id.to_string();
402 let tgt_str = edge.target_id.to_string();
403 let relation_str = edge.relation.to_string();
404 let metadata_str = edge
405 .metadata
406 .as_ref()
407 .map(serde_json::to_string)
408 .transpose()
409 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
410 if let Err(e) = conn.execute(
411 "INSERT INTO graph_edges \
412 (namespace, id, source_id, target_id, relation, weight, \
413 created_at, updated_at, deleted_at, metadata, target_backend) \
414 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) \
415 ON CONFLICT(namespace, id) DO UPDATE SET \
416 source_id = excluded.source_id, \
417 target_id = excluded.target_id, \
418 relation = excluded.relation, \
419 weight = excluded.weight, \
420 updated_at = excluded.updated_at, \
421 deleted_at = NULL, \
422 metadata = excluded.metadata, \
423 target_backend = excluded.target_backend \
424 ON CONFLICT(namespace, source_id, target_id, relation) DO UPDATE SET \
425 weight = excluded.weight, \
426 updated_at = excluded.updated_at, \
427 deleted_at = NULL, \
428 metadata = excluded.metadata, \
429 target_backend = excluded.target_backend",
430 rusqlite::params![
431 &namespace,
432 id_str,
433 src_str,
434 tgt_str,
435 relation_str,
436 edge.weight,
437 edge.created_at.timestamp_micros(),
438 edge.updated_at.timestamp_micros(),
439 edge.deleted_at.map(|t| t.timestamp_micros()),
440 metadata_str,
441 edge.target_backend.as_deref(),
442 ],
443 ) {
444 let _ = conn.execute_batch("ROLLBACK");
445 return Err(e);
446 }
447 affected += 1;
448 }
449
450 if let Err(e) = conn.execute_batch("COMMIT") {
451 let _ = conn.execute_batch("ROLLBACK");
452 return Err(e);
453 }
454 Ok(BatchWriteSummary {
455 attempted,
456 affected,
457 failed: 0,
458 first_error: String::new(),
459 })
460 })
461 .await
462 }
463
464 async fn get_edge(&self, id: LinkId) -> Result<Option<Edge>, StorageError> {
465 let namespace = self.namespace.clone();
466 let id_str = Uuid::from(id).to_string();
467
468 self.with_reader("get_edge", move |conn| {
469 let mut stmt = conn.prepare(
470 "SELECT namespace, id, source_id, target_id, relation, weight, \
471 created_at, updated_at, deleted_at, metadata, target_backend \
472 FROM graph_edges WHERE namespace = ?1 AND id = ?2 AND deleted_at IS NULL",
473 )?;
474 let mut rows = stmt.query(rusqlite::params![namespace, id_str])?;
475 match rows.next()? {
476 Some(row) => Ok(Some(read_edge(row)?)),
477 None => Ok(None),
478 }
479 })
480 .await
481 }
482
483 async fn delete_edge(&self, id: LinkId, mode: DeleteMode) -> Result<bool, StorageError> {
484 let namespace = self.namespace.clone();
485 let id_str = Uuid::from(id).to_string();
486
487 self.with_writer("delete_edge", move |conn| {
488 let affected = match mode {
489 DeleteMode::Soft => conn.execute(
490 "UPDATE graph_edges SET deleted_at = ?3, updated_at = ?3 \
491 WHERE namespace = ?1 AND id = ?2 AND deleted_at IS NULL",
492 rusqlite::params![namespace, id_str, chrono::Utc::now().timestamp_micros(),],
493 )?,
494 DeleteMode::Hard => conn.execute(
495 "DELETE FROM graph_edges WHERE namespace = ?1 AND id = ?2",
496 rusqlite::params![namespace, id_str],
497 )?,
498 };
499 Ok(affected > 0)
500 })
501 .await
502 }
503
504 async fn query_edges(
505 &self,
506 filter: EdgeFilter,
507 sort: Vec<SortOrder<EdgeSortField>>,
508 page: PageRequest,
509 ) -> Result<Page<Edge>, StorageError> {
510 let namespace = self.namespace.clone();
511 self.with_reader("query_edges", move |conn| {
512 let (where_clause, filter_params) = build_edge_filter_sql(&namespace, &filter);
513
514 let count_sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
515 let total: i64 = {
516 let mut stmt = conn.prepare(&count_sql)?;
517 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
518 filter_params.iter().map(|p| p.as_ref()).collect();
519 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
520 };
521
522 let order_clause = if sort.is_empty() {
523 " ORDER BY created_at DESC".to_string()
524 } else {
525 let parts: Vec<String> = sort
526 .iter()
527 .map(|s| {
528 let dir = match s.direction {
529 SortDirection::Asc => "ASC",
530 SortDirection::Desc => "DESC",
531 };
532 format!("{} {}", edge_sort_col(&s.field), dir)
533 })
534 .collect();
535 format!(" ORDER BY {}", parts.join(", "))
536 };
537
538 let (_, data_filter_params) = build_edge_filter_sql(&namespace, &filter);
539 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = data_filter_params;
540 all_params.push(Box::new(page.limit as i64));
541 all_params.push(Box::new(page.offset as i64));
542
543 let limit_idx = all_params.len() - 1;
544 let offset_idx = all_params.len();
545
546 let data_sql = format!(
547 "SELECT namespace, id, source_id, target_id, relation, weight, \
548 created_at, updated_at, deleted_at, metadata, target_backend \
549 FROM graph_edges{}{} LIMIT ?{} OFFSET ?{}",
550 where_clause, order_clause, limit_idx, offset_idx,
551 );
552
553 let mut stmt = conn.prepare(&data_sql)?;
554 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
555 all_params.iter().map(|p| p.as_ref()).collect();
556 let rows = stmt.query_map(param_refs.as_slice(), read_edge)?;
557
558 let mut items = Vec::new();
559 for row in rows {
560 items.push(row?);
561 }
562
563 Ok(Page {
564 items,
565 total: Some(total as u64),
566 })
567 })
568 .await
569 }
570
571 async fn count_edges(&self, filter: EdgeFilter) -> Result<u64, StorageError> {
572 let namespace = self.namespace.clone();
573 self.with_reader("count_edges", move |conn| {
574 let (where_clause, params) = build_edge_filter_sql(&namespace, &filter);
575 let sql = format!("SELECT COUNT(*) FROM graph_edges{}", where_clause);
576 let mut stmt = conn.prepare(&sql)?;
577 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
578 params.iter().map(|p| p.as_ref()).collect();
579 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
580 Ok(count as u64)
581 })
582 .await
583 }
584
585 async fn neighbors(
586 &self,
587 node_id: Uuid,
588 query: NeighborQuery,
589 ) -> Result<Vec<NeighborHit>, StorageError> {
590 use khive_storage::types::Direction;
591
592 let namespace = self.namespace.clone();
593 let node_str = node_id.to_string();
594
595 self.with_reader("neighbors", move |conn| {
596 let base_out = "SELECT target_id AS node_id, id AS edge_id, relation, weight \
597 FROM graph_edges \
598 WHERE namespace = ?1 AND source_id = ?2 AND deleted_at IS NULL";
599 let base_in = "SELECT source_id AS node_id, id AS edge_id, relation, weight \
600 FROM graph_edges \
601 WHERE namespace = ?1 AND target_id = ?2 AND deleted_at IS NULL";
602
603 let sql = match query.direction {
604 Direction::Out => base_out.to_string(),
605 Direction::In => base_in.to_string(),
606 Direction::Both => format!("{} UNION ALL {}", base_out, base_in),
607 };
608
609 let mut conditions: Vec<String> = Vec::new();
610 let mut extra_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
611 let mut param_idx = 3;
612
613 if let Some(ref rels) = query.relations {
614 if !rels.is_empty() {
615 let placeholders: Vec<String> = rels
616 .iter()
617 .map(|r| {
618 extra_params.push(Box::new(r.to_string()));
619 let p = format!("?{}", param_idx);
620 param_idx += 1;
621 p
622 })
623 .collect();
624 conditions.push(format!("relation IN ({})", placeholders.join(",")));
625 }
626 }
627
628 if let Some(min_w) = query.min_weight {
629 extra_params.push(Box::new(min_w));
630 conditions.push(format!("weight >= ?{}", param_idx));
631 param_idx += 1;
632 }
633
634 let where_extra = if conditions.is_empty() {
635 String::new()
636 } else {
637 format!(" WHERE {}", conditions.join(" AND "))
638 };
639
640 let limit_clause = if let Some(lim) = query.limit {
641 extra_params.push(Box::new(lim as i64));
642 format!(" LIMIT ?{}", param_idx)
643 } else {
644 String::new()
645 };
646
647 let full_sql = format!(
648 "SELECT node_id, edge_id, relation, weight FROM ({}){}{}",
649 sql, where_extra, limit_clause
650 );
651
652 let mut stmt = conn.prepare(&full_sql)?;
653
654 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
655 all_params.push(Box::new(namespace.clone()));
656 all_params.push(Box::new(node_str.clone()));
657 all_params.extend(extra_params);
658
659 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
660 all_params.iter().map(|p| p.as_ref()).collect();
661
662 let rows = stmt.query_map(param_refs.as_slice(), |row| {
663 let nid_str: String = row.get(0)?;
664 let eid_str: String = row.get(1)?;
665 let relation_str: String = row.get(2)?;
666 let weight: f64 = row.get(3)?;
667 Ok((nid_str, eid_str, relation_str, weight))
668 })?;
669
670 let mut hits = Vec::new();
671 for row in rows {
672 let (nid_str, eid_str, relation_str, weight) = row?;
673 let relation = relation_str.parse::<EdgeRelation>().map_err(|e| {
674 rusqlite::Error::FromSqlConversionFailure(
675 2,
676 rusqlite::types::Type::Text,
677 Box::new(e),
678 )
679 })?;
680 hits.push(NeighborHit {
681 node_id: parse_uuid(&nid_str)?,
682 edge_id: parse_uuid(&eid_str)?,
683 relation,
684 weight,
685 name: None,
686 kind: None,
687 });
688 }
689
690 Ok(hits)
691 })
692 .await
693 }
694
695 async fn traverse(&self, request: TraversalRequest) -> Result<Vec<GraphPath>, StorageError> {
696 use khive_storage::types::Direction;
697
698 if request.roots.is_empty() {
699 return Ok(Vec::new());
700 }
701
702 let roots = request.roots.clone();
703 let opts = request.options.clone();
704 let include_roots = request.include_roots;
705 let namespace = self.namespace.clone();
706
707 self.with_reader("traverse", move |conn| {
708 let mut all_paths: Vec<GraphPath> = Vec::new();
709
710 for root_id in &roots {
711 let root_str = root_id.to_string();
712
713 let (join_condition, next_node) = match opts.direction {
714 Direction::Out => ("e.source_id = t.node_id", "e.target_id"),
715 Direction::In => ("e.target_id = t.node_id", "e.source_id"),
716 Direction::Both => (
717 "(e.source_id = t.node_id OR e.target_id = t.node_id)",
718 "CASE WHEN e.source_id = t.node_id THEN e.target_id ELSE e.source_id END",
719 ),
720 };
721
722 let mut relation_cond = String::new();
723 let mut relation_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
724 let mut param_idx = 4;
725
726 if let Some(ref rels) = opts.relations {
727 if !rels.is_empty() {
728 let placeholders: Vec<String> = rels
729 .iter()
730 .map(|r| {
731 relation_params.push(Box::new(r.to_string()));
732 let p = format!("?{}", param_idx);
733 param_idx += 1;
734 p
735 })
736 .collect();
737 relation_cond =
738 format!(" AND e.relation IN ({})", placeholders.join(","));
739 }
740 }
741
742 let mut weight_cond = String::new();
743 if let Some(min_w) = opts.min_weight {
744 relation_params.push(Box::new(min_w));
745 weight_cond = format!(" AND e.weight >= ?{}", param_idx);
746 param_idx += 1;
747 }
748
749 let limit_clause = if let Some(lim) = opts.limit {
750 relation_params.push(Box::new(lim as i64));
751 format!(" LIMIT ?{}", param_idx)
752 } else {
753 String::new()
754 };
755
756 let cte_sql = format!(
757 "WITH RECURSIVE traversal(node_id, edge_id, depth, path, total_weight) AS (\
758 SELECT ?2, NULL, 0, ?2, 0.0 \
759 UNION ALL \
760 SELECT {next_node}, e.id, t.depth + 1, \
761 t.path || ',' || {next_node}, \
762 t.total_weight + e.weight \
763 FROM graph_edges e \
764 JOIN traversal t ON {join_condition} \
765 WHERE e.namespace = ?1 \
766 AND e.deleted_at IS NULL \
767 AND t.depth < ?3 \
768 AND (',' || t.path || ',') NOT LIKE '%,' || {next_node} || ',%'{rel_cond}{wt_cond} \
769 ) \
770 SELECT node_id, edge_id, depth, path, total_weight \
771 FROM traversal WHERE depth > 0 \
772 ORDER BY depth{limit}",
773 next_node = next_node,
774 join_condition = join_condition,
775 rel_cond = relation_cond,
776 wt_cond = weight_cond,
777 limit = limit_clause,
778 );
779
780 let mut stmt = conn.prepare(&cte_sql)?;
781
782 let mut all_params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
783 all_params.push(Box::new(namespace.clone()));
784 all_params.push(Box::new(root_str.clone()));
785 all_params.push(Box::new(opts.max_depth as i64));
786 all_params.extend(relation_params);
787
788 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
789 all_params.iter().map(|p| p.as_ref()).collect();
790
791 let rows = stmt.query_map(param_refs.as_slice(), |row| {
792 let node_str: String = row.get(0)?;
793 let edge_str: Option<String> = row.get(1)?;
794 let depth: i64 = row.get(2)?;
795 let _path: String = row.get(3)?;
796 let total_weight: f64 = row.get(4)?;
797 Ok((node_str, edge_str, depth, total_weight))
798 })?;
799
800 let mut nodes = Vec::new();
801 let mut max_weight = 0.0f64;
802
803 if include_roots {
804 nodes.push(PathNode {
805 node_id: *root_id,
806 via_edge: None,
807 depth: 0,
808 name: None,
809 kind: None,
810 });
811 }
812
813 for row in rows {
814 let (node_str, edge_str, depth, total_weight) = row?;
815 let node_id = parse_uuid(&node_str)?;
816 let via_edge = edge_str.map(|s| parse_uuid(&s)).transpose()?;
817 nodes.push(PathNode {
818 node_id,
819 via_edge,
820 depth: depth as usize,
821 name: None,
822 kind: None,
823 });
824 if total_weight > max_weight {
825 max_weight = total_weight;
826 }
827 }
828
829 if nodes.len() > if include_roots { 1 } else { 0 } || include_roots {
830 all_paths.push(GraphPath {
831 root_id: *root_id,
832 nodes,
833 total_weight: max_weight,
834 });
835 }
836 }
837
838 Ok(all_paths)
839 })
840 .await
841 }
842}
843
844const GRAPH_DDL: &str = "\
849 CREATE TABLE IF NOT EXISTS graph_edges (\
850 namespace TEXT NOT NULL,\
851 id TEXT NOT NULL,\
852 source_id TEXT NOT NULL,\
853 target_id TEXT NOT NULL,\
854 relation TEXT NOT NULL,\
855 weight REAL NOT NULL DEFAULT 1.0,\
856 created_at INTEGER NOT NULL,\
857 updated_at INTEGER NOT NULL,\
858 deleted_at INTEGER,\
859 metadata TEXT,\
860 target_backend TEXT,\
861 PRIMARY KEY (namespace, id)\
862 );\
863 CREATE UNIQUE INDEX IF NOT EXISTS idx_graph_edges_unique_triple ON graph_edges(namespace, source_id, target_id, relation);\
864 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_source ON graph_edges(namespace, source_id);\
865 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_target ON graph_edges(namespace, target_id);\
866 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_relation ON graph_edges(namespace, relation);\
867 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_src_rel ON graph_edges(namespace, source_id, relation);\
868 CREATE INDEX IF NOT EXISTS idx_graph_edges_ns_tgt_rel ON graph_edges(namespace, target_id, relation);\
869 CREATE INDEX IF NOT EXISTS idx_graph_edges_target_backend ON graph_edges(target_backend) WHERE target_backend IS NOT NULL;\
870";
871
872pub(crate) fn ensure_graph_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
873 conn.execute_batch(GRAPH_DDL)
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879 use crate::pool::PoolConfig;
880 use khive_storage::types::{Direction, TraversalOptions};
881
882 fn setup_memory_store() -> SqlGraphStore {
883 let config = PoolConfig {
884 path: None,
885 ..PoolConfig::default()
886 };
887 let pool = Arc::new(ConnectionPool::new(config).unwrap());
888
889 {
890 let writer = pool.writer().unwrap();
891 writer.conn().execute_batch(GRAPH_DDL).unwrap();
892 }
893
894 SqlGraphStore::new_scoped(pool, false, "default")
895 }
896
897 fn make_edge(source: Uuid, target: Uuid, relation: EdgeRelation, weight: f64) -> Edge {
898 let now = Utc::now();
899 Edge {
900 id: Uuid::new_v4().into(),
901 namespace: "default".to_string(),
902 source_id: source,
903 target_id: target,
904 relation,
905 weight,
906 created_at: now,
907 updated_at: now,
908 deleted_at: None,
909 metadata: None,
910 target_backend: None,
911 }
912 }
913
914 #[tokio::test]
915 async fn test_upsert_and_get_edge() {
916 let store = setup_memory_store();
917
918 let src = Uuid::new_v4();
919 let tgt = Uuid::new_v4();
920 let now = Utc::now();
921 let edge = Edge {
922 id: Uuid::new_v4().into(),
923 namespace: "default".to_string(),
924 source_id: src,
925 target_id: tgt,
926 relation: EdgeRelation::Extends,
927 weight: 0.8,
928 created_at: now,
929 updated_at: now,
930 deleted_at: None,
931 metadata: None,
932 target_backend: None,
933 };
934 let edge_id = edge.id;
935
936 store.upsert_edge(edge).await.unwrap();
937
938 let fetched = store.get_edge(edge_id).await.unwrap();
939 assert!(fetched.is_some());
940 let fetched = fetched.unwrap();
941 assert_eq!(fetched.id, edge_id);
942 assert_eq!(fetched.namespace, "default");
943 assert_eq!(fetched.source_id, src);
944 assert_eq!(fetched.target_id, tgt);
945 assert_eq!(fetched.relation, EdgeRelation::Extends);
946 assert!((fetched.weight - 0.8).abs() < 1e-9);
947 }
948
949 #[tokio::test]
950 async fn test_delete_edge() {
951 let store = setup_memory_store();
952
953 let edge = make_edge(Uuid::new_v4(), Uuid::new_v4(), EdgeRelation::Contains, 1.0);
954 let edge_id = edge.id;
955
956 store.upsert_edge(edge).await.unwrap();
957 assert!(store.get_edge(edge_id).await.unwrap().is_some());
958
959 let deleted = store.delete_edge(edge_id, DeleteMode::Hard).await.unwrap();
960 assert!(deleted);
961
962 assert!(store.get_edge(edge_id).await.unwrap().is_none());
963
964 let deleted_again = store.delete_edge(edge_id, DeleteMode::Hard).await.unwrap();
965 assert!(!deleted_again);
966 }
967
968 #[tokio::test]
969 async fn test_count_edges() {
970 let store = setup_memory_store();
971
972 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 0);
973
974 for _ in 0..5 {
975 store
976 .upsert_edge(make_edge(
977 Uuid::new_v4(),
978 Uuid::new_v4(),
979 EdgeRelation::DependsOn,
980 1.0,
981 ))
982 .await
983 .unwrap();
984 }
985
986 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 5);
987 }
988
989 #[tokio::test]
990 async fn test_neighbors_outbound() {
991 let store = setup_memory_store();
992
993 let a = Uuid::new_v4();
994 let b = Uuid::new_v4();
995 let c = Uuid::new_v4();
996 let d = Uuid::new_v4();
997
998 store
999 .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
1000 .await
1001 .unwrap();
1002 store
1003 .upsert_edge(make_edge(a, c, EdgeRelation::DependsOn, 0.7))
1004 .await
1005 .unwrap();
1006 store
1007 .upsert_edge(make_edge(d, a, EdgeRelation::Extends, 0.5))
1008 .await
1009 .unwrap();
1010
1011 let query = NeighborQuery {
1012 direction: Direction::Out,
1013 relations: None,
1014 limit: None,
1015 min_weight: None,
1016 };
1017
1018 let hits = store.neighbors(a, query).await.unwrap();
1019 assert_eq!(hits.len(), 2);
1020
1021 let neighbor_ids: Vec<Uuid> = hits.iter().map(|h| h.node_id).collect();
1022 assert!(neighbor_ids.contains(&b));
1023 assert!(neighbor_ids.contains(&c));
1024 assert!(!neighbor_ids.contains(&d));
1025 }
1026
1027 #[tokio::test]
1028 async fn test_traverse_depth_2() {
1029 let store = setup_memory_store();
1030
1031 let a = Uuid::new_v4();
1032 let b = Uuid::new_v4();
1033 let c = Uuid::new_v4();
1034 let d = Uuid::new_v4();
1035
1036 store
1037 .upsert_edge(make_edge(a, b, EdgeRelation::Extends, 1.0))
1038 .await
1039 .unwrap();
1040 store
1041 .upsert_edge(make_edge(b, c, EdgeRelation::Extends, 2.0))
1042 .await
1043 .unwrap();
1044 store
1045 .upsert_edge(make_edge(c, d, EdgeRelation::Extends, 3.0))
1046 .await
1047 .unwrap();
1048
1049 let request = TraversalRequest {
1050 roots: vec![a],
1051 options: TraversalOptions::new(2).with_direction(Direction::Out),
1052 include_roots: true,
1053 };
1054
1055 let paths = store.traverse(request).await.unwrap();
1056 assert_eq!(paths.len(), 1);
1057
1058 let path = &paths[0];
1059 let node_ids: Vec<Uuid> = path.nodes.iter().map(|n| n.node_id).collect();
1060 assert!(node_ids.contains(&a));
1061 assert!(node_ids.contains(&b));
1062 assert!(node_ids.contains(&c));
1063 assert!(!node_ids.contains(&d));
1064 }
1065
1066 #[tokio::test]
1067 async fn test_metadata_roundtrip() {
1068 let store = setup_memory_store();
1069
1070 let src = Uuid::new_v4();
1071 let tgt = Uuid::new_v4();
1072 let meta = serde_json::json!({"note": "important link", "confidence": 0.95});
1073 let now = Utc::now();
1074 let edge = Edge {
1075 id: Uuid::new_v4().into(),
1076 namespace: "default".to_string(),
1077 source_id: src,
1078 target_id: tgt,
1079 relation: EdgeRelation::Implements,
1080 weight: 0.9,
1081 created_at: now,
1082 updated_at: now,
1083 deleted_at: None,
1084 metadata: Some(meta.clone()),
1085 target_backend: None,
1086 };
1087 let edge_id = edge.id;
1088
1089 store.upsert_edge(edge).await.unwrap();
1090
1091 let fetched = store.get_edge(edge_id).await.unwrap().unwrap();
1092 assert_eq!(
1093 fetched.metadata.as_ref(),
1094 Some(&meta),
1095 "metadata must survive a write/read roundtrip via get_edge"
1096 );
1097
1098 let page = store
1100 .query_edges(EdgeFilter::default(), vec![], PageRequest::default())
1101 .await
1102 .unwrap();
1103 let from_query = page
1104 .items
1105 .iter()
1106 .find(|e| e.id == edge_id)
1107 .expect("edge must appear in query_edges result");
1108 assert_eq!(
1109 from_query.metadata.as_ref(),
1110 Some(&meta),
1111 "metadata must survive a write/read roundtrip via query_edges"
1112 );
1113 }
1114
1115 #[tokio::test]
1116 async fn test_upsert_edges_batch() {
1117 let store = setup_memory_store();
1118
1119 let edges: Vec<Edge> = (0..10)
1120 .map(|i| {
1121 make_edge(
1122 Uuid::new_v4(),
1123 Uuid::new_v4(),
1124 EdgeRelation::Implements,
1125 i as f64,
1126 )
1127 })
1128 .collect();
1129
1130 let summary = store.upsert_edges(edges).await.unwrap();
1131 assert_eq!(summary.attempted, 10);
1132 assert_eq!(summary.affected, 10);
1133 assert_eq!(summary.failed, 0);
1134
1135 assert_eq!(store.count_edges(EdgeFilter::default()).await.unwrap(), 10);
1136 }
1137
1138 #[tokio::test]
1141 async fn graph_duplicate_edges_ignored() {
1142 let store = setup_memory_store();
1143
1144 let src = Uuid::new_v4();
1145 let tgt = Uuid::new_v4();
1146
1147 let now = Utc::now();
1149 let edge1 = Edge {
1150 id: Uuid::new_v4().into(),
1151 namespace: "default".to_string(),
1152 source_id: src,
1153 target_id: tgt,
1154 relation: EdgeRelation::Extends,
1155 weight: 1.0,
1156 created_at: now,
1157 updated_at: now,
1158 deleted_at: None,
1159 metadata: None,
1160 target_backend: None,
1161 };
1162 let edge2 = Edge {
1163 id: Uuid::new_v4().into(),
1164 namespace: "default".to_string(),
1165 source_id: src,
1166 target_id: tgt,
1167 relation: EdgeRelation::Extends,
1168 weight: 0.5,
1169 created_at: now,
1170 updated_at: now,
1171 deleted_at: None,
1172 metadata: None,
1173 target_backend: None,
1174 };
1175
1176 store.upsert_edge(edge1).await.unwrap();
1177 store.upsert_edge(edge2).await.unwrap();
1178
1179 assert_eq!(
1180 store.count_edges(EdgeFilter::default()).await.unwrap(),
1181 1,
1182 "duplicate (source, target, relation) triple must be ignored; only one edge must exist"
1183 );
1184 }
1185
1186 #[tokio::test]
1189 async fn graph_duplicate_edges_refresh_existing_row() {
1190 let store = setup_memory_store();
1191 let src = Uuid::new_v4();
1192 let tgt = Uuid::new_v4();
1193
1194 let now = Utc::now();
1195 let edge1 = Edge {
1196 id: Uuid::new_v4().into(),
1197 namespace: "default".to_string(),
1198 source_id: src,
1199 target_id: tgt,
1200 relation: EdgeRelation::Extends,
1201 weight: 1.0,
1202 created_at: now,
1203 updated_at: now,
1204 deleted_at: None,
1205 metadata: None,
1206 target_backend: None,
1207 };
1208 let edge2 = Edge {
1209 id: Uuid::new_v4().into(),
1210 namespace: "default".to_string(),
1211 source_id: src,
1212 target_id: tgt,
1213 relation: EdgeRelation::Extends,
1214 weight: 0.5,
1215 created_at: now,
1216 updated_at: now,
1217 deleted_at: None,
1218 metadata: None,
1219 target_backend: None,
1220 };
1221
1222 store.upsert_edge(edge1).await.unwrap();
1223 store.upsert_edge(edge2).await.unwrap();
1224
1225 let edges = store
1226 .query_edges(EdgeFilter::default(), vec![], PageRequest::default())
1227 .await
1228 .unwrap();
1229 assert_eq!(
1230 edges.items.len(),
1231 1,
1232 "duplicate natural key must collapse to one row"
1233 );
1234 assert!(
1235 (edges.items[0].weight - 0.5).abs() < 0.001,
1236 "F053: natural-key conflict must DO UPDATE (weight=0.5 from second upsert); \
1237 current DO NOTHING keeps stale weight={}",
1238 edges.items[0].weight
1239 );
1240 }
1241}