use std::collections::HashMap;
use crate::Optimizer;
use crate::common::zeros_entry;
#[derive(Debug, Clone)]
pub struct Sophia {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub gamma: f32,
pub rho: f32,
pub eps: f32,
pub weight_decay: f32,
step: u64,
m: HashMap<String, Vec<f32>>,
h: HashMap<String, Vec<f32>>,
}
impl Sophia {
pub fn new(lr: f32) -> Self {
Self {
lr,
beta1: 0.965,
beta2: 0.99,
gamma: 0.01,
rho: 0.04,
eps: 1e-12,
weight_decay: 0.1,
step: 0,
m: HashMap::new(),
h: HashMap::new(),
}
}
pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
self.beta1 = b1;
self.beta2 = b2;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
pub fn update_hessian(&mut self, name: &str, h_hat: &[f32]) {
let h = zeros_entry(&mut self.h, name, h_hat.len());
let b2 = self.beta2;
for i in 0..h.len() {
h[i] = b2 * h[i] + (1.0 - b2) * h_hat[i];
}
}
}
impl Optimizer for Sophia {
fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
debug_assert_eq!(param.len(), grad.len());
let b1 = self.beta1;
let gamma = self.gamma.max(self.eps);
let rho = self.rho;
let eps = self.eps;
let lr = self.lr;
let wd = self.weight_decay;
let m = zeros_entry(&mut self.m, name, param.len());
for i in 0..param.len() {
m[i] = b1 * m[i] + (1.0 - b1) * grad[i];
}
let h_default = vec![0.0f32; param.len()];
let h = self.h.get(name).unwrap_or(&h_default);
for i in 0..param.len() {
let denom = (gamma * h[i]).max(eps);
let mut u = m[i] / denom;
if u > rho {
u = rho;
} else if u < -rho {
u = -rho;
}
param[i] -= lr * (u + wd * param[i]);
}
}
fn end_iteration(&mut self) {
self.step += 1;
}
}