vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
//! One-shot decompression kernel dispatch shared across formats.
//!
//! Every decompressor (zstd, lz4, zlib, gzip, deflate) ends up in
//! this module to upload bindings, submit a single compute dispatch,
//! and read back the output slice plus a status word. Format-specific
//! modules supply the uniform layout, workgroup count, and debug
//! labels via [`DecompressionKernelInputs`] and
//! [`DecompressionKernelLabels`] so this code path stays agnostic.

use crate::engine::decompress::{
    align_to_copy, binding, bytes_for_u32s, decode_u32s, unpack_words_to_bytes,
};
use crate::runtime::cache::{BufferPool, PooledBuffer};
use std::sync::mpsc;
use vyre::error::{Error, Result};
use vyre::Program;

/// Buffer and dispatch inputs shared by decompression kernels.
pub struct DecompressionKernelInputs<'a> {
    /// Compressed input bytes uploaded to binding 0.
    pub data: &'a [u8],
    /// Uniform bytes uploaded to binding 2.
    pub uniform_bytes: &'a [u8],
    /// Decompressed output buffer length in u32 words.
    pub output_words_len: usize,
    /// Status buffer length in u32 words.
    pub status_words_len: usize,
    /// Exact decompressed byte length returned to the caller.
    pub total_output_size: u32,
    /// Compute workgroups dispatched in the X dimension.
    pub block_count: u32,
}

/// Debug labels and error context for one decompression kernel dispatch.
pub struct DecompressionKernelLabels {
    pub format: &'static str,
    pub compressed_buffer: &'static str,
    pub output_buffer: &'static str,
    pub status_buffer: &'static str,
    pub uniforms_buffer: &'static str,
    pub shader: &'static str,
    pub bind_group_layout: &'static str,
    pub pipeline_layout: &'static str,
    pub pipeline: &'static str,
    pub bind_group: &'static str,
    pub readback_output: &'static str,
    pub readback_status: &'static str,
    pub command_encoder: &'static str,
    pub pass: &'static str,
}

/// Output returned by a decompression kernel dispatch.
pub struct DispatchOutput {
    pub output: Vec<u8>,
    pub statuses: Vec<u32>,
}

/// Compile, bind, and dispatch a decompression kernel on the GPU.
///
/// # Errors
///
/// Returns `Error::Decompress` if the IR program cannot be lowered to WGSL or
/// the readback data is malformed. Returns `Error::Gpu` if buffer sizing,
/// device execution, or mapped readback fails.
pub fn dispatch_decompression_kernel(
    device: &wgpu::Device,
    queue: &wgpu::Queue,
    inputs: DecompressionKernelInputs<'_>,
    labels: DecompressionKernelLabels,
    program: &Program,
    encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<DispatchOutput> {
    vyre::ops::registry::gate::verify_program_certificate(program).map_err(|source| {
        Error::Decompress {
            message: source.to_string(),
        }
    })?;

    let pool = BufferPool::global();
    let compressed_buffer =
        create_storage_buffer(device, queue, labels.compressed_buffer, inputs.data)?;
    let storage_usage =
        wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST;
    let output_buffer = pool.acquire(
        device,
        labels.output_buffer,
        bytes_for_u32s(inputs.output_words_len)?,
        storage_usage,
    )?;
    let status_buffer = pool.acquire(
        device,
        labels.status_buffer,
        bytes_for_u32s(inputs.status_words_len)?,
        storage_usage,
    )?;
    let uniforms_buffer =
        create_storage_buffer(device, queue, labels.uniforms_buffer, inputs.uniform_bytes)?;

    let shader_source = vyre::lower::wgsl::lower(program).map_err(|source| Error::Decompress {
        message: format!(
            "failed to lower {} decompression IR to WGSL: {source}. Fix: repair the {} IR composition before GPU dispatch.",
            labels.format, labels.format
        ),
    })?;
    let _shader_label = labels.shader;
    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
        label: Some(labels.bind_group_layout),
        entries: &[
            storage_binding(0, true),
            storage_binding(1, false),
            storage_binding(2, true),
            storage_binding(3, false),
        ],
    });
    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
        label: Some(labels.pipeline_layout),
        bind_group_layouts: &[&bind_group_layout],
        push_constant_ranges: &[],
    });
    let pipeline = crate::runtime::compile_compute_pipeline_with_layout(
        device,
        labels.pipeline,
        &shader_source,
        "main",
        Some(&pipeline_layout),
    )?;

    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
        label: Some(labels.bind_group),
        layout: &bind_group_layout,
        entries: &[
            binding(0, &compressed_buffer),
            binding(1, &output_buffer),
            binding(2, &uniforms_buffer),
            binding(3, &status_buffer),
        ],
    });

    let output_size = align_to_copy(bytes_for_u32s(inputs.output_words_len)?);
    let status_size = align_to_copy(bytes_for_u32s(inputs.status_words_len)?);
    let readback_output = readback_buffer(device, labels.readback_output, output_size)?;
    let readback_status = readback_buffer(device, labels.readback_status, status_size)?;

    let mut owned_encoder: Option<wgpu::CommandEncoder> = encoder.is_none().then(|| {
        device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
            label: Some(labels.command_encoder),
        })
    });
    let encoder = if let Some(encoder) = encoder {
        encoder
    } else {
        // SAFETY: `owned_encoder` is `Some` when `encoder` is `None`.
        owned_encoder
            .as_mut()
            .expect("owned encoder must be present")
    };
    encoder.clear_buffer(&output_buffer, 0, Some(output_buffer.size()));
    encoder.clear_buffer(&status_buffer, 0, Some(status_buffer.size()));
    {
        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
            label: Some(labels.pass),
            timestamp_writes: None,
        });
        pass.set_pipeline(&pipeline);
        pass.set_bind_group(0, &bind_group, &[]);
        pass.dispatch_workgroups(inputs.block_count, 1, 1);
    }
    encoder.copy_buffer_to_buffer(&output_buffer, 0, &readback_output, 0, output_size);
    encoder.copy_buffer_to_buffer(&status_buffer, 0, &readback_status, 0, status_size);

    let Some(owned_encoder) = owned_encoder else {
        return Err(Error::Decompress {
            message: "dispatch_decompression_kernel was called with an external encoder; use deferred dispatch mode to submit and map externally. Fix: pass `None` for immediate readback behavior.".to_string(),
        });
    };
    let submission = queue.submit(Some(owned_encoder.finish()));
    let output_bytes =
        read_mapped_buffer(device, &readback_output, output_size, submission.clone())?;
    let status_bytes = read_mapped_buffer(device, &readback_status, status_size, submission)?;
    pool.release(compressed_buffer);
    pool.release(output_buffer);
    pool.release(status_buffer);
    pool.release(uniforms_buffer);
    pool.release(readback_output);
    pool.release(readback_status);

    let statuses = decode_u32s(&status_bytes)?;
    let output_words = decode_u32s(&output_bytes)?;
    let total_output_size =
        usize::try_from(inputs.total_output_size).map_err(|source| Error::Decompress {
            message: format!(
                "{} output size {} cannot fit usize: {source}. Fix: reject this decompression descriptor on this platform.",
                labels.format, inputs.total_output_size
            ),
        })?;
    let output = unpack_words_to_bytes(&output_words, total_output_size)?;
    Ok(DispatchOutput { output, statuses })
}

pub fn create_storage_buffer(
    device: &wgpu::Device,
    queue: &wgpu::Queue,
    label: &'static str,
    bytes: &[u8],
) -> Result<PooledBuffer> {
    let size = align_to_copy(u64::try_from(bytes.len()).map_err(|source| Error::Gpu {
        message: format!(
            "storage buffer '{label}' length {} cannot fit u64: {source}. Fix: split the upload.",
            bytes.len()
        ),
    })?);
    let buffer = BufferPool::global().acquire(
        device,
        label,
        size,
        wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
    )?;
    queue.write_buffer(&buffer, 0, bytes);
    write_zero_padding(queue, &buffer, bytes.len(), size)?;
    Ok(buffer)
}

pub fn readback_buffer(
    device: &wgpu::Device,
    label: &'static str,
    size: u64,
) -> Result<PooledBuffer> {
    BufferPool::global().acquire(
        device,
        label,
        size,
        wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
    )
}

pub fn read_mapped_buffer(
    device: &wgpu::Device,
    buffer: &wgpu::Buffer,
    size: u64,
    submission: wgpu::SubmissionIndex,
) -> Result<Vec<u8>> {
    let slice = buffer.slice(0..size);
    let (sender, receiver) = mpsc::sync_channel(1);
    slice.map_async(wgpu::MapMode::Read, move |result| {
        if let Err(send_err) = sender.send(result) {
            tracing::warn!(
                ?send_err,
                "decompress readback receiver dropped before map_async result delivery"
            );
        }
    });
    match device.poll(wgpu::Maintain::wait_for(submission)) {
        wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
    }
    receiver
        .recv()
        .map_err(|error| Error::Gpu {
            message: format!(
                "GPU readback channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."
            ),
        })?
        .map_err(|error| Error::Gpu {
            message: format!(
                "GPU readback map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."
            ),
        })?;
    let mapped = slice.get_mapped_range();
    let output_len = usize::try_from(size).map_err(|source| Error::Gpu {
        message: format!("readback size cannot fit usize: {source}. Fix: split the workload."),
    })?;
    let mut bytes = vec![0_u8; output_len];
    bytes.copy_from_slice(&mapped[..output_len]);
    drop(mapped);
    buffer.unmap();
    Ok(bytes)
}

fn write_zero_padding(
    queue: &wgpu::Queue,
    buffer: &wgpu::Buffer,
    written_len: usize,
    total_size: u64,
) -> Result<()> {
    let mut offset = u64::try_from(written_len).map_err(|source| Error::Gpu {
        message: format!(
            "written byte length {written_len} cannot fit u64: {source}. Fix: split the upload."
        ),
    })?;
    let zeros = [0u8; 4096];
    while offset < total_size {
        let remaining = usize::try_from((total_size - offset).min(4096)).map_err(|source| {
            Error::Gpu {
                message: format!(
                    "padding byte length cannot fit usize: {source}. Fix: run on a supported target."
                ),
            }
        })?;
        queue.write_buffer(buffer, offset, &zeros[..remaining]);
        offset += u64::try_from(remaining).map_err(|source| Error::Gpu {
            message: format!(
                "padding chunk length {remaining} cannot fit u64: {source}. Fix: run on a supported target."
            ),
        })?;
    }
    Ok(())
}

pub fn storage_binding(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
    wgpu::BindGroupLayoutEntry {
        binding,
        visibility: wgpu::ShaderStages::COMPUTE,
        ty: wgpu::BindingType::Buffer {
            ty: wgpu::BufferBindingType::Storage { read_only },
            has_dynamic_offset: false,
            min_binding_size: None,
        },
        count: None,
    }
}