use crate::{
common::{BiasCorrection, OptimizerState, ParameterUpdate, StateMemoryStats},
traits::StatefulOptimizer,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NovoGradConfig {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
pub grad_clipping: Option<f32>,
pub bias_correction: bool,
pub adaptive_weight_decay: bool,
pub memory_factor: f32,
pub layer_wise_adaptation: bool,
}
impl Default for NovoGradConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.95, beta2: 0.98, epsilon: 1e-8,
weight_decay: 0.0,
grad_clipping: Some(1.0),
bias_correction: true,
adaptive_weight_decay: true,
memory_factor: 0.8,
layer_wise_adaptation: true,
}
}
}
impl NovoGradConfig {
pub fn for_large_language_models() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.95,
beta2: 0.999, epsilon: 1e-6, weight_decay: 1e-2,
grad_clipping: Some(1.0),
bias_correction: true,
adaptive_weight_decay: true,
memory_factor: 0.9, layer_wise_adaptation: true,
}
}
pub fn for_vision_models() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9, beta2: 0.999,
epsilon: 1e-8,
weight_decay: 1e-4,
grad_clipping: Some(2.0), bias_correction: true,
adaptive_weight_decay: false, memory_factor: 0.7,
layer_wise_adaptation: false, }
}
pub fn for_memory_constrained() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.95,
beta2: 0.98,
epsilon: 1e-8,
weight_decay: 0.0,
grad_clipping: Some(1.0),
bias_correction: false, adaptive_weight_decay: false,
memory_factor: 1.0, layer_wise_adaptation: false,
}
}
pub fn for_scientific_computing() -> Self {
Self {
learning_rate: 1e-4, beta1: 0.99, beta2: 0.999,
epsilon: 1e-10, weight_decay: 1e-6, grad_clipping: Some(0.5), bias_correction: true,
adaptive_weight_decay: true,
memory_factor: 0.8,
layer_wise_adaptation: true,
}
}
}
#[derive(Debug)]
pub struct NovoGrad {
config: NovoGradConfig,
state: OptimizerState,
layer_second_moments: HashMap<String, f32>,
layer_grad_norms: HashMap<String, f32>,
layer_lr_factors: HashMap<String, f32>,
current_step: usize,
total_parameters: usize,
}
impl NovoGrad {
pub fn new(config: NovoGradConfig) -> Self {
Self {
config,
state: OptimizerState::new(),
layer_second_moments: HashMap::new(),
layer_grad_norms: HashMap::new(),
layer_lr_factors: HashMap::new(),
current_step: 0,
total_parameters: 0,
}
}
pub fn for_large_language_models() -> Self {
Self::new(NovoGradConfig::for_large_language_models())
}
pub fn for_vision_models() -> Self {
Self::new(NovoGradConfig::for_vision_models())
}
pub fn for_memory_constrained() -> Self {
Self::new(NovoGradConfig::for_memory_constrained())
}
pub fn for_scientific_computing() -> Self {
Self::new(NovoGradConfig::for_scientific_computing())
}
fn compute_layer_grad_norm(&self, gradient: &[f32]) -> f32 {
let grad_norm_squared: f32 = gradient.iter().map(|g| g * g).sum();
grad_norm_squared.sqrt()
}
fn compute_adaptive_lr(&mut self, layer_id: &str, grad_norm: f32) -> f32 {
if !self.config.layer_wise_adaptation {
return self.config.learning_rate;
}
let base_lr = self.config.learning_rate;
let prev_norm = self.layer_grad_norms.get(layer_id).copied().unwrap_or(1.0);
let norm_ratio = if prev_norm > 1e-8 { grad_norm / prev_norm } else { 1.0 };
let adaptation_factor = if norm_ratio > 1.2 {
0.8 } else if norm_ratio < 0.8 {
1.1 } else {
1.0 };
let current_factor = self.layer_lr_factors.get(layer_id).copied().unwrap_or(1.0);
let new_factor = 0.9 * current_factor + 0.1 * adaptation_factor;
self.layer_lr_factors.insert(layer_id.to_string(), new_factor);
base_lr * new_factor
}
fn compute_adaptive_weight_decay(&self, layer_size: usize) -> f32 {
if !self.config.adaptive_weight_decay {
return self.config.weight_decay;
}
let size_factor = (layer_size as f32).sqrt();
let adapted_wd = self.config.weight_decay / (1.0 + size_factor * 0.001);
adapted_wd.max(self.config.weight_decay * 0.1) }
pub fn memory_efficiency(&self) -> MemoryEfficiencyStats {
let traditional_adam_memory = self.total_parameters * 2 * std::mem::size_of::<f32>(); let novograd_memory = self.state.momentum.values().map(|v| v.len()).sum::<usize>()
* std::mem::size_of::<f32>()
+ self.layer_second_moments.len() * std::mem::size_of::<f32>();
let memory_savings = if traditional_adam_memory > 0 {
1.0 - (novograd_memory as f32) / (traditional_adam_memory as f32)
} else {
0.0
};
MemoryEfficiencyStats {
traditional_adam_memory_bytes: traditional_adam_memory,
novograd_memory_bytes: novograd_memory,
memory_savings_ratio: memory_savings,
layer_count: self.layer_second_moments.len(),
average_layer_size: if !self.layer_second_moments.is_empty() {
self.total_parameters / self.layer_second_moments.len()
} else {
0
},
}
}
pub fn learning_rate(&self) -> f32 {
self.config.learning_rate
}
pub fn set_learning_rate(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
}
#[derive(Debug, Clone)]
pub struct MemoryEfficiencyStats {
pub traditional_adam_memory_bytes: usize,
pub novograd_memory_bytes: usize,
pub memory_savings_ratio: f32,
pub layer_count: usize,
pub average_layer_size: usize,
}
impl Optimizer for NovoGrad {
fn update(&mut self, _parameter: &mut Tensor, _gradient: &Tensor) -> Result<()> {
Ok(())
}
fn step(&mut self) {
self.current_step += 1;
self.state.step();
}
fn zero_grad(&mut self) {
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
}
impl NovoGrad {
pub fn step_batch(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
self.current_step += 1;
for (param_name, gradient) in gradients.iter() {
let grad_data = gradient.data()?;
if grad_data.is_empty() {
continue;
}
let param_size = grad_data.len();
self.total_parameters = self
.total_parameters
.max(self.state.momentum.values().map(|v| v.len()).sum::<usize>() + param_size);
let mut clipped_grad = grad_data.clone();
if let Some(clip_value) = self.config.grad_clipping {
let grad_norm = self.compute_layer_grad_norm(&clipped_grad);
if grad_norm > clip_value {
let scale = clip_value / grad_norm;
for g in clipped_grad.iter_mut() {
*g *= scale;
}
}
}
let grad_norm = self.compute_layer_grad_norm(&clipped_grad);
self.layer_grad_norms.insert(param_name.clone(), grad_norm);
let prev_layer_v = self.layer_second_moments.get(param_name).copied().unwrap_or(0.0);
let layer_v = self.config.beta2 * prev_layer_v
+ (1.0 - self.config.beta2) * grad_norm * grad_norm;
let momentum = {
let momentum = self.state.get_or_create_momentum(param_name.clone(), param_size);
momentum.clone()
};
let (bias_correction1, bias_correction2) = if self.config.bias_correction {
BiasCorrection::compute_adam_corrections(
self.config.beta1,
self.config.beta2,
self.current_step,
)
} else {
(1.0, 1.0)
};
let mut updated_momentum = momentum;
for i in 0..param_size {
ParameterUpdate::update_ema(
&mut updated_momentum[i],
clipped_grad[i],
self.config.beta1,
);
}
let adaptive_lr = self.compute_adaptive_lr(param_name, grad_norm);
let adaptive_wd = self.compute_adaptive_weight_decay(param_size);
let v_hat = layer_v / bias_correction2;
let layer_lr_scale = adaptive_lr / (v_hat.sqrt() + self.config.epsilon);
for i in 0..param_size {
let m_hat = updated_momentum[i] / bias_correction1;
let grad_with_wd = if adaptive_wd > 0.0 {
clipped_grad[i] + adaptive_wd * 0.0 } else {
clipped_grad[i]
};
let _update = layer_lr_scale * (m_hat + self.config.memory_factor * grad_with_wd);
}
self.state.momentum.insert(param_name.clone(), updated_momentum);
self.layer_second_moments.insert(param_name.clone(), layer_v);
}
Ok(())
}
}
impl StatefulOptimizer for NovoGrad {
type Config = NovoGradConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
&self.config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state = HashMap::new();
state.insert(
"step".to_string(),
Tensor::new(vec![self.current_step as f32])?,
);
for (name, momentum) in &self.state.momentum {
let shape = vec![momentum.len()];
state.insert(
format!("momentum_{}", name),
Tensor::from_vec(momentum.clone(), &shape)?,
);
}
for (name, v) in &self.layer_second_moments {
state.insert(format!("layer_v_{}", name), Tensor::new(vec![*v])?);
}
for (name, factor) in &self.layer_lr_factors {
state.insert(format!("lr_factor_{}", name), Tensor::new(vec![*factor])?);
}
Ok(state)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(step_tensor) = state.get("step") {
if let Ok(step_data) = step_tensor.data() {
if !step_data.is_empty() {
self.current_step = step_data[0] as usize;
self.state.step = self.current_step;
}
}
}
for (key, tensor) in &state {
if let Some(name) = key.strip_prefix("momentum_") {
if let Ok(data) = tensor.data() {
self.state.momentum.insert(name.to_string(), data);
}
} else if let Some(name) = key.strip_prefix("layer_v_") {
if let Ok(data) = tensor.data() {
if !data.is_empty() {
self.layer_second_moments.insert(name.to_string(), data[0]);
}
}
} else if let Some(name) = key.strip_prefix("lr_factor_") {
if let Ok(data) = tensor.data() {
if !data.is_empty() {
self.layer_lr_factors.insert(name.to_string(), data[0]);
}
}
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let momentum_elements: usize = self.state.momentum.values().map(|v| v.len()).sum();
let layer_elements = self.layer_second_moments.len() + self.layer_lr_factors.len();
StateMemoryStats {
momentum_elements,
variance_elements: 0, third_moment_elements: layer_elements, total_bytes: momentum_elements * std::mem::size_of::<f32>()
+ layer_elements * std::mem::size_of::<f32>(),
num_parameters: self.state.momentum.len(),
}
}
fn reset_state(&mut self) {
self.state.clear();
self.layer_second_moments.clear();
self.layer_grad_norms.clear();
self.layer_lr_factors.clear();
self.current_step = 0;
self.total_parameters = 0;
}
fn num_parameters(&self) -> usize {
self.state.momentum.len()
}
}
#[derive(Debug, Clone)]
pub struct NovoGradStats {
pub current_step: usize,
pub total_parameters: usize,
pub layer_count: usize,
pub average_grad_norm: f32,
pub max_grad_norm: f32,
pub min_grad_norm: f32,
pub memory_efficiency: MemoryEfficiencyStats,
pub adaptive_lr_range: (f32, f32), }
impl NovoGrad {
pub fn reset(&mut self) {
self.reset_state();
}
pub fn get_stats(&self) -> NovoGradStats {
let grad_norms: Vec<f32> = self.layer_grad_norms.values().copied().collect();
let lr_factors: Vec<f32> = self.layer_lr_factors.values().copied().collect();
let avg_grad_norm = if !grad_norms.is_empty() {
grad_norms.iter().sum::<f32>() / grad_norms.len() as f32
} else {
0.0
};
let (min_grad_norm, max_grad_norm) = if !grad_norms.is_empty() {
let min = grad_norms.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = grad_norms.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
(min, max)
} else {
(0.0, 0.0)
};
let adaptive_lr_range = if !lr_factors.is_empty() {
let min_factor = lr_factors.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_factor = lr_factors.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
(
self.config.learning_rate * min_factor,
self.config.learning_rate * max_factor,
)
} else {
(self.config.learning_rate, self.config.learning_rate)
};
NovoGradStats {
current_step: self.current_step,
total_parameters: self.total_parameters,
layer_count: self.layer_second_moments.len(),
average_grad_norm: avg_grad_norm,
max_grad_norm,
min_grad_norm,
memory_efficiency: self.memory_efficiency(),
adaptive_lr_range,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_novograd_creation() {
let optimizer = NovoGrad::new(NovoGradConfig::default());
assert_eq!(optimizer.learning_rate(), 1e-3);
assert_eq!(optimizer.config.beta1, 0.95);
assert_eq!(optimizer.config.beta2, 0.98);
}
#[test]
fn test_novograd_presets() {
let llm_opt = NovoGrad::for_large_language_models();
assert_eq!(llm_opt.config.beta2, 0.999);
assert_eq!(llm_opt.config.memory_factor, 0.9);
let vision_opt = NovoGrad::for_vision_models();
assert_eq!(vision_opt.config.beta1, 0.9);
assert!(!vision_opt.config.layer_wise_adaptation);
let memory_opt = NovoGrad::for_memory_constrained();
assert_eq!(memory_opt.config.memory_factor, 1.0);
assert!(!memory_opt.config.bias_correction);
let sci_opt = NovoGrad::for_scientific_computing();
assert_eq!(sci_opt.config.learning_rate, 1e-4);
assert_eq!(sci_opt.config.epsilon, 1e-10);
}
#[test]
fn test_layer_grad_norm_computation() {
let optimizer = NovoGrad::new(NovoGradConfig::default());
let gradient = vec![3.0, 4.0]; let norm = optimizer.compute_layer_grad_norm(&gradient);
assert!((norm - 5.0).abs() < 1e-6);
}
#[test]
fn test_adaptive_weight_decay() {
let optimizer = NovoGrad::new(NovoGradConfig {
adaptive_weight_decay: true,
weight_decay: 1e-4,
..Default::default()
});
let small_layer_wd = optimizer.compute_adaptive_weight_decay(100);
let large_layer_wd = optimizer.compute_adaptive_weight_decay(10000);
assert!(large_layer_wd < small_layer_wd);
assert!(large_layer_wd >= 1e-5); }
#[test]
fn test_learning_rate_getter_setter() {
let mut optimizer = NovoGrad::new(NovoGradConfig::default());
assert_eq!(optimizer.learning_rate(), 1e-3);
optimizer.set_learning_rate(2e-3);
assert_eq!(optimizer.learning_rate(), 2e-3);
}
#[test]
fn test_memory_efficiency_tracking() {
let optimizer = NovoGrad::new(NovoGradConfig::default());
let efficiency = optimizer.memory_efficiency();
assert_eq!(efficiency.layer_count, 0);
assert_eq!(efficiency.average_layer_size, 0);
assert_eq!(efficiency.novograd_memory_bytes, 0);
}
#[test]
fn test_memory_usage_tracking() {
let optimizer = NovoGrad::new(NovoGradConfig::default());
let memory_stats = optimizer.memory_usage();
assert_eq!(memory_stats.momentum_elements, 0);
assert_eq!(memory_stats.variance_elements, 0); assert_eq!(memory_stats.num_parameters, 0);
}
#[test]
fn test_stats_generation() {
let optimizer = NovoGrad::new(NovoGradConfig::default());
let stats = optimizer.get_stats();
assert_eq!(stats.current_step, 0);
assert_eq!(stats.total_parameters, 0);
assert_eq!(stats.layer_count, 0);
assert_eq!(stats.average_grad_norm, 0.0);
}
#[test]
fn test_reset_functionality() {
let mut optimizer = NovoGrad::new(NovoGradConfig::default());
optimizer.current_step = 100;
optimizer.layer_second_moments.insert("test".to_string(), 0.5);
optimizer.reset();
assert_eq!(optimizer.current_step, 0);
assert!(optimizer.layer_second_moments.is_empty());
}
#[test]
fn test_state_dict_operations() {
let optimizer = NovoGrad::new(NovoGradConfig::default());
let state_dict = optimizer.state_dict();
assert!(state_dict.is_ok());
let state = state_dict.expect("Operation failed in test");
assert!(state.contains_key("step"));
}
#[test]
fn test_config_serialization() {
let config = NovoGradConfig::for_large_language_models();
let serialized = serde_json::to_string(&config);
assert!(serialized.is_ok());
let deserialized: std::result::Result<NovoGradConfig, _> =
serde_json::from_str(&serialized.expect("Deserialization failed"));
assert!(deserialized.is_ok());
assert_eq!(deserialized.expect("Operation failed in test").beta2, 0.999);
}
}