use bytes::Bytes;
use flate2::read::{DeflateDecoder, GzDecoder};
use std::io::Read;
use crate::types::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionType {
None,
Gzip,
Deflate,
Brotli,
}
impl CompressionType {
pub fn from_header(content_encoding: Option<&str>) -> Self {
match content_encoding.map(|s| s.to_lowercase()).as_deref() {
Some("gzip") | Some("x-gzip") => CompressionType::Gzip,
Some("deflate") => CompressionType::Deflate,
Some("br") => CompressionType::Brotli,
_ => CompressionType::None,
}
}
}
pub struct Decompressor {
max_size: u64,
max_ratio: f64,
}
impl Default for Decompressor {
fn default() -> Self {
Self {
max_size: 100 * 1024 * 1024, max_ratio: 100.0,
}
}
}
impl Decompressor {
pub fn new(max_size: u64, max_ratio: f64) -> Self {
Self { max_size, max_ratio }
}
pub fn decompress(&self, data: &[u8], compression: CompressionType) -> Result<Bytes> {
match compression {
CompressionType::None => Ok(Bytes::copy_from_slice(data)),
CompressionType::Gzip => self.decompress_gzip(data),
CompressionType::Deflate => self.decompress_deflate(data),
CompressionType::Brotli => self.decompress_brotli(data),
}
}
fn decompress_gzip(&self, data: &[u8]) -> Result<Bytes> {
let mut decoder = GzDecoder::new(data);
self.read_with_limits(&mut decoder, data.len())
}
fn decompress_deflate(&self, data: &[u8]) -> Result<Bytes> {
let mut decoder = DeflateDecoder::new(data);
self.read_with_limits(&mut decoder, data.len())
}
fn decompress_brotli(&self, data: &[u8]) -> Result<Bytes> {
let mut decoder = brotli::Decompressor::new(data, 4096);
self.read_with_limits(&mut decoder, data.len())
}
fn read_with_limits<R: Read>(&self, reader: &mut R, compressed_size: usize) -> Result<Bytes> {
let mut output = Vec::new();
let mut buffer = [0u8; 8192];
let mut total_read: u64 = 0;
loop {
let n = reader.read(&mut buffer).map_err(|e| {
Error::Decompression(format!("Read error: {}", e))
})?;
if n == 0 {
break;
}
total_read += n as u64;
if total_read > self.max_size {
return Err(Error::SizeExceeded {
max: self.max_size,
actual: total_read,
});
}
if compressed_size > 0 {
let ratio = total_read as f64 / compressed_size as f64;
if ratio > self.max_ratio {
return Err(Error::Decompression(format!(
"Compression ratio {} exceeds limit {}",
ratio, self.max_ratio
)));
}
}
output.extend_from_slice(&buffer[..n]);
}
Ok(Bytes::from(output))
}
}
#[derive(Debug)]
pub struct DecompressionResult {
pub data: Bytes,
pub compressed_size: u64,
pub decompressed_size: u64,
pub ratio: f64,
}
impl DecompressionResult {
pub fn new(data: Bytes, compressed_size: u64) -> Self {
let decompressed_size = data.len() as u64;
let ratio = if compressed_size > 0 {
decompressed_size as f64 / compressed_size as f64
} else {
1.0
};
Self {
data,
compressed_size,
decompressed_size,
ratio,
}
}
}