1use std::{
2 io::{self, Write},
3 sync::{
4 atomic::{AtomicBool, AtomicU32, Ordering},
5 mpsc::SyncSender,
6 Arc, Mutex,
7 },
8};
9
10use super::{
11 add_padding, write_xz_block_header, write_xz_index, write_xz_stream_footer,
12 write_xz_stream_header, CheckType, ChecksumCalculator, FilterConfig, FilterType, IndexRecord,
13};
14use crate::{
15 enc::{Lzma2Writer, LzmaOptions},
16 error_invalid_input, set_error,
17 work_pool::{WorkPool, WorkPoolConfig},
18 work_queue::WorkerHandle,
19 AutoFinish, AutoFinisher, Lzma2Options, Result, XzOptions,
20};
21
22#[derive(Debug, Clone)]
24struct WorkUnit {
25 uncompressed_data: Vec<u8>,
26 lzma_options: LzmaOptions,
27 check_type: CheckType,
28}
29
30#[derive(Debug)]
32struct ResultUnit {
33 compressed_data: Vec<u8>,
34 checksum: Vec<u8>,
35 uncompressed_size: u64,
36}
37
38pub struct XzWriterMt<W: Write> {
40 inner: W,
41 options: XzOptions,
42 current_work_unit: Vec<u8>,
43 block_size: usize,
44 work_pool: WorkPool<WorkUnit, ResultUnit>,
45 index_records: Vec<IndexRecord>,
46 checksum_calculator: ChecksumCalculator,
47 header_written: bool,
48 total_uncompressed_pos: u64,
49}
50
51impl<W: Write> XzWriterMt<W> {
52 pub fn new(inner: W, options: XzOptions, num_workers: u32) -> Result<Self> {
60 if options.filters.len() > 3 {
61 return Err(error_invalid_input(
62 "XZ allows only at most 3 pre-filters plus LZMA2",
63 ));
64 }
65
66 let block_size = match options.block_size {
67 None => return Err(error_invalid_input("block size must be set")),
68 Some(block_size) => block_size.get().max(options.lzma_options.dict_size as u64),
69 };
70
71 let block_size = usize::try_from(block_size)
72 .map_err(|_| error_invalid_input("block size bigger than usize"))?;
73
74 let checksum_calculator = ChecksumCalculator::new(options.check_type);
75
76 let num_work = u64::MAX;
78
79 Ok(Self {
80 inner,
81 options,
82 current_work_unit: Vec::with_capacity(block_size.min(1024 * 1024)),
83 block_size,
84 work_pool: WorkPool::new(
85 WorkPoolConfig::new(num_workers, num_work),
86 worker_thread_logic,
87 ),
88 index_records: Vec::new(),
89 checksum_calculator,
90 header_written: false,
91 total_uncompressed_pos: 0,
92 })
93 }
94
95 fn write_stream_header(&mut self) -> Result<()> {
96 if self.header_written {
97 return Ok(());
98 }
99
100 write_xz_stream_header(&mut self.inner, self.options.check_type)?;
101 self.header_written = true;
102
103 Ok(())
104 }
105
106 fn write_block_header(&mut self, _block_uncompressed_size: u64) -> Result<u64> {
107 let mut filters = self.options.filters.clone();
109 filters.push(FilterConfig {
110 filter_type: FilterType::Lzma2,
111 property: 0,
112 });
113
114 write_xz_block_header(
115 &mut self.inner,
116 &filters,
117 self.options.lzma_options.dict_size,
118 )
119 }
120
121 fn send_work_unit(&mut self) -> Result<()> {
123 if self.current_work_unit.is_empty() {
124 return Ok(());
125 }
126
127 self.write_stream_header()?;
129
130 self.drain_available_results()?;
131
132 let work_data = core::mem::take(&mut self.current_work_unit);
133 let mut work_data_opt = Some(work_data);
134
135 self.work_pool.dispatch_next_work(&mut |_seq| {
136 let data = work_data_opt.take().ok_or_else(|| {
137 io::Error::new(io::ErrorKind::InvalidInput, "work already provided")
138 })?;
139 Ok(WorkUnit {
140 uncompressed_data: data,
141 lzma_options: self.options.lzma_options.clone(),
142 check_type: self.options.check_type,
143 })
144 })?;
145
146 self.drain_available_results()?;
147
148 Ok(())
149 }
150
151 fn drain_available_results(&mut self) -> Result<()> {
153 while let Some(result) = self.work_pool.try_get_result()? {
154 self.write_compressed_block(
155 result.compressed_data,
156 result.checksum,
157 result.uncompressed_size,
158 )?;
159 }
160 Ok(())
161 }
162
163 fn write_compressed_block(
164 &mut self,
165 compressed_data: Vec<u8>,
166 checksum: Vec<u8>,
167 block_uncompressed_size: u64,
168 ) -> Result<()> {
169 let block_header_size = self.write_block_header(block_uncompressed_size)?;
170
171 let data_size = compressed_data.len() as u64;
172 let padding_needed = (4 - (data_size % 4)) % 4;
173
174 self.inner.write_all(&compressed_data)?;
175
176 add_padding(&mut self.inner, padding_needed as usize)?;
177
178 self.inner.write_all(&checksum)?;
179
180 let unpadded_size = block_header_size + data_size + self.options.check_type.checksum_size();
181 self.index_records.push(IndexRecord {
182 unpadded_size,
183 uncompressed_size: block_uncompressed_size,
184 });
185
186 self.total_uncompressed_pos += block_uncompressed_size;
187
188 Ok(())
189 }
190
191 pub fn auto_finish(self) -> AutoFinisher<Self> {
193 AutoFinisher(Some(self))
194 }
195
196 pub fn into_inner(self) -> W {
198 self.inner
199 }
200
201 #[inline(always)]
202 fn write_index(&mut self) -> Result<()> {
203 write_xz_index(&mut self.inner, &self.index_records)
204 }
205
206 #[inline(always)]
207 fn write_stream_footer(&mut self) -> Result<()> {
208 write_xz_stream_footer(
209 &mut self.inner,
210 &self.index_records,
211 self.options.check_type,
212 )
213 }
214
215 pub fn finish(mut self) -> Result<W> {
217 self.write_stream_header()?;
218
219 if !self.current_work_unit.is_empty() {
220 self.send_work_unit()?;
221 }
222
223 if self.work_pool.next_index_to_dispatch() == 0 {
225 self.write_index()?;
227 self.write_stream_footer()?;
228
229 self.inner.flush()?;
230
231 return Ok(self.inner);
232 }
233
234 self.work_pool.finish();
236
237 while let Some(result) = self.work_pool.get_result(|_| {
239 Err(io::Error::new(
240 io::ErrorKind::InvalidInput,
241 "no more work to dispatch",
242 ))
243 })? {
244 self.write_compressed_block(
245 result.compressed_data,
246 result.checksum,
247 result.uncompressed_size,
248 )?;
249 }
250
251 self.write_index()?;
252 self.write_stream_footer()?;
253
254 self.inner.flush()?;
255
256 Ok(self.inner)
257 }
258}
259
260fn worker_thread_logic(
262 worker_handle: WorkerHandle<(u64, WorkUnit)>,
263 result_tx: SyncSender<(u64, ResultUnit)>,
264 shutdown_flag: Arc<AtomicBool>,
265 error_store: Arc<Mutex<Option<io::Error>>>,
266 active_workers: Arc<AtomicU32>,
267) {
268 while !shutdown_flag.load(Ordering::Acquire) {
269 let (index, work_unit) = match worker_handle.steal() {
270 Some(work) => {
271 active_workers.fetch_add(1, Ordering::Release);
272 work
273 }
274 None => {
275 break;
277 }
278 };
279
280 let mut compressed_buffer = Vec::new();
281 let uncompressed_size = work_unit.uncompressed_data.len() as u64;
282
283 let mut checksum_calculator = ChecksumCalculator::new(work_unit.check_type);
284 checksum_calculator.update(&work_unit.uncompressed_data);
285 let checksum = checksum_calculator.finalize_to_bytes();
286
287 let options = Lzma2Options {
288 lzma_options: work_unit.lzma_options,
289 ..Default::default()
290 };
291
292 let mut writer = Lzma2Writer::new(&mut compressed_buffer, options);
293 let result = match writer.write_all(&work_unit.uncompressed_data) {
294 Ok(_) => match writer.finish() {
295 Ok(_) => ResultUnit {
296 compressed_data: compressed_buffer,
297 checksum,
298 uncompressed_size,
299 },
300 Err(error) => {
301 active_workers.fetch_sub(1, Ordering::Release);
302 set_error(error, &error_store, &shutdown_flag);
303 return;
304 }
305 },
306 Err(error) => {
307 active_workers.fetch_sub(1, Ordering::Release);
308 set_error(error, &error_store, &shutdown_flag);
309 return;
310 }
311 };
312
313 if result_tx.send((index, result)).is_err() {
314 active_workers.fetch_sub(1, Ordering::Release);
315 return;
316 }
317
318 active_workers.fetch_sub(1, Ordering::Release);
319 }
320}
321
322impl<W: Write> Write for XzWriterMt<W> {
323 fn write(&mut self, buf: &[u8]) -> Result<usize> {
324 if buf.is_empty() {
325 return Ok(0);
326 }
327
328 let mut total_written = 0;
329 let mut remaining_buf = buf;
330
331 while !remaining_buf.is_empty() {
332 let block_remaining = self.block_size.saturating_sub(self.current_work_unit.len());
333 let to_write = remaining_buf.len().min(block_remaining);
334
335 if to_write > 0 {
336 self.current_work_unit
337 .extend_from_slice(&remaining_buf[..to_write]);
338 total_written += to_write;
339 remaining_buf = &remaining_buf[to_write..];
340 }
341
342 if self.current_work_unit.len() >= self.block_size {
343 self.send_work_unit()?;
344 }
345
346 self.drain_available_results()?;
347 }
348
349 Ok(total_written)
350 }
351
352 fn flush(&mut self) -> Result<()> {
353 if !self.current_work_unit.is_empty() {
354 self.send_work_unit()?;
355 }
356
357 while let Some(result) = self.work_pool.try_get_result()? {
359 self.write_compressed_block(
360 result.compressed_data,
361 result.checksum,
362 result.uncompressed_size,
363 )?;
364 }
365
366 self.inner.flush()
367 }
368}
369
370impl<W: Write> AutoFinish for XzWriterMt<W> {
371 fn finish_ignore_error(self) {
372 let _ = self.finish();
373 }
374}