1use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12 pub edge_type: EdgeType,
14 pub weight: f32,
16 pub created_at: Timestamp,
18 #[serde(default, skip_serializing_if = "Option::is_none")]
20 pub valid_from: Option<Timestamp>,
21 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub valid_until: Option<Timestamp>,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
26 pub label: Option<String>,
27}
28
29impl StoredEdge {
30 pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
32 Self {
33 edge_type: edge.edge_type,
34 weight: edge.weight,
35 created_at: edge.created_at,
36 valid_from: edge.valid_from,
37 valid_until: edge.valid_until,
38 label: edge.label.clone(),
39 }
40 }
41
42 pub fn is_valid_at(&self, at: Timestamp) -> bool {
44 let from = self.valid_from.unwrap_or(0);
45 match self.valid_until {
46 Some(until) => at >= from && at < until,
47 None => at >= from,
48 }
49 }
50
51 pub fn is_invalidated(&self) -> bool {
53 self.valid_until.is_some()
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59struct DeltaEdge {
60 source_idx: u32,
61 target_idx: u32,
62 data: StoredEdge,
63}
64
65#[derive(Debug, Clone, Default, Serialize, Deserialize)]
67struct CompressedStorage {
68 row_offsets: Vec<u32>,
70 col_indices: Vec<u32>,
72 edge_data: Vec<StoredEdge>,
74}
75
76impl CompressedStorage {
77 #[allow(dead_code)]
78 fn new(num_nodes: usize) -> Self {
79 Self {
80 row_offsets: vec![0; num_nodes + 1],
81 col_indices: Vec::new(),
82 edge_data: Vec::new(),
83 }
84 }
85
86 fn neighbors(&self, row: u32) -> &[u32] {
88 let row = row as usize;
89 if row + 1 >= self.row_offsets.len() {
90 return &[];
91 }
92 let start = self.row_offsets[row] as usize;
93 let end = self.row_offsets[row + 1] as usize;
94 &self.col_indices[start..end]
95 }
96
97 fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
98 let row = row as usize;
99 if row + 1 >= self.row_offsets.len() {
100 return &[];
101 }
102 let start = self.row_offsets[row] as usize;
103 let end = self.row_offsets[row + 1] as usize;
104 &self.edge_data[start..end]
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CsrGraph {
111 id_to_idx: HashMap<MemoryId, u32>,
113 idx_to_id: Vec<MemoryId>,
115
116 csr: CompressedStorage,
118 csc: CompressedStorage,
120
121 delta_edges: Vec<DeltaEdge>,
123 removed_edges: Vec<(u32, u32)>,
125}
126
127impl CsrGraph {
128 pub fn new() -> Self {
130 Self {
131 id_to_idx: HashMap::default(),
132 idx_to_id: Vec::new(),
133 csr: CompressedStorage::default(),
134 csc: CompressedStorage::default(),
135 delta_edges: Vec::new(),
136 removed_edges: Vec::new(),
137 }
138 }
139
140 pub fn add_node(&mut self, id: MemoryId) -> u32 {
142 if let Some(&idx) = self.id_to_idx.get(&id) {
143 return idx;
144 }
145 let idx = self.idx_to_id.len() as u32;
146 self.id_to_idx.insert(id, idx);
147 self.idx_to_id.push(id);
148 idx
149 }
150
151 pub fn remove_node(&mut self, id: MemoryId) {
153 let Some(&idx) = self.id_to_idx.get(&id) else {
154 return;
155 };
156 for &neighbor in self.csr.neighbors(idx) {
158 self.removed_edges.push((idx, neighbor));
159 }
160 for &neighbor in self.csc.neighbors(idx) {
161 self.removed_edges.push((neighbor, idx));
162 }
163 self.delta_edges
165 .retain(|e| e.source_idx != idx && e.target_idx != idx);
166 self.id_to_idx.remove(&id);
167 }
168
169 pub fn add_edge(&mut self, edge: &MemoryEdge) {
171 let source_idx = self.add_node(edge.source);
172 let target_idx = self.add_node(edge.target);
173 self.delta_edges.push(DeltaEdge {
174 source_idx,
175 target_idx,
176 data: StoredEdge::from_memory_edge(edge),
177 });
178 }
179
180 pub fn strengthen_edge(&mut self, source: MemoryId, target: MemoryId, delta: f32) {
183 if let Some(existing) = self
185 .outgoing(source)
186 .into_iter()
187 .find(|(id, _)| *id == target)
188 {
189 let (_, stored) = existing;
190 let new_weight = (stored.weight + delta).min(1.0);
191 let source_idx = self.add_node(source);
192 let target_idx = self.add_node(target);
193 self.delta_edges.push(DeltaEdge {
194 source_idx,
195 target_idx,
196 data: StoredEdge {
197 weight: new_weight,
198 ..stored
199 },
200 });
201 }
202 }
203
204 pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
206 let (Some(&src_idx), Some(&tgt_idx)) =
207 (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
208 else {
209 return;
210 };
211 self.removed_edges.push((src_idx, tgt_idx));
212 self.delta_edges
213 .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
214 }
215
216 pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
218 let Some(&idx) = self.id_to_idx.get(&id) else {
219 return Vec::new();
220 };
221 self.outgoing_by_idx(idx)
222 }
223
224 pub fn outgoing_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
226 self.outgoing(id)
227 .into_iter()
228 .filter(|(_, e)| e.is_valid_at(at))
229 .collect()
230 }
231
232 pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
233 let mut results = Vec::new();
234
235 let neighbors = self.csr.neighbors(idx);
237 let edges = self.csr.edge_data_for(idx);
238 for (i, &neighbor) in neighbors.iter().enumerate() {
239 if !self.is_removed(idx, neighbor)
240 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
241 {
242 results.push((id, edges[i].clone()));
243 }
244 }
245
246 for delta in &self.delta_edges {
248 if delta.source_idx == idx
249 && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
250 {
251 results.push((id, delta.data.clone()));
252 }
253 }
254
255 results
256 }
257
258 pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
260 let Some(&idx) = self.id_to_idx.get(&id) else {
261 return Vec::new();
262 };
263 self.incoming_by_idx(idx)
264 }
265
266 pub fn incoming_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
268 self.incoming(id)
269 .into_iter()
270 .filter(|(_, e)| e.is_valid_at(at))
271 .collect()
272 }
273
274 pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
275 let mut results = Vec::new();
276
277 let neighbors = self.csc.neighbors(idx);
279 let edges = self.csc.edge_data_for(idx);
280 for (i, &neighbor) in neighbors.iter().enumerate() {
281 if !self.is_removed(neighbor, idx)
282 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
283 {
284 results.push((id, edges[i].clone()));
285 }
286 }
287
288 for delta in &self.delta_edges {
290 if delta.target_idx == idx
291 && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
292 {
293 results.push((id, delta.data.clone()));
294 }
295 }
296
297 results
298 }
299
300 pub fn contains_node(&self, id: MemoryId) -> bool {
302 self.id_to_idx.contains_key(&id)
303 }
304
305 pub fn node_count(&self) -> usize {
307 self.idx_to_id.len()
308 }
309
310 pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
312 self.id_to_idx.get(&id).copied()
313 }
314
315 #[allow(dead_code)]
317 pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
318 self.idx_to_id.get(idx as usize).copied()
319 }
320
321 pub fn node_ids(&self) -> &[MemoryId] {
323 &self.idx_to_id
324 }
325
326 fn is_removed(&self, source: u32, target: u32) -> bool {
327 self.removed_edges
328 .iter()
329 .any(|&(s, t)| s == source && t == target)
330 }
331
332 pub fn compact(&mut self) {
334 let num_nodes = self.idx_to_id.len();
335
336 let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
338
339 for row in 0..num_nodes {
341 let row = row as u32;
342 let neighbors = self.csr.neighbors(row);
343 let edges = self.csr.edge_data_for(row);
344 for (i, &col) in neighbors.iter().enumerate() {
345 if !self.is_removed(row, col) {
346 all_edges.push((row, col, edges[i].clone()));
347 }
348 }
349 }
350
351 for delta in &self.delta_edges {
353 all_edges.push((delta.source_idx, delta.target_idx, delta.data.clone()));
354 }
355
356 self.csr = Self::build_compressed(&all_edges, num_nodes, false);
358
359 self.csc = Self::build_compressed(&all_edges, num_nodes, true);
361
362 self.delta_edges.clear();
363 self.removed_edges.clear();
364 }
365
366 fn build_compressed(
367 edges: &[(u32, u32, StoredEdge)],
368 num_nodes: usize,
369 transpose: bool,
370 ) -> CompressedStorage {
371 let mut counts = vec![0u32; num_nodes];
373 for &(src, tgt, ref _data) in edges {
374 let row = if transpose { tgt } else { src };
375 if (row as usize) < num_nodes {
376 counts[row as usize] += 1;
377 }
378 }
379
380 let mut row_offsets = vec![0u32; num_nodes + 1];
382 for i in 0..num_nodes {
383 row_offsets[i + 1] = row_offsets[i] + counts[i];
384 }
385
386 let total = row_offsets[num_nodes] as usize;
387 let mut col_indices = vec![0u32; total];
388 let mut edge_data = vec![
389 StoredEdge {
390 edge_type: EdgeType::Related,
391 weight: 0.0,
392 created_at: 0,
393 valid_from: None,
394 valid_until: None,
395 label: None,
396 };
397 total
398 ];
399
400 let mut cursors = row_offsets[..num_nodes].to_vec();
402 for &(src, tgt, ref data) in edges {
403 let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
404 if (row as usize) < num_nodes {
405 let pos = cursors[row as usize] as usize;
406 col_indices[pos] = col;
407 edge_data[pos] = data.clone();
408 cursors[row as usize] += 1;
409 }
410 }
411
412 CompressedStorage {
413 row_offsets,
414 col_indices,
415 edge_data,
416 }
417 }
418 pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
420 let data =
421 serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
422 std::fs::write(path, data)?;
423 Ok(())
424 }
425
426 pub fn load(path: &std::path::Path) -> MenteResult<Self> {
428 let data = std::fs::read(path)?;
429 let graph: Self =
430 serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
431 Ok(graph)
432 }
433}
434
435impl Default for CsrGraph {
436 fn default() -> Self {
437 Self::new()
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
446 MemoryEdge {
447 source: src,
448 target: tgt,
449 edge_type: etype,
450 weight: 0.8,
451 created_at: 1000,
452 valid_from: None,
453 valid_until: None,
454 label: None,
455 }
456 }
457
458 #[test]
459 fn test_add_node_idempotent() {
460 let mut g = CsrGraph::new();
461 let id = MemoryId::new();
462 let idx1 = g.add_node(id);
463 let idx2 = g.add_node(id);
464 assert_eq!(idx1, idx2);
465 assert_eq!(g.node_count(), 1);
466 }
467
468 #[test]
469 fn test_add_and_query_edges() {
470 let mut g = CsrGraph::new();
471 let a = MemoryId::new();
472 let b = MemoryId::new();
473 let c = MemoryId::new();
474
475 g.add_edge(&make_edge(a, b, EdgeType::Caused));
476 g.add_edge(&make_edge(a, c, EdgeType::Related));
477
478 let out = g.outgoing(a);
479 assert_eq!(out.len(), 2);
480
481 let inc_b = g.incoming(b);
482 assert_eq!(inc_b.len(), 1);
483 assert_eq!(inc_b[0].0, a);
484 }
485
486 #[test]
487 fn test_remove_edge() {
488 let mut g = CsrGraph::new();
489 let a = MemoryId::new();
490 let b = MemoryId::new();
491
492 g.add_edge(&make_edge(a, b, EdgeType::Caused));
493 assert_eq!(g.outgoing(a).len(), 1);
494
495 g.remove_edge(a, b);
496 assert_eq!(g.outgoing(a).len(), 0);
497 }
498
499 #[test]
500 fn test_compact() {
501 let mut g = CsrGraph::new();
502 let a = MemoryId::new();
503 let b = MemoryId::new();
504 let c = MemoryId::new();
505
506 g.add_edge(&make_edge(a, b, EdgeType::Caused));
507 g.add_edge(&make_edge(b, c, EdgeType::Before));
508 g.compact();
509
510 let out_a = g.outgoing(a);
511 assert_eq!(out_a.len(), 1);
512 assert_eq!(out_a[0].0, b);
513
514 let inc_c = g.incoming(c);
515 assert_eq!(inc_c.len(), 1);
516 assert_eq!(inc_c[0].0, b);
517 }
518
519 #[test]
520 fn test_compact_with_removals() {
521 let mut g = CsrGraph::new();
522 let a = MemoryId::new();
523 let b = MemoryId::new();
524 let c = MemoryId::new();
525
526 g.add_edge(&make_edge(a, b, EdgeType::Caused));
527 g.add_edge(&make_edge(a, c, EdgeType::Related));
528 g.compact();
529
530 g.remove_edge(a, b);
531 g.compact();
532
533 let out = g.outgoing(a);
534 assert_eq!(out.len(), 1);
535 assert_eq!(out[0].0, c);
536 }
537
538 #[test]
539 fn test_remove_node_cleans_id_to_idx() {
540 let mut g = CsrGraph::new();
541 let a = MemoryId::new();
542 let b = MemoryId::new();
543
544 g.add_edge(&make_edge(a, b, EdgeType::Caused));
545 assert!(g.contains_node(a));
546 assert!(g.contains_node(b));
547
548 g.remove_node(a);
549 assert!(
550 !g.contains_node(a),
551 "removed node should not be in id_to_idx"
552 );
553 assert!(g.contains_node(b), "unrelated node should still exist");
554
555 assert!(g.outgoing(a).is_empty());
557 assert!(g.incoming(b).is_empty());
558 }
559
560 #[test]
561 fn test_remove_node_then_readd() {
562 let mut g = CsrGraph::new();
563 let a = MemoryId::new();
564 let b = MemoryId::new();
565 let c = MemoryId::new();
566
567 g.add_edge(&make_edge(a, b, EdgeType::Caused));
568 g.remove_node(a);
569
570 g.add_edge(&make_edge(a, c, EdgeType::Related));
572 assert!(g.contains_node(a));
573 let out = g.outgoing(a);
574 assert_eq!(out.len(), 1);
575 assert_eq!(out[0].0, c);
576 }
577}