agzip/
decode.rs

1//! Types to decompress a compressed input.
2
3use {
4    crate::{
5        format::{Flags, Footer},
6        reader::{Reader, Skippable},
7    },
8    flate2::{Crc, Decompress, DecompressError, FlushDecompress, Status},
9    std::{error, ffi::CString, fmt, mem, ops::ControlFlow},
10};
11
12/// A type containing output information about the gzip header.
13///
14/// To read the required information from the header, create an instance of this
15/// type and pass it to the [decoder](Decoder::new). Unnecessary fields can be
16/// left as `None`. The remaining data will be written after decoding the header.
17#[derive(Debug, Default)]
18pub struct ReadHeader<'head> {
19    pub mtime: Option<&'head mut u32>,
20    pub extra: Option<&'head mut Box<[u8]>>,
21    pub name: Option<&'head mut Option<CString>>,
22    pub comment: Option<&'head mut Option<CString>>,
23}
24
25#[derive(Debug)]
26enum State {
27    Start(Reader<[u8; 10]>),
28    ExtraLen(Reader<[u8; 2]>),
29    Extra(Reader<Skippable<Box<[u8]>>>),
30    Name(Vec<u8>),
31    Comment(Vec<u8>),
32    Crc(Reader<[u8; 2]>),
33    Payload,
34    Footer(Reader<[u8; 8]>),
35}
36
37#[derive(Debug)]
38struct Parser<'head> {
39    state: State,
40    flags: Flags,
41    header: ReadHeader<'head>,
42    footer: Footer,
43}
44
45impl<'head> Parser<'head> {
46    #[inline]
47    fn new(header: ReadHeader<'head>) -> Self {
48        let state = State::Start(Reader::default());
49        let flags = Flags(0);
50        let footer = Footer::empty();
51
52        Self {
53            state,
54            flags,
55            header,
56            footer,
57        }
58    }
59
60    fn parse<D>(&mut self, input: &mut &[u8], mut deco: D) -> Parsed
61    where
62        D: FnMut(&mut &[u8]) -> ControlFlow<()>,
63    {
64        loop {
65            match &mut self.state {
66                State::Start(read) => {
67                    let Some(&mut bytes) = read.read_from(input) else {
68                        return Parsed::Done;
69                    };
70
71                    let Some((flags, mtime)) = parse_start(bytes) else {
72                        return Parsed::InvalidHeader;
73                    };
74
75                    self.flags = flags;
76                    if let Some(mtime_mut) = self.header.mtime.as_deref_mut() {
77                        *mtime_mut = mtime;
78                    }
79
80                    self.state = State::ExtraLen(Reader::default());
81                }
82                State::ExtraLen(read) => {
83                    if !self.flags.has(Flags::EXTRA) {
84                        self.state = State::Name(vec![]);
85                        continue;
86                    }
87
88                    let Some(&mut bytes) = read.read_from(input) else {
89                        return Parsed::Done;
90                    };
91
92                    let len = u16::from_le_bytes(bytes) as usize;
93                    let read = if self.header.extra.is_some() {
94                        Reader::alloc(len).fill()
95                    } else {
96                        Reader::skip(len)
97                    };
98
99                    self.state = State::Extra(read);
100                }
101                State::Extra(read) => {
102                    let Some(extra) = read.read_from(input) else {
103                        return Parsed::Done;
104                    };
105
106                    if let Skippable::Fill(extra) = extra {
107                        if let Some(header_extra) = self.header.extra.as_deref_mut() {
108                            mem::swap(header_extra, extra);
109                        }
110                    }
111
112                    self.state = State::Name(vec![]);
113                }
114                State::Name(out) => {
115                    if !self.flags.has(Flags::NAME) {
116                        self.state = State::Comment(vec![]);
117                        continue;
118                    }
119
120                    let (read, parse) = read_while(0, input);
121                    if self.header.name.is_some() {
122                        out.extend_from_slice(read);
123                    }
124
125                    if parse {
126                        return Parsed::Done;
127                    }
128
129                    if let Some(name) = self.header.name.as_deref_mut() {
130                        *name = CString::new(mem::take(out)).ok();
131                    }
132
133                    self.state = State::Comment(vec![]);
134                }
135                State::Comment(out) => {
136                    if !self.flags.has(Flags::COMMENT) {
137                        self.state = State::Crc(Reader::default());
138                        continue;
139                    }
140
141                    let (read, parse) = read_while(0, input);
142                    if self.header.comment.is_some() {
143                        out.extend_from_slice(read);
144                    }
145
146                    if parse {
147                        return Parsed::Done;
148                    }
149
150                    if let Some(comment) = self.header.comment.as_deref_mut() {
151                        *comment = CString::new(mem::take(out)).ok();
152                    }
153
154                    self.state = State::Crc(Reader::default());
155                }
156                State::Crc(read) => {
157                    if !self.flags.has(Flags::CRC) {
158                        self.state = State::Payload;
159                        continue;
160                    }
161
162                    if read.read_from(input).is_none() {
163                        return Parsed::Done;
164                    };
165
166                    self.state = State::Payload;
167                }
168                State::Payload => match deco(input) {
169                    ControlFlow::Continue(()) => return Parsed::Done,
170                    ControlFlow::Break(()) => self.state = State::Footer(Reader::default()),
171                },
172                State::Footer(buf) => {
173                    let Some(&mut bytes) = buf.read_from(input) else {
174                        return Parsed::Done;
175                    };
176
177                    self.footer = parse_footer(bytes);
178                    return Parsed::End;
179                }
180            }
181        }
182    }
183}
184
185enum Parsed {
186    Done,
187    End,
188    InvalidHeader,
189}
190
191fn parse_start(s: [u8; 10]) -> Option<(Flags, u32)> {
192    let [31, 139, 8, flags, mt3, mt2, mt1, mt0, xfl, os] = s else {
193        return None;
194    };
195
196    let flags = Flags(flags);
197    let mtime = u32::from_le_bytes([mt3, mt2, mt1, mt0]);
198    _ = xfl; // ignored
199    _ = os; // ignored
200
201    Some((flags, mtime))
202}
203
204fn parse_footer(s: [u8; 8]) -> Footer {
205    let [c3, c2, c1, c0, i3, i2, i1, i0] = s;
206    let crc = u32::from_le_bytes([c3, c2, c1, c0]);
207    let isize = u32::from_le_bytes([i3, i2, i1, i0]);
208    Footer { crc, isize }
209}
210
211fn read_while<'input>(u: u8, input: &mut &'input [u8]) -> (&'input [u8], bool) {
212    match memchr::memchr(u, input) {
213        Some(n) => {
214            let (left, right) = input.split_at(n);
215            *input = &right[1..];
216            (left, false)
217        }
218        None => {
219            let out = *input;
220            *input = &[];
221            (out, true)
222        }
223    }
224}
225
226/// The stream decoder.
227#[derive(Debug)]
228pub struct Decoder<'head> {
229    decomp: Decompress,
230    parser: Parser<'head>,
231    crc: Crc,
232}
233
234impl<'head> Decoder<'head> {
235    /// Creates a new decoder instance. Specify the necessary output fields for the gzip
236    /// [header](ReadHeader) that will be written as a result of decoding.
237    #[inline]
238    pub fn new(header: ReadHeader<'head>) -> Self {
239        Self {
240            decomp: Decompress::new(false),
241            parser: Parser::new(header),
242            crc: Crc::default(),
243        }
244    }
245
246    /// Decodes a portion of input data and writes it to the output buffer.
247    /// Then, check the returned [decoded](Decoded) value, which contains the operation status.
248    pub fn decode(&mut self, mut input: &[u8], output: &mut [u8]) -> Decoded {
249        let mut written = 0;
250        let mut need_more_input = false;
251        let mut err = None;
252
253        let deco = |input: &mut &[u8]| {
254            let input_size = self.decomp.total_in();
255            let output_size = self.decomp.total_out();
256
257            let res = self.decomp.decompress(input, output, FlushDecompress::None);
258
259            let read = self.decomp.total_in() - input_size;
260            *input = &input[read as usize..];
261
262            written = (self.decomp.total_out() - output_size) as usize;
263            self.crc.update(&output[..written]);
264
265            match res {
266                Ok(Status::Ok) => ControlFlow::Continue(()),
267                Ok(Status::BufError) => {
268                    need_more_input = true;
269                    ControlFlow::Continue(())
270                }
271                Ok(Status::StreamEnd) => ControlFlow::Break(()),
272                Err(e) => {
273                    err = Some(Error::Decompress(e));
274                    ControlFlow::Continue(())
275                }
276            }
277        };
278
279        let initial_input_len = input.len();
280        let input_mut = &mut input;
281        let parsed = self.parser.parse(input_mut, deco);
282        let read = initial_input_len - input_mut.len();
283
284        match parsed {
285            Parsed::Done if need_more_input => {
286                debug_assert_eq!(written, 0, "nothing is written to the output");
287                Decoded::NeedMoreInput { read }
288            }
289            Parsed::Done => err.map_or(
290                Decoded::Done {
291                    read,
292                    written,
293                    end: false,
294                },
295                Decoded::Fail,
296            ),
297            Parsed::End if self.parser.footer.checksum(&self.crc) => Decoded::Done {
298                read,
299                written,
300                end: true,
301            },
302            Parsed::End => Decoded::Fail(Error::ChecksumMismatch),
303            Parsed::InvalidHeader => Decoded::Fail(Error::InvalidHeader),
304        }
305    }
306}
307
308/// The [decode](Decoder::decode) operation status.
309#[derive(Debug)]
310pub enum Decoded {
311    /// A portion of the input data has been successfully decompressed
312    /// and written to the output buffer.
313    Done {
314        /// How much data has been read from the input buffer.
315        read: usize,
316
317        /// How much data has been written to the output buffer.
318        written: usize,
319
320        /// Whether all input data has been fully written.
321        end: bool,
322    },
323
324    /// More input data is required.
325    /// In this case, read more data into the input buffer and retry the operation.
326    NeedMoreInput {
327        /// How much data has been read from the input buffer.
328        /// Even though there was not enough data to complete the operation,
329        /// a smaller portion may have been read.
330        read: usize,
331    },
332
333    /// A decompression error has occurred.
334    Fail(Error),
335}
336
337/// The decoding error.
338#[derive(Debug)]
339pub enum Error {
340    /// The header data is invalid.
341    InvalidHeader,
342
343    /// The checksum doesn't match.
344    ChecksumMismatch,
345
346    /// The decompression error.
347    Decompress(DecompressError),
348}
349
350impl fmt::Display for Error {
351    #[inline]
352    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353        match self {
354            Self::InvalidHeader => write!(f, "invalid header"),
355            Self::ChecksumMismatch => write!(f, "the checksum doesn't match"),
356            Self::Decompress(e) => e.fmt(f),
357        }
358    }
359}
360
361impl error::Error for Error {
362    #[inline]
363    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
364        match self {
365            Self::InvalidHeader | Self::ChecksumMismatch => None,
366            Self::Decompress(e) => Some(e),
367        }
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    fn decode(expected: &[u8], input: &[u8]) {
376        let mut d = Decoder::new(ReadHeader::default());
377        let mut output = vec![0; expected.len()];
378        let Decoded::Done { read, written, end } = d.decode(input, output.as_mut_slice()) else {
379            panic!("failed to decode input");
380        };
381
382        assert_eq!(read, input.len());
383        assert_eq!(written, expected.len());
384        assert!(end);
385        assert_eq!(output, expected);
386    }
387
388    #[test]
389    fn decode_hello() {
390        decode(
391            include_bytes!("../test/hello.txt"),
392            include_bytes!("../test/hello.gzip"),
393        );
394    }
395
396    #[test]
397    fn decode_lorem() {
398        decode(
399            include_bytes!("../test/lorem.txt"),
400            include_bytes!("../test/lorem.gzip"),
401        );
402    }
403
404    fn decode_partial(expected: &[u8], input: &[u8]) {
405        let mut d = Decoder::new(ReadHeader::default());
406        let mut output = vec![0; expected.len()];
407        let mut p = 0;
408        let mut finished = false;
409
410        for part in input.chunks(4) {
411            let Decoded::Done { read, written, end } = d.decode(part, &mut output[p..]) else {
412                panic!("failed to decode input");
413            };
414
415            p += written;
416            finished = end || finished;
417
418            assert_eq!(read, part.len());
419        }
420
421        assert_eq!(p, expected.len());
422        assert!(finished);
423        assert_eq!(output, expected);
424    }
425
426    #[test]
427    fn decode_partial_hello() {
428        decode_partial(
429            include_bytes!("../test/hello.txt"),
430            include_bytes!("../test/hello.gzip"),
431        );
432    }
433
434    #[test]
435    fn decode_partial_lorem() {
436        decode_partial(
437            include_bytes!("../test/lorem.txt"),
438            include_bytes!("../test/lorem.gzip"),
439        );
440    }
441
442    #[test]
443    fn decode_no_input() {
444        let expected = include_bytes!("../test/lorem.txt");
445        let input = include_bytes!("../test/lorem.gzip");
446        let input = &input[..input.len() / 2];
447
448        let mut d = Decoder::new(ReadHeader::default());
449        let mut output = vec![0; expected.len()];
450        let Decoded::Done {
451            read, end: false, ..
452        } = d.decode(input, output.as_mut_slice())
453        else {
454            panic!("failed to decode input");
455        };
456
457        let input = &input[read..];
458        let decoded = d.decode(input, output.as_mut_slice());
459        assert!(matches!(decoded, Decoded::NeedMoreInput { read: 0 }));
460    }
461
462    #[test]
463    fn decode_checksum_mismatch() {
464        let expected = include_bytes!("../test/hello.txt");
465        let input = const {
466            let mut input = *include_bytes!("../test/hello.gzip");
467            input[input.len() - 5] = 0;
468            input
469        }
470        .as_slice();
471
472        let mut d = Decoder::new(ReadHeader::default());
473        let mut output = vec![0; expected.len()];
474        let decoded = d.decode(input, output.as_mut_slice());
475        assert!(matches!(decoded, Decoded::Fail(Error::ChecksumMismatch)));
476    }
477}