fastlz_rs/
decompress.rs

1use core::fmt;
2
3use crate::util::*;
4
5#[cfg(feature = "alloc")]
6extern crate alloc;
7
8#[cfg(feature = "std")]
9extern crate std;
10
11/// Compression errors
12#[derive(Debug, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum DecompressError {
15    /// The input was truncated
16    InputTruncated,
17    /// The input contains an invalid backreference
18    InvalidBackreference,
19    /// The input contains an compression level indicator
20    InvalidCompressionLevel,
21    /// The output buffer was too small to hold all the output.
22    ///
23    /// The output that has been written *is* valid, but has been truncated.
24    OutputTooSmall,
25}
26impl fmt::Display for DecompressError {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            DecompressError::InputTruncated => write!(f, "input was truncated"),
30            DecompressError::InvalidBackreference => write!(f, "invalid backreference"),
31            DecompressError::InvalidCompressionLevel => write!(f, "invalid compression level"),
32            DecompressError::OutputTooSmall => write!(f, "output buffer was insufficient"),
33        }
34    }
35}
36#[cfg(feature = "std")]
37impl std::error::Error for DecompressError {}
38
39impl<'a> OutputSink<DecompressError> for BufOutput<'a> {
40    fn put_lits(&mut self, lits: &[u8]) -> Result<(), DecompressError> {
41        let mut len = lits.len();
42        let mut did_overflow = false;
43        if self.pos + len > self.buf.len() {
44            did_overflow = true;
45            len = self.buf.len() - self.pos;
46        }
47
48        self.buf[self.pos..self.pos + len].copy_from_slice(&lits[..len]);
49        self.pos += len;
50
51        if did_overflow {
52            Err(DecompressError::OutputTooSmall)
53        } else {
54            Ok(())
55        }
56    }
57
58    fn put_backref(&mut self, disp: usize, mut len: usize) -> Result<(), DecompressError> {
59        if disp + 1 > self.pos {
60            return Err(DecompressError::InvalidBackreference);
61        }
62
63        let mut did_overflow = false;
64        if self.pos + len > self.buf.len() {
65            did_overflow = true;
66            len = self.buf.len() - self.pos;
67        }
68
69        for i in 0..len {
70            self.buf[self.pos + i] = self.buf[self.pos - disp - 1 + i];
71        }
72        self.pos += len;
73
74        if did_overflow {
75            Err(DecompressError::OutputTooSmall)
76        } else {
77            Ok(())
78        }
79    }
80}
81
82#[cfg(feature = "alloc")]
83impl OutputSink<DecompressError> for VecOutput {
84    fn put_lits(&mut self, lits: &[u8]) -> Result<(), DecompressError> {
85        self.vec.extend_from_slice(lits);
86        Ok(())
87    }
88
89    fn put_backref(&mut self, disp: usize, len: usize) -> Result<(), DecompressError> {
90        let pos = self.vec.len();
91        if disp + 1 > pos {
92            return Err(DecompressError::InvalidBackreference);
93        }
94
95        self.vec.resize(pos + len, 0);
96        for i in 0..len {
97            self.vec[pos + i] = self.vec[pos - disp - 1 + i];
98        }
99
100        Ok(())
101    }
102}
103
104trait InputHelper {
105    fn getc(&mut self) -> Result<u8, DecompressError>;
106    fn check_len(&mut self, min: usize) -> Result<(), DecompressError>;
107}
108impl InputHelper for &[u8] {
109    fn getc(&mut self) -> Result<u8, DecompressError> {
110        if self.len() == 0 {
111            return Err(DecompressError::InputTruncated);
112        }
113        let c = self[0];
114        *self = &self[1..];
115        Ok(c)
116    }
117
118    fn check_len(&mut self, min: usize) -> Result<(), DecompressError> {
119        if self.len() < min {
120            Err(DecompressError::InputTruncated)
121        } else {
122            Ok(())
123        }
124    }
125}
126
127fn decompress_lv1(
128    mut inp: &[u8],
129    outp: &mut impl OutputSink<DecompressError>,
130) -> Result<(), DecompressError> {
131    // special for first control byte
132    let mut ctrl = inp.getc().unwrap() & 0b000_11111;
133    loop {
134        if ctrl >> 5 == 0b000 {
135            // literal run
136            let len = (ctrl & 0b000_11111) as usize + 1;
137            inp.check_len(len)?;
138            outp.put_lits(&inp[..len])?;
139            inp = &inp[len..];
140        } else {
141            // backreference
142            let mut disp = ((ctrl & 0b000_11111) as usize) << 8;
143            let len = if ctrl >> 5 == 0b111 {
144                // long match
145                inp.getc()? as usize + 9
146            } else {
147                (ctrl >> 5) as usize + 2
148            };
149            disp |= inp.getc()? as usize;
150            outp.put_backref(disp, len)?;
151        }
152
153        if let Ok(c) = inp.getc() {
154            ctrl = c;
155        } else {
156            return Ok(());
157        }
158    }
159}
160
161fn decompress_lv2(
162    mut inp: &[u8],
163    outp: &mut impl OutputSink<DecompressError>,
164) -> Result<(), DecompressError> {
165    // special for first control byte
166    let mut ctrl = inp.getc().unwrap() & 0b000_11111;
167    loop {
168        if ctrl >> 5 == 0b000 {
169            // literal run
170            let len = (ctrl & 0b000_11111) as usize + 1;
171            inp.check_len(len)?;
172            outp.put_lits(&inp[..len])?;
173            inp = &inp[len..];
174        } else {
175            // backreference
176            let mut disp = ((ctrl & 0b000_11111) as usize) << 8;
177
178            let mut len = (ctrl >> 5) as usize + 2;
179            if ctrl >> 5 == 0b111 {
180                // long match
181                loop {
182                    let morelen = inp.getc()?;
183                    len += morelen as usize;
184                    if morelen != 0xff {
185                        break;
186                    }
187                }
188            }
189
190            disp |= inp.getc()? as usize;
191            if disp == 0b11111_11111111 {
192                let moredisp = ((inp.getc()? as usize) << 8) | (inp.getc()? as usize);
193                disp += moredisp;
194            }
195
196            outp.put_backref(disp, len)?;
197        }
198
199        if let Ok(c) = inp.getc() {
200            ctrl = c;
201        } else {
202            return Ok(());
203        }
204    }
205}
206
207fn decompress_impl(
208    inp: &[u8],
209    outp: &mut impl OutputSink<DecompressError>,
210) -> Result<(), DecompressError> {
211    if inp.len() == 0 {
212        return Ok(());
213    }
214
215    match inp[0] >> 5 {
216        0 => decompress_lv1(inp, outp),
217        1 => decompress_lv2(inp, outp),
218        _ => Err(DecompressError::InvalidCompressionLevel),
219    }
220}
221
222/// Decompress the input into a preallocated buffer
223///
224/// Returns the actual decompressed size on success, or an error otherwise
225pub fn decompress_to_buf(inp: &[u8], outp: &mut [u8]) -> Result<usize, DecompressError> {
226    let mut outp: BufOutput = outp.into();
227    decompress_impl(inp, &mut outp)?;
228    Ok(outp.pos)
229}
230
231#[cfg(feature = "alloc")]
232/// Decompress the input into a [Vec](alloc::vec::Vec)
233///
234/// Returns the result on success, or an error otherwise
235///
236/// If `capacity_hint` is provided, it will be passed to [Vec::with_capacity](alloc::vec::Vec::with_capacity)
237pub fn decompress_to_vec(
238    inp: &[u8],
239    capacity_hint: Option<usize>,
240) -> Result<alloc::vec::Vec<u8>, DecompressError> {
241    let mut ret: VecOutput = if let Some(capacity_hint) = capacity_hint {
242        alloc::vec::Vec::with_capacity(capacity_hint)
243    } else {
244        alloc::vec::Vec::new()
245    }
246    .into();
247    decompress_impl(inp, &mut ret)?;
248    Ok(ret.vec)
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_buf_out_lits() {
257        {
258            let mut out = [0u8; 3];
259            let mut outbuf: BufOutput = (&mut out[..]).into();
260            outbuf.put_lits(&[1]).unwrap();
261            assert_eq!(outbuf.buf, [1, 0, 0]);
262            // test overflow
263            outbuf.put_lits(&[2, 3, 4]).expect_err("");
264            assert_eq!(outbuf.buf, [1, 2, 3]);
265        }
266
267        {
268            let mut out = [0u8; 3];
269            let mut outbuf: BufOutput = (&mut out[..]).into();
270            // test exact fit
271            outbuf.put_lits(&[1, 2, 3]).unwrap();
272            assert_eq!(outbuf.buf, [1, 2, 3]);
273            outbuf.put_lits(&[4]).expect_err("");
274            assert_eq!(outbuf.buf, [1, 2, 3]);
275        }
276    }
277
278    #[test]
279    fn test_buf_out_backref() {
280        {
281            let mut out = [0u8; 8];
282            let mut outbuf: BufOutput = (&mut out[..]).into();
283            outbuf.put_lits(&[1, 2, 3]).unwrap();
284
285            // invalid, before the start
286            assert_eq!(
287                outbuf.put_backref(3, 5),
288                Err(DecompressError::InvalidBackreference)
289            );
290
291            // overflow, but should still write up to limit
292            assert_eq!(
293                outbuf.put_backref(1, 6),
294                Err(DecompressError::OutputTooSmall)
295            );
296
297            assert_eq!(outbuf.buf, [1, 2, 3, 2, 3, 2, 3, 2])
298        }
299
300        {
301            let mut out = [0u8; 8];
302            let mut outbuf: BufOutput = (&mut out[..]).into();
303            outbuf.put_lits(&[1, 2, 3]).unwrap();
304
305            // exact fit
306            outbuf.put_backref(2, 5).unwrap();
307            assert_eq!(outbuf.buf, [1, 2, 3, 1, 2, 3, 1, 2]);
308        }
309
310        // note: we already tested the "hard" case of len > disp
311    }
312
313    #[cfg(feature = "alloc")]
314    #[test]
315    fn test_vec_out_lits() {
316        let out = alloc::vec::Vec::new();
317        let mut outbuf: VecOutput = out.into();
318        outbuf.put_lits(&[1]).unwrap();
319        assert_eq!(outbuf.vec, [1]);
320        outbuf.put_lits(&[2, 3, 4]).unwrap();
321        assert_eq!(outbuf.vec, [1, 2, 3, 4]);
322    }
323
324    #[cfg(feature = "alloc")]
325    #[test]
326    fn test_vec_out_backref() {
327        let out = alloc::vec::Vec::new();
328        let mut outbuf: VecOutput = out.into();
329        outbuf.put_lits(&[1, 2, 3]).unwrap();
330        outbuf.put_backref(1, 6).unwrap();
331        assert_eq!(outbuf.vec, [1, 2, 3, 2, 3, 2, 3, 2, 3]);
332    }
333
334    #[test]
335    fn test_lv1_manual_lits() {
336        let mut out = [0u8; 5];
337        let len = decompress_to_buf(&[0x01, b'A', b'B', 0x02, b'C', b'D', b'E'], &mut out).unwrap();
338        assert_eq!(len, 5);
339        assert_eq!(out, [b'A', b'B', b'C', b'D', b'E']);
340    }
341
342    #[test]
343    fn test_lv1_manual_short_match() {
344        let mut out = [0u8; 5];
345        let len = decompress_to_buf(&[0x01, b'A', b'B', 0x20, 0x01], &mut out).unwrap();
346        assert_eq!(len, 5);
347        assert_eq!(out, [b'A', b'B', b'A', b'B', b'A']);
348    }
349
350    #[test]
351    fn test_lv1_manual_long_match() {
352        let mut out = [0u8; 11];
353        let len = decompress_to_buf(&[0x01, b'A', b'B', 0xe0, 0x00, 0x01], &mut out).unwrap();
354        assert_eq!(len, 11);
355        assert_eq!(
356            out,
357            [b'A', b'B', b'A', b'B', b'A', b'B', b'A', b'B', b'A', b'B', b'A']
358        );
359    }
360
361    #[cfg(feature = "std")]
362    #[test]
363    fn test_lv1_against_ref() {
364        extern crate std;
365
366        let d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
367        let inp_fn = d.join("src/decompress.rs");
368
369        let inp = std::fs::read(inp_fn).unwrap();
370
371        let mut reference = crate::wasmtester::FastLZWasm::new();
372        let ref_ = reference.fastlz_compress_level(1, &inp);
373        std::println!("{:02x?}", &ref_[..8]);
374
375        let out = decompress_to_vec(&ref_, None).unwrap();
376        assert_eq!(inp, out);
377    }
378
379    #[test]
380    fn test_lv2_manual_short_match() {
381        let mut out = [0u8; 5];
382        let len = decompress_to_buf(&[0x21, b'A', b'B', 0x20, 0x01], &mut out).unwrap();
383        assert_eq!(len, 5);
384        assert_eq!(out, [b'A', b'B', b'A', b'B', b'A']);
385    }
386
387    #[test]
388    fn test_lv2_manual_long_match() {
389        let mut out = [0u8; 11];
390        let len = decompress_to_buf(&[0x21, b'A', b'B', 0xe0, 0x00, 0x01], &mut out).unwrap();
391        assert_eq!(len, 11);
392        assert_eq!(
393            out,
394            [b'A', b'B', b'A', b'B', b'A', b'B', b'A', b'B', b'A', b'B', b'A']
395        );
396    }
397
398    #[test]
399    fn test_lv2_manual_verylong_match() {
400        let mut out = [0u8; 266];
401        let len = decompress_to_buf(&[0x21, b'A', b'B', 0xe0, 0xff, 0x00, 0x01], &mut out).unwrap();
402        assert_eq!(len, 266);
403        for i in 0..(266 / 2) {
404            assert_eq!(out[i * 2], b'A');
405            assert_eq!(out[i * 2 + 1], b'B');
406        }
407    }
408
409    #[test]
410    fn test_lv2_manual_verylong_disp() {
411        let mut out = [0u8; 0x2004];
412        let len = decompress_to_buf(
413            &[
414                0x21, b'A', 0x00, 0xE0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
415                0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
416                0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x15, 0x00, 0x3F, 0xFF, 0x00, 0x00,
417                0x00, b'Z',
418            ],
419            &mut out,
420        )
421        .unwrap();
422        assert_eq!(len, 0x2004);
423        for i in 0..0x2004 {
424            if i == 0 {
425                assert_eq!(out[i], b'A');
426            } else if i == 0x2000 {
427                assert_eq!(out[i], b'A');
428            } else if i == 0x2003 {
429                assert_eq!(out[i], b'Z');
430            } else {
431                assert_eq!(out[i], 0);
432            }
433        }
434    }
435
436    #[cfg(feature = "std")]
437    #[test]
438    fn test_lv2_against_ref() {
439        extern crate std;
440
441        let d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
442        let inp_fn = d.join("src/decompress.rs");
443
444        let inp = std::fs::read(inp_fn).unwrap();
445
446        let mut reference = crate::wasmtester::FastLZWasm::new();
447        let ref_ = reference.fastlz_compress_level(2, &inp);
448        std::println!("{:02x?}", &ref_[..8]);
449
450        let out = decompress_to_vec(&ref_, None).unwrap();
451        assert_eq!(inp, out);
452    }
453}