vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// RFC 1950 zlib wrapper decompression.

use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::compression::deflate_core;
use crate::ops::{AlgebraicLaw, OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};

pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded { lo: 0, hi: 255 }];

/// Maximum decompressed bytes accepted per compressed input byte.
pub const MAX_OUTPUT_RATIO: usize = 1024;

/// zlib decompression operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct ZlibDecompress;

impl ZlibDecompress {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "compression.zlib_decompress",
        BYTES_TO_BYTES_INPUTS,
        BYTES_TO_BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical dispatch guard program.
    #[must_use]
    pub fn program() -> Program {
        Program::new(
            vec![
                BufferDecl::read("input", 0, DataType::Bytes),
                BufferDecl::output("out", 1, DataType::Bytes),
            ],
            [1, 1, 1],
            vec![Node::if_then(
                Expr::gt(
                    Expr::buf_len("out"),
                    Expr::mul(Expr::buf_len("input"), Expr::u32(1024)),
                ),
                vec![Node::Return],
            )],
        )
    }
}

/// Decompress a zlib byte stream with header, Adler-32, and bomb validation.
///
/// # Errors
///
/// Returns an actionable `Fix: ...` message when the wrapper, checksum, stream,
/// or expansion ratio is invalid.
pub fn decompress_bytes(input: &[u8]) -> Result<Vec<u8>, String> {
    if input.len() < 6 {
        return Err("Fix: provide a complete zlib stream with header and Adler-32 trailer.".into());
    }
    let cmf = input[0];
    let flg = input[1];
    if cmf & 0x0f != 8 {
        return Err("Fix: zlib CM compression method must be DEFLATE (8).".into());
    }
    if cmf >> 4 > 7 {
        return Err("Fix: zlib CINFO window size must be at most 32 KiB.".into());
    }
    if (u16::from(cmf) << 8 | u16::from(flg)) % 31 != 0 {
        return Err("Fix: repair zlib FCHECK header checksum.".into());
    }
    if flg & 0x20 != 0 {
        return Err("Fix: zlib preset dictionaries are not accepted by this operation.".into());
    }
    let max_output = input.len().checked_mul(MAX_OUTPUT_RATIO).ok_or_else(|| {
        "Fix: reject zlib input whose max_output_ratio multiplication overflows.".to_string()
    })?;
    let payload_end = input.len() - 4;
    let output = deflate_core::decompress(&input[2..payload_end], max_output)?;
    let expected = u32::from_be_bytes([
        input[payload_end],
        input[payload_end + 1],
        input[payload_end + 2],
        input[payload_end + 3],
    ]);
    if adler32(&output) != expected {
        return Err("Fix: zlib Adler-32 trailer does not match decompressed bytes.".into());
    }
    Ok(output)
}

pub fn adler32(bytes: &[u8]) -> u32 {
    const MOD: u32 = 65_521;
    let mut a = 1_u32;
    let mut b = 0_u32;
    for &byte in bytes {
        a = (a + u32::from(byte)) % MOD;
        b = (b + a) % MOD;
    }
    (b << 16) | a
}