use std::collections::HashMap;
use crate::Optimizer;
use crate::common::{jacobi_eigh_sym, matmul, zeros_entry};
#[derive(Debug, Clone)]
struct SoapState {
l: Vec<f32>, r: Vec<f32>, ql: Vec<f32>, qr: Vec<f32>, m_rot: Vec<f32>, v_rot: Vec<f32>, initialized_basis: bool,
}
#[derive(Debug, Clone)]
pub struct Soap {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub shampoo_beta: f32,
pub eps: f32,
pub weight_decay: f32,
pub precond_freq: u64,
pub jacobi_sweeps: u32,
step: u64,
state: HashMap<String, SoapState>,
fb_m: HashMap<String, Vec<f32>>,
fb_v: HashMap<String, Vec<f32>>,
}
impl Soap {
pub fn new(lr: f32) -> Self {
Self {
lr,
beta1: 0.95,
beta2: 0.95,
shampoo_beta: 0.95,
eps: 1e-8,
weight_decay: 0.01,
precond_freq: 10,
jacobi_sweeps: 30,
step: 0,
state: HashMap::new(),
fb_m: HashMap::new(),
fb_v: HashMap::new(),
}
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
}
impl Optimizer for Soap {
fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
debug_assert_eq!(param.len(), grad.len());
if shape.len() != 2 {
adamw_fallback(self, name, param, grad);
return;
}
let (m, n) = (shape[0], shape[1]);
debug_assert_eq!(m * n, param.len());
let t = (self.step + 1) as f64;
let b1 = self.beta1 as f64;
let b2 = self.beta2 as f64;
let bc1 = 1.0 - b1.powf(t);
let bc2 = 1.0 - b2.powf(t);
let sb = self.shampoo_beta as f64;
let eps = self.eps;
let lr = self.lr;
let wd = self.weight_decay;
let st = self
.state
.entry(name.to_owned())
.or_insert_with(|| SoapState {
l: vec![0.0; m * m],
r: vec![0.0; n * n],
ql: identity(m),
qr: identity(n),
m_rot: vec![0.0; m * n],
v_rot: vec![0.0; m * n],
initialized_basis: false,
});
for i in 0..m {
for j in 0..m {
let mut s = 0.0f64;
for p in 0..n {
s += grad[i * n + p] as f64 * grad[j * n + p] as f64;
}
let lij = sb * st.l[i * m + j] as f64 + (1.0 - sb) * s;
st.l[i * m + j] = lij as f32;
}
}
for i in 0..n {
for j in 0..n {
let mut s = 0.0f64;
for p in 0..m {
s += grad[p * n + i] as f64 * grad[p * n + j] as f64;
}
let rij = sb * st.r[i * n + j] as f64 + (1.0 - sb) * s;
st.r[i * n + j] = rij as f32;
}
}
let need_rediag = !st.initialized_basis || self.step.is_multiple_of(self.precond_freq);
if need_rediag {
let mut l_copy = st.l.clone();
let mut r_copy = st.r.clone();
jacobi_eigh_sym(&mut l_copy, m, &mut st.ql, self.jacobi_sweeps, 1e-6);
jacobi_eigh_sym(&mut r_copy, n, &mut st.qr, self.jacobi_sweeps, 1e-6);
st.initialized_basis = true;
}
let mut tmp = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut s = 0.0f32;
for p in 0..m {
s += st.ql[p * m + i] * grad[p * n + j];
}
tmp[i * n + j] = s;
}
}
let mut g_rot = vec![0.0f32; m * n];
matmul(&tmp, &st.qr, m, n, n, &mut g_rot);
let mut u_rot = vec![0.0f32; m * n];
for k in 0..m * n {
let g = g_rot[k] as f64;
let mi = b1 * st.m_rot[k] as f64 + (1.0 - b1) * g;
let vi = b2 * st.v_rot[k] as f64 + (1.0 - b2) * g * g;
st.m_rot[k] = mi as f32;
st.v_rot[k] = vi as f32;
let m_hat = mi / bc1;
let v_hat = vi / bc2;
u_rot[k] = (m_hat / (v_hat.sqrt() + eps as f64)) as f32;
}
matmul(&st.ql, &u_rot, m, m, n, &mut tmp);
let mut u = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut s = 0.0f32;
for p in 0..n {
s += tmp[i * n + p] * st.qr[j * n + p];
}
u[i * n + j] = s;
}
}
for i in 0..m * n {
param[i] -= lr * (u[i] + wd * param[i]);
}
}
fn end_iteration(&mut self) {
self.step += 1;
}
}
fn identity(n: usize) -> Vec<f32> {
let mut out = vec![0.0; n * n];
for i in 0..n {
out[i * n + i] = 1.0;
}
out
}
fn adamw_fallback(opt: &mut Soap, name: &str, param: &mut [f32], grad: &[f32]) {
let t = (opt.step + 1) as f64;
let b1 = opt.beta1 as f64;
let b2 = opt.beta2 as f64;
let bc1 = 1.0 - b1.powf(t);
let bc2 = 1.0 - b2.powf(t);
let m = zeros_entry(&mut opt.fb_m, name, param.len());
let v = zeros_entry(&mut opt.fb_v, name, param.len());
let eps = opt.eps as f64;
let lr = opt.lr as f64;
let wd = opt.weight_decay as f64;
for i in 0..param.len() {
let g = grad[i] as f64;
let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
m[i] = mi as f32;
v[i] = vi as f32;
let p = param[i] as f64;
param[i] = (p - lr * (mi / bc1 / ((vi / bc2).sqrt() + eps) + wd * p)) as f32;
}
}