vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
//! Zstd GPU decompression dispatch for raw/RLE blocks.

use crate::engine::decompress::dispatch_kernel::{
    dispatch_decompression_kernel, DecompressionKernelInputs, DecompressionKernelLabels,
};
pub use crate::engine::decompress::uniforms::ZstdUniforms;
use vyre::error::{Error, Result};
use vyre::ops::compression::zstd;

/// Run the zstd GPU shader and return the raw decompressed bytes and status words.
///
/// # Errors
///
/// Returns `Error::Decompress` if descriptors are non-empty, the stream is
/// malformed, sizing is invalid, or decompression readback fails.
pub fn dispatch_zstd(
    args: ZstdDispatchArgs<'_>,
    command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<(Vec<u8>, Vec<u32>)> {
    if !args.descriptor_words.is_empty() {
        return Err(Error::Decompress {
            message: format!(
                "zstd dispatch received {} descriptor words, but the current zstd IR backend consumes input/output/params/status buffers directly. Fix: pass descriptor fields through uniforms or update the zstd IR composition.",
                args.descriptor_words.len()
            ),
        });
    }
    let derived_output_bytes = decoded_len_from_blocks(args.data)?;
    let derived_output_words = super::words_for_output_bytes("zstd", derived_output_bytes)?;
    let derived_output_len = super::u32_output_len("zstd", derived_output_bytes)?;
    let uniforms = ZstdUniforms {
        input_len: u32::try_from(args.data.len()).map_err(|source| Error::Decompress {
            message: format!(
                "zstd input length {} cannot fit u32 shader uniforms: {source}. Fix: split the compressed stream before GPU dispatch.",
                args.data.len()
            ),
        })?,
        output_len: derived_output_len,
        total_output_size: derived_output_len,
    };
    super::validate_output_ratio(
        "zstd",
        args.data.len(),
        derived_output_bytes,
        zstd::MAX_OUTPUT_RATIO as usize,
    )?;
    super::validate_backend_capacity(
        args.device,
        args.data.len(),
        0,
        derived_output_words,
        args.status_words_len,
        std::mem::size_of::<ZstdUniforms>() / std::mem::size_of::<u32>(),
    )?;

    let program = zstd::ZstdDecompress::program();
    let dispatch = dispatch_decompression_kernel(
        args.device,
        args.queue,
        DecompressionKernelInputs {
            data: args.data,
            uniform_bytes: bytemuck::bytes_of(&uniforms),
            output_words_len: derived_output_words,
            status_words_len: args.status_words_len,
            total_output_size: uniforms.total_output_size,
            block_count: args.block_count,
        },
        DecompressionKernelLabels {
            format: "zstd",
            compressed_buffer: "vyre-zstd-compressed",
            output_buffer: "vyre-zstd-output",
            status_buffer: "vyre-zstd-status",
            uniforms_buffer: "vyre-zstd-uniforms",
            shader: "vyre-zstd-shader",
            bind_group_layout: "vyre-zstd-bind-group-layout",
            pipeline_layout: "vyre-zstd-pipeline-layout",
            pipeline: "vyre-zstd-pipeline",
            bind_group: "vyre-zstd-bind-group",
            readback_output: "vyre-zstd-readback-output",
            readback_status: "vyre-zstd-readback-status",
            command_encoder: "vyre-zstd-command-encoder",
            pass: "vyre-zstd-pass",
        },
        &program,
        command_encoder,
    )?;
    Ok((dispatch.output, dispatch.statuses))
}

/// Parameters for [`dispatch_zstd`].
pub struct ZstdDispatchArgs<'a> {
    /// GPU device used to create buffers and pipelines.
    pub device: &'a wgpu::Device,
    /// GPU queue used to submit work.
    pub queue: &'a wgpu::Queue,
    /// Compressed input bytes.
    pub data: &'a [u8],
    /// Reserved descriptor words. Must be empty in the current implementation.
    pub descriptor_words: &'a [u32],
    /// Deprecated compatibility field. Output size is derived from block headers.
    pub output_words_len: usize,
    /// Status buffer length in 32-bit words.
    pub status_words_len: usize,
    /// Deprecated compatibility field. Uniforms are derived after validation.
    pub uniforms: ZstdUniforms,
    /// Number of compute workgroups to dispatch in X.
    pub block_count: u32,
}

pub(crate) fn decoded_len_from_blocks(data: &[u8]) -> Result<usize> {
    let mut cursor = 0usize;
    let mut output_len = 0usize;
    while cursor < data.len() {
        if data.len().saturating_sub(cursor) < 3 {
            return Err(Error::Decompress {
                message: "zstd block header is truncated. Fix: reject this malformed stream before GPU dispatch.".to_string(),
            });
        }
        let header = u32::from(data[cursor])
            | (u32::from(data[cursor + 1]) << 8)
            | (u32::from(data[cursor + 2]) << 16);
        cursor += 3;
        let last = (header & 1) != 0;
        let block_type = (header >> 1) & 0b11;
        let block_size = usize::try_from(header >> 3).map_err(|source| Error::Decompress {
            message: format!(
                "zstd block size cannot fit usize: {source}. Fix: reject this malformed stream before GPU dispatch."
            ),
        })?;
        match block_type {
            0 => {
                cursor = cursor.checked_add(block_size).ok_or_else(|| Error::Decompress {
                    message: "zstd raw block cursor overflow. Fix: reject this malformed stream before GPU dispatch.".to_string(),
                })?;
                if cursor > data.len() {
                    return Err(Error::Decompress {
                        message: "zstd raw block length exceeds input bounds. Fix: reject this malformed stream before GPU dispatch.".to_string(),
                    });
                }
                output_len = output_len.checked_add(block_size).ok_or_else(|| Error::Decompress {
                    message: "zstd decoded length overflowed while reading a raw block. Fix: split or reject this stream.".to_string(),
                })?;
            }
            1 => {
                if cursor >= data.len() {
                    return Err(Error::Decompress {
                        message: "zstd RLE block is missing its repeated byte. Fix: reject this malformed stream before GPU dispatch.".to_string(),
                    });
                }
                cursor += 1;
                output_len = output_len.checked_add(block_size).ok_or_else(|| Error::Decompress {
                    message: "zstd decoded length overflowed while reading an RLE block. Fix: split or reject this stream.".to_string(),
                })?;
            }
            2 => {
                return Err(Error::Decompress {
                    message: "zstd compressed block output length is not derivable from the block header. Fix: dispatch only raw/RLE blocks or provide a trusted frame parser that supplies verified sizes.".to_string(),
                });
            }
            _ => {
                return Err(Error::Decompress {
                    message: "zstd reserved block type encountered. Fix: reject this malformed stream before GPU dispatch.".to_string(),
                });
            }
        }
        if last {
            if cursor != data.len() {
                return Err(Error::Decompress {
                    message: "zstd stream has trailing bytes after the last block. Fix: reject this malformed stream before GPU dispatch.".to_string(),
                });
            }
            return Ok(output_len);
        }
    }
    Ok(output_len)
}