use scirs2_core::ndarray::{Array, ArrayD, ArrayViewD, Dimension}; use std::collections::HashMap;
use trustformers_core::traits::Optimizer;
use crate::optimizer::OptimizerState;
#[derive(Debug, Clone)]
pub struct AdaFactor {
pub learning_rate: f32,
pub beta1: Option<f32>, pub beta2: f32,
pub epsilon: f32,
pub clip_threshold: f32,
pub decay_rate_sqrt: f32,
pub beta1_decay: bool,
pub scale_parameter: bool,
pub relative_step_size: bool,
pub warmup_init: bool,
state: HashMap<String, AdaFactorParamState>,
step: usize,
}
#[derive(Debug, Clone)]
struct AdaFactorParamState {
step: usize,
exp_avg: Option<ArrayD<f32>>, exp_avg_sq_row: Option<ArrayD<f32>>, exp_avg_sq_col: Option<ArrayD<f32>>, exp_avg_sq: Option<ArrayD<f32>>, }
impl AdaFactor {
pub fn new(learning_rate: f32) -> Self {
Self {
learning_rate,
beta1: None, beta2: -0.8,
epsilon: 1e-30,
clip_threshold: 1.0,
decay_rate_sqrt: 0.8,
beta1_decay: true,
scale_parameter: true,
relative_step_size: true,
warmup_init: false,
state: HashMap::new(),
step: 0,
}
}
pub fn with_beta1(mut self, beta1: f32) -> Self {
self.beta1 = Some(beta1);
self
}
pub fn with_beta2(mut self, beta2: f32) -> Self {
self.beta2 = beta2;
self
}
pub fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
pub fn with_clip_threshold(mut self, clip_threshold: f32) -> Self {
self.clip_threshold = clip_threshold;
self
}
pub fn with_relative_step_size(mut self, relative_step_size: bool) -> Self {
self.relative_step_size = relative_step_size;
self
}
fn get_lr(&self) -> f32 {
if self.relative_step_size {
let min_step = if self.warmup_init { 1e-6 } else { 1e-2 };
let rel_step_sz = (self.step as f32 + 1.0).powf(-0.5).min(min_step);
if self.scale_parameter {
rel_step_sz * self.learning_rate.sqrt()
} else {
rel_step_sz
}
} else {
self.learning_rate
}
}
fn get_beta1(&self) -> f32 {
if let Some(beta1) = self.beta1 {
if self.beta1_decay {
beta1 * (1.0 - (self.step as f32).powf(-self.decay_rate_sqrt))
} else {
beta1
}
} else {
0.0
}
}
fn get_beta2(&self) -> f32 {
1.0 - (self.step as f32 + 1.0).powf(self.beta2)
}
fn should_use_factored_second_moment(shape: &[usize]) -> bool {
shape.len() >= 2
}
fn approximate_sq_grad(&self, exp_avg_sq_row: &ArrayD<f32>, exp_avg_sq_col: &ArrayD<f32>) -> ArrayD<f32> {
let shape = exp_avg_sq_row.raw_dim();
let mut result = ArrayD::zeros(shape);
if shape.ndim() == 2 {
let (rows, cols) = (shape[0], shape[1]);
for i in 0..rows {
for j in 0..cols {
result[[i, j]] = exp_avg_sq_row[[i]] * exp_avg_sq_col[[j]];
}
}
let r_factor = exp_avg_sq_row.sum() / rows as f32;
let c_factor = exp_avg_sq_col.sum() / cols as f32;
let norm_factor = (r_factor * c_factor).sqrt();
if norm_factor > 0.0 {
result = result / norm_factor;
}
}
result
}
}
impl Optimizer for AdaFactor {
fn step(&mut self, params: &mut HashMap<String, ArrayViewD<f32>>, gradients: &HashMap<String, ArrayViewD<f32>>) -> Result<(), Box<dyn std::error::Error>> {
self.step += 1;
let lr = self.get_lr();
let beta1 = self.get_beta1();
let beta2 = self.get_beta2();
for (name, grad) in gradients {
if let Some(param) = params.get_mut(name) {
let grad_shape = grad.shape().to_vec();
if !self.state.contains_key(name) {
let factored = Self::should_use_factored_second_moment(&grad_shape);
let state = AdaFactorParamState {
step: 0,
exp_avg: if self.beta1.is_some() {
Some(ArrayD::zeros(grad.raw_dim()))
} else {
None
},
exp_avg_sq_row: if factored {
Some(ArrayD::zeros([grad_shape[0]]))
} else {
None
},
exp_avg_sq_col: if factored && grad_shape.len() > 1 {
Some(ArrayD::zeros([grad_shape[1]]))
} else {
None
},
exp_avg_sq: if !factored {
Some(ArrayD::zeros(grad.raw_dim()))
} else {
None
},
};
self.state.insert(name.clone(), state);
}
let state = self.state.get_mut(name)
.expect("state must exist after insert");
state.step += 1;
let grad_norm = grad.iter().map(|x| x * x).sum::<f32>().sqrt();
let clip_coeff = (self.clip_threshold / grad_norm.max(1e-8)).min(1.0);
let clipped_grad = grad.mapv(|x| x * clip_coeff);
if let Some(ref mut exp_avg) = state.exp_avg {
*exp_avg = exp_avg.mapv(|x| x * beta1) + clipped_grad.mapv(|x| x * (1.0 - beta1));
}
if let (Some(ref mut exp_avg_sq_row), Some(ref mut exp_avg_sq_col)) =
(&mut state.exp_avg_sq_row, &mut state.exp_avg_sq_col) {
if grad_shape.len() == 2 {
let grad_sq_row = clipped_grad.map_axis(ndarray::Axis(1), |row| {
row.iter().map(|x| x * x).sum::<f32>() / row.len() as f32
});
let grad_sq_col = clipped_grad.map_axis(ndarray::Axis(0), |col| {
col.iter().map(|x| x * x).sum::<f32>() / col.len() as f32
});
*exp_avg_sq_row = exp_avg_sq_row.mapv(|x| x * beta2) + grad_sq_row.mapv(|x| x * (1.0 - beta2));
*exp_avg_sq_col = exp_avg_sq_col.mapv(|x| x * beta2) + grad_sq_col.mapv(|x| x * (1.0 - beta2));
}
} else if let Some(ref mut exp_avg_sq) = state.exp_avg_sq {
*exp_avg_sq = exp_avg_sq.mapv(|x| x * beta2) + clipped_grad.mapv(|x| x * x * (1.0 - beta2));
}
let update = if let (Some(ref exp_avg_sq_row), Some(ref exp_avg_sq_col)) =
(&state.exp_avg_sq_row, &state.exp_avg_sq_col) {
let exp_avg_sq_approx = self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col);
let denominator = exp_avg_sq_approx.mapv(|x| x.sqrt() + self.epsilon);
if let Some(ref exp_avg) = state.exp_avg {
exp_avg / denominator
} else {
clipped_grad / denominator
}
} else if let Some(ref exp_avg_sq) = state.exp_avg_sq {
let denominator = exp_avg_sq.mapv(|x| x.sqrt() + self.epsilon);
if let Some(ref exp_avg) = state.exp_avg {
exp_avg / denominator
} else {
clipped_grad / denominator
}
} else {
return Err("Invalid AdaFactor state".into());
};
let param_array = param.to_owned();
let updated_param = param_array - update.mapv(|x| x * lr);
println!("AdaFactor update applied to parameter: {}", name);
}
}
Ok(())
}
fn get_state(&self) -> OptimizerState {
OptimizerState {
step: self.step,
learning_rate: self.learning_rate,
}
}
fn set_learning_rate(&mut self, lr: f32) {
self.learning_rate = lr;
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
#[test]
fn test_adafactor_creation() {
let optimizer = AdaFactor::new(0.001);
assert_eq!(optimizer.learning_rate, 0.001);
assert_eq!(optimizer.beta1, None);
assert!(optimizer.beta2 < 0.0); }
#[test]
fn test_adafactor_with_beta1() {
let optimizer = AdaFactor::new(0.001).with_beta1(0.9);
assert_eq!(optimizer.beta1, Some(0.9));
}
#[test]
fn test_factored_second_moment_check() {
assert!(!AdaFactor::should_use_factored_second_moment(&[100])); assert!(AdaFactor::should_use_factored_second_moment(&[10, 20])); assert!(AdaFactor::should_use_factored_second_moment(&[5, 10, 15])); }
#[test]
fn test_learning_rate_scaling() {
let mut optimizer = AdaFactor::new(1.0)
.with_relative_step_size(true);
optimizer.step = 100;
let lr = optimizer.get_lr();
assert!(lr > 0.0);
assert!(lr < 1.0); }
#[test]
fn test_beta_decay() {
let mut optimizer = AdaFactor::new(0.001)
.with_beta1(0.9);
optimizer.step = 0;
let beta1_0 = optimizer.get_beta1();
optimizer.step = 100;
let beta1_100 = optimizer.get_beta1();
assert!(beta1_100 < beta1_0); }
}