Skip to main content

god_graph/transformer/optimization/
tensor_ring.rs

1//! Tensor Ring Compression for LLM weights
2//!
3//! This module implements tensor ring compression for LLM weight matrices,
4//! achieving parameter reduction while maintaining accuracy.
5//!
6//! ## Mathematical Foundation
7//!
8//! Tensor Ring decomposition represents a high-dimensional tensor as a ring of
9//! 3D core tensors:
10//!
11//! W(i₁,...,iₙ) = Σ Tr[G₁(i₁) × G₂(i₂) × ... × Gₙ(iₙ)]
12//!
13//! where Gₖ(iₖ) ∈ R^(rₖ₋₁×rₖ) and rₖ are the TR ranks controlling compression.
14//!
15//! ## Compression Ratio
16//!
17//! For a weight matrix W ∈ R^(m×n) with TR ranks [r₀, r₁, r₂]:
18//! - Original parameters: m × n
19//! - TR parameters: r₀×m×r₁ + r₁×n×r₂
20//! - Compression ratio: (m × n) / (r₀×m×r₁ + r₁×n×r₂)
21//!
22//! ## Example
23//!
24//! ```no_run
25//! use god_gragh::transformer::optimization::{TensorRingCompressor, CompressionConfig};
26//!
27//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
28//! let config = CompressionConfig::new()
29//!     .with_target_ranks(vec![32, 64])
30//!     .with_layers(vec!["qkv".to_string(), "mlp".to_string()]);
31//!
32//! let compressor = TensorRingCompressor::new(config);
33//!
34//! // Compress a weight matrix
35//! // let compressed_graph = compressor.compress_graph(&graph)?;
36//!
37//! // Query compression ratio
38//! // println!("Compression ratio: {:.2}x", compressor.compression_ratio());
39//! # Ok(())
40//! # }
41//! ```
42
43use crate::errors::{GraphError, GraphResult};
44use crate::graph::Graph;
45use crate::tensor::decomposition::tensor_ring::TensorRing;
46use crate::tensor::DenseTensor;
47use crate::tensor::TensorBase;
48use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
49use std::cell::RefCell;
50use std::collections::HashMap;
51
52/// Configuration for tensor ring compression
53#[derive(Debug, Clone)]
54pub struct CompressionConfig {
55    /// Target TR ranks for each dimension
56    pub target_ranks: Vec<usize>,
57    /// Layers to compress (by name pattern)
58    pub layers: Vec<String>,
59    /// Minimum rank (lower bound)
60    pub min_rank: usize,
61    /// Maximum rank (upper bound)
62    pub max_rank: usize,
63    /// Target compression ratio (adaptive rank selection)
64    pub target_ratio: Option<f64>,
65}
66
67impl CompressionConfig {
68    /// Create a new compression config with default values
69    pub fn new() -> Self {
70        Self {
71            target_ranks: vec![64],
72            layers: vec![".*".to_string()], // Match all layers by default
73            min_rank: 16,
74            max_rank: 256,
75            target_ratio: None,
76        }
77    }
78
79    /// Set target ranks
80    pub fn with_target_ranks(mut self, ranks: Vec<usize>) -> Self {
81        self.target_ranks = ranks;
82        self
83    }
84
85    /// Set layers to compress
86    pub fn with_layers(mut self, layers: Vec<String>) -> Self {
87        self.layers = layers;
88        self
89    }
90
91    /// Set minimum rank
92    pub fn with_min_rank(mut self, rank: usize) -> Self {
93        self.min_rank = rank;
94        self
95    }
96
97    /// Set maximum rank
98    pub fn with_max_rank(mut self, rank: usize) -> Self {
99        self.max_rank = rank;
100        self
101    }
102
103    /// Set target compression ratio
104    pub fn with_target_ratio(mut self, ratio: f64) -> Self {
105        self.target_ratio = Some(ratio.clamp(1.5, 10.0));
106        self
107    }
108
109    /// Check if a layer name matches the compression pattern
110    pub fn matches_layer(&self, layer_name: &str) -> bool {
111        self.layers.iter().any(|pattern| {
112            if pattern == ".*" {
113                true
114            } else {
115                layer_name.contains(pattern)
116            }
117        })
118    }
119}
120
121impl Default for CompressionConfig {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127/// Tensor Ring compressor for LLM weights
128pub struct TensorRingCompressor {
129    config: CompressionConfig,
130    compressed_tensors: RefCell<HashMap<String, TensorRing>>,
131    original_params: RefCell<usize>,
132    compressed_params: RefCell<usize>,
133}
134
135impl TensorRingCompressor {
136    /// Create a new tensor ring compressor
137    pub fn new(config: CompressionConfig) -> Self {
138        Self {
139            config,
140            compressed_tensors: RefCell::new(HashMap::new()),
141            original_params: RefCell::new(0),
142            compressed_params: RefCell::new(0),
143        }
144    }
145
146    /// Get the compression configuration
147    pub fn config(&self) -> &CompressionConfig {
148        &self.config
149    }
150
151    /// Compress a single tensor
152    ///
153    /// # Arguments
154    ///
155    /// * `tensor` - Weight tensor to compress
156    ///
157    /// # Returns
158    ///
159    /// TensorRing decomposition of the input tensor
160    pub fn decompose(&self, tensor: &DenseTensor) -> Result<TensorRing, crate::tensor::TensorError> {
161        use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
162
163        let shape = tensor.shape();
164
165        // Select rank based on tensor dimensions and config
166        let rank = self.select_rank(shape);
167
168        compress_tensor_ring(tensor, rank)
169    }
170
171    /// Reconstruct a tensor from its Tensor Ring decomposition
172    ///
173    /// # Arguments
174    ///
175    /// * `ring` - TensorRing decomposition to reconstruct
176    ///
177    /// # Returns
178    ///
179    /// Reconstructed dense tensor
180    pub fn reconstruct(&self, ring: &TensorRing) -> Result<DenseTensor, crate::tensor::TensorError> {
181        ring.reconstruct()
182    }
183
184    /// Compress all weights in a graph
185    ///
186    /// # Arguments
187    ///
188    /// * `graph` - Graph containing weights to compress
189    ///
190    /// # Returns
191    ///
192    /// Compression report with statistics
193    pub fn compress_graph(
194        &self,
195        graph: &Graph<OperatorType, WeightTensor>,
196    ) -> GraphResult<CompressionReport> {
197        use crate::graph::traits::GraphQuery;
198        use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
199
200        let mut total_original_params = 0usize;
201        let mut total_compressed_params = 0usize;
202        let mut layer_reports = Vec::new();
203        let mut compressed_map = HashMap::new();
204
205        // Calculate compression statistics for each edge weight
206        for edge_ref in graph.edges() {
207            let weight = edge_ref.data();
208
209            // Decompose into Tensor Ring format
210            let weight_tensor = DenseTensor::new(
211                weight.data.to_vec(),
212                weight.shape.to_vec(),
213            );
214
215            // Select rank based on config
216            let rank = self.select_rank(weight_tensor.shape());
217
218            let ring = compress_tensor_ring(&weight_tensor, rank)
219                .map_err(|e| GraphError::InvalidFormat(e.to_string()))?;
220
221            // Count parameters
222            let original_params = weight_tensor.shape().iter().product::<usize>();
223            let compressed_params = ring.cores.iter()
224                .map(|c| c.shape().iter().product::<usize>())
225                .sum::<usize>();
226
227            total_original_params += original_params;
228            total_compressed_params += compressed_params;
229
230            // Store compressed tensor
231            compressed_map.insert(weight.name.clone(), ring.clone());
232
233            layer_reports.push(LayerCompressionReport {
234                layer_name: weight.name.clone(),
235                original_params,
236                compressed_params,
237                compression_ratio: original_params as f64 / compressed_params as f64,
238                ranks: ring.ranks.clone(),
239            });
240        }
241
242        let overall_ratio = if total_compressed_params > 0 {
243            total_original_params as f64 / total_compressed_params as f64
244        } else {
245            1.0
246        };
247
248        // Store compressed tensors and statistics
249        *self.compressed_tensors.borrow_mut() = compressed_map;
250        *self.original_params.borrow_mut() = total_original_params;
251        *self.compressed_params.borrow_mut() = total_compressed_params;
252
253        Ok(CompressionReport {
254            original_params: total_original_params,
255            compressed_params: total_compressed_params,
256            compression_ratio: overall_ratio,
257            layers: layer_reports,
258        })
259    }
260
261    /// Get the achieved compression ratio
262    pub fn compression_ratio(&self) -> f64 {
263        let compressed = *self.compressed_params.borrow();
264        if compressed == 0 {
265            return 1.0;
266        }
267        let original = *self.original_params.borrow();
268        original as f64 / compressed as f64
269    }
270
271    /// Get the number of original parameters
272    pub fn original_params(&self) -> usize {
273        *self.original_params.borrow()
274    }
275
276    /// Get the number of compressed parameters
277    pub fn compressed_params(&self) -> usize {
278        *self.compressed_params.borrow()
279    }
280
281    /// Get compressed tensors
282    pub fn compressed_tensors(&self) -> std::cell::Ref<'_, HashMap<String, TensorRing>> {
283        self.compressed_tensors.borrow()
284    }
285
286    /// Select optimal rank for a tensor based on config
287    fn select_rank(&self, shape: &[usize]) -> usize {
288        // Simple heuristic: use config rank or adapt based on dimensions
289        let min_dim = shape.iter().min().copied().unwrap_or(1024);
290        
291        let base_rank = self.config.target_ranks.first().copied().unwrap_or(64);
292        
293        // Clamp to min/max bounds
294        base_rank
295            .max(self.config.min_rank)
296            .min(self.config.max_rank)
297            .min(min_dim / 2)
298    }
299
300    /// Compress a weight tensor and store the result
301    #[allow(dead_code)]
302    fn compress_weight(
303        &self,
304        name: &str,
305        tensor: &DenseTensor,
306    ) -> Result<TensorRing, crate::tensor::TensorError> {
307        use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
308
309        let rank = self.select_rank(tensor.shape());
310        let ring = compress_tensor_ring(tensor, rank)?;
311
312        // Update parameter counts
313        let original = tensor.shape().iter().product::<usize>();
314        let compressed = ring
315            .cores
316            .iter()
317            .map(|c| c.shape().iter().product::<usize>())
318            .sum::<usize>();
319
320        *self.original_params.borrow_mut() += original;
321        *self.compressed_params.borrow_mut() += compressed;
322
323        // Store compressed tensor
324        self.compressed_tensors.borrow_mut().insert(name.to_string(), ring.clone());
325
326        Ok(ring)
327    }
328}
329
330impl Default for TensorRingCompressor {
331    fn default() -> Self {
332        Self::new(CompressionConfig::new())
333    }
334}
335
336/// Adaptive rank selection based on singular value decay
337///
338/// # Arguments
339///
340/// * `tensor` - Weight tensor to analyze
341/// * `energy_threshold` - Fraction of energy to preserve (e.g., 0.99)
342///
343/// # Returns
344///
345/// Recommended rank for compression
346pub fn adaptive_rank_selection(
347    tensor: &DenseTensor,
348    energy_threshold: f64,
349) -> Result<usize, crate::tensor::TensorError> {
350    use crate::tensor::decomposition::svd_decompose;
351
352    let shape = tensor.shape();
353    let min_dim = shape.iter().min().copied().unwrap_or(1);
354    
355    // Compute SVD
356    let (_, s, _) = svd_decompose(tensor, Some(min_dim))?;
357    
358    // Calculate cumulative energy
359    let s_data = s.data();
360    let total_energy: f64 = s_data.iter().map(|x| x * x).sum();
361    let threshold = total_energy * energy_threshold;
362    
363    let mut cumulative_energy = 0.0;
364    for (i, &sigma) in s_data.iter().enumerate() {
365        cumulative_energy += sigma * sigma;
366        if cumulative_energy >= threshold {
367            return Ok(i + 1);
368        }
369    }
370    
371    Ok(min_dim)
372}
373
374/// Mixed precision compression strategy
375///
376/// Compresses different layers with different ranks based on importance.
377///
378/// # Arguments
379///
380/// * `tensors` - Map of layer names to weight tensors
381/// * `base_rank` - Base compression rank
382/// * `importance_map` - Optional importance scores for each layer
383///
384/// # Returns
385///
386/// Map of layer names to TensorRing decompositions
387pub fn mixed_precision_compress(
388    tensors: &HashMap<String, DenseTensor>,
389    base_rank: usize,
390    importance_map: Option<&HashMap<String, f64>>,
391) -> Result<HashMap<String, TensorRing>, crate::tensor::TensorError> {
392    use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
393
394    let mut results = HashMap::new();
395    
396    for (name, tensor) in tensors {
397        // Adjust rank based on importance
398        let importance = importance_map
399            .and_then(|m| m.get(name))
400            .copied()
401            .unwrap_or(1.0);
402        
403        // Higher importance → higher rank
404        let rank = (base_rank as f64 * importance).ceil() as usize;
405        
406        let ring = compress_tensor_ring(tensor, rank)?;
407        results.insert(name.clone(), ring);
408    }
409    
410    Ok(results)
411}
412
413/// Compression report for a single layer
414#[derive(Debug, Clone)]
415pub struct LayerCompressionReport {
416    /// Layer name
417    pub layer_name: String,
418    /// Original parameter count
419    pub original_params: usize,
420    /// Compressed parameter count
421    pub compressed_params: usize,
422    /// Compression ratio (original / compressed)
423    pub compression_ratio: f64,
424    /// TR ranks used
425    pub ranks: Vec<usize>,
426}
427
428/// Overall compression report
429#[derive(Debug, Clone)]
430pub struct CompressionReport {
431    /// Total original parameters
432    pub original_params: usize,
433    /// Total compressed parameters
434    pub compressed_params: usize,
435    /// Overall compression ratio
436    pub compression_ratio: f64,
437    /// Per-layer reports
438    pub layers: Vec<LayerCompressionReport>,
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use crate::tensor::traits::TensorOps;
445
446    #[test]
447    fn test_compression_config() {
448        let config = CompressionConfig::new()
449            .with_target_ranks(vec![32, 64])
450            .with_layers(vec!["qkv".to_string(), "mlp".to_string()])
451            .with_min_rank(16)
452            .with_max_rank(128);
453
454        assert!(config.matches_layer("model.layers.0.qkv.weight"));
455        assert!(config.matches_layer("model.layers.0.mlp.gate_proj"));
456        assert!(!config.matches_layer("model.norm.weight"));
457    }
458
459    #[test]
460    fn test_tensor_ring_compressor() {
461        // Use a smaller rank to achieve actual compression
462        // For 64x64 matrix, we need rank < sqrt(64*64) = 64 for compression
463        // TR params = 2 * rank * 64 * rank = 128 * rank^2
464        // Original = 4096
465        // For compression: 128 * rank^2 < 4096 => rank^2 < 32 => rank < 6
466        let config = CompressionConfig::new()
467            .with_target_ranks(vec![4])
468            .with_min_rank(2)
469            .with_max_rank(8);
470        let compressor = TensorRingCompressor::new(config);
471
472        let tensor = DenseTensor::from_vec(
473            vec![1.0; 64 * 64],
474            vec![64, 64],
475        );
476
477        let ring = compressor.decompose(&tensor).unwrap();
478        
479        eprintln!("Original shape: {:?}", ring.original_shape);
480        eprintln!("Ranks: {:?}", ring.ranks);
481        eprintln!("Core shapes: {:?}", ring.cores.iter().map(|c| c.shape()).collect::<Vec<_>>());
482        eprintln!("Compression ratio: {}", ring.compression_ratio());
483        
484        assert!(ring.compression_ratio() > 1.0, "Compression ratio should be > 1.0, got {}", ring.compression_ratio());
485    }
486
487    #[test]
488    fn test_adaptive_rank_selection() {
489        // Create a low-rank tensor (rank 5)
490        let u = DenseTensor::from_vec(
491            (0..100 * 5).map(|i| (i % 10) as f64 / 10.0).collect(),
492            vec![100, 5],
493        );
494        let v = DenseTensor::from_vec(
495            (0..5 * 50).map(|i| (i % 7) as f64 / 10.0).collect(),
496            vec![5, 50],
497        );
498        let tensor = u.matmul(&v);
499
500        let rank = adaptive_rank_selection(&tensor, 0.99).unwrap();
501        assert!(rank <= 10); // Should detect low intrinsic rank
502    }
503
504    #[test]
505    fn test_compress_weight() {
506        let config = CompressionConfig::new()
507            .with_target_ranks(vec![4])
508            .with_min_rank(2)
509            .with_max_rank(8);
510        let compressor = TensorRingCompressor::new(config);
511
512        let tensor = DenseTensor::from_vec(
513            vec![1.0; 16 * 16],
514            vec![16, 16],
515        );
516
517        let ring = compressor.compress_weight("test_weight", &tensor).unwrap();
518        
519        assert_eq!(ring.original_shape, vec![16, 16]);
520        assert!(!ring.cores.is_empty());
521    }
522
523    #[test]
524    fn test_compression_ratio() {
525        let config = CompressionConfig::new()
526            .with_target_ranks(vec![4])
527            .with_min_rank(2)
528            .with_max_rank(8);
529        let compressor = TensorRingCompressor::new(config);
530
531        let tensor = DenseTensor::from_vec(
532            vec![1.0; 32 * 32],
533            vec![32, 32],
534        );
535
536        let ring = compressor.decompose(&tensor).unwrap();
537        
538        // Verify compression ratio is calculated correctly
539        let ratio = ring.compression_ratio();
540        assert!(ratio > 0.0);
541    }
542
543    #[test]
544    fn test_reconstruct_tensor() {
545        let config = CompressionConfig::new()
546            .with_target_ranks(vec![4])
547            .with_min_rank(2)
548            .with_max_rank(8);
549        let compressor = TensorRingCompressor::new(config);
550
551        let tensor = DenseTensor::from_vec(
552            vec![1.0; 8 * 8],
553            vec![8, 8],
554        );
555
556        let ring = compressor.decompose(&tensor).unwrap();
557        let reconstructed = ring.reconstruct().unwrap();
558
559        // Check shapes match
560        assert_eq!(reconstructed.shape(), tensor.shape());
561    }
562}