vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
#![allow(missing_docs, unused_imports, unused_variables, unreachable_patterns, clippy::all)]
use crate::error::{Error, Result};
use crate::ops::AlgebraicLaw;
use crate::ir::validate::limits::MAX_GRAPH_NODES;
use crate::ops::graph::bfs::Bfs;
use crate::ops::graph::csr::{to_csr, CsrGraph};
use crate::ops::{OpSpec, BYTES_TO_U32_OUTPUTS, U32_INPUTS};
use std::collections::VecDeque;

// WGSL lowering marker for `graph.reachability`.
//
// No special per-op lowering is needed. The reachability metadata reuses the
// BFS IR program.

pub fn bfs_from_source(csr: &CsrGraph, source: u32, max_depth: u32, reached: &mut Vec<ReachableNode>) {
    let node_count = csr.node_count();
    let mut visited = vec![false; node_count];
    let mut queue = VecDeque::new();
    let Ok(source_index) = usize::try_from(source) else {
        return;
    };
    visited[source_index] = true;
    queue.push_back((source, 0u32));

    while let Some((node, depth)) = queue.pop_front() {
        if node != source {
            reached.push((source, node, depth));
        }
        let Ok(node_index) = usize::try_from(node) else {
            continue;
        };
        if depth >= max_depth || is_sanitizer(csr.node_data[node_index]) {
            continue;
        }
        let Some(next_node_index) = node_index.checked_add(1) else {
            continue;
        };
        let Ok(start) = usize::try_from(csr.offsets[node_index]) else {
            continue;
        };
        let Ok(end) = usize::try_from(csr.offsets[next_node_index]) else {
            continue;
        };
        for &target in &csr.targets[start..end] {
            let Ok(target_idx) = usize::try_from(target) else {
                continue;
            };
            if target_idx < node_count && !visited[target_idx] {
                visited[target_idx] = true;
                queue.push_back((target, depth.saturating_add(1)));
            }
        }
    }
}

impl ReachabilityOp {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "graph.reachability",
        U32_INPUTS,
        BYTES_TO_U32_OUTPUTS,
        LAWS,
        Bfs::program,
    );
}

pub fn is_sanitizer(node_data: u32) -> bool {
    ((node_data >> 16) & 0xFF) == 4
}

pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
    lo: 0,
    hi: u32::MAX,
}];

/// Multi-source reachability operation metadata.
#[derive(Debug, Default, Clone, Copy)]
pub struct ReachabilityOp;

/// A reached node tuple `(source, node, depth)`.
pub type ReachableNode = (u32, u32, u32);

/// Compute multi-source reachability on the CPU.
///
/// Sanitizer semantics are encoded by node metadata whose label type is `4`,
/// matching the graph BFS shader.
///
/// # Errors
///
/// Returns an actionable error when the CSR shape is invalid or the node count
/// exceeds host-side allocation limits.
pub fn reachable_nodes(
    csr: &CsrGraph,
    sources: &[u32],
    max_depth: u32,
) -> Result<Vec<ReachableNode>> {
    let node_count = csr.node_count();
    validate_reachability_csr(csr, node_count)?;
    let mut reached = Vec::new();
    for &source in sources {
        let Ok(source_index) = usize::try_from(source) else {
            continue;
        };
        if source_index >= node_count {
            continue;
        }
        bfs_from_source(csr, source, max_depth, &mut reached);
    }
    Ok(reached)
}

pub fn validate_reachability_csr(csr: &CsrGraph, node_count: usize) -> Result<()> {
    if node_count > MAX_GRAPH_NODES {
        return Err(Error::Csr {
            message: format!(
                "GraphTooLarge: node_count {node_count} exceeds {MAX_GRAPH_NODES}. Fix: split the graph before CPU reachability."
            ),
        });
    }
    let expected_offsets = node_count.checked_add(1).ok_or_else(|| Error::Csr {
        message: "CsrInvalid: node_count + 1 overflows usize. Fix: split the graph before CPU reachability.".to_string(),
    })?;
    if csr.offsets.len() != expected_offsets {
        return Err(Error::Csr {
            message: format!(
                "CsrInvalid: offsets length {} does not equal node_count + 1 ({expected_offsets}). Fix: rebuild CSR offsets before CPU reachability.",
                csr.offsets.len()
            ),
        });
    }
    csr.validate()
}

// Unit tests.
// Unit tests extracted from `ops/graph/reachability/kernel.rs`.

#[test]
pub fn finds_reachable_nodes_for_each_source() -> crate::error::Result<()> {
    let csr = to_csr(4, &[(0, 1), (1, 2), (3, 2)])?;
    assert_eq!(
        reachable_nodes(&csr, &[0, 3], 8)?,
        vec![(0, 1, 1), (0, 2, 2), (3, 2, 1)]
    );
    Ok(())
}

#[test]
pub fn respects_max_depth() -> crate::error::Result<()> {
    let csr = to_csr(3, &[(0, 1), (1, 2)])?;
    assert_eq!(reachable_nodes(&csr, &[0], 1)?, vec![(0, 1, 1)]);
    Ok(())
}