use std::io::{self, Cursor, Read, Write};
use std::result::Result as StdResult;
use crate::traits::{DictError, Result};
const MAX_DECOMPRESSED_BYTES: usize = 128 * 1024 * 1024;
fn ensure_within_limit(len: usize) -> Result<()> {
if len > MAX_DECOMPRESSED_BYTES {
return Err(DictError::DecompressionError(format!(
"Decompressed data exceeds safety limit ({} bytes)",
MAX_DECOMPRESSED_BYTES
)));
}
Ok(())
}
fn read_to_end_with_limit<R: Read>(mut reader: R, limit: usize) -> Result<Vec<u8>> {
let mut out = Vec::new();
let mut buf = [0u8; 8192];
while let Ok(n) = reader.read(&mut buf) {
if n == 0 {
break;
}
if out.len().saturating_add(n) > limit {
return Err(DictError::DecompressionError(
"Decompressed data exceeds safety limit".to_string(),
));
}
out.extend_from_slice(&buf[..n]);
}
Ok(out)
}
#[derive(Debug, Clone, PartialEq)]
pub enum CompressionAlgorithm {
None,
Gzip,
Lz4,
Zstd,
}
impl Default for CompressionAlgorithm {
fn default() -> Self {
CompressionAlgorithm::Zstd
}
}
pub fn compress(data: &[u8], algorithm: CompressionAlgorithm) -> Result<Vec<u8>> {
match algorithm {
CompressionAlgorithm::None => Ok(data.to_vec()),
CompressionAlgorithm::Gzip => compress_gzip(data),
CompressionAlgorithm::Lz4 => compress_lz4(data),
CompressionAlgorithm::Zstd => compress_zstd(data),
}
}
pub fn decompress(compressed: &[u8], algorithm: CompressionAlgorithm) -> Result<Vec<u8>> {
match algorithm {
CompressionAlgorithm::None => {
ensure_within_limit(compressed.len())?;
Ok(compressed.to_vec())
}
CompressionAlgorithm::Gzip => decompress_gzip(compressed),
CompressionAlgorithm::Lz4 => decompress_lz4(compressed),
CompressionAlgorithm::Zstd => decompress_zstd(compressed),
}
}
pub fn compression_level(level: u32, algorithm: &CompressionAlgorithm) -> u32 {
match algorithm {
CompressionAlgorithm::Gzip => level.min(9), CompressionAlgorithm::Zstd => level.min(19), _ => level,
}
}
pub fn max_compression_level(algorithm: &CompressionAlgorithm) -> u32 {
match algorithm {
CompressionAlgorithm::Gzip => 9,
CompressionAlgorithm::Zstd => 19,
_ => 1,
}
}
pub fn suggested_compression_level(algorithm: &CompressionAlgorithm) -> u32 {
match algorithm {
CompressionAlgorithm::Gzip => 6, CompressionAlgorithm::Zstd => 6, _ => 1,
}
}
fn compress_gzip(data: &[u8]) -> Result<Vec<u8>> {
let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::new(6));
encoder
.write_all(data)
.map_err(|e| DictError::DecompressionError(e.to_string()))?;
encoder
.finish()
.map_err(|e| DictError::DecompressionError(e.to_string()))
}
fn decompress_gzip(compressed: &[u8]) -> Result<Vec<u8>> {
let decoder = flate2::read::GzDecoder::new(compressed);
read_to_end_with_limit(decoder, MAX_DECOMPRESSED_BYTES)
.map_err(|e| DictError::DecompressionError(e.to_string()))
}
fn compress_lz4(data: &[u8]) -> Result<Vec<u8>> {
let mut encoder = lz4_flex::frame::FrameEncoder::new(Vec::new());
encoder
.write_all(data)
.map_err(|e| DictError::DecompressionError(e.to_string()))?;
encoder
.finish()
.map_err(|e| DictError::DecompressionError(e.to_string()))
}
fn decompress_lz4(compressed: &[u8]) -> Result<Vec<u8>> {
let decoder = lz4_flex::frame::FrameDecoder::new(Cursor::new(compressed));
read_to_end_with_limit(decoder, MAX_DECOMPRESSED_BYTES)
.map_err(|e| DictError::DecompressionError(e.to_string()))
}
fn compress_zstd(data: &[u8]) -> Result<Vec<u8>> {
let mut encoder = zstd::Encoder::new(Vec::new(), 6)
.map_err(|e| DictError::DecompressionError(e.to_string()))?;
encoder
.write_all(data)
.map_err(|e| DictError::DecompressionError(e.to_string()))?;
encoder
.finish()
.map_err(|e| DictError::DecompressionError(e.to_string()))
}
fn decompress_zstd(compressed: &[u8]) -> Result<Vec<u8>> {
let decoder =
zstd::Decoder::new(compressed).map_err(|e| DictError::DecompressionError(e.to_string()))?;
read_to_end_with_limit(decoder, MAX_DECOMPRESSED_BYTES)
.map_err(|e| DictError::DecompressionError(e.to_string()))
}
pub fn estimate_compression_ratio(
original_size: u64,
algorithm: &CompressionAlgorithm,
level: u32,
) -> f32 {
match algorithm {
CompressionAlgorithm::None => 1.0,
CompressionAlgorithm::Gzip => {
let adjusted_level = compression_level(level, algorithm);
(adjusted_level as f32 / 9.0 * 2.0 + 1.0).min(4.0)
}
CompressionAlgorithm::Lz4 => {
2.0 + (level as f32 / 10.0)
}
CompressionAlgorithm::Zstd => {
let adjusted_level = compression_level(level, algorithm);
(adjusted_level as f32 / 19.0 * 4.0 + 1.5).min(6.0)
}
}
}
pub fn get_algorithm_settings(algorithm: &CompressionAlgorithm) -> AlgorithmSettings {
match algorithm {
CompressionAlgorithm::None => AlgorithmSettings {
supports_streaming: false,
supports_dictionary: false,
typical_ratio: 1.0,
speed_category: SpeedCategory::VeryFast,
memory_overhead: 0,
},
CompressionAlgorithm::Gzip => AlgorithmSettings {
supports_streaming: true,
supports_dictionary: false,
typical_ratio: 2.5,
speed_category: SpeedCategory::Fast,
memory_overhead: 256 * 1024, },
CompressionAlgorithm::Lz4 => AlgorithmSettings {
supports_streaming: true,
supports_dictionary: true,
typical_ratio: 2.0,
speed_category: SpeedCategory::VeryFast,
memory_overhead: 64 * 1024, },
CompressionAlgorithm::Zstd => AlgorithmSettings {
supports_streaming: true,
supports_dictionary: true,
typical_ratio: 3.5,
speed_category: SpeedCategory::Medium,
memory_overhead: 512 * 1024, },
}
}
#[derive(Debug, Clone)]
pub struct AlgorithmSettings {
pub supports_streaming: bool,
pub supports_dictionary: bool,
pub typical_ratio: f32,
pub speed_category: SpeedCategory,
pub memory_overhead: u64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SpeedCategory {
VeryFast,
Fast,
Medium,
Slow,
}
pub fn compress_stream<R: Read, W: Write>(
input: &mut R,
output: &mut W,
algorithm: CompressionAlgorithm,
) -> Result<u64> {
match algorithm {
CompressionAlgorithm::None => {
let mut buffer = vec![0u8; 8192];
let mut total_written = 0u64;
loop {
match input.read(&mut buffer) {
Ok(0) => break,
Ok(n) => {
output
.write_all(&buffer[..n])
.map_err(|e| DictError::IoError(e.to_string()))?;
total_written += n as u64;
}
Err(e) => return Err(DictError::IoError(e.to_string())),
}
}
Ok(total_written)
}
CompressionAlgorithm::Gzip => {
let mut encoder = flate2::write::GzEncoder::new(output, flate2::Compression::new(6));
let mut total_written = 0u64;
let mut buffer = vec![0u8; 8192];
loop {
match input.read(&mut buffer) {
Ok(0) => break,
Ok(n) => {
encoder
.write_all(&buffer[..n])
.map_err(|e| DictError::IoError(e.to_string()))?;
total_written += n as u64;
}
Err(e) => return Err(DictError::IoError(e.to_string())),
}
}
encoder
.finish()
.map_err(|e| DictError::IoError(e.to_string()))?;
Ok(total_written)
}
CompressionAlgorithm::Lz4 => {
let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
let mut total_written = 0u64;
let mut buffer = vec![0u8; 8192];
loop {
match input.read(&mut buffer) {
Ok(0) => break,
Ok(n) => {
encoder
.write_all(&buffer[..n])
.map_err(|e| DictError::IoError(e.to_string()))?;
total_written += n as u64;
}
Err(e) => return Err(DictError::IoError(e.to_string())),
}
}
encoder
.finish()
.map_err(|e| DictError::IoError(e.to_string()))?;
Ok(total_written)
}
CompressionAlgorithm::Zstd => {
let mut encoder =
zstd::Encoder::new(output, 6).map_err(|e| DictError::IoError(e.to_string()))?;
let mut total_written = 0u64;
let mut buffer = vec![0u8; 8192];
loop {
match input.read(&mut buffer) {
Ok(0) => break,
Ok(n) => {
encoder
.write_all(&buffer[..n])
.map_err(|e| DictError::IoError(e.to_string()))?;
total_written += n as u64;
}
Err(e) => return Err(DictError::IoError(e.to_string())),
}
}
encoder
.finish()
.map_err(|e| DictError::IoError(e.to_string()))?;
Ok(total_written)
}
}
}
pub fn decompress_stream<R: Read, W: Write>(
input: &mut R,
output: &mut W,
algorithm: CompressionAlgorithm,
) -> Result<u64> {
match algorithm {
CompressionAlgorithm::None => {
let mut buffer = vec![0u8; 8192];
let mut total_written = 0u64;
loop {
match input.read(&mut buffer) {
Ok(0) => break,
Ok(n) => {
ensure_within_limit(total_written as usize + n)?;
output
.write_all(&buffer[..n])
.map_err(|e| DictError::IoError(e.to_string()))?;
total_written += n as u64;
}
Err(e) => return Err(DictError::IoError(e.to_string())),
}
}
Ok(total_written)
}
CompressionAlgorithm::Gzip => {
let mut decoder = flate2::read::GzDecoder::new(input);
let mut total_written = 0u64;
let mut buffer = vec![0u8; 8192];
loop {
match decoder.read(&mut buffer) {
Ok(0) => break,
Ok(n) => {
ensure_within_limit(total_written as usize + n)?;
output
.write_all(&buffer[..n])
.map_err(|e| DictError::IoError(e.to_string()))?;
total_written += n as u64;
}
Err(e) => return Err(DictError::IoError(e.to_string())),
}
}
Ok(total_written)
}
CompressionAlgorithm::Lz4 => {
let mut decoder = lz4_flex::frame::FrameDecoder::new(input);
let mut total_written = 0u64;
let mut buffer = vec![0u8; 8192];
loop {
match decoder.read(&mut buffer) {
Ok(0) => break,
Ok(n) => {
ensure_within_limit(total_written as usize + n)?;
output
.write_all(&buffer[..n])
.map_err(|e| DictError::IoError(e.to_string()))?;
total_written += n as u64;
}
Err(e) => return Err(DictError::IoError(e.to_string())),
}
}
Ok(total_written)
}
CompressionAlgorithm::Zstd => {
let mut decoder =
zstd::Decoder::new(input).map_err(|e| DictError::IoError(e.to_string()))?;
let mut total_written = 0u64;
let mut buffer = vec![0u8; 8192];
loop {
match decoder.read(&mut buffer) {
Ok(0) => break,
Ok(n) => {
ensure_within_limit(total_written as usize + n)?;
output
.write_all(&buffer[..n])
.map_err(|e| DictError::IoError(e.to_string()))?;
total_written += n as u64;
}
Err(e) => return Err(DictError::IoError(e.to_string())),
}
}
Ok(total_written)
}
}
}