1use std::path::Path;
4
5use mentedb_core::edge::MemoryEdge;
6use mentedb_core::error::{MenteError, MenteResult};
7use mentedb_core::types::MemoryId;
8use parking_lot::RwLock;
9
10use crate::belief::propagate_update;
11use crate::contradiction::find_contradictions;
12use crate::csr::CsrGraph;
13use crate::traversal::extract_subgraph;
14
15pub struct GraphManager {
19 graph: RwLock<CsrGraph>,
20}
21
22impl GraphManager {
23 pub fn new() -> Self {
25 Self {
26 graph: RwLock::new(CsrGraph::new()),
27 }
28 }
29
30 pub fn save(&self, dir: &Path) -> MenteResult<()> {
32 std::fs::create_dir_all(dir)?;
33 self.graph.read().save(&dir.join("graph.json"))
34 }
35
36 pub fn load(dir: &Path) -> MenteResult<Self> {
38 let graph = CsrGraph::load(&dir.join("graph.json"))?;
39 Ok(Self {
40 graph: RwLock::new(graph),
41 })
42 }
43
44 pub fn add_memory(&self, id: MemoryId) {
46 self.graph.write().add_node(id);
47 }
48
49 pub fn remove_memory(&self, id: MemoryId) {
51 self.graph.write().remove_node(id);
52 }
53
54 pub fn add_relationship(&self, edge: &MemoryEdge) -> MenteResult<()> {
56 let mut g = self.graph.write();
57 if !g.contains_node(edge.source) {
58 return Err(MenteError::MemoryNotFound(edge.source));
59 }
60 if !g.contains_node(edge.target) {
61 return Err(MenteError::MemoryNotFound(edge.target));
62 }
63 g.add_edge(edge);
64 Ok(())
65 }
66
67 pub fn get_context_subgraph(
69 &self,
70 center: MemoryId,
71 depth: usize,
72 ) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
73 extract_subgraph(&self.graph.read(), center, depth)
74 }
75
76 pub fn propagate_belief_change(
78 &self,
79 id: MemoryId,
80 new_confidence: f32,
81 ) -> Vec<(MemoryId, f32)> {
82 propagate_update(&self.graph.read(), id, new_confidence)
83 }
84
85 pub fn find_all_contradictions(&self, id: MemoryId) -> Vec<MemoryId> {
87 find_contradictions(&self.graph.read(), id)
88 }
89
90 pub fn compact(&self) {
92 self.graph.write().compact();
93 }
94
95 pub fn strengthen_edge(&self, source: MemoryId, target: MemoryId, delta: f32) {
97 self.graph.write().strengthen_edge(source, target, delta);
98 }
99
100 pub fn read_graph(&self) -> parking_lot::RwLockReadGuard<'_, CsrGraph> {
104 self.graph.read()
105 }
106
107 pub fn graph(&self) -> parking_lot::RwLockReadGuard<'_, CsrGraph> {
109 self.graph.read()
110 }
111}
112
113impl Default for GraphManager {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use mentedb_core::edge::EdgeType;
123
124 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
125 MemoryEdge {
126 source: src,
127 target: tgt,
128 edge_type: etype,
129 weight: 0.8,
130 created_at: 1000,
131 valid_from: None,
132 valid_until: None,
133 label: None,
134 }
135 }
136
137 #[test]
138 fn test_add_memory_and_relationship() {
139 let mgr = GraphManager::new();
140 let a = MemoryId::new();
141 let b = MemoryId::new();
142 mgr.add_memory(a);
143 mgr.add_memory(b);
144 assert!(
145 mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
146 .is_ok()
147 );
148 }
149
150 #[test]
151 fn test_relationship_missing_node() {
152 let mgr = GraphManager::new();
153 let a = MemoryId::new();
154 let b = MemoryId::new();
155 mgr.add_memory(a);
156 assert!(
158 mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
159 .is_err()
160 );
161 }
162
163 #[test]
164 fn test_context_subgraph() {
165 let mgr = GraphManager::new();
166 let a = MemoryId::new();
167 let b = MemoryId::new();
168 let c = MemoryId::new();
169 mgr.add_memory(a);
170 mgr.add_memory(b);
171 mgr.add_memory(c);
172 mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
173 .unwrap();
174 mgr.add_relationship(&make_edge(b, c, EdgeType::Related))
175 .unwrap();
176
177 let (nodes, edges) = mgr.get_context_subgraph(a, 2);
178 assert_eq!(nodes.len(), 3);
179 assert_eq!(edges.len(), 2);
180 }
181
182 #[test]
183 fn test_compact() {
184 let mgr = GraphManager::new();
185 let a = MemoryId::new();
186 let b = MemoryId::new();
187 mgr.add_memory(a);
188 mgr.add_memory(b);
189 mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
190 .unwrap();
191 mgr.compact();
192
193 let out = mgr.graph().outgoing(a);
194 assert_eq!(out.len(), 1);
195 }
196
197 #[test]
198 fn test_belief_propagation() {
199 let mgr = GraphManager::new();
200 let a = MemoryId::new();
201 let b = MemoryId::new();
202 mgr.add_memory(a);
203 mgr.add_memory(b);
204 mgr.add_relationship(&MemoryEdge {
205 source: a,
206 target: b,
207 edge_type: EdgeType::Caused,
208 weight: 1.0,
209 created_at: 1000,
210 valid_from: None,
211 valid_until: None,
212 label: None,
213 })
214 .unwrap();
215
216 let results = mgr.propagate_belief_change(a, 0.5);
217 assert!(results.len() >= 2);
218 }
219}