lzma_rust2/xz/
reader_mt.rs

1use std::{
2    collections::BTreeMap,
3    io::{self, Cursor, Seek, SeekFrom},
4    sync::{
5        atomic::{AtomicBool, AtomicU32, Ordering},
6        mpsc::{self, Receiver, SyncSender},
7        Arc, Mutex,
8    },
9    thread,
10};
11
12use super::{create_filter_chain, BlockHeader, CheckType, Index, StreamFooter, StreamHeader};
13use crate::{
14    error_invalid_data, set_error,
15    work_queue::{WorkStealingQueue, WorkerHandle},
16    ByteReader, Read,
17};
18
19#[derive(Debug, Clone)]
20struct XzBlock {
21    start_pos: u64,
22    unpadded_size: u64,
23    uncompressed_size: u64,
24}
25
26/// A work unit for a worker thread.
27/// Contains the sequence number and block data.
28type WorkUnit = (u64, Vec<u8>);
29
30/// A result unit from a worker thread.
31/// Contains the sequence number and the decompressed data.
32type ResultUnit = (u64, Vec<u8>);
33
34enum State {
35    /// Dispatching blocks to worker threads.
36    Dispatching,
37    /// All blocks dispatched, waiting for workers to complete.
38    Draining,
39    /// All data has been decompressed and returned. The stream is exhausted.
40    Finished,
41    /// A fatal error occurred in either the reader or a worker thread.
42    Error,
43}
44
45/// A multi-threaded XZ decompressor.
46pub struct XzReaderMt<R: Read + Seek> {
47    inner: Option<R>,
48    blocks: Vec<XzBlock>,
49    check_type: CheckType,
50    result_rx: Receiver<ResultUnit>,
51    result_tx: SyncSender<ResultUnit>,
52    next_sequence_to_dispatch: u64,
53    next_sequence_to_return: u64,
54    last_sequence_id: Option<u64>,
55    out_of_order_chunks: BTreeMap<u64, Vec<u8>>,
56    current_chunk: Cursor<Vec<u8>>,
57    shutdown_flag: Arc<AtomicBool>,
58    error_store: Arc<Mutex<Option<io::Error>>>,
59    state: State,
60    work_queue: WorkStealingQueue<WorkUnit>,
61    active_workers: Arc<AtomicU32>,
62    max_workers: u32,
63    worker_handles: Vec<thread::JoinHandle<()>>,
64    allow_multiple_streams: bool,
65}
66
67impl<R: Read + Seek> XzReaderMt<R> {
68    /// Creates a new multi-threaded XZ reader.
69    ///
70    /// - `inner`: The reader to read compressed data from. Must implement Seek.
71    /// - `allow_multiple_streams`: Whether to allow reading multiple XZ streams concatenated together.
72    /// - `num_workers`: The maximum number of worker threads for decompression. Currently capped at 256 Threads.
73    pub fn new(inner: R, allow_multiple_streams: bool, num_workers: u32) -> io::Result<Self> {
74        let max_workers = num_workers.clamp(1, 256);
75
76        let work_queue = WorkStealingQueue::new();
77        let (result_tx, result_rx) = mpsc::sync_channel::<ResultUnit>(1);
78        let shutdown_flag = Arc::new(AtomicBool::new(false));
79        let error_store = Arc::new(Mutex::new(None));
80        let active_workers = Arc::new(AtomicU32::new(0));
81
82        let mut reader = Self {
83            inner: Some(inner),
84            blocks: Vec::new(),
85            check_type: CheckType::None,
86            result_rx,
87            result_tx,
88            next_sequence_to_dispatch: 0,
89            next_sequence_to_return: 0,
90            last_sequence_id: None,
91            out_of_order_chunks: BTreeMap::new(),
92            current_chunk: Cursor::new(Vec::new()),
93            shutdown_flag,
94            error_store,
95            state: State::Dispatching,
96            work_queue,
97            active_workers,
98            max_workers,
99            worker_handles: Vec::new(),
100            allow_multiple_streams,
101        };
102
103        reader.scan_blocks()?;
104
105        Ok(reader)
106    }
107
108    /// Scan the XZ file to collect information about all blocks.
109    /// This reads the index at the end of the file to efficiently locate block boundaries.
110    fn scan_blocks(&mut self) -> io::Result<()> {
111        let mut reader = self.inner.take().expect("inner reader not set");
112
113        let stream_header = StreamHeader::parse(&mut reader)?;
114        self.check_type = stream_header.check_type;
115
116        let header_end_pos = reader.stream_position()?;
117
118        let file_size = reader.seek(SeekFrom::End(0))?;
119
120        // Minimum XZ file: 12 byte header + 12 byte footer + 8 byte minimum index.
121        if file_size < 32 {
122            return Err(error_invalid_data(
123                "File too small to contain a valid XZ stream",
124            ));
125        }
126
127        reader.seek(SeekFrom::End(-12))?;
128
129        let stream_footer = StreamFooter::parse(&mut reader)?;
130
131        let header_flags = [0, self.check_type as u8];
132
133        if stream_footer.stream_flags != header_flags {
134            return Err(error_invalid_data(
135                "stream header and footer flags mismatch",
136            ));
137        }
138
139        // Now read the index using backward size.
140        let index_size = (stream_footer.backward_size + 1) * 4;
141        let index_start_pos = file_size - 12 - index_size as u64;
142
143        reader.seek(SeekFrom::Start(index_start_pos))?;
144
145        // Parse the index.
146        let index_indicator = reader.read_u8()?;
147
148        if index_indicator != 0 {
149            return Err(error_invalid_data("invalid XZ index indicator"));
150        }
151
152        let index = Index::parse(&mut reader)?;
153
154        let mut block_start_pos = header_end_pos;
155
156        for record in &index.records {
157            self.blocks.push(XzBlock {
158                start_pos: block_start_pos,
159                unpadded_size: record.unpadded_size,
160                uncompressed_size: record.uncompressed_size,
161            });
162
163            let padding_needed = (4 - (record.unpadded_size % 4)) % 4;
164            let actual_block_size = record.unpadded_size + padding_needed;
165
166            block_start_pos += actual_block_size;
167        }
168
169        if self.blocks.is_empty() {
170            return Err(io::Error::new(
171                io::ErrorKind::InvalidData,
172                "No valid XZ blocks found",
173            ));
174        }
175
176        self.inner = Some(reader);
177        Ok(())
178    }
179
180    fn spawn_worker_thread(&mut self) {
181        let worker_handle = self.work_queue.worker();
182        let result_tx = self.result_tx.clone();
183        let shutdown_flag = Arc::clone(&self.shutdown_flag);
184        let error_store = Arc::clone(&self.error_store);
185        let active_workers = Arc::clone(&self.active_workers);
186        let check_type = self.check_type;
187
188        let handle = thread::spawn(move || {
189            worker_thread_logic(
190                worker_handle,
191                result_tx,
192                check_type,
193                shutdown_flag,
194                error_store,
195                active_workers,
196            );
197        });
198
199        self.worker_handles.push(handle);
200    }
201
202    /// Get the count of XZ blocks found in the file.
203    pub fn block_count(&self) -> usize {
204        self.blocks.len()
205    }
206
207    fn dispatch_next_block(&mut self) -> io::Result<bool> {
208        let block_index = self.next_sequence_to_dispatch as usize;
209
210        if block_index >= self.blocks.len() {
211            // No more blocks to dispatch.
212            return Ok(false);
213        }
214
215        let block = &self.blocks[block_index];
216        let mut reader = self.inner.take().expect("inner reader not set");
217
218        reader.seek(SeekFrom::Start(block.start_pos))?;
219
220        let padding_needed = (4 - (block.unpadded_size % 4)) % 4;
221        let total_block_size = block.unpadded_size + padding_needed;
222
223        let mut block_data = vec![0u8; total_block_size as usize];
224        reader.read_exact(&mut block_data)?;
225
226        self.inner = Some(reader);
227
228        if !self
229            .work_queue
230            .push((self.next_sequence_to_dispatch, block_data))
231        {
232            // Queue is closed, this indicates shutdown.
233            self.state = State::Error;
234            set_error(
235                io::Error::new(io::ErrorKind::BrokenPipe, "Worker threads have shut down"),
236                &self.error_store,
237                &self.shutdown_flag,
238            );
239            return Err(io::Error::new(
240                io::ErrorKind::BrokenPipe,
241                "Worker threads have shut down",
242            ));
243        }
244
245        // We spawn a new thread if we have work queued, no available workers, and haven't reached
246        // the maximal allowed parallelism yet.
247        let spawned_workers = self.worker_handles.len() as u32;
248        let active_workers = self.active_workers.load(Ordering::Acquire);
249        let queue_len = self.work_queue.len();
250
251        if queue_len > 0 && active_workers == spawned_workers && spawned_workers < self.max_workers
252        {
253            self.spawn_worker_thread();
254        }
255
256        self.next_sequence_to_dispatch += 1;
257        Ok(true)
258    }
259
260    fn get_next_uncompressed_chunk(&mut self) -> io::Result<Option<Vec<u8>>> {
261        loop {
262            // Always check for already-received chunks first.
263            if let Some(result) = self
264                .out_of_order_chunks
265                .remove(&self.next_sequence_to_return)
266            {
267                self.next_sequence_to_return += 1;
268                return Ok(Some(result));
269            }
270
271            // Check for a globally stored error.
272            if let Some(err) = self.error_store.lock().unwrap().take() {
273                self.state = State::Error;
274                return Err(err);
275            }
276
277            match self.state {
278                State::Dispatching => {
279                    // First, always try to receive a result without blocking.
280                    // This keeps the pipeline moving and avoids unnecessary blocking.
281                    match self.result_rx.try_recv() {
282                        Ok((seq, result)) => {
283                            if seq == self.next_sequence_to_return {
284                                self.next_sequence_to_return += 1;
285                                return Ok(Some(result));
286                            } else {
287                                self.out_of_order_chunks.insert(seq, result);
288                                continue; // Loop again to check the out_of_order_chunks.
289                            }
290                        }
291                        Err(mpsc::TryRecvError::Disconnected) => {
292                            // All workers are done.
293                            self.state = State::Draining;
294                            continue;
295                        }
296                        Err(mpsc::TryRecvError::Empty) => {
297                            // No results are ready. Now, we can consider dispatching more work.
298                        }
299                    }
300
301                    // If the work queue has capacity, try to read more from the source.
302                    if self.work_queue.is_empty() {
303                        match self.dispatch_next_block() {
304                            Ok(true) => {
305                                // Successfully read and dispatched a block, loop to continue.
306                                continue;
307                            }
308                            Ok(false) => {
309                                // No more blocks to dispatch.
310                                // Set the last sequence ID and transition to draining.
311                                self.last_sequence_id =
312                                    Some(self.next_sequence_to_dispatch.saturating_sub(1));
313                                self.state = State::Draining;
314                                continue;
315                            }
316                            Err(error) => {
317                                set_error(error, &self.error_store, &self.shutdown_flag);
318                                self.state = State::Error;
319                                continue;
320                            }
321                        }
322                    }
323
324                    // Now we MUST wait for a result to make progress.
325                    match self.result_rx.recv() {
326                        Ok((seq, result)) => {
327                            if seq == self.next_sequence_to_return {
328                                self.next_sequence_to_return += 1;
329                                return Ok(Some(result));
330                            } else {
331                                self.out_of_order_chunks.insert(seq, result);
332                                // We've made progress, loop to check the out_of_order_chunks.
333                                continue;
334                            }
335                        }
336                        Err(_) => {
337                            // All workers are done.
338                            self.state = State::Draining;
339                        }
340                    }
341                }
342                State::Draining => {
343                    if let Some(last_seq) = self.last_sequence_id {
344                        if self.next_sequence_to_return > last_seq {
345                            self.state = State::Finished;
346                            continue;
347                        }
348                    }
349
350                    // In Draining state, we only wait for results.
351                    match self.result_rx.recv() {
352                        Ok((seq, result)) => {
353                            if seq == self.next_sequence_to_return {
354                                self.next_sequence_to_return += 1;
355                                return Ok(Some(result));
356                            } else {
357                                self.out_of_order_chunks.insert(seq, result);
358                            }
359                        }
360                        Err(_) => {
361                            // All workers finished, and channel is empty. We are done.
362                            self.state = State::Finished;
363                        }
364                    }
365                }
366                State::Finished => {
367                    return Ok(None);
368                }
369                State::Error => {
370                    // The error was already logged, now we just propagate it.
371                    return Err(self.error_store.lock().unwrap().take().unwrap_or_else(|| {
372                        io::Error::other("decompression failed with an unknown error")
373                    }));
374                }
375            }
376        }
377    }
378}
379
380/// The logic for a single worker thread.
381fn worker_thread_logic(
382    worker_handle: WorkerHandle<WorkUnit>,
383    result_tx: SyncSender<ResultUnit>,
384    check_type: CheckType,
385    shutdown_flag: Arc<AtomicBool>,
386    error_store: Arc<Mutex<Option<io::Error>>>,
387    active_workers: Arc<AtomicU32>,
388) {
389    while !shutdown_flag.load(Ordering::Acquire) {
390        let (seq, work_unit_data) = match worker_handle.steal() {
391            Some(work) => {
392                active_workers.fetch_add(1, Ordering::Release);
393                work
394            }
395            None => {
396                // No more work available and queue is closed
397                break;
398            }
399        };
400
401        let result = decompress_xz_block(work_unit_data, check_type);
402
403        match result {
404            Ok(decompressed_data) => {
405                if result_tx.send((seq, decompressed_data)).is_err() {
406                    active_workers.fetch_sub(1, Ordering::Release);
407                    return;
408                }
409            }
410            Err(error) => {
411                active_workers.fetch_sub(1, Ordering::Release);
412                set_error(error, &error_store, &shutdown_flag);
413                return;
414            }
415        }
416
417        active_workers.fetch_sub(1, Ordering::Release);
418    }
419}
420
421/// Decompresses a single XZ block by parsing the header and applying filters directly.
422fn decompress_xz_block(block_data: Vec<u8>, check_type: CheckType) -> io::Result<Vec<u8>> {
423    let (filters, properties, header_size) = BlockHeader::parse_from_slice(&block_data)?;
424
425    let checksum_size = check_type.checksum_size() as usize;
426    let padding_in_block_data = (4 - (block_data.len() % 4)) % 4;
427    let unpadded_size_in_data = block_data.len() - padding_in_block_data;
428    let compressed_data_end = unpadded_size_in_data - checksum_size;
429
430    if compressed_data_end <= header_size {
431        return Err(error_invalid_data(
432            "Block data too short for compressed content",
433        ));
434    }
435
436    let compressed_data = block_data[header_size..compressed_data_end].to_vec();
437    let mut compressed_data = compressed_data.as_slice();
438
439    let base_reader: Box<dyn Read> = Box::new(&mut compressed_data);
440    let mut chain_reader = create_filter_chain(base_reader, &filters, &properties);
441
442    let mut decompressed_data = Vec::new();
443    chain_reader.read_to_end(&mut decompressed_data)?;
444
445    Ok(decompressed_data)
446}
447
448impl<R: Read + Seek> Read for XzReaderMt<R> {
449    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
450        if buf.is_empty() {
451            return Ok(0);
452        }
453
454        let bytes_read = self.current_chunk.read(buf)?;
455
456        if bytes_read > 0 {
457            return Ok(bytes_read);
458        }
459
460        let chunk_data = self.get_next_uncompressed_chunk()?;
461
462        let Some(chunk_data) = chunk_data else {
463            // This is the clean end of the stream.
464            return Ok(0);
465        };
466
467        self.current_chunk = Cursor::new(chunk_data);
468
469        // Recursive call to read the new chunk data.
470        self.read(buf)
471    }
472}
473
474impl<R: Read + Seek> Drop for XzReaderMt<R> {
475    fn drop(&mut self) {
476        self.shutdown_flag.store(true, Ordering::Release);
477        self.work_queue.close();
478        // Worker threads will exit when the work queue is closed.
479        // JoinHandles will be dropped, which is fine since we set the shutdown flag,
480    }
481}