vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// Raw RFC 1951 DEFLATE decompression.

use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::compression::deflate_core;
use crate::ops::{AlgebraicLaw, OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};

pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded { lo: 0, hi: 255 }];

/// Maximum decompressed bytes accepted per compressed input byte.
pub const MAX_OUTPUT_RATIO: usize = 1024;

/// Raw DEFLATE decompression operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct DeflateDecompress;

impl DeflateDecompress {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "compression.deflate_decompress",
        BYTES_TO_BYTES_INPUTS,
        BYTES_TO_BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical dispatch guard program.
    #[must_use]
    pub fn program() -> Program {
        Program::new(
            vec![
                BufferDecl::read("input", 0, DataType::Bytes),
                BufferDecl::output("out", 1, DataType::Bytes),
            ],
            [1, 1, 1],
            vec![Node::if_then(
                Expr::gt(max_output_len(), max_allowed_len()),
                vec![Node::Return],
            )],
        )
    }
}

/// Decompress a raw DEFLATE byte stream with the Wave A bomb cap.
///
/// # Errors
///
/// Returns an actionable `Fix: ...` message when the stream is malformed or
/// would expand beyond [`MAX_OUTPUT_RATIO`].
pub fn decompress_bytes(input: &[u8]) -> Result<Vec<u8>, String> {
    let max_output = max_output_for_input(input.len())?;
    deflate_core::decompress(input, max_output)
}

pub fn max_output_for_input(input_len: usize) -> Result<usize, String> {
    input_len.checked_mul(MAX_OUTPUT_RATIO).ok_or_else(|| {
        "Fix: reject DEFLATE input whose max_output_ratio multiplication overflows.".to_string()
    })
}

pub fn max_output_len() -> Expr {
    Expr::buf_len("out")
}

pub fn max_allowed_len() -> Expr {
    Expr::mul(Expr::buf_len("input"), Expr::u32(MAX_OUTPUT_RATIO as u32))
}