1use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12 pub edge_type: EdgeType,
14 pub weight: f32,
16 pub created_at: Timestamp,
18 #[serde(default, skip_serializing_if = "Option::is_none")]
20 pub valid_from: Option<Timestamp>,
21 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub valid_until: Option<Timestamp>,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
26 pub label: Option<String>,
27}
28
29impl StoredEdge {
30 pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
32 Self {
33 edge_type: edge.edge_type,
34 weight: edge.weight,
35 created_at: edge.created_at,
36 valid_from: edge.valid_from,
37 valid_until: edge.valid_until,
38 label: edge.label.clone(),
39 }
40 }
41
42 pub fn is_valid_at(&self, at: Timestamp) -> bool {
44 let from = self.valid_from.unwrap_or(0);
45 match self.valid_until {
46 Some(until) => at >= from && at < until,
47 None => at >= from,
48 }
49 }
50
51 pub fn is_invalidated(&self) -> bool {
53 self.valid_until.is_some()
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59struct DeltaEdge {
60 source_idx: u32,
61 target_idx: u32,
62 data: StoredEdge,
63}
64
65#[derive(Debug, Clone, Default, Serialize, Deserialize)]
67struct CompressedStorage {
68 row_offsets: Vec<u32>,
70 col_indices: Vec<u32>,
72 edge_data: Vec<StoredEdge>,
74}
75
76impl CompressedStorage {
77 #[allow(dead_code)]
78 fn new(num_nodes: usize) -> Self {
79 Self {
80 row_offsets: vec![0; num_nodes + 1],
81 col_indices: Vec::new(),
82 edge_data: Vec::new(),
83 }
84 }
85
86 fn neighbors(&self, row: u32) -> &[u32] {
88 let row = row as usize;
89 if row + 1 >= self.row_offsets.len() {
90 return &[];
91 }
92 let start = self.row_offsets[row] as usize;
93 let end = self.row_offsets[row + 1] as usize;
94 &self.col_indices[start..end]
95 }
96
97 fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
98 let row = row as usize;
99 if row + 1 >= self.row_offsets.len() {
100 return &[];
101 }
102 let start = self.row_offsets[row] as usize;
103 let end = self.row_offsets[row + 1] as usize;
104 &self.edge_data[start..end]
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CsrGraph {
111 id_to_idx: HashMap<MemoryId, u32>,
113 idx_to_id: Vec<MemoryId>,
115
116 csr: CompressedStorage,
118 csc: CompressedStorage,
120
121 delta_edges: Vec<DeltaEdge>,
123 removed_edges: Vec<(u32, u32)>,
125}
126
127impl CsrGraph {
128 pub fn new() -> Self {
130 Self {
131 id_to_idx: HashMap::default(),
132 idx_to_id: Vec::new(),
133 csr: CompressedStorage::default(),
134 csc: CompressedStorage::default(),
135 delta_edges: Vec::new(),
136 removed_edges: Vec::new(),
137 }
138 }
139
140 pub fn add_node(&mut self, id: MemoryId) -> u32 {
142 if let Some(&idx) = self.id_to_idx.get(&id) {
143 return idx;
144 }
145 let idx = self.idx_to_id.len() as u32;
146 self.id_to_idx.insert(id, idx);
147 self.idx_to_id.push(id);
148 idx
149 }
150
151 pub fn remove_node(&mut self, id: MemoryId) {
153 let Some(&idx) = self.id_to_idx.get(&id) else {
154 return;
155 };
156 for &neighbor in self.csr.neighbors(idx) {
158 self.removed_edges.push((idx, neighbor));
159 }
160 for &neighbor in self.csc.neighbors(idx) {
161 self.removed_edges.push((neighbor, idx));
162 }
163 self.delta_edges
165 .retain(|e| e.source_idx != idx && e.target_idx != idx);
166 }
167
168 pub fn add_edge(&mut self, edge: &MemoryEdge) {
170 let source_idx = self.add_node(edge.source);
171 let target_idx = self.add_node(edge.target);
172 self.delta_edges.push(DeltaEdge {
173 source_idx,
174 target_idx,
175 data: StoredEdge::from_memory_edge(edge),
176 });
177 }
178
179 pub fn strengthen_edge(&mut self, source: MemoryId, target: MemoryId, delta: f32) {
182 if let Some(existing) = self
184 .outgoing(source)
185 .into_iter()
186 .find(|(id, _)| *id == target)
187 {
188 let (_, stored) = existing;
189 let new_weight = (stored.weight + delta).min(1.0);
190 let source_idx = self.add_node(source);
191 let target_idx = self.add_node(target);
192 self.delta_edges.push(DeltaEdge {
193 source_idx,
194 target_idx,
195 data: StoredEdge {
196 weight: new_weight,
197 ..stored
198 },
199 });
200 }
201 }
202
203 pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
205 let (Some(&src_idx), Some(&tgt_idx)) =
206 (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
207 else {
208 return;
209 };
210 self.removed_edges.push((src_idx, tgt_idx));
211 self.delta_edges
212 .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
213 }
214
215 pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
217 let Some(&idx) = self.id_to_idx.get(&id) else {
218 return Vec::new();
219 };
220 self.outgoing_by_idx(idx)
221 }
222
223 pub fn outgoing_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
225 self.outgoing(id)
226 .into_iter()
227 .filter(|(_, e)| e.is_valid_at(at))
228 .collect()
229 }
230
231 pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
232 let mut results = Vec::new();
233
234 let neighbors = self.csr.neighbors(idx);
236 let edges = self.csr.edge_data_for(idx);
237 for (i, &neighbor) in neighbors.iter().enumerate() {
238 if !self.is_removed(idx, neighbor)
239 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
240 {
241 results.push((id, edges[i].clone()));
242 }
243 }
244
245 for delta in &self.delta_edges {
247 if delta.source_idx == idx
248 && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
249 {
250 results.push((id, delta.data.clone()));
251 }
252 }
253
254 results
255 }
256
257 pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
259 let Some(&idx) = self.id_to_idx.get(&id) else {
260 return Vec::new();
261 };
262 self.incoming_by_idx(idx)
263 }
264
265 pub fn incoming_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
267 self.incoming(id)
268 .into_iter()
269 .filter(|(_, e)| e.is_valid_at(at))
270 .collect()
271 }
272
273 pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
274 let mut results = Vec::new();
275
276 let neighbors = self.csc.neighbors(idx);
278 let edges = self.csc.edge_data_for(idx);
279 for (i, &neighbor) in neighbors.iter().enumerate() {
280 if !self.is_removed(neighbor, idx)
281 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
282 {
283 results.push((id, edges[i].clone()));
284 }
285 }
286
287 for delta in &self.delta_edges {
289 if delta.target_idx == idx
290 && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
291 {
292 results.push((id, delta.data.clone()));
293 }
294 }
295
296 results
297 }
298
299 pub fn contains_node(&self, id: MemoryId) -> bool {
301 self.id_to_idx.contains_key(&id)
302 }
303
304 pub fn node_count(&self) -> usize {
306 self.idx_to_id.len()
307 }
308
309 pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
311 self.id_to_idx.get(&id).copied()
312 }
313
314 #[allow(dead_code)]
316 pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
317 self.idx_to_id.get(idx as usize).copied()
318 }
319
320 pub fn node_ids(&self) -> &[MemoryId] {
322 &self.idx_to_id
323 }
324
325 fn is_removed(&self, source: u32, target: u32) -> bool {
326 self.removed_edges
327 .iter()
328 .any(|&(s, t)| s == source && t == target)
329 }
330
331 pub fn compact(&mut self) {
333 let num_nodes = self.idx_to_id.len();
334
335 let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
337
338 for row in 0..num_nodes {
340 let row = row as u32;
341 let neighbors = self.csr.neighbors(row);
342 let edges = self.csr.edge_data_for(row);
343 for (i, &col) in neighbors.iter().enumerate() {
344 if !self.is_removed(row, col) {
345 all_edges.push((row, col, edges[i].clone()));
346 }
347 }
348 }
349
350 for delta in &self.delta_edges {
352 all_edges.push((delta.source_idx, delta.target_idx, delta.data.clone()));
353 }
354
355 self.csr = Self::build_compressed(&all_edges, num_nodes, false);
357
358 self.csc = Self::build_compressed(&all_edges, num_nodes, true);
360
361 self.delta_edges.clear();
362 self.removed_edges.clear();
363 }
364
365 fn build_compressed(
366 edges: &[(u32, u32, StoredEdge)],
367 num_nodes: usize,
368 transpose: bool,
369 ) -> CompressedStorage {
370 let mut counts = vec![0u32; num_nodes];
372 for &(src, tgt, ref _data) in edges {
373 let row = if transpose { tgt } else { src };
374 if (row as usize) < num_nodes {
375 counts[row as usize] += 1;
376 }
377 }
378
379 let mut row_offsets = vec![0u32; num_nodes + 1];
381 for i in 0..num_nodes {
382 row_offsets[i + 1] = row_offsets[i] + counts[i];
383 }
384
385 let total = row_offsets[num_nodes] as usize;
386 let mut col_indices = vec![0u32; total];
387 let mut edge_data = vec![
388 StoredEdge {
389 edge_type: EdgeType::Related,
390 weight: 0.0,
391 created_at: 0,
392 valid_from: None,
393 valid_until: None,
394 label: None,
395 };
396 total
397 ];
398
399 let mut cursors = row_offsets[..num_nodes].to_vec();
401 for &(src, tgt, ref data) in edges {
402 let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
403 if (row as usize) < num_nodes {
404 let pos = cursors[row as usize] as usize;
405 col_indices[pos] = col;
406 edge_data[pos] = data.clone();
407 cursors[row as usize] += 1;
408 }
409 }
410
411 CompressedStorage {
412 row_offsets,
413 col_indices,
414 edge_data,
415 }
416 }
417 pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
419 let data =
420 serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
421 std::fs::write(path, data)?;
422 Ok(())
423 }
424
425 pub fn load(path: &std::path::Path) -> MenteResult<Self> {
427 let data = std::fs::read(path)?;
428 let graph: Self =
429 serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
430 Ok(graph)
431 }
432}
433
434impl Default for CsrGraph {
435 fn default() -> Self {
436 Self::new()
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
445 MemoryEdge {
446 source: src,
447 target: tgt,
448 edge_type: etype,
449 weight: 0.8,
450 created_at: 1000,
451 valid_from: None,
452 valid_until: None,
453 label: None,
454 }
455 }
456
457 #[test]
458 fn test_add_node_idempotent() {
459 let mut g = CsrGraph::new();
460 let id = MemoryId::new();
461 let idx1 = g.add_node(id);
462 let idx2 = g.add_node(id);
463 assert_eq!(idx1, idx2);
464 assert_eq!(g.node_count(), 1);
465 }
466
467 #[test]
468 fn test_add_and_query_edges() {
469 let mut g = CsrGraph::new();
470 let a = MemoryId::new();
471 let b = MemoryId::new();
472 let c = MemoryId::new();
473
474 g.add_edge(&make_edge(a, b, EdgeType::Caused));
475 g.add_edge(&make_edge(a, c, EdgeType::Related));
476
477 let out = g.outgoing(a);
478 assert_eq!(out.len(), 2);
479
480 let inc_b = g.incoming(b);
481 assert_eq!(inc_b.len(), 1);
482 assert_eq!(inc_b[0].0, a);
483 }
484
485 #[test]
486 fn test_remove_edge() {
487 let mut g = CsrGraph::new();
488 let a = MemoryId::new();
489 let b = MemoryId::new();
490
491 g.add_edge(&make_edge(a, b, EdgeType::Caused));
492 assert_eq!(g.outgoing(a).len(), 1);
493
494 g.remove_edge(a, b);
495 assert_eq!(g.outgoing(a).len(), 0);
496 }
497
498 #[test]
499 fn test_compact() {
500 let mut g = CsrGraph::new();
501 let a = MemoryId::new();
502 let b = MemoryId::new();
503 let c = MemoryId::new();
504
505 g.add_edge(&make_edge(a, b, EdgeType::Caused));
506 g.add_edge(&make_edge(b, c, EdgeType::Before));
507 g.compact();
508
509 let out_a = g.outgoing(a);
510 assert_eq!(out_a.len(), 1);
511 assert_eq!(out_a[0].0, b);
512
513 let inc_c = g.incoming(c);
514 assert_eq!(inc_c.len(), 1);
515 assert_eq!(inc_c[0].0, b);
516 }
517
518 #[test]
519 fn test_compact_with_removals() {
520 let mut g = CsrGraph::new();
521 let a = MemoryId::new();
522 let b = MemoryId::new();
523 let c = MemoryId::new();
524
525 g.add_edge(&make_edge(a, b, EdgeType::Caused));
526 g.add_edge(&make_edge(a, c, EdgeType::Related));
527 g.compact();
528
529 g.remove_edge(a, b);
530 g.compact();
531
532 let out = g.outgoing(a);
533 assert_eq!(out.len(), 1);
534 assert_eq!(out[0].0, c);
535 }
536}