use std::collections::HashMap;
use crate::Optimizer;
use crate::common::{matmul, zeros_entry};
#[derive(Debug, Clone)]
struct KronState {
ql: Vec<f32>, qr: Vec<f32>, }
#[derive(Debug, Clone)]
pub struct KronPsgd {
pub lr: f32,
pub precond_lr: f32,
pub momentum: f32,
pub weight_decay: f32,
pub eps: f32,
pub clip: f32,
state: HashMap<String, KronState>,
mom: HashMap<String, Vec<f32>>,
}
impl KronPsgd {
pub fn new(lr: f32) -> Self {
Self {
lr,
precond_lr: 0.1,
momentum: 0.9,
weight_decay: 0.0,
eps: 1e-8,
clip: 1.0,
state: HashMap::new(),
mom: HashMap::new(),
}
}
pub fn with_momentum(mut self, mu: f32) -> Self {
self.momentum = mu;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
}
impl Optimizer for KronPsgd {
fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
debug_assert_eq!(param.len(), grad.len());
let lr = self.lr;
let wd = self.weight_decay;
if shape.len() != 2 {
let v = zeros_entry(&mut self.mom, name, param.len());
let mu = self.momentum;
for i in 0..param.len() {
v[i] = mu * v[i] + grad[i] + wd * param[i];
param[i] -= lr * v[i];
}
return;
}
let (m, n) = (shape[0], shape[1]);
debug_assert_eq!(m * n, param.len());
let st = self
.state
.entry(name.to_owned())
.or_insert_with(|| KronState {
ql: identity_triangular(m),
qr: identity_triangular(n),
});
let a = matmul_3(&st.ql, grad, &st.qr, m, n, true);
let b = matmul_3_inv(&st.ql, grad, &st.qr, m, n);
update_factor(&mut st.ql, &a, &b, m, n, true, self.precond_lr, self.eps);
update_factor(&mut st.qr, &a, &b, m, n, false, self.precond_lr, self.eps);
let mut ql_t_ql = vec![0.0f32; m * m];
for i in 0..m {
for j in 0..m {
let mut s = 0.0f32;
for p in 0..m {
s += st.ql[p * m + i] * st.ql[p * m + j];
}
ql_t_ql[i * m + j] = s;
}
}
let mut qr_qr_t = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..n {
let mut s = 0.0f32;
for p in 0..n {
s += st.qr[i * n + p] * st.qr[j * n + p];
}
qr_qr_t[i * n + j] = s;
}
}
let mut tmp = vec![0.0f32; m * n];
matmul(&ql_t_ql, grad, m, m, n, &mut tmp);
let mut p_g = vec![0.0f32; m * n];
matmul(&tmp, &qr_qr_t, m, n, n, &mut p_g);
let mut max_abs = 0.0f32;
for &x in &p_g {
if x.abs() > max_abs {
max_abs = x.abs();
}
}
let scale = if max_abs > self.clip {
self.clip / max_abs
} else {
1.0
};
let v = zeros_entry(&mut self.mom, name, param.len());
let mu = self.momentum;
for i in 0..param.len() {
let g = scale * p_g[i] + wd * param[i];
v[i] = mu * v[i] + g;
param[i] -= lr * v[i];
}
}
}
fn identity_triangular(n: usize) -> Vec<f32> {
let mut out = vec![0.0; n * n];
for i in 0..n {
out[i * n + i] = 1.0;
}
out
}
fn matmul_3(ql: &[f32], g: &[f32], qr: &[f32], m: usize, n: usize, trans_q_r: bool) -> Vec<f32> {
let mut t1 = vec![0.0f32; m * n];
matmul(ql, g, m, m, n, &mut t1);
let mut out = vec![0.0f32; m * n];
if trans_q_r {
for i in 0..m {
for j in 0..n {
let mut s = 0.0f32;
for p in 0..n {
s += t1[i * n + p] * qr[j * n + p];
}
out[i * n + j] = s;
}
}
} else {
matmul(&t1, qr, m, n, n, &mut out);
}
out
}
fn matmul_3_inv(ql: &[f32], g: &[f32], qr: &[f32], m: usize, n: usize) -> Vec<f32> {
let mut x = g.to_vec();
for j in 0..n {
for i in 0..m {
let mut s = x[i * n + j];
for p in 0..i {
s -= ql[p * m + i] * x[p * n + j];
}
let d = ql[i * m + i];
x[i * n + j] = if d.abs() > 1e-12 { s / d } else { 0.0 };
}
}
let mut y = x;
for i in 0..m {
for j in 0..n {
let mut s = y[i * n + j];
for p in 0..j {
s -= y[i * n + p] * qr[p * n + j];
}
let d = qr[j * n + j];
y[i * n + j] = if d.abs() > 1e-12 { s / d } else { 0.0 };
}
}
y
}
fn update_factor(
q: &mut [f32],
a: &[f32],
b: &[f32],
m: usize,
n: usize,
which: bool,
plr: f32,
eps: f32,
) {
let dim = if which { m } else { n };
let mut grad_q = vec![0.0f32; dim * dim];
let mut norm = 0.0f64;
for i in 0..dim {
for j in 0..dim {
let mut a_term = 0.0f32;
let mut b_term = 0.0f32;
if which {
for p in 0..n {
a_term += a[i * n + p] * a[j * n + p];
b_term += b[i * n + p] * b[j * n + p];
}
} else {
for p in 0..m {
a_term += a[p * n + i] * a[p * n + j];
b_term += b[p * n + i] * b[p * n + j];
}
}
let d = a_term - b_term;
grad_q[i * dim + j] = d;
norm += d as f64 * d as f64;
}
}
let scale = plr / ((norm.sqrt() as f32) + eps);
for i in 0..dim {
for j in 0..dim {
if j < i {
grad_q[i * dim + j] = 0.0; }
}
}
let mut q_new = vec![0.0f32; dim * dim];
matmul(q, &grad_q, dim, dim, dim, &mut q_new);
for k in 0..dim * dim {
q[k] -= scale * q_new[k];
}
}