lzma_rust2/
lzma2_reader.rs

1use super::{
2    decoder::LZMADecoder,
3    lz::LZDecoder,
4    range_dec::{RangeDecoder, RangeDecoderBuffer},
5};
6use byteorder::{self, BigEndian, ReadBytesExt};
7use std::io::{ErrorKind, Read};
8
9pub const COMPRESSED_SIZE_MAX: u32 = 1 << 16;
10
11/// Decompresses a raw LZMA2 stream (no XZ headers).
12/// # Examples
13/// ```
14/// use std::io::Read;
15/// use lzma_rust2::LZMA2Reader;
16/// use lzma_rust2::LZMA2Options;
17///
18/// let compressed: Vec<u8> = vec![1, 0, 12, 72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33, 0];
19/// let mut reader = LZMA2Reader::new(compressed.as_slice(), LZMA2Options::DICT_SIZE_DEFAULT, None);
20/// let mut decompressed = Vec::new();
21/// reader.read_to_end(&mut decompressed).unwrap();
22/// assert_eq!(&decompressed[..], b"Hello, world!");
23/// ```
24pub struct LZMA2Reader<R> {
25    inner: R,
26    lz: LZDecoder,
27    rc: RangeDecoder<RangeDecoderBuffer>,
28    lzma: Option<LZMADecoder>,
29    uncompressed_size: usize,
30    is_lzma_chunk: bool,
31    need_dict_reset: bool,
32    need_props: bool,
33    end_reached: bool,
34    error: Option<std::io::Error>,
35}
36
37#[inline]
38pub fn get_memory_usage(dict_size: u32) -> u32 {
39    40 + COMPRESSED_SIZE_MAX / 1024 + get_dict_size(dict_size) / 1024
40}
41
42#[inline]
43fn get_dict_size(dict_size: u32) -> u32 {
44    (dict_size + 15) & !15
45}
46
47impl<R> LZMA2Reader<R> {
48    pub fn into_inner(self) -> R {
49        self.inner
50    }
51
52    pub fn get_ref(&self) -> &R {
53        &self.inner
54    }
55
56    pub fn get_mut(&mut self) -> &mut R {
57        &mut self.inner
58    }
59}
60
61impl<R: Read> LZMA2Reader<R> {
62    /// Create a new LZMA2 reader.
63    /// `inner` is the reader to read compressed data from.
64    /// `dict_size` is the dictionary size in bytes.
65    pub fn new(inner: R, dict_size: u32, preset_dict: Option<&[u8]>) -> Self {
66        let has_preset = preset_dict.as_ref().map(|a| !a.is_empty()).unwrap_or(false);
67        let lz = LZDecoder::new(get_dict_size(dict_size) as _, preset_dict);
68        let rc = RangeDecoder::new_buffer(COMPRESSED_SIZE_MAX as _);
69        Self {
70            inner,
71            lz,
72            rc,
73            lzma: None,
74            uncompressed_size: 0,
75            is_lzma_chunk: false,
76            need_dict_reset: !has_preset,
77            need_props: true,
78            end_reached: false,
79            error: None,
80        }
81    }
82
83    fn decode_chunk_header(&mut self) -> std::io::Result<()> {
84        let control = self.inner.read_u8()?;
85        if control == 0x00 {
86            self.end_reached = true;
87            return Ok(());
88        }
89
90        if control >= 0xE0 || control == 0x01 {
91            self.need_props = true;
92            self.need_dict_reset = false;
93            self.lz.reset();
94        } else if self.need_dict_reset {
95            return Err(std::io::Error::new(
96                ErrorKind::InvalidInput,
97                "Corrupted input data (LZMA2:0)",
98            ));
99        }
100        if control >= 0x80 {
101            self.is_lzma_chunk = true;
102            self.uncompressed_size = ((control & 0x1F) as usize) << 16;
103            self.uncompressed_size += self.inner.read_u16::<BigEndian>()? as usize + 1;
104            let compressed_size = self.inner.read_u16::<BigEndian>()? as usize + 1;
105            if control >= 0xC0 {
106                self.need_props = false;
107                self.decode_props()?;
108            } else if self.need_props {
109                return Err(std::io::Error::new(
110                    ErrorKind::InvalidInput,
111                    "Corrupted input data (LZMA2:1)",
112                ));
113            } else if control >= 0xA0 {
114                if let Some(l) = self.lzma.as_mut() {
115                    l.reset()
116                }
117            }
118            self.rc.prepare(&mut self.inner, compressed_size)?;
119        } else if control > 0x02 {
120            return Err(std::io::Error::new(
121                ErrorKind::InvalidInput,
122                "Corrupted input data (LZMA2:2)",
123            ));
124        } else {
125            self.is_lzma_chunk = false;
126            self.uncompressed_size = (self.inner.read_u16::<BigEndian>()? + 1) as _;
127        }
128        Ok(())
129    }
130
131    fn decode_props(&mut self) -> std::io::Result<()> {
132        let props = self.inner.read_u8()?;
133        if props > (4 * 5 + 4) * 9 + 8 {
134            return Err(std::io::Error::new(
135                ErrorKind::InvalidInput,
136                "Corrupted input data (LZMA2:3)",
137            ));
138        }
139        let pb = props / (9 * 5);
140        let props = props - pb * 9 * 5;
141        let lp = props / 9;
142        let lc = props - lp * 9;
143        if lc + lp > 4 {
144            return Err(std::io::Error::new(
145                ErrorKind::InvalidInput,
146                "Corrupted input data (LZMA2:4)",
147            ));
148        }
149        self.lzma = Some(LZMADecoder::new(lc as _, lp as _, pb as _));
150
151        Ok(())
152    }
153
154    fn read_decode(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
155        if buf.is_empty() {
156            return Ok(0);
157        }
158        if let Some(e) = &self.error {
159            return Err(std::io::Error::new(e.kind(), e.to_string()));
160        }
161
162        if self.end_reached {
163            return Ok(0);
164        }
165        let mut size = 0;
166        let mut len = buf.len();
167        let mut off = 0;
168        while len > 0 {
169            if self.uncompressed_size == 0 {
170                self.decode_chunk_header()?;
171                if self.end_reached {
172                    return Ok(size);
173                }
174            }
175
176            let copy_size_max = self.uncompressed_size.min(len);
177            if !self.is_lzma_chunk {
178                self.lz.copy_uncompressed(&mut self.inner, copy_size_max)?;
179            } else {
180                self.lz.set_limit(copy_size_max);
181                if let Some(lzma) = self.lzma.as_mut() {
182                    lzma.decode(&mut self.lz, &mut self.rc)?;
183                }
184            }
185
186            {
187                let copied_size = self.lz.flush(buf, off);
188                off += copied_size;
189                len -= copied_size;
190                size += copied_size;
191                self.uncompressed_size -= copied_size;
192                if self.uncompressed_size == 0 && (!self.rc.is_finished() || self.lz.has_pending())
193                {
194                    return Err(std::io::Error::new(
195                        ErrorKind::InvalidInput,
196                        "rc not finished or lz has pending",
197                    ));
198                }
199            }
200        }
201        Ok(size)
202    }
203}
204
205impl<R: Read> Read for LZMA2Reader<R> {
206    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
207        match self.read_decode(buf) {
208            Ok(size) => Ok(size),
209            Err(e) => {
210                let error = std::io::Error::new(e.kind(), e.to_string());
211                self.error = Some(e);
212                Err(error)
213            }
214        }
215    }
216}