use crate::error::{Result, TitorError};
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
use std::io::{BufReader, Read, Write};
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, trace};
#[derive(Clone)]
pub enum CompressionStrategy {
None,
Fast,
Adaptive {
min_size: usize,
skip_extensions: Vec<String>,
},
Custom(Arc<dyn Fn(&Path, usize) -> bool + Send + Sync>),
}
impl Default for CompressionStrategy {
fn default() -> Self {
CompressionStrategy::Fast
}
}
impl std::fmt::Debug for CompressionStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::None => write!(f, "None"),
Self::Fast => write!(f, "Fast"),
Self::Adaptive { min_size, skip_extensions } => f
.debug_struct("Adaptive")
.field("min_size", min_size)
.field("skip_extensions", skip_extensions)
.finish(),
Self::Custom(_) => write!(f, "Custom(Fn)"),
}
}
}
#[derive(Debug, Default, Clone)]
pub struct CompressionStats {
pub files_compressed: usize,
pub files_stored_raw: usize,
pub bytes_saved: usize,
pub compression_time_ms: u64,
pub decompression_time_ms: u64,
}
impl CompressionStats {
pub fn compression_ratio(&self) -> f64 {
if self.files_compressed == 0 {
return 0.0;
}
let total_files = self.files_compressed + self.files_stored_raw;
self.files_compressed as f64 / total_files as f64
}
pub fn avg_bytes_saved_per_file(&self) -> usize {
if self.files_compressed == 0 {
return 0;
}
self.bytes_saved / self.files_compressed
}
}
#[derive(Debug)]
pub struct CompressionEngine {
strategy: CompressionStrategy,
stats: CompressionStats,
}
const LZ4_MAGIC: &[u8] = b"LZ4T";
impl CompressionEngine {
pub fn new(strategy: CompressionStrategy) -> Self {
Self {
strategy,
stats: CompressionStats::default(),
}
}
pub fn stats(&self) -> &CompressionStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = CompressionStats::default();
}
pub fn compress(&mut self, path: &Path, content: &[u8]) -> Result<Vec<u8>> {
let start = Instant::now();
if !self.should_compress(path, content.len()) {
trace!("Skipping compression for {:?} (strategy)", path);
self.stats.files_stored_raw += 1;
let mut result = Vec::with_capacity(4 + content.len());
result.extend_from_slice(&[0, 0, 0, 0]); result.extend_from_slice(content);
return Ok(result);
}
if content.len() < 64 {
trace!("File too small to benefit from compression: {:?}", path);
self.stats.files_stored_raw += 1;
let mut result = Vec::with_capacity(4 + content.len());
result.extend_from_slice(&[0, 0, 0, 0]);
result.extend_from_slice(content);
return Ok(result);
}
let compressed = compress_prepend_size(content);
if compressed.len() < content.len() {
let saved = content.len() - compressed.len();
self.stats.bytes_saved += saved;
self.stats.files_compressed += 1;
self.stats.compression_time_ms += start.elapsed().as_millis() as u64;
debug!(
"Compressed {:?}: {} -> {} bytes (saved {} bytes, {:.1}%)",
path,
content.len(),
compressed.len(),
saved,
(saved as f64 / content.len() as f64) * 100.0
);
let mut result = Vec::with_capacity(LZ4_MAGIC.len() + compressed.len());
result.extend_from_slice(LZ4_MAGIC);
result.extend_from_slice(&compressed);
Ok(result)
} else {
trace!("Compression not beneficial for {:?}, storing raw", path);
self.stats.files_stored_raw += 1;
let mut result = Vec::with_capacity(4 + content.len());
result.extend_from_slice(&[0, 0, 0, 0]); result.extend_from_slice(content);
Ok(result)
}
}
pub fn decompress(&mut self, content: &[u8]) -> Result<Vec<u8>> {
let start = Instant::now();
if content.len() < 4 {
return Err(TitorError::decompression("Content too short"));
}
if content.starts_with(LZ4_MAGIC) {
let compressed_data = &content[LZ4_MAGIC.len()..];
match decompress_size_prepended(compressed_data) {
Ok(decompressed) => {
self.stats.decompression_time_ms += start.elapsed().as_millis() as u64;
trace!("Decompressed {} bytes to {} bytes", content.len(), decompressed.len());
Ok(decompressed)
}
Err(e) => Err(TitorError::decompression(format!(
"LZ4 decompression failed: {}",
e
))),
}
} else if content.starts_with(&[0, 0, 0, 0]) {
trace!("Content not compressed, returning as-is");
Ok(content[4..].to_vec())
} else {
if self.is_legacy_lz4_compressed(content) {
match decompress_size_prepended(content) {
Ok(decompressed) => {
self.stats.decompression_time_ms += start.elapsed().as_millis() as u64;
trace!("Decompressed legacy LZ4: {} bytes to {} bytes", content.len(), decompressed.len());
Ok(decompressed)
}
Err(_) => {
trace!("Legacy format not compressed, returning as-is");
Ok(content.to_vec())
}
}
} else {
trace!("Unknown format, returning as-is");
Ok(content.to_vec())
}
}
}
fn should_compress(&self, path: &Path, size: usize) -> bool {
match &self.strategy {
CompressionStrategy::None => false,
CompressionStrategy::Fast => size >= 1024, CompressionStrategy::Adaptive { min_size, skip_extensions } => {
if size < *min_size {
return false;
}
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
let should_skip = skip_extensions
.iter()
.any(|skip| skip.eq_ignore_ascii_case(ext));
!should_skip
} else {
true
}
}
CompressionStrategy::Custom(func) => func(path, size),
}
}
fn is_legacy_lz4_compressed(&self, content: &[u8]) -> bool {
if content.len() < 8 { return false;
}
let size = u32::from_le_bytes([content[0], content[1], content[2], content[3]]) as usize;
size > 0 && size < 100_000_000 && content.len() > 4
}
#[cfg(test)]
fn is_lz4_compressed(&self, content: &[u8]) -> bool {
content.starts_with(LZ4_MAGIC)
}
}
#[derive(Debug)]
pub struct CompressedFileStream<R: Read> {
reader: BufReader<R>,
chunk_size: usize,
_marker: std::marker::PhantomData<R>,
}
impl<R: Read> CompressedFileStream<R> {
pub fn new(reader: R, chunk_size: usize) -> Self {
Self {
reader: BufReader::with_capacity(chunk_size, reader),
chunk_size,
_marker: std::marker::PhantomData,
}
}
pub fn process_chunks<F>(&mut self, mut processor: F) -> Result<CompressionStats>
where
F: FnMut(&[u8]) -> Result<()>,
{
let mut buffer = vec![0u8; self.chunk_size];
let mut stats = CompressionStats::default();
let mut total_uncompressed = 0u64;
let mut total_compressed = 0u64;
let start = Instant::now();
loop {
match self.reader.read(&mut buffer) {
Ok(0) => break, Ok(bytes_read) => {
let chunk = &buffer[..bytes_read];
total_uncompressed += bytes_read as u64;
let compressed = compress_prepend_size(chunk);
if compressed.len() < chunk.len() {
processor(&compressed)?;
total_compressed += compressed.len() as u64;
stats.files_compressed += 1;
stats.bytes_saved += chunk.len() - compressed.len();
} else {
let mut uncompressed = vec![0u8; 4 + chunk.len()];
uncompressed[0..4].copy_from_slice(&[0, 0, 0, 0]); uncompressed[4..].copy_from_slice(chunk);
processor(&uncompressed)?;
total_compressed += uncompressed.len() as u64;
stats.files_stored_raw += 1;
}
}
Err(e) => return Err(e.into()),
}
}
stats.compression_time_ms = start.elapsed().as_millis() as u64;
trace!(
"Streamed compression: {} bytes -> {} bytes (saved {} bytes)",
total_uncompressed,
total_compressed,
total_uncompressed.saturating_sub(total_compressed)
);
Ok(stats)
}
#[allow(unused)]
pub async fn process_chunks_async<F, Fut>(&mut self, mut processor: F) -> Result<CompressionStats>
where
F: FnMut(Vec<u8>) -> Fut,
Fut: std::future::Future<Output = Result<()>>,
{
let mut buffer = vec![0u8; self.chunk_size];
let mut stats = CompressionStats::default();
let mut total_uncompressed = 0u64;
let mut total_compressed = 0u64;
let start = Instant::now();
loop {
let bytes_read = match self.reader.read(&mut buffer) {
Ok(n) => n,
Err(e) => return Err(e.into()),
};
if bytes_read == 0 {
break; }
let chunk = buffer[..bytes_read].to_vec();
total_uncompressed += bytes_read as u64;
let compressed = tokio::task::spawn_blocking(move || {
compress_prepend_size(&chunk)
}).await.map_err(|e| TitorError::internal(format!("Compression task failed: {}", e)))?;
if compressed.len() < bytes_read {
processor(compressed.clone()).await?;
total_compressed += compressed.len() as u64;
stats.files_compressed += 1;
stats.bytes_saved += bytes_read - compressed.len();
} else {
let mut uncompressed = vec![0u8; 4 + bytes_read];
uncompressed[0..4].copy_from_slice(&[0, 0, 0, 0]);
uncompressed[4..].copy_from_slice(&buffer[..bytes_read]);
processor(uncompressed.clone()).await?;
total_compressed += uncompressed.len() as u64;
stats.files_stored_raw += 1;
}
}
stats.compression_time_ms = start.elapsed().as_millis() as u64;
Ok(stats)
}
}
#[derive(Debug)]
pub struct CompressedWriter<W: Write> {
writer: Option<W>,
buffer: Vec<u8>,
chunk_size: usize,
stats: CompressionStats,
compression_start: Option<Instant>,
}
impl<W: Write> CompressedWriter<W> {
pub fn new(writer: W, chunk_size: usize) -> Self {
Self {
writer: Some(writer),
buffer: Vec::with_capacity(chunk_size),
chunk_size,
stats: CompressionStats::default(),
compression_start: Some(Instant::now()),
}
}
pub fn stats(&self) -> &CompressionStats {
&self.stats
}
pub fn flush_buffer(&mut self) -> Result<()> {
if !self.buffer.is_empty() {
let uncompressed_size = self.buffer.len();
let compressed = compress_prepend_size(&self.buffer);
if let Some(writer) = self.writer.as_mut() {
if compressed.len() < uncompressed_size {
writer.write_all(&compressed)?;
self.stats.files_compressed += 1;
self.stats.bytes_saved += uncompressed_size - compressed.len();
trace!(
"Compressed chunk: {} -> {} bytes (saved {} bytes)",
uncompressed_size,
compressed.len(),
uncompressed_size - compressed.len()
);
} else {
writer.write_all(&[0, 0, 0, 0])?; writer.write_all(&self.buffer)?;
self.stats.files_stored_raw += 1;
trace!("Stored chunk uncompressed: {} bytes", uncompressed_size);
}
}
self.buffer.clear();
}
Ok(())
}
pub fn finish(mut self) -> Result<(W, CompressionStats)> {
self.flush_buffer()?;
let mut writer = self.writer.take()
.ok_or_else(|| TitorError::internal("Writer already consumed"))?;
writer.flush()?;
if let Some(start) = self.compression_start.take() {
self.stats.compression_time_ms = start.elapsed().as_millis() as u64;
}
Ok((writer, self.stats.clone()))
}
}
impl<W: Write> Write for CompressedWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if self.writer.is_none() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Writer already consumed"
));
}
let written = buf.len();
let mut offset = 0;
while offset < buf.len() {
let remaining_in_buffer = self.chunk_size - self.buffer.len();
let to_copy = remaining_in_buffer.min(buf.len() - offset);
self.buffer.extend_from_slice(&buf[offset..offset + to_copy]);
offset += to_copy;
if self.buffer.len() >= self.chunk_size {
self.flush_buffer()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
}
}
Ok(written)
}
fn flush(&mut self) -> std::io::Result<()> {
self.flush_buffer()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
if let Some(writer) = self.writer.as_mut() {
writer.flush()
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Writer already consumed"
))
}
}
}
impl<W: Write> Drop for CompressedWriter<W> {
fn drop(&mut self) {
let _ = self.flush_buffer();
}
}
pub fn default_skip_extensions() -> Vec<String> {
vec![
"jpg", "jpeg", "png", "gif", "webp", "ico", "bmp", "svg",
"mp4", "avi", "mkv", "mov", "wmv", "flv", "webm", "m4v", "mpg", "mpeg",
"mp3", "wav", "flac", "aac", "ogg", "wma", "m4a", "opus",
"zip", "rar", "7z", "tar", "gz", "bz2", "xz", "zst",
"lz4", "lzo", "lzma", "br",
"pdf", "epub", "mobi",
]
.into_iter()
.map(String::from)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_compression_fast_strategy() {
let mut engine = CompressionEngine::new(CompressionStrategy::Fast);
let path = PathBuf::from("test.txt");
let small_content = b"hello";
let result = engine.compress(&path, small_content).unwrap();
assert_eq!(&result[0..4], &[0, 0, 0, 0]);
assert_eq!(&result[4..], small_content);
assert_eq!(engine.stats().files_stored_raw, 1);
let decompressed = engine.decompress(&result).unwrap();
assert_eq!(decompressed, small_content);
let large_content = "a".repeat(10000).into_bytes();
let compressed = engine.compress(&path, &large_content).unwrap();
assert!(compressed.len() < large_content.len() + LZ4_MAGIC.len());
assert_eq!(engine.stats().files_compressed, 1);
let decompressed = engine.decompress(&compressed).unwrap();
assert_eq!(decompressed, large_content);
}
#[test]
fn test_compression_adaptive_strategy() {
let mut engine = CompressionEngine::new(CompressionStrategy::Adaptive {
min_size: 100,
skip_extensions: vec!["jpg".to_string()],
});
let small_path = PathBuf::from("small.txt");
let small_content = b"tiny";
let result = engine.compress(&small_path, small_content).unwrap();
assert_eq!(&result[0..4], &[0, 0, 0, 0]);
assert_eq!(&result[4..], small_content);
let jpg_path = PathBuf::from("image.jpg");
let jpg_content = vec![0xFF; 1000]; let result = engine.compress(&jpg_path, &jpg_content).unwrap();
assert_eq!(&result[0..4], &[0, 0, 0, 0]);
assert_eq!(&result[4..], &jpg_content[..]);
let txt_path = PathBuf::from("large.txt");
let txt_content = "x".repeat(1000).into_bytes();
let compressed = engine.compress(&txt_path, &txt_content).unwrap();
assert!(compressed.starts_with(LZ4_MAGIC));
assert!(compressed.len() < txt_content.len() + LZ4_MAGIC.len());
}
#[test]
fn test_is_lz4_compressed() {
let engine = CompressionEngine::new(CompressionStrategy::Fast);
let mut compressed = Vec::new();
compressed.extend_from_slice(LZ4_MAGIC);
compressed.extend_from_slice(&compress_prepend_size(b"test data"));
assert!(engine.is_lz4_compressed(&compressed));
assert!(!engine.is_lz4_compressed(&[0, 0, 0, 0, 1, 2, 3]));
assert!(!engine.is_lz4_compressed(b"raw data"));
assert!(!engine.is_lz4_compressed(b"abc"));
}
#[test]
fn test_compression_with_edge_case_content() {
let mut engine = CompressionEngine::new(CompressionStrategy::Fast);
let path = PathBuf::from("test.bin");
let content = vec![57, 228, 255, 0, 1, 2, 3, 4, 5];
let compressed = engine.compress(&path, &content).unwrap();
assert!(compressed.starts_with(&[0, 0, 0, 0]) || compressed.starts_with(LZ4_MAGIC));
let decompressed = engine.decompress(&compressed).unwrap();
assert_eq!(decompressed, content);
}
#[test]
fn test_backward_compatibility() {
let mut engine = CompressionEngine::new(CompressionStrategy::Fast);
let content = b"Hello, World!";
let legacy_compressed = compress_prepend_size(content);
let decompressed = engine.decompress(&legacy_compressed).unwrap();
assert_eq!(decompressed, content);
let raw_data = b"raw content without any marker";
let decompressed = engine.decompress(raw_data).unwrap();
assert_eq!(decompressed, raw_data);
}
#[test]
fn test_compression_stats() {
let mut stats = CompressionStats::default();
stats.files_compressed = 8;
stats.files_stored_raw = 2;
stats.bytes_saved = 1000;
assert_eq!(stats.compression_ratio(), 0.8);
assert_eq!(stats.avg_bytes_saved_per_file(), 125);
}
}