lzma_rust2/enc/
lzma2_writer_mt.rs

1use std::{
2    io::{self, Write},
3    sync::{
4        atomic::{AtomicBool, AtomicU32, Ordering},
5        mpsc::SyncSender,
6        Arc, Mutex,
7    },
8};
9
10use super::Lzma2Writer;
11use crate::{
12    error_invalid_input, set_error,
13    work_pool::{WorkPool, WorkPoolConfig},
14    work_queue::WorkerHandle,
15    AutoFinish, AutoFinisher, ByteWriter, Lzma2Options,
16};
17
18/// A work unit for a worker thread.
19#[derive(Debug, Clone)]
20struct WorkUnit {
21    data: Vec<u8>,
22    options: Lzma2Options,
23}
24
25/// A multi-threaded LZMA2 compressor.
26pub struct Lzma2WriterMt<W: Write> {
27    inner: W,
28    options: Lzma2Options,
29    chunk_size: usize,
30    current_work_unit: Vec<u8>,
31    work_pool: WorkPool<WorkUnit, Vec<u8>>,
32}
33
34impl<W: Write> Lzma2WriterMt<W> {
35    /// Creates a new multi-threaded LZMA2 writer.
36    ///
37    /// - `inner`: The writer to write compressed data to.
38    /// - `options`: The LZMA2 options used for compressing. Chunk size must be set when using the
39    ///   multi-threaded encoder. If you need just one chunk, then use the single-threaded encoder.
40    /// - `num_workers`: The maximum number of worker threads for compression.
41    ///   Currently capped at 256 Threads.
42    pub fn new(inner: W, options: Lzma2Options, num_workers: u32) -> crate::Result<Self> {
43        let chunk_size = match options.chunk_size {
44            None => return Err(error_invalid_input("chunk size must be set")),
45            Some(chunk_size) => chunk_size.get().max(options.lzma_options.dict_size as u64),
46        };
47
48        let chunk_size = usize::try_from(chunk_size)
49            .map_err(|_| error_invalid_input("chunk size bigger than usize"))?;
50
51        // We don't know how many work units we'll have ahead of time.
52        let num_work = u64::MAX;
53
54        Ok(Self {
55            inner,
56            options,
57            chunk_size,
58            current_work_unit: Vec::with_capacity(chunk_size),
59            work_pool: WorkPool::new(
60                WorkPoolConfig::new(num_workers, num_work),
61                worker_thread_logic,
62            ),
63        })
64    }
65
66    /// Sends the current work unit to the workers.
67    fn send_work_unit(&mut self) -> io::Result<()> {
68        if self.current_work_unit.is_empty() {
69            return Ok(());
70        }
71
72        self.drain_available_results()?;
73
74        let work_data = core::mem::take(&mut self.current_work_unit);
75        let mut single_chunk_options = self.options.clone();
76        single_chunk_options.chunk_size = None;
77        single_chunk_options.lzma_options.preset_dict = None;
78
79        let mut work_data_opt = Some(work_data);
80
81        self.work_pool.dispatch_next_work(&mut |_seq| {
82            let data = work_data_opt.take().ok_or_else(|| {
83                io::Error::new(io::ErrorKind::InvalidInput, "work already provided")
84            })?;
85            Ok(WorkUnit {
86                data,
87                options: single_chunk_options.clone(),
88            })
89        })?;
90
91        self.drain_available_results()?;
92
93        Ok(())
94    }
95
96    /// Drains all currently available results from the work pool and writes them.
97    fn drain_available_results(&mut self) -> io::Result<()> {
98        while let Some(compressed_data) = self.work_pool.try_get_result()? {
99            self.inner.write_all(&compressed_data)?;
100        }
101        Ok(())
102    }
103
104    /// Returns a wrapper around `self` that will finish the stream on drop.
105    pub fn auto_finish(self) -> AutoFinisher<Self> {
106        AutoFinisher(Some(self))
107    }
108
109    /// Consume the Lzma2WriterMt and return the inner writer.
110    pub fn into_inner(self) -> W {
111        self.inner
112    }
113
114    /// Finishes the compression and returns the underlying writer.
115    pub fn finish(mut self) -> io::Result<W> {
116        if !self.current_work_unit.is_empty() {
117            self.send_work_unit()?;
118        }
119
120        // If no data was provided to compress, write an empty LZMA2 stream.
121        if self.work_pool.next_index_to_dispatch() == 0 {
122            self.inner.write_u8(0x00)?;
123            self.inner.flush()?;
124
125            return Ok(self.inner);
126        }
127
128        // Mark the WorkPool as finished so it knows no more work is coming.
129        self.work_pool.finish();
130
131        // Wait for all remaining work to complete.
132        while let Some(compressed_data) = self.work_pool.get_result(|_| {
133            Err(io::Error::new(
134                io::ErrorKind::InvalidInput,
135                "no more work to dispatch",
136            ))
137        })? {
138            self.inner.write_all(&compressed_data)?;
139        }
140
141        self.inner.write_u8(0x00)?;
142        self.inner.flush()?;
143
144        Ok(self.inner)
145    }
146}
147
148/// The logic for a single worker thread.
149fn worker_thread_logic(
150    worker_handle: WorkerHandle<(u64, WorkUnit)>,
151    result_tx: SyncSender<(u64, Vec<u8>)>,
152    shutdown_flag: Arc<AtomicBool>,
153    error_store: Arc<Mutex<Option<io::Error>>>,
154    active_workers: Arc<AtomicU32>,
155) {
156    while !shutdown_flag.load(Ordering::Acquire) {
157        let (index, work_unit) = match worker_handle.steal() {
158            Some(work) => {
159                active_workers.fetch_add(1, Ordering::Release);
160                work
161            }
162            None => {
163                // No more work available and queue is closed.
164                break;
165            }
166        };
167
168        let mut compressed_buffer = Vec::new();
169
170        let mut writer = Lzma2Writer::new(&mut compressed_buffer, work_unit.options);
171
172        let result = match writer.write_all(&work_unit.data) {
173            Ok(_) => match writer.flush() {
174                Ok(_) => compressed_buffer,
175                Err(error) => {
176                    active_workers.fetch_sub(1, Ordering::Release);
177                    set_error(error, &error_store, &shutdown_flag);
178                    return;
179                }
180            },
181            Err(error) => {
182                active_workers.fetch_sub(1, Ordering::Release);
183                set_error(error, &error_store, &shutdown_flag);
184                return;
185            }
186        };
187
188        if result_tx.send((index, result)).is_err() {
189            active_workers.fetch_sub(1, Ordering::Release);
190            return;
191        }
192
193        active_workers.fetch_sub(1, Ordering::Release);
194    }
195}
196
197impl<W: Write> Write for Lzma2WriterMt<W> {
198    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
199        if buf.is_empty() {
200            return Ok(0);
201        }
202
203        let mut total_written = 0;
204        let mut remaining_buf = buf;
205
206        while !remaining_buf.is_empty() {
207            let chunk_remaining = self.chunk_size.saturating_sub(self.current_work_unit.len());
208            let to_write = remaining_buf.len().min(chunk_remaining);
209
210            if to_write > 0 {
211                self.current_work_unit
212                    .extend_from_slice(&remaining_buf[..to_write]);
213                total_written += to_write;
214                remaining_buf = &remaining_buf[to_write..];
215            }
216
217            if self.current_work_unit.len() >= self.chunk_size {
218                self.send_work_unit()?;
219            }
220
221            self.drain_available_results()?;
222        }
223
224        Ok(total_written)
225    }
226
227    fn flush(&mut self) -> io::Result<()> {
228        if !self.current_work_unit.is_empty() {
229            self.send_work_unit()?;
230        }
231
232        // Wait for all pending work to complete and write the results.
233        while let Some(compressed_data) = self.work_pool.try_get_result()? {
234            self.inner.write_all(&compressed_data)?;
235        }
236
237        self.inner.flush()
238    }
239}
240
241impl<W: Write> AutoFinish for Lzma2WriterMt<W> {
242    fn finish_ignore_error(self) {
243        let _ = self.finish();
244    }
245}