use crate::errors::{GraphError, GraphResult};
use crate::graph::Graph;
use crate::tensor::decomposition::tensor_ring::TensorRing;
use crate::tensor::DenseTensor;
use crate::tensor::TensorBase;
use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
use std::cell::RefCell;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub target_ranks: Vec<usize>,
pub layers: Vec<String>,
pub min_rank: usize,
pub max_rank: usize,
pub target_ratio: Option<f64>,
}
impl CompressionConfig {
pub fn new() -> Self {
Self {
target_ranks: vec![64],
layers: vec![".*".to_string()], min_rank: 16,
max_rank: 256,
target_ratio: None,
}
}
pub fn with_target_ranks(mut self, ranks: Vec<usize>) -> Self {
self.target_ranks = ranks;
self
}
pub fn with_layers(mut self, layers: Vec<String>) -> Self {
self.layers = layers;
self
}
pub fn with_min_rank(mut self, rank: usize) -> Self {
self.min_rank = rank;
self
}
pub fn with_max_rank(mut self, rank: usize) -> Self {
self.max_rank = rank;
self
}
pub fn with_target_ratio(mut self, ratio: f64) -> Self {
self.target_ratio = Some(ratio.clamp(1.5, 10.0));
self
}
pub fn matches_layer(&self, layer_name: &str) -> bool {
self.layers.iter().any(|pattern| {
if pattern == ".*" {
true
} else {
layer_name.contains(pattern)
}
})
}
}
impl Default for CompressionConfig {
fn default() -> Self {
Self::new()
}
}
pub struct TensorRingCompressor {
config: CompressionConfig,
compressed_tensors: RefCell<HashMap<String, TensorRing>>,
original_params: RefCell<usize>,
compressed_params: RefCell<usize>,
}
impl TensorRingCompressor {
pub fn new(config: CompressionConfig) -> Self {
Self {
config,
compressed_tensors: RefCell::new(HashMap::new()),
original_params: RefCell::new(0),
compressed_params: RefCell::new(0),
}
}
pub fn config(&self) -> &CompressionConfig {
&self.config
}
pub fn decompose(&self, tensor: &DenseTensor) -> Result<TensorRing, crate::tensor::TensorError> {
use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
let shape = tensor.shape();
let rank = self.select_rank(shape);
compress_tensor_ring(tensor, rank)
}
pub fn reconstruct(&self, ring: &TensorRing) -> Result<DenseTensor, crate::tensor::TensorError> {
ring.reconstruct()
}
pub fn compress_graph(
&self,
graph: &Graph<OperatorType, WeightTensor>,
) -> GraphResult<CompressionReport> {
use crate::graph::traits::GraphQuery;
use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
let mut total_original_params = 0usize;
let mut total_compressed_params = 0usize;
let mut layer_reports = Vec::new();
let mut compressed_map = HashMap::new();
for edge_ref in graph.edges() {
let weight = edge_ref.data();
let weight_tensor = DenseTensor::new(
weight.data.to_vec(),
weight.shape.to_vec(),
);
let rank = self.select_rank(weight_tensor.shape());
let ring = compress_tensor_ring(&weight_tensor, rank)
.map_err(|e| GraphError::InvalidFormat(e.to_string()))?;
let original_params = weight_tensor.shape().iter().product::<usize>();
let compressed_params = ring.cores.iter()
.map(|c| c.shape().iter().product::<usize>())
.sum::<usize>();
total_original_params += original_params;
total_compressed_params += compressed_params;
compressed_map.insert(weight.name.clone(), ring.clone());
layer_reports.push(LayerCompressionReport {
layer_name: weight.name.clone(),
original_params,
compressed_params,
compression_ratio: original_params as f64 / compressed_params as f64,
ranks: ring.ranks.clone(),
});
}
let overall_ratio = if total_compressed_params > 0 {
total_original_params as f64 / total_compressed_params as f64
} else {
1.0
};
*self.compressed_tensors.borrow_mut() = compressed_map;
*self.original_params.borrow_mut() = total_original_params;
*self.compressed_params.borrow_mut() = total_compressed_params;
Ok(CompressionReport {
original_params: total_original_params,
compressed_params: total_compressed_params,
compression_ratio: overall_ratio,
layers: layer_reports,
})
}
pub fn compression_ratio(&self) -> f64 {
let compressed = *self.compressed_params.borrow();
if compressed == 0 {
return 1.0;
}
let original = *self.original_params.borrow();
original as f64 / compressed as f64
}
pub fn original_params(&self) -> usize {
*self.original_params.borrow()
}
pub fn compressed_params(&self) -> usize {
*self.compressed_params.borrow()
}
pub fn compressed_tensors(&self) -> std::cell::Ref<'_, HashMap<String, TensorRing>> {
self.compressed_tensors.borrow()
}
fn select_rank(&self, shape: &[usize]) -> usize {
let min_dim = shape.iter().min().copied().unwrap_or(1024);
let base_rank = self.config.target_ranks.first().copied().unwrap_or(64);
base_rank
.max(self.config.min_rank)
.min(self.config.max_rank)
.min(min_dim / 2)
}
#[allow(dead_code)]
fn compress_weight(
&self,
name: &str,
tensor: &DenseTensor,
) -> Result<TensorRing, crate::tensor::TensorError> {
use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
let rank = self.select_rank(tensor.shape());
let ring = compress_tensor_ring(tensor, rank)?;
let original = tensor.shape().iter().product::<usize>();
let compressed = ring
.cores
.iter()
.map(|c| c.shape().iter().product::<usize>())
.sum::<usize>();
*self.original_params.borrow_mut() += original;
*self.compressed_params.borrow_mut() += compressed;
self.compressed_tensors.borrow_mut().insert(name.to_string(), ring.clone());
Ok(ring)
}
}
impl Default for TensorRingCompressor {
fn default() -> Self {
Self::new(CompressionConfig::new())
}
}
pub fn adaptive_rank_selection(
tensor: &DenseTensor,
energy_threshold: f64,
) -> Result<usize, crate::tensor::TensorError> {
use crate::tensor::decomposition::svd_decompose;
let shape = tensor.shape();
let min_dim = shape.iter().min().copied().unwrap_or(1);
let (_, s, _) = svd_decompose(tensor, Some(min_dim))?;
let s_data = s.data();
let total_energy: f64 = s_data.iter().map(|x| x * x).sum();
let threshold = total_energy * energy_threshold;
let mut cumulative_energy = 0.0;
for (i, &sigma) in s_data.iter().enumerate() {
cumulative_energy += sigma * sigma;
if cumulative_energy >= threshold {
return Ok(i + 1);
}
}
Ok(min_dim)
}
pub fn mixed_precision_compress(
tensors: &HashMap<String, DenseTensor>,
base_rank: usize,
importance_map: Option<&HashMap<String, f64>>,
) -> Result<HashMap<String, TensorRing>, crate::tensor::TensorError> {
use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
let mut results = HashMap::new();
for (name, tensor) in tensors {
let importance = importance_map
.and_then(|m| m.get(name))
.copied()
.unwrap_or(1.0);
let rank = (base_rank as f64 * importance).ceil() as usize;
let ring = compress_tensor_ring(tensor, rank)?;
results.insert(name.clone(), ring);
}
Ok(results)
}
#[derive(Debug, Clone)]
pub struct LayerCompressionReport {
pub layer_name: String,
pub original_params: usize,
pub compressed_params: usize,
pub compression_ratio: f64,
pub ranks: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct CompressionReport {
pub original_params: usize,
pub compressed_params: usize,
pub compression_ratio: f64,
pub layers: Vec<LayerCompressionReport>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::traits::TensorOps;
#[test]
fn test_compression_config() {
let config = CompressionConfig::new()
.with_target_ranks(vec![32, 64])
.with_layers(vec!["qkv".to_string(), "mlp".to_string()])
.with_min_rank(16)
.with_max_rank(128);
assert!(config.matches_layer("model.layers.0.qkv.weight"));
assert!(config.matches_layer("model.layers.0.mlp.gate_proj"));
assert!(!config.matches_layer("model.norm.weight"));
}
#[test]
fn test_tensor_ring_compressor() {
let config = CompressionConfig::new()
.with_target_ranks(vec![4])
.with_min_rank(2)
.with_max_rank(8);
let compressor = TensorRingCompressor::new(config);
let tensor = DenseTensor::from_vec(
vec![1.0; 64 * 64],
vec![64, 64],
);
let ring = compressor.decompose(&tensor).unwrap();
eprintln!("Original shape: {:?}", ring.original_shape);
eprintln!("Ranks: {:?}", ring.ranks);
eprintln!("Core shapes: {:?}", ring.cores.iter().map(|c| c.shape()).collect::<Vec<_>>());
eprintln!("Compression ratio: {}", ring.compression_ratio());
assert!(ring.compression_ratio() > 1.0, "Compression ratio should be > 1.0, got {}", ring.compression_ratio());
}
#[test]
fn test_adaptive_rank_selection() {
let u = DenseTensor::from_vec(
(0..100 * 5).map(|i| (i % 10) as f64 / 10.0).collect(),
vec![100, 5],
);
let v = DenseTensor::from_vec(
(0..5 * 50).map(|i| (i % 7) as f64 / 10.0).collect(),
vec![5, 50],
);
let tensor = u.matmul(&v);
let rank = adaptive_rank_selection(&tensor, 0.99).unwrap();
assert!(rank <= 10); }
#[test]
fn test_compress_weight() {
let config = CompressionConfig::new()
.with_target_ranks(vec![4])
.with_min_rank(2)
.with_max_rank(8);
let compressor = TensorRingCompressor::new(config);
let tensor = DenseTensor::from_vec(
vec![1.0; 16 * 16],
vec![16, 16],
);
let ring = compressor.compress_weight("test_weight", &tensor).unwrap();
assert_eq!(ring.original_shape, vec![16, 16]);
assert!(!ring.cores.is_empty());
}
#[test]
fn test_compression_ratio() {
let config = CompressionConfig::new()
.with_target_ranks(vec![4])
.with_min_rank(2)
.with_max_rank(8);
let compressor = TensorRingCompressor::new(config);
let tensor = DenseTensor::from_vec(
vec![1.0; 32 * 32],
vec![32, 32],
);
let ring = compressor.decompose(&tensor).unwrap();
let ratio = ring.compression_ratio();
assert!(ratio > 0.0);
}
#[test]
fn test_reconstruct_tensor() {
let config = CompressionConfig::new()
.with_target_ranks(vec![4])
.with_min_rank(2)
.with_max_rank(8);
let compressor = TensorRingCompressor::new(config);
let tensor = DenseTensor::from_vec(
vec![1.0; 8 * 8],
vec![8, 8],
);
let ring = compressor.decompose(&tensor).unwrap();
let reconstructed = ring.reconstruct().unwrap();
assert_eq!(reconstructed.shape(), tensor.shape());
}
}