use crate::errors::{GraphError, GraphResult};
use crate::graph::Graph;
use crate::tensor::DenseTensor;
use crate::tensor::TensorBase;
use crate::transformer::optimization::error_analysis::ErrorAccumulator;
use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
use std::cell::RefCell;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct LieGroupConfig {
pub block_size: usize,
pub orthogonalize: bool,
pub target_layers: Vec<String>,
pub use_cayley: bool,
pub iterations: usize,
pub tolerance: f64,
}
impl LieGroupConfig {
pub fn new() -> Self {
Self {
block_size: 64,
orthogonalize: true,
target_layers: vec![".*".to_string()],
use_cayley: false,
iterations: 10,
tolerance: 1e-6,
}
}
pub fn with_block_size(mut self, size: usize) -> Self {
self.block_size = size;
self
}
pub fn with_orthogonalize(mut self, ortho: bool) -> Self {
self.orthogonalize = ortho;
self
}
pub fn with_target_layers(mut self, layers: Vec<String>) -> Self {
self.target_layers = layers;
self
}
pub fn with_cayley(mut self, use_cayley: bool) -> Self {
self.use_cayley = use_cayley;
self
}
pub fn with_iterations(mut self, iterations: usize) -> Self {
self.iterations = iterations;
self
}
pub fn matches_layer(&self, layer_name: &str) -> bool {
self.target_layers.iter().any(|pattern| {
if pattern == ".*" {
true
} else {
layer_name.contains(pattern)
}
})
}
}
impl Default for LieGroupConfig {
fn default() -> Self {
Self::new()
}
}
pub struct LieGroupOptimizer {
config: LieGroupConfig,
statistics: RefCell<HashMap<String, f64>>,
error_accumulator: RefCell<ErrorAccumulator>,
}
impl LieGroupOptimizer {
pub fn new(config: LieGroupConfig) -> Self {
Self {
config,
statistics: RefCell::new(HashMap::new()),
error_accumulator: RefCell::new(ErrorAccumulator::new()),
}
}
pub fn config(&self) -> &LieGroupConfig {
&self.config
}
pub fn statistics(&self) -> std::cell::Ref<'_, HashMap<String, f64>> {
self.statistics.borrow()
}
pub fn error_accumulator(&self) -> std::cell::Ref<'_, ErrorAccumulator> {
self.error_accumulator.borrow()
}
pub fn error_accumulator_mut(&self) -> std::cell::RefMut<'_, ErrorAccumulator> {
self.error_accumulator.borrow_mut()
}
pub fn orthogonalize_weights(
&self,
graph: &mut Graph<OperatorType, WeightTensor>,
) -> GraphResult<()> {
use crate::graph::traits::GraphQuery;
let mut orthogonalized_count = 0;
let mut total_error = 0.0;
let edge_indices: Vec<_> = graph.edges().map(|e| e.index()).collect();
for edge_idx in edge_indices {
let error = self.orthogonalize_single_weight(graph, edge_idx)?;
let weight = &graph[edge_idx];
self.error_accumulator
.borrow_mut()
.record_error(&weight.name, error);
total_error += error;
orthogonalized_count += 1;
}
if orthogonalized_count > 0 {
self.statistics.borrow_mut().insert(
"orthogonalization_error".to_string(),
total_error / orthogonalized_count as f64
);
}
Ok(())
}
pub fn orthogonalize_single_weight(
&self,
graph: &mut Graph<OperatorType, WeightTensor>,
edge_idx: crate::edge::EdgeIndex,
) -> GraphResult<f64> {
use crate::tensor::decomposition::qr::orthogonalize_in_place;
let weight = &mut graph[edge_idx];
let shape = weight.shape.to_vec();
if shape.len() != 2 {
eprintln!("Skipping orthogonalization for {}: shape={:?} (not 2D)", weight.name, shape);
return Ok(0.0);
}
if shape[0] < shape[1] {
eprintln!("Skipping orthogonalization for {}: shape={:?} (m < n)", weight.name, shape);
return Ok(0.0);
}
let error = orthogonalize_in_place(&mut weight.data, &shape)
.map_err(|e| GraphError::InvalidFormat(e.to_string()))?;
Ok(error)
}
#[allow(dead_code)]
fn check_orthogonality(tensor: &DenseTensor) -> f64 {
let shape = tensor.shape();
if shape.len() != 2 {
return f64::MAX;
}
let n = shape[0];
let m = shape[1];
let data = tensor.data();
let mut max_error: f64 = 0.0;
for i in 0..m {
for j in 0..m {
let mut dot = 0.0;
for k in 0..n {
dot += data[k * m + i] * data[k * m + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
let error = (dot - expected).abs();
max_error = max_error.max(error);
}
}
max_error
}
pub fn block_decompose(
&self,
graph: &mut Graph<OperatorType, WeightTensor>,
) -> GraphResult<DecomposedWeights> {
use crate::graph::traits::GraphQuery;
let block_size = self.config.block_size;
let mut decomposed_blocks = Vec::new();
let mut total_blocks = 0;
let edge_data: Vec<_> = graph.edges().map(|e| {
(e.index(), e.data().name.clone(), e.data().data.to_vec(), e.data().shape.to_vec())
}).collect();
for (_edge_idx, layer_name, weight_data, weight_shape) in edge_data {
if !self.config.matches_layer(&layer_name) {
continue;
}
let weight_tensor = DenseTensor::new(weight_data, weight_shape);
let blocks = decompose_into_so_blocks(&weight_tensor, block_size)
.map_err(|e| GraphError::InvalidFormat(e.to_string()))?;
total_blocks += blocks.len();
decomposed_blocks.push(BlockDecomposition {
layer_name,
num_blocks: blocks.len(),
block_size,
});
}
self.statistics.borrow_mut().insert(
"total_blocks".to_string(),
total_blocks as f64
);
Ok(DecomposedWeights {
blocks: decomposed_blocks,
total_blocks,
})
}
pub fn lie_algebra_regularize(
&self,
tensor: &DenseTensor,
) -> Result<DenseTensor, crate::tensor::TensorError> {
use crate::tensor::decomposition::lie_algebra::skew_symmetric_projection;
let skew = skew_symmetric_projection(tensor)?;
crate::tensor::decomposition::lie_algebra::lie_exponential(&skew)
}
pub fn cayley_transform(
&self,
tensor: &DenseTensor,
) -> Result<DenseTensor, crate::tensor::TensorError> {
use crate::tensor::decomposition::lie_algebra::{
lie_exponential, skew_symmetric_projection,
};
if self.config.use_cayley {
let skew = skew_symmetric_projection(tensor)?;
lie_exponential(&skew)
} else {
crate::tensor::decomposition::qr::orthogonalize(tensor)
}
}
pub fn is_well_conditioned(&self, tensor: &DenseTensor, threshold: f64) -> bool {
let shape = tensor.shape();
if shape.len() != 2 {
return false;
}
let data = tensor.data();
let (m, n) = (shape[0], shape[1]);
let mut v = vec![1.0 / (n as f64).sqrt(); n];
for _ in 0..20 {
let mut av = vec![0.0; m];
for i in 0..m {
for j in 0..n {
av[i] += data[i * n + j] * v[j];
}
}
let mut atav = vec![0.0; n];
for i in 0..n {
for j in 0..m {
atav[i] += data[j * n + i] * av[j];
}
}
let norm: f64 = atav.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm < 1e-10 {
return true;
}
v = atav.into_iter().map(|x| x / norm).collect();
}
let sigma_max_sq: f64 = v
.iter()
.enumerate()
.map(|(i, &vi)| {
let mut sum = 0.0;
for j in 0..n {
let mut aj = 0.0;
for k in 0..m {
aj += data[k * n + j] * data[k * n + i];
}
sum += aj * v[j];
}
sum * vi
})
.sum();
let sigma_max = sigma_max_sq.sqrt();
let sigma_min = 1.0 / sigma_max;
let condition_number = sigma_max / sigma_min;
condition_number < threshold
}
}
#[derive(Debug, Clone)]
pub struct BlockDecomposition {
pub layer_name: String,
pub num_blocks: usize,
pub block_size: usize,
}
#[derive(Debug, Clone)]
pub struct DecomposedWeights {
pub blocks: Vec<BlockDecomposition>,
pub total_blocks: usize,
}
#[derive(Debug, Clone)]
pub struct SOkBlock {
pub data: Vec<f64>,
pub size: usize,
}
impl SOkBlock {
pub fn new(data: Vec<f64>, size: usize) -> Result<Self, crate::tensor::TensorError> {
if data.len() != size * size {
return Err(crate::tensor::TensorError::DimensionMismatch {
expected: size * size,
got: data.len(),
});
}
Ok(Self { data, size })
}
pub fn is_orthogonal(&self, tolerance: f64) -> bool {
let n = self.size;
let data = &self.data;
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for k in 0..n {
dot += data[k * n + i] * data[k * n + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
if (dot - expected).abs() > tolerance {
return false;
}
}
}
true
}
}
pub fn decompose_into_so_blocks(
tensor: &DenseTensor,
block_size: usize,
) -> Result<Vec<SOkBlock>, crate::tensor::TensorError> {
use crate::tensor::decomposition::qr::orthogonalize;
let shape = tensor.shape();
if shape.len() != 2 {
return Err(crate::tensor::TensorError::DimensionMismatch {
expected: 2,
got: shape.len(),
});
}
let (m, n) = (shape[0], shape[1]);
let mut blocks = Vec::new();
for i in (0..m).step_by(block_size) {
for j in (0..n).step_by(block_size) {
let block_m = std::cmp::min(block_size, m - i);
let block_n = std::cmp::min(block_size, n - j);
let mut block_data = vec![0.0; block_m * block_n];
for bi in 0..block_m {
for bj in 0..block_n {
block_data[bi * block_n + bj] =
tensor.data()[(i + bi) * n + (j + bj)];
}
}
if block_m != block_n {
let size = std::cmp::max(block_m, block_n);
let mut square_block = vec![0.0; size * size];
for bi in 0..block_m {
for bj in 0..block_n {
square_block[bi * size + bj] = block_data[bi * block_n + bj];
}
}
block_data = square_block;
}
let block_tensor = DenseTensor::from_vec(
block_data,
vec![block_m.max(block_n), block_m.max(block_n)],
);
let ortho = orthogonalize(&block_tensor)?;
blocks.push(SOkBlock::new(ortho.data().to_vec(), ortho.shape()[0])?);
}
}
Ok(blocks)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lie_group_config() {
let config = LieGroupConfig::new()
.with_block_size(128)
.with_orthogonalize(true)
.with_target_layers(vec!["q_proj".to_string(), "k_proj".to_string()]);
assert_eq!(config.block_size, 128);
assert!(config.orthogonalize);
assert!(config.matches_layer("model.layers.0.attn.q_proj"));
assert!(config.matches_layer("model.layers.0.attn.k_proj"));
assert!(!config.matches_layer("model.layers.0.mlp"));
}
#[test]
fn test_sok_block() {
let theta = std::f64::consts::PI / 4.0;
let cos_t = theta.cos();
let sin_t = theta.sin();
let block = SOkBlock::new(
vec![cos_t, -sin_t, sin_t, cos_t],
2,
).unwrap();
assert!(block.is_orthogonal(1e-5));
}
#[test]
fn test_decompose_into_so_blocks() {
let tensor = DenseTensor::from_vec(
vec![1.0, 0.0, 0.0, 1.0],
vec![2, 2],
);
let blocks = decompose_into_so_blocks(&tensor, 2).unwrap();
assert_eq!(blocks.len(), 1);
assert!(blocks[0].is_orthogonal(1e-5));
}
#[test]
fn test_lie_optimizer() {
let config = LieGroupConfig::new()
.with_block_size(64)
.with_orthogonalize(true);
let optimizer = LieGroupOptimizer::new(config);
let tensor = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
);
let result = optimizer.cayley_transform(&tensor);
assert!(result.is_ok());
}
#[test]
fn test_orthogonalize_single_weight() {
use crate::graph::Graph;
use crate::graph::traits::GraphOps;
let config = LieGroupConfig::new()
.with_block_size(2)
.with_orthogonalize(true);
let optimizer = LieGroupOptimizer::new(config);
let mut graph = Graph::<OperatorType, WeightTensor>::directed();
let from = graph.add_node(OperatorType::Linear { in_features: 2, out_features: 2 }).unwrap();
let to = graph.add_node(OperatorType::Linear { in_features: 2, out_features: 2 }).unwrap();
let weight = WeightTensor::new(
"test".to_string(),
vec![1.0, 0.0, 0.0, 1.0],
vec![2, 2],
);
let edge = graph.add_edge(from, to, weight).unwrap();
let error = optimizer.orthogonalize_single_weight(&mut graph, edge);
assert!(error.is_ok());
}
#[test]
fn test_error_accumulator() {
let config = LieGroupConfig::new().with_orthogonalize(true);
let optimizer = LieGroupOptimizer::new(config);
{
let mut acc = optimizer.error_accumulator_mut();
acc.record_error("layer1", 0.01);
acc.record_error("layer2", 0.02);
}
let acc = optimizer.error_accumulator();
assert_eq!(acc.num_layers(), 2);
assert!(acc.get_layer_errors("layer1").is_some());
assert!(acc.get_layer_errors("layer2").is_some());
}
#[test]
fn test_check_orthogonality() {
let identity = DenseTensor::from_vec(
vec![1.0, 0.0, 0.0, 1.0],
vec![2, 2],
);
let error = LieGroupOptimizer::check_orthogonality(&identity);
assert!(error < 1e-10);
let non_ortho = DenseTensor::from_vec(
vec![1.0, 1.0, 1.0, 1.0],
vec![2, 2],
);
let error = LieGroupOptimizer::check_orthogonality(&non_ortho);
assert!(error > 0.1);
}
}
pub fn orthogonalize_weights_in_place(
config: &LieGroupConfig,
graph: &mut Graph<OperatorType, WeightTensor>,
) -> GraphResult<Vec<f64>> {
use crate::graph::traits::GraphQuery;
let mut errors = Vec::new();
let optimizer = LieGroupOptimizer::new(config.clone());
let edge_indices: Vec<_> = graph.edges().map(|e| e.index()).collect();
for edge_idx in edge_indices {
let error = optimizer.orthogonalize_single_weight(graph, edge_idx)?;
errors.push(error);
}
Ok(errors)
}