1use std::str::FromStr;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use storage::SqlitePool;
14use thiserror::Error;
15use uuid::Uuid;
16
17#[derive(Debug, Error)]
19pub enum GraphError {
20 #[error("SQLite error: {0}")]
21 Sqlite(#[from] storage::sqlite::SqliteError),
22 #[error("rusqlite error: {0}")]
23 Rusqlite(#[from] rusqlite::Error),
24 #[error("invalid node body json: {0}")]
25 Body(#[from] serde_json::Error),
26 #[error("invalid timestamp: {0}")]
27 Timestamp(String),
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
35pub struct NodeKind(pub String);
36
37impl NodeKind {
38 pub fn new(s: impl Into<String>) -> Self {
39 Self(s.into())
40 }
41
42 pub fn as_str(&self) -> &str {
43 &self.0
44 }
45}
46
47impl From<&str> for NodeKind {
48 fn from(s: &str) -> Self {
49 Self(s.to_string())
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
57pub struct EdgeKind(pub String);
58
59impl EdgeKind {
60 pub fn new(s: impl Into<String>) -> Self {
61 Self(s.into())
62 }
63
64 pub fn as_str(&self) -> &str {
65 &self.0
66 }
67}
68
69impl From<&str> for EdgeKind {
70 fn from(s: &str) -> Self {
71 Self(s.to_string())
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct Node {
78 pub id: String,
79 pub session_id: Option<String>,
80 pub namespace: String,
81 pub kind: NodeKind,
82 pub body: serde_json::Value,
83 pub vector_id: Option<String>,
86 pub weight: f32,
87 pub created_at: DateTime<Utc>,
88}
89
90impl Node {
91 pub fn new(
94 kind: NodeKind,
95 body: serde_json::Value,
96 namespace: impl Into<String>,
97 session_id: Option<String>,
98 ) -> Self {
99 Self {
100 id: Uuid::new_v4().to_string(),
101 session_id,
102 namespace: namespace.into(),
103 kind,
104 body,
105 vector_id: None,
106 weight: 1.0,
107 created_at: Utc::now(),
108 }
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct Edge {
115 pub src: String,
116 pub dst: String,
117 pub kind: EdgeKind,
118 pub weight: f32,
119 pub created_at: DateTime<Utc>,
120}
121
122impl Edge {
123 pub fn new(src: impl Into<String>, dst: impl Into<String>, kind: EdgeKind) -> Self {
124 Self {
125 src: src.into(),
126 dst: dst.into(),
127 kind,
128 weight: 1.0,
129 created_at: Utc::now(),
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
139pub struct GraphHit {
140 pub id: String,
141 pub text: String,
143 pub rank: f64,
145 pub namespace: String,
146 pub node_kind: NodeKind,
147 pub created_at: DateTime<Utc>,
148 pub weight: f32,
149 pub vector_id: Option<String>,
150}
151
152pub trait EpisodicGraph: Send + Sync {
156 fn add_node(&self, node: &Node) -> Result<(), GraphError>;
157 fn add_edge(&self, edge: &Edge) -> Result<(), GraphError>;
158 fn search_text(
164 &self,
165 query: &str,
166 limit: usize,
167 namespace: Option<&str>,
168 ) -> Result<Vec<GraphHit>, GraphError>;
169 fn get_node(&self, id: &str) -> Result<Option<Node>, GraphError>;
170 fn neighbors(&self, id: &str) -> Result<Vec<(Edge, Node)>, GraphError>;
174 fn incoming(&self, id: &str) -> Result<Vec<(Edge, Node)>, GraphError>;
176 fn path(&self, src: &str, dst: &str, max_hops: u32) -> Result<Option<Vec<String>>, GraphError>;
180 fn list_all_nodes(&self) -> Result<Vec<Node>, GraphError>;
184 fn update_weight(&self, id: &str, weight: f32) -> Result<(), GraphError>;
187 fn delete_node(&self, id: &str) -> Result<bool, GraphError>;
192}
193
194pub struct SqliteGraph {
196 db: SqlitePool,
197}
198
199impl SqliteGraph {
200 pub fn new(db: SqlitePool) -> Self {
201 Self { db }
202 }
203
204 pub fn pool(&self) -> &SqlitePool {
205 &self.db
206 }
207}
208
209fn parse_ts(s: &str) -> Result<DateTime<Utc>, GraphError> {
210 if let Ok(dt) = DateTime::parse_from_rfc3339(s) {
213 return Ok(dt.with_timezone(&Utc));
214 }
215 if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
216 return Ok(naive.and_utc());
217 }
218 Err(GraphError::Timestamp(s.to_string()))
219}
220
221fn row_to_node(row: &rusqlite::Row<'_>) -> rusqlite::Result<NodeRow> {
222 Ok(NodeRow {
223 id: row.get(0)?,
224 session_id: row.get(1)?,
225 namespace: row.get(2)?,
226 node_kind: row.get(3)?,
227 body_json: row.get(4)?,
228 vector_id: row.get(5)?,
229 weight: row.get::<_, f64>(6)? as f32,
230 created_at: row.get(7)?,
231 })
232}
233
234struct NodeRow {
235 id: String,
236 session_id: Option<String>,
237 namespace: String,
238 node_kind: String,
239 body_json: String,
240 vector_id: Option<String>,
241 weight: f32,
242 created_at: String,
243}
244
245struct HitRow {
248 id: String,
249 text: String,
250 rank: f64,
251 namespace: String,
252 node_kind: String,
253 created_at: String,
254 weight: f64,
255 vector_id: Option<String>,
256}
257
258impl NodeRow {
259 fn into_node(self) -> Result<Node, GraphError> {
260 Ok(Node {
261 id: self.id,
262 session_id: self.session_id,
263 namespace: self.namespace,
264 kind: NodeKind(self.node_kind),
265 body: serde_json::Value::from_str(&self.body_json)?,
266 vector_id: self.vector_id,
267 weight: self.weight,
268 created_at: parse_ts(&self.created_at)?,
269 })
270 }
271}
272
273struct EdgeRow {
274 src: String,
275 dst: String,
276 edge_kind: String,
277 weight: f32,
278 created_at: String,
279}
280
281impl EdgeRow {
282 fn into_edge(self) -> Result<Edge, GraphError> {
283 Ok(Edge {
284 src: self.src,
285 dst: self.dst,
286 kind: EdgeKind(self.edge_kind),
287 weight: self.weight,
288 created_at: parse_ts(&self.created_at)?,
289 })
290 }
291}
292
293impl EpisodicGraph for SqliteGraph {
294 fn add_node(&self, node: &Node) -> Result<(), GraphError> {
295 let body = serde_json::to_string(&node.body)?;
296 let created = node.created_at.to_rfc3339();
297 self.db.with_conn(|conn| {
298 conn.execute(
299 "INSERT INTO nodes
300 (id, session_id, namespace, node_kind, body_json,
301 vector_id, weight, created_at)
302 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
303 rusqlite::params![
304 node.id,
305 node.session_id,
306 node.namespace,
307 node.kind.0,
308 body,
309 node.vector_id,
310 node.weight as f64,
311 created,
312 ],
313 )?;
314 Ok(())
315 })?;
316 Ok(())
317 }
318
319 fn add_edge(&self, edge: &Edge) -> Result<(), GraphError> {
320 let created = edge.created_at.to_rfc3339();
321 self.db.with_conn(|conn| {
322 conn.execute(
323 "INSERT INTO edges
324 (src_id, dst_id, edge_kind, weight, created_at)
325 VALUES (?1, ?2, ?3, ?4, ?5)",
326 rusqlite::params![edge.src, edge.dst, edge.kind.0, edge.weight as f64, created,],
327 )?;
328 Ok(())
329 })?;
330 Ok(())
331 }
332
333 fn search_text(
334 &self,
335 query: &str,
336 limit: usize,
337 namespace: Option<&str>,
338 ) -> Result<Vec<GraphHit>, GraphError> {
339 let sanitized = crate::episodic::sanitize_fts5_query(query);
340 if sanitized.is_empty() {
341 return Ok(Vec::new());
342 }
343
344 let raw: Vec<HitRow> = self.db.with_conn(|conn| {
349 let mut sql = String::from(
354 "SELECT n.id, f.text, f.rank, n.namespace, n.node_kind,
355 n.created_at, n.weight, n.vector_id
356 FROM nodes_fts f
357 JOIN nodes n ON n.rowid = f.rowid
358 WHERE nodes_fts MATCH ?1",
359 );
360 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(sanitized)];
361
362 if let Some(ns) = namespace {
363 sql.push_str(&format!(
364 " AND (n.namespace = ?{} OR n.namespace LIKE ?{})",
365 params.len() + 1,
366 params.len() + 2
367 ));
368 params.push(Box::new(ns.to_string()));
369 params.push(Box::new(format!("{ns}/%")));
370 }
371
372 sql.push_str(&format!(" ORDER BY f.rank LIMIT ?{}", params.len() + 1));
373 params.push(Box::new(limit as i64));
374
375 let mut stmt = conn.prepare(&sql)?;
376 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
377 params.iter().map(|p| p.as_ref()).collect();
378 let rows = stmt
379 .query_map(param_refs.as_slice(), |row| {
380 Ok(HitRow {
381 id: row.get(0)?,
382 text: row.get(1)?,
383 rank: row.get(2)?,
384 namespace: row.get(3)?,
385 node_kind: row.get(4)?,
386 created_at: row.get(5)?,
387 weight: row.get(6)?,
388 vector_id: row.get(7)?,
389 })
390 })?
391 .collect::<Result<Vec<_>, _>>()?;
392 Ok(rows)
393 })?;
394
395 raw.into_iter()
396 .map(|r| {
397 Ok(GraphHit {
398 id: r.id,
399 text: r.text,
400 rank: r.rank,
401 namespace: r.namespace,
402 node_kind: NodeKind(r.node_kind),
403 created_at: parse_ts(&r.created_at)?,
404 weight: r.weight as f32,
405 vector_id: r.vector_id,
406 })
407 })
408 .collect()
409 }
410
411 fn get_node(&self, id: &str) -> Result<Option<Node>, GraphError> {
412 let row: Option<NodeRow> = self.db.with_conn(|conn| {
413 let mut stmt = conn.prepare(
414 "SELECT id, session_id, namespace, node_kind, body_json,
415 vector_id, weight, created_at
416 FROM nodes WHERE id = ?1",
417 )?;
418 let mut rows = stmt.query([id])?;
419 if let Some(row) = rows.next()? {
420 Ok(Some(row_to_node(row)?))
421 } else {
422 Ok(None)
423 }
424 })?;
425 row.map(NodeRow::into_node).transpose()
426 }
427
428 fn neighbors(&self, id: &str) -> Result<Vec<(Edge, Node)>, GraphError> {
429 let raw: Vec<(EdgeRow, NodeRow)> = self.db.with_conn(|conn| {
430 let mut stmt = conn.prepare(
431 "SELECT e.src_id, e.dst_id, e.edge_kind, e.weight, e.created_at,
432 n.id, n.session_id, n.namespace, n.node_kind, n.body_json,
433 n.vector_id, n.weight, n.created_at
434 FROM edges e JOIN nodes n ON n.id = e.dst_id
435 WHERE e.src_id = ?1
436 ORDER BY e.created_at",
437 )?;
438 let mut rows = stmt.query([id])?;
439 let mut out = Vec::new();
440 while let Some(row) = rows.next()? {
441 let edge = EdgeRow {
442 src: row.get(0)?,
443 dst: row.get(1)?,
444 edge_kind: row.get(2)?,
445 weight: row.get::<_, f64>(3)? as f32,
446 created_at: row.get(4)?,
447 };
448 let node = NodeRow {
449 id: row.get(5)?,
450 session_id: row.get(6)?,
451 namespace: row.get(7)?,
452 node_kind: row.get(8)?,
453 body_json: row.get(9)?,
454 vector_id: row.get(10)?,
455 weight: row.get::<_, f64>(11)? as f32,
456 created_at: row.get(12)?,
457 };
458 out.push((edge, node));
459 }
460 Ok(out)
461 })?;
462 raw.into_iter()
463 .map(|(e, n)| Ok((e.into_edge()?, n.into_node()?)))
464 .collect()
465 }
466
467 fn incoming(&self, id: &str) -> Result<Vec<(Edge, Node)>, GraphError> {
468 let raw: Vec<(EdgeRow, NodeRow)> = self.db.with_conn(|conn| {
469 let mut stmt = conn.prepare(
470 "SELECT e.src_id, e.dst_id, e.edge_kind, e.weight, e.created_at,
471 n.id, n.session_id, n.namespace, n.node_kind, n.body_json,
472 n.vector_id, n.weight, n.created_at
473 FROM edges e JOIN nodes n ON n.id = e.src_id
474 WHERE e.dst_id = ?1
475 ORDER BY e.created_at",
476 )?;
477 let mut rows = stmt.query([id])?;
478 let mut out = Vec::new();
479 while let Some(row) = rows.next()? {
480 let edge = EdgeRow {
481 src: row.get(0)?,
482 dst: row.get(1)?,
483 edge_kind: row.get(2)?,
484 weight: row.get::<_, f64>(3)? as f32,
485 created_at: row.get(4)?,
486 };
487 let node = NodeRow {
488 id: row.get(5)?,
489 session_id: row.get(6)?,
490 namespace: row.get(7)?,
491 node_kind: row.get(8)?,
492 body_json: row.get(9)?,
493 vector_id: row.get(10)?,
494 weight: row.get::<_, f64>(11)? as f32,
495 created_at: row.get(12)?,
496 };
497 out.push((edge, node));
498 }
499 Ok(out)
500 })?;
501 raw.into_iter()
502 .map(|(e, n)| Ok((e.into_edge()?, n.into_node()?)))
503 .collect()
504 }
505
506 fn list_all_nodes(&self) -> Result<Vec<Node>, GraphError> {
507 let rows: Vec<NodeRow> = self.db.with_conn(|conn| {
508 let mut stmt = conn.prepare(
509 "SELECT id, session_id, namespace, node_kind, body_json,
510 vector_id, weight, created_at
511 FROM nodes",
512 )?;
513 let mut rows = stmt.query([])?;
514 let mut out = Vec::new();
515 while let Some(row) = rows.next()? {
516 out.push(row_to_node(row)?);
517 }
518 Ok(out)
519 })?;
520 rows.into_iter().map(NodeRow::into_node).collect()
521 }
522
523 fn update_weight(&self, id: &str, weight: f32) -> Result<(), GraphError> {
524 self.db.with_conn(|conn| {
525 conn.execute(
526 "UPDATE nodes SET weight = ?1 WHERE id = ?2",
527 rusqlite::params![weight as f64, id],
528 )?;
529 Ok(())
530 })?;
531 Ok(())
532 }
533
534 fn delete_node(&self, id: &str) -> Result<bool, GraphError> {
535 let deleted = self.db.with_conn(|conn| {
538 conn.execute("PRAGMA foreign_keys = ON", [])?;
539 let n = conn.execute("DELETE FROM nodes WHERE id = ?1", [id])?;
540 Ok(n > 0)
541 })?;
542 Ok(deleted)
543 }
544
545 fn path(&self, src: &str, dst: &str, max_hops: u32) -> Result<Option<Vec<String>>, GraphError> {
546 let max_depth = max_hops.max(1) as i64;
550 let result: Option<String> = self.db.with_conn(|conn| {
551 let mut stmt = conn.prepare(
552 "WITH RECURSIVE walk(node_id, depth, path) AS (
553 SELECT ?1, 0, ?1
554 UNION ALL
555 SELECT e.dst_id, w.depth + 1, w.path || '\u{1f}' || e.dst_id
556 FROM edges e JOIN walk w ON e.src_id = w.node_id
557 WHERE w.depth < ?2
558 AND instr(w.path || '\u{1f}', e.dst_id || '\u{1f}') = 0
559 )
560 SELECT path FROM walk WHERE node_id = ?3
561 ORDER BY depth LIMIT 1",
562 )?;
563 let mut rows = stmt.query(rusqlite::params![src, max_depth, dst])?;
564 if let Some(row) = rows.next()? {
565 let p: String = row.get(0)?;
566 Ok(Some(p))
567 } else {
568 Ok(None)
569 }
570 })?;
571 Ok(result.map(|p| p.split('\u{1f}').map(|s| s.to_string()).collect()))
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578
579 fn store() -> SqliteGraph {
580 SqliteGraph::new(SqlitePool::open_memory().expect("memory pool"))
581 }
582
583 fn node(kind: &str, name: &str) -> Node {
584 Node::new(
585 NodeKind::new(kind),
586 serde_json::json!({"name": name}),
587 "personal",
588 None,
589 )
590 }
591
592 #[test]
593 fn search_text_finds_node_by_body_term() {
594 let g = store();
595 let mut n = Node::new(
596 NodeKind::new("tool_call"),
597 serde_json::json!({"verb": "terminal.open", "program": "ripgrep"}),
598 "personal",
599 None,
600 );
601 n.weight = 0.5;
602 g.add_node(&n).unwrap();
603 g.add_node(&node("terminal_event", "cargo")).unwrap();
605
606 let hits = g.search_text("ripgrep", 10, None).unwrap();
607 assert_eq!(hits.len(), 1, "only the ripgrep node should match");
608 assert_eq!(hits[0].id, n.id);
609 assert_eq!(hits[0].node_kind.as_str(), "tool_call");
610 assert!((hits[0].weight - 0.5).abs() < 1e-6);
611 }
612
613 #[test]
614 fn search_text_respects_namespace_scope() {
615 let g = store();
616 let work = Node::new(
617 NodeKind::new("tool_call"),
618 serde_json::json!({"program": "deploy"}),
619 "work",
620 None,
621 );
622 let personal = Node::new(
623 NodeKind::new("tool_call"),
624 serde_json::json!({"program": "deploy"}),
625 "personal",
626 None,
627 );
628 g.add_node(&work).unwrap();
629 g.add_node(&personal).unwrap();
630
631 let hits = g.search_text("deploy", 10, Some("work")).unwrap();
632 assert_eq!(hits.len(), 1);
633 assert_eq!(hits[0].id, work.id);
634 }
635
636 #[test]
637 fn search_text_index_stays_in_sync_on_delete() {
638 let g = store();
639 let n = Node::new(
640 NodeKind::new("tool_call"),
641 serde_json::json!({"program": "ephemeral-tool"}),
642 "personal",
643 None,
644 );
645 g.add_node(&n).unwrap();
646 assert_eq!(g.search_text("ephemeral-tool", 10, None).unwrap().len(), 1);
647 assert!(g.delete_node(&n.id).unwrap());
648 assert!(
649 g.search_text("ephemeral-tool", 10, None)
650 .unwrap()
651 .is_empty(),
652 "delete trigger must drop the FTS row"
653 );
654 }
655
656 #[test]
657 fn search_text_empty_query_returns_empty() {
658 let g = store();
659 g.add_node(&node("tool_call", "x")).unwrap();
660 assert!(g.search_text(" ", 10, None).unwrap().is_empty());
661 assert!(g.search_text("!@#$", 10, None).unwrap().is_empty());
662 }
663
664 #[test]
665 fn round_trip_node() {
666 let g = store();
667 let mut n = node("tool_call", "echo");
668 n.weight = 0.75;
669 n.vector_id = Some("vec-123".into());
670 g.add_node(&n).unwrap();
671 let got = g.get_node(&n.id).unwrap().expect("node should exist");
672 assert_eq!(got.id, n.id);
673 assert_eq!(got.kind, n.kind);
674 assert_eq!(got.body, n.body);
675 assert!((got.weight - 0.75).abs() < 1e-6);
676 assert_eq!(got.vector_id.as_deref(), Some("vec-123"));
677 }
678
679 #[test]
680 fn get_node_returns_none_for_missing() {
681 let g = store();
682 assert!(g.get_node("nope").unwrap().is_none());
683 }
684
685 #[test]
686 fn neighbors_returns_outgoing_edges_with_destination_nodes() {
687 let g = store();
688 let a = node("tool_call", "a");
689 let b = node("tool_event", "b");
690 let c = node("tool_event", "c");
691 g.add_node(&a).unwrap();
692 g.add_node(&b).unwrap();
693 g.add_node(&c).unwrap();
694 g.add_edge(&Edge::new(&a.id, &b.id, EdgeKind::new("causal_produced")))
695 .unwrap();
696 g.add_edge(&Edge::new(&a.id, &c.id, EdgeKind::new("references")))
697 .unwrap();
698
699 let nb = g.neighbors(&a.id).unwrap();
700 assert_eq!(nb.len(), 2);
701 let kinds: Vec<&str> = nb.iter().map(|(e, _)| e.kind.as_str()).collect();
702 assert!(kinds.contains(&"causal_produced"));
703 assert!(kinds.contains(&"references"));
704 let dst_ids: std::collections::HashSet<&str> =
705 nb.iter().map(|(_, n)| n.id.as_str()).collect();
706 assert!(dst_ids.contains(b.id.as_str()));
707 assert!(dst_ids.contains(c.id.as_str()));
708 }
709
710 #[test]
711 fn incoming_is_the_reverse_view() {
712 let g = store();
713 let a = node("t", "a");
714 let b = node("t", "b");
715 g.add_node(&a).unwrap();
716 g.add_node(&b).unwrap();
717 g.add_edge(&Edge::new(&a.id, &b.id, EdgeKind::new("k")))
718 .unwrap();
719
720 let inb = g.incoming(&b.id).unwrap();
721 assert_eq!(inb.len(), 1);
722 assert_eq!(inb[0].1.id, a.id);
723 assert!(g.incoming(&a.id).unwrap().is_empty());
724 }
725
726 #[test]
727 fn path_finds_chain_through_three_nodes() {
728 let g = store();
729 let a = node("t", "a");
730 let b = node("t", "b");
731 let c = node("t", "c");
732 g.add_node(&a).unwrap();
733 g.add_node(&b).unwrap();
734 g.add_node(&c).unwrap();
735 let k = EdgeKind::new("rel");
736 g.add_edge(&Edge::new(&a.id, &b.id, k.clone())).unwrap();
737 g.add_edge(&Edge::new(&b.id, &c.id, k.clone())).unwrap();
738
739 let p = g.path(&a.id, &c.id, 5).unwrap().expect("path exists");
740 assert_eq!(p, vec![a.id.clone(), b.id.clone(), c.id.clone()]);
741 }
742
743 #[test]
744 fn path_respects_max_hops() {
745 let g = store();
746 let a = node("t", "a");
747 let b = node("t", "b");
748 let c = node("t", "c");
749 g.add_node(&a).unwrap();
750 g.add_node(&b).unwrap();
751 g.add_node(&c).unwrap();
752 let k = EdgeKind::new("rel");
753 g.add_edge(&Edge::new(&a.id, &b.id, k.clone())).unwrap();
754 g.add_edge(&Edge::new(&b.id, &c.id, k.clone())).unwrap();
755 assert!(g.path(&a.id, &c.id, 1).unwrap().is_none());
757 }
758
759 #[test]
760 fn path_returns_none_when_disconnected() {
761 let g = store();
762 let a = node("t", "a");
763 let b = node("t", "b");
764 g.add_node(&a).unwrap();
765 g.add_node(&b).unwrap();
766 assert!(g.path(&a.id, &b.id, 5).unwrap().is_none());
767 }
768
769 #[test]
770 fn path_handles_cycles_without_diverging() {
771 let g = store();
772 let a = node("t", "a");
773 let b = node("t", "b");
774 g.add_node(&a).unwrap();
775 g.add_node(&b).unwrap();
776 let k = EdgeKind::new("rel");
777 g.add_edge(&Edge::new(&a.id, &b.id, k.clone())).unwrap();
779 g.add_edge(&Edge::new(&b.id, &a.id, k)).unwrap();
780 let p = g.path(&a.id, &b.id, 5).unwrap().expect("path exists");
781 assert_eq!(p, vec![a.id, b.id]);
782 }
783}