use anyhow::Result;
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct GENIEConfig {
pub learning_rate: f32,
pub osgr_momentum: f32,
pub alignment_weight: f32,
pub preconditioning_eps: f32,
pub min_osgr: f32,
pub max_osgr: f32,
pub adaptive_alignment: bool,
pub weight_decay: f32,
pub normalize_osgr: bool,
pub warmup_steps: u64,
}
impl Default for GENIEConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
osgr_momentum: 0.9,
alignment_weight: 0.1,
preconditioning_eps: 1e-8,
min_osgr: 1e-6,
max_osgr: 1e6,
adaptive_alignment: true,
weight_decay: 0.0,
normalize_osgr: true,
warmup_steps: 100,
}
}
}
impl GENIEConfig {
pub fn new() -> Self {
Self::default()
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn osgr_momentum(mut self, momentum: f32) -> Self {
self.osgr_momentum = momentum;
self
}
pub fn alignment_weight(mut self, weight: f32) -> Self {
self.alignment_weight = weight;
self
}
pub fn preconditioning_eps(mut self, eps: f32) -> Self {
self.preconditioning_eps = eps;
self
}
pub fn weight_decay(mut self, decay: f32) -> Self {
self.weight_decay = decay;
self
}
pub fn build(self) -> Self {
self
}
}
#[derive(Debug, Clone, Default)]
pub struct GENIEState {
pub step: u64,
pub momentum_buffers: HashMap<String, Vec<f32>>,
}
#[derive(Debug, Clone, Default)]
pub struct DomainStats {
pub domain_losses: Vec<f32>,
pub domain_variance: f32,
pub cross_domain_alignment: f32,
}
pub struct GENIE {
config: GENIEConfig,
state: GENIEState,
}
impl GENIE {
pub fn new(config: GENIEConfig) -> Self {
Self {
config,
state: GENIEState::default(),
}
}
pub fn learning_rate(&self) -> f32 {
self.config.learning_rate
}
pub fn set_learning_rate(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
pub fn step(
&mut self,
parameters: &mut HashMap<String, Tensor>,
gradients: &HashMap<String, Tensor>,
_current_loss: f32,
) -> Result<()> {
self.state.step += 1;
for (param_name, gradient) in gradients.iter() {
if let Some(parameter) = parameters.get_mut(param_name) {
let param_data = parameter.data()?;
let grad_data = gradient.data()?;
if !self.state.momentum_buffers.contains_key(param_name) {
self.state
.momentum_buffers
.insert(param_name.clone(), vec![0.0; param_data.len()]);
}
let momentum_buffer = self
.state
.momentum_buffers
.get_mut(param_name)
.expect("GENIE momentum buffer must exist after initialization");
let mut updated_params = param_data.clone();
for i in 0..param_data.len() {
momentum_buffer[i] = self.config.osgr_momentum * momentum_buffer[i]
+ (1.0 - self.config.osgr_momentum) * grad_data[i];
let effective_grad = if self.config.weight_decay > 0.0 {
grad_data[i] + self.config.weight_decay * param_data[i]
} else {
grad_data[i]
};
updated_params[i] = param_data[i]
- self.config.learning_rate
* (momentum_buffer[i] + self.config.alignment_weight * effective_grad);
}
*parameter = Tensor::new(updated_params)?;
}
}
Ok(())
}
pub fn get_osgr_stats(&self) -> HashMap<String, f32> {
HashMap::new()
}
pub fn get_alignment_stats(&self) -> HashMap<String, f32> {
HashMap::new()
}
pub fn get_domain_stats(&self) -> DomainStats {
DomainStats {
domain_losses: Vec::new(),
domain_variance: 0.0,
cross_domain_alignment: 0.0,
}
}
pub fn reset_state(&mut self) {
self.state = GENIEState::default();
}
pub fn get_stats(&self) -> GENIEStats {
GENIEStats {
step: self.state.step,
mean_osgr: 1.0,
mean_alignment: 0.0,
num_parameters: self.state.momentum_buffers.len(),
adaptive_alignment_weight: self.config.alignment_weight,
domain_variance: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct GENIEStats {
pub step: u64,
pub mean_osgr: f32,
pub mean_alignment: f32,
pub num_parameters: usize,
pub adaptive_alignment_weight: f32,
pub domain_variance: f32,
}