1use std::collections::{HashMap, HashSet, hash_map::Entry};
25
26use super::dense_array::DenseArray;
27
28pub use nodedb_types::graph::Direction;
30
31pub struct CsrIndex {
33 pub(crate) node_to_id: HashMap<String, u32>,
35 pub(crate) id_to_node: Vec<String>,
36
37 pub(crate) label_to_id: HashMap<String, u16>,
39 pub(crate) id_to_label: Vec<String>,
40
41 pub(crate) out_offsets: Vec<u32>,
51 pub(crate) out_targets: DenseArray<u32>,
52 pub(crate) out_labels: DenseArray<u16>,
53 pub(crate) out_weights: Option<DenseArray<f64>>,
55
56 pub(crate) in_offsets: Vec<u32>,
57 pub(crate) in_targets: DenseArray<u32>,
58 pub(crate) in_labels: DenseArray<u16>,
59 pub(crate) in_weights: Option<DenseArray<f64>>,
61
62 pub(crate) buffer_out: Vec<Vec<(u16, u32)>>,
65 pub(crate) buffer_in: Vec<Vec<(u16, u32)>>,
66 pub(crate) buffer_out_weights: Vec<Vec<f64>>,
69 pub(crate) buffer_in_weights: Vec<Vec<f64>>,
71
72 pub(crate) deleted_edges: HashSet<(u32, u16, u32)>,
74
75 pub(crate) has_weights: bool,
78
79 pub(crate) node_label_bits: Vec<u64>,
86 pub(crate) node_label_to_id: HashMap<String, u8>,
87 pub(crate) node_label_names: Vec<String>,
88
89 pub(crate) access_counts: Vec<std::cell::Cell<u32>>,
94 pub(crate) query_epoch: u64,
96}
97
98impl Default for CsrIndex {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl CsrIndex {
105 pub fn new() -> Self {
106 Self {
107 node_to_id: HashMap::new(),
108 id_to_node: Vec::new(),
109 label_to_id: HashMap::new(),
110 id_to_label: Vec::new(),
111 out_offsets: vec![0],
112 out_targets: DenseArray::default(),
113 out_labels: DenseArray::default(),
114 out_weights: None,
115 in_offsets: vec![0],
116 in_targets: DenseArray::default(),
117 in_labels: DenseArray::default(),
118 in_weights: None,
119 buffer_out: Vec::new(),
120 buffer_in: Vec::new(),
121 buffer_out_weights: Vec::new(),
122 buffer_in_weights: Vec::new(),
123 deleted_edges: HashSet::new(),
124 has_weights: false,
125 node_label_bits: Vec::new(),
126 node_label_to_id: HashMap::new(),
127 node_label_names: Vec::new(),
128 access_counts: Vec::new(),
129 query_epoch: 0,
130 }
131 }
132
133 pub(crate) fn ensure_node(&mut self, node: &str) -> u32 {
135 match self.node_to_id.entry(node.to_string()) {
136 Entry::Occupied(e) => *e.get(),
137 Entry::Vacant(e) => {
138 let id = self.id_to_node.len() as u32;
139 e.insert(id);
140 self.id_to_node.push(node.to_string());
141 self.out_offsets
143 .push(*self.out_offsets.last().unwrap_or(&0));
144 self.in_offsets.push(*self.in_offsets.last().unwrap_or(&0));
145 self.buffer_out.push(Vec::new());
147 self.buffer_in.push(Vec::new());
148 self.buffer_out_weights.push(Vec::new());
149 self.buffer_in_weights.push(Vec::new());
150 self.node_label_bits.push(0);
151 self.access_counts.push(std::cell::Cell::new(0));
152 id
153 }
154 }
155 }
156
157 fn ensure_label(&mut self, label: &str) -> u16 {
159 match self.label_to_id.entry(label.to_string()) {
160 Entry::Occupied(e) => *e.get(),
161 Entry::Vacant(e) => {
162 let id = self.id_to_label.len() as u16;
163 e.insert(id);
164 self.id_to_label.push(label.to_string());
165 id
166 }
167 }
168 }
169
170 fn ensure_node_label(&mut self, label: &str) -> Option<u8> {
174 if let Some(&id) = self.node_label_to_id.get(label) {
175 return Some(id);
176 }
177 let id = self.node_label_names.len();
178 if id >= 64 {
179 return None; }
181 let id = id as u8;
182 self.node_label_to_id.insert(label.to_string(), id);
183 self.node_label_names.push(label.to_string());
184 Some(id)
185 }
186
187 pub fn add_node_label(&mut self, node: &str, label: &str) -> bool {
189 let node_id = self.ensure_node(node);
190 let Some(label_id) = self.ensure_node_label(label) else {
191 return false;
192 };
193 self.node_label_bits[node_id as usize] |= 1u64 << label_id;
194 true
195 }
196
197 pub fn remove_node_label(&mut self, node: &str, label: &str) {
199 let Some(&node_id) = self.node_to_id.get(node) else {
200 return;
201 };
202 let Some(&label_id) = self.node_label_to_id.get(label) else {
203 return;
204 };
205 self.node_label_bits[node_id as usize] &= !(1u64 << label_id);
206 }
207
208 pub fn node_has_label(&self, node_id: u32, label: &str) -> bool {
210 let Some(&label_id) = self.node_label_to_id.get(label) else {
211 return false;
212 };
213 let bits = self
214 .node_label_bits
215 .get(node_id as usize)
216 .copied()
217 .unwrap_or(0);
218 bits & (1u64 << label_id) != 0
219 }
220
221 pub fn node_labels(&self, node_id: u32) -> Vec<&str> {
223 let bits = self
224 .node_label_bits
225 .get(node_id as usize)
226 .copied()
227 .unwrap_or(0);
228 if bits == 0 {
229 return Vec::new();
230 }
231 let mut labels = Vec::new();
232 for (i, name) in self.node_label_names.iter().enumerate() {
233 if bits & (1u64 << i) != 0 {
234 labels.push(name.as_str());
235 }
236 }
237 labels
238 }
239
240 pub fn add_edge(&mut self, src: &str, label: &str, dst: &str) {
243 self.add_edge_internal(src, label, dst, 1.0, false);
244 }
245
246 pub fn add_edge_weighted(&mut self, src: &str, label: &str, dst: &str, weight: f64) {
252 self.add_edge_internal(src, label, dst, weight, weight != 1.0);
253 }
254
255 fn add_edge_internal(
256 &mut self,
257 src: &str,
258 label: &str,
259 dst: &str,
260 weight: f64,
261 force_weights: bool,
262 ) {
263 let src_id = self.ensure_node(src);
264 let dst_id = self.ensure_node(dst);
265 let label_id = self.ensure_label(label);
266
267 let out = &self.buffer_out[src_id as usize];
269 if out.iter().any(|&(l, d)| l == label_id && d == dst_id) {
270 return;
271 }
272 if self.dense_has_edge(src_id, label_id, dst_id) {
274 return;
275 }
276
277 if force_weights && !self.has_weights {
279 self.enable_weights();
280 }
281
282 self.buffer_out[src_id as usize].push((label_id, dst_id));
283 self.buffer_in[dst_id as usize].push((label_id, src_id));
284
285 if self.has_weights {
286 self.buffer_out_weights[src_id as usize].push(weight);
287 self.buffer_in_weights[dst_id as usize].push(weight);
288 }
289
290 self.deleted_edges.remove(&(src_id, label_id, dst_id));
292 }
293
294 pub fn remove_edge(&mut self, src: &str, label: &str, dst: &str) {
296 let (Some(&src_id), Some(&dst_id)) = (self.node_to_id.get(src), self.node_to_id.get(dst))
297 else {
298 return;
299 };
300 let Some(&label_id) = self.label_to_id.get(label) else {
301 return;
302 };
303
304 let out_buf = &self.buffer_out[src_id as usize];
306 if let Some(pos) = out_buf
307 .iter()
308 .position(|&(l, d)| l == label_id && d == dst_id)
309 {
310 self.buffer_out[src_id as usize].swap_remove(pos);
311 if self.has_weights {
312 self.buffer_out_weights[src_id as usize].swap_remove(pos);
313 }
314 }
315 let in_buf = &self.buffer_in[dst_id as usize];
316 if let Some(pos) = in_buf
317 .iter()
318 .position(|&(l, s)| l == label_id && s == src_id)
319 {
320 self.buffer_in[dst_id as usize].swap_remove(pos);
321 if self.has_weights {
322 self.buffer_in_weights[dst_id as usize].swap_remove(pos);
323 }
324 }
325
326 if self.dense_has_edge(src_id, label_id, dst_id) {
328 self.deleted_edges.insert((src_id, label_id, dst_id));
329 }
330 }
331
332 pub fn remove_node_edges(&mut self, node: &str) -> usize {
334 let Some(&node_id) = self.node_to_id.get(node) else {
335 return 0;
336 };
337 let mut removed = 0;
338
339 let out_edges: Vec<(u16, u32)> = self.iter_out_edges(node_id).collect();
341 for (label_id, dst_id) in &out_edges {
342 let in_buf = &self.buffer_in[*dst_id as usize];
343 if let Some(pos) = in_buf
344 .iter()
345 .position(|&(l, s)| l == *label_id && s == node_id)
346 {
347 self.buffer_in[*dst_id as usize].swap_remove(pos);
348 if self.has_weights {
349 self.buffer_in_weights[*dst_id as usize].swap_remove(pos);
350 }
351 }
352 self.deleted_edges.insert((node_id, *label_id, *dst_id));
353 removed += 1;
354 }
355 self.buffer_out[node_id as usize].clear();
356 if self.has_weights {
357 self.buffer_out_weights[node_id as usize].clear();
358 }
359
360 let in_edges: Vec<(u16, u32)> = self.iter_in_edges(node_id).collect();
362 for (label_id, src_id) in &in_edges {
363 let out_buf = &self.buffer_out[*src_id as usize];
364 if let Some(pos) = out_buf
365 .iter()
366 .position(|&(l, d)| l == *label_id && d == node_id)
367 {
368 self.buffer_out[*src_id as usize].swap_remove(pos);
369 if self.has_weights {
370 self.buffer_out_weights[*src_id as usize].swap_remove(pos);
371 }
372 }
373 self.deleted_edges.insert((*src_id, *label_id, node_id));
374 removed += 1;
375 }
376 self.buffer_in[node_id as usize].clear();
377 if self.has_weights {
378 self.buffer_in_weights[node_id as usize].clear();
379 }
380
381 removed
382 }
383
384 pub fn remove_nodes_with_prefix(&mut self, prefix: &str) {
389 let matching_nodes: Vec<String> = self
390 .node_to_id
391 .keys()
392 .filter(|k| k.starts_with(prefix))
393 .cloned()
394 .collect();
395 for node in &matching_nodes {
396 self.remove_node_edges(node);
397 }
398 }
399
400 pub fn neighbors(
402 &self,
403 node: &str,
404 label_filter: Option<&str>,
405 direction: Direction,
406 ) -> Vec<(String, String)> {
407 let Some(&node_id) = self.node_to_id.get(node) else {
408 return Vec::new();
409 };
410 self.record_access(node_id);
411 let label_id = label_filter.and_then(|l| self.label_to_id.get(l).copied());
412
413 let mut result = Vec::new();
414
415 if matches!(direction, Direction::Out | Direction::Both) {
416 for (lid, dst) in self.iter_out_edges(node_id) {
417 if label_id.is_none_or(|f| f == lid) {
418 result.push((
419 self.id_to_label[lid as usize].clone(),
420 self.id_to_node[dst as usize].clone(),
421 ));
422 }
423 }
424 }
425 if matches!(direction, Direction::In | Direction::Both) {
426 for (lid, src) in self.iter_in_edges(node_id) {
427 if label_id.is_none_or(|f| f == lid) {
428 result.push((
429 self.id_to_label[lid as usize].clone(),
430 self.id_to_node[src as usize].clone(),
431 ));
432 }
433 }
434 }
435
436 result
437 }
438
439 pub fn neighbors_multi(
441 &self,
442 node: &str,
443 label_filters: &[&str],
444 direction: Direction,
445 ) -> Vec<(String, String)> {
446 let Some(&node_id) = self.node_to_id.get(node) else {
447 return Vec::new();
448 };
449 self.record_access(node_id);
450 let label_ids: Vec<u16> = label_filters
451 .iter()
452 .filter_map(|l| self.label_to_id.get(*l).copied())
453 .collect();
454 let match_label = |lid: u16| label_ids.is_empty() || label_ids.contains(&lid);
455
456 let mut result = Vec::new();
457
458 if matches!(direction, Direction::Out | Direction::Both) {
459 for (lid, dst) in self.iter_out_edges(node_id) {
460 if match_label(lid) {
461 result.push((
462 self.id_to_label[lid as usize].clone(),
463 self.id_to_node[dst as usize].clone(),
464 ));
465 }
466 }
467 }
468 if matches!(direction, Direction::In | Direction::Both) {
469 for (lid, src) in self.iter_in_edges(node_id) {
470 if match_label(lid) {
471 result.push((
472 self.id_to_label[lid as usize].clone(),
473 self.id_to_node[src as usize].clone(),
474 ));
475 }
476 }
477 }
478
479 result
480 }
481
482 pub fn add_node(&mut self, name: &str) -> u32 {
485 self.ensure_node(name)
486 }
487
488 pub fn node_count(&self) -> usize {
489 self.id_to_node.len()
490 }
491
492 pub fn contains_node(&self, node: &str) -> bool {
493 self.node_to_id.contains_key(node)
494 }
495
496 pub fn node_name(&self, dense_id: u32) -> &str {
498 &self.id_to_node[dense_id as usize]
499 }
500
501 pub fn node_id(&self, name: &str) -> Option<u32> {
503 self.node_to_id.get(name).copied()
504 }
505
506 pub fn label_name(&self, label_id: u16) -> &str {
508 &self.id_to_label[label_id as usize]
509 }
510
511 pub fn label_id(&self, name: &str) -> Option<u16> {
513 self.label_to_id.get(name).copied()
514 }
515
516 pub fn out_degree(&self, node_id: u32) -> usize {
518 self.iter_out_edges(node_id).count()
519 }
520
521 pub fn in_degree(&self, node_id: u32) -> usize {
523 self.iter_in_edges(node_id).count()
524 }
525
526 pub fn edge_count(&self) -> usize {
528 let n = self.id_to_node.len();
529 (0..n).map(|i| self.out_degree(i as u32)).sum()
530 }
531
532 pub(crate) fn build_dense(edges: &[Vec<(u16, u32)>]) -> (Vec<u32>, Vec<u32>, Vec<u16>) {
536 let n = edges.len();
537 let total: usize = edges.iter().map(|e| e.len()).sum();
538 let mut offsets = Vec::with_capacity(n + 1);
539 let mut targets = Vec::with_capacity(total);
540 let mut labels = Vec::with_capacity(total);
541
542 let mut offset = 0u32;
543 for node_edges in edges {
544 offsets.push(offset);
545 for &(lid, target) in node_edges {
546 targets.push(target);
547 labels.push(lid);
548 }
549 offset += node_edges.len() as u32;
550 }
551 offsets.push(offset);
552
553 (offsets, targets, labels)
554 }
555
556 fn dense_has_edge(&self, src: u32, label: u16, dst: u32) -> bool {
558 for (lid, target) in self.dense_out_edges(src) {
559 if lid == label && target == dst {
560 return true;
561 }
562 }
563 false
564 }
565
566 pub(crate) fn dense_out_edges(&self, node: u32) -> impl Iterator<Item = (u16, u32)> + '_ {
568 let idx = node as usize;
569 if idx + 1 >= self.out_offsets.len() {
570 return Vec::new().into_iter();
571 }
572 let start = self.out_offsets[idx] as usize;
573 let end = self.out_offsets[idx + 1] as usize;
574 (start..end)
575 .map(move |i| (self.out_labels[i], self.out_targets[i]))
576 .collect::<Vec<_>>()
577 .into_iter()
578 }
579
580 pub(crate) fn dense_in_edges(&self, node: u32) -> impl Iterator<Item = (u16, u32)> + '_ {
582 let idx = node as usize;
583 if idx + 1 >= self.in_offsets.len() {
584 return Vec::new().into_iter();
585 }
586 let start = self.in_offsets[idx] as usize;
587 let end = self.in_offsets[idx + 1] as usize;
588 (start..end)
589 .map(move |i| (self.in_labels[i], self.in_targets[i]))
590 .collect::<Vec<_>>()
591 .into_iter()
592 }
593
594 pub fn iter_out_edges(&self, node: u32) -> impl Iterator<Item = (u16, u32)> + '_ {
596 let idx = node as usize;
597 let dense = self
598 .dense_out_edges(node)
599 .filter(move |&(lid, dst)| !self.deleted_edges.contains(&(node, lid, dst)));
600 let buffer = if idx < self.buffer_out.len() {
601 self.buffer_out[idx].to_vec()
602 } else {
603 Vec::new()
604 };
605 dense.chain(buffer)
606 }
607
608 pub fn iter_in_edges(&self, node: u32) -> impl Iterator<Item = (u16, u32)> + '_ {
610 let idx = node as usize;
611 let dense = self
612 .dense_in_edges(node)
613 .filter(move |&(lid, src)| !self.deleted_edges.contains(&(src, lid, node)));
614 let buffer = if idx < self.buffer_in.len() {
615 self.buffer_in[idx].to_vec()
616 } else {
617 Vec::new()
618 };
619 dense.chain(buffer)
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626
627 fn make_csr() -> CsrIndex {
628 let mut csr = CsrIndex::new();
629 csr.add_edge("a", "KNOWS", "b");
630 csr.add_edge("b", "KNOWS", "c");
631 csr.add_edge("c", "KNOWS", "d");
632 csr.add_edge("a", "WORKS", "e");
633 csr
634 }
635
636 #[test]
637 fn neighbors_out() {
638 let csr = make_csr();
639 let n = csr.neighbors("a", None, Direction::Out);
640 assert_eq!(n.len(), 2);
641 let dsts: Vec<&str> = n.iter().map(|(_, d)| d.as_str()).collect();
642 assert!(dsts.contains(&"b"));
643 assert!(dsts.contains(&"e"));
644 }
645
646 #[test]
647 fn neighbors_filtered() {
648 let csr = make_csr();
649 let n = csr.neighbors("a", Some("KNOWS"), Direction::Out);
650 assert_eq!(n.len(), 1);
651 assert_eq!(n[0].1, "b");
652 }
653
654 #[test]
655 fn neighbors_in() {
656 let csr = make_csr();
657 let n = csr.neighbors("b", None, Direction::In);
658 assert_eq!(n.len(), 1);
659 assert_eq!(n[0].1, "a");
660 }
661
662 #[test]
663 fn incremental_remove() {
664 let mut csr = make_csr();
665 assert_eq!(csr.neighbors("a", Some("KNOWS"), Direction::Out).len(), 1);
666 csr.remove_edge("a", "KNOWS", "b");
667 assert_eq!(csr.neighbors("a", Some("KNOWS"), Direction::Out).len(), 0);
668 }
669
670 #[test]
671 fn duplicate_add_is_idempotent() {
672 let mut csr = CsrIndex::new();
673 csr.add_edge("a", "L", "b");
674 csr.add_edge("a", "L", "b");
675 assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
676 }
677
678 #[test]
679 fn compact_merges_buffer_into_dense() {
680 let mut csr = CsrIndex::new();
681 csr.add_edge("a", "L", "b");
682 csr.add_edge("b", "L", "c");
683 assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
684
685 csr.compact();
686 assert!(csr.buffer_out.iter().all(|b| b.is_empty()));
687 assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
688 assert_eq!(csr.neighbors("b", None, Direction::Out).len(), 1);
689 }
690
691 #[test]
692 fn compact_handles_deletes() {
693 let mut csr = CsrIndex::new();
694 csr.add_edge("a", "L", "b");
695 csr.add_edge("a", "L", "c");
696 csr.compact();
697
698 csr.remove_edge("a", "L", "b");
699 assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
700
701 csr.compact();
702 assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
703 assert_eq!(csr.neighbors("a", None, Direction::Out)[0].1, "c");
704 }
705
706 #[test]
707 fn label_interning_reduces_memory() {
708 let mut csr = CsrIndex::new();
709 for i in 0..100 {
710 csr.add_edge(&format!("n{i}"), "FOLLOWS", &format!("n{}", i + 1));
711 }
712 assert_eq!(csr.id_to_label.len(), 1);
713 assert_eq!(csr.id_to_label[0], "FOLLOWS");
714 }
715
716 #[test]
717 fn edge_count() {
718 let csr = make_csr();
719 assert_eq!(csr.edge_count(), 4);
720 }
721
722 #[test]
723 fn checkpoint_roundtrip() {
724 let mut csr = make_csr();
725 csr.compact();
726
727 let bytes = csr.checkpoint_to_bytes();
728 assert!(!bytes.is_empty());
729
730 let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
731 assert_eq!(restored.node_count(), csr.node_count());
732 assert_eq!(restored.edge_count(), csr.edge_count());
733
734 let n = restored.neighbors("a", Some("KNOWS"), Direction::Out);
735 assert_eq!(n.len(), 1);
736 assert_eq!(n[0].1, "b");
737 }
738
739 #[test]
740 fn memory_estimation() {
741 let csr = make_csr();
742 let mem = csr.estimated_memory_bytes();
743 assert!(mem > 0);
744 }
745
746 #[test]
747 fn out_degree_and_in_degree() {
748 let mut csr = CsrIndex::new();
749 csr.add_edge("a", "L", "b");
750 csr.add_edge("a", "L", "c");
751 csr.add_edge("d", "L", "b");
752
753 let a_id = *csr.node_to_id.get("a").unwrap();
754 let b_id = *csr.node_to_id.get("b").unwrap();
755
756 assert_eq!(csr.out_degree(a_id), 2);
757 assert_eq!(csr.in_degree(b_id), 2);
758 }
759
760 #[test]
761 fn remove_node_edges_all() {
762 let mut csr = CsrIndex::new();
763 csr.add_edge("a", "L", "b");
764 csr.add_edge("a", "L", "c");
765 csr.add_edge("d", "L", "a");
766
767 let removed = csr.remove_node_edges("a");
768 assert_eq!(removed, 3);
769 assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 0);
770 assert_eq!(csr.neighbors("a", None, Direction::In).len(), 0);
771 }
772
773 #[test]
774 fn add_node_idempotent() {
775 let mut csr = CsrIndex::new();
776 let id1 = csr.add_node("x");
777 let id2 = csr.add_node("x");
778 assert_eq!(id1, id2);
779 assert_eq!(csr.node_count(), 1);
780 }
781
782 #[test]
783 fn node_labels_bitset() {
784 let mut csr = CsrIndex::new();
785 csr.add_edge("alice", "KNOWS", "bob");
786 csr.add_edge("acme", "EMPLOYS", "alice");
787
788 assert!(csr.add_node_label("alice", "Person"));
790 assert!(csr.add_node_label("bob", "Person"));
791 assert!(csr.add_node_label("acme", "Company"));
792
793 let alice_id = csr.node_id("alice").unwrap();
794 let bob_id = csr.node_id("bob").unwrap();
795 let acme_id = csr.node_id("acme").unwrap();
796
797 assert!(csr.node_has_label(alice_id, "Person"));
798 assert!(!csr.node_has_label(alice_id, "Company"));
799 assert!(csr.node_has_label(acme_id, "Company"));
800 assert!(!csr.node_has_label(acme_id, "Person"));
801
802 assert!(csr.add_node_label("alice", "Employee"));
804 assert!(csr.node_has_label(alice_id, "Person"));
805 assert!(csr.node_has_label(alice_id, "Employee"));
806 assert_eq!(csr.node_labels(alice_id), vec!["Person", "Employee"]);
807
808 csr.remove_node_label("alice", "Employee");
810 assert!(!csr.node_has_label(alice_id, "Employee"));
811 assert!(csr.node_has_label(alice_id, "Person"));
812
813 assert!(!csr.node_has_label(bob_id, "NonExistent"));
815 }
816}