use std::collections::HashMap;
use crate::Optimizer;
use crate::common::zeros_entry;
#[derive(Debug, Clone)]
pub struct Mars {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub eps: f32,
pub weight_decay: f32,
pub gamma: f32,
pub clip_c: bool,
step: u64,
m: HashMap<String, Vec<f32>>,
v: HashMap<String, Vec<f32>>,
prev_g: HashMap<String, Vec<f32>>,
scratch: HashMap<String, Vec<f32>>,
}
impl Mars {
pub fn new(lr: f32) -> Self {
Self {
lr,
beta1: 0.95,
beta2: 0.99,
eps: 1e-8,
weight_decay: 0.0,
gamma: 0.025,
clip_c: true,
step: 0,
m: HashMap::new(),
v: HashMap::new(),
prev_g: HashMap::new(),
scratch: HashMap::new(),
}
}
pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
self.beta1 = b1;
self.beta2 = b2;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
}
impl Optimizer for Mars {
fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
debug_assert_eq!(param.len(), grad.len());
let t = (self.step + 1) as f64;
let b1 = self.beta1 as f64;
let b2 = self.beta2 as f64;
let bc1 = 1.0 - b1.powf(t);
let bc2 = 1.0 - b2.powf(t);
let scale = self.gamma as f64 * b1 / (1.0 - b1);
let eps = self.eps as f64;
let lr = self.lr as f64;
let wd = self.weight_decay as f64;
let prev = zeros_entry(&mut self.prev_g, name, param.len());
let c = zeros_entry(&mut self.scratch, name, param.len());
let m = zeros_entry(&mut self.m, name, param.len());
let v = zeros_entry(&mut self.v, name, param.len());
let mut c_sq_norm = 0.0f64;
for i in 0..param.len() {
let g = grad[i] as f64;
let pg = prev[i] as f64;
let ci = g + scale * (g - pg);
c[i] = ci as f32;
c_sq_norm += ci * ci;
prev[i] = grad[i];
}
if self.clip_c && c_sq_norm > 1.0 {
let s = (1.0 / c_sq_norm.sqrt()) as f32;
for ci in c.iter_mut() {
*ci *= s;
}
}
for i in 0..param.len() {
let ci = c[i] as f64;
let mi = b1 * m[i] as f64 + (1.0 - b1) * ci;
let vi = b2 * v[i] as f64 + (1.0 - b2) * ci * ci;
m[i] = mi as f32;
v[i] = vi as f32;
let m_hat = mi / bc1;
let v_hat = vi / bc2;
let p = param[i] as f64;
param[i] = (p - lr * (m_hat / (v_hat.sqrt() + eps) + wd * p)) as f32;
}
}
fn end_iteration(&mut self) {
self.step += 1;
}
}