use crate::common::{OptimizerState, ParameterUpdate};
use anyhow::{Result, Context};
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct LoRARITEConfig {
pub learning_rate: f32,
pub lora_rank: usize,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
pub preconditioning_strength: f32,
pub bias_correction: bool,
pub transformation_invariance: bool,
pub adaptation_frequency: u64,
pub min_singular_value: f32,
pub max_condition_number: f32,
pub adaptive_rank: bool,
pub factorization_reg: f32,
}
impl Default for LoRARITEConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
lora_rank: 16,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
preconditioning_strength: 0.1,
bias_correction: true,
transformation_invariance: true,
adaptation_frequency: 10,
min_singular_value: 1e-6,
max_condition_number: 1e6,
adaptive_rank: false,
factorization_reg: 1e-6,
}
}
}
impl LoRARITEConfig {
pub fn new() -> Self {
Self::default()
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn lora_rank(mut self, rank: usize) -> Self {
self.lora_rank = rank;
self
}
pub fn beta1(mut self, beta1: f32) -> Self {
self.beta1 = beta1;
self
}
pub fn beta2(mut self, beta2: f32) -> Self {
self.beta2 = beta2;
self
}
pub fn preconditioning_strength(mut self, strength: f32) -> Self {
self.preconditioning_strength = strength;
self
}
pub fn weight_decay(mut self, decay: f32) -> Self {
self.weight_decay = decay;
self
}
pub fn transformation_invariance(mut self, enable: bool) -> Self {
self.transformation_invariance = enable;
self
}
pub fn build(self) -> Self {
self
}
}
#[derive(Debug, Clone)]
pub struct LoRARITEState {
pub step: u64,
pub m_a: HashMap<String, Tensor>,
pub m_b: HashMap<String, Tensor>,
pub v_a: HashMap<String, Tensor>,
pub v_b: HashMap<String, Tensor>,
pub precond_a: HashMap<String, Tensor>,
pub precond_b: HashMap<String, Tensor>,
pub singular_values: HashMap<String, Tensor>,
pub condition_numbers: HashMap<String, f32>,
pub effective_ranks: HashMap<String, usize>,
pub transformation_stats: TransformationStats,
}
#[derive(Debug, Clone)]
pub struct TransformationStats {
pub num_transformations: u64,
pub condition_improvement: f32,
pub rank_stability: f32,
pub preconditioning_gain: f32,
}
impl Default for TransformationStats {
fn default() -> Self {
Self {
num_transformations: 0,
condition_improvement: 0.0,
rank_stability: 1.0,
preconditioning_gain: 1.0,
}
}
}
impl Default for LoRARITEState {
fn default() -> Self {
Self {
step: 0,
m_a: HashMap::new(),
m_b: HashMap::new(),
v_a: HashMap::new(),
v_b: HashMap::new(),
precond_a: HashMap::new(),
precond_b: HashMap::new(),
singular_values: HashMap::new(),
condition_numbers: HashMap::new(),
effective_ranks: HashMap::new(),
transformation_stats: TransformationStats::default(),
}
}
}
pub struct LoRARITE {
config: LoRARITEConfig,
state: LoRARITEState,
}
impl LoRARITE {
pub fn new(config: LoRARITEConfig) -> Self {
Self {
config,
state: LoRARITEState::default(),
}
}
pub fn learning_rate(&self) -> f32 {
self.config.learning_rate
}
pub fn set_learning_rate(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn is_lora_a_matrix(&self, param_name: &str) -> bool {
param_name.ends_with("_a") || param_name.contains("lora_a") || param_name.contains("lora_A")
}
fn is_lora_b_matrix(&self, param_name: &str) -> bool {
param_name.ends_with("_b") || param_name.contains("lora_b") || param_name.contains("lora_B")
}
fn get_lora_base_name(&self, param_name: &str) -> String {
if param_name.ends_with("_a") {
param_name.trim_end_matches("_a").to_string()
} else if param_name.ends_with("_b") {
param_name.trim_end_matches("_b").to_string()
} else if param_name.contains("lora_a") {
param_name.replace("lora_a", "lora")
} else if param_name.contains("lora_b") {
param_name.replace("lora_b", "lora")
} else {
param_name.to_string()
}
}
fn compute_svd(&self, matrix_a: &Tensor, matrix_b: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let product = matrix_b.matmul(&matrix_a)?;
let product_t = product.transpose(-1, -2)?;
let gram_matrix = product.matmul(&product_t)?;
let eigenvalues = self.compute_eigenvalues(&gram_matrix)?;
let singular_values = eigenvalues.sqrt()?;
let u = Tensor::eye(matrix_b.shape()[0])?;
let v = Tensor::eye(matrix_a.shape()[1])?;
Ok((u, singular_values, v))
}
fn compute_eigenvalues(&self, matrix: &Tensor) -> Result<Tensor> {
let diagonal = matrix.diagonal()?;
Ok(diagonal.abs())
}
fn compute_lora_preconditioning(&self, param_name: &str, gradient: &Tensor) -> Result<Tensor> {
let grad_squared = gradient.pow(&Tensor::scalar(2.0)?)?;
let reg_tensor = Tensor::scalar(self.config.factorization_reg)?;
let preconditioner = grad_squared.add(®_tensor)?;
if self.config.transformation_invariance {
self.apply_transformation_invariance(&preconditioner)
} else {
Ok(preconditioner.sqrt()?.reciprocal())
}
}
fn apply_transformation_invariance(&self, preconditioner: &Tensor) -> Result<Tensor> {
let eigenvalues = self.compute_eigenvalues(preconditioner)?;
let min_val = Tensor::scalar(self.config.min_singular_value)?;
let max_val = Tensor::scalar(self.config.min_singular_value * self.config.max_condition_number)?;
let clamped_eigenvalues = eigenvalues.clamp(&min_val, &max_val)?;
let sqrt_eigenvalues = clamped_eigenvalues.sqrt()?;
sqrt_eigenvalues.reciprocal()
}
fn update_moments(&mut self, param_name: &str, gradient: &Tensor) -> Result<(Tensor, Tensor)> {
let beta1 = self.config.beta1;
let beta2 = self.config.beta2;
let (m_map, v_map) = if self.is_lora_a_matrix(param_name) {
(&mut self.state.m_a, &mut self.state.v_a)
} else {
(&mut self.state.m_b, &mut self.state.v_b)
};
let m = if let Some(prev_m) = m_map.get(param_name) {
let beta1_tensor = Tensor::scalar(beta1)?;
let one_minus_beta1 = Tensor::scalar(1.0 - beta1)?;
let weighted_prev = prev_m.mul(&beta1_tensor)?;
let weighted_grad = gradient.mul(&one_minus_beta1)?;
weighted_prev.add(&weighted_grad)?
} else {
gradient.mul(&Tensor::scalar(1.0 - beta1)?)?
};
let grad_squared = gradient.pow(&Tensor::scalar(2.0)?)?;
let v = if let Some(prev_v) = v_map.get(param_name) {
let beta2_tensor = Tensor::scalar(beta2)?;
let one_minus_beta2 = Tensor::scalar(1.0 - beta2)?;
let weighted_prev = prev_v.mul(&beta2_tensor)?;
let weighted_grad_sq = grad_squared.mul(&one_minus_beta2)?;
weighted_prev.add(&weighted_grad_sq)?
} else {
grad_squared.mul(&Tensor::scalar(1.0 - beta2)?)?
};
m_map.insert(param_name.to_string(), m.clone());
v_map.insert(param_name.to_string(), v.clone());
Ok((m, v))
}
fn apply_bias_correction(&self, moment: &Tensor, beta: f32) -> Result<Tensor> {
if !self.config.bias_correction {
return Ok(moment.clone());
}
let step = self.state.step as f32;
let correction_factor = 1.0 - beta.powf(step);
let correction_tensor = Tensor::scalar(correction_factor)?;
moment.div(&correction_tensor)
}
fn compute_effective_rank(&self, singular_values: &Tensor) -> Result<usize> {
let sv_data = singular_values.data()?;
let total_variance: f32 = sv_data.iter().sum();
let threshold = 0.95 * total_variance;
let mut cumulative_variance = 0.0;
let mut effective_rank = 0;
for &sv in sv_data.iter() {
cumulative_variance += sv;
effective_rank += 1;
if cumulative_variance >= threshold {
break;
}
}
Ok(effective_rank.min(self.config.lora_rank))
}
fn update_lora_stats(&mut self, base_name: &str, matrix_a: &Tensor, matrix_b: &Tensor) -> Result<()> {
let (_, singular_values, _) = self.compute_svd(matrix_a, matrix_b)?;
let sv_data = singular_values.data()?;
let max_sv = sv_data.iter().fold(0.0f32, |a, &b| a.max(b));
let min_sv = sv_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let condition_number = max_sv / (min_sv + self.config.epsilon);
let effective_rank = self.compute_effective_rank(&singular_values)?;
self.state.singular_values.insert(base_name.to_string(), singular_values);
self.state.condition_numbers.insert(base_name.to_string(), condition_number);
self.state.effective_ranks.insert(base_name.to_string(), effective_rank);
Ok(())
}
pub fn step(&mut self, parameters: &mut HashMap<String, Tensor>,
gradients: &HashMap<String, Tensor>) -> Result<()> {
self.state.step += 1;
let mut processed_pairs: std::collections::HashSet<String> = std::collections::HashSet::new();
for (param_name, gradient) in gradients.iter() {
if let Some(parameter) = parameters.get_mut(param_name) {
let base_name = self.get_lora_base_name(param_name);
if processed_pairs.contains(&base_name) {
continue;
}
let mut effective_gradient = gradient.clone();
if self.config.weight_decay > 0.0 {
let weight_decay_term = parameter.mul(&Tensor::scalar(self.config.weight_decay)?)?;
effective_gradient = effective_gradient.add(&weight_decay_term)?;
}
let (m, v) = self.update_moments(param_name, &effective_gradient)?;
let corrected_m = self.apply_bias_correction(&m, self.config.beta1)?;
let corrected_v = self.apply_bias_correction(&v, self.config.beta2)?;
let preconditioner = self.compute_lora_preconditioning(param_name, &effective_gradient)?;
let v_sqrt = corrected_v.sqrt()?;
let v_sqrt_eps = v_sqrt.add(&Tensor::scalar(self.config.epsilon)?)?;
let adam_update = corrected_m.div(&v_sqrt_eps)?;
let strength = Tensor::scalar(self.config.preconditioning_strength)?;
let one_minus_strength = Tensor::scalar(1.0 - self.config.preconditioning_strength)?;
let preconditioned_update = adam_update.mul(&strength)?.mul(&preconditioner)?
.add(&adam_update.mul(&one_minus_strength)?)?;
let lr_tensor = Tensor::scalar(self.config.learning_rate)?;
let param_update = preconditioned_update.mul(&lr_tensor)?;
*parameter = parameter.sub(¶m_update)?;
if self.is_lora_a_matrix(param_name) || self.is_lora_b_matrix(param_name) {
let a_name = format!("{}_a", base_name);
let b_name = format!("{}_b", base_name);
if let (Some(matrix_a), Some(matrix_b)) = (parameters.get(&a_name), parameters.get(&b_name)) {
self.update_lora_stats(&base_name, matrix_a, matrix_b)?;
processed_pairs.insert(base_name);
}
}
}
}
if self.state.step % self.config.adaptation_frequency == 0 {
self.update_transformation_stats()?;
}
Ok(())
}
fn update_transformation_stats(&mut self) -> Result<()> {
let mut total_condition_improvement = 0.0;
let mut count = 0;
for &condition_number in self.state.condition_numbers.values() {
if condition_number < self.config.max_condition_number {
total_condition_improvement += 1.0 / condition_number;
count += 1;
}
}
if count > 0 {
self.state.transformation_stats.condition_improvement = total_condition_improvement / count as f32;
self.state.transformation_stats.num_transformations += 1;
}
let mut rank_variance = 0.0;
let ranks: Vec<f32> = self.state.effective_ranks.values().map(|&r| r as f32).collect();
if !ranks.is_empty() {
let mean_rank: f32 = ranks.iter().sum::<f32>() / ranks.len() as f32;
rank_variance = ranks.iter().map(|&r| (r - mean_rank).powi(2)).sum::<f32>() / ranks.len() as f32;
self.state.transformation_stats.rank_stability = 1.0 / (1.0 + rank_variance.sqrt());
}
Ok(())
}
pub fn get_lora_stats(&self) -> LoRARITEStats {
let avg_condition_number = if self.state.condition_numbers.is_empty() {
1.0
} else {
self.state.condition_numbers.values().sum::<f32>() / self.state.condition_numbers.len() as f32
};
let avg_effective_rank = if self.state.effective_ranks.is_empty() {
self.config.lora_rank
} else {
self.state.effective_ranks.values().sum::<usize>() / self.state.effective_ranks.len()
};
LoRARITEStats {
step: self.state.step,
avg_condition_number,
avg_effective_rank,
num_lora_pairs: self.state.singular_values.len(),
transformation_invariance_score: self.state.transformation_stats.condition_improvement,
rank_stability: self.state.transformation_stats.rank_stability,
preconditioning_effectiveness: self.state.transformation_stats.preconditioning_gain,
}
}
pub fn reset_state(&mut self) {
self.state = LoRARITEState::default();
}
pub fn get_condition_numbers(&self) -> &HashMap<String, f32> {
&self.state.condition_numbers
}
pub fn get_effective_ranks(&self) -> &HashMap<String, usize> {
&self.state.effective_ranks
}
}
#[derive(Debug, Clone)]
pub struct LoRARITEStats {
pub step: u64,
pub avg_condition_number: f32,
pub avg_effective_rank: usize,
pub num_lora_pairs: usize,
pub transformation_invariance_score: f32,
pub rank_stability: f32,
pub preconditioning_effectiveness: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::tensor::Tensor;
#[test]
fn test_lora_rite_creation() {
let config = LoRARITEConfig::new()
.learning_rate(1e-3)
.lora_rank(16)
.beta1(0.9)
.build();
let optimizer = LoRARITE::new(config);
assert_eq!(optimizer.learning_rate(), 1e-3);
}
#[test]
fn test_lora_rite_config_builder() {
let config = LoRARITEConfig::new()
.learning_rate(2e-3)
.lora_rank(32)
.beta1(0.95)
.beta2(0.999)
.preconditioning_strength(0.2)
.weight_decay(1e-4)
.build();
assert_eq!(config.learning_rate, 2e-3);
assert_eq!(config.lora_rank, 32);
assert_eq!(config.beta1, 0.95);
assert_eq!(config.beta2, 0.999);
assert_eq!(config.preconditioning_strength, 0.2);
assert_eq!(config.weight_decay, 1e-4);
}
#[test]
fn test_lora_matrix_detection() {
let config = LoRARITEConfig::new().build();
let optimizer = LoRARITE::new(config);
assert!(optimizer.is_lora_a_matrix("layer1_a"));
assert!(optimizer.is_lora_b_matrix("layer1_b"));
assert!(optimizer.is_lora_a_matrix("attention.lora_a"));
assert!(optimizer.is_lora_b_matrix("attention.lora_b"));
assert!(!optimizer.is_lora_a_matrix("layer1.weight"));
}
#[test]
fn test_lora_base_name_extraction() {
let config = LoRARITEConfig::new().build();
let optimizer = LoRARITE::new(config);
assert_eq!(optimizer.get_lora_base_name("layer1_a"), "layer1");
assert_eq!(optimizer.get_lora_base_name("layer1_b"), "layer1");
assert_eq!(optimizer.get_lora_base_name("attention.lora_a"), "attention.lora");
assert_eq!(optimizer.get_lora_base_name("attention.lora_b"), "attention.lora");
}
#[test]
fn test_lora_rite_step() -> Result<()> {
let config = LoRARITEConfig::new()
.learning_rate(1e-2)
.lora_rank(4)
.build();
let mut optimizer = LoRARITE::new(config);
let mut parameters = HashMap::new();
parameters.insert("layer1_a".to_string(), Tensor::ones(&[4, 8])?); parameters.insert("layer1_b".to_string(), Tensor::ones(&[2, 4])?);
let mut gradients = HashMap::new();
gradients.insert("layer1_a".to_string(), Tensor::ones(&[4, 8])? * 0.1);
gradients.insert("layer1_b".to_string(), Tensor::ones(&[2, 4])? * 0.1);
let orig_a = parameters.get("layer1_a").expect("Key not found").clone();
let orig_b = parameters.get("layer1_b").expect("Key not found").clone();
optimizer.step(&mut parameters, &gradients)?;
let updated_a = parameters.get("layer1_a").expect("Key not found");
let updated_b = parameters.get("layer1_b").expect("Key not found");
assert_ne!(updated_a.mean()?.to_scalar::<f32>()?, orig_a.mean()?.to_scalar::<f32>()?);
assert_ne!(updated_b.mean()?.to_scalar::<f32>()?, orig_b.mean()?.to_scalar::<f32>()?);
Ok(())
}
#[test]
fn test_moment_updates() -> Result<()> {
let config = LoRARITEConfig::new().build();
let mut optimizer = LoRARITE::new(config);
let gradient = Tensor::ones(&[2, 2])? * 0.5;
let (m1, v1) = optimizer.update_moments("test_a", &gradient)?;
let (m2, v2) = optimizer.update_moments("test_a", &gradient)?;
assert_ne!(m1.mean()?.to_scalar::<f32>()?, m2.mean()?.to_scalar::<f32>()?);
assert_ne!(v1.mean()?.to_scalar::<f32>()?, v2.mean()?.to_scalar::<f32>()?);
Ok(())
}
#[test]
fn test_bias_correction() -> Result<()> {
let config = LoRARITEConfig::new().bias_correction(true).build();
let optimizer = LoRARITE::new(config);
let moment = Tensor::ones(&[2, 2])? * 0.5;
let beta = 0.9;
let corrected = optimizer.apply_bias_correction(&moment, beta)?;
assert!(corrected.mean()?.to_scalar::<f32>()? > moment.mean()?.to_scalar::<f32>()?);
Ok(())
}
#[test]
fn test_lora_stats() -> Result<()> {
let config = LoRARITEConfig::new().lora_rank(4).build();
let mut optimizer = LoRARITE::new(config);
optimizer.state.condition_numbers.insert("layer1".to_string(), 2.5);
optimizer.state.condition_numbers.insert("layer2".to_string(), 3.0);
optimizer.state.effective_ranks.insert("layer1".to_string(), 3);
optimizer.state.effective_ranks.insert("layer2".to_string(), 4);
let stats = optimizer.get_lora_stats();
assert_eq!(stats.num_lora_pairs, 0); assert_eq!(stats.avg_condition_number, 2.75); assert_eq!(stats.avg_effective_rank, 3);
Ok(())
}
#[test]
fn test_learning_rate_methods() {
let config = LoRARITEConfig::new().learning_rate(1e-3).build();
let mut optimizer = LoRARITE::new(config);
assert_eq!(optimizer.learning_rate(), 1e-3);
optimizer.set_learning_rate(2e-3);
assert_eq!(optimizer.learning_rate(), 2e-3);
}
#[test]
fn test_weight_decay() -> Result<()> {
let config = LoRARITEConfig::new()
.learning_rate(1e-2)
.weight_decay(1e-2)
.build();
let mut optimizer = LoRARITE::new(config);
let mut parameters = HashMap::new();
parameters.insert("layer1_a".to_string(), Tensor::ones(&[2, 2])?);
let mut gradients = HashMap::new();
gradients.insert("layer1_a".to_string(), Tensor::zeros(&[2, 2])?);
let initial_param_value = parameters.get("layer1_a").expect("Key not found").mean()?.to_scalar::<f32>()?;
optimizer.step(&mut parameters, &gradients)?;
let final_param_value = parameters.get("layer1_a").expect("Key not found").mean()?.to_scalar::<f32>()?;
assert!(final_param_value < initial_param_value);
Ok(())
}
#[test]
fn test_transformation_invariance() -> Result<()> {
let config = LoRARITEConfig::new()
.transformation_invariance(true)
.build();
let optimizer = LoRARITE::new(config);
let preconditioner = Tensor::ones(&[2, 2])? * 2.0;
let transformed = optimizer.apply_transformation_invariance(&preconditioner)?;
let result_value = transformed.mean()?.to_scalar::<f32>()?;
assert!(result_value > 0.0);
assert!(result_value.is_finite());
Ok(())
}
}