pub mod config;
pub mod core;
pub mod layer_state;
pub mod natural_gradients;
pub mod utils;
pub use config::{KFACConfig, KFACStats, LayerInfo, LayerType};
pub use core::KFAC;
pub use layer_state::KFACLayerState;
pub use natural_gradients::{NaturalGradientCompute, NaturalGradientConfig};
pub use utils::{KFACUtils, OrderedFloat};
pub use natural_gradients::NaturalGradientCompute as NGCompute;
pub use utils::KFACUtils as Utils;
#[cfg(test)]
mod integration_tests {
use super::*;
use scirs2_core::ndarray::Array2;
use std::collections::HashMap;
#[test]
fn test_kfac_integration_dense_layer() {
let config = KFACConfig::<f32>::default();
let mut kfac = KFAC::new(config);
let layer_info = LayerInfo::dense("dense1".to_string(), 4, 2, true);
assert!(kfac.register_layer(layer_info).is_ok());
let activations = Array2::from_shape_vec(
(3, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("unwrap failed");
let gradients = Array2::from_shape_vec((3, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
.expect("unwrap failed");
let mut layer_gradients = HashMap::new();
layer_gradients.insert("dense1".to_string(), (&activations, &gradients));
let updates = kfac
.step::<fn() -> f32>(layer_gradients, None)
.expect("unwrap failed");
assert!(updates.contains_key("dense1"));
assert_eq!(updates["dense1"].dim(), gradients.dim());
}
#[test]
fn test_kfac_integration_multiple_layers() {
let config = KFACConfig::<f64>::for_small_model();
let mut kfac = KFAC::new(config);
let layer1 = LayerInfo::dense("layer1".to_string(), 8, 4, true);
let layer2 = LayerInfo::dense("layer2".to_string(), 4, 2, false);
assert!(kfac.register_layer(layer1).is_ok());
assert!(kfac.register_layer(layer2).is_ok());
assert_eq!(kfac.num_layers(), 2);
assert!(kfac.has_layer("layer1"));
assert!(kfac.has_layer("layer2"));
let batch_size = 5;
let activations1 = Array2::ones((batch_size, 8));
let gradients1 = Array2::ones((batch_size, 4)) * 0.1;
let activations2 = Array2::ones((batch_size, 4));
let gradients2 = Array2::ones((batch_size, 2)) * 0.2;
let mut layer_gradients = HashMap::new();
layer_gradients.insert("layer1".to_string(), (&activations1, &gradients1));
layer_gradients.insert("layer2".to_string(), (&activations2, &gradients2));
for step in 0..5 {
let updates = kfac
.step::<fn() -> f64>(layer_gradients.clone(), None)
.expect("unwrap failed");
assert_eq!(updates.len(), 2);
assert!(updates.contains_key("layer1"));
assert!(updates.contains_key("layer2"));
assert_eq!(updates["layer1"].dim(), gradients1.dim());
assert_eq!(updates["layer2"].dim(), gradients2.dim());
assert_eq!(kfac.step_count(), step + 1);
}
let stats = kfac.get_stats();
assert_eq!(stats.total_steps, 5);
assert!(stats.cov_updates > 0);
}
#[test]
fn test_kfac_memory_usage() {
let config = KFACConfig::<f32>::default();
let mut kfac = KFAC::new(config);
let layer_info = LayerInfo::dense("large_layer".to_string(), 512, 256, true);
kfac.register_layer(layer_info).expect("unwrap failed");
let memory_usage = kfac.estimate_memory_usage();
assert!(memory_usage > 0);
let expected_minimum = (513 * 513 + 256 * 256) * std::mem::size_of::<f32>();
assert!(memory_usage >= expected_minimum);
}
#[test]
fn test_kfac_adaptive_damping() {
let mut config = KFACConfig::<f32> {
auto_damping: true,
target_acceptance_ratio: 0.8,
..Default::default()
};
let mut kfac = KFAC::new(config);
let layer_info = LayerInfo::dense("test_layer".to_string(), 4, 2, false);
kfac.register_layer(layer_info).expect("unwrap failed");
let activations = Array2::ones((2, 4));
let gradients = Array2::ones((2, 2)) * 0.1;
let mut layer_gradients = HashMap::new();
layer_gradients.insert("test_layer".to_string(), (&activations, &gradients));
let loss_fn = || 1.0; kfac.step(layer_gradients.clone(), Some(loss_fn))
.expect("unwrap failed");
let improving_loss_fn = || 0.8; kfac.step(layer_gradients.clone(), Some(improving_loss_fn))
.expect("unwrap failed");
assert!(kfac.acceptance_ratio() >= 1.0);
let worsening_loss_fn = || 1.2; kfac.step(layer_gradients, Some(worsening_loss_fn))
.expect("unwrap failed");
assert!(kfac.acceptance_ratio() < 1.2);
}
#[test]
fn test_kfac_layer_specific_damping() {
let config = KFACConfig::<f64>::default();
let mut kfac = KFAC::new(config);
let layer_info = LayerInfo::dense("test_layer".to_string(), 3, 2, false);
kfac.register_layer(layer_info).expect("unwrap failed");
assert!(kfac.set_layer_damping("test_layer", 0.01, 0.02).is_ok());
let state = kfac.get_layer_state("test_layer").expect("unwrap failed");
assert!((state.damping_a - 0.01).abs() < 1e-10);
assert!((state.damping_g - 0.02).abs() < 1e-10);
assert!(kfac.set_layer_damping("nonexistent", 0.01, 0.02).is_err());
}
#[test]
fn test_kfac_reset() {
let config = KFACConfig::<f32>::default();
let mut kfac = KFAC::new(config);
let layer_info = LayerInfo::dense("test_layer".to_string(), 2, 2, false);
kfac.register_layer(layer_info).expect("unwrap failed");
let activations = Array2::ones((2, 2));
let gradients = Array2::ones((2, 2)) * 0.1;
let mut layer_gradients = HashMap::new();
layer_gradients.insert("test_layer".to_string(), (&activations, &gradients));
kfac.step::<fn() -> f32>(layer_gradients.clone(), None)
.expect("unwrap failed");
kfac.step::<fn() -> f32>(layer_gradients, None)
.expect("unwrap failed");
assert_eq!(kfac.step_count(), 2);
assert!(kfac.get_stats().total_steps > 0);
kfac.reset();
assert_eq!(kfac.step_count(), 0);
assert_eq!(kfac.get_stats().total_steps, 0);
assert!((kfac.acceptance_ratio() - 1.0).abs() < 1e-6);
assert!(kfac.has_layer("test_layer"));
let state = kfac.get_layer_state("test_layer").expect("unwrap failed");
assert_eq!(state.num_updates, 0);
assert!(!state.is_ready()); }
}