use alloc::vec::Vec;
use crate::checksum::Adler32;
use crate::deflate;
use crate::error::Error;
use crate::traits::{Algorithm, Flush, RawDecoder, RawEncoder, RawProgress};
const HEADER_CMF: u8 = 0x78;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EncoderConfig {
pub level: u8,
}
impl Default for EncoderConfig {
fn default() -> Self {
Self { level: 6 }
}
}
fn header_bytes(level: u8) -> (u8, u8) {
let level = level.clamp(1, 9);
let flevel: u8 = match level {
1 => 0,
2..=5 => 1,
6 => 2,
_ => 3, };
let cmf = HEADER_CMF;
let partial = ((cmf as u32) << 8) | ((flevel as u32) << 6);
let fcheck = (31 - (partial % 31)) % 31;
let flg = (flevel << 6) | (fcheck as u8);
(cmf, flg)
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct DecoderConfig {
pub dictionary: Vec<u8>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Zlib;
impl Algorithm for Zlib {
const NAME: &'static str = "zlib";
type Encoder = Encoder;
type Decoder = Decoder;
type EncoderConfig = EncoderConfig;
type DecoderConfig = DecoderConfig;
fn encoder_with(c: Self::EncoderConfig) -> Encoder {
Encoder::with_config(c)
}
fn decoder_with(c: Self::DecoderConfig) -> Decoder {
Decoder::with_config(c)
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum DecPhase {
Header,
DictId,
Deflate,
Trailer,
Done,
}
pub struct Decoder {
inner: deflate::Decoder,
adler: Adler32,
header: [u8; 2],
header_idx: u8,
dictionary: Vec<u8>,
dictid: [u8; 4],
dictid_idx: u8,
trailer_carryover: Vec<u8>,
trailer_carryover_idx: usize,
trailer: [u8; 4],
trailer_idx: u8,
phase: DecPhase,
poisoned: bool,
}
impl Decoder {
pub fn new() -> Self {
Self {
inner: deflate::Decoder::new(),
adler: Adler32::new(),
header: [0u8; 2],
header_idx: 0,
dictionary: Vec::new(),
dictid: [0u8; 4],
dictid_idx: 0,
trailer_carryover: Vec::new(),
trailer_carryover_idx: 0,
trailer: [0u8; 4],
trailer_idx: 0,
phase: DecPhase::Header,
poisoned: false,
}
}
pub fn with_config(config: DecoderConfig) -> Self {
let mut d = Self::new();
d.dictionary = config.dictionary;
d
}
fn poison(&mut self, e: Error) -> Error {
self.poisoned = true;
e
}
fn validate_header(&mut self) -> Result<DecPhase, Error> {
let cmf = self.header[0];
let flg = self.header[1];
if cmf & 0x0F != 8 {
return Err(self.poison(Error::Unsupported));
}
let total = ((cmf as u32) << 8) | (flg as u32);
if !total.is_multiple_of(31) {
return Err(self.poison(Error::BadHeader));
}
let fdict = (flg & 0x20) != 0;
if fdict {
if self.dictionary.is_empty() {
return Err(self.poison(Error::Unsupported));
}
Ok(DecPhase::DictId)
} else {
Ok(DecPhase::Deflate)
}
}
fn validate_dictid_and_seed(&mut self) -> Result<(), Error> {
let on_wire = u32::from_be_bytes(self.dictid);
let mut sum = Adler32::new();
sum.update(&self.dictionary);
if sum.finalize() != on_wire {
return Err(self.poison(Error::ChecksumMismatch));
}
self.inner.load_dictionary(&self.dictionary);
Ok(())
}
fn next_trailer_byte(&mut self, input: &[u8], consumed: &mut usize) -> Option<bool> {
if self.trailer_carryover_idx < self.trailer_carryover.len() {
let b = self.trailer_carryover[self.trailer_carryover_idx];
self.trailer_carryover_idx += 1;
self.trailer[self.trailer_idx as usize] = b;
self.trailer_idx += 1;
Some(false)
} else if *consumed < input.len() {
self.trailer[self.trailer_idx as usize] = input[*consumed];
*consumed += 1;
self.trailer_idx += 1;
Some(true)
} else {
None
}
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut consumed = 0usize;
let mut written = 0usize;
loop {
let initial_consumed = consumed;
let initial_written = written;
match self.phase {
DecPhase::Header => {
while self.header_idx < 2 && consumed < input.len() {
self.header[self.header_idx as usize] = input[consumed];
self.header_idx += 1;
consumed += 1;
}
if self.header_idx == 2 {
self.phase = self.validate_header()?;
} else {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
DecPhase::DictId => {
while self.dictid_idx < 4 && consumed < input.len() {
self.dictid[self.dictid_idx as usize] = input[consumed];
self.dictid_idx += 1;
consumed += 1;
}
if self.dictid_idx == 4 {
self.validate_dictid_and_seed()?;
self.phase = DecPhase::Deflate;
} else {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
DecPhase::Deflate => {
let before_written = written;
let p = self
.inner
.raw_decode(&input[consumed..], &mut output[written..])
.map_err(|e| self.poison(e))?;
consumed += p.consumed;
written += p.written;
self.adler.update(&output[before_written..written]);
if self.inner.is_complete() {
self.trailer_carryover = self.inner.drain_trailing_bytes();
self.trailer_carryover_idx = 0;
self.phase = DecPhase::Trailer;
} else if p.consumed == 0 && p.written == 0 {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
DecPhase::Trailer => {
while self.trailer_idx < 4 {
if self.next_trailer_byte(input, &mut consumed).is_none() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
let expected = u32::from_be_bytes(self.trailer);
if expected != self.adler.finalize() {
return Err(self.poison(Error::ChecksumMismatch));
}
self.phase = DecPhase::Done;
}
DecPhase::Done => {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
if consumed == initial_consumed && written == initial_written {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let empty: [u8; 0] = [];
let p = self.raw_decode(&empty, output)?;
if matches!(self.phase, DecPhase::Done) {
Ok(RawProgress {
consumed: 0,
written: p.written,
done: true,
})
} else {
Err(self.poison(Error::UnexpectedEnd))
}
}
fn raw_reset(&mut self) {
self.inner.raw_reset();
self.adler.reset();
self.header_idx = 0;
self.dictid_idx = 0;
self.trailer_carryover.clear();
self.trailer_carryover_idx = 0;
self.trailer_idx = 0;
self.phase = DecPhase::Header;
self.poisoned = false;
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum EncPhase {
Header,
Deflate,
Trailer,
Done,
}
pub struct Encoder {
inner: deflate::Encoder,
adler: Adler32,
header: [u8; 2],
header_idx: u8,
trailer: [u8; 4],
trailer_idx: u8,
phase: EncPhase,
}
impl Encoder {
pub fn new() -> Self {
Self::with_config(EncoderConfig::default())
}
pub fn with_config(config: EncoderConfig) -> Self {
let (cmf, flg) = header_bytes(config.level);
Self {
inner: deflate::Encoder::with_config(deflate::EncoderConfig {
level: config.level,
}),
adler: Adler32::new(),
header: [cmf, flg],
header_idx: 0,
trailer: [0u8; 4],
trailer_idx: 0,
phase: EncPhase::Header,
}
}
fn drain_header(&mut self, output: &mut [u8], written: &mut usize) -> bool {
while self.header_idx < 2 && *written < output.len() {
output[*written] = self.header[self.header_idx as usize];
*written += 1;
self.header_idx += 1;
}
self.header_idx == 2
}
fn drain_trailer(&mut self, output: &mut [u8], written: &mut usize) -> bool {
while self.trailer_idx < 4 && *written < output.len() {
output[*written] = self.trailer[self.trailer_idx as usize];
*written += 1;
self.trailer_idx += 1;
}
self.trailer_idx == 4
}
}
impl Default for Encoder {
fn default() -> Self {
Self::new()
}
}
impl RawEncoder for Encoder {
fn raw_encode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
let mut consumed = 0usize;
let mut written = 0usize;
if matches!(self.phase, EncPhase::Header) {
if !self.drain_header(output, &mut written) {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
self.phase = EncPhase::Deflate;
}
if !matches!(self.phase, EncPhase::Deflate) {
return Err(Error::Corrupt);
}
let before = consumed;
let p = self
.inner
.raw_encode(&input[consumed..], &mut output[written..])?;
consumed += p.consumed;
written += p.written;
self.adler.update(&input[before..before + p.consumed]);
Ok(RawProgress {
consumed,
written,
done: false,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
let mut written = 0usize;
if matches!(self.phase, EncPhase::Header) {
if !self.drain_header(output, &mut written) {
return Ok(RawProgress {
consumed: 0,
written,
done: false,
});
}
self.phase = EncPhase::Deflate;
}
if matches!(self.phase, EncPhase::Deflate) {
loop {
let p = self.inner.raw_finish(&mut output[written..])?;
written += p.written;
if p.done {
let adler = self.adler.finalize();
self.trailer = adler.to_be_bytes();
self.trailer_idx = 0;
self.phase = EncPhase::Trailer;
break;
}
if p.written == 0 {
return Ok(RawProgress {
consumed: 0,
written,
done: false,
});
}
}
}
if matches!(self.phase, EncPhase::Trailer) && self.drain_trailer(output, &mut written) {
self.phase = EncPhase::Done;
return Ok(RawProgress {
consumed: 0,
written,
done: true,
});
}
if matches!(self.phase, EncPhase::Done) {
return Ok(RawProgress {
consumed: 0,
written,
done: true,
});
}
Ok(RawProgress {
consumed: 0,
written,
done: false,
})
}
fn raw_reset(&mut self) {
self.inner.raw_reset();
self.adler.reset();
self.header_idx = 0;
self.trailer = [0u8; 4];
self.trailer_idx = 0;
self.phase = EncPhase::Header;
}
fn raw_flush(&mut self, output: &mut [u8], mode: Flush) -> Result<RawProgress, Error> {
let mut written = 0usize;
if matches!(self.phase, EncPhase::Header) {
if !self.drain_header(output, &mut written) {
return Ok(RawProgress {
consumed: 0,
written,
done: false,
});
}
self.phase = EncPhase::Deflate;
}
if !matches!(self.phase, EncPhase::Deflate) {
return Err(Error::Corrupt);
}
let p = self.inner.raw_flush(&mut output[written..], mode)?;
written += p.written;
Ok(RawProgress {
consumed: 0,
written,
done: p.done,
})
}
}