use super::{Transport, TransportMetadata, TransportResponse};
use crate::compression::CompressionAlgorithm;
use crate::transport::compression::{
TransportCompressionConfig, TransportCompressionError, TransportCompressionStats,
};
#[path = "decompression.rs"]
mod decompression;
#[path = "wrapper_helpers.rs"]
mod helpers;
use async_trait::async_trait;
use helpers::CompressionResult;
use std::fmt::Debug;
use std::sync::{Arc, Mutex};
use tracing::{debug, trace, warn};
#[derive(Debug)]
pub struct CompressedTransport {
inner: Box<dyn Transport>,
config: TransportCompressionConfig,
stats: Arc<Mutex<TransportCompressionStats>>,
compression_level: i32,
min_compress_size: usize,
}
impl CompressedTransport {
pub fn new(inner: Box<dyn Transport>, config: TransportCompressionConfig) -> Self {
let stats = Arc::new(Mutex::new(TransportCompressionStats::new()));
Self {
inner,
config: config.clone(),
stats,
compression_level: 3, min_compress_size: config.compression_threshold,
}
}
pub fn with_level(
inner: Box<dyn Transport>,
config: TransportCompressionConfig,
level: i32,
) -> Self {
let mut transport = Self::new(inner, config);
transport.compression_level = level.clamp(1, 22);
transport
}
pub fn with_threshold(
inner: Box<dyn Transport>,
config: TransportCompressionConfig,
threshold: usize,
) -> Self {
let mut transport = Self::new(inner, config);
transport.min_compress_size = threshold;
transport
}
fn compress_data(
&self,
data: &[u8],
) -> anyhow::Result<(CompressionResult, TransportCompressionStats)> {
let start = std::time::Instant::now();
let original_size = data.len();
if original_size < self.min_compress_size {
trace!(
"Data size {} < threshold {}, skipping compression",
original_size, self.min_compress_size
);
let mut stats = TransportCompressionStats::new();
stats.base.record_skipped();
return Ok((
CompressionResult {
data: data.to_vec(),
original_size,
compressed_size: original_size,
algorithm: CompressionAlgorithm::None,
compression_ratio: 1.0,
},
stats,
));
}
match self.compress_with_zstd(data) {
Ok(result) => {
let elapsed = start.elapsed().as_micros() as u64;
let mut stats = TransportCompressionStats::new();
stats
.base
.record_compression(original_size, result.compressed_size, elapsed);
stats.record_compression_time(elapsed);
let savings = (1.0 - result.compression_ratio) * 100.0;
debug!(
"Compressed {} bytes to {} bytes ({:.1}% savings, ratio: {:.2})",
original_size, result.compressed_size, savings, result.compression_ratio
);
if result.compression_ratio > self.config.warning_ratio_threshold {
stats.record_warning_threshold();
warn!(
"Poor compression ratio: {:.2}% for {} bytes",
result.compression_ratio * 100.0,
original_size
);
}
Ok((result, stats))
}
Err(e) => {
warn!("Compression failed, falling back to uncompressed: {}", e);
let mut stats = TransportCompressionStats::new();
stats.base.record_failed();
Ok((
CompressionResult {
data: data.to_vec(),
original_size,
compressed_size: original_size,
algorithm: CompressionAlgorithm::None,
compression_ratio: 1.0,
},
stats,
))
}
}
}
fn compress_with_zstd(&self, data: &[u8]) -> anyhow::Result<CompressionResult> {
#[cfg(feature = "compression-zstd")]
{
let compressed = zstd::stream::encode_all(data, self.compression_level)?;
let compressed_size = compressed.len();
let original_size = data.len();
let ratio = compressed_size as f64 / original_size as f64;
Ok(CompressionResult {
data: compressed,
original_size,
compressed_size,
algorithm: CompressionAlgorithm::Zstd,
compression_ratio: ratio,
})
}
#[cfg(not(feature = "compression-zstd"))]
{
self.compress_with_lz4(data)
}
}
#[cfg(not(feature = "compression-zstd"))]
fn compress_with_lz4(&self, data: &[u8]) -> anyhow::Result<CompressionResult> {
#[cfg(feature = "compression-lz4")]
{
let compressed = lz4_flex::compress(data);
let compressed_size = compressed.len();
let original_size = data.len();
let ratio = compressed_size as f64 / original_size as f64;
Ok(CompressionResult {
data: compressed,
original_size,
compressed_size,
algorithm: CompressionAlgorithm::Lz4,
compression_ratio: ratio,
})
}
#[cfg(not(feature = "compression-lz4"))]
{
Err(anyhow::anyhow!(
"No compression algorithm available (enable compression-zstd or compression-lz4 feature)"
))
}
}
async fn decompress_response(
&self,
response: TransportResponse,
) -> anyhow::Result<TransportResponse> {
let content_encoding = response
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-encoding"))
.map(|(_, v)| v.as_str());
if let Some(encoding) = content_encoding {
let start = std::time::Instant::now();
let decompressed = match encoding {
"zstd" => self.decompress_zstd(&response.body).await?,
"lz4" => self.decompress_lz4(&response.body).await?,
"gzip" => self.decompress_gzip(&response.body).await?,
_ => {
trace!("Unknown encoding: {}, returning as-is", encoding);
return Ok(response);
}
};
let elapsed = start.elapsed().as_micros() as u64;
let mut stats = self.stats.lock().map_err(|e| {
anyhow::anyhow!(TransportCompressionError::StreamFailed(format!(
"Failed to acquire stats lock: {}",
e
)))
})?;
stats.base.record_decompression(elapsed);
stats.record_decompression_time(elapsed);
debug!(
"Decompressed {} bytes to {} bytes in {}µs",
response.body.len(),
decompressed.len(),
elapsed
);
Ok(TransportResponse {
status: response.status,
body: decompressed,
headers: response.headers,
})
} else {
Ok(response)
}
}
}
#[async_trait]
impl Transport for CompressedTransport {
async fn send(&self, data: &[u8]) -> anyhow::Result<TransportResponse> {
let (compression_result, operation_stats) = self.compress_data(data)?;
{
let mut stats = self.stats.lock().map_err(|e| {
anyhow::anyhow!(TransportCompressionError::StreamFailed(format!(
"Failed to acquire stats lock: {}",
e
)))
})?;
stats.base.total_original_bytes += operation_stats.base.total_original_bytes;
stats.base.total_compressed_bytes += operation_stats.base.total_compressed_bytes;
stats.base.compression_count += operation_stats.base.compression_count;
stats.base.skipped_count += operation_stats.base.skipped_count;
stats.base.failed_count += operation_stats.base.failed_count;
stats.total_compressions += operation_stats.total_compressions;
stats.total_decompressions += operation_stats.total_decompressions;
}
let mut headers = Vec::new();
if compression_result.algorithm != CompressionAlgorithm::None {
let encoding = match compression_result.algorithm {
CompressionAlgorithm::Zstd => "zstd",
CompressionAlgorithm::Lz4 => "lz4",
CompressionAlgorithm::Gzip => "gzip",
CompressionAlgorithm::None => unreachable!(),
};
headers.push(("Content-Encoding".to_string(), encoding.to_string()));
}
let mut response = self.inner.send(&compression_result.data).await?;
response.headers.extend(headers);
self.decompress_response(response).await
}
async fn send_async(&self, data: &[u8]) -> anyhow::Result<()> {
let (compression_result, _) = self.compress_data(data)?;
self.inner.send_async(&compression_result.data).await
}
async fn health_check(&self) -> anyhow::Result<bool> {
self.inner.health_check().await
}
fn metadata(&self) -> TransportMetadata {
let mut metadata = self.inner.metadata();
metadata.supports_compression = true;
metadata
}
}
#[cfg(test)]
#[path = "wrapper_tests.rs"]
mod tests;