1use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12 pub edge_type: EdgeType,
13 pub weight: f32,
14 pub created_at: Timestamp,
15}
16
17impl StoredEdge {
18 pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
19 Self {
20 edge_type: edge.edge_type,
21 weight: edge.weight,
22 created_at: edge.created_at,
23 }
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29struct DeltaEdge {
30 source_idx: u32,
31 target_idx: u32,
32 data: StoredEdge,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37struct CompressedStorage {
38 row_offsets: Vec<u32>,
40 col_indices: Vec<u32>,
42 edge_data: Vec<StoredEdge>,
44}
45
46impl CompressedStorage {
47 #[allow(dead_code)]
48 fn new(num_nodes: usize) -> Self {
49 Self {
50 row_offsets: vec![0; num_nodes + 1],
51 col_indices: Vec::new(),
52 edge_data: Vec::new(),
53 }
54 }
55
56 fn neighbors(&self, row: u32) -> &[u32] {
58 let row = row as usize;
59 if row + 1 >= self.row_offsets.len() {
60 return &[];
61 }
62 let start = self.row_offsets[row] as usize;
63 let end = self.row_offsets[row + 1] as usize;
64 &self.col_indices[start..end]
65 }
66
67 fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
68 let row = row as usize;
69 if row + 1 >= self.row_offsets.len() {
70 return &[];
71 }
72 let start = self.row_offsets[row] as usize;
73 let end = self.row_offsets[row + 1] as usize;
74 &self.edge_data[start..end]
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CsrGraph {
81 id_to_idx: HashMap<MemoryId, u32>,
83 idx_to_id: Vec<MemoryId>,
85
86 csr: CompressedStorage,
88 csc: CompressedStorage,
90
91 delta_edges: Vec<DeltaEdge>,
93 removed_edges: Vec<(u32, u32)>,
95}
96
97impl CsrGraph {
98 pub fn new() -> Self {
99 Self {
100 id_to_idx: HashMap::default(),
101 idx_to_id: Vec::new(),
102 csr: CompressedStorage::default(),
103 csc: CompressedStorage::default(),
104 delta_edges: Vec::new(),
105 removed_edges: Vec::new(),
106 }
107 }
108
109 pub fn add_node(&mut self, id: MemoryId) -> u32 {
111 if let Some(&idx) = self.id_to_idx.get(&id) {
112 return idx;
113 }
114 let idx = self.idx_to_id.len() as u32;
115 self.id_to_idx.insert(id, idx);
116 self.idx_to_id.push(id);
117 idx
118 }
119
120 pub fn remove_node(&mut self, id: MemoryId) {
122 let Some(&idx) = self.id_to_idx.get(&id) else {
123 return;
124 };
125 for &neighbor in self.csr.neighbors(idx) {
127 self.removed_edges.push((idx, neighbor));
128 }
129 for &neighbor in self.csc.neighbors(idx) {
130 self.removed_edges.push((neighbor, idx));
131 }
132 self.delta_edges
134 .retain(|e| e.source_idx != idx && e.target_idx != idx);
135 }
136
137 pub fn add_edge(&mut self, edge: &MemoryEdge) {
139 let source_idx = self.add_node(edge.source);
140 let target_idx = self.add_node(edge.target);
141 self.delta_edges.push(DeltaEdge {
142 source_idx,
143 target_idx,
144 data: StoredEdge::from_memory_edge(edge),
145 });
146 }
147
148 pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
150 let (Some(&src_idx), Some(&tgt_idx)) =
151 (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
152 else {
153 return;
154 };
155 self.removed_edges.push((src_idx, tgt_idx));
156 self.delta_edges
157 .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
158 }
159
160 pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
162 let Some(&idx) = self.id_to_idx.get(&id) else {
163 return Vec::new();
164 };
165 self.outgoing_by_idx(idx)
166 }
167
168 pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
169 let mut results = Vec::new();
170
171 let neighbors = self.csr.neighbors(idx);
173 let edges = self.csr.edge_data_for(idx);
174 for (i, &neighbor) in neighbors.iter().enumerate() {
175 if !self.is_removed(idx, neighbor)
176 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
177 {
178 results.push((id, edges[i]));
179 }
180 }
181
182 for delta in &self.delta_edges {
184 if delta.source_idx == idx
185 && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
186 {
187 results.push((id, delta.data));
188 }
189 }
190
191 results
192 }
193
194 pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
196 let Some(&idx) = self.id_to_idx.get(&id) else {
197 return Vec::new();
198 };
199 self.incoming_by_idx(idx)
200 }
201
202 pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
203 let mut results = Vec::new();
204
205 let neighbors = self.csc.neighbors(idx);
207 let edges = self.csc.edge_data_for(idx);
208 for (i, &neighbor) in neighbors.iter().enumerate() {
209 if !self.is_removed(neighbor, idx)
210 && let Some(&id) = self.idx_to_id.get(neighbor as usize)
211 {
212 results.push((id, edges[i]));
213 }
214 }
215
216 for delta in &self.delta_edges {
218 if delta.target_idx == idx
219 && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
220 {
221 results.push((id, delta.data));
222 }
223 }
224
225 results
226 }
227
228 pub fn contains_node(&self, id: MemoryId) -> bool {
230 self.id_to_idx.contains_key(&id)
231 }
232
233 pub fn node_count(&self) -> usize {
235 self.idx_to_id.len()
236 }
237
238 pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
240 self.id_to_idx.get(&id).copied()
241 }
242
243 #[allow(dead_code)]
245 pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
246 self.idx_to_id.get(idx as usize).copied()
247 }
248
249 pub fn node_ids(&self) -> &[MemoryId] {
251 &self.idx_to_id
252 }
253
254 fn is_removed(&self, source: u32, target: u32) -> bool {
255 self.removed_edges
256 .iter()
257 .any(|&(s, t)| s == source && t == target)
258 }
259
260 pub fn compact(&mut self) {
262 let num_nodes = self.idx_to_id.len();
263
264 let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
266
267 for row in 0..num_nodes {
269 let row = row as u32;
270 let neighbors = self.csr.neighbors(row);
271 let edges = self.csr.edge_data_for(row);
272 for (i, &col) in neighbors.iter().enumerate() {
273 if !self.is_removed(row, col) {
274 all_edges.push((row, col, edges[i]));
275 }
276 }
277 }
278
279 for delta in &self.delta_edges {
281 all_edges.push((delta.source_idx, delta.target_idx, delta.data));
282 }
283
284 self.csr = Self::build_compressed(&all_edges, num_nodes, false);
286
287 self.csc = Self::build_compressed(&all_edges, num_nodes, true);
289
290 self.delta_edges.clear();
291 self.removed_edges.clear();
292 }
293
294 fn build_compressed(
295 edges: &[(u32, u32, StoredEdge)],
296 num_nodes: usize,
297 transpose: bool,
298 ) -> CompressedStorage {
299 let mut counts = vec![0u32; num_nodes];
301 for &(src, tgt, _) in edges {
302 let row = if transpose { tgt } else { src };
303 if (row as usize) < num_nodes {
304 counts[row as usize] += 1;
305 }
306 }
307
308 let mut row_offsets = vec![0u32; num_nodes + 1];
310 for i in 0..num_nodes {
311 row_offsets[i + 1] = row_offsets[i] + counts[i];
312 }
313
314 let total = row_offsets[num_nodes] as usize;
315 let mut col_indices = vec![0u32; total];
316 let mut edge_data = vec![
317 StoredEdge {
318 edge_type: EdgeType::Related,
319 weight: 0.0,
320 created_at: 0,
321 };
322 total
323 ];
324
325 let mut cursors = row_offsets[..num_nodes].to_vec();
327 for &(src, tgt, data) in edges {
328 let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
329 if (row as usize) < num_nodes {
330 let pos = cursors[row as usize] as usize;
331 col_indices[pos] = col;
332 edge_data[pos] = data;
333 cursors[row as usize] += 1;
334 }
335 }
336
337 CompressedStorage {
338 row_offsets,
339 col_indices,
340 edge_data,
341 }
342 }
343 pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
345 let data =
346 serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
347 std::fs::write(path, data)?;
348 Ok(())
349 }
350
351 pub fn load(path: &std::path::Path) -> MenteResult<Self> {
353 let data = std::fs::read(path)?;
354 let graph: Self =
355 serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
356 Ok(graph)
357 }
358}
359
360impl Default for CsrGraph {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use uuid::Uuid;
370
371 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
372 MemoryEdge {
373 source: src,
374 target: tgt,
375 edge_type: etype,
376 weight: 0.8,
377 created_at: 1000,
378 }
379 }
380
381 #[test]
382 fn test_add_node_idempotent() {
383 let mut g = CsrGraph::new();
384 let id = Uuid::new_v4();
385 let idx1 = g.add_node(id);
386 let idx2 = g.add_node(id);
387 assert_eq!(idx1, idx2);
388 assert_eq!(g.node_count(), 1);
389 }
390
391 #[test]
392 fn test_add_and_query_edges() {
393 let mut g = CsrGraph::new();
394 let a = Uuid::new_v4();
395 let b = Uuid::new_v4();
396 let c = Uuid::new_v4();
397
398 g.add_edge(&make_edge(a, b, EdgeType::Caused));
399 g.add_edge(&make_edge(a, c, EdgeType::Related));
400
401 let out = g.outgoing(a);
402 assert_eq!(out.len(), 2);
403
404 let inc_b = g.incoming(b);
405 assert_eq!(inc_b.len(), 1);
406 assert_eq!(inc_b[0].0, a);
407 }
408
409 #[test]
410 fn test_remove_edge() {
411 let mut g = CsrGraph::new();
412 let a = Uuid::new_v4();
413 let b = Uuid::new_v4();
414
415 g.add_edge(&make_edge(a, b, EdgeType::Caused));
416 assert_eq!(g.outgoing(a).len(), 1);
417
418 g.remove_edge(a, b);
419 assert_eq!(g.outgoing(a).len(), 0);
420 }
421
422 #[test]
423 fn test_compact() {
424 let mut g = CsrGraph::new();
425 let a = Uuid::new_v4();
426 let b = Uuid::new_v4();
427 let c = Uuid::new_v4();
428
429 g.add_edge(&make_edge(a, b, EdgeType::Caused));
430 g.add_edge(&make_edge(b, c, EdgeType::Before));
431 g.compact();
432
433 let out_a = g.outgoing(a);
434 assert_eq!(out_a.len(), 1);
435 assert_eq!(out_a[0].0, b);
436
437 let inc_c = g.incoming(c);
438 assert_eq!(inc_c.len(), 1);
439 assert_eq!(inc_c[0].0, b);
440 }
441
442 #[test]
443 fn test_compact_with_removals() {
444 let mut g = CsrGraph::new();
445 let a = Uuid::new_v4();
446 let b = Uuid::new_v4();
447 let c = Uuid::new_v4();
448
449 g.add_edge(&make_edge(a, b, EdgeType::Caused));
450 g.add_edge(&make_edge(a, c, EdgeType::Related));
451 g.compact();
452
453 g.remove_edge(a, b);
454 g.compact();
455
456 let out = g.outgoing(a);
457 assert_eq!(out.len(), 1);
458 assert_eq!(out[0].0, c);
459 }
460}