use crossbeam_channel::{Receiver, Sender, bounded};
use libdeflater::{CompressionLvl, Compressor};
use std::collections::BTreeMap;
use std::io::{self, Write};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
const DEFAULT_BLOCK_SIZE: usize = 65536;
struct UncompressedBlock {
serial: u64,
data: Vec<u8>,
}
enum CompressedBlock {
Ok { serial: u64, data: Vec<u8> },
Err { serial: u64, error: String },
}
#[allow(clippy::needless_pass_by_value)]
fn compress_block(
block: UncompressedBlock,
compressor: &mut Compressor,
) -> io::Result<CompressedBlock> {
let uncompressed = &block.data;
let max_compressed = compressor.gzip_compress_bound(uncompressed.len());
let mut compressed_data = vec![0u8; max_compressed];
let compressed_len = compressor
.gzip_compress(uncompressed, &mut compressed_data)
.map_err(|e| io::Error::other(format!("Gzip compression failed: {e:?}")))?;
compressed_data.truncate(compressed_len);
Ok(CompressedBlock::Ok { serial: block.serial, data: compressed_data })
}
#[derive(Debug, Clone)]
pub struct ParallelGzipConfig {
pub compression_threads: usize,
pub compression_level: u8,
pub queue_size: Option<usize>,
pub block_size: usize,
}
impl Default for ParallelGzipConfig {
fn default() -> Self {
Self {
compression_threads: 4,
compression_level: 6,
queue_size: None,
block_size: DEFAULT_BLOCK_SIZE,
}
}
}
impl ParallelGzipConfig {
#[must_use]
pub fn with_threads(threads: usize) -> Self {
Self { compression_threads: threads.max(1), ..Default::default() }
}
fn effective_queue_size(&self) -> usize {
self.queue_size.unwrap_or(self.compression_threads * 2)
}
}
pub struct ParallelGzipWriter {
block_buffer: Vec<u8>,
block_size: usize,
next_serial: u64,
compress_tx: Sender<UncompressedBlock>,
compression_handles: Vec<JoinHandle<()>>,
io_handle: Option<JoinHandle<io::Result<()>>>,
}
impl ParallelGzipWriter {
pub fn new<W>(writer: W, config: &ParallelGzipConfig) -> io::Result<Self>
where
W: Write + Send + 'static,
{
let queue_size = config.effective_queue_size();
let compression_level = CompressionLvl::new(i32::from(config.compression_level))
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{e:?}")))?;
let (compress_tx, compress_rx) = bounded::<UncompressedBlock>(queue_size);
let (output_tx, output_rx) = bounded::<CompressedBlock>(queue_size);
let compress_rx = Arc::new(compress_rx);
let output_tx = Arc::new(output_tx);
let mut compression_handles = Vec::with_capacity(config.compression_threads);
for _ in 0..config.compression_threads {
let rx = Arc::clone(&compress_rx);
let tx = Arc::clone(&output_tx);
let level = compression_level;
let handle = thread::spawn(move || {
let mut compressor = Compressor::new(level);
while let Ok(block) = rx.recv() {
let serial = block.serial;
let result = match compress_block(block, &mut compressor) {
Ok(compressed) => compressed,
Err(e) => CompressedBlock::Err { serial, error: e.to_string() },
};
if tx.send(result).is_err() {
break;
}
}
});
compression_handles.push(handle);
}
drop(compress_rx);
drop(output_tx);
let io_handle =
thread::spawn(move || -> io::Result<()> { Self::io_writer_loop(writer, output_rx) });
Ok(Self {
block_buffer: Vec::with_capacity(config.block_size),
block_size: config.block_size,
next_serial: 0,
compress_tx,
compression_handles,
io_handle: Some(io_handle),
})
}
#[allow(clippy::needless_pass_by_value)]
fn io_writer_loop<W: Write>(
mut writer: W,
output_rx: Receiver<CompressedBlock>,
) -> io::Result<()> {
let mut next_expected: u64 = 0;
let mut pending: BTreeMap<u64, CompressedBlock> = BTreeMap::new();
while let Ok(block) = output_rx.recv() {
let serial = match &block {
CompressedBlock::Ok { serial, .. } | CompressedBlock::Err { serial, .. } => *serial,
};
pending.insert(serial, block);
while let Some(block) = pending.remove(&next_expected) {
match block {
CompressedBlock::Ok { data, .. } => writer.write_all(&data)?,
CompressedBlock::Err { error, .. } => {
return Err(io::Error::other(format!(
"compression failed for block {next_expected}: {error}"
)));
}
}
next_expected += 1;
}
}
for (_, block) in pending {
match block {
CompressedBlock::Ok { data, .. } => writer.write_all(&data)?,
CompressedBlock::Err { serial, error } => {
return Err(io::Error::other(format!(
"compression failed for block {serial}: {error}"
)));
}
}
}
writer.flush()?;
Ok(())
}
fn dispatch_block(&mut self) -> io::Result<()> {
if self.block_buffer.is_empty() {
return Ok(());
}
let block = UncompressedBlock {
serial: self.next_serial,
data: std::mem::replace(&mut self.block_buffer, Vec::with_capacity(self.block_size)),
};
self.next_serial += 1;
self.compress_tx
.send(block)
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "Compression channel closed"))
}
pub fn finish(mut self) -> io::Result<()> {
self.dispatch_block()?;
drop(self.compress_tx);
for handle in self.compression_handles {
handle.join().map_err(|_| io::Error::other("Compression worker thread panicked"))?;
}
if let Some(handle) = self.io_handle.take() {
handle.join().map_err(|_| io::Error::other("I/O thread panicked"))??;
}
Ok(())
}
}
impl Write for ParallelGzipWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.block_buffer.extend_from_slice(buf);
if self.block_buffer.len() >= self.block_size {
self.dispatch_block()?;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.dispatch_block()
}
}
#[cfg(test)]
mod tests {
use super::*;
use flate2::read::MultiGzDecoder;
use std::fs::File;
use std::io::Read;
use tempfile::NamedTempFile;
#[test]
fn test_basic_compression() -> io::Result<()> {
let temp = NamedTempFile::new()?;
let path = temp.path().to_path_buf();
{
let file = File::create(&path)?;
let config = ParallelGzipConfig::with_threads(2);
let mut writer = ParallelGzipWriter::new(file, &config)?;
writer.write_all(b"Hello, World!")?;
writer.finish()?;
}
let file = File::open(&path)?;
let mut decoder = MultiGzDecoder::new(file);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed)?;
assert_eq!(decompressed, "Hello, World!");
Ok(())
}
#[test]
fn test_multi_block_compression() -> io::Result<()> {
let temp = NamedTempFile::new()?;
let path = temp.path().to_path_buf();
let test_data = "ACGT".repeat(100_000);
{
let file = File::create(&path)?;
let config = ParallelGzipConfig {
compression_threads: 4,
block_size: 16384, ..Default::default()
};
let mut writer = ParallelGzipWriter::new(file, &config)?;
writer.write_all(test_data.as_bytes())?;
writer.finish()?;
}
let file = File::open(&path)?;
let mut decoder = MultiGzDecoder::new(file);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed)?;
assert_eq!(decompressed, test_data);
Ok(())
}
#[test]
fn test_single_thread() -> io::Result<()> {
let temp = NamedTempFile::new()?;
let path = temp.path().to_path_buf();
{
let file = File::create(&path)?;
let config = ParallelGzipConfig::with_threads(1);
let mut writer = ParallelGzipWriter::new(file, &config)?;
writer.write_all(b"Single thread test")?;
writer.finish()?;
}
let file = File::open(&path)?;
let mut decoder = MultiGzDecoder::new(file);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed)?;
assert_eq!(decompressed, "Single thread test");
Ok(())
}
}