Skip to main content

agentic_memory/graph/
memory_graph.rs

1//! Core graph structure — nodes + edges with adjacency indexes.
2
3use std::collections::HashMap;
4
5use crate::index::{ClusterMap, SessionIndex, TemporalIndex, TypeIndex};
6use crate::types::{AmemError, AmemResult, CognitiveEvent, Edge, EdgeType, MAX_EDGES_PER_NODE};
7
8/// The core in-memory graph structure holding cognitive events and their relationships.
9pub struct MemoryGraph {
10    /// All nodes, indexed by ID.
11    nodes: Vec<CognitiveEvent>,
12    /// All edges, grouped by source_id.
13    edges: Vec<Edge>,
14    /// Adjacency index: source_id -> (start_index, count) in edges vec.
15    adjacency: HashMap<u64, (usize, usize)>,
16    /// Reverse adjacency: target_id -> list of source_ids.
17    reverse_adjacency: HashMap<u64, Vec<u64>>,
18    /// Next available node ID.
19    next_id: u64,
20    /// Feature vector dimension.
21    dimension: usize,
22    /// Type index.
23    pub(crate) type_index: TypeIndex,
24    /// Temporal index.
25    pub(crate) temporal_index: TemporalIndex,
26    /// Session index.
27    pub(crate) session_index: SessionIndex,
28    /// Cluster map.
29    pub(crate) cluster_map: ClusterMap,
30}
31
32impl MemoryGraph {
33    /// Create a new empty graph.
34    pub fn new(dimension: usize) -> Self {
35        Self {
36            nodes: Vec::new(),
37            edges: Vec::new(),
38            adjacency: HashMap::new(),
39            reverse_adjacency: HashMap::new(),
40            next_id: 0,
41            dimension,
42            type_index: TypeIndex::new(),
43            temporal_index: TemporalIndex::new(),
44            session_index: SessionIndex::new(),
45            cluster_map: ClusterMap::new(dimension),
46        }
47    }
48
49    /// Create from pre-existing data (used by reader).
50    pub fn from_parts(
51        nodes: Vec<CognitiveEvent>,
52        edges: Vec<Edge>,
53        dimension: usize,
54    ) -> AmemResult<Self> {
55        let next_id = nodes.iter().map(|n| n.id + 1).max().unwrap_or(0);
56
57        let mut graph = Self {
58            nodes: Vec::new(),
59            edges: Vec::new(),
60            adjacency: HashMap::new(),
61            reverse_adjacency: HashMap::new(),
62            next_id,
63            dimension,
64            type_index: TypeIndex::new(),
65            temporal_index: TemporalIndex::new(),
66            session_index: SessionIndex::new(),
67            cluster_map: ClusterMap::new(dimension),
68        };
69
70        // Insert nodes directly (they already have IDs assigned)
71        graph.nodes = nodes;
72
73        // Rebuild indexes from nodes
74        graph.type_index.rebuild(&graph.nodes);
75        graph.temporal_index.rebuild(&graph.nodes);
76        graph.session_index.rebuild(&graph.nodes);
77
78        // Sort edges by source_id, then target_id
79        let mut sorted_edges = edges;
80        sorted_edges.sort_by(|a, b| {
81            a.source_id
82                .cmp(&b.source_id)
83                .then(a.target_id.cmp(&b.target_id))
84        });
85        graph.edges = sorted_edges;
86
87        // Build adjacency indexes
88        graph.rebuild_adjacency();
89
90        Ok(graph)
91    }
92
93    /// Number of nodes.
94    pub fn node_count(&self) -> usize {
95        self.nodes.len()
96    }
97
98    /// Number of edges.
99    pub fn edge_count(&self) -> usize {
100        self.edges.len()
101    }
102
103    /// Get a node by ID (immutable).
104    pub fn get_node(&self, id: u64) -> Option<&CognitiveEvent> {
105        // Fast path: if IDs are sequential, nodes[id] has id == id
106        let idx = id as usize;
107        if idx < self.nodes.len() && self.nodes[idx].id == id {
108            return Some(&self.nodes[idx]);
109        }
110        // Fallback: linear scan (needed after remove_node)
111        self.nodes.iter().find(|n| n.id == id)
112    }
113
114    /// Get a node by ID (mutable).
115    pub fn get_node_mut(&mut self, id: u64) -> Option<&mut CognitiveEvent> {
116        // Fast path: if IDs are sequential, nodes[id] has id == id
117        let idx = id as usize;
118        if idx < self.nodes.len() && self.nodes[idx].id == id {
119            return Some(&mut self.nodes[idx]);
120        }
121        // Fallback: linear scan (needed after remove_node)
122        self.nodes.iter_mut().find(|n| n.id == id)
123    }
124
125    /// Ensure adjacency indexes are up to date.
126    /// No-op in the current implementation (adjacency is always up to date).
127    pub fn ensure_adjacency(&mut self) {
128        // Currently a no-op — adjacency is rebuilt on every mutation.
129    }
130
131    /// Get all edges from a source node.
132    pub fn edges_from(&self, source_id: u64) -> &[Edge] {
133        if let Some(&(start, count)) = self.adjacency.get(&source_id) {
134            &self.edges[start..start + count]
135        } else {
136            &[]
137        }
138    }
139
140    /// Get all edges that point TO this node.
141    pub fn edges_to(&self, target_id: u64) -> Vec<&Edge> {
142        if let Some(sources) = self.reverse_adjacency.get(&target_id) {
143            let mut result = Vec::new();
144            for &src_id in sources {
145                for edge in self.edges_from(src_id) {
146                    if edge.target_id == target_id {
147                        result.push(edge);
148                    }
149                }
150            }
151            result
152        } else {
153            Vec::new()
154        }
155    }
156
157    /// Get all nodes (immutable slice).
158    pub fn nodes(&self) -> &[CognitiveEvent] {
159        &self.nodes
160    }
161
162    /// Get all edges (immutable slice).
163    pub fn edges(&self) -> &[Edge] {
164        &self.edges
165    }
166
167    /// The feature vector dimension for this graph.
168    pub fn dimension(&self) -> usize {
169        self.dimension
170    }
171
172    /// Add a node, returns the assigned ID.
173    pub fn add_node(&mut self, mut event: CognitiveEvent) -> AmemResult<u64> {
174        // Validate content size
175        event.validate(self.dimension)?;
176
177        // Pad feature vec if empty
178        if event.feature_vec.is_empty() {
179            event.feature_vec = vec![0.0; self.dimension];
180        } else if event.feature_vec.len() != self.dimension {
181            return Err(AmemError::DimensionMismatch {
182                expected: self.dimension,
183                got: event.feature_vec.len(),
184            });
185        }
186
187        // Assign ID
188        let id = self.next_id;
189        event.id = id;
190        self.next_id += 1;
191
192        // Update indexes
193        self.type_index.add_node(&event);
194        self.temporal_index.add_node(&event);
195        self.session_index.add_node(&event);
196
197        self.nodes.push(event);
198
199        Ok(id)
200    }
201
202    /// Add an edge between two existing nodes.
203    pub fn add_edge(&mut self, edge: Edge) -> AmemResult<()> {
204        // Validate: no self-edges
205        if edge.source_id == edge.target_id {
206            return Err(AmemError::SelfEdge(edge.source_id));
207        }
208
209        // Validate: source exists
210        if self.get_node(edge.source_id).is_none() {
211            return Err(AmemError::NodeNotFound(edge.source_id));
212        }
213
214        // Validate: target exists
215        if self.get_node(edge.target_id).is_none() {
216            return Err(AmemError::InvalidEdgeTarget(edge.target_id));
217        }
218
219        // Check max edges per node
220        let current_count = self
221            .adjacency
222            .get(&edge.source_id)
223            .map(|(_, c)| *c)
224            .unwrap_or(0);
225        if current_count >= MAX_EDGES_PER_NODE as usize {
226            return Err(AmemError::TooManyEdges(MAX_EDGES_PER_NODE));
227        }
228
229        self.edges.push(edge);
230        self.rebuild_adjacency();
231
232        Ok(())
233    }
234
235    /// Remove a node and all its edges.
236    pub fn remove_node(&mut self, id: u64) -> AmemResult<CognitiveEvent> {
237        let pos = self
238            .nodes
239            .iter()
240            .position(|n| n.id == id)
241            .ok_or(AmemError::NodeNotFound(id))?;
242
243        let removed = self.nodes.remove(pos);
244
245        // Remove from indexes
246        self.type_index.remove_node(id, removed.event_type);
247        self.temporal_index.remove_node(id, removed.created_at);
248        self.session_index.remove_node(id, removed.session_id);
249
250        // Remove all edges involving this node
251        self.edges
252            .retain(|e| e.source_id != id && e.target_id != id);
253
254        // Rebuild adjacency
255        self.rebuild_adjacency();
256
257        Ok(removed)
258    }
259
260    /// Remove a specific edge.
261    pub fn remove_edge(
262        &mut self,
263        source_id: u64,
264        target_id: u64,
265        edge_type: EdgeType,
266    ) -> AmemResult<()> {
267        let initial_len = self.edges.len();
268        self.edges.retain(|e| {
269            !(e.source_id == source_id && e.target_id == target_id && e.edge_type == edge_type)
270        });
271        if self.edges.len() == initial_len {
272            return Err(AmemError::NodeNotFound(source_id));
273        }
274        self.rebuild_adjacency();
275        Ok(())
276    }
277
278    /// Rebuild adjacency indexes from the current edge list.
279    fn rebuild_adjacency(&mut self) {
280        self.adjacency.clear();
281        self.reverse_adjacency.clear();
282
283        // Sort edges by source_id, then target_id
284        self.edges.sort_by(|a, b| {
285            a.source_id
286                .cmp(&b.source_id)
287                .then(a.target_id.cmp(&b.target_id))
288        });
289
290        let mut i = 0;
291        while i < self.edges.len() {
292            let source = self.edges[i].source_id;
293            let start = i;
294            while i < self.edges.len() && self.edges[i].source_id == source {
295                // Build reverse adjacency
296                self.reverse_adjacency
297                    .entry(self.edges[i].target_id)
298                    .or_default()
299                    .push(source);
300                i += 1;
301            }
302            self.adjacency.insert(source, (start, i - start));
303        }
304
305        // Dedup reverse adjacency
306        for list in self.reverse_adjacency.values_mut() {
307            list.sort_unstable();
308            list.dedup();
309        }
310    }
311
312    /// Get the next available node ID (for builder use).
313    pub fn next_id(&self) -> u64 {
314        self.next_id
315    }
316
317    /// Get the type index.
318    pub fn type_index(&self) -> &TypeIndex {
319        &self.type_index
320    }
321
322    /// Get the temporal index.
323    pub fn temporal_index(&self) -> &TemporalIndex {
324        &self.temporal_index
325    }
326
327    /// Get the session index.
328    pub fn session_index(&self) -> &SessionIndex {
329        &self.session_index
330    }
331
332    /// Get the cluster map.
333    pub fn cluster_map(&self) -> &ClusterMap {
334        &self.cluster_map
335    }
336
337    /// Get a mutable reference to the cluster map.
338    pub fn cluster_map_mut(&mut self) -> &mut ClusterMap {
339        &mut self.cluster_map
340    }
341}