use std::collections::{HashSet, VecDeque};
use vyre::engine::dataflow::{bfs_reachability, MAX_FINDINGS};
use vyre::ops::graph::csr::{to_csr, CsrGraph};
fn gpu() -> Option<&'static (wgpu::Device, wgpu::Queue)> {
match vyre_wgpu::runtime::cached_device() {
Ok(pair) => Some(pair),
Err(error) => {
panic!("GPU required on this machine (RTX 5090 / 4090 available per project invariant) — do not silently skip: dataflow engine test: {error}");
None
}
}
}
fn graph(node_count: usize, edges: &[(u32, u32)], sinks: &[u32]) -> CsrGraph {
let mut csr = to_csr(node_count, edges).expect("test CSR must build");
for &sink in sinks {
csr.node_data[usize::try_from(sink).expect("sink index must fit usize")] = 2 << 16;
}
csr
}
fn reference(csr: &CsrGraph, sources: &[u32], max_depth: u32) -> Vec<(u32, u32, u32)> {
let mut out = Vec::new();
for &source in sources {
let mut seen = HashSet::from([source]);
let mut queue = VecDeque::from([(source, 0_u32)]);
while let Some((node, depth)) = queue.pop_front() {
let label = (csr.node_data[node as usize] >> 16) & 0xff;
if depth > 0 && matches!(label, 2 | 3) {
out.push((source, node, depth));
}
if depth >= max_depth {
continue;
}
let start = csr.offsets[node as usize] as usize;
let end = csr.offsets[node as usize + 1] as usize;
for &target in &csr.targets[start..end] {
if seen.insert(target) {
queue.push_back((target, depth + 1));
}
}
}
}
out.sort_unstable();
out
}
#[test]
fn bfs_reachability_matches_cpu_reference_for_small_graph() {
let Some((device, queue)) = gpu() else { return };
let csr = graph(5, &[(0, 1), (1, 2), (0, 3), (3, 4)], &[2, 4]);
let mut got = bfs_reachability(device, queue, &csr, &[0], 3, None).expect("GPU BFS must run");
got.sort_unstable();
assert_eq!(got, reference(&csr, &[0], 3));
}
#[test]
fn bfs_reachability_handles_empty_single_all_match_and_no_match_inputs() {
let Some((device, queue)) = gpu() else { return };
let empty = graph(0, &[], &[]);
assert_eq!(
bfs_reachability(device, queue, &empty, &[], 4, None).unwrap(),
[]
);
let single = graph(1, &[], &[0]);
assert_eq!(
bfs_reachability(device, queue, &single, &[0], 4, None).unwrap(),
[]
);
let all_match = graph(3, &[(0, 1), (1, 2)], &[1, 2]);
let mut got = bfs_reachability(device, queue, &all_match, &[0], 4, None).unwrap();
got.sort_unstable();
assert_eq!(got, reference(&all_match, &[0], 4));
let no_match = graph(3, &[(0, 1), (1, 2)], &[]);
assert_eq!(
bfs_reachability(device, queue, &no_match, &[0], 4, None).unwrap(),
[]
);
}
#[test]
fn bfs_reachability_rejects_hostile_sources_and_oversized_outputs() {
let Some((device, queue)) = gpu() else { return };
let csr = graph(2, &[(0, 1)], &[1]);
let source_error = bfs_reachability(device, queue, &csr, &[2], 2, None).unwrap_err();
assert!(source_error.to_string().contains("Fix:"));
let malformed = CsrGraph {
offsets: vec![0],
targets: vec![0],
node_data: vec![0, 0],
};
let malformed_error = bfs_reachability(device, queue, &malformed, &[0], 2, None).unwrap_err();
assert!(malformed_error.to_string().contains("Fix:"));
let node_count = MAX_FINDINGS + 1;
let oversized = CsrGraph {
offsets: vec![0; node_count + 1],
targets: Vec::new(),
node_data: vec![0; node_count],
};
let oversized_error = bfs_reachability(device, queue, &oversized, &[0], 1, None).unwrap_err();
assert!(oversized_error.to_string().contains("Fix:"));
}
#[test]
fn bfs_reachability_repeated_dispatches_do_not_reuse_stale_results() {
let Some((device, queue)) = gpu() else { return };
let csr = graph(4, &[(0, 1), (1, 2), (2, 3)], &[3]);
for _ in 0..32 {
let got =
bfs_reachability(device, queue, &csr, &[0], 4, None).expect("repeated BFS dispatch");
assert_eq!(got, vec![(0, 3, 3)]);
}
}