#![cfg_attr(docsrs, doc(cfg(feature = "zip_reduce")))]
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{Algorithm, RawDecoder, RawEncoder, RawProgress};
#[derive(Debug, Clone, Copy, Default)]
pub struct ZipReduce;
impl Algorithm for ZipReduce {
const NAME: &'static str = "zip-reduce";
type Encoder = Encoder;
type Decoder = Decoder;
type EncoderConfig = ();
type DecoderConfig = ();
fn encoder_with(_: ()) -> Encoder {
Encoder::new()
}
fn decoder_with(_: ()) -> Decoder {
Decoder::new()
}
}
#[derive(Debug, Default)]
pub struct Encoder;
impl Encoder {
pub const fn new() -> Self {
Self
}
}
impl RawEncoder for Encoder {
fn raw_encode(&mut self, _input: &[u8], _output: &mut [u8]) -> Result<RawProgress, Error> {
Err(Error::Unsupported)
}
fn raw_finish(&mut self, _output: &mut [u8]) -> Result<RawProgress, Error> {
Err(Error::Unsupported)
}
fn raw_reset(&mut self) {}
}
const fn follower_idx_bw(n: u8) -> u8 {
match n {
17..=32 => 5,
9..=16 => 4,
5..=8 => 3,
3..=4 => 2,
1..=2 => 1,
_ => 0,
}
}
#[derive(Debug, Clone, Copy)]
struct FollowerSet {
size: u8,
idx_bw: u8,
followers: [u8; 32],
}
impl FollowerSet {
const fn empty() -> Self {
Self {
size: 0,
idx_bw: 0,
followers: [0u8; 32],
}
}
}
#[derive(Debug, Clone, Copy)]
struct BitReader {
bitpos: u64,
}
impl BitReader {
const fn new() -> Self {
Self { bitpos: 0 }
}
fn rebase(&mut self, dropped_bytes: usize) {
self.bitpos -= (dropped_bytes as u64) * 8;
}
fn byte_pos(&self) -> usize {
(self.bitpos / 8) as usize
}
fn has_bits(&self, buf: &[u8], n: u32) -> bool {
let end_bits = self.bitpos.saturating_add(n as u64);
end_bits <= (buf.len() as u64) * 8
}
fn peek_bits(&self, buf: &[u8], n: u32) -> Result<u32, Error> {
debug_assert!(n <= 32);
if !self.has_bits(buf, n) {
return Err(Error::UnexpectedEnd);
}
if n == 0 {
return Ok(0);
}
let byte = (self.bitpos / 8) as usize;
let shift = (self.bitpos % 8) as u32;
let mut acc: u64 = 0;
let take = (n + shift).div_ceil(8);
for i in 0..take as usize {
if byte + i < buf.len() {
acc |= (buf[byte + i] as u64) << (i * 8);
}
}
let mask: u64 = if n == 32 {
0xFFFF_FFFF
} else {
(1u64 << n) - 1
};
Ok(((acc >> shift) & mask) as u32)
}
fn read_bits(&mut self, buf: &[u8], n: u32) -> Result<u32, Error> {
let v = self.peek_bits(buf, n)?;
self.bitpos += n as u64;
Ok(v)
}
}
#[derive(Debug, Clone, Copy)]
struct PendingMatch {
dist: usize,
remaining: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Phase {
Header,
FollowerSets,
Body,
Done,
Poison,
}
pub struct Decoder {
phase: Phase,
factor: u8,
uncomp_len: u32,
input_buf: Vec<u8>,
bits: BitReader,
next_fset: i16,
fsets: Vec<FollowerSet>,
out: Vec<u8>,
window_base: usize,
emit_cursor: usize,
prev_byte: u8,
pending: Option<PendingMatch>,
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl Decoder {
pub fn new() -> Self {
Self {
phase: Phase::Header,
factor: 0,
uncomp_len: 0,
input_buf: Vec::new(),
bits: BitReader::new(),
next_fset: 255,
fsets: Vec::new(),
out: Vec::new(),
window_base: 0,
emit_cursor: 0,
prev_byte: 0,
pending: None,
}
}
fn parse_header(&mut self) -> Result<(), Error> {
if self.input_buf.len() < 5 {
return Err(Error::UnexpectedEnd);
}
let factor = self.input_buf[0];
if !(1..=4).contains(&factor) {
return Err(Error::BadHeader);
}
let ucl = u32::from_le_bytes([
self.input_buf[1],
self.input_buf[2],
self.input_buf[3],
self.input_buf[4],
]);
self.factor = factor;
self.uncomp_len = ucl;
self.input_buf.drain(0..5);
self.bits = BitReader::new();
let cap = (ucl as usize).min(1024 * 1024);
self.out = Vec::with_capacity(cap);
self.window_base = 0;
self.fsets = vec![FollowerSet::empty(); 256];
self.next_fset = 255;
self.phase = Phase::FollowerSets;
Ok(())
}
fn read_follower_sets(&mut self) -> Result<(), Error> {
while self.next_fset >= 0 {
let idx = self.next_fset as usize;
let saved = self.bits;
let n = self.bits.read_bits(&self.input_buf, 6)? as u8;
if n > 32 {
return Err(Error::Corrupt);
}
self.fsets[idx].size = n;
self.fsets[idx].idx_bw = follower_idx_bw(n);
for j in 0..n as usize {
match self.bits.read_bits(&self.input_buf, 8) {
Ok(b) => self.fsets[idx].followers[j] = b as u8,
Err(Error::UnexpectedEnd) => {
self.bits = saved;
return Err(Error::UnexpectedEnd);
}
Err(e) => return Err(e),
}
}
self.next_fset -= 1;
}
self.phase = Phase::Body;
Ok(())
}
fn read_next_byte(&mut self) -> Result<u8, Error> {
let prev = self.prev_byte as usize;
let fset = self.fsets[prev];
let saved = self.bits;
if fset.size == 0 {
match self.bits.read_bits(&self.input_buf, 8) {
Ok(b) => Ok(b as u8),
Err(e) => {
self.bits = saved;
Err(e)
}
}
} else {
let sel = match self.bits.read_bits(&self.input_buf, 1) {
Ok(v) => v,
Err(e) => {
self.bits = saved;
return Err(e);
}
};
if sel == 1 {
match self.bits.read_bits(&self.input_buf, 8) {
Ok(b) => Ok(b as u8),
Err(e) => {
self.bits = saved;
Err(e)
}
}
} else {
let idx_bw = fset.idx_bw as u32;
let idx = match self.bits.read_bits(&self.input_buf, idx_bw) {
Ok(v) => v as usize,
Err(e) => {
self.bits = saved;
return Err(e);
}
};
if idx >= fset.size as usize {
Err(Error::Corrupt)
} else {
Ok(fset.followers[idx])
}
}
}
}
fn decode_body(&mut self, output: &mut [u8], written: &mut usize) -> Result<(), Error> {
let max_dist = ((1usize << self.factor) - 1) * 256 + 255 + 1;
let buffer_ahead = max_dist * 4;
if let Some(mut pm) = self.pending.take() {
while pm.remaining > 0 {
self.flush_emit(output, written);
self.slide_window(max_dist);
if self.produced() - self.emit_cursor >= buffer_ahead && *written >= output.len() {
self.pending = Some(pm);
return Ok(());
}
let pos = self.produced();
let b = if pm.dist > pos {
0u8
} else {
self.out[(pos - pm.dist) - self.window_base]
};
self.out.push(b);
pm.remaining -= 1;
if (self.produced() as u32) >= self.uncomp_len && pm.remaining > 0 {
return Err(Error::Corrupt);
}
}
}
let v_len_bits: u32 = (8 - self.factor) as u32;
let len_mask: u32 = (1u32 << v_len_bits) - 1;
while (self.produced() as u32) < self.uncomp_len {
self.flush_emit(output, written);
self.slide_window(max_dist);
if self.produced() - self.emit_cursor >= buffer_ahead && *written >= output.len() {
return Ok(());
}
let saved_bits = self.bits;
let saved_prev = self.prev_byte;
let cur = match self.read_next_byte() {
Ok(b) => b,
Err(Error::UnexpectedEnd) => {
self.bits = saved_bits;
return Ok(());
}
Err(e) => return Err(e),
};
self.prev_byte = cur;
if cur != DLE_BYTE {
self.out.push(cur);
continue;
}
let v = match self.read_next_byte() {
Ok(b) => b,
Err(Error::UnexpectedEnd) => {
self.bits = saved_bits;
self.prev_byte = saved_prev;
return Ok(());
}
Err(e) => return Err(e),
};
self.prev_byte = v;
if v == 0 {
self.out.push(DLE_BYTE);
continue;
}
let mut len = (v as u32 & len_mask) as usize;
if (len as u32) == len_mask {
let elb = match self.read_next_byte() {
Ok(b) => b,
Err(Error::UnexpectedEnd) => {
self.bits = saved_bits;
self.prev_byte = saved_prev;
return Ok(());
}
Err(e) => return Err(e),
};
self.prev_byte = elb;
len += elb as usize;
}
len += 3;
let w = match self.read_next_byte() {
Ok(b) => b,
Err(Error::UnexpectedEnd) => {
self.bits = saved_bits;
self.prev_byte = saved_prev;
return Ok(());
}
Err(e) => return Err(e),
};
self.prev_byte = w;
let dist_hi = (v as usize) >> v_len_bits;
let dist = dist_hi * 256 + w as usize + 1;
let remaining_out = (self.uncomp_len as usize) - self.produced();
if len > remaining_out {
return Err(Error::Corrupt);
}
let mut pm = PendingMatch {
dist,
remaining: len,
};
while pm.remaining > 0 {
let pos = self.produced();
let b = if pm.dist > pos {
0u8
} else {
self.out[(pos - pm.dist) - self.window_base]
};
self.out.push(b);
pm.remaining -= 1;
}
}
Ok(())
}
fn produced(&self) -> usize {
self.window_base + self.out.len()
}
fn slide_window(&mut self, max_dist: usize) {
let keep = max_dist.saturating_mul(4).max(max_dist + 1);
let end = self.produced();
let drop_limit = self.emit_cursor.min(end.saturating_sub(keep));
if drop_limit <= self.window_base {
return;
}
let drop = drop_limit - self.window_base;
self.out.drain(0..drop);
self.window_base += drop;
}
fn flush_emit(&mut self, output: &mut [u8], written: &mut usize) {
while self.emit_cursor < self.produced() && *written < output.len() {
output[*written] = self.out[self.emit_cursor - self.window_base];
*written += 1;
self.emit_cursor += 1;
}
}
fn compact_input(&mut self) {
let bp = self.bits.byte_pos();
if bp == 0 {
return;
}
self.input_buf.drain(0..bp);
self.bits.rebase(bp);
}
}
const DLE_BYTE: u8 = 0x90;
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if matches!(self.phase, Phase::Poison) {
return Err(Error::Corrupt);
}
self.input_buf.extend_from_slice(input);
let mut written = 0usize;
self.flush_emit(output, &mut written);
if written == output.len() {
return Ok(RawProgress {
consumed: input.len(),
written,
done: false,
});
}
if matches!(self.phase, Phase::Header) {
match self.parse_header() {
Ok(()) => {}
Err(Error::UnexpectedEnd) => {
return Ok(RawProgress {
consumed: input.len(),
written,
done: false,
});
}
Err(e) => {
self.phase = Phase::Poison;
return Err(e);
}
}
}
if matches!(self.phase, Phase::FollowerSets) {
match self.read_follower_sets() {
Ok(()) => {}
Err(Error::UnexpectedEnd) => {
self.compact_input();
return Ok(RawProgress {
consumed: input.len(),
written,
done: false,
});
}
Err(e) => {
self.phase = Phase::Poison;
return Err(e);
}
}
}
if (self.produced() as u32) >= self.uncomp_len && matches!(self.phase, Phase::Body) {
self.phase = Phase::Done;
}
if matches!(self.phase, Phase::Body) {
match self.decode_body(output, &mut written) {
Ok(()) => {}
Err(e) => {
self.phase = Phase::Poison;
return Err(e);
}
}
self.flush_emit(output, &mut written);
if (self.produced() as u32) >= self.uncomp_len && self.emit_cursor == self.produced() {
self.phase = Phase::Done;
}
}
self.compact_input();
Ok(RawProgress {
consumed: input.len(),
written,
done: matches!(self.phase, Phase::Done),
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if matches!(self.phase, Phase::Poison) {
return Err(Error::Corrupt);
}
let mut written = 0usize;
self.flush_emit(output, &mut written);
if matches!(self.phase, Phase::Body) {
match self.decode_body(output, &mut written) {
Ok(()) => {}
Err(e) => {
self.phase = Phase::Poison;
return Err(e);
}
}
self.flush_emit(output, &mut written);
if (self.produced() as u32) >= self.uncomp_len && self.emit_cursor == self.produced() {
self.phase = Phase::Done;
}
}
let done = matches!(self.phase, Phase::Done);
if !done && written == 0 {
if self.emit_cursor == self.produced() && !matches!(self.phase, Phase::Done) {
self.phase = Phase::Poison;
return Err(Error::UnexpectedEnd);
}
}
Ok(RawProgress {
consumed: 0,
written,
done,
})
}
fn raw_reset(&mut self) {
self.phase = Phase::Header;
self.factor = 0;
self.uncomp_len = 0;
self.input_buf.clear();
self.bits = BitReader::new();
self.next_fset = 255;
self.fsets.clear();
self.out.clear();
self.window_base = 0;
self.emit_cursor = 0;
self.prev_byte = 0;
self.pending = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn follower_idx_bw_matches_reference() {
assert_eq!(follower_idx_bw(0), 0);
assert_eq!(follower_idx_bw(1), 1);
assert_eq!(follower_idx_bw(2), 1);
assert_eq!(follower_idx_bw(3), 2);
assert_eq!(follower_idx_bw(4), 2);
assert_eq!(follower_idx_bw(5), 3);
assert_eq!(follower_idx_bw(8), 3);
assert_eq!(follower_idx_bw(9), 4);
assert_eq!(follower_idx_bw(16), 4);
assert_eq!(follower_idx_bw(17), 5);
assert_eq!(follower_idx_bw(32), 5);
}
#[test]
fn bit_reader_reads_lsb_first() {
let buf = [0b1011_0011u8, 0b1111_0000u8];
let mut br = BitReader::new();
assert_eq!(br.read_bits(&buf, 4).unwrap(), 0x3);
assert_eq!(br.read_bits(&buf, 4).unwrap(), 0xB);
assert_eq!(br.read_bits(&buf, 4).unwrap(), 0x0);
assert_eq!(br.read_bits(&buf, 4).unwrap(), 0xF);
}
#[allow(dead_code)]
mod fixtures {
include!("../../tests/zip_reduce_fixtures.in");
pub(super) const ABC_REPEATED_R4_PUB: &[u8] = ABC_REPEATED_R4;
}
#[test]
fn sliding_window_bounds_retained_output() {
let mut dec = Decoder::new();
let mut buf = [0u8; 7];
let mut total = 0usize;
let max_dist = ((1usize << 4) - 1) * 256 + 255 + 1;
let window_cap = max_dist * 8 + 4096;
let mut consumed = 0usize;
loop {
let (p, status) = {
use crate::traits::RawDecoder;
let r = dec
.raw_decode(&fixtures::ABC_REPEATED_R4_PUB[consumed..], &mut buf)
.unwrap();
let s = r.done;
(r, s)
};
consumed += p.consumed;
total += p.written;
assert!(
dec.out.len() <= window_cap,
"retained window {} exceeded bound {} (OOM regression)",
dec.out.len(),
window_cap
);
if status {
break;
}
if p.consumed == 0 && p.written == 0 {
if consumed >= fixtures::ABC_REPEATED_R4_PUB.len() {
break;
}
}
}
assert_eq!(total, 66000, "decoded length mismatch");
}
}