use std::collections::{HashMap, HashSet, hash_map::Entry};
use super::dense_array::DenseArray;
pub use nodedb_types::graph::Direction;
pub struct CsrIndex {
pub(crate) node_to_id: HashMap<String, u32>,
pub(crate) id_to_node: Vec<String>,
pub(crate) label_to_id: HashMap<String, u16>,
pub(crate) id_to_label: Vec<String>,
pub(crate) out_offsets: Vec<u32>,
pub(crate) out_targets: DenseArray<u32>,
pub(crate) out_labels: DenseArray<u16>,
pub(crate) out_weights: Option<DenseArray<f64>>,
pub(crate) in_offsets: Vec<u32>,
pub(crate) in_targets: DenseArray<u32>,
pub(crate) in_labels: DenseArray<u16>,
pub(crate) in_weights: Option<DenseArray<f64>>,
pub(crate) buffer_out: Vec<Vec<(u16, u32)>>,
pub(crate) buffer_in: Vec<Vec<(u16, u32)>>,
pub(crate) buffer_out_weights: Vec<Vec<f64>>,
pub(crate) buffer_in_weights: Vec<Vec<f64>>,
pub(crate) deleted_edges: HashSet<(u32, u16, u32)>,
pub(crate) has_weights: bool,
pub(crate) node_label_bits: Vec<u64>,
pub(crate) node_label_to_id: HashMap<String, u8>,
pub(crate) node_label_names: Vec<String>,
pub(crate) access_counts: Vec<std::cell::Cell<u32>>,
pub(crate) query_epoch: u64,
}
impl Default for CsrIndex {
fn default() -> Self {
Self::new()
}
}
impl CsrIndex {
pub fn new() -> Self {
Self {
node_to_id: HashMap::new(),
id_to_node: Vec::new(),
label_to_id: HashMap::new(),
id_to_label: Vec::new(),
out_offsets: vec![0],
out_targets: DenseArray::default(),
out_labels: DenseArray::default(),
out_weights: None,
in_offsets: vec![0],
in_targets: DenseArray::default(),
in_labels: DenseArray::default(),
in_weights: None,
buffer_out: Vec::new(),
buffer_in: Vec::new(),
buffer_out_weights: Vec::new(),
buffer_in_weights: Vec::new(),
deleted_edges: HashSet::new(),
has_weights: false,
node_label_bits: Vec::new(),
node_label_to_id: HashMap::new(),
node_label_names: Vec::new(),
access_counts: Vec::new(),
query_epoch: 0,
}
}
pub(crate) fn ensure_node(&mut self, node: &str) -> u32 {
match self.node_to_id.entry(node.to_string()) {
Entry::Occupied(e) => *e.get(),
Entry::Vacant(e) => {
let id = self.id_to_node.len() as u32;
e.insert(id);
self.id_to_node.push(node.to_string());
self.out_offsets
.push(*self.out_offsets.last().unwrap_or(&0));
self.in_offsets.push(*self.in_offsets.last().unwrap_or(&0));
self.buffer_out.push(Vec::new());
self.buffer_in.push(Vec::new());
self.buffer_out_weights.push(Vec::new());
self.buffer_in_weights.push(Vec::new());
self.node_label_bits.push(0);
self.access_counts.push(std::cell::Cell::new(0));
id
}
}
}
fn ensure_label(&mut self, label: &str) -> u16 {
match self.label_to_id.entry(label.to_string()) {
Entry::Occupied(e) => *e.get(),
Entry::Vacant(e) => {
let id = self.id_to_label.len() as u16;
e.insert(id);
self.id_to_label.push(label.to_string());
id
}
}
}
fn ensure_node_label(&mut self, label: &str) -> Option<u8> {
if let Some(&id) = self.node_label_to_id.get(label) {
return Some(id);
}
let id = self.node_label_names.len();
if id >= 64 {
return None; }
let id = id as u8;
self.node_label_to_id.insert(label.to_string(), id);
self.node_label_names.push(label.to_string());
Some(id)
}
pub fn add_node_label(&mut self, node: &str, label: &str) -> bool {
let node_id = self.ensure_node(node);
let Some(label_id) = self.ensure_node_label(label) else {
return false;
};
self.node_label_bits[node_id as usize] |= 1u64 << label_id;
true
}
pub fn remove_node_label(&mut self, node: &str, label: &str) {
let Some(&node_id) = self.node_to_id.get(node) else {
return;
};
let Some(&label_id) = self.node_label_to_id.get(label) else {
return;
};
self.node_label_bits[node_id as usize] &= !(1u64 << label_id);
}
pub fn node_has_label(&self, node_id: u32, label: &str) -> bool {
let Some(&label_id) = self.node_label_to_id.get(label) else {
return false;
};
let bits = self
.node_label_bits
.get(node_id as usize)
.copied()
.unwrap_or(0);
bits & (1u64 << label_id) != 0
}
pub fn node_labels(&self, node_id: u32) -> Vec<&str> {
let bits = self
.node_label_bits
.get(node_id as usize)
.copied()
.unwrap_or(0);
if bits == 0 {
return Vec::new();
}
let mut labels = Vec::new();
for (i, name) in self.node_label_names.iter().enumerate() {
if bits & (1u64 << i) != 0 {
labels.push(name.as_str());
}
}
labels
}
pub fn add_edge(&mut self, src: &str, label: &str, dst: &str) {
self.add_edge_internal(src, label, dst, 1.0, false);
}
pub fn add_edge_weighted(&mut self, src: &str, label: &str, dst: &str, weight: f64) {
self.add_edge_internal(src, label, dst, weight, weight != 1.0);
}
fn add_edge_internal(
&mut self,
src: &str,
label: &str,
dst: &str,
weight: f64,
force_weights: bool,
) {
let src_id = self.ensure_node(src);
let dst_id = self.ensure_node(dst);
let label_id = self.ensure_label(label);
let out = &self.buffer_out[src_id as usize];
if out.iter().any(|&(l, d)| l == label_id && d == dst_id) {
return;
}
if self.dense_has_edge(src_id, label_id, dst_id) {
return;
}
if force_weights && !self.has_weights {
self.enable_weights();
}
self.buffer_out[src_id as usize].push((label_id, dst_id));
self.buffer_in[dst_id as usize].push((label_id, src_id));
if self.has_weights {
self.buffer_out_weights[src_id as usize].push(weight);
self.buffer_in_weights[dst_id as usize].push(weight);
}
self.deleted_edges.remove(&(src_id, label_id, dst_id));
}
pub fn remove_edge(&mut self, src: &str, label: &str, dst: &str) {
let (Some(&src_id), Some(&dst_id)) = (self.node_to_id.get(src), self.node_to_id.get(dst))
else {
return;
};
let Some(&label_id) = self.label_to_id.get(label) else {
return;
};
let out_buf = &self.buffer_out[src_id as usize];
if let Some(pos) = out_buf
.iter()
.position(|&(l, d)| l == label_id && d == dst_id)
{
self.buffer_out[src_id as usize].swap_remove(pos);
if self.has_weights {
self.buffer_out_weights[src_id as usize].swap_remove(pos);
}
}
let in_buf = &self.buffer_in[dst_id as usize];
if let Some(pos) = in_buf
.iter()
.position(|&(l, s)| l == label_id && s == src_id)
{
self.buffer_in[dst_id as usize].swap_remove(pos);
if self.has_weights {
self.buffer_in_weights[dst_id as usize].swap_remove(pos);
}
}
if self.dense_has_edge(src_id, label_id, dst_id) {
self.deleted_edges.insert((src_id, label_id, dst_id));
}
}
pub fn remove_node_edges(&mut self, node: &str) -> usize {
let Some(&node_id) = self.node_to_id.get(node) else {
return 0;
};
let mut removed = 0;
let out_edges: Vec<(u16, u32)> = self.iter_out_edges(node_id).collect();
for (label_id, dst_id) in &out_edges {
let in_buf = &self.buffer_in[*dst_id as usize];
if let Some(pos) = in_buf
.iter()
.position(|&(l, s)| l == *label_id && s == node_id)
{
self.buffer_in[*dst_id as usize].swap_remove(pos);
if self.has_weights {
self.buffer_in_weights[*dst_id as usize].swap_remove(pos);
}
}
self.deleted_edges.insert((node_id, *label_id, *dst_id));
removed += 1;
}
self.buffer_out[node_id as usize].clear();
if self.has_weights {
self.buffer_out_weights[node_id as usize].clear();
}
let in_edges: Vec<(u16, u32)> = self.iter_in_edges(node_id).collect();
for (label_id, src_id) in &in_edges {
let out_buf = &self.buffer_out[*src_id as usize];
if let Some(pos) = out_buf
.iter()
.position(|&(l, d)| l == *label_id && d == node_id)
{
self.buffer_out[*src_id as usize].swap_remove(pos);
if self.has_weights {
self.buffer_out_weights[*src_id as usize].swap_remove(pos);
}
}
self.deleted_edges.insert((*src_id, *label_id, node_id));
removed += 1;
}
self.buffer_in[node_id as usize].clear();
if self.has_weights {
self.buffer_in_weights[node_id as usize].clear();
}
removed
}
pub fn remove_nodes_with_prefix(&mut self, prefix: &str) {
let matching_nodes: Vec<String> = self
.node_to_id
.keys()
.filter(|k| k.starts_with(prefix))
.cloned()
.collect();
for node in &matching_nodes {
self.remove_node_edges(node);
}
}
pub fn neighbors(
&self,
node: &str,
label_filter: Option<&str>,
direction: Direction,
) -> Vec<(String, String)> {
let Some(&node_id) = self.node_to_id.get(node) else {
return Vec::new();
};
self.record_access(node_id);
let label_id = label_filter.and_then(|l| self.label_to_id.get(l).copied());
let mut result = Vec::new();
if matches!(direction, Direction::Out | Direction::Both) {
for (lid, dst) in self.iter_out_edges(node_id) {
if label_id.is_none_or(|f| f == lid) {
result.push((
self.id_to_label[lid as usize].clone(),
self.id_to_node[dst as usize].clone(),
));
}
}
}
if matches!(direction, Direction::In | Direction::Both) {
for (lid, src) in self.iter_in_edges(node_id) {
if label_id.is_none_or(|f| f == lid) {
result.push((
self.id_to_label[lid as usize].clone(),
self.id_to_node[src as usize].clone(),
));
}
}
}
result
}
pub fn neighbors_multi(
&self,
node: &str,
label_filters: &[&str],
direction: Direction,
) -> Vec<(String, String)> {
let Some(&node_id) = self.node_to_id.get(node) else {
return Vec::new();
};
self.record_access(node_id);
let label_ids: Vec<u16> = label_filters
.iter()
.filter_map(|l| self.label_to_id.get(*l).copied())
.collect();
let match_label = |lid: u16| label_ids.is_empty() || label_ids.contains(&lid);
let mut result = Vec::new();
if matches!(direction, Direction::Out | Direction::Both) {
for (lid, dst) in self.iter_out_edges(node_id) {
if match_label(lid) {
result.push((
self.id_to_label[lid as usize].clone(),
self.id_to_node[dst as usize].clone(),
));
}
}
}
if matches!(direction, Direction::In | Direction::Both) {
for (lid, src) in self.iter_in_edges(node_id) {
if match_label(lid) {
result.push((
self.id_to_label[lid as usize].clone(),
self.id_to_node[src as usize].clone(),
));
}
}
}
result
}
pub fn add_node(&mut self, name: &str) -> u32 {
self.ensure_node(name)
}
pub fn node_count(&self) -> usize {
self.id_to_node.len()
}
pub fn contains_node(&self, node: &str) -> bool {
self.node_to_id.contains_key(node)
}
pub fn node_name(&self, dense_id: u32) -> &str {
&self.id_to_node[dense_id as usize]
}
pub fn node_id(&self, name: &str) -> Option<u32> {
self.node_to_id.get(name).copied()
}
pub fn label_name(&self, label_id: u16) -> &str {
&self.id_to_label[label_id as usize]
}
pub fn label_id(&self, name: &str) -> Option<u16> {
self.label_to_id.get(name).copied()
}
pub fn out_degree(&self, node_id: u32) -> usize {
self.iter_out_edges(node_id).count()
}
pub fn in_degree(&self, node_id: u32) -> usize {
self.iter_in_edges(node_id).count()
}
pub fn edge_count(&self) -> usize {
let n = self.id_to_node.len();
(0..n).map(|i| self.out_degree(i as u32)).sum()
}
pub(crate) fn build_dense(edges: &[Vec<(u16, u32)>]) -> (Vec<u32>, Vec<u32>, Vec<u16>) {
let n = edges.len();
let total: usize = edges.iter().map(|e| e.len()).sum();
let mut offsets = Vec::with_capacity(n + 1);
let mut targets = Vec::with_capacity(total);
let mut labels = Vec::with_capacity(total);
let mut offset = 0u32;
for node_edges in edges {
offsets.push(offset);
for &(lid, target) in node_edges {
targets.push(target);
labels.push(lid);
}
offset += node_edges.len() as u32;
}
offsets.push(offset);
(offsets, targets, labels)
}
fn dense_has_edge(&self, src: u32, label: u16, dst: u32) -> bool {
for (lid, target) in self.dense_out_edges(src) {
if lid == label && target == dst {
return true;
}
}
false
}
pub(crate) fn dense_out_edges(&self, node: u32) -> impl Iterator<Item = (u16, u32)> + '_ {
let idx = node as usize;
if idx + 1 >= self.out_offsets.len() {
return Vec::new().into_iter();
}
let start = self.out_offsets[idx] as usize;
let end = self.out_offsets[idx + 1] as usize;
(start..end)
.map(move |i| (self.out_labels[i], self.out_targets[i]))
.collect::<Vec<_>>()
.into_iter()
}
pub(crate) fn dense_in_edges(&self, node: u32) -> impl Iterator<Item = (u16, u32)> + '_ {
let idx = node as usize;
if idx + 1 >= self.in_offsets.len() {
return Vec::new().into_iter();
}
let start = self.in_offsets[idx] as usize;
let end = self.in_offsets[idx + 1] as usize;
(start..end)
.map(move |i| (self.in_labels[i], self.in_targets[i]))
.collect::<Vec<_>>()
.into_iter()
}
pub fn iter_out_edges(&self, node: u32) -> impl Iterator<Item = (u16, u32)> + '_ {
let idx = node as usize;
let dense = self
.dense_out_edges(node)
.filter(move |&(lid, dst)| !self.deleted_edges.contains(&(node, lid, dst)));
let buffer = if idx < self.buffer_out.len() {
self.buffer_out[idx].to_vec()
} else {
Vec::new()
};
dense.chain(buffer)
}
pub fn iter_in_edges(&self, node: u32) -> impl Iterator<Item = (u16, u32)> + '_ {
let idx = node as usize;
let dense = self
.dense_in_edges(node)
.filter(move |&(lid, src)| !self.deleted_edges.contains(&(src, lid, node)));
let buffer = if idx < self.buffer_in.len() {
self.buffer_in[idx].to_vec()
} else {
Vec::new()
};
dense.chain(buffer)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_csr() -> CsrIndex {
let mut csr = CsrIndex::new();
csr.add_edge("a", "KNOWS", "b");
csr.add_edge("b", "KNOWS", "c");
csr.add_edge("c", "KNOWS", "d");
csr.add_edge("a", "WORKS", "e");
csr
}
#[test]
fn neighbors_out() {
let csr = make_csr();
let n = csr.neighbors("a", None, Direction::Out);
assert_eq!(n.len(), 2);
let dsts: Vec<&str> = n.iter().map(|(_, d)| d.as_str()).collect();
assert!(dsts.contains(&"b"));
assert!(dsts.contains(&"e"));
}
#[test]
fn neighbors_filtered() {
let csr = make_csr();
let n = csr.neighbors("a", Some("KNOWS"), Direction::Out);
assert_eq!(n.len(), 1);
assert_eq!(n[0].1, "b");
}
#[test]
fn neighbors_in() {
let csr = make_csr();
let n = csr.neighbors("b", None, Direction::In);
assert_eq!(n.len(), 1);
assert_eq!(n[0].1, "a");
}
#[test]
fn incremental_remove() {
let mut csr = make_csr();
assert_eq!(csr.neighbors("a", Some("KNOWS"), Direction::Out).len(), 1);
csr.remove_edge("a", "KNOWS", "b");
assert_eq!(csr.neighbors("a", Some("KNOWS"), Direction::Out).len(), 0);
}
#[test]
fn duplicate_add_is_idempotent() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("a", "L", "b");
assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
}
#[test]
fn compact_merges_buffer_into_dense() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("b", "L", "c");
assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
csr.compact();
assert!(csr.buffer_out.iter().all(|b| b.is_empty()));
assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
assert_eq!(csr.neighbors("b", None, Direction::Out).len(), 1);
}
#[test]
fn compact_handles_deletes() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("a", "L", "c");
csr.compact();
csr.remove_edge("a", "L", "b");
assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
csr.compact();
assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 1);
assert_eq!(csr.neighbors("a", None, Direction::Out)[0].1, "c");
}
#[test]
fn label_interning_reduces_memory() {
let mut csr = CsrIndex::new();
for i in 0..100 {
csr.add_edge(&format!("n{i}"), "FOLLOWS", &format!("n{}", i + 1));
}
assert_eq!(csr.id_to_label.len(), 1);
assert_eq!(csr.id_to_label[0], "FOLLOWS");
}
#[test]
fn edge_count() {
let csr = make_csr();
assert_eq!(csr.edge_count(), 4);
}
#[test]
fn checkpoint_roundtrip() {
let mut csr = make_csr();
csr.compact();
let bytes = csr.checkpoint_to_bytes();
assert!(!bytes.is_empty());
let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
assert_eq!(restored.node_count(), csr.node_count());
assert_eq!(restored.edge_count(), csr.edge_count());
let n = restored.neighbors("a", Some("KNOWS"), Direction::Out);
assert_eq!(n.len(), 1);
assert_eq!(n[0].1, "b");
}
#[test]
fn memory_estimation() {
let csr = make_csr();
let mem = csr.estimated_memory_bytes();
assert!(mem > 0);
}
#[test]
fn out_degree_and_in_degree() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("a", "L", "c");
csr.add_edge("d", "L", "b");
let a_id = *csr.node_to_id.get("a").unwrap();
let b_id = *csr.node_to_id.get("b").unwrap();
assert_eq!(csr.out_degree(a_id), 2);
assert_eq!(csr.in_degree(b_id), 2);
}
#[test]
fn remove_node_edges_all() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("a", "L", "c");
csr.add_edge("d", "L", "a");
let removed = csr.remove_node_edges("a");
assert_eq!(removed, 3);
assert_eq!(csr.neighbors("a", None, Direction::Out).len(), 0);
assert_eq!(csr.neighbors("a", None, Direction::In).len(), 0);
}
#[test]
fn add_node_idempotent() {
let mut csr = CsrIndex::new();
let id1 = csr.add_node("x");
let id2 = csr.add_node("x");
assert_eq!(id1, id2);
assert_eq!(csr.node_count(), 1);
}
#[test]
fn node_labels_bitset() {
let mut csr = CsrIndex::new();
csr.add_edge("alice", "KNOWS", "bob");
csr.add_edge("acme", "EMPLOYS", "alice");
assert!(csr.add_node_label("alice", "Person"));
assert!(csr.add_node_label("bob", "Person"));
assert!(csr.add_node_label("acme", "Company"));
let alice_id = csr.node_id("alice").unwrap();
let bob_id = csr.node_id("bob").unwrap();
let acme_id = csr.node_id("acme").unwrap();
assert!(csr.node_has_label(alice_id, "Person"));
assert!(!csr.node_has_label(alice_id, "Company"));
assert!(csr.node_has_label(acme_id, "Company"));
assert!(!csr.node_has_label(acme_id, "Person"));
assert!(csr.add_node_label("alice", "Employee"));
assert!(csr.node_has_label(alice_id, "Person"));
assert!(csr.node_has_label(alice_id, "Employee"));
assert_eq!(csr.node_labels(alice_id), vec!["Person", "Employee"]);
csr.remove_node_label("alice", "Employee");
assert!(!csr.node_has_label(alice_id, "Employee"));
assert!(csr.node_has_label(alice_id, "Person"));
assert!(!csr.node_has_label(bob_id, "NonExistent"));
}
}