use crate::error::AutogradError;
use crate::optim::lbfgs::{dot, l2_norm};
pub struct NaturalGradient {
pub learning_rate: f64,
pub damping: f64,
pub cg_max_iter: usize,
pub cg_tol: f64,
}
impl NaturalGradient {
pub fn new(lr: f64) -> Self {
Self {
learning_rate: lr,
damping: 1e-4,
cg_max_iter: 50,
cg_tol: 1e-8,
}
}
pub fn with_damping(mut self, damping: f64) -> Self {
self.damping = damping;
self
}
pub fn with_cg_max_iter(mut self, n: usize) -> Self {
self.cg_max_iter = n;
self
}
pub fn step(&self, params: &mut Vec<f64>, grad: &[f64], fisher_samples: &[Vec<f64>]) {
let m = fisher_samples.len();
if m == 0 {
for (p, g) in params.iter_mut().zip(grad.iter()) {
*p -= self.learning_rate * g;
}
return;
}
let n = params.len();
let nat_grad = self.solve_fisher_cg(grad, fisher_samples, n, m);
for (p, ng) in params.iter_mut().zip(nat_grad.iter()) {
*p -= self.learning_rate * ng;
}
}
pub fn minimize<F>(
&self,
grad_and_samples_fn: F,
mut x: Vec<f64>,
max_iter: usize,
tol: f64,
) -> Result<NaturalGradientResult, AutogradError>
where
F: Fn(&[f64]) -> (f64, Vec<f64>, Vec<Vec<f64>>),
{
let mut loss_history = Vec::with_capacity(max_iter + 1);
for iter in 0..max_iter {
let (f, g, samples) = grad_and_samples_fn(&x);
loss_history.push(f);
let grad_norm = l2_norm(&g);
if grad_norm < tol {
return Ok(NaturalGradientResult {
x,
f,
grad_norm,
iterations: iter,
converged: true,
loss_history,
});
}
self.step(&mut x, &g, &samples);
}
let (f, g, _) = grad_and_samples_fn(&x);
let grad_norm = l2_norm(&g);
loss_history.push(f);
Ok(NaturalGradientResult {
x,
f,
grad_norm,
iterations: max_iter,
converged: grad_norm < tol,
loss_history,
})
}
fn solve_fisher_cg(
&self,
g: &[f64],
samples: &[Vec<f64>],
n: usize,
m: usize,
) -> Vec<f64> {
let m_f = m as f64;
let damping = self.damping;
let fv = |v: &[f64]| -> Vec<f64> {
let mut result = vec![0.0_f64; n];
for s in samples {
let sv: f64 = dot(s, v);
for j in 0..n {
result[j] += sv * s[j] / m_f;
}
}
for j in 0..n {
result[j] += damping * v[j];
}
result
};
let mut x = vec![0.0_f64; n];
let mut r = g.to_vec(); let mut p = r.clone();
let mut rr: f64 = dot(&r, &r);
for _ in 0..self.cg_max_iter {
if rr < self.cg_tol * self.cg_tol {
break;
}
let ap = fv(&p);
let pap: f64 = dot(&p, &ap);
if pap < 1e-20 {
break;
}
let alpha = rr / pap;
for i in 0..n {
x[i] += alpha * p[i];
r[i] -= alpha * ap[i];
}
let rr_new: f64 = dot(&r, &r);
let beta = rr_new / rr.max(1e-20);
for i in 0..n {
p[i] = r[i] + beta * p[i];
}
rr = rr_new;
}
x
}
}
#[derive(Debug, Clone)]
pub struct KFACLayer {
pub in_dim: usize,
pub out_dim: usize,
pub a_factor: Vec<f64>,
pub g_factor: Vec<f64>,
pub damping: f64,
}
impl KFACLayer {
pub fn new(in_dim: usize, out_dim: usize, damping: f64) -> Self {
Self {
in_dim,
out_dim,
a_factor: vec![0.0_f64; in_dim * in_dim],
g_factor: vec![0.0_f64; out_dim * out_dim],
damping,
}
}
pub fn update(&mut self, activations: &[f64], grad_outputs: &[f64], batch: usize, momentum: f64) {
let batch_f = batch as f64;
let d = self.damping;
for i in 0..self.in_dim {
for j in 0..self.in_dim {
let mut aij = 0.0_f64;
for b in 0..batch {
aij += activations[b * self.in_dim + i] * activations[b * self.in_dim + j];
}
aij /= batch_f;
self.a_factor[i * self.in_dim + j] =
(1.0 - momentum) * self.a_factor[i * self.in_dim + j] + momentum * aij;
}
self.a_factor[i * self.in_dim + i] += d;
}
for i in 0..self.out_dim {
for j in 0..self.out_dim {
let mut gij = 0.0_f64;
for b in 0..batch {
gij +=
grad_outputs[b * self.out_dim + i] * grad_outputs[b * self.out_dim + j];
}
gij /= batch_f;
self.g_factor[i * self.out_dim + j] =
(1.0 - momentum) * self.g_factor[i * self.out_dim + j] + momentum * gij;
}
self.g_factor[i * self.out_dim + i] += d;
}
}
pub fn precondition(&self, dw: &[f64]) -> Result<Vec<f64>, AutogradError> {
let a_inv = cholesky_invert(&self.a_factor, self.in_dim)?;
let g_inv = cholesky_invert(&self.g_factor, self.out_dim)?;
let mut tmp = vec![0.0_f64; self.out_dim * self.in_dim];
for i in 0..self.out_dim {
for j in 0..self.in_dim {
let mut v = 0.0_f64;
for k in 0..self.in_dim {
v += dw[i * self.in_dim + k] * a_inv[k * self.in_dim + j];
}
tmp[i * self.in_dim + j] = v;
}
}
let mut result = vec![0.0_f64; self.out_dim * self.in_dim];
for i in 0..self.out_dim {
for j in 0..self.in_dim {
let mut v = 0.0_f64;
for k in 0..self.out_dim {
v += g_inv[i * self.out_dim + k] * tmp[k * self.in_dim + j];
}
result[i * self.in_dim + j] = v;
}
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct NaturalGradientResult {
pub x: Vec<f64>,
pub f: f64,
pub grad_norm: f64,
pub iterations: usize,
pub converged: bool,
pub loss_history: Vec<f64>,
}
fn cholesky_invert(a: &[f64], n: usize) -> Result<Vec<f64>, AutogradError> {
let mut l = vec![0.0_f64; n * n];
for i in 0..n {
for j in 0..=i {
let mut sum = a[i * n + j];
for k in 0..j {
sum -= l[i * n + k] * l[j * n + k];
}
if i == j {
if sum <= 0.0 {
return Err(AutogradError::OperationError(
"Cholesky: matrix is not positive definite".to_string(),
));
}
l[i * n + i] = sum.sqrt();
} else {
let lii = l[j * n + j];
if lii.abs() < 1e-20 {
return Err(AutogradError::OperationError(
"Cholesky: near-zero diagonal".to_string(),
));
}
l[i * n + j] = sum / lii;
}
}
}
let mut y = vec![0.0_f64; n * n];
for col in 0..n {
for row in 0..n {
let mut val = if row == col { 1.0_f64 } else { 0.0_f64 };
for k in 0..row {
val -= l[row * n + k] * y[k * n + col];
}
let lrr = l[row * n + row];
if lrr.abs() < 1e-20 {
return Err(AutogradError::OperationError(
"Cholesky invert: singular L".to_string(),
));
}
y[row * n + col] = val / lrr;
}
}
let mut inv = vec![0.0_f64; n * n];
for col in 0..n {
let mut x = vec![0.0_f64; n];
for row in (0..n).rev() {
let mut val = y[row * n + col];
for k in (row + 1)..n {
val -= l[k * n + row] * x[k];
}
let lrr = l[row * n + row];
x[row] = val / lrr.max(1e-20);
}
for row in 0..n {
inv[row * n + col] = x[row];
}
}
Ok(inv)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_natural_gradient_step_reduces_gradient_norm() {
let params_start = vec![2.0_f64, -3.0];
let gradient = vec![4.0_f64, -6.0];
let fisher_samples = vec![gradient.clone()];
let ng = NaturalGradient::new(0.1).with_damping(1e-3);
let mut params = params_start.clone();
ng.step(&mut params, &gradient, &fisher_samples);
let grad_norm_before = l2_norm(&gradient);
let grad_after: Vec<f64> = params.iter().map(|p| 2.0 * p).collect();
let grad_norm_after = l2_norm(&grad_after);
assert!(
grad_norm_after < grad_norm_before,
"grad norm did not decrease: before={grad_norm_before} after={grad_norm_after}"
);
}
#[test]
fn test_natural_gradient_fallback_no_samples() {
let mut params = vec![1.0_f64, 2.0];
let grad = vec![0.5_f64, -1.0];
let ng = NaturalGradient::new(0.1);
ng.step(&mut params, &grad, &[]);
assert!((params[0] - (1.0 - 0.1 * 0.5)).abs() < 1e-10);
assert!((params[1] - (2.0 + 0.1 * 1.0)).abs() < 1e-10);
}
#[test]
fn test_cholesky_invert_identity() {
let mut a = vec![0.0_f64; 9];
a[0] = 1.0;
a[4] = 1.0;
a[8] = 1.0;
let inv = cholesky_invert(&a, 3).expect("cholesky invert failed");
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0_f64 } else { 0.0_f64 };
assert!((inv[i * 3 + j] - expected).abs() < 1e-10);
}
}
}
#[test]
fn test_kfac_layer_update() {
let mut layer = KFACLayer::new(2, 2, 1e-3);
let acts = vec![1.0_f64, 0.0];
let grads = vec![0.0_f64, 1.0];
layer.update(&acts, &grads, 1, 1.0);
assert!((layer.a_factor[0] - (1.0 + 1e-3)).abs() < 1e-8);
}
}