god_graph/transformer/optimization/
tensor_ring.rs1use crate::errors::{GraphError, GraphResult};
44use crate::graph::Graph;
45use crate::tensor::decomposition::tensor_ring::TensorRing;
46use crate::tensor::DenseTensor;
47use crate::tensor::TensorBase;
48use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
49use std::cell::RefCell;
50use std::collections::HashMap;
51
52#[derive(Debug, Clone)]
54pub struct CompressionConfig {
55 pub target_ranks: Vec<usize>,
57 pub layers: Vec<String>,
59 pub min_rank: usize,
61 pub max_rank: usize,
63 pub target_ratio: Option<f64>,
65}
66
67impl CompressionConfig {
68 pub fn new() -> Self {
70 Self {
71 target_ranks: vec![64],
72 layers: vec![".*".to_string()], min_rank: 16,
74 max_rank: 256,
75 target_ratio: None,
76 }
77 }
78
79 pub fn with_target_ranks(mut self, ranks: Vec<usize>) -> Self {
81 self.target_ranks = ranks;
82 self
83 }
84
85 pub fn with_layers(mut self, layers: Vec<String>) -> Self {
87 self.layers = layers;
88 self
89 }
90
91 pub fn with_min_rank(mut self, rank: usize) -> Self {
93 self.min_rank = rank;
94 self
95 }
96
97 pub fn with_max_rank(mut self, rank: usize) -> Self {
99 self.max_rank = rank;
100 self
101 }
102
103 pub fn with_target_ratio(mut self, ratio: f64) -> Self {
105 self.target_ratio = Some(ratio.clamp(1.5, 10.0));
106 self
107 }
108
109 pub fn matches_layer(&self, layer_name: &str) -> bool {
111 self.layers.iter().any(|pattern| {
112 if pattern == ".*" {
113 true
114 } else {
115 layer_name.contains(pattern)
116 }
117 })
118 }
119}
120
121impl Default for CompressionConfig {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127pub struct TensorRingCompressor {
129 config: CompressionConfig,
130 compressed_tensors: RefCell<HashMap<String, TensorRing>>,
131 original_params: RefCell<usize>,
132 compressed_params: RefCell<usize>,
133}
134
135impl TensorRingCompressor {
136 pub fn new(config: CompressionConfig) -> Self {
138 Self {
139 config,
140 compressed_tensors: RefCell::new(HashMap::new()),
141 original_params: RefCell::new(0),
142 compressed_params: RefCell::new(0),
143 }
144 }
145
146 pub fn config(&self) -> &CompressionConfig {
148 &self.config
149 }
150
151 pub fn decompose(&self, tensor: &DenseTensor) -> Result<TensorRing, crate::tensor::TensorError> {
161 use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
162
163 let shape = tensor.shape();
164
165 let rank = self.select_rank(shape);
167
168 compress_tensor_ring(tensor, rank)
169 }
170
171 pub fn reconstruct(&self, ring: &TensorRing) -> Result<DenseTensor, crate::tensor::TensorError> {
181 ring.reconstruct()
182 }
183
184 pub fn compress_graph(
194 &self,
195 graph: &Graph<OperatorType, WeightTensor>,
196 ) -> GraphResult<CompressionReport> {
197 use crate::graph::traits::GraphQuery;
198 use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
199
200 let mut total_original_params = 0usize;
201 let mut total_compressed_params = 0usize;
202 let mut layer_reports = Vec::new();
203 let mut compressed_map = HashMap::new();
204
205 for edge_ref in graph.edges() {
207 let weight = edge_ref.data();
208
209 let weight_tensor = DenseTensor::new(
211 weight.data.to_vec(),
212 weight.shape.to_vec(),
213 );
214
215 let rank = self.select_rank(weight_tensor.shape());
217
218 let ring = compress_tensor_ring(&weight_tensor, rank)
219 .map_err(|e| GraphError::InvalidFormat(e.to_string()))?;
220
221 let original_params = weight_tensor.shape().iter().product::<usize>();
223 let compressed_params = ring.cores.iter()
224 .map(|c| c.shape().iter().product::<usize>())
225 .sum::<usize>();
226
227 total_original_params += original_params;
228 total_compressed_params += compressed_params;
229
230 compressed_map.insert(weight.name.clone(), ring.clone());
232
233 layer_reports.push(LayerCompressionReport {
234 layer_name: weight.name.clone(),
235 original_params,
236 compressed_params,
237 compression_ratio: original_params as f64 / compressed_params as f64,
238 ranks: ring.ranks.clone(),
239 });
240 }
241
242 let overall_ratio = if total_compressed_params > 0 {
243 total_original_params as f64 / total_compressed_params as f64
244 } else {
245 1.0
246 };
247
248 *self.compressed_tensors.borrow_mut() = compressed_map;
250 *self.original_params.borrow_mut() = total_original_params;
251 *self.compressed_params.borrow_mut() = total_compressed_params;
252
253 Ok(CompressionReport {
254 original_params: total_original_params,
255 compressed_params: total_compressed_params,
256 compression_ratio: overall_ratio,
257 layers: layer_reports,
258 })
259 }
260
261 pub fn compression_ratio(&self) -> f64 {
263 let compressed = *self.compressed_params.borrow();
264 if compressed == 0 {
265 return 1.0;
266 }
267 let original = *self.original_params.borrow();
268 original as f64 / compressed as f64
269 }
270
271 pub fn original_params(&self) -> usize {
273 *self.original_params.borrow()
274 }
275
276 pub fn compressed_params(&self) -> usize {
278 *self.compressed_params.borrow()
279 }
280
281 pub fn compressed_tensors(&self) -> std::cell::Ref<'_, HashMap<String, TensorRing>> {
283 self.compressed_tensors.borrow()
284 }
285
286 fn select_rank(&self, shape: &[usize]) -> usize {
288 let min_dim = shape.iter().min().copied().unwrap_or(1024);
290
291 let base_rank = self.config.target_ranks.first().copied().unwrap_or(64);
292
293 base_rank
295 .max(self.config.min_rank)
296 .min(self.config.max_rank)
297 .min(min_dim / 2)
298 }
299
300 #[allow(dead_code)]
302 fn compress_weight(
303 &self,
304 name: &str,
305 tensor: &DenseTensor,
306 ) -> Result<TensorRing, crate::tensor::TensorError> {
307 use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
308
309 let rank = self.select_rank(tensor.shape());
310 let ring = compress_tensor_ring(tensor, rank)?;
311
312 let original = tensor.shape().iter().product::<usize>();
314 let compressed = ring
315 .cores
316 .iter()
317 .map(|c| c.shape().iter().product::<usize>())
318 .sum::<usize>();
319
320 *self.original_params.borrow_mut() += original;
321 *self.compressed_params.borrow_mut() += compressed;
322
323 self.compressed_tensors.borrow_mut().insert(name.to_string(), ring.clone());
325
326 Ok(ring)
327 }
328}
329
330impl Default for TensorRingCompressor {
331 fn default() -> Self {
332 Self::new(CompressionConfig::new())
333 }
334}
335
336pub fn adaptive_rank_selection(
347 tensor: &DenseTensor,
348 energy_threshold: f64,
349) -> Result<usize, crate::tensor::TensorError> {
350 use crate::tensor::decomposition::svd_decompose;
351
352 let shape = tensor.shape();
353 let min_dim = shape.iter().min().copied().unwrap_or(1);
354
355 let (_, s, _) = svd_decompose(tensor, Some(min_dim))?;
357
358 let s_data = s.data();
360 let total_energy: f64 = s_data.iter().map(|x| x * x).sum();
361 let threshold = total_energy * energy_threshold;
362
363 let mut cumulative_energy = 0.0;
364 for (i, &sigma) in s_data.iter().enumerate() {
365 cumulative_energy += sigma * sigma;
366 if cumulative_energy >= threshold {
367 return Ok(i + 1);
368 }
369 }
370
371 Ok(min_dim)
372}
373
374pub fn mixed_precision_compress(
388 tensors: &HashMap<String, DenseTensor>,
389 base_rank: usize,
390 importance_map: Option<&HashMap<String, f64>>,
391) -> Result<HashMap<String, TensorRing>, crate::tensor::TensorError> {
392 use crate::tensor::decomposition::tensor_ring::compress_tensor_ring;
393
394 let mut results = HashMap::new();
395
396 for (name, tensor) in tensors {
397 let importance = importance_map
399 .and_then(|m| m.get(name))
400 .copied()
401 .unwrap_or(1.0);
402
403 let rank = (base_rank as f64 * importance).ceil() as usize;
405
406 let ring = compress_tensor_ring(tensor, rank)?;
407 results.insert(name.clone(), ring);
408 }
409
410 Ok(results)
411}
412
413#[derive(Debug, Clone)]
415pub struct LayerCompressionReport {
416 pub layer_name: String,
418 pub original_params: usize,
420 pub compressed_params: usize,
422 pub compression_ratio: f64,
424 pub ranks: Vec<usize>,
426}
427
428#[derive(Debug, Clone)]
430pub struct CompressionReport {
431 pub original_params: usize,
433 pub compressed_params: usize,
435 pub compression_ratio: f64,
437 pub layers: Vec<LayerCompressionReport>,
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::tensor::traits::TensorOps;
445
446 #[test]
447 fn test_compression_config() {
448 let config = CompressionConfig::new()
449 .with_target_ranks(vec![32, 64])
450 .with_layers(vec!["qkv".to_string(), "mlp".to_string()])
451 .with_min_rank(16)
452 .with_max_rank(128);
453
454 assert!(config.matches_layer("model.layers.0.qkv.weight"));
455 assert!(config.matches_layer("model.layers.0.mlp.gate_proj"));
456 assert!(!config.matches_layer("model.norm.weight"));
457 }
458
459 #[test]
460 fn test_tensor_ring_compressor() {
461 let config = CompressionConfig::new()
467 .with_target_ranks(vec![4])
468 .with_min_rank(2)
469 .with_max_rank(8);
470 let compressor = TensorRingCompressor::new(config);
471
472 let tensor = DenseTensor::from_vec(
473 vec![1.0; 64 * 64],
474 vec![64, 64],
475 );
476
477 let ring = compressor.decompose(&tensor).unwrap();
478
479 eprintln!("Original shape: {:?}", ring.original_shape);
480 eprintln!("Ranks: {:?}", ring.ranks);
481 eprintln!("Core shapes: {:?}", ring.cores.iter().map(|c| c.shape()).collect::<Vec<_>>());
482 eprintln!("Compression ratio: {}", ring.compression_ratio());
483
484 assert!(ring.compression_ratio() > 1.0, "Compression ratio should be > 1.0, got {}", ring.compression_ratio());
485 }
486
487 #[test]
488 fn test_adaptive_rank_selection() {
489 let u = DenseTensor::from_vec(
491 (0..100 * 5).map(|i| (i % 10) as f64 / 10.0).collect(),
492 vec![100, 5],
493 );
494 let v = DenseTensor::from_vec(
495 (0..5 * 50).map(|i| (i % 7) as f64 / 10.0).collect(),
496 vec![5, 50],
497 );
498 let tensor = u.matmul(&v);
499
500 let rank = adaptive_rank_selection(&tensor, 0.99).unwrap();
501 assert!(rank <= 10); }
503
504 #[test]
505 fn test_compress_weight() {
506 let config = CompressionConfig::new()
507 .with_target_ranks(vec![4])
508 .with_min_rank(2)
509 .with_max_rank(8);
510 let compressor = TensorRingCompressor::new(config);
511
512 let tensor = DenseTensor::from_vec(
513 vec![1.0; 16 * 16],
514 vec![16, 16],
515 );
516
517 let ring = compressor.compress_weight("test_weight", &tensor).unwrap();
518
519 assert_eq!(ring.original_shape, vec![16, 16]);
520 assert!(!ring.cores.is_empty());
521 }
522
523 #[test]
524 fn test_compression_ratio() {
525 let config = CompressionConfig::new()
526 .with_target_ranks(vec![4])
527 .with_min_rank(2)
528 .with_max_rank(8);
529 let compressor = TensorRingCompressor::new(config);
530
531 let tensor = DenseTensor::from_vec(
532 vec![1.0; 32 * 32],
533 vec![32, 32],
534 );
535
536 let ring = compressor.decompose(&tensor).unwrap();
537
538 let ratio = ring.compression_ratio();
540 assert!(ratio > 0.0);
541 }
542
543 #[test]
544 fn test_reconstruct_tensor() {
545 let config = CompressionConfig::new()
546 .with_target_ranks(vec![4])
547 .with_min_rank(2)
548 .with_max_rank(8);
549 let compressor = TensorRingCompressor::new(config);
550
551 let tensor = DenseTensor::from_vec(
552 vec![1.0; 8 * 8],
553 vec![8, 8],
554 );
555
556 let ring = compressor.decompose(&tensor).unwrap();
557 let reconstructed = ring.reconstruct().unwrap();
558
559 assert_eq!(reconstructed.shape(), tensor.shape());
561 }
562}