lzma_rust2/
lzma2_reader_mt.rs

1use std::{
2    collections::BTreeMap,
3    io,
4    io::{Cursor, Read},
5    sync::{
6        atomic::{AtomicBool, AtomicU32, Ordering},
7        mpsc::{self, Receiver, SyncSender},
8        Arc, Mutex,
9    },
10    thread,
11};
12
13use crate::{
14    set_error,
15    work_queue::{WorkStealingQueue, WorkerHandle},
16    Lzma2Reader,
17};
18
19/// A work unit for a worker thread.
20/// Contains the sequence number and the raw compressed bytes for a series of chunks.
21type WorkUnit = (u64, Vec<u8>);
22
23/// A result unit from a worker thread.
24/// Contains the sequence number and the decompressed data.
25type ResultUnit = (u64, Vec<u8>);
26
27enum State {
28    /// Actively reading from the inner reader and sending work to threads.
29    Reading,
30    /// The inner reader has reached EOF. We are now waiting for the remaining
31    /// work to be completed by the worker threads.
32    Draining,
33    /// All data has been decompressed and returned. The stream is exhausted.
34    Finished,
35    /// A fatal error occurred in either the reader or a worker thread.
36    Error,
37}
38
39/// A multi-threaded LZMA2 decompressor.
40pub struct Lzma2ReaderMt<R: Read> {
41    inner: R,
42    result_rx: Receiver<ResultUnit>,
43    result_tx: SyncSender<ResultUnit>,
44    current_work_unit: Vec<u8>,
45    next_sequence_to_dispatch: u64,
46    next_sequence_to_return: u64,
47    last_sequence_id: Option<u64>,
48    out_of_order_chunks: BTreeMap<u64, Vec<u8>>,
49    current_chunk: Cursor<Vec<u8>>,
50    shutdown_flag: Arc<AtomicBool>,
51    error_store: Arc<Mutex<Option<io::Error>>>,
52    state: State,
53    work_queue: WorkStealingQueue<WorkUnit>,
54    active_workers: Arc<AtomicU32>,
55    max_workers: u32,
56    dict_size: u32,
57    preset_dict: Option<Arc<Vec<u8>>>,
58    worker_handles: Vec<thread::JoinHandle<()>>,
59}
60
61impl<R: Read> Lzma2ReaderMt<R> {
62    /// Creates a new multi-threaded LZMA2 reader.
63    ///
64    /// - `inner`: The reader to read compressed data from.
65    /// - `dict_size`: The dictionary size in bytes, as specified in the stream properties.
66    /// - `preset_dict`: An optional preset dictionary.
67    /// - `num_workers`: The maximum number of worker threads for decompression. Currently capped at 256 Threads.
68    pub fn new(inner: R, dict_size: u32, preset_dict: Option<&[u8]>, num_workers: u32) -> Self {
69        let max_workers = num_workers.clamp(1, 256);
70
71        let work_queue = WorkStealingQueue::new();
72        let (result_tx, result_rx) = mpsc::sync_channel::<ResultUnit>(1);
73        let shutdown_flag = Arc::new(AtomicBool::new(false));
74        let error_store = Arc::new(Mutex::new(None));
75        let active_workers = Arc::new(AtomicU32::new(0));
76        let preset_dict = preset_dict.map(|s| s.to_vec()).map(Arc::new);
77
78        let mut reader = Self {
79            inner,
80            result_rx,
81            result_tx,
82            current_work_unit: Vec::with_capacity(1024 * 1024),
83            next_sequence_to_dispatch: 0,
84            next_sequence_to_return: 0,
85            last_sequence_id: None,
86            out_of_order_chunks: BTreeMap::new(),
87            current_chunk: Cursor::new(Vec::new()),
88            shutdown_flag,
89            error_store,
90            state: State::Reading,
91            work_queue,
92            active_workers,
93            max_workers,
94            dict_size,
95            preset_dict,
96            worker_handles: Vec::new(),
97        };
98
99        reader.spawn_worker_thread();
100
101        reader
102    }
103
104    fn spawn_worker_thread(&mut self) {
105        let worker_handle = self.work_queue.worker();
106        let result_tx = self.result_tx.clone();
107        let shutdown_flag = Arc::clone(&self.shutdown_flag);
108        let error_store = Arc::clone(&self.error_store);
109        let active_workers = Arc::clone(&self.active_workers);
110        let preset_dict = self.preset_dict.clone();
111        let dict_size = self.dict_size;
112
113        let handle = thread::spawn(move || {
114            worker_thread_logic(
115                worker_handle,
116                result_tx,
117                dict_size,
118                preset_dict,
119                shutdown_flag,
120                error_store,
121                active_workers,
122            );
123        });
124
125        self.worker_handles.push(handle);
126    }
127
128    /// The count of independent chunks found inside the compressed file.
129    /// This is effectively tha maximum parallelization possible.
130    pub fn chunk_count(&self) -> u64 {
131        self.next_sequence_to_return
132    }
133
134    /// Reads one LZMA2 chunk from the inner reader and appends it to the current work unit.
135    /// If the chunk is an independent block, it dispatches the current work unit.
136    ///
137    /// Returns `Ok(false)` on clean EOF, `Ok(true)` on success, and `Err` on I/O error.
138    fn read_and_dispatch_chunk(&mut self) -> io::Result<bool> {
139        let mut control_buf = [0u8; 1];
140        match self.inner.read_exact(&mut control_buf) {
141            Ok(_) => (),
142            Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => {
143                // Clean end of stream.
144                return Ok(false);
145            }
146            Err(error) => return Err(error),
147        }
148
149        let control = control_buf[0];
150
151        if control == 0x00 {
152            // End of stream marker.
153            self.current_work_unit.push(0x00);
154            self.send_work_unit();
155            return Ok(false);
156        }
157
158        let is_independent_chunk = control >= 0xE0 || control == 0x01;
159
160        // Split work units before independent chunks (but not for the very first chunk).
161        if is_independent_chunk && !self.current_work_unit.is_empty() {
162            self.current_work_unit.push(0x00);
163            self.send_work_unit();
164        }
165
166        self.current_work_unit.push(control);
167
168        let chunk_data_size = if control >= 0x80 {
169            // Compressed chunk. Read header to find size.
170            let header_len = if control >= 0xC0 { 5 } else { 4 };
171            let mut header_buf = [0; 5];
172            self.inner.read_exact(&mut header_buf[..header_len])?;
173            self.current_work_unit
174                .extend_from_slice(&header_buf[..header_len]);
175            u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize + 1
176        } else if control == 0x01 || control == 0x02 {
177            // Uncompressed chunk.
178            let mut size_buf = [0u8; 2];
179            self.inner.read_exact(&mut size_buf)?;
180            self.current_work_unit.extend_from_slice(&size_buf);
181            u16::from_be_bytes(size_buf) as usize + 1
182        } else {
183            return Err(io::Error::new(
184                io::ErrorKind::InvalidData,
185                format!("invalid LZMA2 control byte: {control:X}"),
186            ));
187        };
188
189        // Read the chunk data itself.
190        if chunk_data_size > 0 {
191            let start_len = self.current_work_unit.len();
192            self.current_work_unit
193                .resize(start_len + chunk_data_size, 0);
194            self.inner
195                .read_exact(&mut self.current_work_unit[start_len..])?;
196        }
197
198        Ok(true)
199    }
200
201    /// Sends the current work unit to the workers.
202    fn send_work_unit(&mut self) {
203        if self.current_work_unit.is_empty() {
204            return;
205        }
206
207        let work_unit =
208            core::mem::replace(&mut self.current_work_unit, Vec::with_capacity(1024 * 1024));
209
210        if !self
211            .work_queue
212            .push((self.next_sequence_to_dispatch, work_unit))
213        {
214            // Queue is closed, this indicates shutdown.
215            self.state = State::Error;
216            set_error(
217                io::Error::new(io::ErrorKind::BrokenPipe, "worker threads have shut down"),
218                &self.error_store,
219                &self.shutdown_flag,
220            );
221        }
222
223        // We spawn a new thread if we have work queued, no available workers, and haven't reached
224        // the maximal allowed parallelism yet.
225        let spawned_workers = self.worker_handles.len() as u32;
226        let active_workers = self.active_workers.load(Ordering::Acquire);
227        let queue_len = self.work_queue.len();
228
229        if queue_len > 0 && active_workers == spawned_workers && spawned_workers < self.max_workers
230        {
231            self.spawn_worker_thread();
232        }
233
234        self.next_sequence_to_dispatch += 1;
235    }
236
237    fn get_next_uncompressed_chunk(&mut self) -> io::Result<Option<Vec<u8>>> {
238        loop {
239            // Always check for already-received chunks first.
240            if let Some(result) = self
241                .out_of_order_chunks
242                .remove(&self.next_sequence_to_return)
243            {
244                self.next_sequence_to_return += 1;
245                return Ok(Some(result));
246            }
247
248            // Check for a globally stored error.
249            if let Some(err) = self.error_store.lock().unwrap().take() {
250                self.state = State::Error;
251                return Err(err);
252            }
253
254            match self.state {
255                State::Reading => {
256                    // First, always try to receive a result without blocking.
257                    // This keeps the pipeline moving and avoids unnecessary blocking on I/O.
258                    match self.result_rx.try_recv() {
259                        Ok((seq, result)) => {
260                            if seq == self.next_sequence_to_return {
261                                self.next_sequence_to_return += 1;
262                                return Ok(Some(result));
263                            } else {
264                                self.out_of_order_chunks.insert(seq, result);
265                                continue; // Loop again to check the out_of_order_chunks
266                            }
267                        }
268                        Err(mpsc::TryRecvError::Disconnected) => {
269                            // All workers are done.
270                            self.state = State::Draining;
271                            continue;
272                        }
273                        Err(mpsc::TryRecvError::Empty) => {
274                            // No results are ready. Now, we can consider reading more input.
275                        }
276                    }
277
278                    // If the work queue has capacity, try to read more from the source.
279                    if self.work_queue.is_empty() {
280                        match self.read_and_dispatch_chunk() {
281                            Ok(true) => {
282                                // Successfully read and dispatched a chunk, loop to continue.
283                                continue;
284                            }
285                            Ok(false) => {
286                                // Clean EOF from inner reader.
287                                // Send any remaining data as the final work unit.
288                                self.send_work_unit();
289                                self.last_sequence_id =
290                                    Some(self.next_sequence_to_dispatch.saturating_sub(1));
291                                self.state = State::Draining;
292                                continue;
293                            }
294                            Err(error) => {
295                                set_error(error, &self.error_store, &self.shutdown_flag);
296                                self.state = State::Error;
297                                continue;
298                            }
299                        }
300                    }
301
302                    // Now we MUST wait for a result to make progress.
303                    match self.result_rx.recv() {
304                        Ok((seq, result)) => {
305                            if seq == self.next_sequence_to_return {
306                                self.next_sequence_to_return += 1;
307                                return Ok(Some(result));
308                            } else {
309                                self.out_of_order_chunks.insert(seq, result);
310                                // We've made progress, loop to check the out_of_order_chunks
311                                continue;
312                            }
313                        }
314                        Err(_) => {
315                            // All workers are done.
316                            self.state = State::Draining;
317                        }
318                    }
319                }
320                State::Draining => {
321                    if let Some(last_seq) = self.last_sequence_id {
322                        if self.next_sequence_to_return > last_seq {
323                            self.state = State::Finished;
324                            continue;
325                        }
326                    }
327
328                    // In Draining state, we only wait for results.
329                    match self.result_rx.recv() {
330                        Ok((seq, result)) => {
331                            if seq == self.next_sequence_to_return {
332                                self.next_sequence_to_return += 1;
333                                return Ok(Some(result));
334                            } else {
335                                self.out_of_order_chunks.insert(seq, result);
336                            }
337                        }
338                        Err(_) => {
339                            // All workers finished, and channel is empty. We are done.
340                            self.state = State::Finished;
341                        }
342                    }
343                }
344                State::Finished => {
345                    return Ok(None);
346                }
347                State::Error => {
348                    // The error was already logged, now we just propagate it.
349                    return Err(self.error_store.lock().unwrap().take().unwrap_or_else(|| {
350                        io::Error::other("decompression failed with an unknown error")
351                    }));
352                }
353            }
354        }
355    }
356}
357
358/// The logic for a single worker thread.
359fn worker_thread_logic(
360    worker_handle: WorkerHandle<WorkUnit>,
361    result_tx: SyncSender<ResultUnit>,
362    dict_size: u32,
363    preset_dict: Option<Arc<Vec<u8>>>,
364    shutdown_flag: Arc<AtomicBool>,
365    error_store: Arc<Mutex<Option<io::Error>>>,
366    active_workers: Arc<AtomicU32>,
367) {
368    while !shutdown_flag.load(Ordering::Acquire) {
369        let (seq, work_unit_data) = match worker_handle.steal() {
370            Some(work) => {
371                active_workers.fetch_add(1, Ordering::Release);
372                work
373            }
374            None => {
375                // No more work available and queue is closed
376                break;
377            }
378        };
379
380        let mut reader = Lzma2Reader::new(
381            work_unit_data.as_slice(),
382            dict_size,
383            preset_dict.as_deref().map(|v| v.as_slice()),
384        );
385
386        let mut decompressed_data = Vec::with_capacity(work_unit_data.len());
387        let result = match reader.read_to_end(&mut decompressed_data) {
388            Ok(_) => decompressed_data,
389            Err(error) => {
390                active_workers.fetch_sub(1, Ordering::Release);
391                set_error(error, &error_store, &shutdown_flag);
392                return;
393            }
394        };
395
396        if result_tx.send((seq, result)).is_err() {
397            active_workers.fetch_sub(1, Ordering::Release);
398            return;
399        }
400
401        active_workers.fetch_sub(1, Ordering::Release);
402    }
403}
404
405impl<R: Read> Read for Lzma2ReaderMt<R> {
406    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
407        if buf.is_empty() {
408            return Ok(0);
409        }
410
411        let bytes_read = self.current_chunk.read(buf)?;
412
413        if bytes_read > 0 {
414            return Ok(bytes_read);
415        }
416
417        let chunk_data = self.get_next_uncompressed_chunk()?;
418
419        let Some(chunk_data) = chunk_data else {
420            // This is the clean end of the stream.
421            return Ok(0);
422        };
423
424        self.current_chunk = Cursor::new(chunk_data);
425
426        // Recursive call to read the new chunk data.
427        self.read(buf)
428    }
429}
430
431impl<R: Read> Drop for Lzma2ReaderMt<R> {
432    fn drop(&mut self) {
433        self.shutdown_flag.store(true, Ordering::Release);
434        self.work_queue.close();
435        // Worker threads will exit when the work queue is closed.
436        // JoinHandles will be dropped, which is fine since we set the shutdown flag,
437    }
438}