lzma_rs/decode/
stream.rs

1use crate::decode::lzbuffer::{LzBuffer, LzCircularBuffer};
2use crate::decode::lzma::{DecoderState, LzmaParams};
3use crate::decode::rangecoder::RangeDecoder;
4use crate::decompress::Options;
5use crate::error::Error;
6use std::fmt::Debug;
7use std::io::{self, BufRead, Cursor, Read, Write};
8
9/// Minimum header length to be read.
10/// - props: u8 (1 byte)
11/// - dict_size: u32 (4 bytes)
12const MIN_HEADER_LEN: usize = 5;
13
14/// Max header length to be read.
15/// - unpacked_size: u64 (8 bytes)
16const MAX_HEADER_LEN: usize = MIN_HEADER_LEN + 8;
17
18/// Required bytes after the header.
19/// - ignore: u8 (1 byte)
20/// - code: u32 (4 bytes)
21const START_BYTES: usize = 5;
22
23/// Maximum number of bytes to buffer while reading the header.
24const MAX_TMP_LEN: usize = MAX_HEADER_LEN + START_BYTES;
25
26/// Internal state of this streaming decoder. This is needed because we have to
27/// initialize the stream before processing any data.
28#[derive(Debug)]
29enum State<W>
30where
31    W: Write,
32{
33    /// Stream is initialized but header values have not yet been read.
34    Header(W),
35    /// Header values have been read and the stream is ready to process more data.
36    Data(RunState<W>),
37}
38
39/// Structures needed while decoding data.
40struct RunState<W>
41where
42    W: Write,
43{
44    decoder: DecoderState,
45    range: u32,
46    code: u32,
47    output: LzCircularBuffer<W>,
48}
49
50impl<W> Debug for RunState<W>
51where
52    W: Write,
53{
54    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
55        fmt.debug_struct("RunState")
56            .field("range", &self.range)
57            .field("code", &self.code)
58            .finish()
59    }
60}
61
62/// Lzma decompressor that can process multiple chunks of data using the
63/// `io::Write` interface.
64#[cfg_attr(docsrs, doc(cfg(stream)))]
65pub struct Stream<W>
66where
67    W: Write,
68{
69    /// Temporary buffer to hold data while the header is being read.
70    tmp: Cursor<[u8; MAX_TMP_LEN]>,
71    /// Whether the stream is initialized and ready to process data.
72    /// An `Option` is used to avoid interior mutability when updating the state.
73    state: Option<State<W>>,
74    /// Options given when a stream is created.
75    options: Options,
76}
77
78impl<W> Stream<W>
79where
80    W: Write,
81{
82    /// Initialize the stream. This will consume the `output` which is the sink
83    /// implementing `io::Write` that will receive decompressed bytes.
84    pub fn new(output: W) -> Self {
85        Self::new_with_options(&Options::default(), output)
86    }
87
88    /// Initialize the stream with the given `options`. This will consume the
89    /// `output` which is the sink implementing `io::Write` that will
90    /// receive decompressed bytes.
91    pub fn new_with_options(options: &Options, output: W) -> Self {
92        Self {
93            tmp: Cursor::new([0; MAX_TMP_LEN]),
94            state: Some(State::Header(output)),
95            options: *options,
96        }
97    }
98
99    /// Get a reference to the output sink
100    pub fn get_output(&self) -> Option<&W> {
101        self.state.as_ref().map(|state| match state {
102            State::Header(output) => &output,
103            State::Data(state) => state.output.get_output(),
104        })
105    }
106
107    /// Get a mutable reference to the output sink
108    pub fn get_output_mut(&mut self) -> Option<&mut W> {
109        self.state.as_mut().map(|state| match state {
110            State::Header(output) => output,
111            State::Data(state) => state.output.get_output_mut(),
112        })
113    }
114
115    /// Consumes the stream and returns the output sink. This also makes sure
116    /// we have properly reached the end of the stream.
117    pub fn finish(mut self) -> crate::error::Result<W> {
118        if let Some(state) = self.state.take() {
119            match state {
120                State::Header(output) => {
121                    if self.tmp.position() > 0 {
122                        Err(Error::LzmaError("failed to read header".to_string()))
123                    } else {
124                        Ok(output)
125                    }
126                }
127                State::Data(mut state) => {
128                    if !self.options.allow_incomplete {
129                        // Process one last time with empty input to force end of
130                        // stream checks
131                        let mut stream =
132                            Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]);
133                        let mut range_decoder =
134                            RangeDecoder::from_parts(&mut stream, state.range, state.code);
135                        state
136                            .decoder
137                            .process(&mut state.output, &mut range_decoder)?;
138                    }
139                    let output = state.output.finish()?;
140                    Ok(output)
141                }
142            }
143        } else {
144            // this will occur if a call to `write()` fails
145            Err(Error::LzmaError(
146                "can't finish stream because of previous write error".to_string(),
147            ))
148        }
149    }
150
151    /// Attempts to read the header and transition into a running state.
152    ///
153    /// This function will consume the state, returning the next state on both
154    /// error and success.
155    fn read_header<R: BufRead>(
156        output: W,
157        mut input: &mut R,
158        options: &Options,
159    ) -> crate::error::Result<State<W>> {
160        match LzmaParams::read_header(&mut input, options) {
161            Ok(params) => {
162                let decoder = DecoderState::new(params.properties, params.unpacked_size);
163                let output = LzCircularBuffer::from_stream(
164                    output,
165                    params.dict_size as usize,
166                    options.memlimit.unwrap_or(usize::MAX),
167                );
168                // The RangeDecoder is only kept temporarily as we are processing
169                // chunks of data.
170                if let Ok(rangecoder) = RangeDecoder::new(&mut input) {
171                    Ok(State::Data(RunState {
172                        decoder,
173                        output,
174                        range: rangecoder.range,
175                        code: rangecoder.code,
176                    }))
177                } else {
178                    // Failed to create a RangeDecoder because we need more data,
179                    // try again later.
180                    Ok(State::Header(output.into_output()))
181                }
182            }
183            // Failed to read_header() because we need more data, try again later.
184            Err(Error::HeaderTooShort(_)) => Ok(State::Header(output)),
185            // Fatal error. Don't retry.
186            Err(e) => Err(e),
187        }
188    }
189
190    /// Process compressed data
191    fn read_data<R: BufRead>(mut state: RunState<W>, mut input: &mut R) -> io::Result<RunState<W>> {
192        // Construct our RangeDecoder from the previous range and code
193        // values.
194        let mut rangecoder = RangeDecoder::from_parts(&mut input, state.range, state.code);
195
196        // Try to process all bytes of data.
197        state
198            .decoder
199            .process_stream(&mut state.output, &mut rangecoder)
200            .map_err(|e| -> io::Error { e.into() })?;
201
202        Ok(RunState {
203            decoder: state.decoder,
204            output: state.output,
205            range: rangecoder.range,
206            code: rangecoder.code,
207        })
208    }
209}
210
211impl<W> Debug for Stream<W>
212where
213    W: Write + Debug,
214{
215    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
216        fmt.debug_struct("Stream")
217            .field("tmp", &self.tmp.position())
218            .field("state", &self.state)
219            .field("options", &self.options)
220            .finish()
221    }
222}
223
224impl<W> Write for Stream<W>
225where
226    W: Write,
227{
228    fn write(&mut self, data: &[u8]) -> io::Result<usize> {
229        let mut input = Cursor::new(data);
230
231        if let Some(state) = self.state.take() {
232            let state = match state {
233                // Read the header values and transition into a running state.
234                State::Header(state) => {
235                    let res = if self.tmp.position() > 0 {
236                        // attempt to fill the tmp buffer
237                        let position = self.tmp.position();
238                        let bytes_read =
239                            input.read(&mut self.tmp.get_mut()[position as usize..])?;
240                        let bytes_read = if bytes_read < std::u64::MAX as usize {
241                            bytes_read as u64
242                        } else {
243                            return Err(io::Error::new(
244                                io::ErrorKind::Other,
245                                "Failed to convert integer to u64.",
246                            ));
247                        };
248                        self.tmp.set_position(position + bytes_read);
249
250                        // attempt to read the header from our tmp buffer
251                        let (position, res) = {
252                            let mut tmp_input =
253                                Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]);
254                            let res = Stream::read_header(state, &mut tmp_input, &self.options);
255                            (tmp_input.position(), res)
256                        };
257
258                        // discard all bytes up to position if reading the header
259                        // was successful
260                        if let Ok(State::Data(_)) = &res {
261                            let tmp = *self.tmp.get_ref();
262                            let end = self.tmp.position();
263                            let new_len = end - position;
264                            (&mut self.tmp.get_mut()[0..new_len as usize])
265                                .copy_from_slice(&tmp[position as usize..end as usize]);
266                            self.tmp.set_position(new_len);
267                        }
268                        res
269                    } else {
270                        Stream::read_header(state, &mut input, &self.options)
271                    };
272
273                    match res {
274                        // occurs when not enough input bytes were provided to
275                        // read the entire header
276                        Ok(State::Header(val)) => {
277                            if self.tmp.position() == 0 {
278                                // reset the cursor because we may have partial reads
279                                input.set_position(0);
280                                let bytes_read = input.read(&mut self.tmp.get_mut()[..])?;
281                                let bytes_read = if bytes_read < std::u64::MAX as usize {
282                                    bytes_read as u64
283                                } else {
284                                    return Err(io::Error::new(
285                                        io::ErrorKind::Other,
286                                        "Failed to convert integer to u64.",
287                                    ));
288                                };
289                                self.tmp.set_position(bytes_read);
290                            }
291                            State::Header(val)
292                        }
293
294                        // occurs when the header was successfully read and we
295                        // move on to the next state
296                        Ok(State::Data(val)) => State::Data(val),
297
298                        // occurs when the output was consumed due to a
299                        // non-recoverable error
300                        Err(e) => {
301                            return Err(match e {
302                                Error::IoError(e) | Error::HeaderTooShort(e) => e,
303                                Error::LzmaError(e) | Error::XzError(e) => {
304                                    io::Error::new(io::ErrorKind::Other, e)
305                                }
306                            });
307                        }
308                    }
309                }
310
311                // Process another chunk of data.
312                State::Data(state) => {
313                    let state = if self.tmp.position() > 0 {
314                        let mut tmp_input =
315                            Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]);
316                        let res = Stream::read_data(state, &mut tmp_input)?;
317                        self.tmp.set_position(0);
318                        res
319                    } else {
320                        state
321                    };
322                    State::Data(Stream::read_data(state, &mut input)?)
323                }
324            };
325            self.state.replace(state);
326        }
327        Ok(input.position() as usize)
328    }
329
330    /// Flushes the output sink. The internal buffer isn't flushed to avoid
331    /// corrupting the internal state. Instead, call `finish()` to finalize the
332    /// stream and flush all remaining internal data.
333    fn flush(&mut self) -> io::Result<()> {
334        if let Some(ref mut state) = self.state {
335            match state {
336                State::Header(_) => Ok(()),
337                State::Data(state) => state.output.get_output_mut().flush(),
338            }
339        } else {
340            Ok(())
341        }
342    }
343}
344
345impl From<Error> for io::Error {
346    fn from(error: Error) -> io::Error {
347        io::Error::new(io::ErrorKind::Other, format!("{:?}", error))
348    }
349}
350
351#[cfg(test)]
352mod test {
353    use super::*;
354
355    /// Test an empty stream
356    #[test]
357    fn test_stream_noop() {
358        let stream = Stream::new(Vec::new());
359        assert!(stream.get_output().unwrap().is_empty());
360
361        let output = stream.finish().unwrap();
362        assert!(output.is_empty());
363    }
364
365    /// Test writing an empty slice
366    #[test]
367    fn test_stream_zero() {
368        let mut stream = Stream::new(Vec::new());
369
370        stream.write_all(&[]).unwrap();
371        stream.write_all(&[]).unwrap();
372
373        let output = stream.finish().unwrap();
374
375        assert!(output.is_empty());
376    }
377
378    /// Test a bad header value
379    #[test]
380    #[should_panic(expected = "LZMA header invalid properties: 255 must be < 225")]
381    fn test_bad_header() {
382        let input = [255u8; 32];
383
384        let mut stream = Stream::new(Vec::new());
385
386        stream.write_all(&input[..]).unwrap();
387
388        let output = stream.finish().unwrap();
389
390        assert!(output.is_empty());
391    }
392
393    /// Test processing only partial data
394    #[test]
395    fn test_stream_incomplete() {
396        let input = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\
397                      \xfb\xff\xff\xc0\x00\x00\x00";
398        // Process until this index is reached.
399        let mut end = 1u64;
400
401        // Test when we fail to provide the minimum number of bytes required to
402        // read the header. Header size is 13 bytes but we also read the first 5
403        // bytes of data.
404        while end < (MAX_HEADER_LEN + START_BYTES) as u64 {
405            let mut stream = Stream::new(Vec::new());
406            stream.write_all(&input[..end as usize]).unwrap();
407            assert_eq!(stream.tmp.position(), end);
408
409            let err = stream.finish().unwrap_err();
410            assert!(
411                err.to_string().contains("failed to read header"),
412                "error was: {}",
413                err
414            );
415
416            end += 1;
417        }
418
419        // Test when we fail to provide enough bytes to terminate the stream. A
420        // properly terminated stream will have a code value of 0.
421        while end < input.len() as u64 {
422            let mut stream = Stream::new(Vec::new());
423            stream.write_all(&input[..end as usize]).unwrap();
424
425            // Header bytes will be buffered until there are enough to read
426            if end < (MAX_HEADER_LEN + START_BYTES) as u64 {
427                assert_eq!(stream.tmp.position(), end);
428            }
429
430            let err = stream.finish().unwrap_err();
431            assert!(err.to_string().contains("failed to fill whole buffer"));
432
433            end += 1;
434        }
435    }
436
437    /// Test processing all chunk sizes
438    #[test]
439    fn test_stream_chunked() {
440        let small_input = include_bytes!("../../tests/files/small.txt");
441
442        let mut reader = io::Cursor::new(&small_input[..]);
443        let mut small_input_compressed = Vec::new();
444        crate::lzma_compress(&mut reader, &mut small_input_compressed).unwrap();
445
446        let input : Vec<(&[u8], &[u8])> = vec![
447            (b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\xfb\xff\xff\xc0\x00\x00\x00", b""),
448            (&small_input_compressed[..], small_input)];
449        for (input, expected) in input {
450            for chunk in 1..input.len() {
451                let mut consumed = 0;
452                let mut stream = Stream::new(Vec::new());
453                while consumed < input.len() {
454                    let end = std::cmp::min(consumed + chunk, input.len());
455                    stream.write_all(&input[consumed..end]).unwrap();
456                    consumed = end;
457                }
458                let output = stream.finish().unwrap();
459                assert_eq!(expected, &output[..]);
460            }
461        }
462    }
463
464    #[test]
465    fn test_stream_corrupted() {
466        let mut stream = Stream::new(Vec::new());
467        let err = stream
468            .write_all(b"corrupted bytes here corrupted bytes here")
469            .unwrap_err();
470        assert!(err.to_string().contains("beyond output size"));
471        let err = stream.finish().unwrap_err();
472        assert!(err
473            .to_string()
474            .contains("can\'t finish stream because of previous write error"));
475    }
476
477    #[test]
478    fn test_allow_incomplete() {
479        let input = include_bytes!("../../tests/files/small.txt");
480
481        let mut reader = io::Cursor::new(&input[..]);
482        let mut compressed = Vec::new();
483        crate::lzma_compress(&mut reader, &mut compressed).unwrap();
484        let compressed = &compressed[..compressed.len() / 2];
485
486        // Should fail to finish() without the allow_incomplete option.
487        let mut stream = Stream::new(Vec::new());
488        stream.write_all(&compressed[..]).unwrap();
489        stream.finish().unwrap_err();
490
491        // Should succeed with the allow_incomplete option.
492        let mut stream = Stream::new_with_options(
493            &Options {
494                allow_incomplete: true,
495                ..Default::default()
496            },
497            Vec::new(),
498        );
499        stream.write_all(&compressed[..]).unwrap();
500        let output = stream.finish().unwrap();
501        assert_eq!(output, &input[..26]);
502    }
503}