use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
#[allow(dead_code)]
fn shape_err_to_linalg(err: scirs2_core::ndarray::ShapeError) -> crate::error::LinalgError {
crate::error::LinalgError::ShapeError(err.to_string())
}
use crate::decomposition::cholesky;
use crate::error::{LinalgError, LinalgResult};
use crate::norm::matrix_norm;
#[allow(dead_code)]
pub fn kron<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (m, n) = a.dim();
let (p, q) = b.dim();
let mut result = Array2::zeros((m * p, n * q));
for i in 0..m {
for j in 0..n {
for k in 0..p {
for l in 0..q {
result[[i * p + k, j * q + l]] = a[[i, j]] * b[[k, l]];
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn kron_matvec<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
x: &scirs2_core::ndarray::ArrayView1<F>,
) -> LinalgResult<scirs2_core::ndarray::Array1<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
let (m, n) = a.dim();
let (p, q) = b.dim();
if x.len() != n * q {
return Err(LinalgError::ShapeError(format!(
"Vector length ({}) must equal n*q ({}*{}={})",
x.len(),
n,
q,
n * q
)));
}
let x_mat = x
.to_owned()
.into_shape_with_order((n, q))
.map_err(shape_err_to_linalg)?;
let tmp = x_mat.dot(&b.t());
let tmp_reshaped = tmp
.into_shape_with_order(n * p)
.map_err(shape_err_to_linalg)?;
let result = a.dot(
&tmp_reshaped
.into_shape_with_order((n, p))
.map_err(shape_err_to_linalg)?,
);
let result_vec = result
.into_shape_with_order(m * p)
.map_err(shape_err_to_linalg)?;
Ok(result_vec)
}
#[allow(dead_code)]
pub fn kron_matmul<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
x: &ArrayView2<F>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
let (m, n) = a.dim();
let (p, q) = b.dim();
let (nx, r) = x.dim();
if nx != n * q {
return Err(LinalgError::ShapeError(format!(
"First dimension of X ({}) must equal n*q ({}*{}={})",
nx,
n,
q,
n * q
)));
}
let mut result = Array2::zeros((m * p, r));
for col in 0..r {
let x_col = x.slice(scirs2_core::ndarray::s![.., col]);
let y_col = kron_matvec(a, b, &x_col)?;
for i in 0..m * p {
result[[i, col]] = y_col[i];
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn kron_factorize<F>(
m: &ArrayView2<F>,
m_rows: usize,
n_cols: usize,
) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
let (total_rows, total_cols) = m.dim();
let p_rows = total_rows / m_rows;
let q_cols = total_cols / n_cols;
if m_rows * p_rows != total_rows || n_cols * q_cols != total_cols {
return Err(LinalgError::ShapeError(format!(
"Matrix of shape ({total_rows}, {total_cols}) cannot be factorized into ({m_rows}, {n_cols}) and ({p_rows}, {q_cols})"
)));
}
let m_reshaped = (*m)
.into_shape_with_order((m_rows, p_rows, n_cols, q_cols))
.map_err(|_| {
LinalgError::ShapeError(
"Failed to reshape matrix for Kronecker factorization".to_string(),
)
})?;
let m_tensor = m_reshaped;
let mut a = Array2::zeros((m_rows, n_cols));
let mut b = Array2::zeros((p_rows, q_cols));
for i in 0..m_rows {
for j in 0..n_cols {
let mut sum = F::zero();
let mut count = F::zero();
for k in 0..p_rows {
for l in 0..q_cols {
sum += m_tensor[[i, k, j, l]];
count += F::one();
}
}
a[[i, j]] = sum / count;
}
}
let a_norm = a.iter().map(|&x| x * x).sum::<F>().sqrt();
if a_norm > F::epsilon() {
for i in 0..m_rows {
for j in 0..n_cols {
a[[i, j]] /= a_norm;
}
}
}
for k in 0..p_rows {
for l in 0..q_cols {
let mut sum = F::zero();
for i in 0..m_rows {
for j in 0..n_cols {
sum += m_tensor[[i, k, j, l]] * a[[i, j]];
}
}
b[[k, l]] = sum;
}
}
let scaling_factor = a_norm;
for i in 0..m_rows {
for j in 0..n_cols {
a[[i, j]] *= scaling_factor.sqrt();
}
}
for k in 0..p_rows {
for l in 0..q_cols {
b[[k, l]] /= scaling_factor.sqrt();
}
}
Ok((a, b))
}
#[allow(dead_code)]
pub fn kfac_factorization<F>(
input_acts: &ArrayView2<F>,
output_grads: &ArrayView2<F>,
damping: Option<F>,
) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
let (batchsize1, input_dim) = input_acts.dim();
let (batchsize2, output_dim) = output_grads.dim();
if batchsize1 != batchsize2 {
return Err(LinalgError::ShapeError(format!(
"Batch sizes must match: {batchsize1} vs {batchsize2}"
)));
}
let batchsize = batchsize1;
let damping_factor =
damping.unwrap_or_else(|| F::from(1e-4).expect("Failed to convert constant to float"));
let mut input_acts_with_bias = Array2::zeros((batchsize, input_dim + 1));
for i in 0..batchsize {
for j in 0..input_dim {
input_acts_with_bias[[i, j]] = input_acts[[i, j]];
}
input_acts_with_bias[[i, input_dim]] = F::one();
}
let mut a_cov = Array2::zeros((input_dim + 1, input_dim + 1));
for i in 0..(input_dim + 1) {
for j in 0..(input_dim + 1) {
let mut sum = F::zero();
for b in 0..batchsize {
sum += input_acts_with_bias[[b, i]] * input_acts_with_bias[[b, j]];
}
a_cov[[i, j]] = sum / F::from(batchsize).expect("Failed to convert to float");
}
}
for i in 0..(input_dim + 1) {
a_cov[[i, i]] += damping_factor;
}
let mut s_cov = Array2::zeros((output_dim, output_dim));
for i in 0..output_dim {
for j in 0..output_dim {
let mut sum = F::zero();
for b in 0..batchsize {
sum += output_grads[[b, i]] * output_grads[[b, j]];
}
s_cov[[i, j]] = sum / F::from(batchsize).expect("Failed to convert to float");
}
}
for i in 0..output_dim {
s_cov[[i, i]] += damping_factor;
}
Ok((a_cov, s_cov))
}
#[allow(dead_code)]
pub fn kfac_update<F>(
weights: &ArrayView2<F>,
gradients: &ArrayView2<F>,
a_inv: &ArrayView2<F>,
s_inv: &ArrayView2<F>,
learning_rate: F,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
let (input_dim, output_dim) = weights.dim();
let (grad_rows, grad_cols) = gradients.dim();
if input_dim != grad_rows || output_dim != grad_cols {
return Err(LinalgError::ShapeError(format!(
"Weights ({input_dim}, {output_dim}) and gradients ({grad_rows}, {grad_cols}) must have the same shape"
)));
}
if a_inv.dim().0 != input_dim || a_inv.dim().1 != input_dim {
return Err(LinalgError::ShapeError(format!(
"A inverse shape ({}, {}) must match input dimension {}",
a_inv.dim().0,
a_inv.dim().1,
input_dim
)));
}
if s_inv.dim().0 != output_dim || s_inv.dim().1 != output_dim {
return Err(LinalgError::ShapeError(format!(
"S inverse shape ({}, {}) must match output dimension {}",
s_inv.dim().0,
s_inv.dim().1,
output_dim
)));
}
let gradients_owned = gradients.to_owned();
let s_inv_owned = s_inv.to_owned();
let tmp = a_inv.to_owned().dot(&gradients_owned);
let natural_grad = tmp.dot(&s_inv_owned);
let mut new_weights = weights.to_owned();
for i in 0..input_dim {
for j in 0..output_dim {
new_weights[[i, j]] = weights[[i, j]] - learning_rate * natural_grad[[i, j]];
}
}
Ok(new_weights)
}
#[derive(Debug)]
pub struct KFACOptimizer<F> {
pub decay_factor: F,
pub base_damping: F,
pub adaptive_damping: F,
pub min_damping: F,
pub max_damping: F,
pub step_count: usize,
pub input_cov_avg: Option<Array2<F>>,
pub output_cov_avg: Option<Array2<F>>,
pub input_trace: Option<F>,
pub output_trace: Option<F>,
}
impl<F> KFACOptimizer<F>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
pub fn new(decay_factor: Option<F>, basedamping: Option<F>) -> Self {
let decay = decay_factor
.unwrap_or_else(|| F::from(0.95).expect("Failed to convert constant to float"));
let damping = basedamping
.unwrap_or_else(|| F::from(1e-4).expect("Failed to convert constant to float"));
Self {
decay_factor: decay,
base_damping: damping,
adaptive_damping: damping,
min_damping: damping / F::from(10.0).expect("Failed to convert constant to float"),
max_damping: damping * F::from(100.0).expect("Failed to convert constant to float"),
step_count: 0,
input_cov_avg: None,
output_cov_avg: None,
input_trace: None,
output_trace: None,
}
}
pub fn update_covariances(
&mut self,
input_acts: &ArrayView2<F>,
output_grads: &ArrayView2<F>,
) -> LinalgResult<(Array2<F>, Array2<F>)> {
let (current_input_cov, current_output_cov) =
kfac_factorization(input_acts, output_grads, Some(F::zero()))?;
match (&mut self.input_cov_avg, &mut self.output_cov_avg) {
(Some(ref mut input_avg), Some(ref mut output_avg)) => {
for i in 0..input_avg.nrows() {
for j in 0..input_avg.ncols() {
input_avg[[i, j]] = self.decay_factor * input_avg[[i, j]]
+ (F::one() - self.decay_factor) * current_input_cov[[i, j]];
}
}
for i in 0..output_avg.nrows() {
for j in 0..output_avg.ncols() {
output_avg[[i, j]] = self.decay_factor * output_avg[[i, j]]
+ (F::one() - self.decay_factor) * current_output_cov[[i, j]];
}
}
}
_ => {
self.input_cov_avg = Some(current_input_cov.clone());
self.output_cov_avg = Some(current_output_cov.clone());
}
}
let bias_correction = F::one() - self.decay_factor.powi(self.step_count as i32 + 1);
let mut corrected_input = self
.input_cov_avg
.as_ref()
.expect("Operation failed")
.clone();
let mut corrected_output = self
.output_cov_avg
.as_ref()
.expect("Operation failed")
.clone();
for i in 0..corrected_input.nrows() {
for j in 0..corrected_input.ncols() {
corrected_input[[i, j]] /= bias_correction;
}
}
for i in 0..corrected_output.nrows() {
for j in 0..corrected_output.ncols() {
corrected_output[[i, j]] /= bias_correction;
}
}
for i in 0..corrected_input.nrows() {
corrected_input[[i, i]] += self.adaptive_damping;
}
for i in 0..corrected_output.nrows() {
corrected_output[[i, i]] += self.adaptive_damping;
}
let input_trace = (0..corrected_input.nrows())
.map(|i| corrected_input[[i, i]])
.sum::<F>();
let output_trace = (0..corrected_output.nrows())
.map(|i| corrected_output[[i, i]])
.sum::<F>();
self.input_trace = Some(input_trace);
self.output_trace = Some(output_trace);
self.step_count += 1;
Ok((corrected_input, corrected_output))
}
pub fn adjust_damping(&mut self, loss_improved: bool, improvementratio: Option<F>) {
if loss_improved {
if let Some(_ratio) = improvementratio {
if _ratio > F::from(0.75).expect("Failed to convert constant to float") {
self.adaptive_damping = (self.adaptive_damping
/ F::from(3.0).expect("Failed to convert constant to float"))
.max(self.min_damping);
} else if _ratio > F::from(0.25).expect("Failed to convert constant to float") {
self.adaptive_damping = (self.adaptive_damping
/ F::from(2.0).expect("Failed to convert constant to float"))
.max(self.min_damping);
}
} else {
self.adaptive_damping = (self.adaptive_damping
/ F::from(1.5).expect("Failed to convert constant to float"))
.max(self.min_damping);
}
} else {
self.adaptive_damping = (self.adaptive_damping
* F::from(2.0).expect("Failed to convert constant to float"))
.min(self.max_damping);
}
}
pub fn get_damping(&self) -> F {
self.adaptive_damping
}
pub fn reset(&mut self) {
self.step_count = 0;
self.input_cov_avg = None;
self.output_cov_avg = None;
self.input_trace = None;
self.output_trace = None;
self.adaptive_damping = self.base_damping;
}
}
#[derive(Debug)]
pub struct BlockDiagonalFisher<F> {
pub layer_factors: Vec<(Array2<F>, Array2<F>)>,
pub inverse_factors: Vec<(Array2<F>, Array2<F>)>,
pub layer_dims: Vec<(usize, usize)>,
pub damping: F,
}
impl<F> BlockDiagonalFisher<F>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
pub fn new(layer_dims: Vec<(usize, usize)>, damping: F) -> Self {
Self {
layer_factors: Vec::new(),
inverse_factors: Vec::new(),
layer_dims,
damping,
}
}
pub fn update_fisher(
&mut self,
layer_activations: &[ArrayView2<F>],
layer_gradients: &[ArrayView2<F>],
) -> LinalgResult<()> {
if layer_activations.len() != layer_gradients.len()
|| layer_activations.len() != self.layer_dims.len()
{
return Err(LinalgError::ShapeError(
"Mismatched number of layers".to_string(),
));
}
self.layer_factors.clear();
self.inverse_factors.clear();
for (i, (&(input_dim, output_dim), (acts, grads))) in self
.layer_dims
.iter()
.zip(layer_activations.iter().zip(layer_gradients.iter()))
.enumerate()
{
if acts.ncols() != input_dim || grads.ncols() != output_dim {
return Err(LinalgError::ShapeError(format!(
"Layer {} dimension mismatch: expected ({}, {}), got ({}, {})",
i,
input_dim,
output_dim,
acts.ncols(),
grads.ncols()
)));
}
let (input_cov, output_cov) = kfac_factorization(acts, grads, Some(self.damping))?;
let input_inv = self.stable_inverse(&input_cov.view())?;
let output_inv = self.stable_inverse(&output_cov.view())?;
self.layer_factors.push((input_cov, output_cov));
self.inverse_factors.push((input_inv, output_inv));
}
Ok(())
}
fn stable_inverse(&self, matrix: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let n = matrix.nrows();
if let Ok(l) = cholesky(matrix, None) {
let mut inv = Array2::eye(n);
for col in 0..n {
let mut b = Array1::zeros(n);
b[col] = F::one();
let mut y = Array1::zeros(n);
for i in 0..n {
let mut sum = F::zero();
for j in 0..i {
sum += l[[i, j]] * y[j];
}
y[i] = (b[i] - sum) / l[[i, i]];
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = F::zero();
for j in (i + 1)..n {
sum += l[[j, i]] * x[j];
}
x[i] = (y[i] - sum) / l[[i, i]];
}
for i in 0..n {
inv[[i, col]] = x[i];
}
}
return Ok(inv);
}
let mut regularized = matrix.to_owned();
for i in 0..n {
regularized[[i, i]] +=
self.damping * F::from(10.0).expect("Failed to convert constant to float");
}
let mut inv = Array2::eye(n);
for i in 0..n {
inv[[i, i]] = F::one() / (regularized[[i, i]] + self.damping);
}
Ok(inv)
}
pub fn precondition_gradients(
&self,
layer_gradients: &[ArrayView2<F>],
) -> LinalgResult<Vec<Array2<F>>> {
if layer_gradients.len() != self.inverse_factors.len() {
return Err(LinalgError::ShapeError(
"Number of gradient matrices must match number of layers".to_string(),
));
}
let mut preconditioned = Vec::new();
for ((grads, (input_inv, output_inv)), &(input_dim, output_dim)) in layer_gradients
.iter()
.zip(self.inverse_factors.iter())
.zip(self.layer_dims.iter())
{
let (batchsize, grad_output_dim) = grads.dim();
if grad_output_dim != output_dim {
return Err(LinalgError::ShapeError(format!(
"Gradient output dimension mismatch: expected {output_dim}, got {grad_output_dim}"
)));
}
let mut extended_grads = Array2::zeros((input_dim + 1, output_dim));
for i in 0..input_dim {
for j in 0..output_dim {
let mut sum = F::zero();
for b in 0..batchsize {
sum += grads[[b, j]]; }
extended_grads[[i, j]] =
sum / F::from(batchsize).expect("Failed to convert to float");
}
}
for j in 0..output_dim {
let mut sum = F::zero();
for b in 0..batchsize {
sum += grads[[b, j]];
}
extended_grads[[input_dim, j]] =
sum / F::from(batchsize).expect("Failed to convert to float");
}
let temp = input_inv.dot(&extended_grads);
let preconditioned_grad = temp.dot(output_inv);
let mut result = Array2::zeros((batchsize, output_dim));
for b in 0..batchsize {
for j in 0..output_dim {
result[[b, j]] = preconditioned_grad[[0, j]]; }
}
preconditioned.push(result);
}
Ok(preconditioned)
}
pub fn memory_info(&self) -> BlockFisherMemoryInfo {
let mut total_elements = 0;
let mut total_inverse_elements = 0;
let mut original_elements = 0;
for ((input_cov, output_cov), &(input_dim, output_dim)) in
self.layer_factors.iter().zip(self.layer_dims.iter())
{
total_elements += input_cov.len() + output_cov.len();
total_inverse_elements += input_cov.len() + output_cov.len(); original_elements += input_dim * output_dim * input_dim * output_dim;
}
let compression_ratio =
original_elements as f64 / (total_elements + total_inverse_elements) as f64;
BlockFisherMemoryInfo {
num_layers: self.layer_factors.len(),
total_factor_elements: total_elements,
total_inverse_elements,
compression_ratio,
estimated_full_fisher_elements: original_elements,
}
}
}
#[derive(Debug)]
pub struct BlockFisherMemoryInfo {
pub num_layers: usize,
pub total_factor_elements: usize,
pub total_inverse_elements: usize,
pub compression_ratio: f64,
pub estimated_full_fisher_elements: usize,
}
#[allow(dead_code)]
pub fn advanced_kfac_step<F>(
weights: &ArrayView2<F>,
gradients: &ArrayView2<F>,
kfac_optimizer: &mut KFACOptimizer<F>,
input_acts: &ArrayView2<F>,
output_grads: &ArrayView2<F>,
learning_rate: F,
_momentum: Option<F>,
gradient_clip: Option<F>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
let (input_cov, output_cov) = kfac_optimizer.update_covariances(input_acts, output_grads)?;
let input_inv = stablematrix_inverse(&input_cov.view(), kfac_optimizer.get_damping())?;
let output_inv = stablematrix_inverse(&output_cov.view(), kfac_optimizer.get_damping())?;
let temp = input_inv.dot(gradients);
let mut natural_grad = temp.dot(&output_inv);
if let Some(clip_threshold) = gradient_clip {
let grad_norm = matrix_norm(&natural_grad.view(), "fro", None)?;
if grad_norm > clip_threshold {
let scale_factor = clip_threshold / grad_norm;
for elem in natural_grad.iter_mut() {
*elem *= scale_factor;
}
}
}
let mut new_weights = weights.to_owned();
for i in 0..weights.nrows() {
for j in 0..weights.ncols() {
new_weights[[i, j]] = weights[[i, j]] - learning_rate * natural_grad[[i, j]];
}
}
Ok(new_weights)
}
#[allow(dead_code)]
fn stablematrix_inverse<F>(matrix: &ArrayView2<F>, damping: F) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync,
{
let n = matrix.nrows();
let mut regularized = matrix.to_owned();
for i in 0..n {
regularized[[i, i]] += damping;
}
match cholesky(®ularized.view(), None) {
Ok(l) => {
let mut inv = Array2::eye(n);
for col in 0..n {
let mut b = Array1::zeros(n);
b[col] = F::one();
let mut y = Array1::zeros(n);
for i in 0..n {
let mut sum = F::zero();
for j in 0..i {
sum += l[[i, j]] * y[j];
}
y[i] = (b[i] - sum) / l[[i, i]];
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = F::zero();
for j in (i + 1)..n {
sum += l[[j, i]] * x[j];
}
x[i] = (y[i] - sum) / l[[i, i]];
}
for i in 0..n {
inv[[i, col]] = x[i];
}
}
Ok(inv)
}
Err(_) => {
let mut inv = Array2::zeros((n, n));
for i in 0..n {
let diag_val = matrix[[i, i]]
+ damping * F::from(100.0).expect("Failed to convert constant to float");
inv[[i, i]] = F::one() / diag_val;
}
Ok(inv)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_kron_simple() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[0.1, 0.2], [0.3, 0.4]];
let c = kron(&a.view(), &b.view()).expect("Operation failed");
assert_eq!(c.shape(), &[4, 4]);
assert_relative_eq!(c[[0, 0]], 0.1);
assert_relative_eq!(c[[0, 1]], 0.2);
assert_relative_eq!(c[[0, 2]], 0.2);
assert_relative_eq!(c[[0, 3]], 0.4);
assert_relative_eq!(c[[1, 0]], 0.3);
assert_relative_eq!(c[[1, 1]], 0.4);
assert_relative_eq!(c[[1, 2]], 0.6);
assert_relative_eq!(c[[1, 3]], 0.8);
assert_relative_eq!(c[[2, 0]], 0.3);
assert_relative_eq!(c[[2, 1]], 0.6);
assert_relative_eq!(c[[2, 2]], 0.4);
assert_relative_eq!(c[[2, 3]], 0.8);
assert_relative_eq!(c[[3, 0]], 0.9);
assert_relative_eq!(c[[3, 1]], 1.2);
assert_relative_eq!(c[[3, 2]], 1.2);
assert_relative_eq!(c[[3, 3]], 1.6);
}
#[test]
fn test_kron_matvec() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[0.1, 0.2], [0.3, 0.4]];
let x = array![1.0, 2.0, 3.0, 4.0];
let y = kron_matvec(&a.view(), &b.view(), &x.view()).expect("Operation failed");
let ab = kron(&a.view(), &b.view()).expect("Operation failed");
let y_direct = ab.dot(&x);
assert_eq!(y.shape(), y_direct.shape());
for i in 0..y.len() {
assert_relative_eq!(y[i], y_direct[i], epsilon = 1e-10);
}
}
#[test]
fn test_kron_matmul() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[0.1, 0.2], [0.3, 0.4]];
let x = array![[1.0, 5.0], [2.0, 6.0], [3.0, 7.0], [4.0, 8.0]];
let y = kron_matmul(&a.view(), &b.view(), &x.view()).expect("Operation failed");
let ab = kron(&a.view(), &b.view()).expect("Operation failed");
let y_direct = ab.dot(&x);
assert_eq!(y.shape(), y_direct.shape());
for i in 0..y.dim().0 {
for j in 0..y.dim().1 {
assert_relative_eq!(y[[i, j]], y_direct[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_kron_factorize() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[0.1, 0.2], [0.3, 0.4]];
let ab = kron(&a.view(), &b.view()).expect("Operation failed");
let (a_hat, b_hat) = kron_factorize(&ab.view(), 2, 2).expect("Operation failed");
let ab_hat = kron(&a_hat.view(), &b_hat.view()).expect("Operation failed");
let mut error = 0.0f64;
for i in 0..4 {
for j in 0..4 {
error += (ab[[i, j]] - ab_hat[[i, j]]).abs() as f64;
}
}
error /= 16.0;
assert!(error < 0.1, "Average error was {}", error);
}
#[test]
fn test_kfac_factorization() {
let input_acts = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
let output_grads = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6],];
let (a_cov, s_cov) =
kfac_factorization(&input_acts.view(), &output_grads.view(), Some(0.01))
.expect("Operation failed");
assert_eq!(a_cov.shape(), &[4, 4]); assert_eq!(s_cov.shape(), &[2, 2]);
for i in 0..4 {
assert!(a_cov[[i, i]] > 0.0);
}
for i in 0..2 {
assert!(s_cov[[i, i]] > 0.0);
}
}
#[test]
fn test_kfac_update() {
let weights = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6],];
let gradients = array![[0.01, 0.02], [0.03, 0.04], [0.05, 0.06],];
let a_inv = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0],];
let s_inv = array![[1.0, 0.0], [0.0, 1.0],];
let learning_rate = 0.1;
let new_weights = kfac_update(
&weights.view(),
&gradients.view(),
&a_inv.view(),
&s_inv.view(),
learning_rate,
)
.expect("Failed to apply natural gradient");
for i in 0..3 {
for j in 0..2 {
assert_relative_eq!(
new_weights[[i, j]],
weights[[i, j]] - learning_rate * gradients[[i, j]],
epsilon = 1e-10
);
}
}
}
#[test]
fn test_kfac_optimizer_basic() {
let mut optimizer = KFACOptimizer::<f64>::new(Some(0.9), Some(0.01));
let input_acts = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let output_grads = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]];
let (input_cov1, output_cov1) = optimizer
.update_covariances(&input_acts.view(), &output_grads.view())
.expect("Failed to update covariances");
assert_eq!(optimizer.step_count, 1);
assert!(optimizer.input_cov_avg.is_some());
assert!(optimizer.output_cov_avg.is_some());
let (input_cov2, output_cov2) = optimizer
.update_covariances(&input_acts.view(), &output_grads.view())
.expect("Failed to update covariances");
assert_eq!(optimizer.step_count, 2);
assert!((input_cov1[[0, 0]] - input_cov2[[0, 0]]).abs() > 1e-10);
}
#[test]
fn test_kfac_optimizer_damping_adjustment() {
let mut optimizer = KFACOptimizer::<f64>::new(None, Some(0.01));
let initial_damping = optimizer.get_damping();
optimizer.adjust_damping(true, Some(0.8));
let after_improvement = optimizer.get_damping();
assert!(after_improvement < initial_damping);
optimizer.adjust_damping(false, None);
let after_deterioration = optimizer.get_damping();
assert!(after_deterioration > after_improvement);
for _ in 0..20 {
optimizer.adjust_damping(false, None);
}
assert!(optimizer.get_damping() <= optimizer.max_damping);
for _ in 0..20 {
optimizer.adjust_damping(true, Some(0.9));
}
assert!(optimizer.get_damping() >= optimizer.min_damping);
}
#[test]
fn test_block_diagonal_fisher() {
let layer_dims = vec![(10, 20), (20, 10)];
let mut fisher = BlockDiagonalFisher::<f64>::new(layer_dims, 0.01);
let layer1_acts = Array2::from_shape_fn((5, 10), |(i, j)| (i + j) as f64 * 0.1);
let layer1_grads = Array2::from_shape_fn((5, 20), |(i, j)| (i + j) as f64 * 0.01);
let layer2_acts = Array2::from_shape_fn((5, 20), |(i, j)| (i + j) as f64 * 0.05);
let layer2_grads = Array2::from_shape_fn((5, 10), |(i, j)| (i + j) as f64 * 0.02);
let activations = vec![layer1_acts.view(), layer2_acts.view()];
let gradients = vec![layer1_grads.view(), layer2_grads.view()];
fisher
.update_fisher(&activations, &gradients)
.expect("Operation failed");
assert_eq!(fisher.layer_factors.len(), 2);
assert_eq!(fisher.inverse_factors.len(), 2);
let grad_matrices = vec![layer1_grads.view(), layer2_grads.view()];
let preconditioned = fisher
.precondition_gradients(&grad_matrices)
.expect("Operation failed");
assert_eq!(preconditioned.len(), 2);
assert_eq!(preconditioned[0].shape(), layer1_grads.shape());
assert_eq!(preconditioned[1].shape(), layer2_grads.shape());
let memory_info = fisher.memory_info();
assert_eq!(memory_info.num_layers, 2);
assert!(memory_info.compression_ratio > 1.0);
}
#[test]
fn test_advanced_kfac_step() {
let mut optimizer = KFACOptimizer::<f64>::new(Some(0.95), Some(0.001));
let weights = array![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]];
let gradients = array![[0.01, 0.02, 0.03], [0.04, 0.05, 0.06], [0.07, 0.08, 0.09]];
let input_acts = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let output_grads = array![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]];
let learning_rate = 0.01;
let new_weights = advanced_kfac_step(
&weights.view(),
&gradients.view(),
&mut optimizer,
&input_acts.view(),
&output_grads.view(),
learning_rate,
None,
Some(1.0), )
.expect("Failed to compute step");
assert_eq!(new_weights.shape(), weights.shape());
let mut weights_changed = false;
for i in 0..weights.nrows() {
for j in 0..weights.ncols() {
if (weights[[i, j]] - new_weights[[i, j]]).abs() > 1e-10 {
weights_changed = true;
break;
}
}
}
assert!(weights_changed);
assert_eq!(optimizer.step_count, 1);
assert!(optimizer.input_cov_avg.is_some());
}
#[test]
fn test_stablematrix_inverse() {
let matrix = array![[2.0, 1.0], [1.0, 2.0]];
let damping = 0.01;
let inv = stablematrix_inverse(&matrix.view(), damping).expect("Operation failed");
let mut regularized = matrix.clone();
for i in 0..2 {
regularized[[i, i]] += damping;
}
let product = regularized.dot(&inv);
let identity = Array2::eye(2);
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(product[[i, j]], identity[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_kfac_optimizer_reset() {
let mut optimizer = KFACOptimizer::<f64>::new(Some(0.9), Some(0.01));
let input_acts = array![[1.0, 2.0], [3.0, 4.0]];
let output_grads = array![[0.1, 0.2], [0.3, 0.4]];
optimizer
.update_covariances(&input_acts.view(), &output_grads.view())
.expect("Failed to update covariances");
optimizer.adjust_damping(false, None);
assert!(optimizer.step_count > 0);
assert!(optimizer.input_cov_avg.is_some());
optimizer.reset();
assert_eq!(optimizer.step_count, 0);
assert!(optimizer.input_cov_avg.is_none());
assert!(optimizer.output_cov_avg.is_none());
assert_eq!(optimizer.adaptive_damping, optimizer.base_damping);
}
#[test]
fn test_block_fisher_memory_info() {
let layer_dims = vec![(10, 5), (5, 3)];
let mut fisher = BlockDiagonalFisher::<f64>::new(layer_dims, 0.01);
let layer1_acts = Array2::zeros((8, 10));
let layer1_grads = Array2::zeros((8, 5));
let layer2_acts = Array2::zeros((8, 5));
let layer2_grads = Array2::zeros((8, 3));
let activations = vec![layer1_acts.view(), layer2_acts.view()];
let gradients = vec![layer1_grads.view(), layer2_grads.view()];
fisher
.update_fisher(&activations, &gradients)
.expect("Operation failed");
let memory_info = fisher.memory_info();
assert!(memory_info.compression_ratio > 1.0);
assert_eq!(memory_info.num_layers, 2);
let layer1_full_fisher = (10 * 5) * (10 * 5); let layer2_full_fisher = (5 * 3) * (5 * 3);
let expected_full = layer1_full_fisher + layer2_full_fisher;
assert_eq!(memory_info.estimated_full_fisher_elements, expected_full);
}
}