lzma_rust2/xz/
writer_mt.rs

1use std::{
2    io::{self, Write},
3    sync::{
4        atomic::{AtomicBool, AtomicU32, Ordering},
5        mpsc::SyncSender,
6        Arc, Mutex,
7    },
8};
9
10use super::{
11    add_padding, write_xz_block_header, write_xz_index, write_xz_stream_footer,
12    write_xz_stream_header, CheckType, ChecksumCalculator, FilterConfig, FilterType, IndexRecord,
13};
14use crate::{
15    enc::{Lzma2Writer, LzmaOptions},
16    error_invalid_input, set_error,
17    work_pool::{WorkPool, WorkPoolConfig},
18    work_queue::WorkerHandle,
19    AutoFinish, AutoFinisher, Lzma2Options, Result, XzOptions,
20};
21
22/// A work unit for a worker thread.
23#[derive(Debug, Clone)]
24struct WorkUnit {
25    uncompressed_data: Vec<u8>,
26    lzma_options: LzmaOptions,
27    check_type: CheckType,
28}
29
30/// A result unit from a worker thread.
31#[derive(Debug)]
32struct ResultUnit {
33    compressed_data: Vec<u8>,
34    checksum: Vec<u8>,
35    uncompressed_size: u64,
36}
37
38/// A multi-threaded XZ compressor.
39pub struct XzWriterMt<W: Write> {
40    inner: W,
41    options: XzOptions,
42    current_work_unit: Vec<u8>,
43    block_size: usize,
44    work_pool: WorkPool<WorkUnit, ResultUnit>,
45    index_records: Vec<IndexRecord>,
46    checksum_calculator: ChecksumCalculator,
47    header_written: bool,
48    total_uncompressed_pos: u64,
49}
50
51impl<W: Write> XzWriterMt<W> {
52    /// Creates a new multi-threaded XZ writer.
53    ///
54    /// - `inner`: The writer to write compressed data to.
55    /// - `options`: The XZ options used for compressing. Block size must be set when using the
56    ///   multi-threaded encoder. If you need just one block, then use the single-threaded encoder.
57    /// - `num_workers`: The maximum number of worker threads for compression.
58    ///   Currently capped at 256 threads.
59    pub fn new(inner: W, options: XzOptions, num_workers: u32) -> Result<Self> {
60        if options.filters.len() > 3 {
61            return Err(error_invalid_input(
62                "XZ allows only at most 3 pre-filters plus LZMA2",
63            ));
64        }
65
66        let block_size = match options.block_size {
67            None => return Err(error_invalid_input("block size must be set")),
68            Some(block_size) => block_size.get().max(options.lzma_options.dict_size as u64),
69        };
70
71        let block_size = usize::try_from(block_size)
72            .map_err(|_| error_invalid_input("block size bigger than usize"))?;
73
74        let checksum_calculator = ChecksumCalculator::new(options.check_type);
75
76        // We don't know how many work units we'll have ahead of time.
77        let num_work = u64::MAX;
78
79        Ok(Self {
80            inner,
81            options,
82            current_work_unit: Vec::with_capacity(block_size.min(1024 * 1024)),
83            block_size,
84            work_pool: WorkPool::new(
85                WorkPoolConfig::new(num_workers, num_work),
86                worker_thread_logic,
87            ),
88            index_records: Vec::new(),
89            checksum_calculator,
90            header_written: false,
91            total_uncompressed_pos: 0,
92        })
93    }
94
95    fn write_stream_header(&mut self) -> Result<()> {
96        if self.header_written {
97            return Ok(());
98        }
99
100        write_xz_stream_header(&mut self.inner, self.options.check_type)?;
101        self.header_written = true;
102
103        Ok(())
104    }
105
106    fn write_block_header(&mut self, _block_uncompressed_size: u64) -> Result<u64> {
107        // Add LZMA2 filter to the list
108        let mut filters = self.options.filters.clone();
109        filters.push(FilterConfig {
110            filter_type: FilterType::Lzma2,
111            property: 0,
112        });
113
114        write_xz_block_header(
115            &mut self.inner,
116            &filters,
117            self.options.lzma_options.dict_size,
118        )
119    }
120
121    /// Sends the current work unit to the workers.
122    fn send_work_unit(&mut self) -> Result<()> {
123        if self.current_work_unit.is_empty() {
124            return Ok(());
125        }
126
127        // Ensure stream header is written before any blocks
128        self.write_stream_header()?;
129
130        self.drain_available_results()?;
131
132        let work_data = core::mem::take(&mut self.current_work_unit);
133        let mut work_data_opt = Some(work_data);
134
135        self.work_pool.dispatch_next_work(&mut |_seq| {
136            let data = work_data_opt.take().ok_or_else(|| {
137                io::Error::new(io::ErrorKind::InvalidInput, "work already provided")
138            })?;
139            Ok(WorkUnit {
140                uncompressed_data: data,
141                lzma_options: self.options.lzma_options.clone(),
142                check_type: self.options.check_type,
143            })
144        })?;
145
146        self.drain_available_results()?;
147
148        Ok(())
149    }
150
151    /// Drains all currently available results from the work pool and writes them.
152    fn drain_available_results(&mut self) -> Result<()> {
153        while let Some(result) = self.work_pool.try_get_result()? {
154            self.write_compressed_block(
155                result.compressed_data,
156                result.checksum,
157                result.uncompressed_size,
158            )?;
159        }
160        Ok(())
161    }
162
163    fn write_compressed_block(
164        &mut self,
165        compressed_data: Vec<u8>,
166        checksum: Vec<u8>,
167        block_uncompressed_size: u64,
168    ) -> Result<()> {
169        let block_header_size = self.write_block_header(block_uncompressed_size)?;
170
171        let data_size = compressed_data.len() as u64;
172        let padding_needed = (4 - (data_size % 4)) % 4;
173
174        self.inner.write_all(&compressed_data)?;
175
176        add_padding(&mut self.inner, padding_needed as usize)?;
177
178        self.inner.write_all(&checksum)?;
179
180        let unpadded_size = block_header_size + data_size + self.options.check_type.checksum_size();
181        self.index_records.push(IndexRecord {
182            unpadded_size,
183            uncompressed_size: block_uncompressed_size,
184        });
185
186        self.total_uncompressed_pos += block_uncompressed_size;
187
188        Ok(())
189    }
190
191    /// Returns a wrapper around `self` that will finish the stream on drop.
192    pub fn auto_finish(self) -> AutoFinisher<Self> {
193        AutoFinisher(Some(self))
194    }
195
196    /// Consume the XzWriterMt and return the inner writer.
197    pub fn into_inner(self) -> W {
198        self.inner
199    }
200
201    #[inline(always)]
202    fn write_index(&mut self) -> Result<()> {
203        write_xz_index(&mut self.inner, &self.index_records)
204    }
205
206    #[inline(always)]
207    fn write_stream_footer(&mut self) -> Result<()> {
208        write_xz_stream_footer(
209            &mut self.inner,
210            &self.index_records,
211            self.options.check_type,
212        )
213    }
214
215    /// Finishes the compression and returns the underlying writer.
216    pub fn finish(mut self) -> Result<W> {
217        self.write_stream_header()?;
218
219        if !self.current_work_unit.is_empty() {
220            self.send_work_unit()?;
221        }
222
223        // If no data was provided to compress, write an empty XZ file.
224        if self.work_pool.next_index_to_dispatch() == 0 {
225            // Write empty index and footer
226            self.write_index()?;
227            self.write_stream_footer()?;
228
229            self.inner.flush()?;
230
231            return Ok(self.inner);
232        }
233
234        // Mark the WorkPool as finished so it knows no more work is coming.
235        self.work_pool.finish();
236
237        // Wait for all remaining work to complete.
238        while let Some(result) = self.work_pool.get_result(|_| {
239            Err(io::Error::new(
240                io::ErrorKind::InvalidInput,
241                "no more work to dispatch",
242            ))
243        })? {
244            self.write_compressed_block(
245                result.compressed_data,
246                result.checksum,
247                result.uncompressed_size,
248            )?;
249        }
250
251        self.write_index()?;
252        self.write_stream_footer()?;
253
254        self.inner.flush()?;
255
256        Ok(self.inner)
257    }
258}
259
260/// The logic for a single worker thread.
261fn worker_thread_logic(
262    worker_handle: WorkerHandle<(u64, WorkUnit)>,
263    result_tx: SyncSender<(u64, ResultUnit)>,
264    shutdown_flag: Arc<AtomicBool>,
265    error_store: Arc<Mutex<Option<io::Error>>>,
266    active_workers: Arc<AtomicU32>,
267) {
268    while !shutdown_flag.load(Ordering::Acquire) {
269        let (index, work_unit) = match worker_handle.steal() {
270            Some(work) => {
271                active_workers.fetch_add(1, Ordering::Release);
272                work
273            }
274            None => {
275                // No more work available and queue is closed.
276                break;
277            }
278        };
279
280        let mut compressed_buffer = Vec::new();
281        let uncompressed_size = work_unit.uncompressed_data.len() as u64;
282
283        let mut checksum_calculator = ChecksumCalculator::new(work_unit.check_type);
284        checksum_calculator.update(&work_unit.uncompressed_data);
285        let checksum = checksum_calculator.finalize_to_bytes();
286
287        let options = Lzma2Options {
288            lzma_options: work_unit.lzma_options,
289            ..Default::default()
290        };
291
292        let mut writer = Lzma2Writer::new(&mut compressed_buffer, options);
293        let result = match writer.write_all(&work_unit.uncompressed_data) {
294            Ok(_) => match writer.finish() {
295                Ok(_) => ResultUnit {
296                    compressed_data: compressed_buffer,
297                    checksum,
298                    uncompressed_size,
299                },
300                Err(error) => {
301                    active_workers.fetch_sub(1, Ordering::Release);
302                    set_error(error, &error_store, &shutdown_flag);
303                    return;
304                }
305            },
306            Err(error) => {
307                active_workers.fetch_sub(1, Ordering::Release);
308                set_error(error, &error_store, &shutdown_flag);
309                return;
310            }
311        };
312
313        if result_tx.send((index, result)).is_err() {
314            active_workers.fetch_sub(1, Ordering::Release);
315            return;
316        }
317
318        active_workers.fetch_sub(1, Ordering::Release);
319    }
320}
321
322impl<W: Write> Write for XzWriterMt<W> {
323    fn write(&mut self, buf: &[u8]) -> Result<usize> {
324        if buf.is_empty() {
325            return Ok(0);
326        }
327
328        let mut total_written = 0;
329        let mut remaining_buf = buf;
330
331        while !remaining_buf.is_empty() {
332            let block_remaining = self.block_size.saturating_sub(self.current_work_unit.len());
333            let to_write = remaining_buf.len().min(block_remaining);
334
335            if to_write > 0 {
336                self.current_work_unit
337                    .extend_from_slice(&remaining_buf[..to_write]);
338                total_written += to_write;
339                remaining_buf = &remaining_buf[to_write..];
340            }
341
342            if self.current_work_unit.len() >= self.block_size {
343                self.send_work_unit()?;
344            }
345
346            self.drain_available_results()?;
347        }
348
349        Ok(total_written)
350    }
351
352    fn flush(&mut self) -> Result<()> {
353        if !self.current_work_unit.is_empty() {
354            self.send_work_unit()?;
355        }
356
357        // Wait for all pending work to complete and write the results.
358        while let Some(result) = self.work_pool.try_get_result()? {
359            self.write_compressed_block(
360                result.compressed_data,
361                result.checksum,
362                result.uncompressed_size,
363            )?;
364        }
365
366        self.inner.flush()
367    }
368}
369
370impl<W: Write> AutoFinish for XzWriterMt<W> {
371    fn finish_ignore_error(self) {
372        let _ = self.finish();
373    }
374}