vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// Shared DEFLATE decoder used by gzip, zlib, and raw deflate operations.

pub const MAX_BITS: u8 = 15;
pub const LEN_BASE: [u16; 29] = [
    3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131,
    163, 195, 227, 258,
];
pub const LEN_EXTRA: [u8; 29] = [
    0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0,
];
pub const DIST_BASE: [u16; 30] = [
    1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537,
    2049, 3073, 4097, 6145, 8193, 12_289, 16_385, 24_577,
];
pub const DIST_EXTRA: [u8; 30] = [
    0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13,
    13,
];
pub const CODE_ORDER: [usize; 19] = [
    16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15,
];
/// Decode a raw RFC 1951 DEFLATE stream.
///
/// # Errors
///
/// Returns an actionable `Fix: ...` message for malformed streams, truncated
/// input, unsupported dictionary references, or expansion above `max_output`.
pub fn decompress(input: &[u8], max_output: usize) -> Result<Vec<u8>, String> {
    let mut reader = BitReader::new(input);
    let mut output = Vec::new();
    loop {
        let is_final = reader.bits(1)?;
        match reader.bits(2)? {
            0 => stored_block(&mut reader, &mut output, max_output)?,
            1 => huffman_block(
                &mut reader,
                &fixed_litlen()?,
                &fixed_dist()?,
                &mut output,
                max_output,
            )?,
            2 => {
                let (litlen, dist) = dynamic_tables(&mut reader)?;
                huffman_block(&mut reader, &litlen, &dist, &mut output, max_output)?;
            }
            _ => {
                return Err("Fix: reject DEFLATE block type 3; it is reserved by RFC 1951.".into())
            }
        }
        if is_final != 0 {
            return Ok(output);
        }
    }
}
pub fn stored_block(
    reader: &mut BitReader<'_>,
    output: &mut Vec<u8>,
    max: usize,
) -> Result<(), String> {
    reader.align_byte();
    let len = reader.u16_le()?;
    let nlen = reader.u16_le()?;
    if len != !nlen {
        return Err("Fix: repair stored DEFLATE block LEN/NLEN complement mismatch.".into());
    }
    let len = usize::from(len);
    ensure_room(output.len(), len, max)?;
    let bytes = reader.bytes(len)?;
    output.extend_from_slice(bytes);
    Ok(())
}
pub fn huffman_block(
    reader: &mut BitReader<'_>,
    litlen: &Huffman,
    dist: &Huffman,
    output: &mut Vec<u8>,
    max: usize,
) -> Result<(), String> {
    loop {
        let sym = litlen.decode(reader)?;
        match sym {
            0..=255 => {
                ensure_room(output.len(), 1, max)?;
                output.push(
                    u8::try_from(sym)
                        .map_err(|error| format!("Fix: literal symbol out of range: {error}"))?,
                );
            }
            256 => return Ok(()),
            257..=285 => copy_match(reader, dist, output, max, usize::from(sym - 257))?,
            _ => return Err("Fix: reject DEFLATE literal/length symbol above 285.".into()),
        }
    }
}
pub fn copy_match(
    reader: &mut BitReader<'_>,
    dist: &Huffman,
    output: &mut Vec<u8>,
    max: usize,
    len_index: usize,
) -> Result<(), String> {
    let length = usize::from(LEN_BASE[len_index])
        + usize::try_from(reader.bits(LEN_EXTRA[len_index])?)
            .map_err(|error| format!("Fix: length extra bits must fit usize: {error}"))?;
    let dist_sym = usize::from(dist.decode(reader)?);
    if dist_sym >= DIST_BASE.len() {
        return Err("Fix: reject DEFLATE distance symbol above 29.".into());
    }
    let distance = usize::from(DIST_BASE[dist_sym])
        + usize::try_from(reader.bits(DIST_EXTRA[dist_sym])?)
            .map_err(|error| format!("Fix: distance extra bits must fit usize: {error}"))?;
    if distance == 0 || distance > output.len() {
        return Err("Fix: reject DEFLATE back-reference before the output window.".into());
    }
    ensure_room(output.len(), length, max)?;
    for _ in 0..length {
        let byte = output[output.len() - distance];
        output.push(byte);
    }
    Ok(())
}
pub fn dynamic_tables(reader: &mut BitReader<'_>) -> Result<(Huffman, Huffman), String> {
    let hlit = usize::try_from(reader.bits(5)? + 257)
        .map_err(|error| format!("Fix: HLIT must fit usize: {error}"))?;
    let hdist = usize::try_from(reader.bits(5)? + 1)
        .map_err(|error| format!("Fix: HDIST must fit usize: {error}"))?;
    let hclen = usize::try_from(reader.bits(4)? + 4)
        .map_err(|error| format!("Fix: HCLEN must fit usize: {error}"))?;
    let mut code_lengths = vec![0_u8; 19];
    for &slot in &CODE_ORDER[..hclen] {
        code_lengths[slot] = u8::try_from(reader.bits(3)?)
            .map_err(|error| format!("Fix: code length must fit u8: {error}"))?;
    }
    let code_tree = Huffman::from_lengths(&code_lengths)?;
    let mut lengths = Vec::with_capacity(hlit + hdist);
    while lengths.len() < hlit + hdist {
        match code_tree.decode(reader)? {
            sym @ 0..=15 => lengths.push(
                u8::try_from(sym)
                    .map_err(|error| format!("Fix: code length symbol must fit u8: {error}"))?,
            ),
            16 => repeat_previous(reader, &mut lengths, 3, 3, hlit + hdist)?,
            17 => repeat_zero(reader, &mut lengths, 3, 3, hlit + hdist)?,
            18 => repeat_zero(reader, &mut lengths, 11, 7, hlit + hdist)?,
            _ => return Err("Fix: reject invalid DEFLATE code-length repeat symbol.".into()),
        }
    }
    if lengths.get(256).copied().unwrap_or(0) == 0 {
        return Err(
            "Fix: dynamic DEFLATE literal tree must include end-of-block symbol 256.".into(),
        );
    }
    Ok((
        Huffman::from_lengths(&lengths[..hlit])?,
        Huffman::from_lengths(&lengths[hlit..])?,
    ))
}
pub fn repeat_previous(
    reader: &mut BitReader<'_>,
    out: &mut Vec<u8>,
    base: usize,
    bits: u8,
    limit: usize,
) -> Result<(), String> {
    let Some(&last) = out.last() else {
        return Err("Fix: DEFLATE repeat-previous length appears before any length.".into());
    };
    repeat_value(reader, out, base, bits, limit, last)
}
pub fn repeat_zero(
    reader: &mut BitReader<'_>,
    out: &mut Vec<u8>,
    base: usize,
    bits: u8,
    limit: usize,
) -> Result<(), String> {
    repeat_value(reader, out, base, bits, limit, 0)
}
pub fn repeat_value(
    reader: &mut BitReader<'_>,
    out: &mut Vec<u8>,
    base: usize,
    bits: u8,
    limit: usize,
    value: u8,
) -> Result<(), String> {
    let extra = usize::try_from(reader.bits(bits)?)
        .map_err(|error| format!("Fix: repeat extra bits must fit usize: {error}"))?;
    let count = base + extra;
    if out.len() + count > limit {
        return Err("Fix: DEFLATE code-length repeat overruns declared table length.".into());
    }
    out.resize(out.len() + count, value);
    Ok(())
}
pub fn fixed_litlen() -> Result<Huffman, String> {
    let mut lengths = vec![0_u8; 288];
    lengths[..144].fill(8);
    lengths[144..256].fill(9);
    lengths[256..280].fill(7);
    lengths[280..288].fill(8);
    Huffman::from_lengths(&lengths)
}
pub fn fixed_dist() -> Result<Huffman, String> {
    Huffman::from_lengths(&[5_u8; 32])
}
pub fn ensure_room(current: usize, add: usize, max: usize) -> Result<(), String> {
    if current.checked_add(add).is_some_and(|next| next <= max) {
        Ok(())
    } else {
        Err("Fix: reject decompression bomb; declared max_output_ratio would be exceeded.".into())
    }
}
#[derive(Clone)]
pub struct Huffman {
    entries: Vec<(u16, u8, u16)>,
}
impl Huffman {
    pub(crate) fn from_lengths(lengths: &[u8]) -> Result<Self, String> {
        let mut bl_count = [0_u16; MAX_BITS as usize + 1];
        for &len in lengths {
            if len > MAX_BITS {
                return Err("Fix: DEFLATE Huffman code length exceeds 15 bits.".into());
            }
            if len != 0 {
                bl_count[usize::from(len)] += 1;
            }
        }
        let mut code = 0_u16;
        let mut next_code = [0_u16; MAX_BITS as usize + 1];
        for bits in 1..=usize::from(MAX_BITS) {
            code = (code + bl_count[bits - 1]) << 1;
            next_code[bits] = code;
        }
        let mut entries = Vec::new();
        for (symbol, &len) in lengths.iter().enumerate() {
            if len == 0 {
                continue;
            }
            let code = next_code[usize::from(len)];
            next_code[usize::from(len)] += 1;
            let sym = u16::try_from(symbol)
                .map_err(|error| format!("Fix: Huffman symbol index must fit u16: {error}"))?;
            entries.push((reverse_bits(code, len), len, sym));
        }
        if entries.is_empty() {
            return Err("Fix: DEFLATE Huffman tree must contain at least one symbol.".into());
        }
        Ok(Self { entries })
    }

    pub(crate) fn decode(&self, reader: &mut BitReader<'_>) -> Result<u16, String> {
        let mut code = 0_u16;
        for len in 1..=MAX_BITS {
            let bit = u16::try_from(reader.bits(1)?)
                .map_err(|error| format!("Fix: Huffman bit must fit u16: {error}"))?;
            code |= bit << (len - 1);
            for &(entry_code, entry_len, symbol) in &self.entries {
                if entry_len == len && entry_code == code {
                    return Ok(symbol);
                }
            }
        }
        Err("Fix: DEFLATE Huffman code does not match any symbol.".into())
    }
}
pub fn reverse_bits(mut code: u16, len: u8) -> u16 {
    let mut out = 0_u16;
    for _ in 0..len {
        out = (out << 1) | (code & 1);
        code >>= 1;
    }
    out
}
pub struct BitReader<'a> {
    input: &'a [u8],
    byte: usize,
    bit: u8,
}
impl<'a> BitReader<'a> {
    pub(crate) fn new(input: &'a [u8]) -> Self {
        Self {
            input,
            byte: 0,
            bit: 0,
        }
    }

    pub(crate) fn bits(&mut self, count: u8) -> Result<u32, String> {
        let mut value = 0_u32;
        for shift in 0..count {
            let Some(&byte) = self.input.get(self.byte) else {
                return Err(
                    "Fix: provide a complete DEFLATE stream; bit reader reached EOF.".into(),
                );
            };
            value |= u32::from((byte >> self.bit) & 1) << shift;
            self.bit += 1;
            if self.bit == 8 {
                self.bit = 0;
                self.byte += 1;
            }
        }
        Ok(value)
    }

    pub(crate) fn align_byte(&mut self) {
        if self.bit != 0 {
            self.bit = 0;
            self.byte += 1;
        }
    }

    pub(crate) fn u16_le(&mut self) -> Result<u16, String> {
        self.align_byte();
        let bytes = self.bytes(2)?;
        Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
    }

    pub(crate) fn bytes(&mut self, count: usize) -> Result<&'a [u8], String> {
        self.align_byte();
        let end = self
            .byte
            .checked_add(count)
            .ok_or_else(|| "Fix: reject DEFLATE byte read with overflowing length.".to_string())?;
        let Some(bytes) = self.input.get(self.byte..end) else {
            return Err("Fix: provide a complete DEFLATE stream; byte reader reached EOF.".into());
        };
        self.byte = end;
        Ok(bytes)
    }
}