extern crate alloc;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
use super::{CHUNK_SIZE, split_for_pos};
const MAX_CHUNK_BODY: usize = 4096;
#[derive(Clone, Copy, PartialEq, Eq)]
enum Phase {
HeaderLow,
HeaderHigh { hdr_lo: u8 },
Body {
compressed: bool,
body_remaining: u16,
},
Draining,
Done,
}
pub struct Decoder {
chunk_buf: Vec<u8>,
out_buf: Vec<u8>,
out_idx: usize,
phase: Phase,
poisoned: bool,
}
impl Decoder {
pub fn new() -> Self {
Self {
chunk_buf: Vec::with_capacity(MAX_CHUNK_BODY),
out_buf: Vec::with_capacity(MAX_CHUNK_BODY),
out_idx: 0,
phase: Phase::HeaderLow,
poisoned: false,
}
}
fn drain(&mut self, output: &mut [u8], written: &mut usize) {
let avail = self.out_buf.len() - self.out_idx;
let room = output.len() - *written;
let n = avail.min(room);
if n > 0 {
output[*written..*written + n]
.copy_from_slice(&self.out_buf[self.out_idx..self.out_idx + n]);
self.out_idx += n;
*written += n;
}
if self.out_idx == self.out_buf.len() {
self.out_buf.clear();
self.out_idx = 0;
self.phase = Phase::HeaderLow;
}
}
fn decode_compressed_chunk(&mut self) -> Result<(), Error> {
let body = &self.chunk_buf[..];
let mut i = 0usize;
let mut out: Vec<u8> = Vec::with_capacity(CHUNK_SIZE);
while i < body.len() {
let flag = body[i];
i += 1;
for bit in 0..8 {
if i >= body.len() {
break;
}
let is_match = (flag >> bit) & 1 != 0;
if !is_match {
out.push(body[i]);
i += 1;
if out.len() > CHUNK_SIZE {
return Err(Error::Corrupt);
}
} else {
if i + 2 > body.len() {
return Err(Error::Corrupt);
}
let token = u16::from_le_bytes([body[i], body[i + 1]]);
i += 2;
let pos = out.len();
if pos == 0 {
return Err(Error::Corrupt);
}
let (_off_bits, length_bits) = split_for_pos(pos);
let length_mask: u16 = (1u16 << length_bits) - 1;
let length = ((token & length_mask) as usize) + 3;
let offset = ((token >> length_bits) as usize) + 1;
if offset > pos {
return Err(Error::InvalidDistance);
}
if out.len() + length > CHUNK_SIZE {
return Err(Error::Corrupt);
}
let src_start = pos - offset;
for k in 0..length {
let b = out[src_start + k];
out.push(b);
}
}
}
}
self.out_buf = out;
self.out_idx = 0;
Ok(())
}
fn decode_uncompressed_chunk(&mut self) -> Result<(), Error> {
self.out_buf.clear();
self.out_buf.extend_from_slice(&self.chunk_buf);
self.out_idx = 0;
Ok(())
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut consumed = 0usize;
let mut written = 0usize;
loop {
if self.phase == Phase::Draining {
self.drain(output, &mut written);
if self.phase == Phase::Draining {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
continue;
}
if self.phase == Phase::Done {
return Ok(RawProgress {
consumed,
written,
done: true,
});
}
if consumed >= input.len() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
match self.phase {
Phase::HeaderLow => {
let hdr_lo = input[consumed];
consumed += 1;
self.phase = Phase::HeaderHigh { hdr_lo };
}
Phase::HeaderHigh { hdr_lo } => {
let hdr_hi = input[consumed];
consumed += 1;
let header = u16::from_le_bytes([hdr_lo, hdr_hi]);
if header == 0 {
self.phase = Phase::Done;
return Ok(RawProgress {
consumed,
written,
done: true,
});
}
let compressed = (header & 0x8000) != 0;
let signature = (header >> 12) & 0x7;
if signature != 0b011 {
self.poisoned = true;
return Err(Error::BadHeader);
}
let body_size = (header & 0x0FFF) as usize + 1;
if body_size > MAX_CHUNK_BODY {
self.poisoned = true;
return Err(Error::BadHeader);
}
self.chunk_buf.clear();
self.phase = Phase::Body {
compressed,
body_remaining: body_size as u16,
};
}
Phase::Body {
compressed,
mut body_remaining,
} => {
let avail = input.len() - consumed;
let want = body_remaining as usize;
let take = avail.min(want);
self.chunk_buf
.extend_from_slice(&input[consumed..consumed + take]);
consumed += take;
body_remaining -= take as u16;
if body_remaining == 0 {
let res = if compressed {
self.decode_compressed_chunk()
} else {
self.decode_uncompressed_chunk()
};
if let Err(e) = res {
self.poisoned = true;
return Err(e);
}
self.phase = if self.out_buf.is_empty() {
Phase::HeaderLow
} else {
Phase::Draining
};
} else {
self.phase = Phase::Body {
compressed,
body_remaining,
};
}
}
Phase::Draining | Phase::Done => {
debug_assert!(false, "handled at top of loop");
}
}
}
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut written = 0usize;
if self.phase == Phase::Draining {
self.drain(output, &mut written);
if self.phase == Phase::Draining {
return Ok(RawProgress {
consumed: 0,
written,
done: false,
});
}
}
match self.phase {
Phase::HeaderLow | Phase::Done => Ok(RawProgress {
consumed: 0,
written,
done: true,
}),
Phase::HeaderHigh { .. } | Phase::Body { .. } => {
self.poisoned = true;
Err(Error::UnexpectedEnd)
}
Phase::Draining => Ok(RawProgress {
consumed: 0,
written,
done: false,
}),
}
}
fn raw_reset(&mut self) {
self.chunk_buf.clear();
self.out_buf.clear();
self.out_idx = 0;
self.phase = Phase::HeaderLow;
self.poisoned = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_stream_finishes_cleanly() {
let mut dec = Decoder::new();
let mut out = [0u8; 16];
let p = dec.raw_finish(&mut out).unwrap();
assert!(p.done);
assert_eq!(p.written, 0);
}
#[test]
fn zero_terminator_ends_cleanly() {
let mut dec = Decoder::new();
let mut out = [0u8; 16];
let p = dec.raw_decode(&[0u8, 0u8], &mut out).unwrap();
assert!(p.done);
assert_eq!(p.written, 0);
}
}