use crate::compress::parallel::{adjust_compression_level, GzipHeaderInfo};
use crate::infra::scheduler::compress_parallel;
use flate2::{Compress, Compression, FlushCompress, Status};
use std::cell::{RefCell, UnsafeCell};
use std::io::{self, Read, Write};
use std::mem::MaybeUninit;
use std::path::Path;
thread_local! {
static PIPELINED_COMPRESS: RefCell<Option<(u32, Compress)>> = const { RefCell::new(None) };
}
const DEFAULT_BLOCK_SIZE: usize = 128 * 1024;
const DICT_SIZE: usize = 32 * 1024;
struct CrcSlot(UnsafeCell<MaybeUninit<crc32fast::Hasher>>);
unsafe impl Sync for CrcSlot {}
#[inline]
fn pipelined_block_size(input_len: usize, _num_threads: usize, _level: u32) -> usize {
if input_len >= 50 * 1024 * 1024 {
192 * 1024
} else {
DEFAULT_BLOCK_SIZE
}
}
pub struct PipelinedGzEncoder {
compression_level: u32,
num_threads: usize,
header_info: GzipHeaderInfo,
}
impl PipelinedGzEncoder {
pub fn new(compression_level: u32, num_threads: usize) -> Self {
Self {
compression_level,
num_threads,
header_info: GzipHeaderInfo::default(),
}
}
pub fn set_header_info(&mut self, info: GzipHeaderInfo) {
self.header_info = info;
}
fn gz_builder(&self) -> flate2::GzBuilder {
let mut builder = flate2::GzBuilder::new();
if let Some(ref name) = self.header_info.filename {
builder = builder.filename(name.as_bytes());
}
builder = builder.mtime(self.header_info.mtime);
if let Some(ref comment) = self.header_info.comment {
builder = builder.comment(comment.as_bytes());
}
builder
}
pub fn compress_buffer<W: Write + Send>(&self, data: &[u8], writer: W) -> io::Result<u64> {
if data.is_empty() {
let encoder = self
.gz_builder()
.write(writer, Compression::new(self.compression_level));
encoder.finish()?;
return Ok(0);
}
if self.num_threads > 1 {
self.compress_parallel_pipeline(data, writer)?;
} else {
self.compress_sequential(data, writer)?;
}
Ok(data.len() as u64)
}
pub fn compress<R: Read, W: Write + Send>(&self, mut reader: R, writer: W) -> io::Result<u64> {
let mut input_data = Vec::new();
let bytes_read = reader.read_to_end(&mut input_data)? as u64;
if input_data.is_empty() {
let encoder = self
.gz_builder()
.write(writer, Compression::new(self.compression_level));
encoder.finish()?;
return Ok(0);
}
if self.num_threads > 1 {
self.compress_parallel_pipeline(&input_data, writer)?;
} else {
self.compress_sequential(&input_data, writer)?;
}
Ok(bytes_read)
}
pub fn compress_file<P: AsRef<Path>, W: Write + Send>(
&self,
path: P,
writer: W,
) -> io::Result<u64> {
use memmap2::Mmap;
use std::fs::File;
let file = File::open(path.as_ref())?;
let file_len = file.metadata()?.len() as usize;
if file_len == 0 {
let encoder = self
.gz_builder()
.write(writer, Compression::new(self.compression_level));
encoder.finish()?;
return Ok(0);
}
let mmap = unsafe { Mmap::map(&file)? };
#[cfg(unix)]
{
let _ = mmap.advise(memmap2::Advice::Sequential);
}
if self.num_threads > 1 {
self.compress_parallel_pipeline(&mmap, writer)?;
} else {
self.compress_sequential(&mmap, writer)?;
}
Ok(file_len as u64)
}
fn compress_parallel_pipeline<W: Write + Send>(
&self,
data: &[u8],
mut writer: W,
) -> io::Result<()> {
let level = adjust_compression_level(self.compression_level);
let data_len = data.len();
let block_size = pipelined_block_size(data_len, self.num_threads, level);
let num_blocks = data_len.div_ceil(block_size);
let mut header = Vec::with_capacity(64);
let mut flags: u8 = 0x00;
if self.header_info.filename.is_some() {
flags |= 0x08;
}
if self.header_info.comment.is_some() {
flags |= 0x10;
}
header.extend_from_slice(&[0x1f, 0x8b, 0x08, flags]);
header.extend_from_slice(&self.header_info.mtime.to_le_bytes());
header.extend_from_slice(&[0x00, 0xff]);
if let Some(ref name) = self.header_info.filename {
header.extend_from_slice(name.as_bytes());
header.push(0);
}
if let Some(ref comment) = self.header_info.comment {
header.extend_from_slice(comment.as_bytes());
header.push(0);
}
writer.write_all(&header)?;
let crc_parts: Vec<CrcSlot> = (0..num_blocks)
.map(|_| CrcSlot(UnsafeCell::new(MaybeUninit::uninit())))
.collect();
let mut writer = compress_parallel(
data,
block_size,
self.num_threads,
writer,
|block_idx, block, dict, is_last, output| {
compress_block_with_dict(block, dict, level, block_size, is_last, output);
let mut hasher = crc32fast::Hasher::new();
hasher.update(block);
unsafe {
*crc_parts[block_idx].0.get() = MaybeUninit::new(hasher);
}
},
)?;
let mut combined_hasher = crc32fast::Hasher::new();
for part in &crc_parts {
let hasher = unsafe { (*part.0.get()).assume_init_read() };
combined_hasher.combine(&hasher);
}
let combined_crc = combined_hasher.finalize();
let isize = (data_len as u32).to_le_bytes();
writer.write_all(&combined_crc.to_le_bytes())?;
writer.write_all(&isize)?;
Ok(())
}
fn compress_sequential<W: Write>(&self, data: &[u8], mut writer: W) -> io::Result<()> {
use crc32fast::Hasher;
let level = adjust_compression_level(self.compression_level);
let mut header = Vec::with_capacity(64);
let mut flags: u8 = 0x00;
if self.header_info.filename.is_some() {
flags |= 0x08;
}
if self.header_info.comment.is_some() {
flags |= 0x10;
}
header.extend_from_slice(&[0x1f, 0x8b, 0x08, flags]);
header.extend_from_slice(&self.header_info.mtime.to_le_bytes());
header.extend_from_slice(&[0x00, 0xff]);
if let Some(ref name) = self.header_info.filename {
header.extend_from_slice(name.as_bytes());
header.push(0);
}
if let Some(ref comment) = self.header_info.comment {
header.extend_from_slice(comment.as_bytes());
header.push(0);
}
writer.write_all(&header)?;
let mut compress = Compress::new(Compression::new(level), false);
let block_size = pipelined_block_size(data.len(), 1, level);
let mut output_buf = vec![0u8; block_size * 2];
let mut crc_hasher = Hasher::new();
let blocks: Vec<&[u8]> = data.chunks(block_size).collect();
for (i, block) in blocks.iter().enumerate() {
crc_hasher.update(block);
if i > 0 {
let prev = blocks[i - 1];
let dict = if prev.len() > DICT_SIZE {
&prev[prev.len() - DICT_SIZE..]
} else {
prev
};
let _ = compress.set_dictionary(dict);
}
let flush = if i == blocks.len() - 1 {
FlushCompress::Finish
} else {
FlushCompress::Sync
};
let mut block_data = *block;
loop {
let before_in = compress.total_in();
let before_out = compress.total_out();
let status = compress.compress(block_data, &mut output_buf, flush)?;
let consumed = (compress.total_in() - before_in) as usize;
let produced = (compress.total_out() - before_out) as usize;
if produced > 0 {
writer.write_all(&output_buf[..produced])?;
}
block_data = &block_data[consumed..];
match status {
Status::Ok if block_data.is_empty() && flush != FlushCompress::Finish => break,
Status::BufError if produced == 0 => break,
Status::StreamEnd => break,
_ => {}
}
}
}
let crc = crc_hasher.finalize();
writer.write_all(&crc.to_le_bytes())?;
writer.write_all(&(data.len() as u32).to_le_bytes())?;
Ok(())
}
}
fn compress_block_with_dict(
block: &[u8],
dict: Option<&[u8]>,
level: u32,
block_size: usize,
is_last: bool,
output: &mut Vec<u8>,
) {
PIPELINED_COMPRESS.with(|comp_cell| {
let mut comp_opt = comp_cell.borrow_mut();
output.clear();
let initial_capacity = block_size + (block_size / 10) + 1024;
if output.capacity() < initial_capacity {
output.reserve(initial_capacity - output.capacity());
}
let compress = match comp_opt.as_mut() {
Some((cached_level, comp)) if *cached_level == level => {
comp.reset();
comp
}
_ => {
*comp_opt = Some((level, Compress::new(Compression::new(level), false)));
&mut comp_opt.as_mut().unwrap().1
}
};
if let Some(d) = dict {
let dict_slice = if d.len() > DICT_SIZE {
&d[d.len() - DICT_SIZE..]
} else {
d
};
let _ = compress.set_dictionary(dict_slice);
}
let flush = if is_last {
FlushCompress::Finish
} else {
FlushCompress::Sync
};
let mut input = block;
loop {
let before_in = compress.total_in();
let status = compress
.compress_vec(input, output, flush)
.expect("compression failed");
let consumed = (compress.total_in() - before_in) as usize;
input = &input[consumed..];
match status {
Status::Ok if input.is_empty() && flush != FlushCompress::Finish => break,
Status::BufError => {
let extra = output.capacity().max(1024);
output.reserve(extra);
}
Status::StreamEnd => break,
_ => {}
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use flate2::read::GzDecoder;
use std::io::Read;
#[test]
fn test_pipelined_compress() {
let data = b"Hello, world! ".repeat(10000);
let mut output = Vec::new();
let encoder = PipelinedGzEncoder::new(9, 4);
encoder
.compress(std::io::Cursor::new(&data), &mut output)
.unwrap();
let mut decoder = GzDecoder::new(&output[..]);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_pipelined_vs_parallel_size() {
use crate::compress::parallel::ParallelGzEncoder;
let data = b"The quick brown fox jumps over the lazy dog. ".repeat(5000);
let mut pipelined_output = Vec::new();
let pipelined = PipelinedGzEncoder::new(9, 4);
pipelined
.compress(std::io::Cursor::new(&data), &mut pipelined_output)
.unwrap();
let mut parallel_output = Vec::new();
let parallel = ParallelGzEncoder::new(9, 4);
parallel
.compress(std::io::Cursor::new(&data), &mut parallel_output)
.unwrap();
println!(
"Pipelined: {} bytes, Parallel: {} bytes",
pipelined_output.len(),
parallel_output.len()
);
assert!(
pipelined_output.len() <= parallel_output.len(),
"Pipelined should produce smaller output"
);
}
}