vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
//! Generic GPU graph dataflow engines.
//!
//! Host-side workflow dispatcher that owns GPU resources (buffers, pipelines,
//! readback), runs Programs from `vyre::ops::graph`, and returns typed host
//! results.

mod bfs;

use self::bfs::bfs_reachability::bfs_reachability;
use crate::bytemuck_safe::{safe_bytes_of_slice, safe_cast_slice};
use crate::runtime::cache::{BufferPool, PooledBuffer};
use bytemuck::{Pod, Zeroable};
use vyre::error::{Error, Result};
use vyre::ops::graph::bfs::Bfs;
use vyre::ops::graph::csr::CsrGraph;

/// `WORKGROUP_SIZE` constant.
pub const WORKGROUP_SIZE: u32 = 64;

/// Maximum BFS findings allocated by one dispatch.
pub const MAX_FINDINGS: usize = 1_000_000;

#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub(crate) struct GpuFinding {
    start_node: u32,
    sink_node: u32,
    depth: u32,
    source_idx: u32,
}

#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub(crate) struct GpuParams {
    num_sources: u32,
    num_nodes: u32,
    max_findings: u32,
    max_depth: u32,
    words_per_source: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

pub(crate) fn validate_inputs(csr: &CsrGraph, sources: &[u32]) -> Result<()> {
    csr.validate()?;
    for &source in sources {
        let source_index = usize::try_from(source).map_err(|err| Error::Dataflow {
            message: format!(
                "source node {source} cannot fit usize: {err}. Fix: reject this source on this platform."
            ),
        })?;
        if source_index >= csr.node_count() {
            return Err(Error::Dataflow {
                message: format!(
                    "source node {source} is outside node_count {}. Fix: remove invalid sources before GPU dispatch.",
                    csr.node_count()
                ),
            });
        }
    }
    Ok(())
}

pub(crate) fn checked_u32(value: usize, label: &str) -> Result<u32> {
    u32::try_from(value).map_err(|source| Error::Dataflow {
        message: format!("{label} value {value} exceeds u32::MAX: {source}. Fix: split the graph input before GPU dispatch."),
    })
}

/// `read_finding_count` function.
pub fn read_finding_count(
    device: &wgpu::Device,
    readback_count_buffer: &wgpu::Buffer,
    max_findings: u32,
    submission: wgpu::SubmissionIndex,
) -> Result<u32> {
    let slice = readback_count_buffer.slice(0..4);
    let (tx, rx) = std::sync::mpsc::channel();
    slice.map_async(wgpu::MapMode::Read, move |result| {
        let _ = tx.send(result);
    });
    match device.poll(wgpu::Maintain::wait_for(submission)) {
        wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
    }
    rx.recv()
        .map_err(|error| Error::Dataflow {
            message: format!("GPU finding count map channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."),
        })?
        .map_err(|error| Error::Dataflow {
            message: format!("GPU finding count map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
        })?;
    let mapped = slice.get_mapped_range();
    let count = u32::from_ne_bytes(mapped[..4].try_into().map_err(|source| {
        Error::Dataflow {
            message: format!("GPU finding count readback was not four bytes: {source}. Fix: inspect the count readback buffer size."),
        }
    })?)
    .min(max_findings);
    drop(mapped);
    readback_count_buffer.unmap();
    Ok(count)
}

/// `read_finding_rows` function.
pub fn read_finding_rows(
    device: &wgpu::Device,
    readback_buffer: &wgpu::Buffer,
    count: u32,
    submission: wgpu::SubmissionIndex,
) -> Result<Vec<(u32, u32, u32)>> {
    let finding_byte_len = u64::from(count)
        * u64::try_from(std::mem::size_of::<GpuFinding>()).map_err(|source| Error::Dataflow {
            message: format!(
                "GpuFinding size cannot fit u64: {source}. Fix: run on a supported target."
            ),
        })?;
    let slice = readback_buffer.slice(0..finding_byte_len);
    let (tx, rx) = std::sync::mpsc::channel();
    slice.map_async(wgpu::MapMode::Read, move |result| {
        let _ = tx.send(result);
    });
    match device.poll(wgpu::Maintain::wait_for(submission)) {
        wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
    }
    rx.recv()
        .map_err(|error| Error::Dataflow {
            message: format!("GPU findings map channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."),
        })?
        .map_err(|error| Error::Dataflow {
            message: format!("GPU findings map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
        })?;
    let mapped = slice.get_mapped_range();
    let findings: &[GpuFinding] = safe_cast_slice(&mapped).map_err(|error| Error::Dataflow {
        message: format!(
            "safe cast failed: {error}. Fix: ensure the readback buffer size matches GpuFinding layout."
        ),
    })?;
    let results = findings
        .iter()
        .take(usize::try_from(count).map_err(|source| Error::Dataflow {
            message: format!("finding count {count} cannot fit usize: {source}. Fix: reject this readback on this platform."),
        })?)
        .map(|finding| (finding.start_node, finding.sink_node, finding.depth))
        .collect();
    drop(mapped);
    readback_buffer.unmap();
    Ok(results)
}

pub(crate) fn read_findings(
    device: &wgpu::Device,
    readback_buffer: &wgpu::Buffer,
    readback_count_buffer: &wgpu::Buffer,
    max_findings: u32,
    submission: wgpu::SubmissionIndex,
) -> Result<Vec<(u32, u32, u32)>> {
    let count = read_finding_count(
        device,
        readback_count_buffer,
        max_findings,
        submission.clone(),
    )?;
    if count == 0 {
        return Ok(Vec::new());
    }
    read_finding_rows(device, readback_buffer, count, submission)
}

pub(crate) fn create_pipeline(
    device: &wgpu::Device,
    queue_slots: u32,
) -> Result<wgpu::ComputePipeline> {
    vyre::ops::registry::gate::verify_certificate(Bfs::SPEC.id()).map_err(|source| {
        Error::Dataflow {
            message: source.to_string(),
        }
    })?;
    let source = vyre::lower::wgsl::lower(&Bfs::program_with_queue_size(queue_slots)).map_err(
        |source| Error::Dataflow {
            message: format!("failed to lower graph BFS IR to WGSL: {source}. Fix: validate the canonical BFS Program and lowerer before dispatch."),
        },
    )?;
    crate::runtime::compile_compute_pipeline(device, "graph bfs pipeline", &source, "main")
        .map_err(|source| Error::Dataflow {
            message: format!(
                "failed to compile graph BFS pipeline: {source}. Fix: repair the BFS WGSL or GPU runtime configuration."
            ),
        })
}

pub(crate) fn create_buffer<T: Pod>(
    device: &wgpu::Device,
    queue: &wgpu::Queue,
    label: &'static str,
    data: &[T],
    usage: wgpu::BufferUsages,
) -> Result<PooledBuffer> {
    let contents = safe_bytes_of_slice(data);
    let effective = if contents.is_empty() {
        &[0u8; 16][..]
    } else {
        contents
    };
    let size = u64::try_from(effective.len()).map_err(|source| Error::Dataflow {
        message: format!(
            "buffer `{label}` has {} bytes that cannot fit u64: {source}. Fix: split the graph input before dispatch.",
            effective.len()
        ),
    })?;
    let buffer =
        BufferPool::global().acquire(device, label, size, usage | wgpu::BufferUsages::COPY_DST)?;
    queue.write_buffer(&buffer, 0, effective);
    Ok(buffer)
}

/// Buffers bound by the graph BFS compute pipeline.
pub(crate) struct BfsBindGroupInputs<'a> {
    /// CSR node metadata buffer.
    pub(crate) node_buffer: &'a wgpu::Buffer,
    /// CSR row-offset buffer.
    pub(crate) offset_buffer: &'a wgpu::Buffer,
    /// CSR adjacency target buffer.
    pub(crate) target_buffer: &'a wgpu::Buffer,
    /// Starting source-node buffer.
    pub(crate) source_buffer: &'a wgpu::Buffer,
    /// Storage buffer receiving discovered `(start, sink, depth)` rows.
    pub(crate) findings_buffer: &'a wgpu::Buffer,
    /// Single-word storage buffer receiving the findings count.
    pub(crate) finding_count_buffer: &'a wgpu::Buffer,
    /// Dispatch parameter block.
    pub(crate) params_buffer: &'a wgpu::Buffer,
    /// Per-source visited bitset buffer.
    pub(crate) visited_buffer: &'a wgpu::Buffer,
}

pub(crate) fn create_bind_group(
    device: &wgpu::Device,
    pipeline: &wgpu::ComputePipeline,
    inputs: &BfsBindGroupInputs<'_>,
) -> wgpu::BindGroup {
    let layout = pipeline.get_bind_group_layout(0);
    device.create_bind_group(&wgpu::BindGroupDescriptor {
        label: Some("graph bfs bind group"),
        layout: &layout,
        entries: &[
            wgpu::BindGroupEntry {
                binding: 0,
                resource: inputs.node_buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 1,
                resource: inputs.offset_buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 2,
                resource: inputs.target_buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 3,
                resource: inputs.source_buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 4,
                resource: inputs.findings_buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 5,
                resource: inputs.finding_count_buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 6,
                resource: inputs.params_buffer.as_entire_binding(),
            },
            wgpu::BindGroupEntry {
                binding: 7,
                resource: inputs.visited_buffer.as_entire_binding(),
            },
        ],
    })
}

#[expect(
    clippy::too_many_arguments,
    reason = "GPU dispatch helpers pass distinct wgpu handles whose grouping would hide binding roles"
)]
pub(crate) fn dispatch_and_copy(
    device: &wgpu::Device,
    queue: &wgpu::Queue,
    pipeline: &wgpu::ComputePipeline,
    bind_group: &wgpu::BindGroup,
    findings_buffer: &wgpu::Buffer,
    finding_count_buffer: &wgpu::Buffer,
    findings_size: u64,
    workgroup_count: u32,
    encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<(PooledBuffer, PooledBuffer, wgpu::SubmissionIndex)> {
    let pool = BufferPool::global();
    let readback_buffer = pool.acquire(
        device,
        "graph findings readback",
        findings_size,
        wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
    )?;
    let readback_count_buffer = pool.acquire(
        device,
        "graph finding count readback",
        4,
        wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
    )?;
    let mut owned_encoder: Option<wgpu::CommandEncoder> = encoder.is_none().then(|| {
        device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
            label: Some("graph bfs encoder"),
        })
    });
    let encoder = if let Some(encoder) = encoder {
        encoder
    } else {
        owned_encoder
            .as_mut()
            .expect("owned encoder must be present when encoder is omitted")
    };
    {
        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
            label: Some("graph bfs pass"),
            timestamp_writes: None,
        });
        pass.set_pipeline(pipeline);
        pass.set_bind_group(0, bind_group, &[]);
        pass.dispatch_workgroups(workgroup_count, 1, 1);
    }
    encoder.copy_buffer_to_buffer(findings_buffer, 0, &readback_buffer, 0, findings_size);
    encoder.copy_buffer_to_buffer(finding_count_buffer, 0, &readback_count_buffer, 0, 4);
    let submission = if let Some(encoder) = owned_encoder {
        queue.submit(Some(encoder.finish()))
    } else {
        return Err(Error::Dataflow {
            message: "dispatch_and_copy was called with an external encoder. Submit the caller-owned command encoder before readback. Fix: call with `None` for immediate submit behavior.".to_string(),
        });
    };
    Ok((readback_buffer, readback_count_buffer, submission))
}

/// Run multi-source BFS reachability using vyre's cached runtime device.
///
/// # Errors
///
/// Returns `Error::Gpu` if the cached GPU device cannot be initialized.
pub fn bfs_reachability_cached(
    csr: &CsrGraph,
    sources: &[u32],
    max_depth: u32,
) -> Result<Vec<(u32, u32, u32)>> {
    let (device, queue) = crate::runtime::cached_device()?;
    bfs_reachability(device, queue, csr, sources, max_depth, None)
}

// The main bfs_reachability function lives in its own file due to size.