use crate::{Error, Result, MAGIC_COMPRESSED};
use alloc::vec::Vec;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum CompressionAlgorithm {
None = 0x00,
Gzip = 0x01,
Brotli = 0x02,
}
impl CompressionAlgorithm {
pub fn from_byte(byte: u8) -> Result<Self> {
match byte {
0x00 => Ok(CompressionAlgorithm::None),
0x01 => Ok(CompressionAlgorithm::Gzip),
0x02 => Ok(CompressionAlgorithm::Brotli),
_ => Err(Error::UnsupportedAlgorithm(byte)),
}
}
pub fn name(&self) -> &'static str {
match self {
CompressionAlgorithm::None => "none",
CompressionAlgorithm::Gzip => "gzip",
CompressionAlgorithm::Brotli => "brotli",
}
}
pub fn from_name(name: &str) -> Result<Self> {
match name.to_lowercase().as_str() {
"none" => Ok(CompressionAlgorithm::None),
"gzip" => Ok(CompressionAlgorithm::Gzip),
"brotli" => Ok(CompressionAlgorithm::Brotli),
_ => Err(Error::UnsupportedAlgorithm(0)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionResult {
pub compressed: Vec<u8>,
pub algorithm: CompressionAlgorithm,
pub original_size: usize,
pub compressed_size: usize,
pub compression_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct CompressionOptions {
pub algorithm: CompressionAlgorithm,
pub min_size_threshold: usize,
pub level: u32,
}
impl Default for CompressionOptions {
fn default() -> Self {
Self {
algorithm: CompressionAlgorithm::Brotli,
min_size_threshold: 100,
level: 6,
}
}
}
pub fn compress(data: &[u8], options: Option<CompressionOptions>) -> Result<CompressionResult> {
let opts = options.unwrap_or_default();
let original_size = data.len();
if original_size < opts.min_size_threshold {
return Ok(CompressionResult {
compressed: data.to_vec(),
algorithm: CompressionAlgorithm::None,
original_size,
compressed_size: original_size,
compression_ratio: 1.0,
});
}
if opts.algorithm == CompressionAlgorithm::None {
return Ok(CompressionResult {
compressed: data.to_vec(),
algorithm: CompressionAlgorithm::None,
original_size,
compressed_size: original_size,
compression_ratio: 1.0,
});
}
let (compressed, algorithm) = match opts.algorithm {
CompressionAlgorithm::Brotli => {
compress_brotli(data, opts.level)?
}
CompressionAlgorithm::Gzip => {
compress_gzip(data, opts.level)?
}
CompressionAlgorithm::None => {
(data.to_vec(), CompressionAlgorithm::None)
}
};
let compressed_size = compressed.len();
let compression_ratio = compressed_size as f64 / original_size as f64;
if compression_ratio < 0.9 {
Ok(CompressionResult {
compressed,
algorithm,
original_size,
compressed_size,
compression_ratio,
})
} else {
Ok(CompressionResult {
compressed: data.to_vec(),
algorithm: CompressionAlgorithm::None,
original_size,
compressed_size: original_size,
compression_ratio: 1.0,
})
}
}
pub fn decompress(data: &[u8], algorithm: CompressionAlgorithm) -> Result<Vec<u8>> {
match algorithm {
CompressionAlgorithm::None => Ok(data.to_vec()),
CompressionAlgorithm::Gzip => decompress_gzip(data),
CompressionAlgorithm::Brotli => decompress_brotli(data),
}
}
fn compress_brotli(data: &[u8], level: u32) -> Result<(Vec<u8>, CompressionAlgorithm)> {
use brotli::enc::BrotliEncoderParams;
let mut output = Vec::new();
let mut params = BrotliEncoderParams::default();
params.quality = level as i32;
brotli::BrotliCompress(
&mut std::io::Cursor::new(data),
&mut output,
¶ms,
).map_err(|e| Error::CompressionFailed(e.to_string()))?;
Ok((output, CompressionAlgorithm::Brotli))
}
fn decompress_brotli(data: &[u8]) -> Result<Vec<u8>> {
let mut output = Vec::new();
brotli::BrotliDecompress(
&mut std::io::Cursor::new(data),
&mut output,
).map_err(|e| Error::DecompressionFailed(e.to_string()))?;
Ok(output)
}
fn compress_gzip(data: &[u8], level: u32) -> Result<(Vec<u8>, CompressionAlgorithm)> {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
encoder.write_all(data)
.map_err(|e| Error::CompressionFailed(e.to_string()))?;
let output = encoder.finish()
.map_err(|e| Error::CompressionFailed(e.to_string()))?;
Ok((output, CompressionAlgorithm::Gzip))
}
fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>> {
use flate2::read::GzDecoder;
use std::io::Read;
let mut decoder = GzDecoder::new(data);
let mut output = Vec::new();
decoder.read_to_end(&mut output)
.map_err(|e| Error::DecompressionFailed(e.to_string()))?;
Ok(output)
}
pub fn serialize_with_header(result: &CompressionResult) -> Vec<u8> {
let original_size = result.original_size as u32;
let mut output = Vec::with_capacity(7 + result.compressed.len());
output.extend_from_slice(MAGIC_COMPRESSED);
output.push(result.algorithm as u8);
output.extend_from_slice(&original_size.to_be_bytes());
output.extend_from_slice(&result.compressed);
output
}
pub fn deserialize_with_header(data: &[u8]) -> Result<(Vec<u8>, CompressionAlgorithm, usize)> {
if data.len() < 7 {
return Err(Error::TruncatedPayload {
expected: 7,
actual: data.len(),
});
}
if &data[0..2] != MAGIC_COMPRESSED {
return Err(Error::InvalidFormat);
}
let algorithm = CompressionAlgorithm::from_byte(data[2])?;
let original_size = u32::from_be_bytes([data[3], data[4], data[5], data[6]]) as usize;
let compressed = data[7..].to_vec();
Ok((compressed, algorithm, original_size))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gzip_roundtrip() {
let data = b"Hello, World! This is a test message that should be compressed.";
let result = compress(data, Some(CompressionOptions {
algorithm: CompressionAlgorithm::Gzip,
min_size_threshold: 10,
level: 6,
})).unwrap();
let decompressed = decompress(&result.compressed, result.algorithm).unwrap();
assert_eq!(data, &decompressed[..]);
}
#[test]
fn test_brotli_roundtrip() {
let data = b"Hello, World! This is a test message that should be compressed with Brotli.";
let result = compress(data, Some(CompressionOptions {
algorithm: CompressionAlgorithm::Brotli,
min_size_threshold: 10,
level: 6,
})).unwrap();
let decompressed = decompress(&result.compressed, result.algorithm).unwrap();
assert_eq!(data, &decompressed[..]);
}
#[test]
fn test_skip_small_data() {
let data = b"tiny";
let result = compress(data, Some(CompressionOptions {
algorithm: CompressionAlgorithm::Brotli,
min_size_threshold: 100, level: 6,
})).unwrap();
assert_eq!(result.algorithm, CompressionAlgorithm::None);
assert_eq!(result.compressed, data);
}
#[test]
fn test_header_serialization() {
let data = b"Test data for header serialization test with enough content.";
let result = compress(data, Some(CompressionOptions {
algorithm: CompressionAlgorithm::Gzip,
min_size_threshold: 10,
level: 6,
})).unwrap();
let serialized = serialize_with_header(&result);
let (compressed, algorithm, original_size) = deserialize_with_header(&serialized).unwrap();
assert_eq!(algorithm, result.algorithm);
assert_eq!(original_size, result.original_size);
assert_eq!(compressed, result.compressed);
}
}