1use std::{
2 collections::BTreeMap,
3 io,
4 io::{Cursor, Read},
5 sync::{
6 atomic::{AtomicBool, AtomicU32, Ordering},
7 mpsc::{self, Receiver, SyncSender},
8 Arc, Mutex,
9 },
10 thread,
11};
12
13use crate::{
14 set_error,
15 work_queue::{WorkStealingQueue, WorkerHandle},
16 Lzma2Reader,
17};
18
19type WorkUnit = (u64, Vec<u8>);
22
23type ResultUnit = (u64, Vec<u8>);
26
27enum State {
28 Reading,
30 Draining,
33 Finished,
35 Error,
37}
38
39pub struct Lzma2ReaderMt<R: Read> {
41 inner: R,
42 result_rx: Receiver<ResultUnit>,
43 result_tx: SyncSender<ResultUnit>,
44 current_work_unit: Vec<u8>,
45 next_sequence_to_dispatch: u64,
46 next_sequence_to_return: u64,
47 last_sequence_id: Option<u64>,
48 out_of_order_chunks: BTreeMap<u64, Vec<u8>>,
49 current_chunk: Cursor<Vec<u8>>,
50 shutdown_flag: Arc<AtomicBool>,
51 error_store: Arc<Mutex<Option<io::Error>>>,
52 state: State,
53 work_queue: WorkStealingQueue<WorkUnit>,
54 active_workers: Arc<AtomicU32>,
55 max_workers: u32,
56 dict_size: u32,
57 preset_dict: Option<Arc<Vec<u8>>>,
58 worker_handles: Vec<thread::JoinHandle<()>>,
59}
60
61impl<R: Read> Lzma2ReaderMt<R> {
62 pub fn new(inner: R, dict_size: u32, preset_dict: Option<&[u8]>, num_workers: u32) -> Self {
69 let max_workers = num_workers.clamp(1, 256);
70
71 let work_queue = WorkStealingQueue::new();
72 let (result_tx, result_rx) = mpsc::sync_channel::<ResultUnit>(1);
73 let shutdown_flag = Arc::new(AtomicBool::new(false));
74 let error_store = Arc::new(Mutex::new(None));
75 let active_workers = Arc::new(AtomicU32::new(0));
76 let preset_dict = preset_dict.map(|s| s.to_vec()).map(Arc::new);
77
78 let mut reader = Self {
79 inner,
80 result_rx,
81 result_tx,
82 current_work_unit: Vec::with_capacity(1024 * 1024),
83 next_sequence_to_dispatch: 0,
84 next_sequence_to_return: 0,
85 last_sequence_id: None,
86 out_of_order_chunks: BTreeMap::new(),
87 current_chunk: Cursor::new(Vec::new()),
88 shutdown_flag,
89 error_store,
90 state: State::Reading,
91 work_queue,
92 active_workers,
93 max_workers,
94 dict_size,
95 preset_dict,
96 worker_handles: Vec::new(),
97 };
98
99 reader.spawn_worker_thread();
100
101 reader
102 }
103
104 fn spawn_worker_thread(&mut self) {
105 let worker_handle = self.work_queue.worker();
106 let result_tx = self.result_tx.clone();
107 let shutdown_flag = Arc::clone(&self.shutdown_flag);
108 let error_store = Arc::clone(&self.error_store);
109 let active_workers = Arc::clone(&self.active_workers);
110 let preset_dict = self.preset_dict.clone();
111 let dict_size = self.dict_size;
112
113 let handle = thread::spawn(move || {
114 worker_thread_logic(
115 worker_handle,
116 result_tx,
117 dict_size,
118 preset_dict,
119 shutdown_flag,
120 error_store,
121 active_workers,
122 );
123 });
124
125 self.worker_handles.push(handle);
126 }
127
128 pub fn chunk_count(&self) -> u64 {
131 self.next_sequence_to_return
132 }
133
134 fn read_and_dispatch_chunk(&mut self) -> io::Result<bool> {
139 let mut control_buf = [0u8; 1];
140 match self.inner.read_exact(&mut control_buf) {
141 Ok(_) => (),
142 Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => {
143 return Ok(false);
145 }
146 Err(error) => return Err(error),
147 }
148
149 let control = control_buf[0];
150
151 if control == 0x00 {
152 self.current_work_unit.push(0x00);
154 self.send_work_unit();
155 return Ok(false);
156 }
157
158 let is_independent_chunk = control >= 0xE0 || control == 0x01;
159
160 if is_independent_chunk && !self.current_work_unit.is_empty() {
162 self.current_work_unit.push(0x00);
163 self.send_work_unit();
164 }
165
166 self.current_work_unit.push(control);
167
168 let chunk_data_size = if control >= 0x80 {
169 let header_len = if control >= 0xC0 { 5 } else { 4 };
171 let mut header_buf = [0; 5];
172 self.inner.read_exact(&mut header_buf[..header_len])?;
173 self.current_work_unit
174 .extend_from_slice(&header_buf[..header_len]);
175 u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize + 1
176 } else if control == 0x01 || control == 0x02 {
177 let mut size_buf = [0u8; 2];
179 self.inner.read_exact(&mut size_buf)?;
180 self.current_work_unit.extend_from_slice(&size_buf);
181 u16::from_be_bytes(size_buf) as usize + 1
182 } else {
183 return Err(io::Error::new(
184 io::ErrorKind::InvalidData,
185 format!("invalid LZMA2 control byte: {control:X}"),
186 ));
187 };
188
189 if chunk_data_size > 0 {
191 let start_len = self.current_work_unit.len();
192 self.current_work_unit
193 .resize(start_len + chunk_data_size, 0);
194 self.inner
195 .read_exact(&mut self.current_work_unit[start_len..])?;
196 }
197
198 Ok(true)
199 }
200
201 fn send_work_unit(&mut self) {
203 if self.current_work_unit.is_empty() {
204 return;
205 }
206
207 let work_unit =
208 core::mem::replace(&mut self.current_work_unit, Vec::with_capacity(1024 * 1024));
209
210 if !self
211 .work_queue
212 .push((self.next_sequence_to_dispatch, work_unit))
213 {
214 self.state = State::Error;
216 set_error(
217 io::Error::new(io::ErrorKind::BrokenPipe, "worker threads have shut down"),
218 &self.error_store,
219 &self.shutdown_flag,
220 );
221 }
222
223 let spawned_workers = self.worker_handles.len() as u32;
226 let active_workers = self.active_workers.load(Ordering::Acquire);
227 let queue_len = self.work_queue.len();
228
229 if queue_len > 0 && active_workers == spawned_workers && spawned_workers < self.max_workers
230 {
231 self.spawn_worker_thread();
232 }
233
234 self.next_sequence_to_dispatch += 1;
235 }
236
237 fn get_next_uncompressed_chunk(&mut self) -> io::Result<Option<Vec<u8>>> {
238 loop {
239 if let Some(result) = self
241 .out_of_order_chunks
242 .remove(&self.next_sequence_to_return)
243 {
244 self.next_sequence_to_return += 1;
245 return Ok(Some(result));
246 }
247
248 if let Some(err) = self.error_store.lock().unwrap().take() {
250 self.state = State::Error;
251 return Err(err);
252 }
253
254 match self.state {
255 State::Reading => {
256 match self.result_rx.try_recv() {
259 Ok((seq, result)) => {
260 if seq == self.next_sequence_to_return {
261 self.next_sequence_to_return += 1;
262 return Ok(Some(result));
263 } else {
264 self.out_of_order_chunks.insert(seq, result);
265 continue; }
267 }
268 Err(mpsc::TryRecvError::Disconnected) => {
269 self.state = State::Draining;
271 continue;
272 }
273 Err(mpsc::TryRecvError::Empty) => {
274 }
276 }
277
278 if self.work_queue.is_empty() {
280 match self.read_and_dispatch_chunk() {
281 Ok(true) => {
282 continue;
284 }
285 Ok(false) => {
286 self.send_work_unit();
289 self.last_sequence_id =
290 Some(self.next_sequence_to_dispatch.saturating_sub(1));
291 self.state = State::Draining;
292 continue;
293 }
294 Err(error) => {
295 set_error(error, &self.error_store, &self.shutdown_flag);
296 self.state = State::Error;
297 continue;
298 }
299 }
300 }
301
302 match self.result_rx.recv() {
304 Ok((seq, result)) => {
305 if seq == self.next_sequence_to_return {
306 self.next_sequence_to_return += 1;
307 return Ok(Some(result));
308 } else {
309 self.out_of_order_chunks.insert(seq, result);
310 continue;
312 }
313 }
314 Err(_) => {
315 self.state = State::Draining;
317 }
318 }
319 }
320 State::Draining => {
321 if let Some(last_seq) = self.last_sequence_id {
322 if self.next_sequence_to_return > last_seq {
323 self.state = State::Finished;
324 continue;
325 }
326 }
327
328 match self.result_rx.recv() {
330 Ok((seq, result)) => {
331 if seq == self.next_sequence_to_return {
332 self.next_sequence_to_return += 1;
333 return Ok(Some(result));
334 } else {
335 self.out_of_order_chunks.insert(seq, result);
336 }
337 }
338 Err(_) => {
339 self.state = State::Finished;
341 }
342 }
343 }
344 State::Finished => {
345 return Ok(None);
346 }
347 State::Error => {
348 return Err(self.error_store.lock().unwrap().take().unwrap_or_else(|| {
350 io::Error::other("decompression failed with an unknown error")
351 }));
352 }
353 }
354 }
355 }
356}
357
358fn worker_thread_logic(
360 worker_handle: WorkerHandle<WorkUnit>,
361 result_tx: SyncSender<ResultUnit>,
362 dict_size: u32,
363 preset_dict: Option<Arc<Vec<u8>>>,
364 shutdown_flag: Arc<AtomicBool>,
365 error_store: Arc<Mutex<Option<io::Error>>>,
366 active_workers: Arc<AtomicU32>,
367) {
368 while !shutdown_flag.load(Ordering::Acquire) {
369 let (seq, work_unit_data) = match worker_handle.steal() {
370 Some(work) => {
371 active_workers.fetch_add(1, Ordering::Release);
372 work
373 }
374 None => {
375 break;
377 }
378 };
379
380 let mut reader = Lzma2Reader::new(
381 work_unit_data.as_slice(),
382 dict_size,
383 preset_dict.as_deref().map(|v| v.as_slice()),
384 );
385
386 let mut decompressed_data = Vec::with_capacity(work_unit_data.len());
387 let result = match reader.read_to_end(&mut decompressed_data) {
388 Ok(_) => decompressed_data,
389 Err(error) => {
390 active_workers.fetch_sub(1, Ordering::Release);
391 set_error(error, &error_store, &shutdown_flag);
392 return;
393 }
394 };
395
396 if result_tx.send((seq, result)).is_err() {
397 active_workers.fetch_sub(1, Ordering::Release);
398 return;
399 }
400
401 active_workers.fetch_sub(1, Ordering::Release);
402 }
403}
404
405impl<R: Read> Read for Lzma2ReaderMt<R> {
406 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
407 if buf.is_empty() {
408 return Ok(0);
409 }
410
411 let bytes_read = self.current_chunk.read(buf)?;
412
413 if bytes_read > 0 {
414 return Ok(bytes_read);
415 }
416
417 let chunk_data = self.get_next_uncompressed_chunk()?;
418
419 let Some(chunk_data) = chunk_data else {
420 return Ok(0);
422 };
423
424 self.current_chunk = Cursor::new(chunk_data);
425
426 self.read(buf)
428 }
429}
430
431impl<R: Read> Drop for Lzma2ReaderMt<R> {
432 fn drop(&mut self) {
433 self.shutdown_flag.store(true, Ordering::Release);
434 self.work_queue.close();
435 }
438}