#![allow(missing_docs, unused_imports, unused_variables, unreachable_patterns, clippy::all)]
use crate::error::{Error, Result};
use crate::ops::AlgebraicLaw;
use crate::ir::validate::limits::MAX_GRAPH_NODES;
use crate::ops::graph::bfs::Bfs;
use crate::ops::graph::csr::{to_csr, CsrGraph};
use crate::ops::{OpSpec, BYTES_TO_U32_OUTPUTS, U32_INPUTS};
use std::collections::VecDeque;
pub fn bfs_from_source(csr: &CsrGraph, source: u32, max_depth: u32, reached: &mut Vec<ReachableNode>) {
let node_count = csr.node_count();
let mut visited = vec![false; node_count];
let mut queue = VecDeque::new();
let Ok(source_index) = usize::try_from(source) else {
return;
};
visited[source_index] = true;
queue.push_back((source, 0u32));
while let Some((node, depth)) = queue.pop_front() {
if node != source {
reached.push((source, node, depth));
}
let Ok(node_index) = usize::try_from(node) else {
continue;
};
if depth >= max_depth || is_sanitizer(csr.node_data[node_index]) {
continue;
}
let Some(next_node_index) = node_index.checked_add(1) else {
continue;
};
let Ok(start) = usize::try_from(csr.offsets[node_index]) else {
continue;
};
let Ok(end) = usize::try_from(csr.offsets[next_node_index]) else {
continue;
};
for &target in &csr.targets[start..end] {
let Ok(target_idx) = usize::try_from(target) else {
continue;
};
if target_idx < node_count && !visited[target_idx] {
visited[target_idx] = true;
queue.push_back((target, depth.saturating_add(1)));
}
}
}
}
impl ReachabilityOp {
pub const SPEC: OpSpec = OpSpec::composition(
"graph.reachability",
U32_INPUTS,
BYTES_TO_U32_OUTPUTS,
LAWS,
Bfs::program,
);
}
pub fn is_sanitizer(node_data: u32) -> bool {
((node_data >> 16) & 0xFF) == 4
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
#[derive(Debug, Default, Clone, Copy)]
pub struct ReachabilityOp;
pub type ReachableNode = (u32, u32, u32);
pub fn reachable_nodes(
csr: &CsrGraph,
sources: &[u32],
max_depth: u32,
) -> Result<Vec<ReachableNode>> {
let node_count = csr.node_count();
validate_reachability_csr(csr, node_count)?;
let mut reached = Vec::new();
for &source in sources {
let Ok(source_index) = usize::try_from(source) else {
continue;
};
if source_index >= node_count {
continue;
}
bfs_from_source(csr, source, max_depth, &mut reached);
}
Ok(reached)
}
pub fn validate_reachability_csr(csr: &CsrGraph, node_count: usize) -> Result<()> {
if node_count > MAX_GRAPH_NODES {
return Err(Error::Csr {
message: format!(
"GraphTooLarge: node_count {node_count} exceeds {MAX_GRAPH_NODES}. Fix: split the graph before CPU reachability."
),
});
}
let expected_offsets = node_count.checked_add(1).ok_or_else(|| Error::Csr {
message: "CsrInvalid: node_count + 1 overflows usize. Fix: split the graph before CPU reachability.".to_string(),
})?;
if csr.offsets.len() != expected_offsets {
return Err(Error::Csr {
message: format!(
"CsrInvalid: offsets length {} does not equal node_count + 1 ({expected_offsets}). Fix: rebuild CSR offsets before CPU reachability.",
csr.offsets.len()
),
});
}
csr.validate()
}
#[test]
pub fn finds_reachable_nodes_for_each_source() -> crate::error::Result<()> {
let csr = to_csr(4, &[(0, 1), (1, 2), (3, 2)])?;
assert_eq!(
reachable_nodes(&csr, &[0, 3], 8)?,
vec![(0, 1, 1), (0, 2, 2), (3, 2, 1)]
);
Ok(())
}
#[test]
pub fn respects_max_depth() -> crate::error::Result<()> {
let csr = to_csr(3, &[(0, 1), (1, 2)])?;
assert_eq!(reachable_nodes(&csr, &[0], 1)?, vec![(0, 1, 1)]);
Ok(())
}