use ferrum_types::{DataType, FerrumError, Result};
use std::sync::Arc;
pub trait CompressionStrategy: Send + Sync + std::fmt::Debug {
fn compress(&self, data: &[u8], original_dtype: DataType) -> Result<CompressedData>;
fn decompress(&self, compressed: &CompressedData) -> Result<Vec<u8>>;
fn compression_ratio(&self) -> f32;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct CompressedData {
pub data: Vec<u8>,
pub original_dtype: DataType,
pub original_size: usize,
pub algorithm: String,
pub params: CompressionParams,
}
#[derive(Debug, Clone, Default)]
pub struct CompressionParams {
pub quantization_bits: Option<u8>,
pub block_size: Option<usize>,
pub custom: std::collections::HashMap<String, String>,
}
#[derive(Debug, Clone, Default)]
pub struct NoCompression;
impl CompressionStrategy for NoCompression {
fn compress(&self, data: &[u8], original_dtype: DataType) -> Result<CompressedData> {
Ok(CompressedData {
data: data.to_vec(),
original_dtype,
original_size: data.len(),
algorithm: "none".to_string(),
params: CompressionParams::default(),
})
}
fn decompress(&self, compressed: &CompressedData) -> Result<Vec<u8>> {
Ok(compressed.data.clone())
}
fn compression_ratio(&self) -> f32 {
1.0 }
fn name(&self) -> &str {
"none"
}
}
#[derive(Debug, Clone)]
pub struct Int4Compression {
params: CompressionParams,
}
impl Int4Compression {
pub fn new() -> Self {
let mut params = CompressionParams::default();
params.quantization_bits = Some(4);
Self { params }
}
}
impl Default for Int4Compression {
fn default() -> Self {
Self::new()
}
}
impl CompressionStrategy for Int4Compression {
fn compress(&self, data: &[u8], original_dtype: DataType) -> Result<CompressedData> {
match original_dtype {
DataType::FP16 | DataType::FP32 => {
let compressed_size = data.len().div_ceil(2); let mut compressed = vec![0u8; compressed_size];
for (i, chunk) in data.chunks(2).enumerate() {
if i < compressed.len() {
if chunk.len() == 2 {
compressed[i] = (chunk[0] & 0xF0) | ((chunk[1] & 0xF0) >> 4);
} else {
compressed[i] = chunk[0] & 0xF0;
}
}
}
Ok(CompressedData {
data: compressed,
original_dtype,
original_size: data.len(),
algorithm: "int4".to_string(),
params: self.params.clone(),
})
}
_ => Err(FerrumError::invalid_parameter(format!(
"INT4 compression not supported for {:?}",
original_dtype
))),
}
}
fn decompress(&self, compressed: &CompressedData) -> Result<Vec<u8>> {
if compressed.algorithm != "int4" {
return Err(FerrumError::invalid_parameter(
"Expected INT4 compressed data",
));
}
let mut decompressed = Vec::with_capacity(compressed.original_size);
for &byte in &compressed.data {
decompressed.push(byte & 0xF0);
if decompressed.len() < compressed.original_size {
decompressed.push((byte & 0x0F) << 4);
}
}
decompressed.truncate(compressed.original_size);
Ok(decompressed)
}
fn compression_ratio(&self) -> f32 {
2.0 }
fn name(&self) -> &str {
"int4"
}
}
#[derive(Debug, Clone)]
pub struct Fp8Compression {
params: CompressionParams,
}
impl Fp8Compression {
pub fn new() -> Self {
let mut params = CompressionParams::default();
params.quantization_bits = Some(8);
Self { params }
}
}
impl Default for Fp8Compression {
fn default() -> Self {
Self::new()
}
}
impl CompressionStrategy for Fp8Compression {
fn compress(&self, data: &[u8], original_dtype: DataType) -> Result<CompressedData> {
match original_dtype {
DataType::FP32 => {
let compressed: Vec<u8> = data.iter().step_by(4).cloned().collect();
Ok(CompressedData {
data: compressed,
original_dtype,
original_size: data.len(),
algorithm: "fp8".to_string(),
params: self.params.clone(),
})
}
DataType::FP16 => {
let compressed: Vec<u8> = data.iter().step_by(2).cloned().collect();
Ok(CompressedData {
data: compressed,
original_dtype,
original_size: data.len(),
algorithm: "fp8".to_string(),
params: self.params.clone(),
})
}
_ => Err(FerrumError::invalid_parameter(format!(
"FP8 compression not supported for {:?}",
original_dtype
))),
}
}
fn decompress(&self, compressed: &CompressedData) -> Result<Vec<u8>> {
if compressed.algorithm != "fp8" {
return Err(FerrumError::invalid_parameter(
"Expected FP8 compressed data",
));
}
match compressed.original_dtype {
DataType::FP32 => {
let mut decompressed = Vec::with_capacity(compressed.original_size);
for &byte in &compressed.data {
decompressed.push(byte);
decompressed.push(0);
decompressed.push(0);
decompressed.push(0);
}
decompressed.truncate(compressed.original_size);
Ok(decompressed)
}
DataType::FP16 => {
let mut decompressed = Vec::with_capacity(compressed.original_size);
for &byte in &compressed.data {
decompressed.push(byte);
decompressed.push(0);
}
decompressed.truncate(compressed.original_size);
Ok(decompressed)
}
_ => Err(FerrumError::invalid_parameter(format!(
"Cannot decompress FP8 to {:?}",
compressed.original_dtype
))),
}
}
fn compression_ratio(&self) -> f32 {
match self.params.quantization_bits.unwrap_or(8) {
8 => 2.0, _ => 1.5,
}
}
fn name(&self) -> &str {
"fp8"
}
}
#[derive(Debug)]
pub struct CompressionManager {
strategies: std::collections::HashMap<String, Arc<dyn CompressionStrategy>>,
default_strategy: String,
}
impl CompressionManager {
pub fn new() -> Self {
let mut strategies: std::collections::HashMap<String, Arc<dyn CompressionStrategy>> =
std::collections::HashMap::new();
strategies.insert("none".to_string(), Arc::new(NoCompression));
strategies.insert("int4".to_string(), Arc::new(Int4Compression::new()));
strategies.insert("fp8".to_string(), Arc::new(Fp8Compression::new()));
Self {
strategies,
default_strategy: "none".to_string(),
}
}
pub fn register_strategy<S>(&mut self, name: String, strategy: S)
where
S: CompressionStrategy + 'static,
{
self.strategies.insert(name, Arc::new(strategy));
}
pub fn set_default_strategy(&mut self, name: &str) -> Result<()> {
if self.strategies.contains_key(name) {
self.default_strategy = name.to_string();
Ok(())
} else {
Err(FerrumError::not_found(format!(
"Compression strategy not found: {}",
name
)))
}
}
pub fn get_strategy(&self, name: &str) -> Result<Arc<dyn CompressionStrategy>> {
self.strategies.get(name).cloned().ok_or_else(|| {
FerrumError::not_found(format!("Compression strategy not found: {}", name))
})
}
pub fn default_strategy(&self) -> Result<Arc<dyn CompressionStrategy>> {
self.get_strategy(&self.default_strategy)
}
pub fn available_strategies(&self) -> Vec<String> {
self.strategies.keys().cloned().collect()
}
pub fn compress(
&self,
data: &[u8],
original_dtype: DataType,
strategy_name: Option<&str>,
) -> Result<CompressedData> {
let strategy_name = strategy_name.unwrap_or(&self.default_strategy);
let strategy = self.get_strategy(strategy_name)?;
strategy.compress(data, original_dtype)
}
pub fn decompress(&self, compressed: &CompressedData) -> Result<Vec<u8>> {
let strategy = self.get_strategy(&compressed.algorithm)?;
strategy.decompress(compressed)
}
}
impl Default for CompressionManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_compression() {
let compression = NoCompression;
let data = vec![1, 2, 3, 4];
let compressed = compression.compress(&data, DataType::FP32).unwrap();
assert_eq!(compressed.data, data);
assert_eq!(compressed.algorithm, "none");
let decompressed = compression.decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_int4_compression() {
let compression = Int4Compression::new();
let data = vec![0xAB, 0xCD, 0xEF, 0x12];
let compressed = compression.compress(&data, DataType::FP16).unwrap();
assert_eq!(compressed.algorithm, "int4");
assert!(compressed.data.len() <= data.len());
let decompressed = compression.decompress(&compressed).unwrap();
assert_eq!(decompressed.len(), data.len());
}
#[test]
fn test_compression_manager() {
let manager = CompressionManager::new();
let strategies = manager.available_strategies();
assert!(strategies.contains(&"none".to_string()));
assert!(strategies.contains(&"int4".to_string()));
assert!(strategies.contains(&"fp8".to_string()));
let strategy = manager.default_strategy().unwrap();
assert_eq!(strategy.name(), "none");
}
#[test]
fn test_manager_compress_decompress() {
let manager = CompressionManager::new();
let data = vec![1, 2, 3, 4];
let compressed = manager
.compress(&data, DataType::FP32, Some("none"))
.unwrap();
let decompressed = manager.decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_unsupported_compression() {
let compression = Int4Compression::new();
let data = vec![1, 2, 3, 4];
let result = compression.compress(&data, DataType::INT32);
assert!(result.is_err());
}
}