use {
crate::{
format::{Flags, Footer},
reader::{Reader, Skippable},
},
flate2::{Crc, Decompress, DecompressError, FlushDecompress, Status},
std::{error, ffi::CString, fmt, mem, ops::ControlFlow},
};
#[derive(Debug, Default)]
pub struct ReadHeader<'head> {
pub mtime: Option<&'head mut u32>,
pub extra: Option<&'head mut Box<[u8]>>,
pub name: Option<&'head mut Option<CString>>,
pub comment: Option<&'head mut Option<CString>>,
}
#[derive(Debug)]
enum State {
Start(Reader<[u8; 10]>),
ExtraLen(Reader<[u8; 2]>),
Extra(Reader<Skippable<Box<[u8]>>>),
Name(Vec<u8>),
Comment(Vec<u8>),
Crc(Reader<[u8; 2]>),
Payload,
Footer(Reader<[u8; 8]>),
}
#[derive(Debug)]
struct Parser<'head> {
state: State,
flags: Flags,
header: ReadHeader<'head>,
footer: Footer,
}
impl<'head> Parser<'head> {
#[inline]
fn new(header: ReadHeader<'head>) -> Self {
let state = State::Start(Reader::default());
let flags = Flags(0);
let footer = Footer::empty();
Self {
state,
flags,
header,
footer,
}
}
fn parse<D>(&mut self, input: &mut &[u8], mut deco: D) -> Parsed
where
D: FnMut(&mut &[u8]) -> ControlFlow<()>,
{
loop {
match &mut self.state {
State::Start(read) => {
let Some(&mut bytes) = read.read_from(input) else {
return Parsed::Done;
};
let Some((flags, mtime)) = parse_start(bytes) else {
return Parsed::InvalidHeader;
};
self.flags = flags;
if let Some(mtime_mut) = self.header.mtime.as_deref_mut() {
*mtime_mut = mtime;
}
self.state = State::ExtraLen(Reader::default());
}
State::ExtraLen(read) => {
if !self.flags.has(Flags::EXTRA) {
self.state = State::Name(vec![]);
continue;
}
let Some(&mut bytes) = read.read_from(input) else {
return Parsed::Done;
};
let len = u16::from_le_bytes(bytes) as usize;
let read = if self.header.extra.is_some() {
Reader::alloc(len).fill()
} else {
Reader::skip(len)
};
self.state = State::Extra(read);
}
State::Extra(read) => {
let Some(extra) = read.read_from(input) else {
return Parsed::Done;
};
if let Skippable::Fill(extra) = extra {
if let Some(header_extra) = self.header.extra.as_deref_mut() {
mem::swap(header_extra, extra);
}
}
self.state = State::Name(vec![]);
}
State::Name(out) => {
if !self.flags.has(Flags::NAME) {
self.state = State::Comment(vec![]);
continue;
}
let (read, parse) = read_while(0, input);
if self.header.name.is_some() {
out.extend_from_slice(read);
}
if parse {
return Parsed::Done;
}
if let Some(name) = self.header.name.as_deref_mut() {
*name = CString::new(mem::take(out)).ok();
}
self.state = State::Comment(vec![]);
}
State::Comment(out) => {
if !self.flags.has(Flags::COMMENT) {
self.state = State::Crc(Reader::default());
continue;
}
let (read, parse) = read_while(0, input);
if self.header.comment.is_some() {
out.extend_from_slice(read);
}
if parse {
return Parsed::Done;
}
if let Some(comment) = self.header.comment.as_deref_mut() {
*comment = CString::new(mem::take(out)).ok();
}
self.state = State::Crc(Reader::default());
}
State::Crc(read) => {
if !self.flags.has(Flags::CRC) {
self.state = State::Payload;
continue;
}
if read.read_from(input).is_none() {
return Parsed::Done;
};
self.state = State::Payload;
}
State::Payload => match deco(input) {
ControlFlow::Continue(()) => return Parsed::Done,
ControlFlow::Break(()) => self.state = State::Footer(Reader::default()),
},
State::Footer(buf) => {
let Some(&mut bytes) = buf.read_from(input) else {
return Parsed::Done;
};
self.footer = parse_footer(bytes);
return Parsed::End;
}
}
}
}
}
enum Parsed {
Done,
End,
InvalidHeader,
}
fn parse_start(s: [u8; 10]) -> Option<(Flags, u32)> {
let [31, 139, 8, flags, mt3, mt2, mt1, mt0, xfl, os] = s else {
return None;
};
let flags = Flags(flags);
let mtime = u32::from_le_bytes([mt3, mt2, mt1, mt0]);
_ = xfl; _ = os;
Some((flags, mtime))
}
fn parse_footer(s: [u8; 8]) -> Footer {
let [c3, c2, c1, c0, i3, i2, i1, i0] = s;
let crc = u32::from_le_bytes([c3, c2, c1, c0]);
let isize = u32::from_le_bytes([i3, i2, i1, i0]);
Footer { crc, isize }
}
fn read_while<'input>(u: u8, input: &mut &'input [u8]) -> (&'input [u8], bool) {
match memchr::memchr(u, input) {
Some(n) => {
let (left, right) = input.split_at(n);
*input = &right[1..];
(left, false)
}
None => {
let out = *input;
*input = &[];
(out, true)
}
}
}
#[derive(Debug)]
pub struct Decoder<'head> {
decomp: Decompress,
parser: Parser<'head>,
crc: Crc,
}
impl<'head> Decoder<'head> {
#[inline]
pub fn new(header: ReadHeader<'head>) -> Self {
Self {
decomp: Decompress::new(false),
parser: Parser::new(header),
crc: Crc::default(),
}
}
pub fn decode(&mut self, mut input: &[u8], output: &mut [u8]) -> Decoded {
let mut written = 0;
let mut need_more_input = false;
let mut err = None;
let deco = |input: &mut &[u8]| {
let input_size = self.decomp.total_in();
let output_size = self.decomp.total_out();
let res = self.decomp.decompress(input, output, FlushDecompress::None);
let read = self.decomp.total_in() - input_size;
*input = &input[read as usize..];
written = (self.decomp.total_out() - output_size) as usize;
self.crc.update(&output[..written]);
match res {
Ok(Status::Ok) => ControlFlow::Continue(()),
Ok(Status::BufError) => {
need_more_input = true;
ControlFlow::Continue(())
}
Ok(Status::StreamEnd) => ControlFlow::Break(()),
Err(e) => {
err = Some(Error::Decompress(e));
ControlFlow::Continue(())
}
}
};
let initial_input_len = input.len();
let input_mut = &mut input;
let parsed = self.parser.parse(input_mut, deco);
let read = initial_input_len - input_mut.len();
match parsed {
Parsed::Done if need_more_input => {
debug_assert_eq!(written, 0, "nothing is written to the output");
Decoded::NeedMoreInput { read }
}
Parsed::Done => err.map_or(
Decoded::Done {
read,
written,
end: false,
},
Decoded::Fail,
),
Parsed::End if self.parser.footer.checksum(&self.crc) => Decoded::Done {
read,
written,
end: true,
},
Parsed::End => Decoded::Fail(Error::ChecksumMismatch),
Parsed::InvalidHeader => Decoded::Fail(Error::InvalidHeader),
}
}
}
#[derive(Debug)]
pub enum Decoded {
Done {
read: usize,
written: usize,
end: bool,
},
NeedMoreInput {
read: usize,
},
Fail(Error),
}
#[derive(Debug)]
pub enum Error {
InvalidHeader,
ChecksumMismatch,
Decompress(DecompressError),
}
impl fmt::Display for Error {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidHeader => write!(f, "invalid header"),
Self::ChecksumMismatch => write!(f, "the checksum doesn't match"),
Self::Decompress(e) => e.fmt(f),
}
}
}
impl error::Error for Error {
#[inline]
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Self::InvalidHeader | Self::ChecksumMismatch => None,
Self::Decompress(e) => Some(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn decode(expected: &[u8], input: &[u8]) {
let mut d = Decoder::new(ReadHeader::default());
let mut output = vec![0; expected.len()];
let Decoded::Done { read, written, end } = d.decode(input, output.as_mut_slice()) else {
panic!("failed to decode input");
};
assert_eq!(read, input.len());
assert_eq!(written, expected.len());
assert!(end);
assert_eq!(output, expected);
}
#[test]
fn decode_hello() {
decode(
include_bytes!("../test/hello.txt"),
include_bytes!("../test/hello.gzip"),
);
}
#[test]
fn decode_lorem() {
decode(
include_bytes!("../test/lorem.txt"),
include_bytes!("../test/lorem.gzip"),
);
}
fn decode_partial(expected: &[u8], input: &[u8]) {
let mut d = Decoder::new(ReadHeader::default());
let mut output = vec![0; expected.len()];
let mut p = 0;
let mut finished = false;
for part in input.chunks(4) {
let Decoded::Done { read, written, end } = d.decode(part, &mut output[p..]) else {
panic!("failed to decode input");
};
p += written;
finished = end || finished;
assert_eq!(read, part.len());
}
assert_eq!(p, expected.len());
assert!(finished);
assert_eq!(output, expected);
}
#[test]
fn decode_partial_hello() {
decode_partial(
include_bytes!("../test/hello.txt"),
include_bytes!("../test/hello.gzip"),
);
}
#[test]
fn decode_partial_lorem() {
decode_partial(
include_bytes!("../test/lorem.txt"),
include_bytes!("../test/lorem.gzip"),
);
}
#[test]
fn decode_no_input() {
let expected = include_bytes!("../test/lorem.txt");
let input = include_bytes!("../test/lorem.gzip");
let input = &input[..input.len() / 2];
let mut d = Decoder::new(ReadHeader::default());
let mut output = vec![0; expected.len()];
let Decoded::Done {
read, end: false, ..
} = d.decode(input, output.as_mut_slice())
else {
panic!("failed to decode input");
};
let input = &input[read..];
let decoded = d.decode(input, output.as_mut_slice());
assert!(matches!(decoded, Decoded::NeedMoreInput { read: 0 }));
}
#[test]
fn decode_checksum_mismatch() {
let expected = include_bytes!("../test/hello.txt");
let input = const {
let mut input = *include_bytes!("../test/hello.gzip");
input[input.len() - 5] = 0;
input
}
.as_slice();
let mut d = Decoder::new(ReadHeader::default());
let mut output = vec![0; expected.len()];
let decoded = d.decode(input, output.as_mut_slice());
assert!(matches!(decoded, Decoded::Fail(Error::ChecksumMismatch)));
}
}