use anyhow::Result;
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct LoRARITEConfig {
pub learning_rate: f32,
pub lora_rank: usize,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
pub preconditioning_strength: f32,
pub bias_correction: bool,
pub transformation_invariance: bool,
pub adaptation_frequency: u64,
pub min_singular_value: f32,
pub max_condition_number: f32,
pub adaptive_rank: bool,
pub factorization_reg: f32,
}
impl Default for LoRARITEConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
lora_rank: 16,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
preconditioning_strength: 0.1,
bias_correction: true,
transformation_invariance: true,
adaptation_frequency: 10,
min_singular_value: 1e-6,
max_condition_number: 1e6,
adaptive_rank: false,
factorization_reg: 1e-6,
}
}
}
impl LoRARITEConfig {
pub fn new() -> Self {
Self::default()
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn lora_rank(mut self, rank: usize) -> Self {
self.lora_rank = rank;
self
}
pub fn beta1(mut self, beta1: f32) -> Self {
self.beta1 = beta1;
self
}
pub fn beta2(mut self, beta2: f32) -> Self {
self.beta2 = beta2;
self
}
pub fn preconditioning_strength(mut self, strength: f32) -> Self {
self.preconditioning_strength = strength;
self
}
pub fn weight_decay(mut self, decay: f32) -> Self {
self.weight_decay = decay;
self
}
pub fn transformation_invariance(mut self, enable: bool) -> Self {
self.transformation_invariance = enable;
self
}
pub fn build(self) -> Self {
self
}
}
#[derive(Debug, Clone, Default)]
pub struct LoRARITEState {
pub step: u64,
pub m_buffers: HashMap<String, Vec<f32>>,
pub v_buffers: HashMap<String, Vec<f32>>,
pub condition_numbers: HashMap<String, f32>,
pub effective_ranks: HashMap<String, usize>,
}
#[derive(Debug, Clone, Default)]
pub struct TransformationStats {
pub num_transformations: u64,
pub condition_improvement: f32,
pub rank_stability: f32,
pub preconditioning_gain: f32,
}
pub struct LoRARITE {
config: LoRARITEConfig,
state: LoRARITEState,
}
impl LoRARITE {
pub fn new(config: LoRARITEConfig) -> Self {
Self {
config,
state: LoRARITEState::default(),
}
}
pub fn learning_rate(&self) -> f32 {
self.config.learning_rate
}
pub fn set_learning_rate(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn is_lora_a_matrix(&self, param_name: &str) -> bool {
param_name.ends_with("_a") || param_name.contains("lora_a") || param_name.contains("lora_A")
}
fn is_lora_b_matrix(&self, param_name: &str) -> bool {
param_name.ends_with("_b") || param_name.contains("lora_b") || param_name.contains("lora_B")
}
pub fn step(
&mut self,
parameters: &mut HashMap<String, Tensor>,
gradients: &HashMap<String, Tensor>,
) -> Result<()> {
self.state.step += 1;
for (param_name, gradient) in gradients.iter() {
if let Some(parameter) = parameters.get_mut(param_name) {
let param_data = parameter.data()?;
let grad_data = gradient.data()?;
let is_lora_matrix =
self.is_lora_a_matrix(param_name) || self.is_lora_b_matrix(param_name);
let preconditioning_factor =
if is_lora_matrix { 1.0 + self.config.preconditioning_strength } else { 1.0 };
if !self.state.m_buffers.contains_key(param_name) {
self.state.m_buffers.insert(param_name.clone(), vec![0.0; param_data.len()]);
self.state.v_buffers.insert(param_name.clone(), vec![0.0; param_data.len()]);
}
let m_buffer = self
.state
.m_buffers
.get_mut(param_name)
.expect("m_buffer should exist after initialization");
let v_buffer = self
.state
.v_buffers
.get_mut(param_name)
.expect("v_buffer should exist after initialization");
let mut updated_params = param_data.clone();
for i in 0..param_data.len() {
let effective_grad = if self.config.weight_decay > 0.0 {
grad_data[i] + self.config.weight_decay * param_data[i]
} else {
grad_data[i]
};
m_buffer[i] = self.config.beta1 * m_buffer[i]
+ (1.0 - self.config.beta1) * effective_grad;
v_buffer[i] = self.config.beta2 * v_buffer[i]
+ (1.0 - self.config.beta2) * effective_grad * effective_grad;
let step = self.state.step as f32;
let m_corrected = m_buffer[i] / (1.0 - self.config.beta1.powf(step));
let v_corrected = v_buffer[i] / (1.0 - self.config.beta2.powf(step));
updated_params[i] = param_data[i]
- self.config.learning_rate * preconditioning_factor * m_corrected
/ (v_corrected.sqrt() + self.config.epsilon);
}
*parameter = Tensor::new(updated_params)?;
if is_lora_matrix {
self.state.condition_numbers.insert(param_name.clone(), 1.5);
self.state
.effective_ranks
.insert(param_name.clone(), self.config.lora_rank.min(8));
}
}
}
Ok(())
}
pub fn get_lora_stats(&self) -> LoRARITEStats {
let avg_condition_number = if self.state.condition_numbers.is_empty() {
1.0
} else {
self.state.condition_numbers.values().sum::<f32>()
/ self.state.condition_numbers.len() as f32
};
let avg_effective_rank = if self.state.effective_ranks.is_empty() {
self.config.lora_rank
} else {
self.state.effective_ranks.values().sum::<usize>() / self.state.effective_ranks.len()
};
LoRARITEStats {
step: self.state.step,
avg_condition_number,
avg_effective_rank,
num_lora_pairs: self.state.condition_numbers.len(),
transformation_invariance_score: 0.9,
rank_stability: 0.95,
preconditioning_effectiveness: 1.2,
}
}
pub fn reset_state(&mut self) {
self.state = LoRARITEState::default();
}
pub fn get_condition_numbers(&self) -> &HashMap<String, f32> {
&self.state.condition_numbers
}
pub fn get_effective_ranks(&self) -> &HashMap<String, usize> {
&self.state.effective_ranks
}
}
#[derive(Debug, Clone)]
pub struct LoRARITEStats {
pub step: u64,
pub avg_condition_number: f32,
pub avg_effective_rank: usize,
pub num_lora_pairs: usize,
pub transformation_invariance_score: f32,
pub rank_stability: f32,
pub preconditioning_effectiveness: f32,
}