use crate::error::AutogradError;
use scirs2_core::ndarray::{Array1, Array2};
const H: f64 = 1e-5;
fn gradient_fd(f: &dyn Fn(&[f64]) -> f64, x: &[f64]) -> Vec<f64> {
let n = x.len();
let mut g = vec![0.0_f64; n];
let mut xp = x.to_vec();
let mut xm = x.to_vec();
for i in 0..n {
xp[i] = x[i] + H;
xm[i] = x[i] - H;
g[i] = (f(&xp) - f(&xm)) / (2.0 * H);
xp[i] = x[i];
xm[i] = x[i];
}
g
}
fn solve_linear_system(
a: &Array2<f64>,
b: &Array1<f64>,
tol: f64,
) -> Result<Array1<f64>, AutogradError> {
let n = b.len();
if a.nrows() != n || a.ncols() != n {
return Err(AutogradError::ShapeMismatch(format!(
"solve_linear_system: expected {}×{} matrix, got {}×{}",
n,
n,
a.nrows(),
a.ncols()
)));
}
let mut aug = vec![0.0_f64; n * (n + 1)];
for i in 0..n {
for j in 0..n {
aug[i * (n + 1) + j] = a[[i, j]];
}
aug[i * (n + 1) + n] = b[i];
}
for col in 0..n {
let mut max_val = aug[col * (n + 1) + col].abs();
let mut max_row = col;
for row in (col + 1)..n {
let v = aug[row * (n + 1) + col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_val < tol {
return Err(AutogradError::OperationError(format!(
"solve_linear_system: matrix is singular or nearly singular (pivot={max_val})"
)));
}
if max_row != col {
for k in 0..=(n) {
aug.swap(col * (n + 1) + k, max_row * (n + 1) + k);
}
}
let pivot = aug[col * (n + 1) + col];
for row in (col + 1)..n {
let factor = aug[row * (n + 1) + col] / pivot;
for k in col..=(n) {
let delta = aug[col * (n + 1) + k] * factor;
aug[row * (n + 1) + k] -= delta;
}
}
}
let mut x = vec![0.0_f64; n];
for i in (0..n).rev() {
let mut sum = aug[i * (n + 1) + n];
for j in (i + 1)..n {
sum -= aug[i * (n + 1) + j] * x[j];
}
let diag = aug[i * (n + 1) + i];
if diag.abs() < tol {
return Err(AutogradError::OperationError(
"solve_linear_system: zero diagonal during back-substitution".to_string(),
));
}
x[i] = sum / diag;
}
Ok(Array1::from(x))
}
fn invert_matrix(a: &Array2<f64>, tol: f64) -> Result<Array2<f64>, AutogradError> {
let n = a.nrows();
if n != a.ncols() {
return Err(AutogradError::ShapeMismatch(format!(
"invert_matrix: expected square matrix, got {}×{}",
n,
a.ncols()
)));
}
let cols = 2 * n;
let mut aug = vec![0.0_f64; n * cols];
for i in 0..n {
for j in 0..n {
aug[i * cols + j] = a[[i, j]];
}
aug[i * cols + n + i] = 1.0; }
for col in 0..n {
let mut max_val = aug[col * cols + col].abs();
let mut max_row = col;
for row in (col + 1)..n {
let v = aug[row * cols + col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_val < tol {
return Err(AutogradError::OperationError(format!(
"invert_matrix: matrix is singular (pivot={max_val})"
)));
}
if max_row != col {
for k in 0..cols {
aug.swap(col * cols + k, max_row * cols + k);
}
}
let pivot = aug[col * cols + col];
for k in 0..cols {
aug[col * cols + k] /= pivot;
}
for row in 0..n {
if row == col {
continue;
}
let factor = aug[row * cols + col];
for k in 0..cols {
let delta = aug[col * cols + k] * factor;
aug[row * cols + k] -= delta;
}
}
}
let mut inv = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
inv[[i, j]] = aug[i * cols + n + j];
}
}
Ok(inv)
}
pub fn natural_gradient(
grad: &Array1<f64>,
fisher: &Array2<f64>,
damping: f64,
) -> Result<Array1<f64>, AutogradError> {
let n = grad.len();
if fisher.nrows() != n || fisher.ncols() != n {
return Err(AutogradError::ShapeMismatch(format!(
"natural_gradient: grad has length {n} but fisher is {}×{}",
fisher.nrows(),
fisher.ncols()
)));
}
if damping < 0.0 {
return Err(AutogradError::OperationError(
"natural_gradient: damping must be non-negative".to_string(),
));
}
let mut f_reg = fisher.clone();
for i in 0..n {
f_reg[[i, i]] += damping;
}
solve_linear_system(&f_reg, grad, 1e-12)
}
pub fn fisher_information_matrix(
model_fn: impl Fn(&[f64], &[f64]) -> f64,
params: &[f64],
data: &[Vec<f64>],
) -> Result<Array2<f64>, AutogradError> {
let p = params.len();
if p == 0 {
return Err(AutogradError::OperationError(
"fisher_information_matrix: params must be non-empty".to_string(),
));
}
if data.is_empty() {
return Err(AutogradError::OperationError(
"fisher_information_matrix: data must be non-empty".to_string(),
));
}
let n = data.len() as f64;
let mut fim = Array2::<f64>::zeros((p, p));
for sample in data.iter() {
let grad = {
let f = |theta: &[f64]| model_fn(theta, sample);
gradient_fd(&f, params)
};
for i in 0..p {
for j in 0..p {
fim[[i, j]] += grad[i] * grad[j];
}
}
}
fim.mapv_inplace(|v| v / n);
Ok(fim)
}
pub fn gauss_newton_matrix(
jacobian: &Array2<f64>,
residuals: &Array1<f64>,
) -> Result<Array2<f64>, AutogradError> {
let (m, p) = (jacobian.nrows(), jacobian.ncols());
if residuals.len() != m {
return Err(AutogradError::ShapeMismatch(format!(
"gauss_newton_matrix: jacobian has {} rows but residuals has {} elements",
m,
residuals.len()
)));
}
let mut g = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
let mut s = 0.0_f64;
for k in 0..m {
s += jacobian[[k, i]] * jacobian[[k, j]];
}
g[[i, j]] = s;
}
}
Ok(g)
}
pub fn kfac_update(
grads: &[Array2<f64>],
a_inv: &[Array2<f64>],
g_inv: &[Array2<f64>],
) -> Result<Vec<Array2<f64>>, AutogradError> {
let num_layers = grads.len();
if a_inv.len() != num_layers {
return Err(AutogradError::ShapeMismatch(format!(
"kfac_update: grads has {} layers but a_inv has {}",
num_layers,
a_inv.len()
)));
}
if g_inv.len() != num_layers {
return Err(AutogradError::ShapeMismatch(format!(
"kfac_update: grads has {} layers but g_inv has {}",
num_layers,
g_inv.len()
)));
}
let mut result = Vec::with_capacity(num_layers);
for l in 0..num_layers {
let delta = &grads[l];
let ai = &a_inv[l];
let gi = &g_inv[l];
let (p, q) = (delta.nrows(), delta.ncols());
if gi.nrows() != p || gi.ncols() != p {
return Err(AutogradError::ShapeMismatch(format!(
"kfac_update: layer {l}: gradient is {p}×{q} but g_inv is {}×{}",
gi.nrows(),
gi.ncols()
)));
}
if ai.nrows() != q || ai.ncols() != q {
return Err(AutogradError::ShapeMismatch(format!(
"kfac_update: layer {l}: gradient is {p}×{q} but a_inv is {}×{}",
ai.nrows(),
ai.ncols()
)));
}
let mut tmp = Array2::<f64>::zeros((p, q));
for i in 0..p {
for j in 0..q {
let mut s = 0.0_f64;
for k in 0..p {
s += gi[[i, k]] * delta[[k, j]];
}
tmp[[i, j]] = s;
}
}
let mut precond = Array2::<f64>::zeros((p, q));
for i in 0..p {
for j in 0..q {
let mut s = 0.0_f64;
for k in 0..q {
s += tmp[[i, k]] * ai[[k, j]];
}
precond[[i, j]] = s;
}
}
result.push(precond);
}
Ok(result)
}
pub fn kfac_factors(
layer_inputs: &Array2<f64>,
layer_grads: &Array2<f64>,
damping: f64,
) -> Result<(Array2<f64>, Array2<f64>), AutogradError> {
let (n, d_in) = (layer_inputs.nrows(), layer_inputs.ncols());
let (n2, d_out) = (layer_grads.nrows(), layer_grads.ncols());
if n != n2 {
return Err(AutogradError::ShapeMismatch(format!(
"kfac_factors: layer_inputs has {n} rows but layer_grads has {n2} rows"
)));
}
if n == 0 {
return Err(AutogradError::OperationError(
"kfac_factors: batch size must be > 0".to_string(),
));
}
let nf = n as f64;
let mut a_cov = Array2::<f64>::zeros((d_in, d_in));
for k in 0..n {
for i in 0..d_in {
for j in 0..d_in {
a_cov[[i, j]] += layer_inputs[[k, i]] * layer_inputs[[k, j]];
}
}
}
a_cov.mapv_inplace(|v| v / nf);
let mut g_cov = Array2::<f64>::zeros((d_out, d_out));
for k in 0..n {
for i in 0..d_out {
for j in 0..d_out {
g_cov[[i, j]] += layer_grads[[k, i]] * layer_grads[[k, j]];
}
}
}
g_cov.mapv_inplace(|v| v / nf);
for i in 0..d_in {
a_cov[[i, i]] += damping;
}
for i in 0..d_out {
g_cov[[i, i]] += damping;
}
let a_inv = invert_matrix(&a_cov, 1e-12)?;
let g_inv = invert_matrix(&g_cov, 1e-12)?;
Ok((a_inv, g_inv))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{arr1, arr2, Array2};
const TOL: f64 = 1e-4;
#[test]
fn test_solve_linear_system_identity() {
let a = Array2::<f64>::eye(3);
let b = arr1(&[1.0_f64, 2.0, 3.0]);
let x = solve_linear_system(&a, &b, 1e-12).expect("solve identity");
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 2.0).abs() < 1e-10);
assert!((x[2] - 3.0).abs() < 1e-10);
}
#[test]
fn test_solve_linear_system_2x2() {
let a = arr2(&[[2.0_f64, 1.0], [1.0, 3.0]]);
let b = arr1(&[5.0_f64, 10.0]);
let x = solve_linear_system(&a, &b, 1e-12).expect("solve 2x2");
assert!((x[0] - 1.0).abs() < TOL, "x[0]={}", x[0]);
assert!((x[1] - 3.0).abs() < TOL, "x[1]={}", x[1]);
}
#[test]
fn test_solve_linear_system_singular_err() {
let a = arr2(&[[1.0_f64, 2.0], [2.0, 4.0]]);
let b = arr1(&[1.0_f64, 1.0]);
let r = solve_linear_system(&a, &b, 1e-8);
assert!(r.is_err(), "Singular matrix should return error");
}
#[test]
fn test_invert_matrix_identity() {
let a = Array2::<f64>::eye(3);
let inv = invert_matrix(&a, 1e-12).expect("invert identity");
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(inv[[i, j]] - expected).abs() < TOL,
"inv[{i},{j}]={}", inv[[i, j]]
);
}
}
}
#[test]
fn test_invert_matrix_diagonal() {
let a = arr2(&[[2.0_f64, 0.0], [0.0, 4.0]]);
let inv = invert_matrix(&a, 1e-12).expect("invert diagonal");
assert!((inv[[0, 0]] - 0.5).abs() < TOL);
assert!((inv[[1, 1]] - 0.25).abs() < TOL);
assert!(inv[[0, 1]].abs() < TOL);
}
#[test]
fn test_natural_gradient_identity_fisher() {
let g = arr1(&[1.0_f64, 2.0, 3.0]);
let f = Array2::<f64>::eye(3);
let ng = natural_gradient(&g, &f, 0.0).expect("natural gradient identity");
assert!((ng[0] - 1.0).abs() < TOL);
assert!((ng[1] - 2.0).abs() < TOL);
assert!((ng[2] - 3.0).abs() < TOL);
}
#[test]
fn test_natural_gradient_diagonal_fisher() {
let g = arr1(&[2.0_f64, 4.0]);
let f = arr2(&[[2.0_f64, 0.0], [0.0, 4.0]]);
let ng = natural_gradient(&g, &f, 1e-10).expect("natural gradient diagonal");
assert!((ng[0] - 1.0).abs() < TOL, "ng[0]={}", ng[0]);
assert!((ng[1] - 1.0).abs() < TOL, "ng[1]={}", ng[1]);
}
#[test]
fn test_natural_gradient_damping() {
let g = arr1(&[1.0_f64]);
let f = Array2::<f64>::zeros((1, 1));
let ng = natural_gradient(&g, &f, 2.0).expect("natural gradient damping");
assert!((ng[0] - 0.5).abs() < TOL, "ng[0]={}", ng[0]);
}
#[test]
fn test_natural_gradient_shape_error() {
let g = arr1(&[1.0_f64, 2.0]);
let f = Array2::<f64>::eye(3); let r = natural_gradient(&g, &f, 0.0);
assert!(r.is_err());
}
#[test]
fn test_fisher_information_gaussian_mean() {
let data = vec![
vec![0.0_f64],
vec![1.0],
vec![-1.0],
vec![2.0],
vec![-2.0],
];
let params = vec![0.0_f64];
let fim = fisher_information_matrix(
|theta: &[f64], x: &[f64]| -0.5 * (x[0] - theta[0]).powi(2),
¶ms,
&data,
)
.expect("FIM gaussian");
assert!(fim[[0, 0]] > 0.0, "FIM should be positive: {}", fim[[0, 0]]);
}
#[test]
fn test_fisher_information_empty_params_err() {
let data = vec![vec![1.0_f64]];
let r = fisher_information_matrix(|_, _| 0.0, &[], &data);
assert!(r.is_err());
}
#[test]
fn test_fisher_information_empty_data_err() {
let params = vec![1.0_f64];
let r = fisher_information_matrix(|_, _| 0.0, ¶ms, &[]);
assert!(r.is_err());
}
#[test]
fn test_gauss_newton_identity_jacobian() {
let j = Array2::<f64>::eye(2);
let r = arr1(&[1.0_f64, 1.0]);
let gn = gauss_newton_matrix(&j, &r).expect("GN identity");
assert!((gn[[0, 0]] - 1.0).abs() < 1e-10);
assert!((gn[[1, 1]] - 1.0).abs() < 1e-10);
assert!(gn[[0, 1]].abs() < 1e-10);
}
#[test]
fn test_gauss_newton_diagonal_jacobian() {
let j = arr2(&[[1.0_f64, 0.0], [0.0, 2.0]]);
let r = arr1(&[1.0_f64, 2.0]);
let gn = gauss_newton_matrix(&j, &r).expect("GN diagonal");
assert!((gn[[0, 0]] - 1.0).abs() < 1e-10, "G[0,0]={}", gn[[0, 0]]);
assert!((gn[[1, 1]] - 4.0).abs() < 1e-10, "G[1,1]={}", gn[[1, 1]]);
assert!(gn[[0, 1]].abs() < 1e-10);
}
#[test]
fn test_gauss_newton_shape_error() {
let j = arr2(&[[1.0_f64, 0.0], [0.0, 1.0]]);
let r = arr1(&[1.0_f64]); let res = gauss_newton_matrix(&j, &r);
assert!(res.is_err());
}
#[test]
fn test_gauss_newton_rectangular_jacobian() {
let j = arr2(&[[1.0_f64, 0.0], [0.0, 1.0], [1.0, 1.0]]);
let r = arr1(&[1.0_f64, 1.0, 1.0]);
let gn = gauss_newton_matrix(&j, &r).expect("GN rectangular");
assert!((gn[[0, 0]] - 2.0).abs() < 1e-10, "G[0,0]={}", gn[[0, 0]]);
assert!((gn[[0, 1]] - 1.0).abs() < 1e-10, "G[0,1]={}", gn[[0, 1]]);
assert!((gn[[1, 0]] - 1.0).abs() < 1e-10, "G[1,0]={}", gn[[1, 0]]);
assert!((gn[[1, 1]] - 2.0).abs() < 1e-10, "G[1,1]={}", gn[[1, 1]]);
}
#[test]
fn test_kfac_update_identity_factors() {
let delta = arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]);
let ai = Array2::<f64>::eye(2);
let gi = Array2::<f64>::eye(2);
let result = kfac_update(&[delta.clone()], &[ai], &[gi]).expect("K-FAC identity");
assert_eq!(result.len(), 1);
let p = &result[0];
assert!((p[[0, 0]] - 1.0).abs() < 1e-10);
assert!((p[[0, 1]] - 2.0).abs() < 1e-10);
assert!((p[[1, 0]] - 3.0).abs() < 1e-10);
assert!((p[[1, 1]] - 4.0).abs() < 1e-10);
}
#[test]
fn test_kfac_update_scaling_factors() {
let delta = arr2(&[[1.0_f64, 0.0], [0.0, 1.0]]);
let ai = arr2(&[[3.0_f64, 0.0], [0.0, 3.0]]);
let gi = arr2(&[[2.0_f64, 0.0], [0.0, 2.0]]);
let result = kfac_update(&[delta], &[ai], &[gi]).expect("K-FAC scaling");
let p = &result[0];
assert!((p[[0, 0]] - 6.0).abs() < 1e-10, "p[0,0]={}", p[[0, 0]]);
assert!((p[[1, 1]] - 6.0).abs() < 1e-10, "p[1,1]={}", p[[1, 1]]);
}
#[test]
fn test_kfac_update_multi_layer() {
let d1 = Array2::<f64>::eye(2);
let d2 = arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]);
let ai1 = Array2::<f64>::eye(2);
let gi1 = Array2::<f64>::eye(2);
let ai2 = Array2::<f64>::eye(2);
let gi2 = Array2::<f64>::eye(2);
let result = kfac_update(&[d1, d2], &[ai1, ai2], &[gi1, gi2]).expect("multi-layer");
assert_eq!(result.len(), 2);
assert!((result[0][[0, 0]] - 1.0).abs() < 1e-10);
assert!((result[1][[0, 1]] - 2.0).abs() < 1e-10);
}
#[test]
fn test_kfac_update_length_mismatch_err() {
let d = Array2::<f64>::eye(2);
let ai = Array2::<f64>::eye(2);
let gi = Array2::<f64>::eye(2);
let r = kfac_update(&[d], &[ai.clone(), ai], &[gi]);
assert!(r.is_err());
}
#[test]
fn test_kfac_factors_identity_inputs() {
let n = 3usize;
let inputs = Array2::<f64>::eye(n);
let grads_m = Array2::<f64>::eye(n);
let damping = 1e-4;
let (ai, gi) = kfac_factors(&inputs, &grads_m, damping).expect("kfac factors");
let expected = 1.0 / (1.0 / (n as f64) + damping);
for i in 0..n {
assert!(
(ai[[i, i]] - expected).abs() < 0.01 * expected,
"ai[{i},{i}]={} expected~{expected}", ai[[i, i]]
);
assert!(
(gi[[i, i]] - expected).abs() < 0.01 * expected,
"gi[{i},{i}]={} expected~{expected}", gi[[i, i]]
);
}
}
#[test]
fn test_kfac_factors_shape_error() {
let inputs = Array2::<f64>::eye(3);
let grads_m = Array2::<f64>::eye(4); let r = kfac_factors(&inputs, &grads_m, 1e-4);
assert!(r.is_err());
}
}