use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::compression::deflate_core;
use crate::ops::{AlgebraicLaw, OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded { lo: 0, hi: 255 }];
pub const MAX_OUTPUT_RATIO: usize = 1024;
#[derive(Debug, Clone, Copy, Default)]
pub struct ZlibDecompress;
impl ZlibDecompress {
pub const SPEC: OpSpec = OpSpec::composition(
"compression.zlib_decompress",
BYTES_TO_BYTES_INPUTS,
BYTES_TO_BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::output("out", 1, DataType::Bytes),
],
[1, 1, 1],
vec![Node::if_then(
Expr::gt(
Expr::buf_len("out"),
Expr::mul(Expr::buf_len("input"), Expr::u32(1024)),
),
vec![Node::Return],
)],
)
}
}
pub fn decompress_bytes(input: &[u8]) -> Result<Vec<u8>, String> {
if input.len() < 6 {
return Err("Fix: provide a complete zlib stream with header and Adler-32 trailer.".into());
}
let cmf = input[0];
let flg = input[1];
if cmf & 0x0f != 8 {
return Err("Fix: zlib CM compression method must be DEFLATE (8).".into());
}
if cmf >> 4 > 7 {
return Err("Fix: zlib CINFO window size must be at most 32 KiB.".into());
}
if (u16::from(cmf) << 8 | u16::from(flg)) % 31 != 0 {
return Err("Fix: repair zlib FCHECK header checksum.".into());
}
if flg & 0x20 != 0 {
return Err("Fix: zlib preset dictionaries are not accepted by this operation.".into());
}
let max_output = input.len().checked_mul(MAX_OUTPUT_RATIO).ok_or_else(|| {
"Fix: reject zlib input whose max_output_ratio multiplication overflows.".to_string()
})?;
let payload_end = input.len() - 4;
let output = deflate_core::decompress(&input[2..payload_end], max_output)?;
let expected = u32::from_be_bytes([
input[payload_end],
input[payload_end + 1],
input[payload_end + 2],
input[payload_end + 3],
]);
if adler32(&output) != expected {
return Err("Fix: zlib Adler-32 trailer does not match decompressed bytes.".into());
}
Ok(output)
}
pub fn adler32(bytes: &[u8]) -> u32 {
const MOD: u32 = 65_521;
let mut a = 1_u32;
let mut b = 0_u32;
for &byte in bytes {
a = (a + u32::from(byte)) % MOD;
b = (b + a) % MOD;
}
(b << 16) | a
}