1use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12 pub edge_type: EdgeType,
14 pub weight: f32,
16 pub created_at: Timestamp,
18 #[serde(default, skip_serializing_if = "Option::is_none")]
20 pub valid_from: Option<Timestamp>,
21 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub valid_until: Option<Timestamp>,
24}
25
26impl StoredEdge {
27 pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
29 Self {
30 edge_type: edge.edge_type,
31 weight: edge.weight,
32 created_at: edge.created_at,
33 valid_from: edge.valid_from,
34 valid_until: edge.valid_until,
35 }
36 }
37
38 pub fn is_valid_at(&self, at: Timestamp) -> bool {
40 let from = self.valid_from.unwrap_or(0);
41 match self.valid_until {
42 Some(until) => at >= from && at < until,
43 None => at >= from,
44 }
45 }
46
47 pub fn is_invalidated(&self) -> bool {
49 self.valid_until.is_some()
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55struct DeltaEdge {
56 source_idx: u32,
57 target_idx: u32,
58 data: StoredEdge,
59}
60
61#[derive(Debug, Clone, Default, Serialize, Deserialize)]
63struct CompressedStorage {
64 row_offsets: Vec<u32>,
66 col_indices: Vec<u32>,
68 edge_data: Vec<StoredEdge>,
70}
71
72impl CompressedStorage {
73 #[allow(dead_code)]
74 fn new(num_nodes: usize) -> Self {
75 Self {
76 row_offsets: vec![0; num_nodes + 1],
77 col_indices: Vec::new(),
78 edge_data: Vec::new(),
79 }
80 }
81
82 fn neighbors(&self, row: u32) -> &[u32] {
84 let row = row as usize;
85 if row + 1 >= self.row_offsets.len() {
86 return &[];
87 }
88 let start = self.row_offsets[row] as usize;
89 let end = self.row_offsets[row + 1] as usize;
90 &self.col_indices[start..end]
91 }
92
93 fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
94 let row = row as usize;
95 if row + 1 >= self.row_offsets.len() {
96 return &[];
97 }
98 let start = self.row_offsets[row] as usize;
99 let end = self.row_offsets[row + 1] as usize;
100 &self.edge_data[start..end]
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct CsrGraph {
107 id_to_idx: HashMap<MemoryId, u32>,
109 idx_to_id: Vec<MemoryId>,
111
112 csr: CompressedStorage,
114 csc: CompressedStorage,
116
117 delta_edges: Vec<DeltaEdge>,
119 removed_edges: Vec<(u32, u32)>,
121}
122
123impl CsrGraph {
124 pub fn new() -> Self {
126 Self {
127 id_to_idx: HashMap::default(),
128 idx_to_id: Vec::new(),
129 csr: CompressedStorage::default(),
130 csc: CompressedStorage::default(),
131 delta_edges: Vec::new(),
132 removed_edges: Vec::new(),
133 }
134 }
135
136 pub fn add_node(&mut self, id: MemoryId) -> u32 {
138 if let Some(&idx) = self.id_to_idx.get(&id) {
139 return idx;
140 }
141 let idx = self.idx_to_id.len() as u32;
142 self.id_to_idx.insert(id, idx);
143 self.idx_to_id.push(id);
144 idx
145 }
146
147 pub fn remove_node(&mut self, id: MemoryId) {
149 let Some(&idx) = self.id_to_idx.get(&id) else {
150 return;
151 };
152 for &neighbor in self.csr.neighbors(idx) {
154 self.removed_edges.push((idx, neighbor));
155 }
156 for &neighbor in self.csc.neighbors(idx) {
157 self.removed_edges.push((neighbor, idx));
158 }
159 self.delta_edges
161 .retain(|e| e.source_idx != idx && e.target_idx != idx);
162 }
163
164 pub fn add_edge(&mut self, edge: &MemoryEdge) {
166 let source_idx = self.add_node(edge.source);
167 let target_idx = self.add_node(edge.target);
168 self.delta_edges.push(DeltaEdge {
169 source_idx,
170 target_idx,
171 data: StoredEdge::from_memory_edge(edge),
172 });
173 }
174
175 pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
177 let (Some(&src_idx), Some(&tgt_idx)) =
178 (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
179 else {
180 return;
181 };
182 self.removed_edges.push((src_idx, tgt_idx));
183 self.delta_edges
184 .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
185 }
186
187 pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
189 let Some(&idx) = self.id_to_idx.get(&id) else {
190 return Vec::new();
191 };
192 self.outgoing_by_idx(idx)
193 }
194
195 pub fn outgoing_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
197 self.outgoing(id)
198 .into_iter()
199 .filter(|(_, e)| e.is_valid_at(at))
200 .collect()
201 }
202
203 pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
204 let mut results = Vec::new();
205
206 let neighbors = self.csr.neighbors(idx);
208 let edges = self.csr.edge_data_for(idx);
209 for (i, &neighbor) in neighbors.iter().enumerate() {
210 if !self.is_removed(idx, neighbor)
211 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
212 {
213 results.push((id, edges[i]));
214 }
215 }
216
217 for delta in &self.delta_edges {
219 if delta.source_idx == idx
220 && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
221 {
222 results.push((id, delta.data));
223 }
224 }
225
226 results
227 }
228
229 pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
231 let Some(&idx) = self.id_to_idx.get(&id) else {
232 return Vec::new();
233 };
234 self.incoming_by_idx(idx)
235 }
236
237 pub fn incoming_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
239 self.incoming(id)
240 .into_iter()
241 .filter(|(_, e)| e.is_valid_at(at))
242 .collect()
243 }
244
245 pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
246 let mut results = Vec::new();
247
248 let neighbors = self.csc.neighbors(idx);
250 let edges = self.csc.edge_data_for(idx);
251 for (i, &neighbor) in neighbors.iter().enumerate() {
252 if !self.is_removed(neighbor, idx)
253 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
254 {
255 results.push((id, edges[i]));
256 }
257 }
258
259 for delta in &self.delta_edges {
261 if delta.target_idx == idx
262 && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
263 {
264 results.push((id, delta.data));
265 }
266 }
267
268 results
269 }
270
271 pub fn contains_node(&self, id: MemoryId) -> bool {
273 self.id_to_idx.contains_key(&id)
274 }
275
276 pub fn node_count(&self) -> usize {
278 self.idx_to_id.len()
279 }
280
281 pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
283 self.id_to_idx.get(&id).copied()
284 }
285
286 #[allow(dead_code)]
288 pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
289 self.idx_to_id.get(idx as usize).copied()
290 }
291
292 pub fn node_ids(&self) -> &[MemoryId] {
294 &self.idx_to_id
295 }
296
297 fn is_removed(&self, source: u32, target: u32) -> bool {
298 self.removed_edges
299 .iter()
300 .any(|&(s, t)| s == source && t == target)
301 }
302
303 pub fn compact(&mut self) {
305 let num_nodes = self.idx_to_id.len();
306
307 let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
309
310 for row in 0..num_nodes {
312 let row = row as u32;
313 let neighbors = self.csr.neighbors(row);
314 let edges = self.csr.edge_data_for(row);
315 for (i, &col) in neighbors.iter().enumerate() {
316 if !self.is_removed(row, col) {
317 all_edges.push((row, col, edges[i]));
318 }
319 }
320 }
321
322 for delta in &self.delta_edges {
324 all_edges.push((delta.source_idx, delta.target_idx, delta.data));
325 }
326
327 self.csr = Self::build_compressed(&all_edges, num_nodes, false);
329
330 self.csc = Self::build_compressed(&all_edges, num_nodes, true);
332
333 self.delta_edges.clear();
334 self.removed_edges.clear();
335 }
336
337 fn build_compressed(
338 edges: &[(u32, u32, StoredEdge)],
339 num_nodes: usize,
340 transpose: bool,
341 ) -> CompressedStorage {
342 let mut counts = vec![0u32; num_nodes];
344 for &(src, tgt, _) in edges {
345 let row = if transpose { tgt } else { src };
346 if (row as usize) < num_nodes {
347 counts[row as usize] += 1;
348 }
349 }
350
351 let mut row_offsets = vec![0u32; num_nodes + 1];
353 for i in 0..num_nodes {
354 row_offsets[i + 1] = row_offsets[i] + counts[i];
355 }
356
357 let total = row_offsets[num_nodes] as usize;
358 let mut col_indices = vec![0u32; total];
359 let mut edge_data = vec![
360 StoredEdge {
361 edge_type: EdgeType::Related,
362 weight: 0.0,
363 created_at: 0,
364 valid_from: None,
365 valid_until: None,
366 };
367 total
368 ];
369
370 let mut cursors = row_offsets[..num_nodes].to_vec();
372 for &(src, tgt, data) in edges {
373 let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
374 if (row as usize) < num_nodes {
375 let pos = cursors[row as usize] as usize;
376 col_indices[pos] = col;
377 edge_data[pos] = data;
378 cursors[row as usize] += 1;
379 }
380 }
381
382 CompressedStorage {
383 row_offsets,
384 col_indices,
385 edge_data,
386 }
387 }
388 pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
390 let data =
391 serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
392 std::fs::write(path, data)?;
393 Ok(())
394 }
395
396 pub fn load(path: &std::path::Path) -> MenteResult<Self> {
398 let data = std::fs::read(path)?;
399 let graph: Self =
400 serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
401 Ok(graph)
402 }
403}
404
405impl Default for CsrGraph {
406 fn default() -> Self {
407 Self::new()
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
416 MemoryEdge {
417 source: src,
418 target: tgt,
419 edge_type: etype,
420 weight: 0.8,
421 created_at: 1000,
422 valid_from: None,
423 valid_until: None,
424 }
425 }
426
427 #[test]
428 fn test_add_node_idempotent() {
429 let mut g = CsrGraph::new();
430 let id = MemoryId::new();
431 let idx1 = g.add_node(id);
432 let idx2 = g.add_node(id);
433 assert_eq!(idx1, idx2);
434 assert_eq!(g.node_count(), 1);
435 }
436
437 #[test]
438 fn test_add_and_query_edges() {
439 let mut g = CsrGraph::new();
440 let a = MemoryId::new();
441 let b = MemoryId::new();
442 let c = MemoryId::new();
443
444 g.add_edge(&make_edge(a, b, EdgeType::Caused));
445 g.add_edge(&make_edge(a, c, EdgeType::Related));
446
447 let out = g.outgoing(a);
448 assert_eq!(out.len(), 2);
449
450 let inc_b = g.incoming(b);
451 assert_eq!(inc_b.len(), 1);
452 assert_eq!(inc_b[0].0, a);
453 }
454
455 #[test]
456 fn test_remove_edge() {
457 let mut g = CsrGraph::new();
458 let a = MemoryId::new();
459 let b = MemoryId::new();
460
461 g.add_edge(&make_edge(a, b, EdgeType::Caused));
462 assert_eq!(g.outgoing(a).len(), 1);
463
464 g.remove_edge(a, b);
465 assert_eq!(g.outgoing(a).len(), 0);
466 }
467
468 #[test]
469 fn test_compact() {
470 let mut g = CsrGraph::new();
471 let a = MemoryId::new();
472 let b = MemoryId::new();
473 let c = MemoryId::new();
474
475 g.add_edge(&make_edge(a, b, EdgeType::Caused));
476 g.add_edge(&make_edge(b, c, EdgeType::Before));
477 g.compact();
478
479 let out_a = g.outgoing(a);
480 assert_eq!(out_a.len(), 1);
481 assert_eq!(out_a[0].0, b);
482
483 let inc_c = g.incoming(c);
484 assert_eq!(inc_c.len(), 1);
485 assert_eq!(inc_c[0].0, b);
486 }
487
488 #[test]
489 fn test_compact_with_removals() {
490 let mut g = CsrGraph::new();
491 let a = MemoryId::new();
492 let b = MemoryId::new();
493 let c = MemoryId::new();
494
495 g.add_edge(&make_edge(a, b, EdgeType::Caused));
496 g.add_edge(&make_edge(a, c, EdgeType::Related));
497 g.compact();
498
499 g.remove_edge(a, b);
500 g.compact();
501
502 let out = g.outgoing(a);
503 assert_eq!(out.len(), 1);
504 assert_eq!(out[0].0, c);
505 }
506}