use crate::common::{OptimizerState, ParameterUpdate};
use anyhow::{Result, Context};
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct GENIEConfig {
pub learning_rate: f32,
pub osgr_momentum: f32,
pub alignment_weight: f32,
pub preconditioning_eps: f32,
pub min_osgr: f32,
pub max_osgr: f32,
pub adaptive_alignment: bool,
pub weight_decay: f32,
pub normalize_osgr: bool,
pub warmup_steps: u64,
}
impl Default for GENIEConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
osgr_momentum: 0.9,
alignment_weight: 0.1,
preconditioning_eps: 1e-8,
min_osgr: 1e-6,
max_osgr: 1e6,
adaptive_alignment: true,
weight_decay: 0.0,
normalize_osgr: true,
warmup_steps: 100,
}
}
}
impl GENIEConfig {
pub fn new() -> Self {
Self::default()
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn osgr_momentum(mut self, momentum: f32) -> Self {
self.osgr_momentum = momentum;
self
}
pub fn alignment_weight(mut self, weight: f32) -> Self {
self.alignment_weight = weight;
self
}
pub fn preconditioning_eps(mut self, eps: f32) -> Self {
self.preconditioning_eps = eps;
self
}
pub fn weight_decay(mut self, decay: f32) -> Self {
self.weight_decay = decay;
self
}
pub fn build(self) -> Self {
self
}
}
#[derive(Debug, Clone)]
pub struct GENIEState {
pub step: u64,
pub osgr: HashMap<String, Tensor>,
pub osgr_ema: HashMap<String, Tensor>,
pub prev_gradients: HashMap<String, Tensor>,
pub prev_loss: Option<f32>,
pub alignment_stats: HashMap<String, f32>,
pub preconditioning_factors: HashMap<String, Tensor>,
pub domain_stats: DomainStats,
}
#[derive(Debug, Clone)]
pub struct DomainStats {
pub domain_losses: Vec<f32>,
pub domain_variance: f32,
pub cross_domain_alignment: f32,
}
impl Default for DomainStats {
fn default() -> Self {
Self {
domain_losses: Vec::new(),
domain_variance: 0.0,
cross_domain_alignment: 0.0,
}
}
}
impl Default for GENIEState {
fn default() -> Self {
Self {
step: 0,
osgr: HashMap::new(),
osgr_ema: HashMap::new(),
prev_gradients: HashMap::new(),
prev_loss: None,
alignment_stats: HashMap::new(),
preconditioning_factors: HashMap::new(),
domain_stats: DomainStats::default(),
}
}
}
pub struct GENIE {
config: GENIEConfig,
state: GENIEState,
}
impl GENIE {
pub fn new(config: GENIEConfig) -> Self {
Self {
config,
state: GENIEState::default(),
}
}
pub fn learning_rate(&self) -> f32 {
self.config.learning_rate
}
pub fn set_learning_rate(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn compute_osgr(&self, param_name: &str, gradient: &Tensor, current_loss: f32) -> Result<Tensor> {
let gradient_norm_sq = gradient.pow(&Tensor::scalar(2.0)?)?;
if let Some(prev_loss) = self.state.prev_loss {
let loss_reduction = prev_loss - current_loss;
let loss_reduction = loss_reduction.max(self.config.preconditioning_eps);
let osgr = gradient_norm_sq.div(&Tensor::scalar(loss_reduction)?)?;
let min_osgr = Tensor::scalar(self.config.min_osgr)?;
let max_osgr = Tensor::scalar(self.config.max_osgr)?;
Ok(osgr.clamp(&min_osgr, &max_osgr)?)
} else {
Ok(gradient_norm_sq)
}
}
fn update_osgr_ema(&mut self, param_name: &str, current_osgr: &Tensor) -> Result<()> {
let momentum = self.config.osgr_momentum;
if let Some(prev_ema) = self.state.osgr_ema.get(param_name) {
let momentum_tensor = Tensor::scalar(momentum)?;
let one_minus_momentum = Tensor::scalar(1.0 - momentum)?;
let weighted_prev = prev_ema.mul(&momentum_tensor)?;
let weighted_current = current_osgr.mul(&one_minus_momentum)?;
let new_ema = weighted_prev.add(&weighted_current)?;
self.state.osgr_ema.insert(param_name.to_string(), new_ema);
} else {
self.state.osgr_ema.insert(param_name.to_string(), current_osgr.clone());
}
Ok(())
}
fn compute_gradient_alignment(&self, param_name: &str, gradient: &Tensor) -> Result<f32> {
if let Some(prev_grad) = self.state.prev_gradients.get(param_name) {
let dot_product = gradient.flatten()?.dot(&prev_grad.flatten()?)?;
let current_norm = gradient.flatten()?.norm()?.to_scalar::<f32>()?;
let prev_norm = prev_grad.flatten()?.norm()?.to_scalar::<f32>()?;
let alignment = dot_product.to_scalar::<f32>()? / (current_norm * prev_norm + self.config.preconditioning_eps);
Ok(alignment.clamp(-1.0, 1.0))
} else {
Ok(0.0) }
}
fn compute_preconditioning_factor(&self, param_name: &str) -> Result<Tensor> {
if let Some(osgr_ema) = self.state.osgr_ema.get(param_name) {
if self.config.normalize_osgr {
let mean_osgr = self.compute_mean_osgr()?;
let normalized_osgr = osgr_ema.div(&Tensor::scalar(mean_osgr + self.config.preconditioning_eps)?)?;
let sqrt_osgr = normalized_osgr.sqrt()?;
sqrt_osgr.reciprocal()
} else {
let sqrt_osgr = osgr_ema.sqrt()?;
sqrt_osgr.reciprocal()
}
} else {
Ok(Tensor::ones_like(&Tensor::scalar(1.0)?)?)
}
}
fn compute_mean_osgr(&self) -> Result<f32> {
if self.state.osgr_ema.is_empty() {
return Ok(1.0);
}
let mut total_osgr = 0.0;
let mut count = 0;
for osgr_tensor in self.state.osgr_ema.values() {
let osgr_mean = osgr_tensor.mean()?.to_scalar::<f32>()?;
total_osgr += osgr_mean;
count += 1;
}
Ok(total_osgr / count as f32)
}
fn compute_adaptive_alignment_weight(&self) -> f32 {
if !self.config.adaptive_alignment {
return self.config.alignment_weight;
}
let progress = (self.state.step as f32 / (self.config.warmup_steps as f32 + 1.0)).min(1.0);
self.config.alignment_weight * progress
}
pub fn step(&mut self, parameters: &mut HashMap<String, Tensor>,
gradients: &HashMap<String, Tensor>, current_loss: f32) -> Result<()> {
self.state.step += 1;
let use_osgr = self.state.step > self.config.warmup_steps;
for (param_name, gradient) in gradients.iter() {
if let Some(parameter) = parameters.get_mut(param_name) {
let mut effective_gradient = gradient.clone();
if self.config.weight_decay > 0.0 {
let weight_decay_term = parameter.mul(&Tensor::scalar(self.config.weight_decay)?)?;
effective_gradient = effective_gradient.add(&weight_decay_term)?;
}
let mut update = effective_gradient.clone();
if use_osgr {
let osgr = self.compute_osgr(param_name, &effective_gradient, current_loss)?;
self.state.osgr.insert(param_name.to_string(), osgr.clone());
self.update_osgr_ema(param_name, &osgr)?;
let alignment = self.compute_gradient_alignment(param_name, &effective_gradient)?;
self.state.alignment_stats.insert(param_name.to_string(), alignment);
let preconditioning = self.compute_preconditioning_factor(param_name)?;
self.state.preconditioning_factors.insert(param_name.to_string(), preconditioning.clone());
update = effective_gradient.mul(&preconditioning)?;
let alignment_weight = self.compute_adaptive_alignment_weight();
if alignment_weight > 0.0 {
let alignment_factor = 1.0 + alignment_weight * alignment.abs();
update = update.mul(&Tensor::scalar(alignment_factor)?)?;
}
}
let lr_tensor = Tensor::scalar(self.config.learning_rate)?;
let param_update = update.mul(&lr_tensor)?;
*parameter = parameter.sub(¶m_update)?;
self.state.prev_gradients.insert(param_name.to_string(), effective_gradient);
}
}
self.state.prev_loss = Some(current_loss);
Ok(())
}
pub fn get_osgr_stats(&self) -> HashMap<String, f32> {
let mut stats = HashMap::new();
for (param_name, osgr_tensor) in &self.state.osgr_ema {
if let Ok(mean_osgr) = osgr_tensor.mean().and_then(|t| t.to_scalar::<f32>()) {
stats.insert(param_name.clone(), mean_osgr);
}
}
stats
}
pub fn get_alignment_stats(&self) -> &HashMap<String, f32> {
&self.state.alignment_stats
}
pub fn get_domain_stats(&self) -> &DomainStats {
&self.state.domain_stats
}
pub fn reset_state(&mut self) {
self.state = GENIEState::default();
}
pub fn get_stats(&self) -> GENIEStats {
let mean_osgr = self.compute_mean_osgr().unwrap_or(1.0);
let mean_alignment = self.state.alignment_stats.values().sum::<f32>() /
self.state.alignment_stats.len().max(1) as f32;
GENIEStats {
step: self.state.step,
mean_osgr,
mean_alignment,
num_parameters: self.state.osgr_ema.len(),
adaptive_alignment_weight: self.compute_adaptive_alignment_weight(),
domain_variance: self.state.domain_stats.domain_variance,
}
}
}
#[derive(Debug, Clone)]
pub struct GENIEStats {
pub step: u64,
pub mean_osgr: f32,
pub mean_alignment: f32,
pub num_parameters: usize,
pub adaptive_alignment_weight: f32,
pub domain_variance: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::tensor::Tensor;
#[test]
fn test_genie_creation() {
let config = GENIEConfig::new()
.learning_rate(1e-3)
.osgr_momentum(0.9)
.alignment_weight(0.1)
.build();
let optimizer = GENIE::new(config);
assert_eq!(optimizer.learning_rate(), 1e-3);
}
#[test]
fn test_genie_config_builder() {
let config = GENIEConfig::new()
.learning_rate(2e-3)
.osgr_momentum(0.95)
.alignment_weight(0.2)
.preconditioning_eps(1e-6)
.weight_decay(1e-4)
.build();
assert_eq!(config.learning_rate, 2e-3);
assert_eq!(config.osgr_momentum, 0.95);
assert_eq!(config.alignment_weight, 0.2);
assert_eq!(config.preconditioning_eps, 1e-6);
assert_eq!(config.weight_decay, 1e-4);
}
#[test]
fn test_genie_step() -> Result<()> {
let config = GENIEConfig::new().learning_rate(1e-2).build();
let mut optimizer = GENIE::new(config);
let mut parameters = HashMap::new();
parameters.insert("weight".to_string(), Tensor::ones(&[2, 2])?);
let mut gradients = HashMap::new();
gradients.insert("weight".to_string(), Tensor::ones(&[2, 2])? * 0.1);
let initial_loss = 1.0;
optimizer.step(&mut parameters, &gradients, initial_loss)?;
let updated_param = parameters.get("weight").expect("Key not found");
let expected_value = 1.0 - 1e-2 * 0.1;
assert_ne!(updated_param.to_scalar::<f32>()?, 1.0);
Ok(())
}
#[test]
fn test_genie_osgr_computation() -> Result<()> {
let config = GENIEConfig::new().build();
let mut optimizer = GENIE::new(config);
optimizer.state.prev_loss = Some(2.0);
let gradient = Tensor::ones(&[2, 2])?;
let current_loss = 1.5;
let osgr = optimizer.compute_osgr("test", &gradient, current_loss)?;
let expected_osgr = 4.0 / 0.5;
let computed_osgr = osgr.mean()?.to_scalar::<f32>()?;
assert!((computed_osgr - expected_osgr).abs() < 1e-5);
Ok(())
}
#[test]
fn test_genie_gradient_alignment() -> Result<()> {
let config = GENIEConfig::new().build();
let mut optimizer = GENIE::new(config);
let prev_grad = Tensor::ones(&[2, 2])?;
optimizer.state.prev_gradients.insert("test".to_string(), prev_grad);
let current_grad = Tensor::ones(&[2, 2])?;
let alignment = optimizer.compute_gradient_alignment("test", ¤t_grad)?;
assert!((alignment - 1.0).abs() < 1e-5);
let opposite_grad = Tensor::ones(&[2, 2])? * -1.0;
let alignment = optimizer.compute_gradient_alignment("test", &opposite_grad)?;
assert!((alignment - (-1.0)).abs() < 1e-5);
Ok(())
}
#[test]
fn test_genie_stats() -> Result<()> {
let config = GENIEConfig::new().warmup_steps(1).build();
let mut optimizer = GENIE::new(config);
let mut parameters = HashMap::new();
parameters.insert("weight".to_string(), Tensor::ones(&[2, 2])?);
let mut gradients = HashMap::new();
gradients.insert("weight".to_string(), Tensor::ones(&[2, 2])? * 0.1);
for i in 0..5 {
let loss = 1.0 - i as f32 * 0.1;
optimizer.step(&mut parameters, &gradients, loss)?;
}
let stats = optimizer.get_stats();
assert_eq!(stats.step, 5);
assert!(stats.num_parameters > 0);
assert!(stats.mean_osgr > 0.0);
Ok(())
}
#[test]
fn test_genie_learning_rate_methods() {
let config = GENIEConfig::new().learning_rate(1e-3).build();
let mut optimizer = GENIE::new(config);
assert_eq!(optimizer.learning_rate(), 1e-3);
optimizer.set_learning_rate(2e-3);
assert_eq!(optimizer.learning_rate(), 2e-3);
}
#[test]
fn test_genie_weight_decay() -> Result<()> {
let config = GENIEConfig::new()
.learning_rate(1e-2)
.weight_decay(1e-2)
.build();
let mut optimizer = GENIE::new(config);
let mut parameters = HashMap::new();
parameters.insert("weight".to_string(), Tensor::ones(&[2, 2])?);
let mut gradients = HashMap::new();
gradients.insert("weight".to_string(), Tensor::zeros(&[2, 2])?);
let initial_param_value = parameters.get("weight").expect("Key not found").to_scalar::<f32>()?;
optimizer.step(&mut parameters, &gradients, 1.0)?;
let final_param_value = parameters.get("weight").expect("Key not found").to_scalar::<f32>()?;
assert!(final_param_value < initial_param_value);
Ok(())
}
#[test]
fn test_genie_warmup() -> Result<()> {
let config = GENIEConfig::new()
.learning_rate(1e-2)
.warmup_steps(5)
.build();
let mut optimizer = GENIE::new(config);
let mut parameters = HashMap::new();
parameters.insert("weight".to_string(), Tensor::ones(&[2, 2])?);
let mut gradients = HashMap::new();
gradients.insert("weight".to_string(), Tensor::ones(&[2, 2])? * 0.1);
optimizer.step(&mut parameters, &gradients, 1.0)?;
assert!(optimizer.state.osgr.is_empty());
for _ in 0..6 {
optimizer.step(&mut parameters, &gradients, 1.0)?;
}
assert!(!optimizer.state.osgr.is_empty());
Ok(())
}
}