vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
pub use crate::ops::graph::MAX_BFS_QUEUE;
use crate::error::{Error, Result};
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::AlgebraicLaw;
use crate::ops::graph::csr::CsrGraph;
use crate::ops::{OpSpec, BYTES_TO_U32_OUTPUTS, U32_INPUTS};
use std::collections::VecDeque;

// WGSL lowering marker for `graph.bfs`.
//
// No special per-op lowering is needed. The normal IR lowerer handles the
// CSR BFS composition.

/// BFS operation over CSR graph buffers.
#[derive(Debug, Default, Clone, Copy)]
pub struct Bfs;

impl Bfs {
    /// Declarative operation specification.
    pub const SPEC: OpSpec =
        OpSpec::composition(OP_ID, U32_INPUTS, BYTES_TO_U32_OUTPUTS, LAWS, Self::program);

    /// Build the canonical IR program with the default queue size.
    ///
    /// Used by the `SPEC` for conformance testing. For GPU dispatch, prefer
    /// [`Self::program_with_queue_size`] with a runtime-detected limit.
    #[must_use]
    pub fn program() -> Program {
        Self::program_with_queue_size(MAX_BFS_QUEUE)
    }

    /// Build the IR program with a specific workgroup queue capacity.
    ///
    /// `max_queue_slots` determines the `var<workgroup>` array size in the
    /// lowered WGSL. It must satisfy `max_queue_slots * 4 <=` the device's
    /// `max_compute_workgroup_storage_size`.
    #[must_use]
    pub fn program_with_queue_size(max_queue_slots: u32) -> Program {
        let source_idx = Expr::var("source_idx");
        let start_node = Expr::var("start_node");
        let node = Expr::var("node");
        let tgt_node = Expr::var("tgt_node");
        let queue_tail = Expr::var("queue_tail");
        let queue_head = Expr::var("queue_head");
        let depth = Expr::var("depth");
        Program::new(
            vec![
                BufferDecl::read("node_labels", 0, DataType::U32),
                BufferDecl::read("edge_offsets", 1, DataType::U32),
                BufferDecl::read("edge_targets", 2, DataType::U32),
                BufferDecl::read("source_nodes", 3, DataType::U32),
                BufferDecl::output("findings", 4, DataType::U32),
                BufferDecl::read_write("finding_count", 5, DataType::U32),
                BufferDecl::read("params", 6, DataType::U32),
                BufferDecl::read_write("visited_set", 7, DataType::U32),
                BufferDecl::workgroup("queue", max_queue_slots, DataType::U32),
            ],
            WORKGROUP_SIZE,
            vec![
                Node::let_bind("source_idx", Expr::gid_x()),
                Node::if_then(
                    valid_source_invocation(source_idx.clone()),
                    vec![
                        Node::let_bind(
                            "start_node",
                            Expr::load("source_nodes", source_idx.clone()),
                        ),
                        Node::if_then(
                            Expr::lt(start_node.clone(), node_count()),
                            vec![
                                Node::let_bind("queue_head", Expr::u32(0)),
                                Node::let_bind("queue_tail", Expr::u32(1)),
                                Node::let_bind("depth", Expr::u32(0)),
                                Node::store("queue", Expr::u32(0), start_node.clone()),
                                mark_visited(source_idx.clone(), start_node.clone()),
                                Node::Loop {
                                    var: "level".to_string(),
                                    from: Expr::u32(0),
                                    to: Expr::load("params", Expr::u32(3)),
                                    body: vec![
                                        Node::let_bind("level_end", queue_tail.clone()),
                                        Node::Loop {
                                            var: "slot".to_string(),
                                            from: queue_head.clone(),
                                            to: Expr::var("level_end"),
                                            body: vec![
                                                Node::let_bind(
                                                    "node",
                                                    Expr::load("queue", Expr::var("slot")),
                                                ),
                                                report_if_sink(
                                                    start_node.clone(),
                                                    node.clone(),
                                                    depth.clone(),
                                                    source_idx.clone(),
                                                ),
                                                scan_neighbors(
                                                    source_idx.clone(),
                                                    node,
                                                    tgt_node,
                                                    queue_tail.clone(),
                                                    max_queue_slots,
                                                ),
                                            ],
                                        },
                                        Node::assign("queue_head", Expr::var("level_end")),
                                        Node::assign("depth", Expr::add(depth, Expr::u32(1))),
                                    ],
                                },
                            ],
                        ),
                    ],
                ),
            ],
        )
        .with_entry_op_id(OP_ID)
    }
}

pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
    lo: 0,
    hi: u32::MAX,
}];

pub fn mark_visited(source_idx: Expr, node: Expr) -> Node {
    let base = Expr::mul(source_idx, words_per_source());
    let word = Expr::add(base, Expr::shr(node.clone(), Expr::u32(5)));
    let bit = Expr::shl(Expr::u32(1), Expr::bitand(node, Expr::u32(31)));
    Node::store(
        "visited_set",
        word.clone(),
        Expr::bitor(Expr::load("visited_set", word), bit),
    )
}

pub fn node_count() -> Expr {
    Expr::load("params", Expr::u32(1))
}

pub fn not_visited(source_idx: Expr, node: Expr) -> Expr {
    let base = Expr::mul(source_idx, words_per_source());
    let word = Expr::add(base, Expr::shr(node.clone(), Expr::u32(5)));
    let bit = Expr::shl(Expr::u32(1), Expr::bitand(node, Expr::u32(31)));
    Expr::eq(
        Expr::bitand(Expr::load("visited_set", word), bit),
        Expr::u32(0),
    )
}

pub const OP_ID: &str = "graph.bfs";

pub fn report_if_sink(start_node: Expr, node: Expr, depth: Expr, source_idx: Expr) -> Node {
    let label = Expr::bitand(
        Expr::shr(Expr::load("node_labels", node.clone()), Expr::u32(16)),
        Expr::u32(0xff),
    );
    Node::if_then(
        Expr::BinOp {
            op: crate::ir::BinOp::And,
            left: Box::new(Expr::BinOp {
                op: crate::ir::BinOp::Or,
                left: Box::new(Expr::eq(label.clone(), Expr::u32(2))),
                right: Box::new(Expr::eq(label, Expr::u32(3))),
            }),
            right: Box::new(Expr::BinOp {
                op: crate::ir::BinOp::Ne,
                left: Box::new(node.clone()),
                right: Box::new(start_node.clone()),
            }),
        },
        vec![
            Node::let_bind(
                "finding_idx",
                Expr::atomic_add("finding_count", Expr::u32(0), Expr::u32(1)),
            ),
            Node::if_then(
                Expr::lt(Expr::var("finding_idx"), Expr::load("params", Expr::u32(2))),
                vec![
                    Node::store(
                        "findings",
                        Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
                        start_node,
                    ),
                    Node::store(
                        "findings",
                        Expr::add(
                            Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
                            Expr::u32(1),
                        ),
                        node,
                    ),
                    Node::store(
                        "findings",
                        Expr::add(
                            Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
                            Expr::u32(2),
                        ),
                        depth,
                    ),
                    Node::store(
                        "findings",
                        Expr::add(
                            Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
                            Expr::u32(3),
                        ),
                        source_idx,
                    ),
                ],
            ),
        ],
    )
}

pub fn scan_neighbors(
    source_idx: Expr,
    node: Expr,
    tgt_node: Expr,
    queue_tail: Expr,
    max_queue_slots: u32,
) -> Node {
    Node::Loop {
        var: "edge".to_string(),
        from: Expr::load("edge_offsets", node.clone()),
        to: Expr::load("edge_offsets", Expr::add(node, Expr::u32(1))),
        body: vec![
            Node::let_bind("tgt_node", Expr::load("edge_targets", Expr::var("edge"))),
            Node::if_then(
                Expr::lt(tgt_node.clone(), node_count()),
                vec![Node::if_then(
                    not_visited(source_idx.clone(), tgt_node.clone()),
                    vec![
                        mark_visited(source_idx, tgt_node.clone()),
                        Node::if_then_else(
                            Expr::lt(queue_tail.clone(), Expr::u32(max_queue_slots)),
                            vec![
                                Node::store("queue", queue_tail.clone(), tgt_node),
                                Node::assign("queue_tail", Expr::add(queue_tail, Expr::u32(1))),
                            ],
                            vec![Node::Return],
                        ),
                    ],
                )],
            ),
        ],
    }
}

pub fn valid_source_invocation(source_idx: Expr) -> Expr {
    Expr::and(
        Expr::lt(source_idx.clone(), Expr::load("params", Expr::u32(0))),
        Expr::and(
            Expr::ne(words_per_source(), Expr::u32(0)),
            Expr::le(
                source_idx,
                Expr::div(Expr::u32(u32::MAX), words_per_source()),
            ),
        ),
    )
}

pub fn validate_edge_offsets(csr: &CsrGraph) -> Result<()> {
    let expected_offsets = csr.node_count().checked_add(1).ok_or_else(|| Error::Csr {
        message:
            "CsrInvalid: node_count + 1 overflows usize. Fix: split the graph before BFS dispatch."
                .to_string(),
    })?;
    if csr.offsets.len() != expected_offsets {
        return Err(Error::Csr {
            message: format!(
                "CsrInvalid: edge_offsets length {} does not equal num_nodes + 1 ({expected_offsets}). Fix: rebuild CSR offsets before BFS dispatch.",
                csr.offsets.len()
            ),
        });
    }
    csr.validate()
}

/// Validate CPU graph inputs before dispatching the BFS shader.
///
/// # Errors
///
/// Returns an actionable error when CSR buffers are malformed, node counts are
/// oversized, or a source can produce a frontier wider than the shader queue.
pub fn validate_frontier_capacity(csr: &CsrGraph, sources: &[u32], max_depth: u32) -> Result<()> {
    validate_edge_offsets(csr)?;
    let max_queue = usize::try_from(MAX_BFS_QUEUE).map_err(|error| Error::Csr {
        message: format!(
            "Overflow: context MAX_BFS_QUEUE value {MAX_BFS_QUEUE} cannot fit usize: {error}. Fix: lower MAX_BFS_QUEUE for this target platform."
        ),
    })?;
    for &source in sources {
        let Ok(source_index) = usize::try_from(source) else {
            continue;
        };
        if source_index >= csr.node_count() {
            continue;
        }
        validate_source_frontier(csr, source, max_depth, max_queue)?;
    }
    Ok(())
}

pub fn validate_source_frontier(
    csr: &CsrGraph,
    source: u32,
    max_depth: u32,
    max_queue: usize,
) -> Result<()> {
    let node_count = csr.node_count();
    let mut visited = vec![false; node_count];
    let source_index = usize::try_from(source).map_err(|error| Error::Csr {
        message: format!(
            "InvalidEdge: source {source} cannot fit usize: {error}. Fix: use source node ids representable on this platform."
        ),
    })?;
    let mut queue = VecDeque::from([(source, 0u32)]);
    visited[source_index] = true;

    while let Some((node, depth)) = queue.pop_front() {
        if depth >= max_depth {
            continue;
        }
        let node_index = usize::try_from(node).map_err(|error| Error::Csr {
            message: format!(
                "CsrInvalid: queued node {node} cannot fit usize: {error}. Fix: rebuild CSR with platform-sized node ids."
            ),
        })?;
        let next_node_index = node_index.checked_add(1).ok_or_else(|| Error::Csr {
            message:
                "CsrInvalid: node_index + 1 overflows usize. Fix: rebuild CSR with fewer nodes."
                    .to_string(),
        })?;
        let start = usize::try_from(csr.offsets[node_index]).map_err(|error| Error::Csr {
            message: format!(
                "CsrInvalid: edge offset cannot fit usize: {error}. Fix: rebuild CSR with fewer edges."
            ),
        })?;
        let end = usize::try_from(csr.offsets[next_node_index]).map_err(|error| Error::Csr {
            message: format!(
                "CsrInvalid: edge offset cannot fit usize: {error}. Fix: rebuild CSR with fewer edges."
            ),
        })?;
        for &target in &csr.targets[start..end] {
            let target_index = usize::try_from(target).map_err(|error| Error::Csr {
                message: format!(
                    "CsrInvalid: target {target} cannot fit usize: {error}. Fix: rebuild CSR with platform-sized node ids."
                ),
            })?;
            if !visited[target_index] {
                visited[target_index] = true;
                queue.push_back((target, depth.saturating_add(1)));
                if queue.len() > max_queue {
                    return Err(Error::Csr {
                        message: format!(
                            "BfsFrontierTooLarge: frontier {} exceeds {MAX_BFS_QUEUE}. Fix: Split the graph into smaller subgraphs or implement multi-pass BFS",
                            queue.len()
                        ),
                    });
                }
            }
        }
    }
    Ok(())
}

pub fn words_per_source() -> Expr {
    Expr::load("params", Expr::u32(4))
}

/// Workgroup size for multi-source BFS.
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];