use std::fs::File;
use std::io::{self, Read, Write};
use std::path::Path;
use std::str::FromStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Compression {
None,
Lz4,
Zstd,
}
impl FromStr for Compression {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"none" => Ok(Self::None),
"lz4" => Ok(Self::Lz4),
"zstd" => Ok(Self::Zstd),
_ => Err(format!("Unknown compression type: {}", s)),
}
}
}
impl Compression {
#[allow(dead_code)] pub fn as_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::Lz4 => "lz4",
Self::Zstd => "zstd",
}
}
}
pub fn compress(data: &[u8], compression: Compression) -> io::Result<Vec<u8>> {
match compression {
Compression::None => Ok(data.to_vec()),
Compression::Lz4 => compress_lz4(data),
Compression::Zstd => compress_zstd(data),
}
}
#[allow(dead_code)] pub fn compress_streaming<R: Read, W: Write>(reader: &mut R, writer: &mut W, compression: Compression) -> io::Result<()> {
match compression {
Compression::None => {
std::io::copy(reader, writer)?;
Ok(())
}
Compression::Lz4 => {
let mut encoder = lz4_flex::frame::FrameEncoder::new(writer);
std::io::copy(reader, &mut encoder)?;
encoder.finish()?;
Ok(())
}
Compression::Zstd => {
let mut encoder = zstd::Encoder::new(writer, 3)?;
std::io::copy(reader, &mut encoder)?;
encoder.finish()?;
Ok(())
}
}
}
#[allow(dead_code)] pub fn decompress(data: &[u8], compression: Compression) -> io::Result<Vec<u8>> {
match compression {
Compression::None => Ok(data.to_vec()),
Compression::Lz4 => decompress_lz4(data),
Compression::Zstd => decompress_zstd(data),
}
}
fn compress_lz4(data: &[u8]) -> io::Result<Vec<u8>> {
Ok(lz4_flex::compress_prepend_size(data))
}
#[allow(dead_code)] fn decompress_lz4(data: &[u8]) -> io::Result<Vec<u8>> {
lz4_flex::decompress_size_prepended(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
fn compress_zstd(data: &[u8]) -> io::Result<Vec<u8>> {
let mut encoder = zstd::Encoder::new(Vec::new(), 3)?;
encoder.write_all(data)?;
encoder.finish()
}
#[allow(dead_code)] fn decompress_zstd(data: &[u8]) -> io::Result<Vec<u8>> {
let mut decoder = zstd::Decoder::new(data)?;
let mut result = Vec::new();
decoder.read_to_end(&mut result)?;
Ok(result)
}
const COMPRESSED_EXTENSIONS: &[&str] = &[
"jpg", "jpeg", "png", "gif", "webp", "avif", "heic", "heif", "mp4", "mkv", "avi", "mov", "webm", "m4v", "flv", "wmv", "mp3", "m4a", "aac", "ogg", "opus", "flac", "wma", "zip", "gz", "bz2", "xz", "7z", "rar", "tar.gz", "tgz", "tar.bz2", "pdf", "docx", "xlsx", "pptx", "wasm", "br", "zst",
];
pub fn is_compressed_extension(filename: &str) -> bool {
if let Some(ext) = filename.rsplit('.').next() {
COMPRESSED_EXTENSIONS.iter().any(|&e| ext.eq_ignore_ascii_case(e))
} else {
false
}
}
#[allow(dead_code)] pub fn should_compress_adaptive(filename: &str, file_size: u64, is_local: bool, network_speed_mbps: Option<u64>) -> Compression {
if is_local {
return Compression::None;
}
if let Some(speed) = network_speed_mbps {
const HIGH_SPEED_THRESHOLD_MBPS: u64 = 500;
if speed > HIGH_SPEED_THRESHOLD_MBPS {
return Compression::None;
}
}
if file_size < 1024 * 1024 {
return Compression::None;
}
const MAX_COMPRESSIBLE_SIZE: u64 = 256 * 1024 * 1024;
if file_size > MAX_COMPRESSIBLE_SIZE {
return Compression::None;
}
if is_compressed_extension(filename) {
return Compression::None;
}
Compression::Zstd
}
#[allow(dead_code)] pub fn should_compress(filename: &str, file_size: u64) -> Compression {
should_compress_adaptive(filename, file_size, false, None)
}
pub fn detect_compressibility(file_path: &Path) -> io::Result<f64> {
const SAMPLE_SIZE: usize = 64 * 1024;
let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; SAMPLE_SIZE];
let bytes_read = file.read(&mut buffer)?;
if bytes_read == 0 {
return Ok(1.0); }
let sample = &buffer[..bytes_read];
let compressed = compress_lz4(sample)?;
let ratio = compressed.len() as f64 / sample.len() as f64;
Ok(ratio)
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, clap::ValueEnum)]
pub enum CompressionDetection {
#[default]
Auto,
Extension,
Always,
Never,
}
pub fn should_compress_smart(file_path: Option<&Path>, filename: &str, file_size: u64, is_local: bool, detection_mode: CompressionDetection) -> Compression {
if is_local {
return Compression::None;
}
match detection_mode {
CompressionDetection::Always => return Compression::Zstd,
CompressionDetection::Never => return Compression::None,
_ => {} }
if file_size < 1024 * 1024 {
return Compression::None;
}
const MAX_COMPRESSIBLE_SIZE: u64 = 256 * 1024 * 1024;
if file_size > MAX_COMPRESSIBLE_SIZE {
return Compression::None;
}
if is_compressed_extension(filename) {
return Compression::None;
}
if detection_mode == CompressionDetection::Extension {
return Compression::Zstd;
}
if let Some(path) = file_path {
match detect_compressibility(path) {
Ok(ratio) if ratio < 0.9 => {
Compression::Zstd
}
Ok(_ratio) => {
Compression::None
}
Err(_) => {
Compression::Zstd
}
}
} else {
Compression::Zstd
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compress_decompress_lz4() {
let original = b"Hello, world! This is a test of LZ4 compression. ".repeat(100);
let compressed = compress(&original, Compression::Lz4).unwrap();
let decompressed = decompress(&compressed, Compression::Lz4).unwrap();
assert_eq!(original.as_slice(), decompressed.as_slice());
assert!(compressed.len() < original.len());
}
#[test]
fn test_compress_decompress_zstd() {
let original = b"Hello, world! This is a test of Zstd compression. ".repeat(100);
let compressed = compress(&original, Compression::Zstd).unwrap();
let decompressed = decompress(&compressed, Compression::Zstd).unwrap();
assert_eq!(original.as_slice(), decompressed.as_slice());
assert!(compressed.len() < original.len());
}
#[test]
fn test_compress_decompress_none() {
let original = b"No compression test";
let compressed = compress(original, Compression::None).unwrap();
let decompressed = decompress(&compressed, Compression::None).unwrap();
assert_eq!(original.as_slice(), decompressed.as_slice());
assert_eq!(compressed.len(), original.len());
}
#[test]
fn test_zstd_compression_ratio() {
let repetitive = b"AAAA".repeat(1000);
let compressed = compress(&repetitive, Compression::Zstd).unwrap();
let ratio = compressed.len() as f64 / repetitive.len() as f64;
assert!(ratio < 0.1); }
#[test]
fn test_is_compressed_extension() {
assert!(is_compressed_extension("file.jpg"));
assert!(is_compressed_extension("video.mp4"));
assert!(is_compressed_extension("archive.zip"));
assert!(is_compressed_extension("document.pdf"));
assert!(is_compressed_extension("file.JPG"));
assert!(is_compressed_extension("video.MP4"));
assert!(is_compressed_extension("archive.ZIP"));
assert!(is_compressed_extension("file.JpG"));
assert!(is_compressed_extension("video.Mp4"));
assert!(!is_compressed_extension("file.txt"));
assert!(!is_compressed_extension("code.rs"));
assert!(!is_compressed_extension("data.csv"));
}
#[test]
fn test_should_compress_small_file() {
assert_eq!(should_compress("test.txt", 1024), Compression::None);
}
#[test]
fn test_should_compress_already_compressed() {
assert_eq!(should_compress("image.jpg", 10_000_000), Compression::None);
assert_eq!(should_compress("video.mp4", 100_000_000), Compression::None);
}
#[test]
fn test_should_compress_large_text() {
assert_eq!(should_compress("data.txt", 10_000_000), Compression::Zstd);
assert_eq!(should_compress("log.log", 50_000_000), Compression::Zstd);
}
#[test]
fn test_roundtrip_empty_data() {
let empty: &[u8] = &[];
for compression in [Compression::None, Compression::Lz4, Compression::Zstd] {
let compressed = compress(empty, compression).unwrap();
let decompressed = decompress(&compressed, compression).unwrap();
assert_eq!(decompressed.as_slice(), empty);
}
}
#[test]
fn test_roundtrip_large_data() {
let large: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
for compression in [Compression::None, Compression::Lz4, Compression::Zstd] {
let compressed = compress(&large, compression).unwrap();
let decompressed = decompress(&compressed, compression).unwrap();
assert_eq!(decompressed, large);
}
}
#[test]
fn test_lz4_compression_ratio() {
let repetitive = b"AAAA".repeat(1000);
let compressed = compress(&repetitive, Compression::Lz4).unwrap();
let ratio = compressed.len() as f64 / repetitive.len() as f64;
assert!(ratio < 0.1); }
#[test]
fn test_adaptive_compression_local() {
assert_eq!(should_compress_adaptive("test.txt", 10_000_000, true, None), Compression::None);
}
#[test]
fn test_adaptive_compression_any_network() {
assert_eq!(should_compress_adaptive("test.txt", 10_000_000, false, Some(100_000)), Compression::None);
assert_eq!(should_compress_adaptive("test.txt", 10_000_000, false, Some(1000)), Compression::None);
assert_eq!(should_compress_adaptive("test.txt", 10_000_000, false, Some(100)), Compression::Zstd);
assert_eq!(should_compress_adaptive("test.txt", 10_000_000, false, None), Compression::Zstd);
}
#[test]
fn test_adaptive_compression_respects_precompressed() {
assert_eq!(should_compress_adaptive("video.mp4", 100_000_000, false, Some(10)), Compression::None);
}
#[test]
fn test_adaptive_compression_small_files() {
assert_eq!(should_compress_adaptive("test.txt", 512_000, false, Some(10)), Compression::None);
}
#[test]
fn test_detect_compressibility_text() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut temp_file = NamedTempFile::new().unwrap();
let repetitive_text = "Hello world! ".repeat(5000); temp_file.write_all(repetitive_text.as_bytes()).unwrap();
temp_file.flush().unwrap();
let ratio = detect_compressibility(temp_file.path()).unwrap();
assert!(ratio < 0.5, "Ratio: {}", ratio);
}
#[test]
fn test_detect_compressibility_random() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut temp_file = NamedTempFile::new().unwrap();
let random_data: Vec<u8> = (0u32..65536)
.map(|i| {
let x = i.wrapping_mul(2654435761); ((x ^ (x >> 16)) & 0xFF) as u8
})
.collect();
temp_file.write_all(&random_data).unwrap();
temp_file.flush().unwrap();
let ratio = detect_compressibility(temp_file.path()).unwrap();
assert!(ratio > 0.85, "Ratio: {}", ratio);
}
#[test]
fn test_detect_compressibility_empty() {
use tempfile::NamedTempFile;
let temp_file = NamedTempFile::new().unwrap();
let ratio = detect_compressibility(temp_file.path()).unwrap();
assert_eq!(ratio, 1.0);
}
#[test]
fn test_should_compress_smart_auto_compressible() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut temp_file = NamedTempFile::new().unwrap();
let text = "Compressible text data! ".repeat(50000); temp_file.write_all(text.as_bytes()).unwrap();
temp_file.flush().unwrap();
let result = should_compress_smart(Some(temp_file.path()), "test.txt", 1_200_000, false, CompressionDetection::Auto);
assert_eq!(result, Compression::Zstd);
}
#[test]
fn test_should_compress_smart_auto_incompressible() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut temp_file = NamedTempFile::new().unwrap();
let random_data: Vec<u8> = (0u32..1_200_000)
.map(|i| {
let x = i.wrapping_mul(2654435761);
((x ^ (x >> 16)) & 0xFF) as u8
})
.collect();
temp_file.write_all(&random_data).unwrap();
temp_file.flush().unwrap();
let result = should_compress_smart(Some(temp_file.path()), "data.bin", 1_200_000, false, CompressionDetection::Auto);
assert_eq!(result, Compression::None);
}
#[test]
fn test_should_compress_smart_always() {
let result = should_compress_smart(
None,
"test.jpg", 10_000_000,
false,
CompressionDetection::Always,
);
assert_eq!(result, Compression::Zstd);
}
#[test]
fn test_should_compress_smart_never() {
let result = should_compress_smart(
None,
"test.txt", 10_000_000,
false,
CompressionDetection::Never,
);
assert_eq!(result, Compression::None);
}
#[test]
fn test_should_compress_smart_extension_mode() {
let result = should_compress_smart(None, "test.txt", 10_000_000, false, CompressionDetection::Extension);
assert_eq!(result, Compression::Zstd);
let result = should_compress_smart(None, "test.jpg", 10_000_000, false, CompressionDetection::Extension);
assert_eq!(result, Compression::None);
}
#[test]
fn test_should_compress_smart_local() {
let result = should_compress_smart(
None,
"test.txt",
10_000_000,
true, CompressionDetection::Auto,
);
assert_eq!(result, Compression::None);
}
#[test]
fn test_should_compress_smart_small_file() {
let result = should_compress_smart(
None,
"test.txt",
512_000, false,
CompressionDetection::Auto,
);
assert_eq!(result, Compression::None);
}
#[test]
fn test_should_compress_smart_known_compressed_extension() {
let result = should_compress_smart(None, "video.mp4", 100_000_000, false, CompressionDetection::Auto);
assert_eq!(result, Compression::None);
}
#[test]
fn test_should_compress_smart_no_path_fallback() {
let result = should_compress_smart(
None, "data.bin",
10_000_000,
false,
CompressionDetection::Auto,
);
assert_eq!(result, Compression::Zstd);
}
}