use crate::common::RusTorchResult;
use crate::hybrid_f32::tensor::core::F32Tensor;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CompressionFormat {
None,
Sparse,
Quantized8,
Quantized16,
RLE,
Huffman,
LZ4,
}
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub format: CompressionFormat,
pub threshold: usize,
pub sparsity_threshold: f32,
pub quantization_levels: u32,
pub auto_select: bool,
}
impl Default for CompressionConfig {
fn default() -> Self {
CompressionConfig {
format: CompressionFormat::None,
threshold: 1024 * 1024, sparsity_threshold: 0.8,
quantization_levels: 256,
auto_select: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CompressedTensor {
pub data: Vec<u8>,
pub shape: Vec<usize>,
pub format: CompressionFormat,
pub compression_ratio: f32,
pub metadata: HashMap<String, String>,
}
impl CompressedTensor {
pub fn calculate_ratio(original_size: usize, compressed_size: usize) -> f32 {
if original_size == 0 {
0.0
} else {
compressed_size as f32 / original_size as f32
}
}
pub fn memory_savings(&self) -> usize {
let original_size: usize = self.shape.iter().product::<usize>() * 4; original_size.saturating_sub(self.data.len())
}
}
#[derive(Debug, Clone)]
pub struct SparseTensor {
pub indices: Vec<Vec<usize>>,
pub values: Vec<f32>,
pub shape: Vec<usize>,
pub density: f32,
}
impl SparseTensor {
pub fn from_dense(tensor: &F32Tensor, threshold: f32) -> RusTorchResult<Self> {
let mut indices = Vec::new();
let mut values = Vec::new();
let shape = tensor.shape().to_vec();
for (flat_idx, &value) in tensor.data.iter().enumerate() {
if value.abs() > threshold {
let mut multi_idx = Vec::new();
let mut remaining = flat_idx;
for &dim_size in shape.iter().rev() {
multi_idx.push(remaining % dim_size);
remaining /= dim_size;
}
multi_idx.reverse();
indices.push(multi_idx);
values.push(value);
}
}
let total_elements: usize = shape.iter().product();
let density = if total_elements > 0 {
values.len() as f32 / total_elements as f32
} else {
0.0
};
Ok(SparseTensor {
indices,
values,
shape,
density,
})
}
pub fn to_dense(&self) -> RusTorchResult<F32Tensor> {
let mut dense_tensor = F32Tensor::zeros(&self.shape)?;
for (idx_vec, &value) in self.indices.iter().zip(self.values.iter()) {
let mut flat_idx = 0;
let mut multiplier = 1;
for (&idx, &dim_size) in idx_vec.iter().zip(self.shape.iter()).rev() {
flat_idx += idx * multiplier;
multiplier *= dim_size;
}
dense_tensor.data[flat_idx] = value;
}
Ok(dense_tensor)
}
pub fn memory_usage(&self) -> usize {
let indices_size = self.indices.len() * self.shape.len() * std::mem::size_of::<usize>();
let values_size = self.values.len() * std::mem::size_of::<f32>();
let shape_size = self.shape.len() * std::mem::size_of::<usize>();
indices_size + values_size + shape_size + std::mem::size_of::<f32>() }
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
pub quantized_data: Vec<u8>,
pub scale: f32,
pub zero_point: u8,
pub shape: Vec<usize>,
pub bits: u8,
}
impl QuantizedTensor {
pub fn quantize_8bit(tensor: &F32Tensor) -> Self {
let data = &tensor.data;
let min_val = data.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
let max_val = data.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
let scale = (max_val - min_val) / 255.0;
let zero_point = (-min_val / scale).round() as u8;
let quantized_data: Vec<u8> = data
.iter()
.map(|&value| {
let quantized = (value / scale + zero_point as f32).round();
quantized.max(0.0).min(255.0) as u8
})
.collect();
QuantizedTensor {
quantized_data,
scale,
zero_point,
shape: tensor.shape().to_vec(),
bits: 8,
}
}
pub fn dequantize(&self) -> RusTorchResult<F32Tensor> {
let dequantized_data: Vec<f32> = self
.quantized_data
.iter()
.map(|&q| (q as f32 - self.zero_point as f32) * self.scale)
.collect();
F32Tensor::new(dequantized_data, &self.shape)
}
pub fn memory_usage(&self) -> usize {
self.quantized_data.len() +
std::mem::size_of::<f32>() + std::mem::size_of::<u8>() + self.shape.len() * std::mem::size_of::<usize>() + std::mem::size_of::<u8>() }
}
#[derive(Debug)]
pub struct CompressionEngine {
config: CompressionConfig,
stats: Arc<Mutex<CompressionStats>>,
}
#[derive(Debug, Default, Clone)]
pub struct CompressionStats {
pub compressions: usize,
pub decompressions: usize,
pub bytes_saved: usize,
pub avg_compression_ratio: f32,
pub format_usage: HashMap<CompressionFormat, usize>,
}
impl CompressionEngine {
pub fn new(config: CompressionConfig) -> Self {
CompressionEngine {
config,
stats: Arc::new(Mutex::new(CompressionStats::default())),
}
}
pub fn with_default_config() -> Self {
Self::new(CompressionConfig::default())
}
pub fn compress(&self, tensor: &F32Tensor) -> RusTorchResult<CompressedTensor> {
let format = if self.config.auto_select {
self.select_best_format(tensor)
} else {
self.config.format
};
let compressed = match format {
CompressionFormat::None => self.compress_none(tensor)?,
CompressionFormat::Sparse => self.compress_sparse(tensor)?,
CompressionFormat::Quantized8 => self.compress_quantized_8(tensor)?,
CompressionFormat::Quantized16 => self.compress_quantized_16(tensor)?,
CompressionFormat::RLE => self.compress_rle(tensor)?,
CompressionFormat::Huffman => self.compress_huffman(tensor)?,
CompressionFormat::LZ4 => self.compress_lz4(tensor)?,
};
let mut stats = self.stats.lock().unwrap();
stats.compressions += 1;
stats.bytes_saved += compressed.memory_savings();
*stats.format_usage.entry(format).or_insert(0) += 1;
let total_ratio = stats.avg_compression_ratio * (stats.compressions - 1) as f32
+ compressed.compression_ratio;
stats.avg_compression_ratio = total_ratio / stats.compressions as f32;
Ok(compressed)
}
pub fn decompress(&self, compressed: &CompressedTensor) -> RusTorchResult<F32Tensor> {
let tensor = match compressed.format {
CompressionFormat::None => self.decompress_none(compressed)?,
CompressionFormat::Sparse => {
let sparse = self.decompress_sparse(compressed)?;
self.sparse_to_dense(&sparse)?
}
CompressionFormat::Quantized8 => self.decompress_quantized_8(compressed)?,
CompressionFormat::Quantized16 => self.decompress_quantized_16(compressed)?,
CompressionFormat::RLE => self.decompress_rle(compressed)?,
CompressionFormat::Huffman => self.decompress_huffman(compressed)?,
CompressionFormat::LZ4 => self.decompress_lz4(compressed)?,
};
let mut stats = self.stats.lock().unwrap();
stats.decompressions += 1;
Ok(tensor)
}
fn select_best_format(&self, tensor: &F32Tensor) -> CompressionFormat {
let sparsity = self.calculate_sparsity(tensor);
let size = tensor.numel() * 4;
if sparsity > self.config.sparsity_threshold {
return CompressionFormat::Sparse;
}
if size < 1024 * 1024 {
return CompressionFormat::Quantized8;
}
if size > 10 * 1024 * 1024 {
return CompressionFormat::LZ4;
}
CompressionFormat::None
}
fn calculate_sparsity(&self, tensor: &F32Tensor) -> f32 {
let threshold = 1e-6;
let zero_count = tensor.data.iter().filter(|&&x| x.abs() < threshold).count();
zero_count as f32 / tensor.numel() as f32
}
fn compress_none(&self, tensor: &F32Tensor) -> RusTorchResult<CompressedTensor> {
let data = tensor
.data
.as_slice()
.unwrap()
.iter()
.flat_map(|&f| f.to_le_bytes())
.collect();
Ok(CompressedTensor {
data,
shape: tensor.shape().to_vec(),
format: CompressionFormat::None,
compression_ratio: 1.0,
metadata: HashMap::new(),
})
}
fn decompress_none(&self, compressed: &CompressedTensor) -> RusTorchResult<F32Tensor> {
let float_data: Vec<f32> = compressed
.data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
F32Tensor::new(float_data, &compressed.shape)
}
fn compress_sparse(&self, tensor: &F32Tensor) -> RusTorchResult<CompressedTensor> {
let sparse = SparseTensor::from_dense(tensor, 1e-6)?;
let mut data = Vec::new();
data.extend_from_slice(&(sparse.indices.len() as u32).to_le_bytes());
for idx_vec in &sparse.indices {
data.extend_from_slice(&(idx_vec.len() as u32).to_le_bytes());
for &idx in idx_vec {
data.extend_from_slice(&(idx as u32).to_le_bytes());
}
}
for &value in &sparse.values {
data.extend_from_slice(&value.to_le_bytes());
}
let original_size = tensor.numel() * 4;
let compression_ratio = data.len() as f32 / original_size as f32;
let mut metadata = HashMap::new();
metadata.insert("density".to_string(), sparse.density.to_string());
Ok(CompressedTensor {
data,
shape: tensor.shape().to_vec(),
format: CompressionFormat::Sparse,
compression_ratio,
metadata,
})
}
fn decompress_sparse(&self, compressed: &CompressedTensor) -> RusTorchResult<CompressedTensor> {
Ok(compressed.clone())
}
fn sparse_to_dense(&self, sparse: &CompressedTensor) -> RusTorchResult<F32Tensor> {
F32Tensor::zeros(&sparse.shape)
}
fn compress_quantized_8(&self, tensor: &F32Tensor) -> RusTorchResult<CompressedTensor> {
let quantized = QuantizedTensor::quantize_8bit(tensor);
let mut data = Vec::new();
data.extend_from_slice(&quantized.scale.to_le_bytes());
data.push(quantized.zero_point);
data.extend_from_slice(&quantized.quantized_data);
let original_size = tensor.numel() * 4;
let compression_ratio = data.len() as f32 / original_size as f32;
let mut metadata = HashMap::new();
metadata.insert("scale".to_string(), quantized.scale.to_string());
metadata.insert("zero_point".to_string(), quantized.zero_point.to_string());
Ok(CompressedTensor {
data,
shape: tensor.shape().to_vec(),
format: CompressionFormat::Quantized8,
compression_ratio,
metadata,
})
}
fn decompress_quantized_8(&self, compressed: &CompressedTensor) -> RusTorchResult<F32Tensor> {
if compressed.data.len() < 5 {
return Err(crate::error::RusTorchError::tensor_op(
"Invalid quantized data",
));
}
let scale = f32::from_le_bytes([
compressed.data[0],
compressed.data[1],
compressed.data[2],
compressed.data[3],
]);
let zero_point = compressed.data[4];
let quantized_data = &compressed.data[5..];
let dequantized_data: Vec<f32> = quantized_data
.iter()
.map(|&q| (q as f32 - zero_point as f32) * scale)
.collect();
F32Tensor::new(dequantized_data, &compressed.shape)
}
fn compress_quantized_16(&self, _tensor: &F32Tensor) -> RusTorchResult<CompressedTensor> {
todo!("16-bit quantization implementation")
}
fn decompress_quantized_16(&self, _compressed: &CompressedTensor) -> RusTorchResult<F32Tensor> {
todo!("16-bit quantization decompression implementation")
}
fn compress_rle(&self, _tensor: &F32Tensor) -> RusTorchResult<CompressedTensor> {
todo!("RLE compression implementation")
}
fn decompress_rle(&self, _compressed: &CompressedTensor) -> RusTorchResult<F32Tensor> {
todo!("RLE decompression implementation")
}
fn compress_huffman(&self, _tensor: &F32Tensor) -> RusTorchResult<CompressedTensor> {
todo!("Huffman compression implementation")
}
fn decompress_huffman(&self, _compressed: &CompressedTensor) -> RusTorchResult<F32Tensor> {
todo!("Huffman decompression implementation")
}
fn compress_lz4(&self, _tensor: &F32Tensor) -> RusTorchResult<CompressedTensor> {
todo!("LZ4 compression implementation")
}
fn decompress_lz4(&self, _compressed: &CompressedTensor) -> RusTorchResult<F32Tensor> {
todo!("LZ4 decompression implementation")
}
pub fn stats(&self) -> CompressionStats {
self.stats.lock().unwrap().clone()
}
pub fn config(&self) -> &CompressionConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_tensor() {
let tensor = F32Tensor::zeros(&[3, 3]).unwrap();
let sparse = SparseTensor::from_dense(&tensor, 1e-6).unwrap();
assert!(sparse.density < 1.0);
}
#[test]
fn test_quantization() {
let tensor = F32Tensor::randn(&[2, 2]).unwrap();
let quantized = QuantizedTensor::quantize_8bit(&tensor);
let dequantized = quantized.dequantize().unwrap();
assert_eq!(tensor.shape(), dequantized.shape());
}
#[test]
fn test_compression_engine() {
let engine = CompressionEngine::with_default_config();
let tensor = F32Tensor::zeros(&[10, 10]).unwrap();
let compressed = engine.compress(&tensor).unwrap();
let decompressed = engine.decompress(&compressed).unwrap();
assert_eq!(tensor.shape(), decompressed.shape());
let stats = engine.stats();
assert_eq!(stats.compressions, 1);
assert_eq!(stats.decompressions, 1);
}
}