vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use std::collections::{HashSet, VecDeque};

use vyre::engine::dataflow::{bfs_reachability, MAX_FINDINGS};
use vyre::ops::graph::csr::{to_csr, CsrGraph};

fn gpu() -> Option<&'static (wgpu::Device, wgpu::Queue)> {
    match vyre_wgpu::runtime::cached_device() {
        Ok(pair) => Some(pair),
        Err(error) => {
            panic!("GPU required on this machine (RTX 5090 / 4090 available per project invariant) — do not silently skip: dataflow engine test: {error}");
            None
        }
    }
}

fn graph(node_count: usize, edges: &[(u32, u32)], sinks: &[u32]) -> CsrGraph {
    let mut csr = to_csr(node_count, edges).expect("test CSR must build");
    for &sink in sinks {
        // BFS kernel extracts the label from bits 16..23: (node_data >> 16) & 0xff.
        csr.node_data[usize::try_from(sink).expect("sink index must fit usize")] = 2 << 16;
    }
    csr
}

fn reference(csr: &CsrGraph, sources: &[u32], max_depth: u32) -> Vec<(u32, u32, u32)> {
    let mut out = Vec::new();
    for &source in sources {
        let mut seen = HashSet::from([source]);
        let mut queue = VecDeque::from([(source, 0_u32)]);
        while let Some((node, depth)) = queue.pop_front() {
            let label = (csr.node_data[node as usize] >> 16) & 0xff;
            if depth > 0 && matches!(label, 2 | 3) {
                out.push((source, node, depth));
            }
            if depth >= max_depth {
                continue;
            }
            let start = csr.offsets[node as usize] as usize;
            let end = csr.offsets[node as usize + 1] as usize;
            for &target in &csr.targets[start..end] {
                if seen.insert(target) {
                    queue.push_back((target, depth + 1));
                }
            }
        }
    }
    out.sort_unstable();
    out
}

#[test]
fn bfs_reachability_matches_cpu_reference_for_small_graph() {
    let Some((device, queue)) = gpu() else { return };
    let csr = graph(5, &[(0, 1), (1, 2), (0, 3), (3, 4)], &[2, 4]);

    let mut got = bfs_reachability(device, queue, &csr, &[0], 3, None).expect("GPU BFS must run");
    got.sort_unstable();

    assert_eq!(got, reference(&csr, &[0], 3));
}

#[test]
fn bfs_reachability_handles_empty_single_all_match_and_no_match_inputs() {
    let Some((device, queue)) = gpu() else { return };

    let empty = graph(0, &[], &[]);
    assert_eq!(
        bfs_reachability(device, queue, &empty, &[], 4, None).unwrap(),
        []
    );

    let single = graph(1, &[], &[0]);
    assert_eq!(
        bfs_reachability(device, queue, &single, &[0], 4, None).unwrap(),
        []
    );

    let all_match = graph(3, &[(0, 1), (1, 2)], &[1, 2]);
    let mut got = bfs_reachability(device, queue, &all_match, &[0], 4, None).unwrap();
    got.sort_unstable();
    assert_eq!(got, reference(&all_match, &[0], 4));

    let no_match = graph(3, &[(0, 1), (1, 2)], &[]);
    assert_eq!(
        bfs_reachability(device, queue, &no_match, &[0], 4, None).unwrap(),
        []
    );
}

#[test]
fn bfs_reachability_rejects_hostile_sources_and_oversized_outputs() {
    let Some((device, queue)) = gpu() else { return };

    let csr = graph(2, &[(0, 1)], &[1]);
    let source_error = bfs_reachability(device, queue, &csr, &[2], 2, None).unwrap_err();
    assert!(source_error.to_string().contains("Fix:"));

    let malformed = CsrGraph {
        offsets: vec![0],
        targets: vec![0],
        node_data: vec![0, 0],
    };
    let malformed_error = bfs_reachability(device, queue, &malformed, &[0], 2, None).unwrap_err();
    assert!(malformed_error.to_string().contains("Fix:"));

    let node_count = MAX_FINDINGS + 1;
    let oversized = CsrGraph {
        offsets: vec![0; node_count + 1],
        targets: Vec::new(),
        node_data: vec![0; node_count],
    };
    let oversized_error = bfs_reachability(device, queue, &oversized, &[0], 1, None).unwrap_err();
    assert!(oversized_error.to_string().contains("Fix:"));
}

#[test]
fn bfs_reachability_repeated_dispatches_do_not_reuse_stale_results() {
    let Some((device, queue)) = gpu() else { return };
    let csr = graph(4, &[(0, 1), (1, 2), (2, 3)], &[3]);

    for _ in 0..32 {
        let got =
            bfs_reachability(device, queue, &csr, &[0], 4, None).expect("repeated BFS dispatch");
        assert_eq!(got, vec![(0, 3, 3)]);
    }
}