vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
//! LZ4 GPU decompression dispatch.

use crate::engine::decompress::dispatch_kernel::{
    dispatch_decompression_kernel, DecompressionKernelInputs, DecompressionKernelLabels,
};
pub use crate::engine::decompress::uniforms::Lz4Uniforms;
use vyre::error::{Error, Result};
use vyre::ops::compression::lz4;

/// Run the LZ4 GPU shader and return the raw decompressed bytes and status words.
///
/// # Errors
///
/// Returns `Error::Decompress` if descriptors are non-empty, the block is
/// malformed, sizing is invalid, or decompression readback fails.
pub fn dispatch_lz4(
    args: Lz4DispatchArgs<'_>,
    command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<(Vec<u8>, Vec<u32>)> {
    if !args.descriptor_words.is_empty() {
        return Err(Error::Decompress {
            message: format!(
                "LZ4 dispatch received {} descriptor words, but the current LZ4 IR backend consumes input/output/params/status buffers directly. Fix: pass descriptor fields through uniforms or update the LZ4 IR composition.",
                args.descriptor_words.len()
            ),
        });
    }
    let derived_output_bytes = decoded_len_from_block(args.data)?;
    let derived_output_words = super::words_for_output_bytes("LZ4", derived_output_bytes)?;
    let derived_output_len = super::u32_output_len("LZ4", derived_output_bytes)?;
    let uniforms = Lz4Uniforms {
        input_len: u32::try_from(args.data.len()).map_err(|source| Error::Decompress {
            message: format!(
                "LZ4 input length {} cannot fit u32 shader uniforms: {source}. Fix: split the compressed stream before GPU dispatch.",
                args.data.len()
            ),
        })?,
        output_len: derived_output_len,
        total_output_size: derived_output_len,
    };
    super::validate_output_ratio(
        "LZ4",
        args.data.len(),
        derived_output_bytes,
        lz4::MAX_OUTPUT_RATIO as usize,
    )?;
    super::validate_backend_capacity(
        args.device,
        args.data.len(),
        0,
        derived_output_words,
        args.status_words_len,
        std::mem::size_of::<Lz4Uniforms>() / std::mem::size_of::<u32>(),
    )?;

    let program = lz4::Lz4Decompress::program();
    let dispatch = dispatch_decompression_kernel(
        args.device,
        args.queue,
        DecompressionKernelInputs {
            data: args.data,
            uniform_bytes: bytemuck::bytes_of(&uniforms),
            output_words_len: derived_output_words,
            status_words_len: args.status_words_len,
            total_output_size: uniforms.total_output_size,
            block_count: args.block_count,
        },
        DecompressionKernelLabels {
            format: "LZ4",
            compressed_buffer: "vyre-lz4-compressed",
            output_buffer: "vyre-lz4-output",
            status_buffer: "vyre-lz4-status",
            uniforms_buffer: "vyre-lz4-uniforms",
            shader: "vyre-lz4-shader",
            bind_group_layout: "vyre-lz4-bind-group-layout",
            pipeline_layout: "vyre-lz4-pipeline-layout",
            pipeline: "vyre-lz4-pipeline",
            bind_group: "vyre-lz4-bind-group",
            readback_output: "vyre-lz4-readback-output",
            readback_status: "vyre-lz4-readback-status",
            command_encoder: "vyre-lz4-command-encoder",
            pass: "vyre-lz4-pass",
        },
        &program,
        command_encoder,
    )?;
    Ok((dispatch.output, dispatch.statuses))
}

/// Parameters for [`dispatch_lz4`].
pub struct Lz4DispatchArgs<'a> {
    /// GPU device used to create buffers and pipelines.
    pub device: &'a wgpu::Device,
    /// GPU queue used to submit work.
    pub queue: &'a wgpu::Queue,
    /// Compressed input bytes.
    pub data: &'a [u8],
    /// Reserved descriptor words. Must be empty in the current implementation.
    pub descriptor_words: &'a [u32],
    /// Deprecated compatibility field. Output size is derived from the block.
    pub output_words_len: usize,
    /// Status buffer length in 32-bit words.
    pub status_words_len: usize,
    /// Deprecated compatibility field. Uniforms are derived after validation.
    pub uniforms: Lz4Uniforms,
    /// Number of compute workgroups to dispatch in X.
    pub block_count: u32,
}

pub(crate) fn decoded_len_from_block(data: &[u8]) -> Result<usize> {
    let mut cursor = 0usize;
    let mut output_len = 0usize;
    while cursor < data.len() {
        let token = data[cursor];
        cursor += 1;

        let literal_len = read_extended_len(data, &mut cursor, usize::from(token >> 4), "literal")?;
        cursor = cursor.checked_add(literal_len).ok_or_else(|| Error::Decompress {
            message: "LZ4 literal cursor overflow. Fix: reject this malformed block before GPU dispatch.".to_string(),
        })?;
        if cursor > data.len() {
            return Err(Error::Decompress {
                message: "LZ4 literal length exceeds compressed block bounds. Fix: reject this malformed block before GPU dispatch.".to_string(),
            });
        }
        output_len = output_len.checked_add(literal_len).ok_or_else(|| Error::Decompress {
            message: "LZ4 decoded length overflowed while reading literals. Fix: split or reject this compressed block.".to_string(),
        })?;
        if cursor == data.len() {
            break;
        }
        if data.len().saturating_sub(cursor) < 2 {
            return Err(Error::Decompress {
                message: "LZ4 match offset is truncated. Fix: reject this malformed block before GPU dispatch.".to_string(),
            });
        }
        let offset = u16::from_le_bytes([data[cursor], data[cursor + 1]]);
        cursor += 2;
        if offset == 0 || usize::from(offset) > output_len {
            return Err(Error::Decompress {
                message: format!(
                    "LZ4 match offset {offset} is invalid for decoded length {output_len}. Fix: reject this malformed block before GPU dispatch."
                ),
            });
        }
        let match_len = read_extended_len(data, &mut cursor, usize::from(token & 0x0f), "match")?
            .checked_add(4)
            .ok_or_else(|| Error::Decompress {
                message: "LZ4 match length overflow. Fix: reject this malformed block before GPU dispatch.".to_string(),
            })?;
        output_len = output_len.checked_add(match_len).ok_or_else(|| Error::Decompress {
            message: "LZ4 decoded length overflowed while reading match. Fix: split or reject this compressed block.".to_string(),
        })?;
    }
    Ok(output_len)
}

/// `read_extended_len` function.
pub fn read_extended_len(
    data: &[u8],
    cursor: &mut usize,
    nibble: usize,
    label: &'static str,
) -> Result<usize> {
    let mut len = nibble;
    if nibble != 15 {
        return Ok(len);
    }
    loop {
        let Some(byte) = data.get(*cursor).copied() else {
            return Err(Error::Decompress {
                message: format!(
                    "LZ4 {label} length extension is truncated. Fix: reject this malformed block before GPU dispatch."
                ),
            });
        };
        *cursor += 1;
        len = len.checked_add(usize::from(byte)).ok_or_else(|| Error::Decompress {
            message: format!(
                "LZ4 {label} length extension overflowed. Fix: reject this malformed block before GPU dispatch."
            ),
        })?;
        if byte != 255 {
            return Ok(len);
        }
    }
}