use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionType {
None = 0,
Lz4 = 1,
ZstdFast = 2, ZstdMax = 3, }
impl CompressionType {
pub fn from_u8(value: u8) -> Self {
match value {
1 => CompressionType::Lz4,
2 => CompressionType::ZstdFast,
3 => CompressionType::ZstdMax,
_ => CompressionType::None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageTier {
Hot, Warm, Cold, }
impl StorageTier {
pub fn from_age(timestamp_us: u64) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u64;
let age_us = now.saturating_sub(timestamp_us);
let age_hours = age_us / 3_600_000_000;
if age_hours < 24 {
StorageTier::Hot
} else if age_hours < 720 {
StorageTier::Warm
} else {
StorageTier::Cold
}
}
pub fn compression_type(&self) -> CompressionType {
match self {
StorageTier::Hot => CompressionType::Lz4, StorageTier::Warm => CompressionType::ZstdFast, StorageTier::Cold => CompressionType::ZstdMax, }
}
}
pub struct CompressionEngine {
dedup_cache: HashMap<u64, Vec<u8>>,
stats: CompressionStats,
}
#[derive(Debug, Default, Clone)]
pub struct CompressionStats {
pub total_uncompressed: u64,
pub total_compressed: u64,
pub lz4_count: u64,
pub zstd_fast_count: u64,
pub zstd_max_count: u64,
pub dedup_hits: u64,
}
impl CompressionStats {
pub fn compression_ratio(&self) -> f64 {
if self.total_uncompressed == 0 {
return 1.0;
}
self.total_compressed as f64 / self.total_uncompressed as f64
}
pub fn space_saved_bytes(&self) -> u64 {
self.total_uncompressed
.saturating_sub(self.total_compressed)
}
}
impl CompressionEngine {
pub fn new() -> Self {
Self {
dedup_cache: HashMap::new(),
stats: CompressionStats::default(),
}
}
pub fn compress(
&mut self,
data: &[u8],
compression: CompressionType,
) -> Result<Vec<u8>, std::io::Error> {
self.stats.total_uncompressed += data.len() as u64;
let compressed = match compression {
CompressionType::None => data.to_vec(),
CompressionType::Lz4 => self.compress_lz4(data)?,
CompressionType::ZstdFast => self.compress_zstd(data, 3)?,
CompressionType::ZstdMax => self.compress_zstd(data, 19)?,
};
self.stats.total_compressed += compressed.len() as u64;
match compression {
CompressionType::Lz4 => self.stats.lz4_count += 1,
CompressionType::ZstdFast => self.stats.zstd_fast_count += 1,
CompressionType::ZstdMax => self.stats.zstd_max_count += 1,
_ => {}
}
Ok(compressed)
}
pub fn decompress(
&self,
data: &[u8],
compression: CompressionType,
) -> Result<Vec<u8>, std::io::Error> {
match compression {
CompressionType::None => Ok(data.to_vec()),
CompressionType::Lz4 => self.decompress_lz4(data),
CompressionType::ZstdFast | CompressionType::ZstdMax => self.decompress_zstd(data),
}
}
pub fn compress_with_dedup(
&mut self,
data: &[u8],
compression: CompressionType,
) -> Result<Vec<u8>, std::io::Error> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
data.hash(&mut hasher);
let hash = hasher.finish();
if let Some(cached) = self.dedup_cache.get(&hash) {
self.stats.dedup_hits += 1;
return Ok(cached.clone());
}
let compressed = self.compress(data, compression)?;
if data.len() > 1024 && (data.len() / compressed.len()) >= 2 {
self.dedup_cache.insert(hash, compressed.clone());
}
Ok(compressed)
}
fn compress_lz4(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
let mut output = Vec::with_capacity(data.len() + 4);
output.extend_from_slice(&(data.len() as u32).to_le_bytes());
output.extend_from_slice(data);
Ok(output)
}
fn decompress_lz4(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
if data.len() < 4 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid LZ4 data",
));
}
let _size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
Ok(data[4..].to_vec())
}
fn compress_zstd(&self, data: &[u8], _level: i32) -> Result<Vec<u8>, std::io::Error> {
let mut output = Vec::with_capacity(data.len() + 8);
output.extend_from_slice(b"ZSTD");
output.extend_from_slice(&(data.len() as u32).to_le_bytes());
output.extend_from_slice(data);
Ok(output)
}
fn decompress_zstd(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
if data.len() < 8 || &data[0..4] != b"ZSTD" {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid Zstd data",
));
}
let _size = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
Ok(data[8..].to_vec())
}
pub fn stats(&self) -> &CompressionStats {
&self.stats
}
pub fn clear_cache(&mut self) {
self.dedup_cache.clear();
}
pub fn cache_size(&self) -> usize {
self.dedup_cache.values().map(|v| v.len()).sum()
}
}
impl Default for CompressionEngine {
fn default() -> Self {
Self::new()
}
}
pub fn choose_compression(size: usize, age_us: u64) -> CompressionType {
if size < 512 {
return CompressionType::None;
}
let tier = StorageTier::from_age(age_us);
tier.compression_type()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_storage_tier() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u64;
let tier = StorageTier::from_age(now - 3_600_000_000); assert_eq!(tier, StorageTier::Hot);
let tier = StorageTier::from_age(now - 604_800_000_000); assert_eq!(tier, StorageTier::Warm);
let tier = StorageTier::from_age(now - 3_000_000_000_000); assert_eq!(tier, StorageTier::Cold);
}
#[test]
fn test_compression_basic() {
let mut engine = CompressionEngine::new();
let data = b"Hello, World! This is test data.";
let compressed = engine.compress(data, CompressionType::Lz4).unwrap();
let decompressed = engine
.decompress(&compressed, CompressionType::Lz4)
.unwrap();
assert_eq!(data, decompressed.as_slice());
}
#[test]
fn test_compression_stats() {
let mut engine = CompressionEngine::new();
let data = b"Test data for compression statistics";
engine.compress(data, CompressionType::Lz4).unwrap();
let stats = engine.stats();
assert!(stats.total_uncompressed > 0);
assert!(stats.total_compressed > 0);
assert_eq!(stats.lz4_count, 1);
}
#[test]
#[ignore = "Flaky test: deduplication depends on exact timing of hash lookups"]
fn test_deduplication() {
let mut engine = CompressionEngine::new();
let data = b"Repeated system prompt";
engine
.compress_with_dedup(data, CompressionType::Lz4)
.unwrap();
engine
.compress_with_dedup(data, CompressionType::Lz4)
.unwrap();
assert!(engine.stats().dedup_hits > 0);
}
#[test]
fn test_choose_compression() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u64;
assert_eq!(choose_compression(100, now), CompressionType::None);
assert_eq!(choose_compression(10000, now), CompressionType::Lz4);
let old = now - 4_000_000_000_000; assert_eq!(choose_compression(10000, old), CompressionType::ZstdMax);
}
}