use super::types::{
CompressionStreamResult, TransportCompressionConfig, TransportCompressionError,
TransportCompressionStats,
};
use crate::compression::{CompressedPayload, CompressionAlgorithm};
use anyhow::Result;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::warn;
#[derive(Debug)]
pub struct AsyncCompressor {
config: TransportCompressionConfig,
stats: Arc<Mutex<TransportCompressionStats>>,
}
impl AsyncCompressor {
pub fn new(config: TransportCompressionConfig) -> Self {
Self {
config,
stats: Arc::new(Mutex::new(TransportCompressionStats::new())),
}
}
pub fn with_stats(
config: TransportCompressionConfig,
stats: Arc<Mutex<TransportCompressionStats>>,
) -> Self {
Self { config, stats }
}
pub async fn compress(&self, data: &[u8]) -> Result<CompressedPayload> {
let start = std::time::Instant::now();
if data.len() < self.config.compression_threshold {
let mut stats = self.stats.lock().map_err(|e| {
TransportCompressionError::StreamFailed(format!(
"Failed to acquire stats lock: {}",
e
))
})?;
stats.base.record_skipped();
return Ok(CompressedPayload {
original_size: data.len(),
compressed_size: data.len(),
compression_ratio: 1.0,
data: data.to_vec(),
algorithm: CompressionAlgorithm::None,
});
}
if data.len() > self.config.max_compressed_size {
return Err(TransportCompressionError::ThresholdExceeded(data.len()).into());
}
let algorithm = if self.config.auto_algorithm_selection {
self.select_algorithm(data)?
} else {
self.config.preferred_algorithm
};
let payload = match algorithm {
CompressionAlgorithm::None => CompressedPayload {
original_size: data.len(),
compressed_size: data.len(),
compression_ratio: 1.0,
data: data.to_vec(),
algorithm: CompressionAlgorithm::None,
},
CompressionAlgorithm::Lz4 => CompressedPayload::compress_lz4(data)?,
CompressionAlgorithm::Zstd => CompressedPayload::compress_zstd(data)?,
CompressionAlgorithm::Gzip => CompressedPayload::compress_gzip(data)?,
};
let elapsed = start.elapsed().as_micros() as u64;
let mut stats = self.stats.lock().map_err(|e| {
TransportCompressionError::StreamFailed(format!("Failed to acquire stats lock: {}", e))
})?;
stats
.base
.record_compression(payload.original_size, payload.compressed_size, elapsed);
if payload.compression_ratio > self.config.warning_ratio_threshold {
stats.record_warning_threshold();
warn!(
"Poor compression ratio: {:.2}%",
payload.compression_ratio * 100.0
);
}
if payload.compression_ratio > self.config.min_acceptable_ratio
&& algorithm != CompressionAlgorithm::Lz4
{
if let Ok(lz4_payload) = CompressedPayload::compress_lz4(data) {
if lz4_payload.compression_ratio < payload.compression_ratio {
stats.record_algorithm_fallback();
return Ok(lz4_payload);
}
}
}
stats.record_compression_time(elapsed);
Ok(payload)
}
pub async fn decompress(&self, payload: &CompressedPayload) -> Result<Vec<u8>> {
let start = std::time::Instant::now();
let result = payload.decompress();
let elapsed = start.elapsed().as_micros() as u64;
let mut stats = self.stats.lock().map_err(|e| {
TransportCompressionError::StreamFailed(format!("Failed to acquire stats lock: {}", e))
})?;
stats.base.record_decompression(elapsed);
stats.record_decompression_time(elapsed);
result.map_err(|e| anyhow::anyhow!(e))
}
pub async fn compress_stream<R, W>(
&self,
reader: &mut R,
writer: &mut W,
) -> Result<CompressionStreamResult>
where
R: AsyncReadExt + Unpin,
W: AsyncWriteExt + Unpin,
{
let start = std::time::Instant::now();
let mut content = Vec::new();
reader
.read_to_end(&mut content)
.await
.map_err(|e| TransportCompressionError::StreamFailed(e.to_string()))?;
let compressed = self.compress(&content).await?;
let header = Self::create_stream_header(&compressed);
writer
.write_all(&header)
.await
.map_err(|e| TransportCompressionError::StreamFailed(e.to_string()))?;
writer
.write_all(&compressed.data)
.await
.map_err(|e| TransportCompressionError::StreamFailed(e.to_string()))?;
writer
.flush()
.await
.map_err(|e| TransportCompressionError::StreamFailed(e.to_string()))?;
let total_read = content.len();
let total_written = header.len() + compressed.data.len();
let elapsed = start.elapsed().as_micros() as u64;
let mut stats = self.stats.lock().map_err(|e| {
TransportCompressionError::StreamFailed(format!("Failed to acquire stats lock: {}", e))
})?;
stats.record_streaming_compression(total_read, total_written);
Ok(CompressionStreamResult {
original_size: total_read,
compressed_size: total_written,
algorithm: compressed.algorithm,
compression_ratio: compressed.compression_ratio,
time_us: elapsed,
})
}
pub async fn decompress_stream<R, W>(&self, reader: &mut R, writer: &mut W) -> Result<usize>
where
R: AsyncReadExt + Unpin,
W: AsyncWriteExt + Unpin,
{
let start = std::time::Instant::now();
let mut header_buf = [0u8; 17];
reader
.read_exact(&mut header_buf)
.await
.map_err(|_e| TransportCompressionError::InvalidHeader)?;
let (original_size, compressed_size, algorithm) = Self::parse_stream_header(&header_buf)?;
let mut compressed_data = Vec::with_capacity(compressed_size);
reader
.read_to_end(&mut compressed_data)
.await
.map_err(|e| TransportCompressionError::StreamFailed(e.to_string()))?;
if compressed_data.len() != compressed_size {
return Err(TransportCompressionError::SizeMismatch {
expected: compressed_size,
actual: compressed_data.len(),
}
.into());
}
let payload = CompressedPayload {
original_size,
compressed_size,
compression_ratio: compressed_size as f64 / original_size as f64,
data: compressed_data,
algorithm,
};
let decompressed = self.decompress(&payload).await?;
writer
.write_all(&decompressed)
.await
.map_err(|e| TransportCompressionError::StreamFailed(e.to_string()))?;
writer
.flush()
.await
.map_err(|e| TransportCompressionError::StreamFailed(e.to_string()))?;
let elapsed = start.elapsed().as_micros() as u64;
let mut stats = self.stats.lock().map_err(|e| {
TransportCompressionError::StreamFailed(format!("Failed to acquire stats lock: {}", e))
})?;
stats.record_streaming_decompression();
stats.record_decompression_time(elapsed);
Ok(decompressed.len())
}
fn select_algorithm(&self, data: &[u8]) -> Result<CompressionAlgorithm> {
if data.len() < self.config.compression_threshold {
return Ok(CompressionAlgorithm::None);
}
if data.len() > 1024 * 1024 {
#[cfg(feature = "compression-lz4")]
return Ok(CompressionAlgorithm::Lz4);
#[cfg(not(feature = "compression-lz4"))]
return Ok(CompressionAlgorithm::Gzip);
}
#[cfg(feature = "compression-zstd")]
return Ok(CompressionAlgorithm::Zstd);
#[cfg(not(feature = "compression-zstd"))]
#[cfg(feature = "compression-lz4")]
return Ok(CompressionAlgorithm::Lz4);
#[cfg(not(feature = "compression-zstd"))]
#[cfg(not(feature = "compression-lz4"))]
return Ok(CompressionAlgorithm::Gzip);
}
pub fn create_stream_header(payload: &CompressedPayload) -> [u8; 17] {
let mut header = [0u8; 17];
header[0..8].copy_from_slice(&(payload.original_size as u64).to_le_bytes());
header[8..16].copy_from_slice(&(payload.compressed_size as u64).to_le_bytes());
header[16] = match payload.algorithm {
CompressionAlgorithm::None => 0,
CompressionAlgorithm::Lz4 => 1,
CompressionAlgorithm::Zstd => 2,
CompressionAlgorithm::Gzip => 3,
};
header
}
fn parse_stream_header(header: &[u8; 17]) -> Result<(usize, usize, CompressionAlgorithm)> {
let original_size = u64::from_le_bytes(
header[0..8]
.try_into()
.expect("Header slice must be exactly 8 bytes"),
) as usize;
let compressed_size = u64::from_le_bytes(
header[8..16]
.try_into()
.expect("Header slice must be exactly 8 bytes"),
) as usize;
let algorithm = match header[16] {
0 => CompressionAlgorithm::None,
1 => CompressionAlgorithm::Lz4,
2 => CompressionAlgorithm::Zstd,
3 => CompressionAlgorithm::Gzip,
_ => {
return Err(TransportCompressionError::InvalidHeader.into());
}
};
Ok((original_size, compressed_size, algorithm))
}
pub fn stats(&self) -> TransportCompressionStats {
self.stats
.lock()
.map(|guard| guard.clone())
.unwrap_or_else(|e| {
tracing::error!("Failed to acquire stats lock for reading: {}", e);
TransportCompressionStats::new()
})
}
pub fn reset_stats(&self) {
if let Ok(mut stats) = self.stats.lock() {
*stats = TransportCompressionStats::new();
} else {
tracing::error!("Failed to acquire stats lock for reset");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
fn test_config_no_compression() -> TransportCompressionConfig {
TransportCompressionConfig {
compression_threshold: 1024,
auto_algorithm_selection: false,
preferred_algorithm: CompressionAlgorithm::None,
..Default::default()
}
}
#[tokio::test]
async fn compress_decompress_roundtrip_without_compression() {
let compressor = AsyncCompressor::new(test_config_no_compression());
let data = b"hello-world";
let payload = compressor.compress(data).await.expect("compress succeeds");
assert_eq!(payload.algorithm, CompressionAlgorithm::None);
let decompressed = compressor
.decompress(&payload)
.await
.expect("decompress succeeds");
assert_eq!(decompressed, data);
}
#[test]
fn create_and_parse_stream_header() {
let payload = CompressedPayload {
original_size: 100,
compressed_size: 40,
compression_ratio: 0.4,
data: vec![0; 40],
algorithm: CompressionAlgorithm::Zstd,
};
let header = AsyncCompressor::create_stream_header(&payload);
let (original, compressed, algorithm) =
AsyncCompressor::parse_stream_header(&header).expect("header parses");
assert_eq!(original, payload.original_size);
assert_eq!(compressed, payload.compressed_size);
assert_eq!(algorithm, CompressionAlgorithm::Zstd);
}
#[tokio::test]
async fn stream_roundtrip_without_compression() {
let compressor = AsyncCompressor::new(test_config_no_compression());
let input = b"streamed-payload";
let (mut input_reader, mut input_writer) = tokio::io::duplex(2048);
input_writer.write_all(input).await.expect("write input");
input_writer.shutdown().await.expect("close input");
let (mut compressed_reader, mut compressed_writer) = tokio::io::duplex(4096);
let result = compressor
.compress_stream(&mut input_reader, &mut compressed_writer)
.await
.expect("compress stream");
compressed_writer
.shutdown()
.await
.expect("close compressed");
let mut compressed_bytes = Vec::new();
compressed_reader
.read_to_end(&mut compressed_bytes)
.await
.expect("read compressed");
let (mut compressed_reader, mut compressed_writer) = tokio::io::duplex(4096);
compressed_writer
.write_all(&compressed_bytes)
.await
.expect("write compressed");
compressed_writer
.shutdown()
.await
.expect("close compressed input");
let (mut output_reader, mut output_writer) = tokio::io::duplex(4096);
let decompressed_size = compressor
.decompress_stream(&mut compressed_reader, &mut output_writer)
.await
.expect("decompress stream");
output_writer.shutdown().await.expect("close output");
let mut output = Vec::new();
output_reader
.read_to_end(&mut output)
.await
.expect("read output");
assert_eq!(decompressed_size, input.len());
assert_eq!(output, input);
assert_eq!(result.original_size, input.len());
}
}