1use std::path::Path;
4
5use mentedb_core::edge::MemoryEdge;
6use mentedb_core::error::{MenteError, MenteResult};
7use mentedb_core::types::MemoryId;
8
9use crate::belief::propagate_update;
10use crate::contradiction::find_contradictions;
11use crate::csr::CsrGraph;
12use crate::traversal::extract_subgraph;
13
14pub struct GraphManager {
16 graph: CsrGraph,
17}
18
19impl GraphManager {
20 pub fn new() -> Self {
21 Self {
22 graph: CsrGraph::new(),
23 }
24 }
25
26 pub fn save(&self, dir: &Path) -> MenteResult<()> {
28 std::fs::create_dir_all(dir)?;
29 self.graph.save(&dir.join("graph.json"))
30 }
31
32 pub fn load(dir: &Path) -> MenteResult<Self> {
34 let graph = CsrGraph::load(&dir.join("graph.json"))?;
35 Ok(Self { graph })
36 }
37
38 pub fn graph(&self) -> &CsrGraph {
40 &self.graph
41 }
42
43 pub fn add_memory(&mut self, id: MemoryId) {
45 self.graph.add_node(id);
46 }
47
48 pub fn remove_memory(&mut self, id: MemoryId) {
50 self.graph.remove_node(id);
51 }
52
53 pub fn add_relationship(&mut self, edge: &MemoryEdge) -> MenteResult<()> {
55 if !self.graph.contains_node(edge.source) {
56 return Err(MenteError::MemoryNotFound(edge.source));
57 }
58 if !self.graph.contains_node(edge.target) {
59 return Err(MenteError::MemoryNotFound(edge.target));
60 }
61 self.graph.add_edge(edge);
62 Ok(())
63 }
64
65 pub fn get_context_subgraph(
67 &self,
68 center: MemoryId,
69 depth: usize,
70 ) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
71 extract_subgraph(&self.graph, center, depth)
72 }
73
74 pub fn propagate_belief_change(
76 &self,
77 id: MemoryId,
78 new_confidence: f32,
79 ) -> Vec<(MemoryId, f32)> {
80 propagate_update(&self.graph, id, new_confidence)
81 }
82
83 pub fn find_all_contradictions(&self, id: MemoryId) -> Vec<MemoryId> {
85 find_contradictions(&self.graph, id)
86 }
87
88 pub fn compact(&mut self) {
90 self.graph.compact();
91 }
92}
93
94impl Default for GraphManager {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use mentedb_core::edge::EdgeType;
104 use uuid::Uuid;
105
106 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
107 MemoryEdge {
108 source: src,
109 target: tgt,
110 edge_type: etype,
111 weight: 0.8,
112 created_at: 1000,
113 }
114 }
115
116 #[test]
117 fn test_add_memory_and_relationship() {
118 let mut mgr = GraphManager::new();
119 let a = Uuid::new_v4();
120 let b = Uuid::new_v4();
121 mgr.add_memory(a);
122 mgr.add_memory(b);
123 assert!(
124 mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
125 .is_ok()
126 );
127 }
128
129 #[test]
130 fn test_relationship_missing_node() {
131 let mut mgr = GraphManager::new();
132 let a = Uuid::new_v4();
133 let b = Uuid::new_v4();
134 mgr.add_memory(a);
135 assert!(
137 mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
138 .is_err()
139 );
140 }
141
142 #[test]
143 fn test_context_subgraph() {
144 let mut mgr = GraphManager::new();
145 let a = Uuid::new_v4();
146 let b = Uuid::new_v4();
147 let c = Uuid::new_v4();
148 mgr.add_memory(a);
149 mgr.add_memory(b);
150 mgr.add_memory(c);
151 mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
152 .unwrap();
153 mgr.add_relationship(&make_edge(b, c, EdgeType::Related))
154 .unwrap();
155
156 let (nodes, edges) = mgr.get_context_subgraph(a, 2);
157 assert_eq!(nodes.len(), 3);
158 assert_eq!(edges.len(), 2);
159 }
160
161 #[test]
162 fn test_compact() {
163 let mut mgr = GraphManager::new();
164 let a = Uuid::new_v4();
165 let b = Uuid::new_v4();
166 mgr.add_memory(a);
167 mgr.add_memory(b);
168 mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
169 .unwrap();
170 mgr.compact();
171
172 let out = mgr.graph().outgoing(a);
173 assert_eq!(out.len(), 1);
174 }
175
176 #[test]
177 fn test_belief_propagation() {
178 let mut mgr = GraphManager::new();
179 let a = Uuid::new_v4();
180 let b = Uuid::new_v4();
181 mgr.add_memory(a);
182 mgr.add_memory(b);
183 mgr.add_relationship(&MemoryEdge {
184 source: a,
185 target: b,
186 edge_type: EdgeType::Caused,
187 weight: 1.0,
188 created_at: 1000,
189 })
190 .unwrap();
191
192 let results = mgr.propagate_belief_change(a, 0.5);
193 assert!(results.len() >= 2);
194 }
195}