lzma_rust2/xz/
writer.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::num::NonZeroU64;
3
4use super::{
5    add_padding, write_xz_block_header, write_xz_index, write_xz_stream_footer,
6    write_xz_stream_header, CheckType, ChecksumCalculator, FilterConfig, FilterType, IndexRecord,
7};
8use crate::{
9    enc::{Lzma2Writer, LzmaOptions},
10    error_invalid_data, error_invalid_input,
11    filter::{bcj::BcjWriter, delta::DeltaWriter},
12    AutoFinish, AutoFinisher, CountingWriter, Lzma2Options, Result, Write,
13};
14
15#[allow(clippy::large_enum_variant)]
16enum FilterWriter<W: Write> {
17    Counting(CountingWriter<W>),
18    Lzma2(Lzma2Writer<Box<FilterWriter<W>>>),
19    Delta(DeltaWriter<Box<FilterWriter<W>>>),
20    Bcj(BcjWriter<Box<FilterWriter<W>>>),
21    Dummy,
22}
23
24impl<W: Write> Write for FilterWriter<W> {
25    fn write(&mut self, buf: &[u8]) -> Result<usize> {
26        match self {
27            FilterWriter::Counting(writer) => writer.write(buf),
28            FilterWriter::Lzma2(writer) => writer.write(buf),
29            FilterWriter::Delta(writer) => writer.write(buf),
30            FilterWriter::Bcj(writer) => writer.write(buf),
31            FilterWriter::Dummy => unimplemented!(),
32        }
33    }
34
35    fn flush(&mut self) -> Result<()> {
36        match self {
37            FilterWriter::Counting(writer) => writer.flush(),
38            FilterWriter::Lzma2(writer) => writer.flush(),
39            FilterWriter::Delta(writer) => writer.flush(),
40            FilterWriter::Bcj(writer) => writer.flush(),
41            FilterWriter::Dummy => unimplemented!(),
42        }
43    }
44}
45
46impl<W: Write> FilterWriter<W> {
47    fn create_filter_chain(
48        inner: CountingWriter<W>,
49        filters: &[FilterConfig],
50        lzma_options: &LzmaOptions,
51    ) -> Result<Self> {
52        let mut chain_writer = FilterWriter::Counting(inner);
53
54        for filter_config in filters.iter().rev() {
55            chain_writer = match filter_config.filter_type {
56                FilterType::Delta => {
57                    let distance = filter_config.property as usize;
58                    FilterWriter::Delta(DeltaWriter::new(Box::new(chain_writer), distance))
59                }
60                FilterType::BcjX86 => {
61                    let start_offset = filter_config.property as usize;
62                    FilterWriter::Bcj(BcjWriter::new_x86(Box::new(chain_writer), start_offset))
63                }
64                FilterType::BcjPpc => {
65                    let start_offset = filter_config.property as usize;
66                    FilterWriter::Bcj(BcjWriter::new_ppc(Box::new(chain_writer), start_offset))
67                }
68                FilterType::BcjIa64 => {
69                    let start_offset = filter_config.property as usize;
70                    FilterWriter::Bcj(BcjWriter::new_ia64(Box::new(chain_writer), start_offset))
71                }
72                FilterType::BcjArm => {
73                    let start_offset = filter_config.property as usize;
74                    FilterWriter::Bcj(BcjWriter::new_arm(Box::new(chain_writer), start_offset))
75                }
76                FilterType::BcjArmThumb => {
77                    let start_offset = filter_config.property as usize;
78                    FilterWriter::Bcj(BcjWriter::new_arm_thumb(
79                        Box::new(chain_writer),
80                        start_offset,
81                    ))
82                }
83                FilterType::BcjSparc => {
84                    let start_offset = filter_config.property as usize;
85                    FilterWriter::Bcj(BcjWriter::new_sparc(Box::new(chain_writer), start_offset))
86                }
87                FilterType::BcjArm64 => {
88                    let start_offset = filter_config.property as usize;
89                    FilterWriter::Bcj(BcjWriter::new_arm64(Box::new(chain_writer), start_offset))
90                }
91                FilterType::BcjRiscv => {
92                    let start_offset = filter_config.property as usize;
93                    FilterWriter::Bcj(BcjWriter::new_riscv(Box::new(chain_writer), start_offset))
94                }
95                FilterType::Lzma2 => {
96                    let options = Lzma2Options {
97                        lzma_options: lzma_options.clone(),
98                        ..Default::default()
99                    };
100                    FilterWriter::Lzma2(Lzma2Writer::new(Box::new(chain_writer), options))
101                }
102            };
103        }
104
105        Ok(chain_writer)
106    }
107
108    fn into_inner(self) -> W {
109        match self {
110            FilterWriter::Counting(writer) => writer.inner,
111            FilterWriter::Lzma2(writer) => {
112                let filter_writer = writer.into_inner();
113                filter_writer.into_inner()
114            }
115            FilterWriter::Delta(writer) => {
116                let filter_writer = writer.into_inner();
117                filter_writer.into_inner()
118            }
119            FilterWriter::Bcj(writer) => {
120                let filter_writer = writer.into_inner();
121                filter_writer.into_inner()
122            }
123            FilterWriter::Dummy => unimplemented!(),
124        }
125    }
126
127    fn inner(&self) -> &W {
128        match self {
129            FilterWriter::Counting(writer) => &writer.inner,
130            FilterWriter::Lzma2(writer) => {
131                let filter_writer = writer.inner();
132                filter_writer.inner()
133            }
134            FilterWriter::Delta(writer) => {
135                let filter_writer = writer.inner();
136                filter_writer.inner()
137            }
138            FilterWriter::Bcj(writer) => {
139                let filter_writer = writer.inner();
140                filter_writer.inner()
141            }
142            FilterWriter::Dummy => unimplemented!(),
143        }
144    }
145
146    fn inner_mut(&mut self) -> &mut W {
147        match self {
148            FilterWriter::Counting(writer) => &mut writer.inner,
149            FilterWriter::Lzma2(writer) => {
150                let filter_writer = writer.inner_mut();
151                filter_writer.inner_mut()
152            }
153            FilterWriter::Delta(writer) => {
154                let filter_writer = writer.inner_mut();
155                filter_writer.inner_mut()
156            }
157            FilterWriter::Bcj(writer) => {
158                let filter_writer = writer.inner_mut();
159                filter_writer.inner_mut()
160            }
161            FilterWriter::Dummy => unimplemented!(),
162        }
163    }
164
165    fn finish(self) -> Result<CountingWriter<W>> {
166        match self {
167            FilterWriter::Counting(writer) => Ok(writer),
168            FilterWriter::Lzma2(writer) => {
169                let inner_writer = writer.finish()?;
170                inner_writer.finish()
171            }
172            FilterWriter::Delta(writer) => {
173                let inner_writer = writer.into_inner();
174                inner_writer.finish()
175            }
176            FilterWriter::Bcj(writer) => {
177                let inner_writer = writer.finish()?;
178                inner_writer.finish()
179            }
180            FilterWriter::Dummy => unimplemented!(),
181        }
182    }
183}
184
185/// Configuration options for XZ compression.
186#[derive(Default, Debug, Clone)]
187pub struct XzOptions {
188    /// LZMA compression options.
189    pub lzma_options: LzmaOptions,
190    /// Checksum type to use.
191    pub check_type: CheckType,
192    /// Maximum uncompressed size for each block (None = single block).
193    /// Will get clamped to be at least the dict size to not waste memory.
194    pub block_size: Option<NonZeroU64>,
195    /// Pre-filter to use (at most 3).
196    pub filters: Vec<FilterConfig>,
197}
198
199impl XzOptions {
200    /// Create options with specific preset and checksum type.
201    pub fn with_preset(preset: u32) -> Self {
202        Self {
203            lzma_options: LzmaOptions::with_preset(preset),
204            check_type: CheckType::Crc64,
205            block_size: None,
206            filters: Vec::new(),
207        }
208    }
209
210    /// Set the checksum type to use (Default is CRC64).
211    pub fn set_check_sum_type(&mut self, check_type: CheckType) {
212        self.check_type = check_type;
213    }
214
215    /// Set the maximum block size (None means a single block, which is the default).
216    pub fn set_block_size(&mut self, block_size: Option<NonZeroU64>) {
217        self.block_size = block_size;
218    }
219
220    /// Prepend a filter to the chain. You can prepend at most 3 additional filter.
221    pub fn prepend_pre_filter(&mut self, filter_type: FilterType, property: u32) {
222        self.filters.insert(
223            0,
224            FilterConfig {
225                filter_type,
226                property,
227            },
228        );
229    }
230}
231
232/// A single-threaded XZ compressor.
233pub struct XzWriter<W: Write> {
234    writer: FilterWriter<W>,
235    options: XzOptions,
236    index_records: Vec<IndexRecord>,
237    block_uncompressed_size: u64,
238    checksum_calculator: ChecksumCalculator,
239    header_written: bool,
240    finished: bool,
241    total_uncompressed_pos: u64,
242    current_block_start_pos: u64,
243    current_block_header_size: u64,
244}
245
246impl<W: Write> XzWriter<W> {
247    /// Create a new XZ writer with the given options.
248    pub fn new(inner: W, options: XzOptions) -> Result<Self> {
249        let mut options = options;
250
251        if options.filters.len() > 3 {
252            return Err(error_invalid_input(
253                "XZ allows only at most 3 pre-filters plus LZMA2",
254            ));
255        }
256
257        if let Some(block_size) = options.block_size.as_mut() {
258            *block_size =
259                NonZeroU64::new(block_size.get().max(options.lzma_options.dict_size as u64))
260                    .expect("block size is zero");
261        }
262
263        // Last filter is always LZMA2.
264        options.filters.push(FilterConfig {
265            filter_type: FilterType::Lzma2,
266            property: 0,
267        });
268
269        let checksum_calculator = ChecksumCalculator::new(options.check_type);
270        let writer = FilterWriter::Counting(CountingWriter::new(inner));
271
272        Ok(Self {
273            writer,
274            options,
275            index_records: Vec::new(),
276            block_uncompressed_size: 0,
277            checksum_calculator,
278            header_written: false,
279            finished: false,
280            total_uncompressed_pos: 0,
281            current_block_start_pos: 0,
282            current_block_header_size: 0,
283        })
284    }
285
286    /// Returns a wrapper around `self` that will finish the stream on drop.
287    pub fn auto_finish(self) -> AutoFinisher<Self> {
288        AutoFinisher(Some(self))
289    }
290
291    /// Consume the XzWriter and return the inner writer.
292    pub fn into_inner(self) -> W {
293        self.writer.into_inner()
294    }
295
296    /// Returns a reference to the inner writer.
297    pub fn inner(&self) -> &W {
298        self.writer.inner()
299    }
300
301    /// Returns a mutable reference to the inner writer.
302    pub fn inner_mut(&mut self) -> &mut W {
303        self.writer.inner_mut()
304    }
305
306    fn write_stream_header(&mut self) -> Result<()> {
307        if self.header_written {
308            return Ok(());
309        }
310
311        write_xz_stream_header(&mut self.writer, self.options.check_type)?;
312
313        self.header_written = true;
314
315        Ok(())
316    }
317
318    fn prepare_next_block(&mut self) -> Result<()> {
319        let writer = core::mem::replace(&mut self.writer, FilterWriter::Dummy);
320        let counting_writer = writer.finish()?;
321        self.writer = FilterWriter::Counting(counting_writer);
322
323        self.current_block_header_size = write_xz_block_header(
324            &mut self.writer,
325            &self.options.filters,
326            self.options.lzma_options.dict_size,
327        )?;
328
329        let writer = core::mem::replace(&mut self.writer, FilterWriter::Dummy);
330        let counting_writer = writer.finish()?;
331        let bytes_written = counting_writer.bytes_written();
332        self.current_block_start_pos = bytes_written;
333
334        self.writer = FilterWriter::create_filter_chain(
335            counting_writer,
336            &self.options.filters,
337            &self.options.lzma_options,
338        )?;
339
340        self.block_uncompressed_size = 0;
341
342        Ok(())
343    }
344
345    fn should_finish_block(&self) -> bool {
346        if let Some(block_size) = self.options.block_size {
347            self.block_uncompressed_size >= block_size.get()
348        } else {
349            false
350        }
351    }
352
353    fn finish_current_block(&mut self) -> Result<()> {
354        // Finish the filter chain and get back to the counting writer.
355        let writer = core::mem::replace(&mut self.writer, FilterWriter::Dummy);
356        let counting_writer = writer.finish()?;
357        let bytes_written = counting_writer.bytes_written();
358        self.writer = FilterWriter::Counting(counting_writer);
359
360        let block_compressed_size = bytes_written - self.current_block_start_pos;
361
362        let data_size = block_compressed_size;
363        let padding_needed = (4 - (data_size % 4)) % 4;
364
365        add_padding(&mut self.writer, padding_needed as usize)?;
366
367        self.write_block_checksum()?;
368
369        let unpadded_size = self.current_block_header_size
370            + block_compressed_size
371            + self.options.check_type.checksum_size();
372
373        self.index_records.push(IndexRecord {
374            unpadded_size,
375            uncompressed_size: self.block_uncompressed_size,
376        });
377
378        self.block_uncompressed_size = 0;
379
380        Ok(())
381    }
382
383    fn get_block_header_size(&self, _compressed_size: u64, _uncompressed_size: u64) -> u64 {
384        // Block header: size_byte(1) + flags(1) + filter_id(1) + props_size(1)
385        // + dict_prop(1) + padding + crc32(4)
386        let base_size: u64 = 9;
387        base_size.div_ceil(4) * 4
388    }
389
390    fn write_block_checksum(&mut self) -> Result<()> {
391        let checksum = self.take_checksum();
392        self.writer.write_all(&checksum)?;
393
394        // Reset checksum calculator for next block.
395        self.checksum_calculator = ChecksumCalculator::new(self.options.check_type);
396
397        Ok(())
398    }
399
400    fn take_checksum(&mut self) -> Vec<u8> {
401        let calculator = core::mem::replace(
402            &mut self.checksum_calculator,
403            ChecksumCalculator::new(self.options.check_type),
404        );
405        calculator.finalize_to_bytes()
406    }
407
408    /// Finish writing the XZ stream and return the inner writer.
409    pub fn finish(mut self) -> Result<W> {
410        if self.finished {
411            return Ok(self.into_inner());
412        }
413
414        self.write_stream_header()?;
415        self.finish_current_block()?;
416
417        write_xz_index(&mut self.writer, &self.index_records)?;
418
419        write_xz_stream_footer(
420            &mut self.writer,
421            &self.index_records,
422            self.options.check_type,
423        )?;
424
425        Ok(self.into_inner())
426    }
427}
428
429impl<W: Write> Write for XzWriter<W> {
430    fn write(&mut self, buf: &[u8]) -> Result<usize> {
431        if self.finished {
432            return Err(error_invalid_data("XzWriter already finished"));
433        }
434
435        self.write_stream_header()?;
436
437        let mut total_written = 0;
438        let mut remaining = buf;
439
440        while !remaining.is_empty() {
441            // Check if we need to start a new block.
442            if self.should_finish_block() {
443                self.finish_current_block()?;
444            }
445
446            // Check if we need to prepare the next block (either first block or after finishing one).
447            if self.block_uncompressed_size == 0 {
448                self.prepare_next_block()?;
449            }
450
451            let max_write_size = match self.options.block_size {
452                Some(block_size) => {
453                    let remaining_capacity = block_size
454                        .get()
455                        .saturating_sub(self.block_uncompressed_size);
456                    remaining.len().min(remaining_capacity as usize)
457                }
458                None => remaining.len(),
459            };
460
461            if max_write_size == 0 {
462                // Block is full, finish it and continue.
463                continue;
464            }
465
466            let chunk_to_write = &remaining[..max_write_size];
467            let written = self.writer.write(chunk_to_write)?;
468
469            self.checksum_calculator.update(&remaining[..written]);
470
471            remaining = &remaining[written..];
472            total_written += written;
473            self.block_uncompressed_size += written as u64;
474            self.total_uncompressed_pos += written as u64;
475        }
476
477        Ok(total_written)
478    }
479
480    fn flush(&mut self) -> Result<()> {
481        self.writer.flush()
482    }
483}
484
485impl<W: Write> AutoFinish for XzWriter<W> {
486    fn finish_ignore_error(self) {
487        let _ = self.finish();
488    }
489}