use std::{
io::{self, Write},
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
mpsc::SyncSender,
Arc, Mutex,
},
};
use super::{
add_padding, write_xz_block_header, write_xz_index, write_xz_stream_footer,
write_xz_stream_header, CheckType, ChecksumCalculator, FilterConfig, FilterType, IndexRecord,
};
use crate::{
enc::{Lzma2Writer, LzmaOptions},
error_invalid_input, set_error,
work_pool::{WorkPool, WorkPoolConfig},
work_queue::WorkerHandle,
AutoFinish, AutoFinisher, Lzma2Options, Result, XzOptions,
};
#[derive(Debug, Clone)]
struct WorkUnit {
uncompressed_data: Vec<u8>,
lzma_options: LzmaOptions,
check_type: CheckType,
}
#[derive(Debug)]
struct ResultUnit {
compressed_data: Vec<u8>,
checksum: Vec<u8>,
uncompressed_size: u64,
}
pub struct XzWriterMt<W: Write> {
inner: W,
options: XzOptions,
current_work_unit: Vec<u8>,
block_size: usize,
work_pool: WorkPool<WorkUnit, ResultUnit>,
index_records: Vec<IndexRecord>,
checksum_calculator: ChecksumCalculator,
header_written: bool,
total_uncompressed_pos: u64,
}
impl<W: Write> XzWriterMt<W> {
pub fn new(inner: W, options: XzOptions, num_workers: u32) -> Result<Self> {
if options.filters.len() > 3 {
return Err(error_invalid_input(
"XZ allows only at most 3 pre-filters plus LZMA2",
));
}
let block_size = match options.block_size {
None => return Err(error_invalid_input("block size must be set")),
Some(block_size) => block_size.get().max(options.lzma_options.dict_size as u64),
};
let block_size = usize::try_from(block_size)
.map_err(|_| error_invalid_input("block size bigger than usize"))?;
let checksum_calculator = ChecksumCalculator::new(options.check_type);
let num_work = u64::MAX;
Ok(Self {
inner,
options,
current_work_unit: Vec::with_capacity(block_size.min(1024 * 1024)),
block_size,
work_pool: WorkPool::new(
WorkPoolConfig::new(num_workers, num_work),
worker_thread_logic,
),
index_records: Vec::new(),
checksum_calculator,
header_written: false,
total_uncompressed_pos: 0,
})
}
fn write_stream_header(&mut self) -> Result<()> {
if self.header_written {
return Ok(());
}
write_xz_stream_header(&mut self.inner, self.options.check_type)?;
self.header_written = true;
Ok(())
}
fn write_block_header(&mut self, _block_uncompressed_size: u64) -> Result<u64> {
let mut filters = self.options.filters.clone();
filters.push(FilterConfig {
filter_type: FilterType::Lzma2,
property: 0,
});
write_xz_block_header(
&mut self.inner,
&filters,
self.options.lzma_options.dict_size,
)
}
fn send_work_unit(&mut self) -> Result<()> {
if self.current_work_unit.is_empty() {
return Ok(());
}
self.write_stream_header()?;
self.drain_available_results()?;
let work_data = core::mem::take(&mut self.current_work_unit);
let mut work_data_opt = Some(work_data);
self.work_pool.dispatch_next_work(&mut |_seq| {
let data = work_data_opt.take().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "work already provided")
})?;
Ok(WorkUnit {
uncompressed_data: data,
lzma_options: self.options.lzma_options.clone(),
check_type: self.options.check_type,
})
})?;
self.drain_available_results()?;
Ok(())
}
fn drain_available_results(&mut self) -> Result<()> {
while let Some(result) = self.work_pool.try_get_result()? {
self.write_compressed_block(
result.compressed_data,
result.checksum,
result.uncompressed_size,
)?;
}
Ok(())
}
fn write_compressed_block(
&mut self,
compressed_data: Vec<u8>,
checksum: Vec<u8>,
block_uncompressed_size: u64,
) -> Result<()> {
let block_header_size = self.write_block_header(block_uncompressed_size)?;
let data_size = compressed_data.len() as u64;
let padding_needed = (4 - (data_size % 4)) % 4;
self.inner.write_all(&compressed_data)?;
add_padding(&mut self.inner, padding_needed as usize)?;
self.inner.write_all(&checksum)?;
let unpadded_size = block_header_size + data_size + self.options.check_type.checksum_size();
self.index_records.push(IndexRecord {
unpadded_size,
uncompressed_size: block_uncompressed_size,
});
self.total_uncompressed_pos += block_uncompressed_size;
Ok(())
}
pub fn auto_finish(self) -> AutoFinisher<Self> {
AutoFinisher(Some(self))
}
pub fn into_inner(self) -> W {
self.inner
}
#[inline(always)]
fn write_index(&mut self) -> Result<()> {
write_xz_index(&mut self.inner, &self.index_records)
}
#[inline(always)]
fn write_stream_footer(&mut self) -> Result<()> {
write_xz_stream_footer(
&mut self.inner,
&self.index_records,
self.options.check_type,
)
}
pub fn finish(mut self) -> Result<W> {
self.write_stream_header()?;
if !self.current_work_unit.is_empty() {
self.send_work_unit()?;
}
if self.work_pool.next_index_to_dispatch() == 0 {
self.write_index()?;
self.write_stream_footer()?;
self.inner.flush()?;
return Ok(self.inner);
}
self.work_pool.finish();
while let Some(result) = self.work_pool.get_result(|_| {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"no more work to dispatch",
))
})? {
self.write_compressed_block(
result.compressed_data,
result.checksum,
result.uncompressed_size,
)?;
}
self.write_index()?;
self.write_stream_footer()?;
self.inner.flush()?;
Ok(self.inner)
}
}
fn worker_thread_logic(
worker_handle: WorkerHandle<(u64, WorkUnit)>,
result_tx: SyncSender<(u64, ResultUnit)>,
shutdown_flag: Arc<AtomicBool>,
error_store: Arc<Mutex<Option<io::Error>>>,
active_workers: Arc<AtomicU32>,
) {
while !shutdown_flag.load(Ordering::Acquire) {
let (index, work_unit) = match worker_handle.steal() {
Some(work) => {
active_workers.fetch_add(1, Ordering::Release);
work
}
None => {
break;
}
};
let mut compressed_buffer = Vec::new();
let uncompressed_size = work_unit.uncompressed_data.len() as u64;
let mut checksum_calculator = ChecksumCalculator::new(work_unit.check_type);
checksum_calculator.update(&work_unit.uncompressed_data);
let checksum = checksum_calculator.finalize_to_bytes();
let options = Lzma2Options {
lzma_options: work_unit.lzma_options,
..Default::default()
};
let mut writer = Lzma2Writer::new(&mut compressed_buffer, options);
let result = match writer.write_all(&work_unit.uncompressed_data) {
Ok(_) => match writer.finish() {
Ok(_) => ResultUnit {
compressed_data: compressed_buffer,
checksum,
uncompressed_size,
},
Err(error) => {
active_workers.fetch_sub(1, Ordering::Release);
set_error(error, &error_store, &shutdown_flag);
return;
}
},
Err(error) => {
active_workers.fetch_sub(1, Ordering::Release);
set_error(error, &error_store, &shutdown_flag);
return;
}
};
if result_tx.send((index, result)).is_err() {
active_workers.fetch_sub(1, Ordering::Release);
return;
}
active_workers.fetch_sub(1, Ordering::Release);
}
}
impl<W: Write> Write for XzWriterMt<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let mut total_written = 0;
let mut remaining_buf = buf;
while !remaining_buf.is_empty() {
let block_remaining = self.block_size.saturating_sub(self.current_work_unit.len());
let to_write = remaining_buf.len().min(block_remaining);
if to_write > 0 {
self.current_work_unit
.extend_from_slice(&remaining_buf[..to_write]);
total_written += to_write;
remaining_buf = &remaining_buf[to_write..];
}
if self.current_work_unit.len() >= self.block_size {
self.send_work_unit()?;
}
self.drain_available_results()?;
}
Ok(total_written)
}
fn flush(&mut self) -> Result<()> {
if !self.current_work_unit.is_empty() {
self.send_work_unit()?;
}
while let Some(result) = self.work_pool.try_get_result()? {
self.write_compressed_block(
result.compressed_data,
result.checksum,
result.uncompressed_size,
)?;
}
self.inner.flush()
}
}
impl<W: Write> AutoFinish for XzWriterMt<W> {
fn finish_ignore_error(self) {
let _ = self.finish();
}
}