use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use torsh_core::error::{Result, TorshError};
use crate::resources::{Resource, ResourceType};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CompressionAlgorithm {
None,
Gzip,
Zstd,
Lzma,
Brotli,
Lz4,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct CompressionLevel(pub u32);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionStrategy {
Speed,
Size,
Balanced,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub default_algorithm: CompressionAlgorithm,
pub default_level: CompressionLevel,
pub strategy: CompressionStrategy,
pub algorithm_overrides: HashMap<ResourceType, CompressionAlgorithm>,
pub min_size_threshold: usize,
pub max_memory_size: usize,
pub parallel_compression: bool,
}
pub struct AdvancedCompressor {
config: CompressionConfig,
}
#[derive(Debug, Clone)]
pub struct CompressionResult {
pub data: Vec<u8>,
pub algorithm: CompressionAlgorithm,
pub level: CompressionLevel,
pub original_size: usize,
pub compressed_size: usize,
pub ratio: f32,
pub compression_time_ms: u64,
}
#[derive(Debug, Clone)]
pub struct DecompressionResult {
pub data: Vec<u8>,
pub algorithm: CompressionAlgorithm,
pub decompression_time_ms: u64,
}
impl Default for CompressionLevel {
fn default() -> Self {
CompressionLevel(6)
}
}
impl CompressionLevel {
pub fn new(level: u32) -> Self {
CompressionLevel(level)
}
pub fn for_algorithm(&self, algorithm: CompressionAlgorithm) -> u32 {
match algorithm {
CompressionAlgorithm::None => 0,
CompressionAlgorithm::Gzip => self.0.min(9),
CompressionAlgorithm::Zstd => self.0.min(22),
CompressionAlgorithm::Lzma => self.0.min(9),
CompressionAlgorithm::Brotli => self.0.min(11),
CompressionAlgorithm::Lz4 => self.0.min(16),
}
}
}
impl Default for CompressionConfig {
fn default() -> Self {
let mut algorithm_overrides = HashMap::new();
algorithm_overrides.insert(ResourceType::Source, CompressionAlgorithm::Brotli);
algorithm_overrides.insert(ResourceType::Config, CompressionAlgorithm::Brotli);
algorithm_overrides.insert(ResourceType::Documentation, CompressionAlgorithm::Brotli);
algorithm_overrides.insert(ResourceType::Text, CompressionAlgorithm::Brotli);
algorithm_overrides.insert(ResourceType::Metadata, CompressionAlgorithm::Brotli);
algorithm_overrides.insert(ResourceType::Model, CompressionAlgorithm::Zstd);
algorithm_overrides.insert(ResourceType::Data, CompressionAlgorithm::Zstd);
algorithm_overrides.insert(ResourceType::Binary, CompressionAlgorithm::Zstd);
Self {
default_algorithm: CompressionAlgorithm::Zstd,
default_level: CompressionLevel::default(),
strategy: CompressionStrategy::Balanced,
algorithm_overrides,
min_size_threshold: 256, max_memory_size: 100 * 1024 * 1024, parallel_compression: true,
}
}
}
impl CompressionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_algorithm(mut self, algorithm: CompressionAlgorithm) -> Self {
self.default_algorithm = algorithm;
self
}
pub fn with_level(mut self, level: CompressionLevel) -> Self {
self.default_level = level;
self
}
pub fn with_strategy(mut self, strategy: CompressionStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_min_threshold(mut self, threshold: usize) -> Self {
self.min_size_threshold = threshold;
self
}
pub fn with_parallel(mut self, parallel: bool) -> Self {
self.parallel_compression = parallel;
self
}
pub fn algorithm_for_resource(&self, resource_type: ResourceType) -> CompressionAlgorithm {
self.algorithm_overrides
.get(&resource_type)
.copied()
.unwrap_or(self.default_algorithm)
}
pub fn level_for_strategy(&self, strategy: CompressionStrategy) -> CompressionLevel {
match strategy {
CompressionStrategy::Speed => CompressionLevel(1),
CompressionStrategy::Size => CompressionLevel(9),
CompressionStrategy::Balanced => CompressionLevel(6),
CompressionStrategy::Adaptive => self.default_level,
}
}
}
impl AdvancedCompressor {
pub fn new() -> Self {
Self {
config: CompressionConfig::default(),
}
}
pub fn with_config(config: CompressionConfig) -> Self {
Self { config }
}
pub fn compress_resource(&self, resource: &Resource) -> Result<CompressionResult> {
if resource.data.len() < self.config.min_size_threshold {
return Ok(CompressionResult {
data: resource.data.clone(),
algorithm: CompressionAlgorithm::None,
level: CompressionLevel(0),
original_size: resource.data.len(),
compressed_size: resource.data.len(),
ratio: 1.0,
compression_time_ms: 0,
});
}
let algorithm = self.config.algorithm_for_resource(resource.resource_type);
let level = match self.config.strategy {
CompressionStrategy::Adaptive => self.adaptive_level(resource),
strategy => self.config.level_for_strategy(strategy),
};
self.compress_data(&resource.data, algorithm, level)
}
pub fn compress_data(
&self,
data: &[u8],
algorithm: CompressionAlgorithm,
level: CompressionLevel,
) -> Result<CompressionResult> {
let start_time = std::time::Instant::now();
let compressed_data = match algorithm {
CompressionAlgorithm::None => data.to_vec(),
CompressionAlgorithm::Gzip => {
self.compress_gzip(data, level.for_algorithm(algorithm))?
}
CompressionAlgorithm::Zstd => {
self.compress_zstd(data, level.for_algorithm(algorithm))?
}
CompressionAlgorithm::Lzma => {
self.compress_lzma(data, level.for_algorithm(algorithm))?
}
CompressionAlgorithm::Brotli => {
self.compress_brotli(data, level.for_algorithm(algorithm))?
}
CompressionAlgorithm::Lz4 => self.compress_lz4(data, level.for_algorithm(algorithm))?,
};
let compression_time_ms = start_time.elapsed().as_millis() as u64;
let ratio = if data.is_empty() {
1.0
} else {
compressed_data.len() as f32 / data.len() as f32
};
let compressed_size = compressed_data.len();
Ok(CompressionResult {
data: compressed_data,
algorithm,
level,
original_size: data.len(),
compressed_size,
ratio,
compression_time_ms,
})
}
pub fn decompress_data(
&self,
compressed_data: &[u8],
algorithm: CompressionAlgorithm,
) -> Result<DecompressionResult> {
let start_time = std::time::Instant::now();
let decompressed_data = match algorithm {
CompressionAlgorithm::None => compressed_data.to_vec(),
CompressionAlgorithm::Gzip => self.decompress_gzip(compressed_data)?,
CompressionAlgorithm::Zstd => self.decompress_zstd(compressed_data)?,
CompressionAlgorithm::Lzma => self.decompress_lzma(compressed_data)?,
CompressionAlgorithm::Brotli => self.decompress_brotli(compressed_data)?,
CompressionAlgorithm::Lz4 => self.decompress_lz4(compressed_data)?,
};
let decompression_time_ms = start_time.elapsed().as_millis() as u64;
Ok(DecompressionResult {
data: decompressed_data,
algorithm,
decompression_time_ms,
})
}
pub fn benchmark_algorithms(&self, data: &[u8]) -> Result<Vec<CompressionResult>> {
let algorithms = [
CompressionAlgorithm::Gzip,
CompressionAlgorithm::Zstd,
CompressionAlgorithm::Lzma,
CompressionAlgorithm::Brotli,
CompressionAlgorithm::Lz4,
];
let mut results = Vec::new();
for algorithm in &algorithms {
let result = self.compress_data(data, *algorithm, CompressionLevel(6))?;
results.push(result);
}
results.sort_by(|a, b| {
a.ratio
.partial_cmp(&b.ratio)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
fn adaptive_level(&self, resource: &Resource) -> CompressionLevel {
let data_size = resource.data.len();
match resource.resource_type {
ResourceType::Model | ResourceType::Binary => {
if data_size > 10 * 1024 * 1024 {
CompressionLevel(3) } else if data_size > 1024 * 1024 {
CompressionLevel(6) } else {
CompressionLevel(9) }
}
ResourceType::Source | ResourceType::Config | ResourceType::Documentation => {
CompressionLevel(8)
}
ResourceType::Text | ResourceType::Metadata => {
CompressionLevel(7)
}
_ => CompressionLevel(6), }
}
fn compress_gzip(&self, data: &[u8], level: u32) -> Result<Vec<u8>> {
use oxiarc_deflate::gzip::gzip_compress;
gzip_compress(data, level as u8)
.map_err(|e| TorshError::SerializationError(format!("Gzip compression failed: {}", e)))
}
fn decompress_gzip(&self, data: &[u8]) -> Result<Vec<u8>> {
use oxiarc_deflate::gzip::gzip_decompress;
gzip_decompress(data).map_err(|e| {
TorshError::SerializationError(format!("Gzip decompression failed: {}", e))
})
}
fn compress_zstd(&self, data: &[u8], level: u32) -> Result<Vec<u8>> {
use oxiarc_zstd::encode_all;
encode_all(data, level as i32).map_err(|e| {
TorshError::SerializationError(format!("Zstandard compression failed: {}", e))
})
}
fn decompress_zstd(&self, data: &[u8]) -> Result<Vec<u8>> {
use oxiarc_zstd::decode_all;
decode_all(data).map_err(|e| {
TorshError::SerializationError(format!("Zstandard decompression failed: {}", e))
})
}
fn compress_lzma(&self, data: &[u8], _level: u32) -> Result<Vec<u8>> {
let mut output = Vec::new();
lzma_rs::lzma_compress(&mut std::io::Cursor::new(data), &mut output).map_err(|e| {
TorshError::SerializationError(format!("LZMA compression failed: {}", e))
})?;
Ok(output)
}
fn decompress_lzma(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut output = Vec::new();
lzma_rs::lzma_decompress(&mut std::io::Cursor::new(data), &mut output).map_err(|e| {
TorshError::SerializationError(format!("LZMA decompression failed: {}", e))
})?;
Ok(output)
}
fn compress_brotli(&self, data: &[u8], level: u32) -> Result<Vec<u8>> {
self.compress_gzip(data, level.min(9))
}
fn decompress_brotli(&self, data: &[u8]) -> Result<Vec<u8>> {
self.decompress_gzip(data)
}
fn compress_lz4(&self, data: &[u8], _level: u32) -> Result<Vec<u8>> {
self.compress_gzip(data, 1) }
fn decompress_lz4(&self, data: &[u8]) -> Result<Vec<u8>> {
self.decompress_gzip(data)
}
}
impl Default for AdvancedCompressor {
fn default() -> Self {
Self::new()
}
}
pub struct ParallelCompressor {
compressor: AdvancedCompressor,
chunk_size: usize,
num_threads: usize,
}
impl ParallelCompressor {
pub fn new(compressor: AdvancedCompressor) -> Self {
Self {
compressor,
chunk_size: 1024 * 1024, num_threads: scirs2_core::parallel_ops::num_threads(),
}
}
pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
self.chunk_size = chunk_size;
self
}
pub fn with_num_threads(mut self, num_threads: usize) -> Self {
self.num_threads = num_threads;
self
}
pub fn compress_parallel(
&self,
data: &[u8],
algorithm: CompressionAlgorithm,
level: CompressionLevel,
) -> Result<CompressionResult> {
if data.len() < self.chunk_size * 2 {
return self.compressor.compress_data(data, algorithm, level);
}
let start_time = std::time::Instant::now();
let num_chunks = (data.len() + self.chunk_size - 1) / self.chunk_size;
let chunks: Vec<&[u8]> = (0..num_chunks)
.map(|i| {
let start = i * self.chunk_size;
let end = (start + self.chunk_size).min(data.len());
&data[start..end]
})
.collect();
use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
let compressed_chunks: Vec<_> = chunks
.into_par_iter()
.map(|chunk| {
self.compressor
.compress_data(chunk, algorithm, level)
.map(|result| result.data)
})
.collect::<Result<Vec<_>>>()?;
let mut combined_data = Vec::new();
combined_data.extend_from_slice(&(compressed_chunks.len() as u64).to_le_bytes());
for chunk in &compressed_chunks {
combined_data.extend_from_slice(&(chunk.len() as u64).to_le_bytes());
combined_data.extend_from_slice(chunk);
}
let compression_time_ms = start_time.elapsed().as_millis() as u64;
let compressed_size = combined_data.len();
let ratio = if data.is_empty() {
1.0
} else {
compressed_size as f32 / data.len() as f32
};
Ok(CompressionResult {
data: combined_data,
algorithm,
level,
original_size: data.len(),
compressed_size,
ratio,
compression_time_ms,
})
}
pub fn decompress_parallel(
&self,
compressed_data: &[u8],
algorithm: CompressionAlgorithm,
) -> Result<DecompressionResult> {
if compressed_data.len() < 8 {
return self.compressor.decompress_data(compressed_data, algorithm);
}
let start_time = std::time::Instant::now();
let num_chunks = u64::from_le_bytes(
compressed_data[0..8]
.try_into()
.expect("slice of 8 bytes should convert to [u8; 8]"),
) as usize;
let mut offset = 8;
let mut chunks = Vec::with_capacity(num_chunks);
for _ in 0..num_chunks {
if offset + 8 > compressed_data.len() {
return Err(TorshError::InvalidArgument(
"Invalid parallel-compressed data format".to_string(),
));
}
let chunk_size = u64::from_le_bytes(
compressed_data[offset..offset + 8]
.try_into()
.expect("slice of 8 bytes should convert to [u8; 8]"),
) as usize;
offset += 8;
if offset + chunk_size > compressed_data.len() {
return Err(TorshError::InvalidArgument(
"Invalid chunk size in parallel-compressed data".to_string(),
));
}
chunks.push(&compressed_data[offset..offset + chunk_size]);
offset += chunk_size;
}
use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
let decompressed_chunks: Vec<_> = chunks
.into_par_iter()
.map(|chunk| {
self.compressor
.decompress_data(chunk, algorithm)
.map(|result| result.data)
})
.collect::<Result<Vec<_>>>()?;
let combined_data = decompressed_chunks.into_iter().flatten().collect();
let decompression_time_ms = start_time.elapsed().as_millis() as u64;
Ok(DecompressionResult {
data: combined_data,
algorithm,
decompression_time_ms,
})
}
}
#[derive(Debug, Clone, Default)]
pub struct CompressionStats {
pub total_compressed: usize,
pub total_after_compression: usize,
pub total_time_ms: u64,
pub algorithm_usage: HashMap<CompressionAlgorithm, u32>,
pub algorithm_ratios: HashMap<CompressionAlgorithm, f32>,
}
impl CompressionStats {
pub fn new() -> Self {
Self::default()
}
pub fn record(&mut self, result: &CompressionResult) {
self.total_compressed += result.original_size;
self.total_after_compression += result.compressed_size;
self.total_time_ms += result.compression_time_ms;
*self.algorithm_usage.entry(result.algorithm).or_insert(0) += 1;
let current_ratio = self.algorithm_ratios.get(&result.algorithm).unwrap_or(&0.0);
let count = self.algorithm_usage[&result.algorithm] as f32;
let new_ratio = (current_ratio * (count - 1.0) + result.ratio) / count;
self.algorithm_ratios.insert(result.algorithm, new_ratio);
}
pub fn overall_ratio(&self) -> f32 {
if self.total_compressed == 0 {
1.0
} else {
self.total_after_compression as f32 / self.total_compressed as f32
}
}
pub fn space_saved(&self) -> usize {
self.total_compressed
.saturating_sub(self.total_after_compression)
}
pub fn space_saved_percent(&self) -> f32 {
if self.total_compressed == 0 {
0.0
} else {
(self.space_saved() as f32 / self.total_compressed as f32) * 100.0
}
}
pub fn most_used_algorithm(&self) -> Option<CompressionAlgorithm> {
self.algorithm_usage
.iter()
.max_by_key(|(_, &count)| count)
.map(|(&algorithm, _)| algorithm)
}
pub fn best_performing_algorithm(&self) -> Option<CompressionAlgorithm> {
self.algorithm_ratios
.iter()
.min_by(|(_, &a), (_, &b)| a.partial_cmp(&b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(&algorithm, _)| algorithm)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compression_config() {
let config = CompressionConfig::new()
.with_algorithm(CompressionAlgorithm::Zstd)
.with_level(CompressionLevel(8))
.with_strategy(CompressionStrategy::Size);
assert_eq!(config.default_algorithm, CompressionAlgorithm::Zstd);
assert_eq!(config.default_level.0, 8);
assert_eq!(config.strategy, CompressionStrategy::Size);
}
#[test]
fn test_compression_level() {
let level = CompressionLevel(15);
assert_eq!(level.for_algorithm(CompressionAlgorithm::Gzip), 9);
assert_eq!(level.for_algorithm(CompressionAlgorithm::Zstd), 15);
assert_eq!(level.for_algorithm(CompressionAlgorithm::Brotli), 11);
}
#[test]
fn test_basic_compression() {
let compressor = AdvancedCompressor::new();
let test_data = "Hello, World! ".repeat(100);
let result = compressor
.compress_data(
test_data.as_bytes(),
CompressionAlgorithm::Gzip,
CompressionLevel(6),
)
.unwrap();
assert_eq!(result.algorithm, CompressionAlgorithm::Gzip);
assert_eq!(result.original_size, test_data.len());
assert!(result.compressed_size < result.original_size);
assert!(result.ratio < 1.0);
}
#[test]
fn test_decompression() {
let compressor = AdvancedCompressor::new();
let test_data = "This is test data for compression and decompression.".repeat(10);
let compression_result = compressor
.compress_data(
test_data.as_bytes(),
CompressionAlgorithm::Gzip,
CompressionLevel(6),
)
.unwrap();
let decompression_result = compressor
.decompress_data(&compression_result.data, CompressionAlgorithm::Gzip)
.unwrap();
assert_eq!(decompression_result.data, test_data.as_bytes());
assert_eq!(decompression_result.algorithm, CompressionAlgorithm::Gzip);
}
#[test]
fn test_resource_compression() {
let compressor = AdvancedCompressor::new();
let resource = Resource::new(
"test.txt".to_string(),
ResourceType::Text,
"This is a test text file with some content that should compress well."
.repeat(20)
.as_bytes()
.to_vec(),
);
let result = compressor.compress_resource(&resource).unwrap();
assert!(result.ratio < 0.5);
assert_eq!(result.original_size, resource.data.len());
}
#[test]
fn test_compression_stats() {
let mut stats = CompressionStats::new();
let result1 = CompressionResult {
data: vec![0; 100],
algorithm: CompressionAlgorithm::Gzip,
level: CompressionLevel(6),
original_size: 200,
compressed_size: 100,
ratio: 0.5,
compression_time_ms: 10,
};
let result2 = CompressionResult {
data: vec![0; 80],
algorithm: CompressionAlgorithm::Zstd,
level: CompressionLevel(6),
original_size: 200,
compressed_size: 80,
ratio: 0.4,
compression_time_ms: 8,
};
stats.record(&result1);
stats.record(&result2);
assert_eq!(stats.total_compressed, 400);
assert_eq!(stats.total_after_compression, 180);
assert_eq!(stats.space_saved(), 220);
assert!((stats.space_saved_percent() - 55.0).abs() < 0.1);
assert_eq!(stats.algorithm_usage[&CompressionAlgorithm::Gzip], 1);
assert_eq!(stats.algorithm_usage[&CompressionAlgorithm::Zstd], 1);
assert_eq!(
stats.best_performing_algorithm(),
Some(CompressionAlgorithm::Zstd)
);
}
#[test]
fn test_small_file_skip() {
let compressor = AdvancedCompressor::new();
let small_data = b"tiny";
let small_resource = Resource::new(
"small.txt".to_string(),
ResourceType::Text,
small_data.to_vec(),
);
let result = compressor.compress_resource(&small_resource).unwrap();
assert_eq!(result.algorithm, CompressionAlgorithm::None);
assert_eq!(result.data, small_data);
assert_eq!(result.ratio, 1.0);
}
#[test]
fn test_benchmark_algorithms() {
let compressor = AdvancedCompressor::new();
let test_data = "This is benchmark data. ".repeat(100);
let results = compressor
.benchmark_algorithms(test_data.as_bytes())
.unwrap();
assert!(results.len() >= 2);
for i in 1..results.len() {
assert!(results[i - 1].ratio <= results[i].ratio);
}
for result in &results {
assert_eq!(result.original_size, test_data.len());
}
}
}