Skip to main content

agentic_memory/graph/
memory_graph.rs

1//! Core graph structure — nodes + edges with adjacency indexes.
2
3use std::collections::HashMap;
4
5use crate::index::{ClusterMap, DocLengths, SessionIndex, TemporalIndex, TermIndex, TypeIndex};
6use crate::types::{AmemError, AmemResult, CognitiveEvent, Edge, EdgeType, MAX_EDGES_PER_NODE};
7
8/// The core in-memory graph structure holding cognitive events and their relationships.
9pub struct MemoryGraph {
10    /// All nodes, indexed by ID.
11    nodes: Vec<CognitiveEvent>,
12    /// All edges, grouped by source_id.
13    edges: Vec<Edge>,
14    /// Adjacency index: source_id -> (start_index, count) in edges vec.
15    adjacency: HashMap<u64, (usize, usize)>,
16    /// Reverse adjacency: target_id -> list of source_ids.
17    reverse_adjacency: HashMap<u64, Vec<u64>>,
18    /// Next available node ID.
19    next_id: u64,
20    /// Feature vector dimension.
21    dimension: usize,
22    /// Type index.
23    pub(crate) type_index: TypeIndex,
24    /// Temporal index.
25    pub(crate) temporal_index: TemporalIndex,
26    /// Session index.
27    pub(crate) session_index: SessionIndex,
28    /// Cluster map.
29    pub(crate) cluster_map: ClusterMap,
30    /// BM25 inverted index (optional, may not be present in old files).
31    pub term_index: Option<TermIndex>,
32    /// Document lengths for BM25 normalization (optional).
33    pub doc_lengths: Option<DocLengths>,
34}
35
36impl MemoryGraph {
37    /// Create a new empty graph.
38    pub fn new(dimension: usize) -> Self {
39        Self {
40            nodes: Vec::new(),
41            edges: Vec::new(),
42            adjacency: HashMap::new(),
43            reverse_adjacency: HashMap::new(),
44            next_id: 0,
45            dimension,
46            type_index: TypeIndex::new(),
47            temporal_index: TemporalIndex::new(),
48            session_index: SessionIndex::new(),
49            cluster_map: ClusterMap::new(dimension),
50            term_index: None,
51            doc_lengths: None,
52        }
53    }
54
55    /// Create from pre-existing data (used by reader).
56    pub fn from_parts(
57        nodes: Vec<CognitiveEvent>,
58        edges: Vec<Edge>,
59        dimension: usize,
60    ) -> AmemResult<Self> {
61        let next_id = nodes.iter().map(|n| n.id + 1).max().unwrap_or(0);
62
63        let mut graph = Self {
64            nodes: Vec::new(),
65            edges: Vec::new(),
66            adjacency: HashMap::new(),
67            reverse_adjacency: HashMap::new(),
68            next_id,
69            dimension,
70            type_index: TypeIndex::new(),
71            temporal_index: TemporalIndex::new(),
72            session_index: SessionIndex::new(),
73            cluster_map: ClusterMap::new(dimension),
74            term_index: None,
75            doc_lengths: None,
76        };
77
78        // Insert nodes directly (they already have IDs assigned)
79        graph.nodes = nodes;
80
81        // Rebuild indexes from nodes
82        graph.type_index.rebuild(&graph.nodes);
83        graph.temporal_index.rebuild(&graph.nodes);
84        graph.session_index.rebuild(&graph.nodes);
85
86        // Sort edges by source_id, then target_id
87        let mut sorted_edges = edges;
88        sorted_edges.sort_by(|a, b| {
89            a.source_id
90                .cmp(&b.source_id)
91                .then(a.target_id.cmp(&b.target_id))
92        });
93        graph.edges = sorted_edges;
94
95        // Build adjacency indexes
96        graph.rebuild_adjacency();
97
98        Ok(graph)
99    }
100
101    /// Number of nodes.
102    pub fn node_count(&self) -> usize {
103        self.nodes.len()
104    }
105
106    /// Number of edges.
107    pub fn edge_count(&self) -> usize {
108        self.edges.len()
109    }
110
111    /// Get a node by ID (immutable).
112    pub fn get_node(&self, id: u64) -> Option<&CognitiveEvent> {
113        // Fast path: if IDs are sequential, nodes[id] has id == id
114        let idx = id as usize;
115        if idx < self.nodes.len() && self.nodes[idx].id == id {
116            return Some(&self.nodes[idx]);
117        }
118        // Fallback: linear scan (needed after remove_node)
119        self.nodes.iter().find(|n| n.id == id)
120    }
121
122    /// Get a node by ID (mutable).
123    pub fn get_node_mut(&mut self, id: u64) -> Option<&mut CognitiveEvent> {
124        // Fast path: if IDs are sequential, nodes[id] has id == id
125        let idx = id as usize;
126        if idx < self.nodes.len() && self.nodes[idx].id == id {
127            return Some(&mut self.nodes[idx]);
128        }
129        // Fallback: linear scan (needed after remove_node)
130        self.nodes.iter_mut().find(|n| n.id == id)
131    }
132
133    /// Ensure adjacency indexes are up to date.
134    /// No-op in the current implementation (adjacency is always up to date).
135    pub fn ensure_adjacency(&mut self) {
136        // Currently a no-op — adjacency is rebuilt on every mutation.
137    }
138
139    /// Get all edges from a source node.
140    pub fn edges_from(&self, source_id: u64) -> &[Edge] {
141        if let Some(&(start, count)) = self.adjacency.get(&source_id) {
142            &self.edges[start..start + count]
143        } else {
144            &[]
145        }
146    }
147
148    /// Get all edges that point TO this node.
149    pub fn edges_to(&self, target_id: u64) -> Vec<&Edge> {
150        if let Some(sources) = self.reverse_adjacency.get(&target_id) {
151            let mut result = Vec::new();
152            for &src_id in sources {
153                for edge in self.edges_from(src_id) {
154                    if edge.target_id == target_id {
155                        result.push(edge);
156                    }
157                }
158            }
159            result
160        } else {
161            Vec::new()
162        }
163    }
164
165    /// Get all nodes (immutable slice).
166    pub fn nodes(&self) -> &[CognitiveEvent] {
167        &self.nodes
168    }
169
170    /// Get all edges (immutable slice).
171    pub fn edges(&self) -> &[Edge] {
172        &self.edges
173    }
174
175    /// The feature vector dimension for this graph.
176    pub fn dimension(&self) -> usize {
177        self.dimension
178    }
179
180    /// Add a node, returns the assigned ID.
181    pub fn add_node(&mut self, mut event: CognitiveEvent) -> AmemResult<u64> {
182        // Validate content size
183        event.validate(self.dimension)?;
184
185        // Pad feature vec if empty
186        if event.feature_vec.is_empty() {
187            event.feature_vec = vec![0.0; self.dimension];
188        } else if event.feature_vec.len() != self.dimension {
189            return Err(AmemError::DimensionMismatch {
190                expected: self.dimension,
191                got: event.feature_vec.len(),
192            });
193        }
194
195        // Assign ID
196        let id = self.next_id;
197        event.id = id;
198        self.next_id += 1;
199
200        // Update indexes
201        self.type_index.add_node(&event);
202        self.temporal_index.add_node(&event);
203        self.session_index.add_node(&event);
204
205        self.nodes.push(event);
206
207        Ok(id)
208    }
209
210    /// Add an edge between two existing nodes.
211    pub fn add_edge(&mut self, edge: Edge) -> AmemResult<()> {
212        // Validate: no self-edges
213        if edge.source_id == edge.target_id {
214            return Err(AmemError::SelfEdge(edge.source_id));
215        }
216
217        // Validate: source exists
218        if self.get_node(edge.source_id).is_none() {
219            return Err(AmemError::NodeNotFound(edge.source_id));
220        }
221
222        // Validate: target exists
223        if self.get_node(edge.target_id).is_none() {
224            return Err(AmemError::InvalidEdgeTarget(edge.target_id));
225        }
226
227        // Check max edges per node
228        let current_count = self
229            .adjacency
230            .get(&edge.source_id)
231            .map(|(_, c)| *c)
232            .unwrap_or(0);
233        if current_count >= MAX_EDGES_PER_NODE as usize {
234            return Err(AmemError::TooManyEdges(MAX_EDGES_PER_NODE));
235        }
236
237        self.edges.push(edge);
238        self.rebuild_adjacency();
239
240        Ok(())
241    }
242
243    /// Remove a node and all its edges.
244    pub fn remove_node(&mut self, id: u64) -> AmemResult<CognitiveEvent> {
245        let pos = self
246            .nodes
247            .iter()
248            .position(|n| n.id == id)
249            .ok_or(AmemError::NodeNotFound(id))?;
250
251        let removed = self.nodes.remove(pos);
252
253        // Remove from indexes
254        self.type_index.remove_node(id, removed.event_type);
255        self.temporal_index.remove_node(id, removed.created_at);
256        self.session_index.remove_node(id, removed.session_id);
257
258        // Remove all edges involving this node
259        self.edges
260            .retain(|e| e.source_id != id && e.target_id != id);
261
262        // Rebuild adjacency
263        self.rebuild_adjacency();
264
265        Ok(removed)
266    }
267
268    /// Remove a specific edge.
269    pub fn remove_edge(
270        &mut self,
271        source_id: u64,
272        target_id: u64,
273        edge_type: EdgeType,
274    ) -> AmemResult<()> {
275        let initial_len = self.edges.len();
276        self.edges.retain(|e| {
277            !(e.source_id == source_id && e.target_id == target_id && e.edge_type == edge_type)
278        });
279        if self.edges.len() == initial_len {
280            return Err(AmemError::NodeNotFound(source_id));
281        }
282        self.rebuild_adjacency();
283        Ok(())
284    }
285
286    /// Rebuild adjacency indexes from the current edge list.
287    fn rebuild_adjacency(&mut self) {
288        self.adjacency.clear();
289        self.reverse_adjacency.clear();
290
291        // Sort edges by source_id, then target_id
292        self.edges.sort_by(|a, b| {
293            a.source_id
294                .cmp(&b.source_id)
295                .then(a.target_id.cmp(&b.target_id))
296        });
297
298        let mut i = 0;
299        while i < self.edges.len() {
300            let source = self.edges[i].source_id;
301            let start = i;
302            while i < self.edges.len() && self.edges[i].source_id == source {
303                // Build reverse adjacency
304                self.reverse_adjacency
305                    .entry(self.edges[i].target_id)
306                    .or_default()
307                    .push(source);
308                i += 1;
309            }
310            self.adjacency.insert(source, (start, i - start));
311        }
312
313        // Dedup reverse adjacency
314        for list in self.reverse_adjacency.values_mut() {
315            list.sort_unstable();
316            list.dedup();
317        }
318    }
319
320    /// Get the next available node ID (for builder use).
321    pub fn next_id(&self) -> u64 {
322        self.next_id
323    }
324
325    /// Get the type index.
326    pub fn type_index(&self) -> &TypeIndex {
327        &self.type_index
328    }
329
330    /// Get the temporal index.
331    pub fn temporal_index(&self) -> &TemporalIndex {
332        &self.temporal_index
333    }
334
335    /// Get the session index.
336    pub fn session_index(&self) -> &SessionIndex {
337        &self.session_index
338    }
339
340    /// Get the cluster map.
341    pub fn cluster_map(&self) -> &ClusterMap {
342        &self.cluster_map
343    }
344
345    /// Get a mutable reference to the cluster map.
346    pub fn cluster_map_mut(&mut self) -> &mut ClusterMap {
347        &mut self.cluster_map
348    }
349
350    /// Get the term index (for BM25 search). None if not built.
351    pub fn term_index(&self) -> Option<&TermIndex> {
352        self.term_index.as_ref()
353    }
354
355    /// Get the doc lengths (for BM25 normalization). None if not built.
356    pub fn doc_lengths(&self) -> Option<&DocLengths> {
357        self.doc_lengths.as_ref()
358    }
359
360    /// Set the term index.
361    pub fn set_term_index(&mut self, index: TermIndex) {
362        self.term_index = Some(index);
363    }
364
365    /// Set the doc lengths.
366    pub fn set_doc_lengths(&mut self, lengths: DocLengths) {
367        self.doc_lengths = Some(lengths);
368    }
369}