Skip to main content

mentedb_graph/
manager.rs

1//! High-level knowledge graph manager.
2
3use std::path::Path;
4
5use mentedb_core::edge::MemoryEdge;
6use mentedb_core::error::{MenteError, MenteResult};
7use mentedb_core::types::MemoryId;
8use parking_lot::RwLock;
9
10use crate::belief::propagate_update;
11use crate::contradiction::find_contradictions;
12use crate::csr::CsrGraph;
13use crate::traversal::extract_subgraph;
14
15/// Owns a `CsrGraph` and provides high-level graph operations.
16///
17/// All methods take `&self` — internal `RwLock` handles concurrency.
18pub struct GraphManager {
19    graph: RwLock<CsrGraph>,
20}
21
22impl GraphManager {
23    /// Creates a new graph manager with an empty graph.
24    pub fn new() -> Self {
25        Self {
26            graph: RwLock::new(CsrGraph::new()),
27        }
28    }
29
30    /// Save the graph to the given directory.
31    pub fn save(&self, dir: &Path) -> MenteResult<()> {
32        std::fs::create_dir_all(dir)?;
33        self.graph.read().save(&dir.join("graph.json"))
34    }
35
36    /// Load the graph from the given directory.
37    pub fn load(dir: &Path) -> MenteResult<Self> {
38        let graph = CsrGraph::load(&dir.join("graph.json"))?;
39        Ok(Self {
40            graph: RwLock::new(graph),
41        })
42    }
43
44    /// Register a memory node in the graph.
45    pub fn add_memory(&self, id: MemoryId) {
46        self.graph.write().add_node(id);
47    }
48
49    /// Remove a memory node and all its edges.
50    pub fn remove_memory(&self, id: MemoryId) {
51        self.graph.write().remove_node(id);
52    }
53
54    /// Add a relationship (edge) between two memory nodes.
55    pub fn add_relationship(&self, edge: &MemoryEdge) -> MenteResult<()> {
56        let mut g = self.graph.write();
57        if !g.contains_node(edge.source) {
58            return Err(MenteError::MemoryNotFound(edge.source));
59        }
60        if !g.contains_node(edge.target) {
61            return Err(MenteError::MemoryNotFound(edge.target));
62        }
63        g.add_edge(edge);
64        Ok(())
65    }
66
67    /// Extract a context subgraph around a center node.
68    pub fn get_context_subgraph(
69        &self,
70        center: MemoryId,
71        depth: usize,
72    ) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
73        extract_subgraph(&self.graph.read(), center, depth)
74    }
75
76    /// Propagate a confidence change through the graph.
77    pub fn propagate_belief_change(
78        &self,
79        id: MemoryId,
80        new_confidence: f32,
81    ) -> Vec<(MemoryId, f32)> {
82        propagate_update(&self.graph.read(), id, new_confidence)
83    }
84
85    /// Find all nodes that contradict the given node.
86    pub fn find_all_contradictions(&self, id: MemoryId) -> Vec<MemoryId> {
87        find_contradictions(&self.graph.read(), id)
88    }
89
90    /// Merge the delta log into CSR/CSC compressed storage.
91    pub fn compact(&self) {
92        self.graph.write().compact();
93    }
94
95    /// Strengthen an edge weight (Hebbian learning: neurons that fire together wire together).
96    pub fn strengthen_edge(&self, source: MemoryId, target: MemoryId, delta: f32) {
97        self.graph.write().strengthen_edge(source, target, delta);
98    }
99
100    /// Access the underlying graph for read-only traversals.
101    ///
102    /// Returns a read guard — hold it only briefly to avoid blocking writers.
103    pub fn read_graph(&self) -> parking_lot::RwLockReadGuard<'_, CsrGraph> {
104        self.graph.read()
105    }
106
107    /// Alias for `read_graph()` — backward compatible access to the CsrGraph.
108    pub fn graph(&self) -> parking_lot::RwLockReadGuard<'_, CsrGraph> {
109        self.graph.read()
110    }
111}
112
113impl Default for GraphManager {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use mentedb_core::edge::EdgeType;
123
124    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
125        MemoryEdge {
126            source: src,
127            target: tgt,
128            edge_type: etype,
129            weight: 0.8,
130            created_at: 1000,
131            valid_from: None,
132            valid_until: None,
133            label: None,
134        }
135    }
136
137    #[test]
138    fn test_add_memory_and_relationship() {
139        let mgr = GraphManager::new();
140        let a = MemoryId::new();
141        let b = MemoryId::new();
142        mgr.add_memory(a);
143        mgr.add_memory(b);
144        assert!(
145            mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
146                .is_ok()
147        );
148    }
149
150    #[test]
151    fn test_relationship_missing_node() {
152        let mgr = GraphManager::new();
153        let a = MemoryId::new();
154        let b = MemoryId::new();
155        mgr.add_memory(a);
156        // b not added
157        assert!(
158            mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
159                .is_err()
160        );
161    }
162
163    #[test]
164    fn test_context_subgraph() {
165        let mgr = GraphManager::new();
166        let a = MemoryId::new();
167        let b = MemoryId::new();
168        let c = MemoryId::new();
169        mgr.add_memory(a);
170        mgr.add_memory(b);
171        mgr.add_memory(c);
172        mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
173            .unwrap();
174        mgr.add_relationship(&make_edge(b, c, EdgeType::Related))
175            .unwrap();
176
177        let (nodes, edges) = mgr.get_context_subgraph(a, 2);
178        assert_eq!(nodes.len(), 3);
179        assert_eq!(edges.len(), 2);
180    }
181
182    #[test]
183    fn test_compact() {
184        let mgr = GraphManager::new();
185        let a = MemoryId::new();
186        let b = MemoryId::new();
187        mgr.add_memory(a);
188        mgr.add_memory(b);
189        mgr.add_relationship(&make_edge(a, b, EdgeType::Caused))
190            .unwrap();
191        mgr.compact();
192
193        let out = mgr.graph().outgoing(a);
194        assert_eq!(out.len(), 1);
195    }
196
197    #[test]
198    fn test_belief_propagation() {
199        let mgr = GraphManager::new();
200        let a = MemoryId::new();
201        let b = MemoryId::new();
202        mgr.add_memory(a);
203        mgr.add_memory(b);
204        mgr.add_relationship(&MemoryEdge {
205            source: a,
206            target: b,
207            edge_type: EdgeType::Caused,
208            weight: 1.0,
209            created_at: 1000,
210            valid_from: None,
211            valid_until: None,
212            label: None,
213        })
214        .unwrap();
215
216        let results = mgr.propagate_belief_change(a, 0.5);
217        assert!(results.len() >= 2);
218    }
219}