use std::collections::BTreeMap;
use std::io::{self, Write};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::thread::JoinHandle;
use flate2::write::GzEncoder;
#[cfg(feature = "test-hooks")]
pub mod test_hooks {
use std::sync::atomic::{AtomicU64, Ordering};
pub static PANIC_AT_POOL_OP_COUNT: AtomicU64 = AtomicU64::new(u64::MAX);
pub static POOL_OP_COUNT: AtomicU64 = AtomicU64::new(0);
pub fn reset() {
PANIC_AT_POOL_OP_COUNT.store(u64::MAX, Ordering::Relaxed);
POOL_OP_COUNT.store(0, Ordering::Relaxed);
}
}
pub const DEFAULT_CHUNK_SIZE: usize = 2 * 1024 * 1024;
const QUEUE_SLOTS_PER_WORKER: usize = 2;
fn compression_level() -> flate2::Compression {
flate2::Compression::fast()
}
pub struct ParallelGzipWriter<W: Write + Send + 'static> {
chunk_size: usize,
current: Vec<u8>,
next_seq: u64,
raw_tx: Option<SyncSender<(u64, Vec<u8>)>>,
writer_handle: Option<JoinHandle<io::Result<W>>>,
worker_handles: Vec<JoinHandle<()>>,
}
impl<W: Write + Send + 'static> ParallelGzipWriter<W> {
pub fn new(inner: W, chunk_size: usize, worker_count: usize) -> Self {
let worker_count = worker_count.max(1);
let raw_cap = worker_count * QUEUE_SLOTS_PER_WORKER;
let compressed_cap = worker_count * QUEUE_SLOTS_PER_WORKER;
let (raw_tx, raw_rx) = sync_channel::<(u64, Vec<u8>)>(raw_cap);
let (compressed_tx, compressed_rx) = sync_channel::<(u64, Vec<u8>)>(compressed_cap);
let raw_rx = std::sync::Arc::new(std::sync::Mutex::new(raw_rx));
let mut worker_handles = Vec::with_capacity(worker_count);
for _ in 0..worker_count {
let raw_rx = std::sync::Arc::clone(&raw_rx);
let compressed_tx = compressed_tx.clone();
worker_handles.push(std::thread::spawn(move || {
worker_loop(&raw_rx, &compressed_tx);
}));
}
drop(compressed_tx);
let writer_handle = std::thread::spawn(move || writer_loop(inner, compressed_rx));
Self {
chunk_size,
current: Vec::with_capacity(chunk_size),
next_seq: 0,
raw_tx: Some(raw_tx),
writer_handle: Some(writer_handle),
worker_handles,
}
}
fn flush_current(&mut self) -> io::Result<()> {
if self.current.is_empty() {
return Ok(());
}
let chunk = std::mem::replace(&mut self.current, Vec::with_capacity(self.chunk_size));
let seq = self.next_seq;
self.next_seq += 1;
let tx = self
.raw_tx
.as_ref()
.ok_or_else(|| io::Error::other("ParallelGzipWriter: already finished"))?;
tx.send((seq, chunk))
.map_err(|_| io::Error::other("ParallelGzipWriter: worker pool dropped"))?;
Ok(())
}
pub fn finish(mut self) -> io::Result<W> {
self.flush_current()?;
self.raw_tx = None;
for h in std::mem::take(&mut self.worker_handles) {
h.join()
.map_err(|_| io::Error::other("parallel gzip worker panicked"))?;
}
let handle = self
.writer_handle
.take()
.ok_or_else(|| io::Error::other("ParallelGzipWriter: writer handle missing"))?;
handle
.join()
.map_err(|_| io::Error::other("parallel gzip writer thread panicked"))?
}
}
impl<W: Write + Send + 'static> Write for ParallelGzipWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut offset = 0;
while offset < buf.len() {
let space = self.chunk_size - self.current.len();
let take = space.min(buf.len() - offset);
self.current.extend_from_slice(&buf[offset..offset + take]);
offset += take;
if self.current.len() >= self.chunk_size {
self.flush_current()?;
}
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<W: Write + Send + 'static> Drop for ParallelGzipWriter<W> {
fn drop(&mut self) {
drop(self.flush_current());
self.raw_tx = None;
for h in std::mem::take(&mut self.worker_handles) {
drop(h.join());
}
if let Some(h) = self.writer_handle.take() {
drop(h.join());
}
}
}
fn worker_loop(
raw_rx: &std::sync::Mutex<Receiver<(u64, Vec<u8>)>>,
compressed_tx: &SyncSender<(u64, Vec<u8>)>,
) {
loop {
let item = {
let guard = match raw_rx.lock() {
Ok(g) => g,
Err(_) => return,
};
guard.recv()
};
let (seq, raw) = match item {
Ok(v) => v,
Err(_) => return, };
#[cfg(feature = "test-hooks")]
{
use std::sync::atomic::Ordering;
let n = test_hooks::POOL_OP_COUNT.fetch_add(1, Ordering::Relaxed) + 1;
if n == test_hooks::PANIC_AT_POOL_OP_COUNT.load(Ordering::Relaxed) {
panic!("test-hooks: parallel_gzip pool op #{n} (seq {seq}) panicking");
}
}
let compressed = match compress_one(&raw) {
Ok(v) => v,
Err(_) => return, };
if compressed_tx.send((seq, compressed)).is_err() {
return;
}
}
}
fn compress_one(raw: &[u8]) -> io::Result<Vec<u8>> {
let mut enc = GzEncoder::new(Vec::with_capacity(raw.len() / 2), compression_level());
enc.write_all(raw)?;
enc.finish()
}
#[allow(clippy::needless_pass_by_value)]
fn writer_loop<W: Write>(
mut inner: W,
compressed_rx: Receiver<(u64, Vec<u8>)>,
) -> io::Result<W> {
let mut pending: BTreeMap<u64, Vec<u8>> = BTreeMap::new();
let mut expected: u64 = 0;
while let Ok((seq, bytes)) = compressed_rx.recv() {
pending.insert(seq, bytes);
while let Some(b) = pending.remove(&expected) {
inner.write_all(&b)?;
expected += 1;
}
}
if !pending.is_empty() {
return Err(io::Error::other(format!(
"parallel gzip writer: {} chunks missing at seq {expected}",
pending.len()
)));
}
inner.flush()?;
Ok(inner)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use flate2::read::MultiGzDecoder;
use std::io::Read;
fn roundtrip(bytes: &[u8], chunk_size: usize, workers: usize) -> Vec<u8> {
let sink: Vec<u8> = Vec::new();
let mut gz = ParallelGzipWriter::new(sink, chunk_size, workers);
gz.write_all(bytes).unwrap();
let compressed = gz.finish().unwrap();
let mut out = Vec::new();
MultiGzDecoder::new(compressed.as_slice()).read_to_end(&mut out).unwrap();
out
}
#[test]
fn empty_stream_produces_empty_output() {
let sink: Vec<u8> = Vec::new();
let mut gz = ParallelGzipWriter::new(sink, 1024, 2);
gz.write_all(&[]).unwrap();
let compressed = gz.finish().unwrap();
assert!(compressed.is_empty(), "empty input should produce no output");
}
#[test]
fn small_stream_single_chunk() {
let payload = b"hello world!".repeat(32);
let out = roundtrip(&payload, 4096, 2);
assert_eq!(out, payload);
}
#[test]
fn stream_spans_multiple_chunks() {
#[allow(clippy::cast_possible_truncation)]
let payload: Vec<u8> = (0..1_000_000u32).map(|i| (i % 251) as u8).collect();
let out = roundtrip(&payload, 64 * 1024, 4);
assert_eq!(out, payload);
}
#[test]
fn concatenated_members_decode_as_one() {
let sink: Vec<u8> = Vec::new();
let mut gz = ParallelGzipWriter::new(sink, 1024, 2);
for i in 0..5u8 {
let block = [i; 1024];
gz.write_all(&block).unwrap();
}
let compressed = gz.finish().unwrap();
let mut out = Vec::new();
MultiGzDecoder::new(compressed.as_slice()).read_to_end(&mut out).unwrap();
assert_eq!(out.len(), 5 * 1024);
for (i, chunk) in out.chunks(1024).enumerate() {
#[allow(clippy::cast_possible_truncation)]
let want = i as u8;
assert!(chunk.iter().all(|&b| b == want), "chunk {i} not all {want}");
}
}
#[test]
fn finish_returns_inner_writer() {
let sink: Vec<u8> = Vec::new();
let mut gz = ParallelGzipWriter::new(sink, 4096, 2);
gz.write_all(b"abc").unwrap();
gz.write_all(b"def").unwrap();
let bytes = gz.finish().unwrap();
let mut out = Vec::new();
MultiGzDecoder::new(bytes.as_slice()).read_to_end(&mut out).unwrap();
assert_eq!(out, b"abcdef");
}
#[test]
fn write_at_exact_chunk_boundary() {
let sink: Vec<u8> = Vec::new();
let mut gz = ParallelGzipWriter::new(sink, 128, 2);
gz.write_all(&[7u8; 128]).unwrap();
gz.write_all(&[9u8; 128]).unwrap();
gz.write_all(&[11u8; 64]).unwrap();
let bytes = gz.finish().unwrap();
let mut out = Vec::new();
MultiGzDecoder::new(bytes.as_slice()).read_to_end(&mut out).unwrap();
assert_eq!(out.len(), 128 + 128 + 64);
assert!(out[0..128].iter().all(|&b| b == 7));
assert!(out[128..256].iter().all(|&b| b == 9));
assert!(out[256..320].iter().all(|&b| b == 11));
}
#[test]
fn many_workers_preserve_order() {
let mut payload = Vec::with_capacity(200_000);
for i in 0..200_000u32 {
payload.extend_from_slice(&i.to_le_bytes());
}
let out = roundtrip(&payload, 1024, 16);
assert_eq!(out, payload);
}
}