use crate::common::{OptimizerState, StateMemoryStats};
use crate::traits::StatefulOptimizer;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::errors::Result;
use trustformers_core::{tensor::Tensor, traits::Optimizer};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaMaxPlusConfig {
pub learning_rate: f32,
pub betas: (f32, f32),
pub epsilon: f32,
pub weight_decay: f32,
pub adaptive_momentum: bool,
pub momentum_adaptation_strength: f32,
pub warmup_steps: usize,
pub variance_tracking: bool,
pub bias_correction_factor: f32,
pub outlier_threshold: f32,
}
impl Default for AdaMaxPlusConfig {
fn default() -> Self {
Self {
learning_rate: 0.001,
betas: (0.9, 0.999),
epsilon: 1e-8,
weight_decay: 0.0,
adaptive_momentum: true,
momentum_adaptation_strength: 0.1,
warmup_steps: 0,
variance_tracking: true,
bias_correction_factor: 1.0,
outlier_threshold: 10.0,
}
}
}
impl AdaMaxPlusConfig {
pub fn new() -> Self {
Self::default()
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn betas(mut self, betas: (f32, f32)) -> Self {
self.betas = betas;
self
}
pub fn epsilon(mut self, eps: f32) -> Self {
self.epsilon = eps;
self
}
pub fn weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
pub fn enable_adaptive_momentum(mut self, enabled: bool) -> Self {
self.adaptive_momentum = enabled;
self
}
pub fn momentum_adaptation_strength(mut self, strength: f32) -> Self {
self.momentum_adaptation_strength = strength;
self
}
pub fn warmup_steps(mut self, steps: usize) -> Self {
self.warmup_steps = steps;
self
}
pub fn variance_tracking(mut self, enabled: bool) -> Self {
self.variance_tracking = enabled;
self
}
pub fn bias_correction_factor(mut self, factor: f32) -> Self {
self.bias_correction_factor = factor;
self
}
pub fn outlier_threshold(mut self, threshold: f32) -> Self {
self.outlier_threshold = threshold;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaMaxPlusParameterState {
pub momentum: Vec<f32>,
pub inf_norm: f32,
pub gradient_variance: f32,
pub step_count: usize,
pub grad_ema: Option<Vec<f32>>,
pub grad_sq_ema: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaMaxPlusState {
pub state: OptimizerState,
pub config: AdaMaxPlusConfig,
pub step_count: usize,
pub inf_norms: HashMap<String, f32>,
pub gradient_variances: HashMap<String, f32>,
pub param_step_counts: HashMap<String, usize>,
}
impl AdaMaxPlusState {
pub fn new(config: AdaMaxPlusConfig) -> Self {
Self {
state: OptimizerState::new(),
config,
step_count: 0,
inf_norms: HashMap::new(),
gradient_variances: HashMap::new(),
param_step_counts: HashMap::new(),
}
}
pub fn memory_usage(&self) -> usize {
let momentum_size = self.state.momentum.values().map(|v| v.len() * 4).sum::<usize>(); let variance_size = self.state.variance.values().map(|v| v.len() * 4).sum::<usize>();
let inf_norms_size = self.inf_norms.len() * 4; let gradient_variances_size = self.gradient_variances.len() * 4;
let param_step_counts_size = self.param_step_counts.len() * 8;
momentum_size
+ variance_size
+ inf_norms_size
+ gradient_variances_size
+ param_step_counts_size
}
}
pub struct AdaMaxPlus {
state: AdaMaxPlusState,
}
impl AdaMaxPlus {
pub fn new(learning_rate: f32, betas: (f32, f32), epsilon: f32, weight_decay: f32) -> Self {
let config = AdaMaxPlusConfig {
learning_rate,
betas,
epsilon,
weight_decay,
..Default::default()
};
Self {
state: AdaMaxPlusState::new(config),
}
}
pub fn from_config(config: AdaMaxPlusConfig) -> Self {
Self {
state: AdaMaxPlusState::new(config),
}
}
pub fn for_large_models() -> Self {
let config = AdaMaxPlusConfig::new()
.learning_rate(0.0002)
.betas((0.9, 0.999))
.enable_adaptive_momentum(true)
.warmup_steps(10000)
.variance_tracking(true)
.weight_decay(0.1);
Self::from_config(config)
}
pub fn for_fast_training() -> Self {
let config = AdaMaxPlusConfig::new()
.learning_rate(0.003)
.betas((0.95, 0.999))
.enable_adaptive_momentum(true)
.momentum_adaptation_strength(0.2)
.warmup_steps(500);
Self::from_config(config)
}
pub fn for_stable_training() -> Self {
let config = AdaMaxPlusConfig::new()
.learning_rate(0.001)
.betas((0.9, 0.999))
.enable_adaptive_momentum(false)
.variance_tracking(false)
.bias_correction_factor(1.2)
.outlier_threshold(5.0);
Self::from_config(config)
}
fn compute_adaptive_momentum(&self, param_id: String) -> f32 {
if !self.state.config.adaptive_momentum {
return self.state.config.betas.0;
}
let base_beta1 = self.state.config.betas.0;
let adaptation_strength = self.state.config.momentum_adaptation_strength;
let variance_factor = if self.state.config.variance_tracking {
self.state.gradient_variances.get(¶m_id).copied().unwrap_or(0.0).min(1.0)
} else {
0.0
};
let adaptive_beta1 = base_beta1 * (1.0 - adaptation_strength * variance_factor);
adaptive_beta1.clamp(0.1, 0.99) }
fn compute_effective_learning_rate(&self) -> f32 {
let base_lr = self.state.config.learning_rate;
if self.state.config.warmup_steps == 0 {
return base_lr;
}
let warmup_factor = if self.state.step_count <= self.state.config.warmup_steps {
self.state.step_count as f32 / self.state.config.warmup_steps as f32
} else {
1.0
};
base_lr * warmup_factor
}
fn update_gradient_variance(&mut self, param_id: String, gradient: &Tensor) -> Result<()> {
if !self.state.config.variance_tracking {
return Ok(());
}
let beta1 = self.state.config.betas.0;
let beta2 = self.state.config.betas.1;
let gradient_data = gradient.data()?;
let param_size = gradient_data.len();
let grad_ema = self
.state
.state
.get_or_create_momentum(format!("{}_grad_ema", param_id), param_size)
.clone();
let grad_sq_ema = self
.state
.state
.get_or_create_variance(format!("{}_grad_sq_ema", param_id), param_size)
.clone();
let updated_grad_ema: Vec<f32> = grad_ema
.iter()
.zip(gradient_data.iter())
.map(|(&m, &g)| beta1 * m + (1.0 - beta1) * g)
.collect();
let updated_grad_sq_ema: Vec<f32> = grad_sq_ema
.iter()
.zip(gradient_data.iter())
.map(|(&v, &g)| beta2 * v + (1.0 - beta2) * g * g)
.collect();
let variance: f32 = updated_grad_sq_ema
.iter()
.zip(updated_grad_ema.iter())
.map(|(&sq_ema, &ema)| sq_ema - ema * ema)
.sum::<f32>()
/ param_size as f32;
self.state
.state
.momentum
.insert(format!("{}_grad_ema", param_id), updated_grad_ema);
self.state
.state
.variance
.insert(format!("{}_grad_sq_ema", param_id), updated_grad_sq_ema);
self.state.gradient_variances.insert(param_id, variance);
Ok(())
}
}
impl Optimizer for AdaMaxPlus {
fn step(&mut self) {
}
fn zero_grad(&mut self) {
}
fn update(&mut self, parameter: &mut Tensor, gradient: &Tensor) -> Result<()> {
let param_data = parameter.data()?;
let param_id = format!("{:p}", param_data.as_ptr());
let param_size = param_data.len();
self.state.step_count += 1;
let momentum_data = {
let momentum_buffer =
self.state.state.get_or_create_momentum(param_id.clone(), param_size);
momentum_buffer.clone()
};
if self.state.config.variance_tracking {
self.update_gradient_variance(param_id.clone(), gradient)?;
}
let effective_gradient = if self.state.config.weight_decay > 0.0 {
gradient.add(¶meter.mul_scalar(self.state.config.weight_decay)?)?
} else {
gradient.clone()
};
let adaptive_beta1 = self.compute_adaptive_momentum(param_id.clone());
let beta2 = self.state.config.betas.1;
let gradient_data = effective_gradient.data()?;
let updated_momentum: Vec<f32> = momentum_data
.iter()
.zip(gradient_data.iter())
.map(|(&m, &g)| adaptive_beta1 * m + (1.0 - adaptive_beta1) * g)
.collect();
let grad_inf_norm = gradient_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let clamped_grad_norm = grad_inf_norm.min(self.state.config.outlier_threshold);
let current_inf_norm = self.state.inf_norms.get(¶m_id).copied().unwrap_or(0.0);
let new_inf_norm = (beta2 * current_inf_norm).max(clamped_grad_norm);
self.state.inf_norms.insert(param_id.clone(), new_inf_norm);
let step_count = self.state.param_step_counts.entry(param_id.clone()).or_insert(0);
*step_count += 1;
let bias_correction = 1.0 - adaptive_beta1.powi(*step_count as i32);
let bias_corrected_momentum: Vec<f32> = updated_momentum
.iter()
.map(|&m| m / (bias_correction * self.state.config.bias_correction_factor))
.collect();
let effective_lr = self.compute_effective_learning_rate();
let step_size = effective_lr / (new_inf_norm + self.state.config.epsilon);
let param_data = parameter.data()?;
let updated_params: Vec<f32> = param_data
.iter()
.zip(bias_corrected_momentum.iter())
.map(|(&p, &m)| p - step_size * m)
.collect();
*parameter = Tensor::new(updated_params)?;
self.state.state.momentum.insert(param_id, updated_momentum);
Ok(())
}
fn set_lr(&mut self, lr: f32) {
self.state.config.learning_rate = lr;
}
fn get_lr(&self) -> f32 {
self.state.config.learning_rate
}
}
impl StatefulOptimizer for AdaMaxPlus {
type Config = AdaMaxPlusConfig;
type State = AdaMaxPlusState;
fn config(&self) -> &Self::Config {
&self.state.config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
for (key, buffer) in &self.state.state.momentum {
let tensor = Tensor::new(buffer.clone())?;
state_dict.insert(format!("{}_momentum", key), tensor);
}
for (key, buffer) in &self.state.state.variance {
let tensor = Tensor::new(buffer.clone())?;
state_dict.insert(format!("{}_variance", key), tensor);
}
for (key, &inf_norm) in &self.state.inf_norms {
let tensor = Tensor::new(vec![inf_norm])?;
state_dict.insert(format!("{}_inf_norm", key), tensor);
}
for (key, &variance) in &self.state.gradient_variances {
let tensor = Tensor::new(vec![variance])?;
state_dict.insert(format!("{}_gradient_variance", key), tensor);
}
for (key, &step_count) in &self.state.param_step_counts {
let tensor = Tensor::new(vec![step_count as f32])?;
state_dict.insert(format!("{}_step_count", key), tensor);
}
let step_tensor = Tensor::new(vec![self.state.step_count as f32])?;
state_dict.insert("step_count".to_string(), step_tensor);
Ok(state_dict)
}
fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
for (key, tensor) in state_dict {
let data = tensor.data()?;
if key == "step_count" {
if let Some(&count) = data.first() {
self.state.step_count = count as usize;
}
} else if let Some(param_id) = key.strip_suffix("_momentum") {
self.state.state.momentum.insert(param_id.to_string(), data.clone());
} else if let Some(param_id) = key.strip_suffix("_variance") {
self.state.state.variance.insert(param_id.to_string(), data.clone());
} else if let Some(param_id) = key.strip_suffix("_inf_norm") {
if let Some(&inf_norm) = data.first() {
self.state.inf_norms.insert(param_id.to_string(), inf_norm);
}
} else if let Some(param_id) = key.strip_suffix("_gradient_variance") {
if let Some(&variance) = data.first() {
self.state.gradient_variances.insert(param_id.to_string(), variance);
}
} else if let Some(param_id) = key.strip_suffix("_step_count") {
if let Some(&step_count) = data.first() {
self.state.param_step_counts.insert(param_id.to_string(), step_count as usize);
}
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
StateMemoryStats {
momentum_elements: self.state.state.momentum.values().map(|v| v.len()).sum::<usize>(),
variance_elements: self.state.state.variance.values().map(|v| v.len()).sum::<usize>(),
third_moment_elements: 0, total_bytes: self.state.memory_usage(),
num_parameters: self.state.state.momentum.len(),
}
}
fn reset_state(&mut self) {
self.state.state.clear();
self.state.step_count = 0;
self.state.inf_norms.clear();
self.state.gradient_variances.clear();
self.state.param_step_counts.clear();
}
fn num_parameters(&self) -> usize {
self.state.state.momentum.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::tensor::Tensor;
#[test]
fn test_adamax_plus_creation() {
let optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.01);
assert_eq!(optimizer.get_lr(), 0.001);
assert_eq!(optimizer.state.config.betas, (0.9, 0.999));
assert_eq!(optimizer.state.config.epsilon, 1e-8);
assert_eq!(optimizer.state.config.weight_decay, 0.01);
}
#[test]
fn test_adamax_plus_config() {
let config = AdaMaxPlusConfig::new()
.learning_rate(0.002)
.betas((0.95, 0.999))
.enable_adaptive_momentum(true)
.warmup_steps(1000);
let optimizer = AdaMaxPlus::from_config(config);
assert_eq!(optimizer.get_lr(), 0.002);
assert_eq!(optimizer.state.config.betas, (0.95, 0.999));
assert!(optimizer.state.config.adaptive_momentum);
assert_eq!(optimizer.state.config.warmup_steps, 1000);
}
#[test]
fn test_adamax_plus_presets() {
let llm_optimizer = AdaMaxPlus::for_large_models();
assert_eq!(llm_optimizer.get_lr(), 0.0002);
assert_eq!(llm_optimizer.state.config.warmup_steps, 10000);
assert!(llm_optimizer.state.config.adaptive_momentum);
let fast_optimizer = AdaMaxPlus::for_fast_training();
assert_eq!(fast_optimizer.get_lr(), 0.003);
assert_eq!(
fast_optimizer.state.config.momentum_adaptation_strength,
0.2
);
let stable_optimizer = AdaMaxPlus::for_stable_training();
assert!(!stable_optimizer.state.config.adaptive_momentum);
assert!(!stable_optimizer.state.config.variance_tracking);
}
#[test]
fn test_adamax_plus_step() -> Result<()> {
let mut optimizer = AdaMaxPlus::new(0.01, (0.9, 0.999), 1e-8, 0.0);
let mut param = Tensor::ones(&[2, 2])?;
let grad = Tensor::new(vec![0.1, 0.2, 0.3, 0.4])?;
let original_data = param.data()?.clone();
optimizer.update(&mut param, &grad)?;
let param_data = param.data()?;
assert!(param_data.iter().zip(original_data.iter()).all(|(&new, &orig)| new != orig));
Ok(())
}
#[test]
fn test_warmup_learning_rate() {
let mut optimizer =
AdaMaxPlus::from_config(AdaMaxPlusConfig::new().learning_rate(0.001).warmup_steps(100));
assert_eq!(optimizer.compute_effective_learning_rate(), 0.0);
optimizer.state.step_count = 50;
assert!((optimizer.compute_effective_learning_rate() - 0.0005).abs() < 1e-9);
optimizer.state.step_count = 100;
assert!((optimizer.compute_effective_learning_rate() - 0.001).abs() < 1e-9);
optimizer.state.step_count = 200;
assert!((optimizer.compute_effective_learning_rate() - 0.001).abs() < 1e-9);
}
#[test]
fn test_adaptive_momentum() {
let optimizer = AdaMaxPlus::from_config(
AdaMaxPlusConfig::new()
.enable_adaptive_momentum(true)
.momentum_adaptation_strength(0.2),
);
let param_id = "test_param".to_string();
let mut test_optimizer = optimizer;
test_optimizer.state.gradient_variances.insert(param_id.clone(), 0.1);
let adaptive_beta1 = test_optimizer.compute_adaptive_momentum(param_id.clone());
assert!(adaptive_beta1 > 0.85);
test_optimizer.state.gradient_variances.insert(param_id.clone(), 0.8);
let adaptive_beta1_high = test_optimizer.compute_adaptive_momentum(param_id);
assert!(adaptive_beta1_high < adaptive_beta1); }
#[test]
fn test_state_dict_save_load() -> Result<()> {
let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.01);
let mut param = Tensor::ones(&[2])?;
let grad = Tensor::new(vec![0.1, 0.2])?;
optimizer.update(&mut param, &grad)?;
let state_dict = optimizer.state_dict()?;
assert!(!state_dict.is_empty());
let mut new_optimizer = AdaMaxPlus::new(0.002, (0.8, 0.99), 1e-7, 0.02);
new_optimizer.load_state_dict(state_dict)?;
assert_eq!(new_optimizer.get_lr(), 0.002); assert_eq!(new_optimizer.state.config.betas, (0.8, 0.99));
assert!(new_optimizer.state.step_count > 0);
Ok(())
}
#[test]
fn test_zero_grad() -> Result<()> {
let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
optimizer.zero_grad();
assert_eq!(optimizer.get_lr(), 0.001);
Ok(())
}
#[test]
fn test_memory_usage_tracking() {
let optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
let memory_usage = optimizer.memory_usage();
assert_eq!(memory_usage.total_bytes, 0); }
#[test]
fn test_lr_get_set() {
let mut optimizer = AdaMaxPlus::new(0.001, (0.9, 0.999), 1e-8, 0.0);
assert_eq!(optimizer.get_lr(), 0.001);
optimizer.set_lr(0.002);
assert_eq!(optimizer.get_lr(), 0.002);
}
}