use std::collections::HashMap;
use crate::Optimizer;
use crate::common::{l2_norm, zeros_entry};
#[derive(Debug, Clone)]
pub struct Adafactor {
pub lr: Option<f32>,
pub beta2_decay: f32,
pub eps1: f32,
pub eps2: f32,
pub clip_threshold: f32,
pub weight_decay: f32,
step: u64,
r: HashMap<String, Vec<f32>>, c: HashMap<String, Vec<f32>>, v: HashMap<String, Vec<f32>>, }
impl Adafactor {
pub fn new() -> Self {
Self {
lr: None,
beta2_decay: -0.8,
eps1: 1e-30,
eps2: 1e-3,
clip_threshold: 1.0,
weight_decay: 0.0,
step: 0,
r: HashMap::new(),
c: HashMap::new(),
v: HashMap::new(),
}
}
pub fn with_lr(mut self, lr: f32) -> Self {
self.lr = Some(lr);
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
}
impl Default for Adafactor {
fn default() -> Self {
Self::new()
}
}
impl Optimizer for Adafactor {
fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
debug_assert_eq!(param.len(), grad.len());
let t = (self.step + 1) as f64;
let beta2_t = 1.0 - t.powf(self.beta2_decay as f64);
let eps1 = self.eps1 as f64;
let clip = self.clip_threshold as f64;
let n = param.len();
let mut update = vec![0.0f32; n];
if shape.len() == 2 {
let (rows, cols) = (shape[0], shape[1]);
debug_assert_eq!(rows * cols, n);
let r = zeros_entry(&mut self.r, name, rows);
let mut row_buf = vec![0.0f64; rows];
for i in 0..rows {
let mut s = 0.0f64;
for j in 0..cols {
let g = grad[i * cols + j] as f64;
s += g * g + eps1;
}
row_buf[i] = s / cols as f64;
}
for i in 0..rows {
r[i] = (beta2_t * r[i] as f64 + (1.0 - beta2_t) * row_buf[i]) as f32;
}
let r_snapshot: Vec<f32> = r.clone();
let c = zeros_entry(&mut self.c, name, cols);
let mut col_buf = vec![0.0f64; cols];
for j in 0..cols {
let mut s = 0.0f64;
for i in 0..rows {
let g = grad[i * cols + j] as f64;
s += g * g + eps1;
}
col_buf[j] = s / rows as f64;
}
for j in 0..cols {
c[j] = (beta2_t * c[j] as f64 + (1.0 - beta2_t) * col_buf[j]) as f32;
}
let r_sum: f64 = r_snapshot.iter().map(|&x| x as f64).sum();
for i in 0..rows {
for j in 0..cols {
let v_ij = r_snapshot[i] as f64 * c[j] as f64 / r_sum.max(eps1);
let g = grad[i * cols + j] as f64;
update[i * cols + j] = (g / v_ij.sqrt().max(eps1.sqrt())) as f32;
}
}
} else {
let v = zeros_entry(&mut self.v, name, n);
for i in 0..n {
let g = grad[i] as f64;
v[i] = (beta2_t * v[i] as f64 + (1.0 - beta2_t) * (g * g + eps1)) as f32;
update[i] = (g / (v[i] as f64).sqrt().max(eps1.sqrt())) as f32;
}
}
let u_rms = (l2_norm(&update) as f64 / (n as f64).sqrt()).max(1.0 / clip);
let scale = (1.0 / (u_rms * clip)).min(1.0);
for u in update.iter_mut() {
*u = (*u as f64 * scale) as f32;
}
let lr = match self.lr {
Some(x) => x as f64,
None => {
let p_rms = (l2_norm(param) as f64 / (n as f64).sqrt()).max(self.eps2 as f64);
(1.0 / t.sqrt()).min(1e-2) * p_rms
}
};
let wd = self.weight_decay as f64;
for i in 0..n {
let p = param[i] as f64;
param[i] = (p - lr * (update[i] as f64 + wd * p)) as f32;
}
}
fn end_iteration(&mut self) {
self.step += 1;
}
}