1use std::{
2 collections::BTreeMap,
3 io::{self, Cursor, Seek, SeekFrom},
4 sync::{
5 atomic::{AtomicBool, AtomicU32, Ordering},
6 mpsc::{self, Receiver, SyncSender},
7 Arc, Mutex,
8 },
9 thread,
10};
11
12use super::{create_filter_chain, BlockHeader, CheckType, Index, StreamFooter, StreamHeader};
13use crate::{
14 error_invalid_data, set_error,
15 work_queue::{WorkStealingQueue, WorkerHandle},
16 ByteReader, Read,
17};
18
19#[derive(Debug, Clone)]
20struct XzBlock {
21 start_pos: u64,
22 unpadded_size: u64,
23 uncompressed_size: u64,
24}
25
26type WorkUnit = (u64, Vec<u8>);
29
30type ResultUnit = (u64, Vec<u8>);
33
34enum State {
35 Dispatching,
37 Draining,
39 Finished,
41 Error,
43}
44
45pub struct XzReaderMt<R: Read + Seek> {
47 inner: Option<R>,
48 blocks: Vec<XzBlock>,
49 check_type: CheckType,
50 result_rx: Receiver<ResultUnit>,
51 result_tx: SyncSender<ResultUnit>,
52 next_sequence_to_dispatch: u64,
53 next_sequence_to_return: u64,
54 last_sequence_id: Option<u64>,
55 out_of_order_chunks: BTreeMap<u64, Vec<u8>>,
56 current_chunk: Cursor<Vec<u8>>,
57 shutdown_flag: Arc<AtomicBool>,
58 error_store: Arc<Mutex<Option<io::Error>>>,
59 state: State,
60 work_queue: WorkStealingQueue<WorkUnit>,
61 active_workers: Arc<AtomicU32>,
62 max_workers: u32,
63 worker_handles: Vec<thread::JoinHandle<()>>,
64 allow_multiple_streams: bool,
65}
66
67impl<R: Read + Seek> XzReaderMt<R> {
68 pub fn new(inner: R, allow_multiple_streams: bool, num_workers: u32) -> io::Result<Self> {
74 let max_workers = num_workers.clamp(1, 256);
75
76 let work_queue = WorkStealingQueue::new();
77 let (result_tx, result_rx) = mpsc::sync_channel::<ResultUnit>(1);
78 let shutdown_flag = Arc::new(AtomicBool::new(false));
79 let error_store = Arc::new(Mutex::new(None));
80 let active_workers = Arc::new(AtomicU32::new(0));
81
82 let mut reader = Self {
83 inner: Some(inner),
84 blocks: Vec::new(),
85 check_type: CheckType::None,
86 result_rx,
87 result_tx,
88 next_sequence_to_dispatch: 0,
89 next_sequence_to_return: 0,
90 last_sequence_id: None,
91 out_of_order_chunks: BTreeMap::new(),
92 current_chunk: Cursor::new(Vec::new()),
93 shutdown_flag,
94 error_store,
95 state: State::Dispatching,
96 work_queue,
97 active_workers,
98 max_workers,
99 worker_handles: Vec::new(),
100 allow_multiple_streams,
101 };
102
103 reader.scan_blocks()?;
104
105 Ok(reader)
106 }
107
108 fn scan_blocks(&mut self) -> io::Result<()> {
111 let mut reader = self.inner.take().expect("inner reader not set");
112
113 let stream_header = StreamHeader::parse(&mut reader)?;
114 self.check_type = stream_header.check_type;
115
116 let header_end_pos = reader.stream_position()?;
117
118 let file_size = reader.seek(SeekFrom::End(0))?;
119
120 if file_size < 32 {
122 return Err(error_invalid_data(
123 "File too small to contain a valid XZ stream",
124 ));
125 }
126
127 reader.seek(SeekFrom::End(-12))?;
128
129 let stream_footer = StreamFooter::parse(&mut reader)?;
130
131 let header_flags = [0, self.check_type as u8];
132
133 if stream_footer.stream_flags != header_flags {
134 return Err(error_invalid_data(
135 "stream header and footer flags mismatch",
136 ));
137 }
138
139 let index_size = (stream_footer.backward_size + 1) * 4;
141 let index_start_pos = file_size - 12 - index_size as u64;
142
143 reader.seek(SeekFrom::Start(index_start_pos))?;
144
145 let index_indicator = reader.read_u8()?;
147
148 if index_indicator != 0 {
149 return Err(error_invalid_data("invalid XZ index indicator"));
150 }
151
152 let index = Index::parse(&mut reader)?;
153
154 let mut block_start_pos = header_end_pos;
155
156 for record in &index.records {
157 self.blocks.push(XzBlock {
158 start_pos: block_start_pos,
159 unpadded_size: record.unpadded_size,
160 uncompressed_size: record.uncompressed_size,
161 });
162
163 let padding_needed = (4 - (record.unpadded_size % 4)) % 4;
164 let actual_block_size = record.unpadded_size + padding_needed;
165
166 block_start_pos += actual_block_size;
167 }
168
169 if self.blocks.is_empty() {
170 return Err(io::Error::new(
171 io::ErrorKind::InvalidData,
172 "No valid XZ blocks found",
173 ));
174 }
175
176 self.inner = Some(reader);
177 Ok(())
178 }
179
180 fn spawn_worker_thread(&mut self) {
181 let worker_handle = self.work_queue.worker();
182 let result_tx = self.result_tx.clone();
183 let shutdown_flag = Arc::clone(&self.shutdown_flag);
184 let error_store = Arc::clone(&self.error_store);
185 let active_workers = Arc::clone(&self.active_workers);
186 let check_type = self.check_type;
187
188 let handle = thread::spawn(move || {
189 worker_thread_logic(
190 worker_handle,
191 result_tx,
192 check_type,
193 shutdown_flag,
194 error_store,
195 active_workers,
196 );
197 });
198
199 self.worker_handles.push(handle);
200 }
201
202 pub fn block_count(&self) -> usize {
204 self.blocks.len()
205 }
206
207 fn dispatch_next_block(&mut self) -> io::Result<bool> {
208 let block_index = self.next_sequence_to_dispatch as usize;
209
210 if block_index >= self.blocks.len() {
211 return Ok(false);
213 }
214
215 let block = &self.blocks[block_index];
216 let mut reader = self.inner.take().expect("inner reader not set");
217
218 reader.seek(SeekFrom::Start(block.start_pos))?;
219
220 let padding_needed = (4 - (block.unpadded_size % 4)) % 4;
221 let total_block_size = block.unpadded_size + padding_needed;
222
223 let mut block_data = vec![0u8; total_block_size as usize];
224 reader.read_exact(&mut block_data)?;
225
226 self.inner = Some(reader);
227
228 if !self
229 .work_queue
230 .push((self.next_sequence_to_dispatch, block_data))
231 {
232 self.state = State::Error;
234 set_error(
235 io::Error::new(io::ErrorKind::BrokenPipe, "Worker threads have shut down"),
236 &self.error_store,
237 &self.shutdown_flag,
238 );
239 return Err(io::Error::new(
240 io::ErrorKind::BrokenPipe,
241 "Worker threads have shut down",
242 ));
243 }
244
245 let spawned_workers = self.worker_handles.len() as u32;
248 let active_workers = self.active_workers.load(Ordering::Acquire);
249 let queue_len = self.work_queue.len();
250
251 if queue_len > 0 && active_workers == spawned_workers && spawned_workers < self.max_workers
252 {
253 self.spawn_worker_thread();
254 }
255
256 self.next_sequence_to_dispatch += 1;
257 Ok(true)
258 }
259
260 fn get_next_uncompressed_chunk(&mut self) -> io::Result<Option<Vec<u8>>> {
261 loop {
262 if let Some(result) = self
264 .out_of_order_chunks
265 .remove(&self.next_sequence_to_return)
266 {
267 self.next_sequence_to_return += 1;
268 return Ok(Some(result));
269 }
270
271 if let Some(err) = self.error_store.lock().unwrap().take() {
273 self.state = State::Error;
274 return Err(err);
275 }
276
277 match self.state {
278 State::Dispatching => {
279 match self.result_rx.try_recv() {
282 Ok((seq, result)) => {
283 if seq == self.next_sequence_to_return {
284 self.next_sequence_to_return += 1;
285 return Ok(Some(result));
286 } else {
287 self.out_of_order_chunks.insert(seq, result);
288 continue; }
290 }
291 Err(mpsc::TryRecvError::Disconnected) => {
292 self.state = State::Draining;
294 continue;
295 }
296 Err(mpsc::TryRecvError::Empty) => {
297 }
299 }
300
301 if self.work_queue.is_empty() {
303 match self.dispatch_next_block() {
304 Ok(true) => {
305 continue;
307 }
308 Ok(false) => {
309 self.last_sequence_id =
312 Some(self.next_sequence_to_dispatch.saturating_sub(1));
313 self.state = State::Draining;
314 continue;
315 }
316 Err(error) => {
317 set_error(error, &self.error_store, &self.shutdown_flag);
318 self.state = State::Error;
319 continue;
320 }
321 }
322 }
323
324 match self.result_rx.recv() {
326 Ok((seq, result)) => {
327 if seq == self.next_sequence_to_return {
328 self.next_sequence_to_return += 1;
329 return Ok(Some(result));
330 } else {
331 self.out_of_order_chunks.insert(seq, result);
332 continue;
334 }
335 }
336 Err(_) => {
337 self.state = State::Draining;
339 }
340 }
341 }
342 State::Draining => {
343 if let Some(last_seq) = self.last_sequence_id {
344 if self.next_sequence_to_return > last_seq {
345 self.state = State::Finished;
346 continue;
347 }
348 }
349
350 match self.result_rx.recv() {
352 Ok((seq, result)) => {
353 if seq == self.next_sequence_to_return {
354 self.next_sequence_to_return += 1;
355 return Ok(Some(result));
356 } else {
357 self.out_of_order_chunks.insert(seq, result);
358 }
359 }
360 Err(_) => {
361 self.state = State::Finished;
363 }
364 }
365 }
366 State::Finished => {
367 return Ok(None);
368 }
369 State::Error => {
370 return Err(self.error_store.lock().unwrap().take().unwrap_or_else(|| {
372 io::Error::other("decompression failed with an unknown error")
373 }));
374 }
375 }
376 }
377 }
378}
379
380fn worker_thread_logic(
382 worker_handle: WorkerHandle<WorkUnit>,
383 result_tx: SyncSender<ResultUnit>,
384 check_type: CheckType,
385 shutdown_flag: Arc<AtomicBool>,
386 error_store: Arc<Mutex<Option<io::Error>>>,
387 active_workers: Arc<AtomicU32>,
388) {
389 while !shutdown_flag.load(Ordering::Acquire) {
390 let (seq, work_unit_data) = match worker_handle.steal() {
391 Some(work) => {
392 active_workers.fetch_add(1, Ordering::Release);
393 work
394 }
395 None => {
396 break;
398 }
399 };
400
401 let result = decompress_xz_block(work_unit_data, check_type);
402
403 match result {
404 Ok(decompressed_data) => {
405 if result_tx.send((seq, decompressed_data)).is_err() {
406 active_workers.fetch_sub(1, Ordering::Release);
407 return;
408 }
409 }
410 Err(error) => {
411 active_workers.fetch_sub(1, Ordering::Release);
412 set_error(error, &error_store, &shutdown_flag);
413 return;
414 }
415 }
416
417 active_workers.fetch_sub(1, Ordering::Release);
418 }
419}
420
421fn decompress_xz_block(block_data: Vec<u8>, check_type: CheckType) -> io::Result<Vec<u8>> {
423 let (filters, properties, header_size) = BlockHeader::parse_from_slice(&block_data)?;
424
425 let checksum_size = check_type.checksum_size() as usize;
426 let padding_in_block_data = (4 - (block_data.len() % 4)) % 4;
427 let unpadded_size_in_data = block_data.len() - padding_in_block_data;
428 let compressed_data_end = unpadded_size_in_data - checksum_size;
429
430 if compressed_data_end <= header_size {
431 return Err(error_invalid_data(
432 "Block data too short for compressed content",
433 ));
434 }
435
436 let compressed_data = block_data[header_size..compressed_data_end].to_vec();
437 let mut compressed_data = compressed_data.as_slice();
438
439 let base_reader: Box<dyn Read> = Box::new(&mut compressed_data);
440 let mut chain_reader = create_filter_chain(base_reader, &filters, &properties);
441
442 let mut decompressed_data = Vec::new();
443 chain_reader.read_to_end(&mut decompressed_data)?;
444
445 Ok(decompressed_data)
446}
447
448impl<R: Read + Seek> Read for XzReaderMt<R> {
449 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
450 if buf.is_empty() {
451 return Ok(0);
452 }
453
454 let bytes_read = self.current_chunk.read(buf)?;
455
456 if bytes_read > 0 {
457 return Ok(bytes_read);
458 }
459
460 let chunk_data = self.get_next_uncompressed_chunk()?;
461
462 let Some(chunk_data) = chunk_data else {
463 return Ok(0);
465 };
466
467 self.current_chunk = Cursor::new(chunk_data);
468
469 self.read(buf)
471 }
472}
473
474impl<R: Read + Seek> Drop for XzReaderMt<R> {
475 fn drop(&mut self) {
476 self.shutdown_flag.store(true, Ordering::Release);
477 self.work_queue.close();
478 }
481}