lzma_rust2/xz/
reader.rs

1use alloc::boxed::Box;
2
3use super::{
4    BlockHeader, ChecksumCalculator, FilterType, Index, StreamFooter, StreamHeader, XZ_MAGIC,
5};
6use crate::{
7    error_invalid_data,
8    filter::{bcj::BcjReader, delta::DeltaReader},
9    CountingReader, Lzma2Reader, Read, Result,
10};
11
12#[allow(clippy::large_enum_variant)]
13enum FilterReader<R: Read> {
14    Counting(CountingReader<R>),
15    Lzma2(Lzma2Reader<Box<FilterReader<R>>>),
16    Delta(DeltaReader<Box<FilterReader<R>>>),
17    Bcj(BcjReader<Box<FilterReader<R>>>),
18    Dummy,
19}
20
21impl<R: Read> Read for FilterReader<R> {
22    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
23        match self {
24            FilterReader::Counting(reader) => reader.read(buf),
25            FilterReader::Lzma2(reader) => reader.read(buf),
26            FilterReader::Delta(reader) => reader.read(buf),
27            FilterReader::Bcj(reader) => reader.read(buf),
28            FilterReader::Dummy => unimplemented!(),
29        }
30    }
31}
32
33impl<R: Read> FilterReader<R> {
34    fn create_filter_chain(inner: R, filters: &[Option<FilterType>], properties: &[u32]) -> Self {
35        let mut chain_reader = FilterReader::Counting(CountingReader::new(inner));
36
37        for (filter, property) in filters
38            .iter()
39            .copied()
40            .zip(properties)
41            .filter_map(|(filter, property)| filter.map(|filter| (filter, *property)))
42            .rev()
43        {
44            chain_reader = match filter {
45                FilterType::Delta => {
46                    let distance = property as usize;
47                    FilterReader::Delta(DeltaReader::new(Box::new(chain_reader), distance))
48                }
49                FilterType::BcjX86 => {
50                    let start_offset = property as usize;
51                    FilterReader::Bcj(BcjReader::new_x86(Box::new(chain_reader), start_offset))
52                }
53                FilterType::BcjPpc => {
54                    let start_offset = property as usize;
55                    FilterReader::Bcj(BcjReader::new_ppc(Box::new(chain_reader), start_offset))
56                }
57                FilterType::BcjIa64 => {
58                    let start_offset = property as usize;
59                    FilterReader::Bcj(BcjReader::new_ia64(Box::new(chain_reader), start_offset))
60                }
61                FilterType::BcjArm => {
62                    let start_offset = property as usize;
63                    FilterReader::Bcj(BcjReader::new_arm(Box::new(chain_reader), start_offset))
64                }
65                FilterType::BcjArmThumb => {
66                    let start_offset = property as usize;
67                    FilterReader::Bcj(BcjReader::new_arm_thumb(
68                        Box::new(chain_reader),
69                        start_offset,
70                    ))
71                }
72                FilterType::BcjSparc => {
73                    let start_offset = property as usize;
74                    FilterReader::Bcj(BcjReader::new_sparc(Box::new(chain_reader), start_offset))
75                }
76                FilterType::BcjArm64 => {
77                    let start_offset = property as usize;
78                    FilterReader::Bcj(BcjReader::new_arm64(Box::new(chain_reader), start_offset))
79                }
80                FilterType::BcjRiscv => {
81                    let start_offset = property as usize;
82                    FilterReader::Bcj(BcjReader::new_riscv(Box::new(chain_reader), start_offset))
83                }
84                FilterType::Lzma2 => {
85                    let dict_size = property;
86                    FilterReader::Lzma2(Lzma2Reader::new(Box::new(chain_reader), dict_size, None))
87                }
88            };
89        }
90
91        chain_reader
92    }
93
94    fn bytes_read(&self) -> u64 {
95        match self {
96            FilterReader::Counting(reader) => reader.bytes_read(),
97            FilterReader::Lzma2(reader) => reader.inner().bytes_read(),
98            FilterReader::Delta(reader) => reader.inner().bytes_read(),
99            FilterReader::Bcj(reader) => reader.inner().bytes_read(),
100            FilterReader::Dummy => unimplemented!(),
101        }
102    }
103
104    fn into_inner(self) -> R {
105        match self {
106            FilterReader::Counting(reader) => reader.inner,
107            FilterReader::Lzma2(reader) => {
108                let filter_reader = reader.into_inner();
109                filter_reader.into_inner()
110            }
111            FilterReader::Delta(reader) => {
112                let filter_reader = reader.into_inner();
113                filter_reader.into_inner()
114            }
115            FilterReader::Bcj(reader) => {
116                let filter_reader = reader.into_inner();
117                filter_reader.into_inner()
118            }
119            FilterReader::Dummy => unimplemented!(),
120        }
121    }
122
123    fn inner(&self) -> &R {
124        match self {
125            FilterReader::Counting(reader) => &reader.inner,
126            FilterReader::Lzma2(reader) => {
127                let filter_reader = reader.inner();
128
129                filter_reader.inner()
130            }
131            FilterReader::Delta(reader) => {
132                let filter_reader = reader.inner();
133                filter_reader.inner()
134            }
135            FilterReader::Bcj(reader) => {
136                let filter_reader = reader.inner();
137                filter_reader.inner()
138            }
139            FilterReader::Dummy => unimplemented!(),
140        }
141    }
142
143    fn inner_mut(&mut self) -> &mut R {
144        match self {
145            FilterReader::Counting(reader) => &mut reader.inner,
146            FilterReader::Lzma2(reader) => {
147                let filter_reader = reader.inner_mut();
148                filter_reader.inner_mut()
149            }
150            FilterReader::Delta(reader) => {
151                let filter_reader = reader.inner_mut();
152                filter_reader.inner_mut()
153            }
154            FilterReader::Bcj(reader) => {
155                let filter_reader = reader.inner_mut();
156                filter_reader.inner_mut()
157            }
158            FilterReader::Dummy => unimplemented!(),
159        }
160    }
161}
162
163/// A single-threaded XZ decompressor.
164pub struct XzReader<R: Read> {
165    reader: FilterReader<R>,
166    stream_header: Option<StreamHeader>,
167    checksum_calculator: Option<ChecksumCalculator>,
168    finished: bool,
169    allow_multiple_streams: bool,
170    blocks_processed: u64,
171}
172
173impl<R: Read> XzReader<R> {
174    /// Create a new [`XzReader`].
175    pub fn new(inner: R, allow_multiple_streams: bool) -> Self {
176        let reader = FilterReader::Counting(CountingReader::new(inner));
177
178        Self {
179            reader,
180            stream_header: None,
181            checksum_calculator: None,
182            finished: false,
183            allow_multiple_streams,
184            blocks_processed: 0,
185        }
186    }
187
188    /// Consume the XzReader and return the inner reader.
189    pub fn into_inner(self) -> R {
190        self.reader.into_inner()
191    }
192
193    /// Returns a reference to the inner reader.
194    pub fn inner(&self) -> &R {
195        self.reader.inner()
196    }
197
198    /// Returns a mutable reference to the inner reader.
199    pub fn inner_mut(&mut self) -> &mut R {
200        self.reader.inner_mut()
201    }
202}
203
204impl<R: Read> XzReader<R> {
205    fn ensure_stream_header(&mut self) -> Result<()> {
206        if self.stream_header.is_none() {
207            let header = StreamHeader::parse(&mut self.reader)?;
208            self.stream_header = Some(header);
209        }
210        Ok(())
211    }
212
213    fn prepare_next_block(&mut self) -> Result<bool> {
214        match BlockHeader::parse(&mut self.reader)? {
215            Some(block_header) => {
216                let base_reader: FilterReader<R> =
217                    core::mem::replace(&mut self.reader, FilterReader::Dummy);
218
219                self.reader = FilterReader::create_filter_chain(
220                    base_reader.into_inner(),
221                    &block_header.filters,
222                    &block_header.properties,
223                );
224
225                match self.stream_header.as_ref() {
226                    Some(header) => {
227                        self.checksum_calculator = Some(ChecksumCalculator::new(header.check_type));
228                    }
229                    None => {
230                        panic!("stream_header not set");
231                    }
232                }
233
234                self.blocks_processed += 1;
235
236                Ok(true)
237            }
238            None => {
239                // End of blocks reached, index follows.
240                self.parse_index_and_footer()?;
241
242                if self.allow_multiple_streams && self.try_start_next_stream()? {
243                    return self.prepare_next_block();
244                }
245
246                self.finished = true;
247                Ok(false)
248            }
249        }
250    }
251
252    fn consume_padding(&mut self, compressed_bytes: u64) -> Result<()> {
253        let padding_needed = match (4 - (compressed_bytes % 4)) % 4 {
254            0 => return Ok(()),
255            n => n as usize,
256        };
257
258        let mut padding_buf = [0u8; 3];
259
260        let bytes_read = self.reader.read(&mut padding_buf[..padding_needed])?;
261
262        if bytes_read != padding_needed {
263            return Err(error_invalid_data("incomplete XZ block padding"));
264        }
265
266        if !padding_buf[..bytes_read].iter().all(|&byte| byte == 0) {
267            return Err(error_invalid_data("invalid XZ block padding"));
268        }
269
270        Ok(())
271    }
272
273    fn verify_block_checksum(&mut self) -> Result<()> {
274        let checksum_calculator = self
275            .checksum_calculator
276            .take()
277            .expect("checksum_calculator not set");
278
279        match checksum_calculator {
280            ChecksumCalculator::None => { /* Nothing to check */ }
281            ChecksumCalculator::Crc32(_) => {
282                let mut checksum = [0u8; 4];
283                self.reader.read_exact(&mut checksum)?;
284
285                if !checksum_calculator.verify(&checksum) {
286                    return Err(error_invalid_data("invalid block checksum"));
287                }
288            }
289            ChecksumCalculator::Crc64(_) => {
290                let mut checksum = [0u8; 8];
291                self.reader.read_exact(&mut checksum)?;
292
293                if !checksum_calculator.verify(&checksum) {
294                    return Err(error_invalid_data("invalid block checksum"));
295                }
296            }
297            ChecksumCalculator::Sha256(_) => {
298                let mut checksum = [0u8; 32];
299                self.reader.read_exact(&mut checksum)?;
300
301                if !checksum_calculator.verify(&checksum) {
302                    return Err(error_invalid_data("invalid block checksum"));
303                }
304            }
305        }
306
307        Ok(())
308    }
309
310    /// Look for the start of the next stream by reading bytes one at a time
311    /// and checking for the XZ magic sequence, allowing for stream padding.
312    fn try_start_next_stream(&mut self) -> Result<bool> {
313        let mut padding_bytes = 0;
314        let mut buffer = [0u8; 6];
315
316        loop {
317            let mut byte_buffer = [0u8; 1];
318            let read = self.reader.read(&mut byte_buffer)?;
319            if read == 0 {
320                // EOF reached, no more streams.
321                return Ok(false);
322            }
323
324            let byte = byte_buffer[0];
325
326            if byte == 0 {
327                // Potential stream padding.
328                padding_bytes += 1;
329                continue;
330            }
331
332            // Non-zero byte found - check if it starts XZ magic.
333            if byte != XZ_MAGIC[0] {
334                return Err(error_invalid_data("invalid data after stream"));
335            }
336
337            buffer[0] = byte;
338            let mut buffer_pos = 1;
339
340            // Read the rest of the magic bytes.
341            while buffer_pos < 6 {
342                match self.reader.read(&mut byte_buffer)? {
343                    0 => {
344                        return Err(error_invalid_data("incomplete XZ magic bytes"));
345                    }
346                    1 => {
347                        buffer[buffer_pos] = byte_buffer[0];
348                        buffer_pos += 1;
349                    }
350                    _ => unreachable!(),
351                }
352            }
353
354            if buffer != XZ_MAGIC {
355                return Err(error_invalid_data("invalid data after stream padding"));
356            }
357
358            if padding_bytes % 4 != 0 {
359                return Err(error_invalid_data("stream padding size not multiple of 4"));
360            }
361
362            let stream_header = StreamHeader::parse_stream_header_flags_and_crc(&mut self.reader)?;
363
364            // Reset state for new stream.
365            self.stream_header = Some(stream_header);
366            self.blocks_processed = 0;
367
368            return Ok(true);
369        }
370    }
371
372    fn parse_index_and_footer(&mut self) -> Result<()> {
373        let index = Index::parse(&mut self.reader)?;
374
375        if index.number_of_records != self.blocks_processed {
376            return Err(error_invalid_data(
377                "number of blocks processed doesn't match index records",
378            ));
379        }
380
381        let stream_footer = StreamFooter::parse(&mut self.reader)?;
382
383        let header = self.stream_header.as_ref().expect("stream_header not set");
384
385        let header_flags = [0, header.check_type as u8];
386        if stream_footer.stream_flags != header_flags {
387            return Err(error_invalid_data(
388                "stream header and footer flags mismatch",
389            ));
390        }
391
392        Ok(())
393    }
394}
395
396impl<R: Read> Read for XzReader<R> {
397    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
398        if self.finished {
399            return Ok(0);
400        }
401
402        self.ensure_stream_header()?;
403
404        loop {
405            if self.checksum_calculator.is_some() {
406                let bytes_read = self.reader.read(buf)?;
407
408                if bytes_read > 0 {
409                    if let Some(ref mut calc) = self.checksum_calculator {
410                        calc.update(&buf[..bytes_read]);
411                    }
412
413                    return Ok(bytes_read);
414                } else {
415                    let reader = core::mem::replace(&mut self.reader, FilterReader::Dummy);
416                    let compressed_bytes = reader.bytes_read();
417                    self.reader = FilterReader::Counting(CountingReader::with_count(
418                        reader.into_inner(),
419                        compressed_bytes,
420                    ));
421
422                    self.consume_padding(compressed_bytes)?;
423                    self.verify_block_checksum()?;
424                }
425            } else {
426                // No current block, prepare the next one.
427                if !self.prepare_next_block()? {
428                    // No more blocks, we're done.
429                    return Ok(0);
430                }
431            }
432        }
433    }
434}