use std::collections::HashSet;
use vyre_foundation::execution_plan::fusion::fuse_programs;
use vyre_foundation::ir::{DataType, Program};
use crate::bitset::bitset_words;
use crate::bitset::or::bitset_or;
use crate::graph::csr_forward_traverse::csr_forward_traverse;
use crate::graph::program_graph::ProgramGraphShape;
pub const OP_ID: &str = "vyre-primitives::graph::reachable_program";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UnknownNode {
pub index: usize,
pub node: u32,
pub node_count: u32,
}
impl std::fmt::Display for UnknownNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"reachable: edges[{}] references node {} but node_count = {}. \
Fix: callers must deduplicate and bounds-check edges before \
calling this primitive.",
self.index, self.node, self.node_count
)
}
}
impl std::error::Error for UnknownNode {}
pub fn reachable(
node_count: u32,
edges: &[(u32, u32)],
sources: &[u32],
) -> Result<HashSet<u32>, UnknownNode> {
const NONE: usize = usize::MAX;
let n = node_count as usize;
let mut head = vec![NONE; n];
let mut to_nodes = Vec::with_capacity(edges.len());
let mut next_edges = Vec::with_capacity(edges.len());
for (index, &(from, to)) in edges.iter().enumerate() {
if (from as usize) >= n {
return Err(UnknownNode {
index,
node: from,
node_count,
});
}
if (to as usize) >= n {
return Err(UnknownNode {
index,
node: to,
node_count,
});
}
let edge_index = to_nodes.len();
to_nodes.push(to);
next_edges.push(head[from as usize]);
head[from as usize] = edge_index;
}
let mut visited = vec![false; n];
let mut out_of_range_sources = Vec::new();
let mut stack = Vec::with_capacity(sources.len());
stack.extend_from_slice(sources);
while let Some(v) = stack.pop() {
let idx = v as usize;
if idx >= n {
out_of_range_sources.push(v);
continue;
}
if visited[idx] {
continue;
}
visited[idx] = true;
let mut edge = head[idx];
while edge != NONE {
let next = to_nodes[edge];
if !visited[next as usize] {
stack.push(next);
}
edge = next_edges[edge];
}
}
let mut result = HashSet::with_capacity(
visited
.iter()
.filter(|&&is_visited| is_visited)
.count()
.saturating_add(out_of_range_sources.len()),
);
for (idx, is_visited) in visited.into_iter().enumerate() {
if is_visited {
result.insert(idx as u32);
}
}
result.extend(out_of_range_sources);
Ok(result)
}
#[must_use]
pub fn reachable_program(
node_count: u32,
edge_count: u32,
sources_buf: &str,
reach_out: &str,
max_iters: u32,
) -> Program {
let shape = ProgramGraphShape::new(node_count, edge_count);
let words = bitset_words(node_count);
let frontier_a = "reach_frontier_a";
let frontier_b = "reach_frontier_b";
let mut arms: Vec<Program> =
Vec::with_capacity((max_iters as usize).saturating_mul(2).saturating_add(1));
arms.push(bitset_or(sources_buf, reach_out, reach_out, words));
for i in 0..max_iters {
let in_buf = if i == 0 {
sources_buf
} else if i % 2 == 1 {
frontier_a
} else {
frontier_b
};
let out_buf = if i % 2 == 0 { frontier_a } else { frontier_b };
arms.push(csr_forward_traverse(shape, in_buf, out_buf, u32::MAX));
arms.push(bitset_or(out_buf, reach_out, reach_out, words));
}
fuse_programs(&arms).unwrap_or_else(|error| {
crate::invalid_output_program(
OP_ID,
reach_out,
DataType::U32,
format!("Fix: reachable_program composition failed: {error}"),
)
})
}
#[cfg(test)]
mod tests {
use super::*;
fn hs(items: &[u32]) -> HashSet<u32> {
items.iter().copied().collect()
}
#[test]
fn empty_sources_reach_nothing() {
let got = reachable(3, &[(0, 1), (1, 2)], &[]).unwrap();
assert!(got.is_empty());
}
#[test]
fn single_source_reaches_chain() {
let got = reachable(3, &[(0, 1), (1, 2)], &[0]).unwrap();
assert_eq!(got, hs(&[0, 1, 2]));
}
#[test]
fn cycle_terminates() {
let got = reachable(2, &[(0, 1), (1, 0)], &[0]).unwrap();
assert_eq!(got, hs(&[0, 1]));
}
#[test]
fn disconnected_source_not_included() {
let got = reachable(4, &[(0, 1), (2, 3)], &[0]).unwrap();
assert_eq!(got, hs(&[0, 1]));
assert!(!got.contains(&2));
assert!(!got.contains(&3));
}
#[test]
fn unknown_source_is_noop() {
let got = reachable(2, &[(0, 1)], &[7]).unwrap();
assert_eq!(got, hs(&[7]));
}
#[test]
fn out_of_range_edge_is_reported_not_silently_dropped() {
let err = reachable(3, &[(0, 1), (5, 1)], &[0]).unwrap_err();
assert_eq!(err.index, 1);
assert_eq!(err.node, 5);
assert_eq!(err.node_count, 3);
}
#[test]
fn reachable_program_smoke() {
let program = reachable_program(4, 4, "sources", "reach", 2);
assert!(!program.is_explicit_noop());
assert!(!program.buffers().is_empty());
assert!(!program.entry().is_empty());
let names: Vec<&str> = program.buffers().iter().map(|b| b.name()).collect();
assert!(names.contains(&"pg_edge_offsets"));
assert!(names.contains(&"pg_edge_targets"));
assert!(names.contains(&"sources"));
assert!(names.contains(&"reach"));
assert!(names.contains(&"reach_frontier_a"));
assert!(names.contains(&"reach_frontier_b"));
}
#[test]
fn reachable_program_zero_iters_seeds_only() {
let program = reachable_program(4, 4, "sources", "reach", 0);
assert!(!program.is_explicit_noop());
assert!(!program.buffers().is_empty());
}
}