use std::collections::HashSet;
use crate::engine::graph::csr::CsrIndex;
use crate::engine::graph::edge_store::Direction;
const MAX_VARLEN_RESULTS: usize = 100_000;
const MAX_VARLEN_FRONTIER: usize = 100_000;
pub(super) struct VarLenExpansion {
pub results: Vec<(u32, String)>,
pub truncated: bool,
}
pub(super) fn expand_variable_length(
csr: &CsrIndex,
source: u32,
label_filter: Option<&str>,
direction: Direction,
min_hops: usize,
max_hops: usize,
want_path: bool,
) -> VarLenExpansion {
let mut results: Vec<(u32, String)> = Vec::new();
if max_hops == 0 {
if min_hops == 0 {
let src_name = if want_path {
csr.node_name_raw(source).to_string()
} else {
String::new()
};
results.push((source, src_name));
}
return VarLenExpansion {
results,
truncated: false,
};
}
let src_name = if want_path {
csr.node_name_raw(source).to_string()
} else {
String::new()
};
let mut visited: HashSet<u32> = HashSet::new();
visited.insert(source);
if min_hops == 0 {
results.push((source, src_name.clone()));
}
let mut frontier: Vec<(u32, String)> = vec![(source, src_name)];
let mut truncated = false;
'outer: for depth in 1..=max_hops {
if frontier.is_empty() {
break;
}
let mut next_frontier: Vec<(u32, String)> = Vec::new();
for (node, path) in &frontier {
let neighbors = collect_neighbors(csr, *node, label_filter, direction);
for (_, dst) in neighbors {
if !visited.insert(dst) {
continue;
}
let new_path = if want_path {
let dst_name = csr.node_name_raw(dst).to_string();
format!("{path}->{dst_name}")
} else {
String::new()
};
if depth >= min_hops {
results.push((dst, new_path.clone()));
if results.len() >= MAX_VARLEN_RESULTS {
truncated = true;
break 'outer;
}
}
if depth < max_hops {
next_frontier.push((dst, new_path));
if next_frontier.len() >= MAX_VARLEN_FRONTIER {
truncated = true;
break 'outer;
}
}
}
}
frontier = next_frontier;
}
VarLenExpansion { results, truncated }
}
pub(super) fn collect_neighbors(
csr: &CsrIndex,
node: u32,
label_filter: Option<&str>,
direction: Direction,
) -> Vec<(u32, u32)> {
let mut neighbors = Vec::new();
match direction {
Direction::Out => {
for (lid, dst) in csr.iter_out_edges_raw(node) {
if label_filter.is_none() || csr_label_matches(csr, lid, label_filter) {
neighbors.push((lid, dst));
}
}
}
Direction::In => {
for (lid, src) in csr.iter_in_edges_raw(node) {
if label_filter.is_none() || csr_label_matches(csr, lid, label_filter) {
neighbors.push((lid, src));
}
}
}
Direction::Both => {
for (lid, dst) in csr.iter_out_edges_raw(node) {
if label_filter.is_none() || csr_label_matches(csr, lid, label_filter) {
neighbors.push((lid, dst));
}
}
for (lid, src) in csr.iter_in_edges_raw(node) {
if label_filter.is_none() || csr_label_matches(csr, lid, label_filter) {
neighbors.push((lid, src));
}
}
}
}
neighbors
}
fn csr_label_matches(csr: &CsrIndex, label_id: u32, filter: Option<&str>) -> bool {
match filter {
None => true,
Some(f) => csr.label_name(label_id) == f,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::graph::csr::CsrIndex;
use crate::engine::graph::edge_store::Direction;
#[test]
fn variable_length_expansion_dedups_nodes_across_paths() {
let mut csr = CsrIndex::new();
let nodes = ["a", "b", "c", "d", "e", "f"];
for &src in &nodes {
for &dst in &nodes {
if src != dst {
csr.add_edge(src, "l", dst).unwrap();
}
}
}
let expansion = expand_variable_length(
&csr,
csr.node_id_raw("a").unwrap(),
Some("l"),
Direction::Out,
1,
8,
false,
);
let results = expansion.results;
let distinct_dsts: std::collections::HashSet<u32> =
results.iter().map(|(d, _)| *d).collect();
assert!(
distinct_dsts.len() <= nodes.len(),
"distinct dst count must be <= |V| ({}); got {}",
nodes.len(),
distinct_dsts.len()
);
assert!(
results.len() <= nodes.len() * 8,
"variable-length expansion must not allocate b^d paths; \
got {} results on a 6-node graph with max_hops=8 \
(expected ≤ {})",
results.len(),
nodes.len() * 8
);
}
#[test]
fn variable_length_expansion_includes_source_at_zero_hops() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "l", "b").unwrap();
csr.add_edge("b", "l", "c").unwrap();
let expansion = expand_variable_length(
&csr,
csr.node_id_raw("a").unwrap(),
Some("l"),
Direction::Out,
0,
2,
false,
);
let results = expansion.results;
let dsts: std::collections::HashSet<u32> = results.iter().map(|(d, _)| *d).collect();
assert!(
dsts.contains(&csr.node_id_raw("a").unwrap()),
"*0..k must include the source node at depth 0; got dsts {dsts:?}"
);
}
#[test]
fn variable_length_expansion_exact_length_returns_only_that_depth() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "l", "b").unwrap();
csr.add_edge("b", "l", "c").unwrap();
csr.add_edge("c", "l", "d").unwrap();
let expansion = expand_variable_length(
&csr,
csr.node_id_raw("a").unwrap(),
Some("l"),
Direction::Out,
2,
2,
false,
);
let results = expansion.results;
let dsts: std::collections::HashSet<u32> = results.iter().map(|(d, _)| *d).collect();
let c = csr.node_id_raw("c").unwrap();
let expected: std::collections::HashSet<u32> = [c].into_iter().collect();
assert_eq!(
dsts, expected,
"*2..2 must return exactly the depth-2 reachable set {{c}}; got {dsts:?}"
);
}
#[test]
fn variable_length_expansion_caps_frontier_per_hop() {
let mut csr = CsrIndex::new();
const LEAVES: usize = 5_000;
for i in 0..LEAVES {
csr.add_edge("root", "l", &format!("leaf_{i}")).unwrap();
}
let expansion = expand_variable_length(
&csr,
csr.node_id_raw("root").unwrap(),
Some("l"),
Direction::Out,
1,
5,
false,
);
let results = expansion.results;
assert!(
results.len() <= LEAVES,
"star with {LEAVES} leaves must return at most {LEAVES} results; \
got {}",
results.len()
);
}
}