use std::collections::HashMap;
#[cfg(feature = "streaming")]
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
#[cfg(feature = "streaming")]
use tokio::io::AsyncBufRead;
#[cfg(feature = "streaming")]
use tokio::io::AsyncRead;
use crate::error::ConnectError;
#[cfg(feature = "streaming")]
pub type BoxedAsyncRead = Pin<Box<dyn AsyncRead + Send>>;
#[cfg(feature = "streaming")]
pub type BoxedAsyncBufRead = Pin<Box<dyn AsyncBufRead + Send>>;
pub trait CompressionProvider: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn compress(&self, data: &[u8]) -> Result<Bytes, ConnectError>;
fn decompressor<'a>(&self, data: &'a [u8])
-> Result<Box<dyn std::io::Read + 'a>, ConnectError>;
fn decompress_with_limit(&self, data: &[u8], max_size: usize) -> Result<Bytes, ConnectError> {
use std::io::Read;
let reader = self.decompressor(data)?;
let capacity = if max_size < 64 * 1024 * 1024 {
max_size.saturating_add(1)
} else {
256
};
let mut buf = Vec::with_capacity(capacity);
reader
.take((max_size as u64).saturating_add(1))
.read_to_end(&mut buf)
.map_err(|e| ConnectError::internal(format!("decompression failed: {e}")))?;
if buf.len() > max_size {
return Err(ConnectError::resource_exhausted(format!(
"decompressed size exceeds limit {max_size}"
)));
}
Ok(Bytes::from(buf))
}
}
#[cfg(feature = "streaming")]
pub trait StreamingCompressionProvider: CompressionProvider {
fn decompress_stream(&self, reader: BoxedAsyncBufRead) -> BoxedAsyncRead;
fn compress_stream(&self, reader: BoxedAsyncBufRead) -> BoxedAsyncRead;
}
#[derive(Clone)]
pub struct CompressionRegistry {
providers: Arc<HashMap<&'static str, Arc<dyn CompressionProvider>>>,
#[cfg(feature = "streaming")]
streaming_providers: Arc<HashMap<&'static str, Arc<dyn StreamingCompressionProvider>>>,
accept_encoding: Arc<str>,
}
impl std::fmt::Debug for CompressionRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompressionRegistry")
.field("providers", &self.providers.keys().collect::<Vec<_>>())
.finish()
}
}
impl CompressionRegistry {
pub fn new() -> Self {
Self {
providers: Arc::new(HashMap::new()),
#[cfg(feature = "streaming")]
streaming_providers: Arc::new(HashMap::new()),
accept_encoding: Arc::from(""),
}
}
fn rebuild_accept_encoding(&mut self) {
let mut encodings: Vec<_> = self.providers.keys().copied().collect();
encodings.sort_unstable();
self.accept_encoding = Arc::from(encodings.join(", "));
}
#[must_use]
pub fn register<P: CompressionProvider>(mut self, provider: P) -> Self {
let providers = Arc::make_mut(&mut self.providers);
providers.insert(provider.name(), Arc::new(provider));
self.rebuild_accept_encoding();
self
}
#[must_use]
pub fn get(&self, name: &str) -> Option<Arc<dyn CompressionProvider>> {
self.providers.get(name).cloned()
}
pub fn supports(&self, name: &str) -> bool {
self.providers.contains_key(name)
}
pub fn supported_encodings(&self) -> Vec<&'static str> {
self.providers.keys().copied().collect()
}
pub fn accept_encoding_header(&self) -> &str {
&self.accept_encoding
}
pub fn negotiate_encoding(
&self,
accept_encoding: Option<&str>,
request_encoding: Option<&str>,
) -> Option<&'static str> {
if let Some(accept) = accept_encoding {
for encoding in accept.split(',').map(|s| s.trim()) {
if encoding == "identity" {
continue; }
if let Some((key, _)) = self.providers.get_key_value(encoding) {
return Some(*key);
}
}
return None; }
if let Some(req_enc) = request_encoding
&& req_enc != "identity"
&& let Some((key, _)) = self.providers.get_key_value(req_enc)
{
return Some(*key);
}
None }
pub fn decompress_with_limit(
&self,
encoding: &str,
data: Bytes,
max_size: usize,
) -> Result<Bytes, ConnectError> {
if encoding == "identity" {
if data.len() > max_size {
return Err(ConnectError::resource_exhausted(format!(
"message size {} exceeds limit {}",
data.len(),
max_size
)));
}
return Ok(data);
}
let provider = self.get(encoding).ok_or_else(|| {
ConnectError::unimplemented(format!("unsupported compression encoding: {encoding}"))
})?;
if data.is_empty() {
return Ok(data);
}
provider.decompress_with_limit(&data, max_size)
}
pub fn compress(&self, encoding: &str, data: &[u8]) -> Result<Bytes, ConnectError> {
if encoding == "identity" {
return Ok(Bytes::copy_from_slice(data));
}
let provider = self.get(encoding).ok_or_else(|| {
ConnectError::unimplemented(format!("unsupported compression encoding: {encoding}"))
})?;
provider.compress(data)
}
#[cfg(feature = "streaming")]
#[must_use]
pub fn register_streaming<P: StreamingCompressionProvider>(mut self, provider: P) -> Self {
let name = provider.name();
let provider = Arc::new(provider);
let providers = Arc::make_mut(&mut self.providers);
providers.insert(name, provider.clone());
let streaming_providers = Arc::make_mut(&mut self.streaming_providers);
streaming_providers.insert(name, provider);
self.rebuild_accept_encoding();
self
}
#[cfg(feature = "streaming")]
pub fn get_streaming(&self, name: &str) -> Option<Arc<dyn StreamingCompressionProvider>> {
self.streaming_providers.get(name).cloned()
}
#[cfg(feature = "streaming")]
pub fn supports_streaming(&self, name: &str) -> bool {
self.streaming_providers.contains_key(name)
}
#[cfg(feature = "streaming")]
pub fn decompress_stream(
&self,
encoding: &str,
reader: BoxedAsyncBufRead,
) -> Result<BoxedAsyncRead, ConnectError> {
if encoding == "identity" {
return Ok(reader);
}
let provider = self.get_streaming(encoding).ok_or_else(|| {
ConnectError::unimplemented(format!(
"streaming decompression not supported for encoding: {encoding}"
))
})?;
Ok(provider.decompress_stream(reader))
}
#[cfg(feature = "streaming")]
pub fn compress_stream(
&self,
encoding: &str,
reader: BoxedAsyncBufRead,
) -> Result<BoxedAsyncRead, ConnectError> {
if encoding == "identity" {
return Ok(reader);
}
let provider = self.get_streaming(encoding).ok_or_else(|| {
ConnectError::unimplemented(format!(
"streaming compression not supported for encoding: {encoding}"
))
})?;
Ok(provider.compress_stream(reader))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CompressionPolicy {
enabled: bool,
min_size: usize,
}
pub const DEFAULT_COMPRESSION_MIN_SIZE: usize = 1024;
impl Default for CompressionPolicy {
fn default() -> Self {
Self {
enabled: true,
min_size: DEFAULT_COMPRESSION_MIN_SIZE,
}
}
}
impl CompressionPolicy {
pub fn disabled() -> Self {
Self {
enabled: false,
min_size: 0,
}
}
#[must_use]
pub fn min_size(mut self, size: usize) -> Self {
self.min_size = size;
self
}
#[inline]
pub fn should_compress(&self, message_size: usize) -> bool {
self.enabled && message_size >= self.min_size
}
pub(crate) fn with_override(&self, override_compress: Option<bool>) -> Self {
match override_compress {
None => *self,
Some(true) => Self {
enabled: true,
min_size: 0,
},
Some(false) => Self::disabled(),
}
}
}
impl Default for CompressionRegistry {
#[allow(unused_mut)]
fn default() -> Self {
let mut registry = Self::new();
#[cfg(all(feature = "gzip", feature = "streaming"))]
{
registry = registry.register_streaming(GzipProvider::default());
}
#[cfg(all(feature = "gzip", not(feature = "streaming")))]
{
registry = registry.register(GzipProvider::default());
}
#[cfg(all(feature = "zstd", feature = "streaming"))]
{
registry = registry.register_streaming(ZstdProvider::default());
}
#[cfg(all(feature = "zstd", not(feature = "streaming")))]
{
registry = registry.register(ZstdProvider::default());
}
registry
}
}
#[cfg(feature = "gzip")]
pub struct GzipProvider {
level: u32,
compressors: std::sync::Mutex<Vec<flate2::Compress>>,
decompressors: std::sync::Mutex<Vec<flate2::Decompress>>,
}
#[cfg(feature = "gzip")]
impl std::fmt::Debug for GzipProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GzipProvider")
.field("level", &self.level)
.field(
"pool_compressors",
&self.compressors.lock().map(|v| v.len()).unwrap_or(0),
)
.field(
"pool_decompressors",
&self.decompressors.lock().map(|v| v.len()).unwrap_or(0),
)
.finish()
}
}
#[cfg(feature = "gzip")]
impl Default for GzipProvider {
fn default() -> Self {
Self {
level: 6,
compressors: std::sync::Mutex::new(Vec::new()),
decompressors: std::sync::Mutex::new(Vec::new()),
}
}
}
#[cfg(feature = "gzip")]
impl GzipProvider {
pub fn new() -> Self {
Self::default()
}
pub fn with_level(level: u32) -> Self {
Self {
level,
compressors: std::sync::Mutex::new(Vec::new()),
decompressors: std::sync::Mutex::new(Vec::new()),
}
}
const MAX_POOL_SIZE: usize = 64;
fn take_compressor(&self) -> flate2::Compress {
self.compressors
.lock()
.unwrap_or_else(|e| e.into_inner())
.pop()
.unwrap_or_else(|| flate2::Compress::new(flate2::Compression::new(self.level), false))
}
fn return_compressor(&self, mut c: flate2::Compress) {
c.reset();
let mut pool = self.compressors.lock().unwrap_or_else(|e| e.into_inner());
if pool.len() < Self::MAX_POOL_SIZE {
pool.push(c);
}
}
fn take_decompressor(&self) -> flate2::Decompress {
self.decompressors
.lock()
.unwrap_or_else(|e| e.into_inner())
.pop()
.unwrap_or_else(|| flate2::Decompress::new(false))
}
fn return_decompressor(&self, mut d: flate2::Decompress) {
d.reset(false);
let mut pool = self.decompressors.lock().unwrap_or_else(|e| e.into_inner());
if pool.len() < Self::MAX_POOL_SIZE {
pool.push(d);
}
}
fn compress_inner(
compressor: &mut flate2::Compress,
data: &[u8],
) -> Result<Bytes, ConnectError> {
let mut output = Vec::with_capacity(data.len() + 32);
output.extend_from_slice(&[
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, ]);
let start_in = compressor.total_in();
loop {
let consumed = (compressor.total_in() - start_in) as usize;
output.reserve(output.capacity().max(4096));
let status = compressor
.compress_vec(
&data[consumed..],
&mut output,
flate2::FlushCompress::Finish,
)
.map_err(|e| ConnectError::internal(format!("gzip compression failed: {e}")))?;
if status == flate2::Status::StreamEnd {
break;
}
}
let mut crc = flate2::Crc::new();
crc.update(data);
output.extend_from_slice(&crc.sum().to_le_bytes());
output.extend_from_slice(&(data.len() as u32).to_le_bytes());
Ok(Bytes::from(output))
}
fn decompress_inner(
decompressor: &mut flate2::Decompress,
data: &[u8],
max_size: Option<usize>,
) -> Result<Bytes, ConnectError> {
let deflate_start = gzip_header_len(data)?;
let stream_data = &data[deflate_start..];
let capacity = match max_size {
Some(limit) if limit < 64 * 1024 * 1024 => limit.saturating_add(1),
_ => data.len().saturating_mul(2).max(256),
};
let mut output = Vec::with_capacity(capacity);
let start_in = decompressor.total_in();
loop {
let consumed = (decompressor.total_in() - start_in) as usize;
if output.capacity() == output.len() {
if let Some(limit) = max_size
&& output.len() > limit
{
return Err(ConnectError::resource_exhausted(format!(
"decompressed size exceeds limit {limit}"
)));
}
output.reserve(output.len().max(4096));
}
let status = decompressor
.decompress_vec(
&stream_data[consumed..],
&mut output,
flate2::FlushDecompress::None,
)
.map_err(|e| ConnectError::internal(format!("gzip decompression failed: {e}")))?;
if status == flate2::Status::StreamEnd {
break;
}
}
if let Some(limit) = max_size
&& output.len() > limit
{
return Err(ConnectError::resource_exhausted(format!(
"decompressed size exceeds limit {limit}"
)));
}
let deflate_consumed = (decompressor.total_in() - start_in) as usize;
let trailer_start = deflate_consumed;
if stream_data.len() < trailer_start + 8 {
return Err(ConnectError::internal("gzip data too short for trailer"));
}
let trailer = &stream_data[trailer_start..trailer_start + 8];
let expected_crc = u32::from_le_bytes([trailer[0], trailer[1], trailer[2], trailer[3]]);
let expected_size = u32::from_le_bytes([trailer[4], trailer[5], trailer[6], trailer[7]]);
let mut crc = flate2::Crc::new();
crc.update(&output);
if crc.sum() != expected_crc {
return Err(ConnectError::internal("gzip CRC32 mismatch"));
}
if expected_size != (output.len() as u32) {
return Err(ConnectError::internal("gzip size mismatch"));
}
Ok(Bytes::from(output))
}
}
#[cfg(feature = "gzip")]
fn gzip_header_len(data: &[u8]) -> Result<usize, ConnectError> {
if data.len() < 10 {
return Err(ConnectError::internal("gzip data too short for header"));
}
if data[0] != 0x1f || data[1] != 0x8b {
return Err(ConnectError::internal("invalid gzip magic"));
}
if data[2] != 0x08 {
return Err(ConnectError::internal(
"unsupported gzip compression method",
));
}
let flags = data[3];
let mut pos = 10;
if flags & 0x04 != 0 {
if pos + 2 > data.len() {
return Err(ConnectError::internal("truncated gzip header"));
}
let xlen = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2 + xlen;
}
if flags & 0x08 != 0 {
while pos < data.len() && data[pos] != 0 {
pos += 1;
}
if pos >= data.len() {
return Err(ConnectError::internal("truncated gzip header"));
}
pos += 1; }
if flags & 0x10 != 0 {
while pos < data.len() && data[pos] != 0 {
pos += 1;
}
if pos >= data.len() {
return Err(ConnectError::internal("truncated gzip header"));
}
pos += 1; }
if flags & 0x02 != 0 {
pos += 2;
}
if pos > data.len() {
return Err(ConnectError::internal("truncated gzip header"));
}
Ok(pos)
}
#[cfg(feature = "gzip")]
impl CompressionProvider for GzipProvider {
fn name(&self) -> &'static str {
"gzip"
}
fn compress(&self, data: &[u8]) -> Result<Bytes, ConnectError> {
let mut compressor = self.take_compressor();
let result = Self::compress_inner(&mut compressor, data);
self.return_compressor(compressor);
result
}
fn decompressor<'a>(
&self,
data: &'a [u8],
) -> Result<Box<dyn std::io::Read + 'a>, ConnectError> {
Ok(Box::new(flate2::read::GzDecoder::new(data)))
}
fn decompress_with_limit(&self, data: &[u8], max_size: usize) -> Result<Bytes, ConnectError> {
let mut decompressor = self.take_decompressor();
let result = Self::decompress_inner(&mut decompressor, data, Some(max_size));
self.return_decompressor(decompressor);
result
}
}
#[cfg(all(feature = "gzip", feature = "streaming"))]
impl StreamingCompressionProvider for GzipProvider {
fn decompress_stream(&self, reader: BoxedAsyncBufRead) -> BoxedAsyncRead {
Box::pin(async_compression::tokio::bufread::GzipDecoder::new(reader))
}
fn compress_stream(&self, reader: BoxedAsyncBufRead) -> BoxedAsyncRead {
Box::pin(async_compression::tokio::bufread::GzipEncoder::new(reader))
}
}
#[cfg(feature = "zstd")]
pub struct ZstdProvider {
level: i32,
compressors: std::sync::Mutex<Vec<zstd::bulk::Compressor<'static>>>,
}
#[cfg(feature = "zstd")]
impl std::fmt::Debug for ZstdProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZstdProvider")
.field("level", &self.level)
.field(
"pool_compressors",
&self.compressors.lock().map(|v| v.len()).unwrap_or(0),
)
.finish()
}
}
#[cfg(feature = "zstd")]
impl ZstdProvider {
const DEFAULT_LEVEL: i32 = 3;
pub fn new() -> Self {
Self::default()
}
pub fn with_level(level: i32) -> Self {
Self {
level,
compressors: std::sync::Mutex::new(Vec::new()),
}
}
}
#[cfg(feature = "zstd")]
impl Default for ZstdProvider {
fn default() -> Self {
Self {
level: Self::DEFAULT_LEVEL,
compressors: std::sync::Mutex::new(Vec::new()),
}
}
}
#[cfg(feature = "zstd")]
impl ZstdProvider {
const MAX_POOL_SIZE: usize = 64;
fn take_compressor(&self) -> Result<zstd::bulk::Compressor<'static>, ConnectError> {
if let Some(c) = self
.compressors
.lock()
.unwrap_or_else(|e| e.into_inner())
.pop()
{
return Ok(c);
}
zstd::bulk::Compressor::new(self.level)
.map_err(|e| ConnectError::internal(format!("failed to create zstd compressor: {e}")))
}
fn return_compressor(&self, c: zstd::bulk::Compressor<'static>) {
let mut pool = self.compressors.lock().unwrap_or_else(|e| e.into_inner());
if pool.len() < Self::MAX_POOL_SIZE {
pool.push(c);
}
}
fn decompress_impl(data: &[u8], max_size: Option<usize>) -> Result<Bytes, ConnectError> {
use std::io::Read;
let mut decoder = zstd::Decoder::new(data)
.map_err(|e| ConnectError::internal(format!("zstd decompression failed: {e}")))?;
let capacity = match max_size {
Some(limit) if limit < 64 * 1024 * 1024 => limit.saturating_add(1),
_ => data.len().saturating_mul(4).max(256),
};
let mut decompressed = Vec::with_capacity(capacity);
match max_size {
Some(limit) => {
decoder
.take((limit as u64).saturating_add(1))
.read_to_end(&mut decompressed)
.map_err(|e| {
ConnectError::internal(format!("zstd decompression failed: {e}"))
})?;
if decompressed.len() > limit {
return Err(ConnectError::resource_exhausted(format!(
"decompressed size exceeds limit {limit}"
)));
}
}
None => {
decoder.read_to_end(&mut decompressed).map_err(|e| {
ConnectError::internal(format!("zstd decompression failed: {e}"))
})?;
}
}
Ok(Bytes::from(decompressed))
}
}
#[cfg(feature = "zstd")]
impl CompressionProvider for ZstdProvider {
fn name(&self) -> &'static str {
"zstd"
}
fn compress(&self, data: &[u8]) -> Result<Bytes, ConnectError> {
let mut compressor = self.take_compressor()?;
let result = compressor
.compress(data)
.map(Bytes::from)
.map_err(|e| ConnectError::internal(format!("zstd compression failed: {e}")));
self.return_compressor(compressor);
result
}
fn decompressor<'a>(
&self,
data: &'a [u8],
) -> Result<Box<dyn std::io::Read + 'a>, ConnectError> {
let decoder = zstd::Decoder::new(data)
.map_err(|e| ConnectError::internal(format!("zstd decompression failed: {e}")))?;
Ok(Box::new(decoder))
}
fn decompress_with_limit(&self, data: &[u8], max_size: usize) -> Result<Bytes, ConnectError> {
Self::decompress_impl(data, Some(max_size))
}
}
#[cfg(all(feature = "zstd", feature = "streaming"))]
impl StreamingCompressionProvider for ZstdProvider {
fn decompress_stream(&self, reader: BoxedAsyncBufRead) -> BoxedAsyncRead {
Box::pin(async_compression::tokio::bufread::ZstdDecoder::new(reader))
}
fn compress_stream(&self, reader: BoxedAsyncBufRead) -> BoxedAsyncRead {
Box::pin(async_compression::tokio::bufread::ZstdEncoder::new(reader))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_registry() {
let registry = CompressionRegistry::new();
assert!(!registry.supports("gzip"));
assert!(!registry.supports("zstd"));
assert!(registry.supported_encodings().is_empty());
}
#[test]
fn test_identity_always_works() {
let registry = CompressionRegistry::new();
let data = b"hello world";
let result = registry
.decompress_with_limit("identity", Bytes::from_static(data), usize::MAX)
.unwrap();
assert_eq!(&result[..], data);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_large_roundtrip() {
let provider = GzipProvider::default();
let data: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
let compressed = provider.compress(&data).unwrap();
let decompressed = provider
.decompress_with_limit(&compressed, usize::MAX)
.unwrap();
assert_eq!(&decompressed[..], &data[..]);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_pooled_cross_compat_with_gz_decoder() {
use std::io::Read;
let provider = GzipProvider::default();
let data: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
let compressed = provider.compress(&data).unwrap();
let mut decoder = flate2::read::GzDecoder::new(&compressed[..]);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed).unwrap();
assert_eq!(&decompressed[..], &data[..]);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_pooled_cross_compat_with_gz_encoder() {
use std::io::Write;
let provider = GzipProvider::default();
let data: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::new(6));
encoder.write_all(&data).unwrap();
let compressed = encoder.finish().unwrap();
let decompressed = provider
.decompress_with_limit(&compressed, usize::MAX)
.unwrap();
assert_eq!(&decompressed[..], &data[..]);
}
#[cfg(feature = "gzip")]
fn gz_hdr(flags: u8, extra: &[u8]) -> Vec<u8> {
let mut h = vec![
0x1f, 0x8b, 0x08, flags, 0, 0, 0, 0, 0, 0xff, ];
h.extend_from_slice(extra);
h
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_basic() {
assert_eq!(gzip_header_len(&gz_hdr(0x00, &[])).unwrap(), 10);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_fextra() {
let extra = [3u8, 0, 0xAA, 0xBB, 0xCC]; assert_eq!(gzip_header_len(&gz_hdr(0x04, &extra)).unwrap(), 10 + 2 + 3);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_fextra_truncated() {
let extra = [100u8, 0, 0xAA, 0xBB];
assert!(gzip_header_len(&gz_hdr(0x04, &extra)).is_err());
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_fname() {
let extra = b"test.txt\0";
assert_eq!(gzip_header_len(&gz_hdr(0x08, extra)).unwrap(), 10 + 9);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_fname_truncated() {
assert!(gzip_header_len(&gz_hdr(0x08, b"nonul")).is_err());
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_fcomment() {
let extra = b"a comment\0";
assert_eq!(gzip_header_len(&gz_hdr(0x10, extra)).unwrap(), 10 + 10);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_fcomment_truncated() {
assert!(gzip_header_len(&gz_hdr(0x10, b"nonul")).is_err());
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_fhcrc() {
assert_eq!(gzip_header_len(&gz_hdr(0x02, &[0xAB, 0xCD])).unwrap(), 12);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_fhcrc_truncated() {
assert!(gzip_header_len(&gz_hdr(0x02, &[0xAB])).is_err());
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_all_flags() {
let mut extra = Vec::new();
extra.extend_from_slice(&[2u8, 0, 0xAA, 0xBB]); extra.extend_from_slice(b"name\0"); extra.extend_from_slice(b"cmt\0"); extra.extend_from_slice(&[0x12, 0x34]); let flags = 0x04 | 0x08 | 0x10 | 0x02;
assert_eq!(
gzip_header_len(&gz_hdr(flags, &extra)).unwrap(),
10 + 4 + 5 + 4 + 2
);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_bad_magic() {
let mut hdr = gz_hdr(0x00, &[]);
hdr[0] = 0x00;
assert!(gzip_header_len(&hdr).is_err());
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_bad_method() {
let mut hdr = gz_hdr(0x00, &[]);
hdr[2] = 0x07; assert!(gzip_header_len(&hdr).is_err());
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_header_len_too_short() {
assert!(gzip_header_len(&[0x1f, 0x8b, 0x08]).is_err());
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_provider() {
let provider = GzipProvider::default();
let data = b"hello world, this is a test of gzip compression";
let compressed = provider.compress(data).unwrap();
assert_ne!(&compressed[..], data);
let decompressed = provider
.decompress_with_limit(&compressed, usize::MAX)
.unwrap();
assert_eq!(&decompressed[..], data);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_registry() {
let registry = CompressionRegistry::new().register(GzipProvider::default());
assert!(registry.supports("gzip"));
assert!(!registry.supports("zstd"));
let data = b"test data";
let compressed = registry.compress("gzip", data).unwrap();
let decompressed = registry
.decompress_with_limit("gzip", compressed, usize::MAX)
.unwrap();
assert_eq!(&decompressed[..], data);
}
#[cfg(feature = "zstd")]
#[test]
fn test_zstd_provider() {
let provider = ZstdProvider::default();
let data = b"hello world, this is a test of zstd compression";
let compressed = provider.compress(data).unwrap();
assert_ne!(&compressed[..], data);
let decompressed = provider
.decompress_with_limit(&compressed, usize::MAX)
.unwrap();
assert_eq!(&decompressed[..], data);
}
#[cfg(feature = "zstd")]
#[test]
fn test_zstd_high_compression_ratio() {
let provider = ZstdProvider::default();
let data = vec![0u8; 100_000];
let compressed = provider.compress(&data).unwrap();
assert!(
compressed.len() * 4 < data.len(),
"expected high compression ratio; got {} bytes -> {} bytes",
data.len(),
compressed.len()
);
let decompressed = provider
.decompress_with_limit(&compressed, usize::MAX)
.unwrap();
assert_eq!(decompressed.len(), data.len());
assert!(decompressed.iter().all(|&b| b == 0));
}
#[cfg(feature = "zstd")]
#[test]
fn test_zstd_registry() {
let registry = CompressionRegistry::new().register(ZstdProvider::default());
assert!(registry.supports("zstd"));
assert!(!registry.supports("gzip"));
let data = b"test data";
let compressed = registry.compress("zstd", data).unwrap();
let decompressed = registry
.decompress_with_limit("zstd", compressed, usize::MAX)
.unwrap();
assert_eq!(&decompressed[..], data);
}
#[test]
fn test_unsupported_encoding() {
let registry = CompressionRegistry::new();
let result =
registry.decompress_with_limit("unknown", Bytes::from_static(b"data"), usize::MAX);
assert!(result.is_err());
}
#[test]
#[cfg(feature = "zstd")]
fn test_decompress_empty_body_with_encoding_header() {
let registry = CompressionRegistry::default();
let result = registry.decompress_with_limit("zstd", Bytes::new(), usize::MAX);
assert_eq!(result.unwrap().len(), 0);
let result = registry.decompress_with_limit("gzip", Bytes::new(), usize::MAX);
assert_eq!(result.unwrap().len(), 0);
let result = registry.decompress_with_limit("zstd", Bytes::new(), 1024);
assert_eq!(result.unwrap().len(), 0);
}
#[test]
fn test_decompress_empty_body_unknown_encoding_still_errors() {
let registry = CompressionRegistry::default();
let result = registry.decompress_with_limit("foo", Bytes::new(), usize::MAX);
let err = result.unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::Unimplemented);
}
#[cfg(all(feature = "gzip", feature = "zstd"))]
#[test]
fn test_default_registry() {
let registry = CompressionRegistry::default();
assert!(registry.supports("gzip"));
assert!(registry.supports("zstd"));
}
#[test]
fn test_accept_encoding_header() {
let registry = CompressionRegistry::new();
assert_eq!(registry.accept_encoding_header(), "");
#[cfg(feature = "gzip")]
{
let registry = CompressionRegistry::new().register(GzipProvider::default());
assert_eq!(registry.accept_encoding_header(), "gzip");
}
}
#[cfg(all(feature = "gzip", feature = "zstd"))]
#[test]
fn test_accept_encoding_header_sorted_deterministic() {
let r1 = CompressionRegistry::new()
.register(GzipProvider::default())
.register(ZstdProvider::default());
let r2 = CompressionRegistry::new()
.register(ZstdProvider::default())
.register(GzipProvider::default());
assert_eq!(r1.accept_encoding_header(), "gzip, zstd");
assert_eq!(r2.accept_encoding_header(), "gzip, zstd");
}
struct MockProvider;
impl CompressionProvider for MockProvider {
fn name(&self) -> &'static str {
"mock"
}
fn compress(&self, data: &[u8]) -> Result<Bytes, ConnectError> {
Ok(Bytes::from(data.iter().rev().copied().collect::<Vec<_>>()))
}
fn decompressor<'a>(
&self,
data: &'a [u8],
) -> Result<Box<dyn std::io::Read + 'a>, ConnectError> {
let reversed: Vec<u8> = data.iter().rev().copied().collect();
Ok(Box::new(std::io::Cursor::new(reversed)))
}
}
#[test]
fn test_custom_provider() {
let registry = CompressionRegistry::new().register(MockProvider);
assert!(registry.supports("mock"));
let data = b"hello";
let compressed = registry.compress("mock", data).unwrap();
assert_eq!(&compressed[..], b"olleh");
let decompressed = registry
.decompress_with_limit("mock", compressed, usize::MAX)
.unwrap();
assert_eq!(&decompressed[..], data);
}
#[cfg(all(feature = "gzip", feature = "streaming"))]
#[tokio::test]
async fn test_gzip_streaming() {
use tokio::io::AsyncReadExt;
let registry = CompressionRegistry::default();
assert!(registry.supports_streaming("gzip"));
let data = b"hello world, this is a test of streaming gzip compression";
let compressed = registry.compress("gzip", data).unwrap();
let reader: BoxedAsyncBufRead = Box::pin(std::io::Cursor::new(compressed.to_vec()));
let mut decompressor = registry.decompress_stream("gzip", reader).unwrap();
let mut decompressed = Vec::new();
decompressor.read_to_end(&mut decompressed).await.unwrap();
assert_eq!(&decompressed[..], data);
}
#[cfg(all(feature = "zstd", feature = "streaming"))]
#[tokio::test]
async fn test_zstd_streaming() {
use tokio::io::AsyncReadExt;
let registry = CompressionRegistry::default();
assert!(registry.supports_streaming("zstd"));
let data = b"hello world, this is a test of streaming zstd compression";
let compressed = registry.compress("zstd", data).unwrap();
let reader: BoxedAsyncBufRead = Box::pin(std::io::Cursor::new(compressed.to_vec()));
let mut decompressor = registry.decompress_stream("zstd", reader).unwrap();
let mut decompressed = Vec::new();
decompressor.read_to_end(&mut decompressed).await.unwrap();
assert_eq!(&decompressed[..], data);
}
#[cfg(all(feature = "gzip", feature = "streaming"))]
#[tokio::test]
async fn test_streaming_compress_decompress_roundtrip() {
use tokio::io::AsyncReadExt;
let registry = CompressionRegistry::default();
let data = b"hello world, this is a roundtrip test of streaming compression";
let input: BoxedAsyncBufRead = Box::pin(std::io::Cursor::new(data.to_vec()));
let mut compressor = registry.compress_stream("gzip", input).unwrap();
let mut compressed = Vec::new();
compressor.read_to_end(&mut compressed).await.unwrap();
let reader: BoxedAsyncBufRead = Box::pin(std::io::Cursor::new(compressed));
let mut decompressor = registry.decompress_stream("gzip", reader).unwrap();
let mut decompressed = Vec::new();
decompressor.read_to_end(&mut decompressed).await.unwrap();
assert_eq!(&decompressed[..], data);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_decompress_with_limit_under() {
let provider = GzipProvider::default();
let data = b"hello world";
let compressed = provider.compress(data).unwrap();
let result = provider.decompress_with_limit(&compressed, 1024);
assert!(result.is_ok());
assert_eq!(&result.unwrap()[..], data);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_decompress_with_limit_exact() {
let provider = GzipProvider::default();
let data = b"hello world";
let compressed = provider.compress(data).unwrap();
let result = provider.decompress_with_limit(&compressed, data.len());
assert!(result.is_ok());
assert_eq!(&result.unwrap()[..], data);
}
#[cfg(feature = "gzip")]
#[test]
fn test_gzip_decompress_with_limit_exceeded() {
let provider = GzipProvider::default();
let data = vec![0u8; 1024];
let compressed = provider.compress(&data).unwrap();
let result = provider.decompress_with_limit(&compressed, 512);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.code, crate::ErrorCode::ResourceExhausted);
}
#[cfg(feature = "zstd")]
#[test]
fn test_zstd_decompress_with_limit_under() {
let provider = ZstdProvider::default();
let data = b"hello world";
let compressed = provider.compress(data).unwrap();
let result = provider.decompress_with_limit(&compressed, 1024);
assert!(result.is_ok());
assert_eq!(&result.unwrap()[..], data);
}
#[cfg(feature = "zstd")]
#[test]
fn test_zstd_decompress_with_limit_exact() {
let provider = ZstdProvider::default();
let data = b"hello world";
let compressed = provider.compress(data).unwrap();
let result = provider.decompress_with_limit(&compressed, data.len());
assert!(result.is_ok());
assert_eq!(&result.unwrap()[..], data);
}
#[cfg(feature = "zstd")]
#[test]
fn test_zstd_decompress_with_limit_exceeded() {
let provider = ZstdProvider::default();
let data = vec![0u8; 1024];
let compressed = provider.compress(&data).unwrap();
let result = provider.decompress_with_limit(&compressed, 512);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.code, crate::ErrorCode::ResourceExhausted);
}
#[test]
fn test_compression_policy_default() {
let policy = CompressionPolicy::default();
assert!(!policy.should_compress(512));
assert!(!policy.should_compress(1023));
assert!(policy.should_compress(1024));
assert!(policy.should_compress(4096));
}
#[test]
fn test_compression_policy_disabled() {
let policy = CompressionPolicy::disabled();
assert!(!policy.should_compress(0));
assert!(!policy.should_compress(1024));
assert!(!policy.should_compress(1_000_000));
}
#[test]
fn test_compression_policy_custom_min_size() {
let policy = CompressionPolicy::default().min_size(4096);
assert!(!policy.should_compress(1024));
assert!(!policy.should_compress(4095));
assert!(policy.should_compress(4096));
assert!(policy.should_compress(8192));
}
#[test]
fn test_compression_policy_empty_message() {
let default_policy = CompressionPolicy::default();
assert!(!default_policy.should_compress(0));
let zero_min = CompressionPolicy::default().min_size(0);
assert!(zero_min.should_compress(0));
let disabled = CompressionPolicy::disabled();
assert!(!disabled.should_compress(0));
}
#[test]
fn test_compression_policy_with_override() {
let policy = CompressionPolicy::default();
let effective = policy.with_override(None);
assert!(!effective.should_compress(512));
assert!(effective.should_compress(2048));
let forced = policy.with_override(Some(true));
assert!(forced.should_compress(0));
assert!(forced.should_compress(1));
let disabled = policy.with_override(Some(false));
assert!(!disabled.should_compress(0));
assert!(!disabled.should_compress(1_000_000));
}
#[test]
fn test_identity_decompress_with_limit() {
let registry = CompressionRegistry::new();
let data = Bytes::from_static(b"hello world");
let result = registry.decompress_with_limit("identity", data.clone(), 1024);
assert!(result.is_ok());
let result = registry.decompress_with_limit("identity", data.clone(), data.len());
assert!(result.is_ok());
let result = registry.decompress_with_limit("identity", data, 5);
assert!(result.is_err());
}
}