1use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12 pub edge_type: EdgeType,
14 pub weight: f32,
16 pub created_at: Timestamp,
18}
19
20impl StoredEdge {
21 pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
23 Self {
24 edge_type: edge.edge_type,
25 weight: edge.weight,
26 created_at: edge.created_at,
27 }
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33struct DeltaEdge {
34 source_idx: u32,
35 target_idx: u32,
36 data: StoredEdge,
37}
38
39#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41struct CompressedStorage {
42 row_offsets: Vec<u32>,
44 col_indices: Vec<u32>,
46 edge_data: Vec<StoredEdge>,
48}
49
50impl CompressedStorage {
51 #[allow(dead_code)]
52 fn new(num_nodes: usize) -> Self {
53 Self {
54 row_offsets: vec![0; num_nodes + 1],
55 col_indices: Vec::new(),
56 edge_data: Vec::new(),
57 }
58 }
59
60 fn neighbors(&self, row: u32) -> &[u32] {
62 let row = row as usize;
63 if row + 1 >= self.row_offsets.len() {
64 return &[];
65 }
66 let start = self.row_offsets[row] as usize;
67 let end = self.row_offsets[row + 1] as usize;
68 &self.col_indices[start..end]
69 }
70
71 fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
72 let row = row as usize;
73 if row + 1 >= self.row_offsets.len() {
74 return &[];
75 }
76 let start = self.row_offsets[row] as usize;
77 let end = self.row_offsets[row + 1] as usize;
78 &self.edge_data[start..end]
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct CsrGraph {
85 id_to_idx: HashMap<MemoryId, u32>,
87 idx_to_id: Vec<MemoryId>,
89
90 csr: CompressedStorage,
92 csc: CompressedStorage,
94
95 delta_edges: Vec<DeltaEdge>,
97 removed_edges: Vec<(u32, u32)>,
99}
100
101impl CsrGraph {
102 pub fn new() -> Self {
104 Self {
105 id_to_idx: HashMap::default(),
106 idx_to_id: Vec::new(),
107 csr: CompressedStorage::default(),
108 csc: CompressedStorage::default(),
109 delta_edges: Vec::new(),
110 removed_edges: Vec::new(),
111 }
112 }
113
114 pub fn add_node(&mut self, id: MemoryId) -> u32 {
116 if let Some(&idx) = self.id_to_idx.get(&id) {
117 return idx;
118 }
119 let idx = self.idx_to_id.len() as u32;
120 self.id_to_idx.insert(id, idx);
121 self.idx_to_id.push(id);
122 idx
123 }
124
125 pub fn remove_node(&mut self, id: MemoryId) {
127 let Some(&idx) = self.id_to_idx.get(&id) else {
128 return;
129 };
130 for &neighbor in self.csr.neighbors(idx) {
132 self.removed_edges.push((idx, neighbor));
133 }
134 for &neighbor in self.csc.neighbors(idx) {
135 self.removed_edges.push((neighbor, idx));
136 }
137 self.delta_edges
139 .retain(|e| e.source_idx != idx && e.target_idx != idx);
140 }
141
142 pub fn add_edge(&mut self, edge: &MemoryEdge) {
144 let source_idx = self.add_node(edge.source);
145 let target_idx = self.add_node(edge.target);
146 self.delta_edges.push(DeltaEdge {
147 source_idx,
148 target_idx,
149 data: StoredEdge::from_memory_edge(edge),
150 });
151 }
152
153 pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
155 let (Some(&src_idx), Some(&tgt_idx)) =
156 (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
157 else {
158 return;
159 };
160 self.removed_edges.push((src_idx, tgt_idx));
161 self.delta_edges
162 .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
163 }
164
165 pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
167 let Some(&idx) = self.id_to_idx.get(&id) else {
168 return Vec::new();
169 };
170 self.outgoing_by_idx(idx)
171 }
172
173 pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
174 let mut results = Vec::new();
175
176 let neighbors = self.csr.neighbors(idx);
178 let edges = self.csr.edge_data_for(idx);
179 for (i, &neighbor) in neighbors.iter().enumerate() {
180 if !self.is_removed(idx, neighbor)
181 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
182 {
183 results.push((id, edges[i]));
184 }
185 }
186
187 for delta in &self.delta_edges {
189 if delta.source_idx == idx
190 && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
191 {
192 results.push((id, delta.data));
193 }
194 }
195
196 results
197 }
198
199 pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
201 let Some(&idx) = self.id_to_idx.get(&id) else {
202 return Vec::new();
203 };
204 self.incoming_by_idx(idx)
205 }
206
207 pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
208 let mut results = Vec::new();
209
210 let neighbors = self.csc.neighbors(idx);
212 let edges = self.csc.edge_data_for(idx);
213 for (i, &neighbor) in neighbors.iter().enumerate() {
214 if !self.is_removed(neighbor, idx)
215 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
216 {
217 results.push((id, edges[i]));
218 }
219 }
220
221 for delta in &self.delta_edges {
223 if delta.target_idx == idx
224 && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
225 {
226 results.push((id, delta.data));
227 }
228 }
229
230 results
231 }
232
233 pub fn contains_node(&self, id: MemoryId) -> bool {
235 self.id_to_idx.contains_key(&id)
236 }
237
238 pub fn node_count(&self) -> usize {
240 self.idx_to_id.len()
241 }
242
243 pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
245 self.id_to_idx.get(&id).copied()
246 }
247
248 #[allow(dead_code)]
250 pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
251 self.idx_to_id.get(idx as usize).copied()
252 }
253
254 pub fn node_ids(&self) -> &[MemoryId] {
256 &self.idx_to_id
257 }
258
259 fn is_removed(&self, source: u32, target: u32) -> bool {
260 self.removed_edges
261 .iter()
262 .any(|&(s, t)| s == source && t == target)
263 }
264
265 pub fn compact(&mut self) {
267 let num_nodes = self.idx_to_id.len();
268
269 let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
271
272 for row in 0..num_nodes {
274 let row = row as u32;
275 let neighbors = self.csr.neighbors(row);
276 let edges = self.csr.edge_data_for(row);
277 for (i, &col) in neighbors.iter().enumerate() {
278 if !self.is_removed(row, col) {
279 all_edges.push((row, col, edges[i]));
280 }
281 }
282 }
283
284 for delta in &self.delta_edges {
286 all_edges.push((delta.source_idx, delta.target_idx, delta.data));
287 }
288
289 self.csr = Self::build_compressed(&all_edges, num_nodes, false);
291
292 self.csc = Self::build_compressed(&all_edges, num_nodes, true);
294
295 self.delta_edges.clear();
296 self.removed_edges.clear();
297 }
298
299 fn build_compressed(
300 edges: &[(u32, u32, StoredEdge)],
301 num_nodes: usize,
302 transpose: bool,
303 ) -> CompressedStorage {
304 let mut counts = vec![0u32; num_nodes];
306 for &(src, tgt, _) in edges {
307 let row = if transpose { tgt } else { src };
308 if (row as usize) < num_nodes {
309 counts[row as usize] += 1;
310 }
311 }
312
313 let mut row_offsets = vec![0u32; num_nodes + 1];
315 for i in 0..num_nodes {
316 row_offsets[i + 1] = row_offsets[i] + counts[i];
317 }
318
319 let total = row_offsets[num_nodes] as usize;
320 let mut col_indices = vec![0u32; total];
321 let mut edge_data = vec![
322 StoredEdge {
323 edge_type: EdgeType::Related,
324 weight: 0.0,
325 created_at: 0,
326 };
327 total
328 ];
329
330 let mut cursors = row_offsets[..num_nodes].to_vec();
332 for &(src, tgt, data) in edges {
333 let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
334 if (row as usize) < num_nodes {
335 let pos = cursors[row as usize] as usize;
336 col_indices[pos] = col;
337 edge_data[pos] = data;
338 cursors[row as usize] += 1;
339 }
340 }
341
342 CompressedStorage {
343 row_offsets,
344 col_indices,
345 edge_data,
346 }
347 }
348 pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
350 let data =
351 serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
352 std::fs::write(path, data)?;
353 Ok(())
354 }
355
356 pub fn load(path: &std::path::Path) -> MenteResult<Self> {
358 let data = std::fs::read(path)?;
359 let graph: Self =
360 serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
361 Ok(graph)
362 }
363}
364
365impl Default for CsrGraph {
366 fn default() -> Self {
367 Self::new()
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
376 MemoryEdge {
377 source: src,
378 target: tgt,
379 edge_type: etype,
380 weight: 0.8,
381 created_at: 1000,
382 }
383 }
384
385 #[test]
386 fn test_add_node_idempotent() {
387 let mut g = CsrGraph::new();
388 let id = MemoryId::new();
389 let idx1 = g.add_node(id);
390 let idx2 = g.add_node(id);
391 assert_eq!(idx1, idx2);
392 assert_eq!(g.node_count(), 1);
393 }
394
395 #[test]
396 fn test_add_and_query_edges() {
397 let mut g = CsrGraph::new();
398 let a = MemoryId::new();
399 let b = MemoryId::new();
400 let c = MemoryId::new();
401
402 g.add_edge(&make_edge(a, b, EdgeType::Caused));
403 g.add_edge(&make_edge(a, c, EdgeType::Related));
404
405 let out = g.outgoing(a);
406 assert_eq!(out.len(), 2);
407
408 let inc_b = g.incoming(b);
409 assert_eq!(inc_b.len(), 1);
410 assert_eq!(inc_b[0].0, a);
411 }
412
413 #[test]
414 fn test_remove_edge() {
415 let mut g = CsrGraph::new();
416 let a = MemoryId::new();
417 let b = MemoryId::new();
418
419 g.add_edge(&make_edge(a, b, EdgeType::Caused));
420 assert_eq!(g.outgoing(a).len(), 1);
421
422 g.remove_edge(a, b);
423 assert_eq!(g.outgoing(a).len(), 0);
424 }
425
426 #[test]
427 fn test_compact() {
428 let mut g = CsrGraph::new();
429 let a = MemoryId::new();
430 let b = MemoryId::new();
431 let c = MemoryId::new();
432
433 g.add_edge(&make_edge(a, b, EdgeType::Caused));
434 g.add_edge(&make_edge(b, c, EdgeType::Before));
435 g.compact();
436
437 let out_a = g.outgoing(a);
438 assert_eq!(out_a.len(), 1);
439 assert_eq!(out_a[0].0, b);
440
441 let inc_c = g.incoming(c);
442 assert_eq!(inc_c.len(), 1);
443 assert_eq!(inc_c[0].0, b);
444 }
445
446 #[test]
447 fn test_compact_with_removals() {
448 let mut g = CsrGraph::new();
449 let a = MemoryId::new();
450 let b = MemoryId::new();
451 let c = MemoryId::new();
452
453 g.add_edge(&make_edge(a, b, EdgeType::Caused));
454 g.add_edge(&make_edge(a, c, EdgeType::Related));
455 g.compact();
456
457 g.remove_edge(a, b);
458 g.compact();
459
460 let out = g.outgoing(a);
461 assert_eq!(out.len(), 1);
462 assert_eq!(out[0].0, c);
463 }
464}