god_graph/transformer/optimization/
lie_group.rs1use crate::errors::{GraphError, GraphResult};
54use crate::graph::Graph;
55use crate::tensor::DenseTensor;
56use crate::tensor::TensorBase;
57use crate::transformer::optimization::error_analysis::ErrorAccumulator;
58use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
59use std::cell::RefCell;
60use std::collections::HashMap;
61
62#[derive(Debug, Clone)]
64pub struct LieGroupConfig {
65 pub block_size: usize,
67 pub orthogonalize: bool,
69 pub target_layers: Vec<String>,
71 pub use_cayley: bool,
73 pub iterations: usize,
75 pub tolerance: f64,
77}
78
79impl LieGroupConfig {
80 pub fn new() -> Self {
82 Self {
83 block_size: 64,
84 orthogonalize: true,
85 target_layers: vec![".*".to_string()],
86 use_cayley: false,
87 iterations: 10,
88 tolerance: 1e-6,
89 }
90 }
91
92 pub fn with_block_size(mut self, size: usize) -> Self {
94 self.block_size = size;
95 self
96 }
97
98 pub fn with_orthogonalize(mut self, ortho: bool) -> Self {
100 self.orthogonalize = ortho;
101 self
102 }
103
104 pub fn with_target_layers(mut self, layers: Vec<String>) -> Self {
106 self.target_layers = layers;
107 self
108 }
109
110 pub fn with_cayley(mut self, use_cayley: bool) -> Self {
112 self.use_cayley = use_cayley;
113 self
114 }
115
116 pub fn with_iterations(mut self, iterations: usize) -> Self {
118 self.iterations = iterations;
119 self
120 }
121
122 pub fn matches_layer(&self, layer_name: &str) -> bool {
124 self.target_layers.iter().any(|pattern| {
125 if pattern == ".*" {
126 true
127 } else {
128 layer_name.contains(pattern)
129 }
130 })
131 }
132}
133
134impl Default for LieGroupConfig {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140pub struct LieGroupOptimizer {
142 config: LieGroupConfig,
143 statistics: RefCell<HashMap<String, f64>>,
144 error_accumulator: RefCell<ErrorAccumulator>,
145}
146
147impl LieGroupOptimizer {
148 pub fn new(config: LieGroupConfig) -> Self {
150 Self {
151 config,
152 statistics: RefCell::new(HashMap::new()),
153 error_accumulator: RefCell::new(ErrorAccumulator::new()),
154 }
155 }
156
157 pub fn config(&self) -> &LieGroupConfig {
159 &self.config
160 }
161
162 pub fn statistics(&self) -> std::cell::Ref<'_, HashMap<String, f64>> {
164 self.statistics.borrow()
165 }
166
167 pub fn error_accumulator(&self) -> std::cell::Ref<'_, ErrorAccumulator> {
169 self.error_accumulator.borrow()
170 }
171
172 pub fn error_accumulator_mut(&self) -> std::cell::RefMut<'_, ErrorAccumulator> {
174 self.error_accumulator.borrow_mut()
175 }
176
177 pub fn orthogonalize_weights(
191 &self,
192 graph: &mut Graph<OperatorType, WeightTensor>,
193 ) -> GraphResult<()> {
194 use crate::graph::traits::GraphQuery;
195
196 let mut orthogonalized_count = 0;
197 let mut total_error = 0.0;
198
199 let edge_indices: Vec<_> = graph.edges().map(|e| e.index()).collect();
201
202 for edge_idx in edge_indices {
204 let error = self.orthogonalize_single_weight(graph, edge_idx)?;
205
206 let weight = &graph[edge_idx];
208 self.error_accumulator
209 .borrow_mut()
210 .record_error(&weight.name, error);
211
212 total_error += error;
213 orthogonalized_count += 1;
214 }
215
216 if orthogonalized_count > 0 {
218 self.statistics.borrow_mut().insert(
219 "orthogonalization_error".to_string(),
220 total_error / orthogonalized_count as f64
221 );
222 }
223
224 Ok(())
225 }
226
227 pub fn orthogonalize_single_weight(
236 &self,
237 graph: &mut Graph<OperatorType, WeightTensor>,
238 edge_idx: crate::edge::EdgeIndex,
239 ) -> GraphResult<f64> {
240 use crate::tensor::decomposition::qr::orthogonalize_in_place;
241
242 let weight = &mut graph[edge_idx];
244 let shape = weight.shape.to_vec();
245
246 if shape.len() != 2 {
248 eprintln!("Skipping orthogonalization for {}: shape={:?} (not 2D)", weight.name, shape);
249 return Ok(0.0);
250 }
251
252 if shape[0] < shape[1] {
254 eprintln!("Skipping orthogonalization for {}: shape={:?} (m < n)", weight.name, shape);
255 return Ok(0.0);
256 }
257
258 let error = orthogonalize_in_place(&mut weight.data, &shape)
260 .map_err(|e| GraphError::InvalidFormat(e.to_string()))?;
261
262 Ok(error)
263 }
264
265 #[allow(dead_code)]
267 fn check_orthogonality(tensor: &DenseTensor) -> f64 {
268 let shape = tensor.shape();
269 if shape.len() != 2 {
270 return f64::MAX;
271 }
272
273 let n = shape[0];
274 let m = shape[1];
275 let data = tensor.data();
276
277 let mut max_error: f64 = 0.0;
279 for i in 0..m {
280 for j in 0..m {
281 let mut dot = 0.0;
282 for k in 0..n {
283 dot += data[k * m + i] * data[k * m + j];
284 }
285 let expected = if i == j { 1.0 } else { 0.0 };
286 let error = (dot - expected).abs();
287 max_error = max_error.max(error);
288 }
289 }
290
291 max_error
292 }
293
294 pub fn block_decompose(
306 &self,
307 graph: &mut Graph<OperatorType, WeightTensor>,
308 ) -> GraphResult<DecomposedWeights> {
309 use crate::graph::traits::GraphQuery;
310
311 let block_size = self.config.block_size;
312 let mut decomposed_blocks = Vec::new();
313 let mut total_blocks = 0;
314
315 let edge_data: Vec<_> = graph.edges().map(|e| {
317 (e.index(), e.data().name.clone(), e.data().data.to_vec(), e.data().shape.to_vec())
318 }).collect();
319
320 for (_edge_idx, layer_name, weight_data, weight_shape) in edge_data {
321 if !self.config.matches_layer(&layer_name) {
323 continue;
324 }
325
326 let weight_tensor = DenseTensor::new(weight_data, weight_shape);
328
329 let blocks = decompose_into_so_blocks(&weight_tensor, block_size)
331 .map_err(|e| GraphError::InvalidFormat(e.to_string()))?;
332
333 total_blocks += blocks.len();
334
335 decomposed_blocks.push(BlockDecomposition {
336 layer_name,
337 num_blocks: blocks.len(),
338 block_size,
339 });
340 }
341
342 self.statistics.borrow_mut().insert(
344 "total_blocks".to_string(),
345 total_blocks as f64
346 );
347
348 Ok(DecomposedWeights {
349 blocks: decomposed_blocks,
350 total_blocks,
351 })
352 }
353
354 pub fn lie_algebra_regularize(
366 &self,
367 tensor: &DenseTensor,
368 ) -> Result<DenseTensor, crate::tensor::TensorError> {
369 use crate::tensor::decomposition::lie_algebra::skew_symmetric_projection;
370
371 let skew = skew_symmetric_projection(tensor)?;
373
374 crate::tensor::decomposition::lie_algebra::lie_exponential(&skew)
376 }
377
378 pub fn cayley_transform(
392 &self,
393 tensor: &DenseTensor,
394 ) -> Result<DenseTensor, crate::tensor::TensorError> {
395 use crate::tensor::decomposition::lie_algebra::{
396 lie_exponential, skew_symmetric_projection,
397 };
398
399 if self.config.use_cayley {
400 let skew = skew_symmetric_projection(tensor)?;
402 lie_exponential(&skew)
404 } else {
405 crate::tensor::decomposition::qr::orthogonalize(tensor)
407 }
408 }
409
410 pub fn is_well_conditioned(&self, tensor: &DenseTensor, threshold: f64) -> bool {
421 let shape = tensor.shape();
423 if shape.len() != 2 {
424 return false;
425 }
426
427 let data = tensor.data();
428 let (m, n) = (shape[0], shape[1]);
429
430 let mut v = vec![1.0 / (n as f64).sqrt(); n];
432 for _ in 0..20 {
433 let mut av = vec![0.0; m];
435 for i in 0..m {
436 for j in 0..n {
437 av[i] += data[i * n + j] * v[j];
438 }
439 }
440
441 let mut atav = vec![0.0; n];
442 for i in 0..n {
443 for j in 0..m {
444 atav[i] += data[j * n + i] * av[j];
445 }
446 }
447
448 let norm: f64 = atav.iter().map(|x| x * x).sum::<f64>().sqrt();
449 if norm < 1e-10 {
450 return true;
451 }
452 v = atav.into_iter().map(|x| x / norm).collect();
453 }
454
455 let sigma_max_sq: f64 = v
457 .iter()
458 .enumerate()
459 .map(|(i, &vi)| {
460 let mut sum = 0.0;
461 for j in 0..n {
462 let mut aj = 0.0;
463 for k in 0..m {
464 aj += data[k * n + j] * data[k * n + i];
465 }
466 sum += aj * v[j];
467 }
468 sum * vi
469 })
470 .sum();
471
472 let sigma_max = sigma_max_sq.sqrt();
473 let sigma_min = 1.0 / sigma_max; let condition_number = sigma_max / sigma_min;
476 condition_number < threshold
477 }
478}
479
480#[derive(Debug, Clone)]
482pub struct BlockDecomposition {
483 pub layer_name: String,
485 pub num_blocks: usize,
487 pub block_size: usize,
489}
490
491#[derive(Debug, Clone)]
493pub struct DecomposedWeights {
494 pub blocks: Vec<BlockDecomposition>,
496 pub total_blocks: usize,
498}
499
500#[derive(Debug, Clone)]
502pub struct SOkBlock {
503 pub data: Vec<f64>,
505 pub size: usize,
507}
508
509impl SOkBlock {
510 pub fn new(data: Vec<f64>, size: usize) -> Result<Self, crate::tensor::TensorError> {
512 if data.len() != size * size {
513 return Err(crate::tensor::TensorError::DimensionMismatch {
514 expected: size * size,
515 got: data.len(),
516 });
517 }
518
519 Ok(Self { data, size })
520 }
521
522 pub fn is_orthogonal(&self, tolerance: f64) -> bool {
524 let n = self.size;
525 let data = &self.data;
526
527 for i in 0..n {
529 for j in 0..n {
530 let mut dot = 0.0;
531 for k in 0..n {
532 dot += data[k * n + i] * data[k * n + j];
533 }
534 let expected = if i == j { 1.0 } else { 0.0 };
535 if (dot - expected).abs() > tolerance {
536 return false;
537 }
538 }
539 }
540
541 true
542 }
543}
544
545pub fn decompose_into_so_blocks(
556 tensor: &DenseTensor,
557 block_size: usize,
558) -> Result<Vec<SOkBlock>, crate::tensor::TensorError> {
559 use crate::tensor::decomposition::qr::orthogonalize;
560
561 let shape = tensor.shape();
562 if shape.len() != 2 {
563 return Err(crate::tensor::TensorError::DimensionMismatch {
564 expected: 2,
565 got: shape.len(),
566 });
567 }
568
569 let (m, n) = (shape[0], shape[1]);
570 let mut blocks = Vec::new();
571
572 for i in (0..m).step_by(block_size) {
574 for j in (0..n).step_by(block_size) {
575 let block_m = std::cmp::min(block_size, m - i);
576 let block_n = std::cmp::min(block_size, n - j);
577
578 let mut block_data = vec![0.0; block_m * block_n];
580 for bi in 0..block_m {
581 for bj in 0..block_n {
582 block_data[bi * block_n + bj] =
583 tensor.data()[(i + bi) * n + (j + bj)];
584 }
585 }
586
587 if block_m != block_n {
589 let size = std::cmp::max(block_m, block_n);
590 let mut square_block = vec![0.0; size * size];
591 for bi in 0..block_m {
592 for bj in 0..block_n {
593 square_block[bi * size + bj] = block_data[bi * block_n + bj];
594 }
595 }
596 block_data = square_block;
597 }
598
599 let block_tensor = DenseTensor::from_vec(
601 block_data,
602 vec![block_m.max(block_n), block_m.max(block_n)],
603 );
604 let ortho = orthogonalize(&block_tensor)?;
605
606 blocks.push(SOkBlock::new(ortho.data().to_vec(), ortho.shape()[0])?);
607 }
608 }
609
610 Ok(blocks)
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_lie_group_config() {
619 let config = LieGroupConfig::new()
620 .with_block_size(128)
621 .with_orthogonalize(true)
622 .with_target_layers(vec!["q_proj".to_string(), "k_proj".to_string()]);
623
624 assert_eq!(config.block_size, 128);
625 assert!(config.orthogonalize);
626 assert!(config.matches_layer("model.layers.0.attn.q_proj"));
627 assert!(config.matches_layer("model.layers.0.attn.k_proj"));
628 assert!(!config.matches_layer("model.layers.0.mlp"));
629 }
630
631 #[test]
632 fn test_sok_block() {
633 let theta = std::f64::consts::PI / 4.0;
635 let cos_t = theta.cos();
636 let sin_t = theta.sin();
637
638 let block = SOkBlock::new(
639 vec![cos_t, -sin_t, sin_t, cos_t],
640 2,
641 ).unwrap();
642
643 assert!(block.is_orthogonal(1e-5));
644 }
645
646 #[test]
647 fn test_decompose_into_so_blocks() {
648 let tensor = DenseTensor::from_vec(
649 vec![1.0, 0.0, 0.0, 1.0],
650 vec![2, 2],
651 );
652
653 let blocks = decompose_into_so_blocks(&tensor, 2).unwrap();
654 assert_eq!(blocks.len(), 1);
655 assert!(blocks[0].is_orthogonal(1e-5));
656 }
657
658 #[test]
659 fn test_lie_optimizer() {
660 let config = LieGroupConfig::new()
661 .with_block_size(64)
662 .with_orthogonalize(true);
663
664 let optimizer = LieGroupOptimizer::new(config);
665
666 let tensor = DenseTensor::from_vec(
667 vec![1.0, 2.0, 3.0, 4.0],
668 vec![2, 2],
669 );
670
671 let result = optimizer.cayley_transform(&tensor);
672 assert!(result.is_ok());
673 }
674
675 #[test]
676 fn test_orthogonalize_single_weight() {
677 use crate::graph::Graph;
678 use crate::graph::traits::GraphOps;
679
680 let config = LieGroupConfig::new()
681 .with_block_size(2)
682 .with_orthogonalize(true);
683
684 let optimizer = LieGroupOptimizer::new(config);
685 let mut graph = Graph::<OperatorType, WeightTensor>::directed();
686
687 let from = graph.add_node(OperatorType::Linear { in_features: 2, out_features: 2 }).unwrap();
688 let to = graph.add_node(OperatorType::Linear { in_features: 2, out_features: 2 }).unwrap();
689
690 let weight = WeightTensor::new(
691 "test".to_string(),
692 vec![1.0, 0.0, 0.0, 1.0],
693 vec![2, 2],
694 );
695 let edge = graph.add_edge(from, to, weight).unwrap();
696
697 let error = optimizer.orthogonalize_single_weight(&mut graph, edge);
698 assert!(error.is_ok());
699 }
700
701 #[test]
702 fn test_error_accumulator() {
703 let config = LieGroupConfig::new().with_orthogonalize(true);
704 let optimizer = LieGroupOptimizer::new(config);
705
706 {
708 let mut acc = optimizer.error_accumulator_mut();
709 acc.record_error("layer1", 0.01);
710 acc.record_error("layer2", 0.02);
711 }
712
713 let acc = optimizer.error_accumulator();
715 assert_eq!(acc.num_layers(), 2);
716 assert!(acc.get_layer_errors("layer1").is_some());
717 assert!(acc.get_layer_errors("layer2").is_some());
718 }
719
720 #[test]
721 fn test_check_orthogonality() {
722 let identity = DenseTensor::from_vec(
724 vec![1.0, 0.0, 0.0, 1.0],
725 vec![2, 2],
726 );
727
728 let error = LieGroupOptimizer::check_orthogonality(&identity);
729 assert!(error < 1e-10);
730
731 let non_ortho = DenseTensor::from_vec(
733 vec![1.0, 1.0, 1.0, 1.0],
734 vec![2, 2],
735 );
736
737 let error = LieGroupOptimizer::check_orthogonality(&non_ortho);
738 assert!(error > 0.1);
739 }
740}
741
742pub fn orthogonalize_weights_in_place(
754 config: &LieGroupConfig,
755 graph: &mut Graph<OperatorType, WeightTensor>,
756) -> GraphResult<Vec<f64>> {
757 use crate::graph::traits::GraphQuery;
758
759 let mut errors = Vec::new();
760 let optimizer = LieGroupOptimizer::new(config.clone());
761
762 let edge_indices: Vec<_> = graph.edges().map(|e| e.index()).collect();
764
765 for edge_idx in edge_indices {
766 let error = optimizer.orthogonalize_single_weight(graph, edge_idx)?;
767 errors.push(error);
768 }
769
770 Ok(errors)
771}