vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
use crate::engine::dataflow::{
    checked_u32, create_bind_group, create_buffer, create_pipeline, dispatch_and_copy,
    read_findings, validate_inputs, BfsBindGroupInputs, GpuFinding, GpuParams, MAX_FINDINGS,
    WORKGROUP_SIZE,
};
use vyre::error::{Error, Result};
use vyre::ops::graph::csr::CsrGraph;
use crate::runtime::cache::BufferPool;

/// Run multi-source BFS reachability on the provided runtime device.
///
/// Returned tuples are `(source, sink, depth)`. The WGSL kernel reports nodes
/// whose packed `node_data` label type is sink (`2`) or both (`3`), preserving
/// the taint-flow contract migrated from pyrograph.
///
/// # Errors
///
/// Returns an error if inputs are invalid, if the GPU device is lost, or if
/// the requested allocation exceeds `MAX_FINDINGS`.
///
/// # Examples
///
/// ```
/// use vyre_wgpu::engine::dataflow::bfs_reachability;
///
/// // Requires a live wgpu device and queue.
/// // let results = bfs_reachability(&device, &queue, &csr, &[0], 8);
/// ```
#[expect(
    clippy::too_many_lines,
    reason = "this function owns one GPU BFS dispatch and keeps resource sizing adjacent to validation"
)]
pub fn bfs_reachability(
    device: &wgpu::Device,
    queue: &wgpu::Queue,
    csr: &CsrGraph,
    sources: &[u32],
    max_depth: u32,
    command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<Vec<(u32, u32, u32)>> {
    validate_inputs(csr, sources)?;
    if sources.is_empty() || csr.node_count() == 0 {
        return Ok(Vec::new());
    }

    let node_count = checked_u32(csr.node_count(), "node_count")?;
    let max_findings = checked_u32(
        sources
            .len()
            .checked_mul(csr.node_count())
            .ok_or_else(|| Error::Dataflow {
                message: "source_count * node_count overflows usize. Fix: reduce source count or graph size before GPU dispatch.".to_string(),
            })?,
        "max_findings",
    )?;
    let max_findings_usize = usize::try_from(max_findings).map_err(|source| Error::Dataflow {
        message: format!("max_findings {max_findings} cannot fit usize: {source}. Fix: reject this dispatch on this platform."),
    })?;
    if max_findings_usize > MAX_FINDINGS {
        return Err(Error::Dataflow {
            message: format!(
                "max_findings {max_findings_usize} exceeds limit {MAX_FINDINGS}. Fix: split sources or graph partitions before GPU dispatch."
            ),
        });
    }
    if max_findings == 0 {
        return Ok(Vec::new());
    }

    let params = GpuParams {
        num_sources: checked_u32(sources.len(), "source_count")?,
        num_nodes: node_count,
        max_findings,
        max_depth,
        words_per_source: node_count.div_ceil(32),
        _pad0: 0,
        _pad1: 0,
        _pad2: 0,
    };

    let queue_slots = crate::engine::bfs_queue_slots_for_device::bfs_queue_slots_for_device(device, None);
    let pipeline = create_pipeline(device, queue_slots)?;
    let node_buffer = create_buffer(
        device,
        queue,
        "graph node data",
        &csr.node_data,
        wgpu::BufferUsages::STORAGE,
    )?;
    let offset_buffer = create_buffer(
        device,
        queue,
        "graph offsets",
        &csr.offsets,
        wgpu::BufferUsages::STORAGE,
    )?;
    let target_buffer = create_buffer(
        device,
        queue,
        "graph targets",
        &csr.targets,
        wgpu::BufferUsages::STORAGE,
    )?;
    let source_buffer = create_buffer(
        device,
        queue,
        "graph sources",
        sources,
        wgpu::BufferUsages::STORAGE,
    )?;
    let params_buffer = create_buffer(
        device,
        queue,
        "graph params",
        &[params],
        wgpu::BufferUsages::STORAGE,
    )?;

    let finding_size =
        u64::try_from(std::mem::size_of::<GpuFinding>()).map_err(|source| Error::Dataflow {
            message: format!(
                "GpuFinding size cannot fit u64: {source}. Fix: run on a supported target."
            ),
        })?;
    let findings_size =
        u64::from(max_findings)
            .checked_mul(finding_size)
            .ok_or_else(|| Error::Dataflow {
                message: "findings buffer size overflow. Fix: split sources or graph partitions before GPU dispatch.".to_string(),
            })?;
    let pool = BufferPool::global();
    let findings_buffer = pool.acquire(
        device,
        "graph findings",
        findings_size,
        wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
    )?;
    let finding_count_buffer = create_buffer(
        device,
        queue,
        "graph finding count",
        &[0u32],
        wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
    )?;

    let visited_size = u64::from(params.num_sources) * u64::from(params.words_per_source) * 4;
    let visited_bytes = usize::try_from(visited_size.max(16)).map_err(|source| Error::Dataflow {
        message: format!("visited buffer size cannot fit usize: {source}. Fix: reject this dispatch on this platform."),
    })?;
    let visited_buffer = pool.acquire(
        device,
        "graph visited",
        visited_size.max(16),
        wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
    )?;
    queue.write_buffer(&visited_buffer, 0, &vec![0u8; visited_bytes]);

    let bind_group_inputs = BfsBindGroupInputs {
        node_buffer: &node_buffer,
        offset_buffer: &offset_buffer,
        target_buffer: &target_buffer,
        source_buffer: &source_buffer,
        findings_buffer: &findings_buffer,
        finding_count_buffer: &finding_count_buffer,
        params_buffer: &params_buffer,
        visited_buffer: &visited_buffer,
    };
    let bind_group = create_bind_group(
        device,
        &pipeline,
        &bind_group_inputs,
    );

    let (readback_buffer, readback_count_buffer, submission) = dispatch_and_copy(
        device,
        queue,
        &pipeline,
        &bind_group,
        &findings_buffer,
        &finding_count_buffer,
        findings_size,
        params.num_sources.div_ceil(WORKGROUP_SIZE),
        command_encoder,
    )?;

    let findings = read_findings(
        device,
        &readback_buffer,
        &readback_count_buffer,
        max_findings,
        submission,
    )?;
    pool.release(node_buffer);
    pool.release(offset_buffer);
    pool.release(target_buffer);
    pool.release(source_buffer);
    pool.release(params_buffer);
    pool.release(findings_buffer);
    pool.release(finding_count_buffer);
    pool.release(visited_buffer);
    pool.release(readback_buffer);
    pool.release(readback_count_buffer);
    Ok(findings)
}