use crate::error::{MathError, Result};
use crate::utils::EPS;
#[derive(Debug, Clone)]
pub struct KFACLayer {
pub a_factor: Vec<Vec<f64>>,
pub g_factor: Vec<Vec<f64>>,
damping: f64,
ema_factor: f64,
num_updates: usize,
}
impl KFACLayer {
pub fn new(input_dim: usize, output_dim: usize) -> Self {
Self {
a_factor: vec![vec![0.0; input_dim]; input_dim],
g_factor: vec![vec![0.0; output_dim]; output_dim],
damping: 1e-3,
ema_factor: 0.95,
num_updates: 0,
}
}
pub fn with_damping(mut self, damping: f64) -> Self {
self.damping = damping.max(EPS);
self
}
pub fn with_ema(mut self, ema: f64) -> Self {
self.ema_factor = ema.clamp(0.0, 1.0);
self
}
pub fn update(&mut self, activations: &[Vec<f64>], gradients: &[Vec<f64>]) -> Result<()> {
if activations.is_empty() || gradients.is_empty() {
return Err(MathError::empty_input("batch"));
}
let batch_size = activations.len();
if gradients.len() != batch_size {
return Err(MathError::dimension_mismatch(batch_size, gradients.len()));
}
let input_dim = self.a_factor.len();
let output_dim = self.g_factor.len();
let mut new_a = vec![vec![0.0; input_dim]; input_dim];
for act in activations {
if act.len() != input_dim {
return Err(MathError::dimension_mismatch(input_dim, act.len()));
}
for i in 0..input_dim {
for j in 0..input_dim {
new_a[i][j] += act[i] * act[j] / batch_size as f64;
}
}
}
let mut new_g = vec![vec![0.0; output_dim]; output_dim];
for grad in gradients {
if grad.len() != output_dim {
return Err(MathError::dimension_mismatch(output_dim, grad.len()));
}
for i in 0..output_dim {
for j in 0..output_dim {
new_g[i][j] += grad[i] * grad[j] / batch_size as f64;
}
}
}
if self.num_updates == 0 {
self.a_factor = new_a;
self.g_factor = new_g;
} else {
for i in 0..input_dim {
for j in 0..input_dim {
self.a_factor[i][j] = self.ema_factor * self.a_factor[i][j]
+ (1.0 - self.ema_factor) * new_a[i][j];
}
}
for i in 0..output_dim {
for j in 0..output_dim {
self.g_factor[i][j] = self.ema_factor * self.g_factor[i][j]
+ (1.0 - self.ema_factor) * new_g[i][j];
}
}
}
self.num_updates += 1;
Ok(())
}
pub fn natural_gradient(&self, weight_grad: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let output_dim = self.g_factor.len();
let input_dim = self.a_factor.len();
if weight_grad.len() != output_dim {
return Err(MathError::dimension_mismatch(output_dim, weight_grad.len()));
}
let a_damped = self.add_damping(&self.a_factor);
let g_damped = self.add_damping(&self.g_factor);
let a_inv = self.invert_matrix(&a_damped)?;
let g_inv = self.invert_matrix(&g_damped)?;
let mut grad_a_inv = vec![vec![0.0; input_dim]; output_dim];
for i in 0..output_dim {
for j in 0..input_dim {
for k in 0..input_dim {
grad_a_inv[i][j] += weight_grad[i][k] * a_inv[k][j];
}
}
}
let mut nat_grad = vec![vec![0.0; input_dim]; output_dim];
for i in 0..output_dim {
for j in 0..input_dim {
for k in 0..output_dim {
nat_grad[i][j] += g_inv[i][k] * grad_a_inv[k][j];
}
}
}
Ok(nat_grad)
}
fn add_damping(&self, matrix: &[Vec<f64>]) -> Vec<Vec<f64>> {
let n = matrix.len();
let mut damped = matrix.to_vec();
let trace: f64 = (0..n).map(|i| matrix[i][i]).sum();
let pi_damping = (self.damping * trace / n as f64).max(EPS);
for i in 0..n {
damped[i][i] += pi_damping;
}
damped
}
fn invert_matrix(&self, matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let n = matrix.len();
let mut l = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..=i {
let mut sum = matrix[i][j];
for k in 0..j {
sum -= l[i][k] * l[j][k];
}
if i == j {
if sum <= 0.0 {
return Err(MathError::numerical_instability(
"Matrix not positive definite in K-FAC",
));
}
l[i][j] = sum.sqrt();
} else {
l[i][j] = sum / l[j][j];
}
}
}
let mut l_inv = vec![vec![0.0; n]; n];
for i in 0..n {
l_inv[i][i] = 1.0 / l[i][i];
for j in (i + 1)..n {
let mut sum = 0.0;
for k in i..j {
sum -= l[j][k] * l_inv[k][i];
}
l_inv[j][i] = sum / l[j][j];
}
}
let mut inv = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
for k in 0..n {
inv[i][j] += l_inv[k][i] * l_inv[k][j];
}
}
}
Ok(inv)
}
pub fn reset(&mut self) {
let input_dim = self.a_factor.len();
let output_dim = self.g_factor.len();
self.a_factor = vec![vec![0.0; input_dim]; input_dim];
self.g_factor = vec![vec![0.0; output_dim]; output_dim];
self.num_updates = 0;
}
}
#[derive(Debug, Clone)]
pub struct KFACApproximation {
layers: Vec<KFACLayer>,
learning_rate: f64,
damping: f64,
}
impl KFACApproximation {
pub fn new(layer_dims: &[(usize, usize)]) -> Self {
let layers = layer_dims
.iter()
.map(|&(input, output)| KFACLayer::new(input, output))
.collect();
Self {
layers,
learning_rate: 0.01,
damping: 1e-3,
}
}
pub fn with_learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = lr.max(EPS);
self
}
pub fn with_damping(mut self, damping: f64) -> Self {
self.damping = damping.max(EPS);
for layer in &mut self.layers {
layer.damping = damping;
}
self
}
pub fn update_layer(
&mut self,
layer_idx: usize,
activations: &[Vec<f64>],
gradients: &[Vec<f64>],
) -> Result<()> {
if layer_idx >= self.layers.len() {
return Err(MathError::invalid_parameter(
"layer_idx",
"index out of bounds",
));
}
self.layers[layer_idx].update(activations, gradients)
}
pub fn natural_gradient_layer(
&self,
layer_idx: usize,
weight_grad: &[Vec<f64>],
) -> Result<Vec<Vec<f64>>> {
if layer_idx >= self.layers.len() {
return Err(MathError::invalid_parameter(
"layer_idx",
"index out of bounds",
));
}
let mut nat_grad = self.layers[layer_idx].natural_gradient(weight_grad)?;
for row in &mut nat_grad {
for val in row {
*val *= -self.learning_rate;
}
}
Ok(nat_grad)
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn reset(&mut self) {
for layer in &mut self.layers {
layer.reset();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kfac_layer_creation() {
let layer = KFACLayer::new(10, 5);
assert_eq!(layer.a_factor.len(), 10);
assert_eq!(layer.g_factor.len(), 5);
}
#[test]
fn test_kfac_layer_update() {
let mut layer = KFACLayer::new(3, 2);
let activations = vec![vec![1.0, 0.0, 1.0], vec![0.0, 1.0, 1.0]];
let gradients = vec![vec![0.5, 0.5], vec![0.3, 0.7]];
layer.update(&activations, &gradients).unwrap();
assert!(layer.a_factor[0][0] > 0.0);
assert!(layer.g_factor[0][0] > 0.0);
}
#[test]
fn test_kfac_natural_gradient() {
let mut layer = KFACLayer::new(2, 2).with_damping(0.1);
let activations = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let gradients = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
layer.update(&activations, &gradients).unwrap();
let weight_grad = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
let nat_grad = layer.natural_gradient(&weight_grad).unwrap();
assert_eq!(nat_grad.len(), 2);
assert_eq!(nat_grad[0].len(), 2);
}
#[test]
fn test_kfac_full_network() {
let kfac = KFACApproximation::new(&[(10, 20), (20, 5)])
.with_learning_rate(0.01)
.with_damping(0.001);
assert_eq!(kfac.num_layers(), 2);
}
}