use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
use super::config::{KFACConfig, KFACStats, LayerInfo};
use super::layer_state::KFACLayerState;
#[derive(Debug)]
pub struct KFAC<T: Float + Debug + Send + Sync + 'static> {
config: KFACConfig<T>,
layer_states: HashMap<String, KFACLayerState<T>>,
step_count: usize,
acceptance_ratio: T,
previous_loss: Option<T>,
eigenvalue_history: Vec<T>,
stats: KFACStats<T>,
}
impl<
T: Float
+ Debug
+ Default
+ Clone
+ Send
+ Sync
+ std::iter::Sum
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ scirs2_core::numeric::FromPrimitive,
> KFAC<T>
{
pub fn new(config: KFACConfig<T>) -> Self {
Self {
config,
layer_states: HashMap::new(),
step_count: 0,
acceptance_ratio: T::from(1.0).unwrap_or_else(|| T::zero()),
previous_loss: None,
eigenvalue_history: Vec::new(),
stats: KFACStats::default(),
}
}
pub fn register_layer(&mut self, layer_info: LayerInfo) -> Result<()> {
let layer_name = layer_info.name.clone();
let state = KFACLayerState::new(layer_info, self.config.damping);
self.layer_states.insert(layer_name, state);
Ok(())
}
pub fn update_covariance_matrices(
&mut self,
layer_name: &str,
activations: &Array2<T>,
gradients: &Array2<T>,
) -> Result<()> {
let should_update = {
let state = self.layer_states.get(layer_name).ok_or_else(|| {
OptimError::InvalidParameter(format!("Layer {} not found", layer_name))
})?;
self.step_count.saturating_sub(state.last_cov_update) >= self.config.cov_update_freq
};
if should_update {
let state = self.layer_states.get_mut(layer_name).ok_or_else(|| {
OptimError::InvalidParameter(format!("Layer {} not found", layer_name))
})?;
state.update_input_covariance(activations, self.config.stat_decay);
state.update_output_covariance(gradients, self.config.stat_decay);
state.last_cov_update = self.step_count;
self.stats.cov_updates += 1;
}
Ok(())
}
pub fn update_inverse_matrices(&mut self, layer_name: &str) -> Result<()> {
let should_update = {
let state = self.layer_states.get(layer_name).ok_or_else(|| {
OptimError::InvalidParameter(format!("Layer {} not found", layer_name))
})?;
self.step_count.saturating_sub(state.last_inv_update) >= self.config.inv_update_freq
};
if should_update {
let current_damping = self.get_adaptive_damping(layer_name)?;
let state = self.layer_states.get_mut(layer_name).ok_or_else(|| {
OptimError::InvalidParameter(format!("Layer {} not found", layer_name))
})?;
state.compute_inverses(current_damping, current_damping)?;
self.stats.inv_updates += 1;
let (a_cond, g_cond) = state.condition_number_estimate();
let avg_cond = (a_cond + g_cond) / T::from(2.0).unwrap_or_else(|| T::zero());
let decay = T::from(0.95).unwrap_or_else(|| T::zero());
self.stats.avg_condition_number =
decay * self.stats.avg_condition_number + (T::one() - decay) * avg_cond;
}
Ok(())
}
pub fn apply_update(&mut self, layer_name: &str, gradients: &Array2<T>) -> Result<Array2<T>> {
let state = self.layer_states.get(layer_name).ok_or_else(|| {
OptimError::InvalidParameter(format!("Layer {} not found", layer_name))
})?;
if !state.is_ready() {
return Ok(gradients * self.config.learning_rate);
}
let a_inv = state.a_cov_inv.as_ref().expect("unwrap failed");
let g_inv = state.g_cov_inv.as_ref().expect("unwrap failed");
let natural_gradients = g_inv.dot(gradients).dot(a_inv);
let mut update = natural_gradients * self.config.learning_rate;
if self.config.weight_decay > T::zero() {
}
Ok(update)
}
pub fn step<F>(
&mut self,
layer_gradients: HashMap<String, (&Array2<T>, &Array2<T>)>,
loss_fn: Option<F>,
) -> Result<HashMap<String, Array2<T>>>
where
F: FnOnce() -> T,
{
self.step_count += 1;
self.stats.total_steps += 1;
let mut updates = HashMap::new();
for (layer_name, (activations, gradients)) in &layer_gradients {
self.update_covariance_matrices(layer_name, activations, gradients)?;
}
for layer_name in layer_gradients.keys() {
self.update_inverse_matrices(layer_name)?;
}
for (layer_name, (_, gradients)) in &layer_gradients {
let update = self.apply_update(layer_name, gradients)?;
updates.insert(layer_name.clone(), update);
}
if self.config.auto_damping {
if let Some(loss_fn) = loss_fn {
let current_loss = loss_fn();
self.update_damping(current_loss);
}
}
Ok(updates)
}
pub fn get_stats(&self) -> &KFACStats<T> {
&self.stats
}
pub fn reset(&mut self) {
for state in self.layer_states.values_mut() {
state.reset();
}
self.step_count = 0;
self.acceptance_ratio = T::from(1.0).unwrap_or_else(|| T::zero());
self.previous_loss = None;
self.eigenvalue_history.clear();
self.stats = KFACStats::default();
}
pub fn estimate_memory_usage(&self) -> usize {
let mut total = 0;
for state in self.layer_states.values() {
total += state.memory_usage();
}
total += std::mem::size_of::<Self>();
total += self.eigenvalue_history.capacity() * std::mem::size_of::<T>();
total
}
pub fn get_layer_state(&self, layer_name: &str) -> Option<&KFACLayerState<T>> {
self.layer_states.get(layer_name)
}
pub fn set_layer_damping(
&mut self,
layer_name: &str,
damping_a: T,
damping_g: T,
) -> Result<()> {
let state = self.layer_states.get_mut(layer_name).ok_or_else(|| {
OptimError::InvalidParameter(format!("Layer {} not found", layer_name))
})?;
state.damping_a = damping_a;
state.damping_g = damping_g;
Ok(())
}
pub fn num_layers(&self) -> usize {
self.layer_states.len()
}
pub fn layer_names(&self) -> Vec<String> {
self.layer_states.keys().cloned().collect()
}
pub fn has_layer(&self, layer_name: &str) -> bool {
self.layer_states.contains_key(layer_name)
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn acceptance_ratio(&self) -> T {
self.acceptance_ratio
}
fn get_adaptive_damping(&self, layer_name: &str) -> Result<T> {
if !self.config.auto_damping {
return Ok(self.config.damping);
}
let base_damping = self.config.damping;
let ratio_diff = self.acceptance_ratio - self.config.target_acceptance_ratio;
if ratio_diff > T::zero() {
Ok(base_damping * T::from(0.9).unwrap_or_else(|| T::zero()))
} else {
Ok(base_damping * T::from(1.1).unwrap_or_else(|| T::zero()))
}
}
fn update_damping(&mut self, current_loss: T) {
if let Some(prev_loss) = self.previous_loss {
let loss_ratio = current_loss / prev_loss;
let decay = T::from(0.95).unwrap_or_else(|| T::zero());
if loss_ratio <= T::one() {
self.acceptance_ratio = decay * self.acceptance_ratio
+ (T::one() - decay) * T::from(1.2).unwrap_or_else(|| T::zero());
} else {
self.acceptance_ratio = decay * self.acceptance_ratio
+ (T::one() - decay) * T::from(0.8).unwrap_or_else(|| T::zero());
}
let min_ratio = T::from(0.1).unwrap_or_else(|| T::zero());
let max_ratio = T::from(2.0).unwrap_or_else(|| T::zero());
self.acceptance_ratio = self.acceptance_ratio.max(min_ratio).min(max_ratio);
}
self.previous_loss = Some(current_loss);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::second_order::kfac::config::{LayerInfo, LayerType};
#[test]
fn test_kfac_creation() {
let config = KFACConfig::<f32>::default();
let kfac = KFAC::new(config);
assert_eq!(kfac.num_layers(), 0);
assert_eq!(kfac.step_count(), 0);
}
#[test]
fn test_layer_registration() {
let config = KFACConfig::<f64>::default();
let mut kfac = KFAC::new(config);
let layer_info = LayerInfo {
name: "test_layer".to_string(),
input_dim: 128,
output_dim: 64,
layer_type: LayerType::Dense,
has_bias: true,
};
assert!(kfac.register_layer(layer_info).is_ok());
assert_eq!(kfac.num_layers(), 1);
assert!(kfac.has_layer("test_layer"));
}
#[test]
fn test_covariance_update() {
let config = KFACConfig::<f32> {
cov_update_freq: 1, ..Default::default()
};
let mut kfac = KFAC::new(config);
let layer_info = LayerInfo {
name: "test_layer".to_string(),
input_dim: 4,
output_dim: 2,
layer_type: LayerType::Dense,
has_bias: false,
};
kfac.register_layer(layer_info).expect("unwrap failed");
let activations =
Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("unwrap failed");
let gradients =
Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).expect("unwrap failed");
let mut layer_gradients = HashMap::new();
layer_gradients.insert("test_layer".to_string(), (&activations, &gradients));
assert!(kfac.step(layer_gradients, None::<fn() -> f32>).is_ok());
let stats = kfac.get_stats();
assert!(stats.cov_updates > 0);
}
#[test]
fn test_memory_usage_estimation() {
let config = KFACConfig::<f64>::default();
let mut kfac = KFAC::new(config);
let layer_info = LayerInfo {
name: "large_layer".to_string(),
input_dim: 1000,
output_dim: 500,
layer_type: LayerType::Dense,
has_bias: true,
};
kfac.register_layer(layer_info).expect("unwrap failed");
let memory_usage = kfac.estimate_memory_usage();
assert!(memory_usage > 0);
assert!(memory_usage > 1000000); }
#[test]
fn test_damping_adjustment() {
let config = KFACConfig::<f32> {
auto_damping: true,
target_acceptance_ratio: 0.75,
..Default::default()
};
let mut kfac = KFAC::new(config);
kfac.update_damping(1.0);
kfac.update_damping(0.9);
assert!(kfac.acceptance_ratio() > 1.0);
kfac.update_damping(1.1);
assert!(kfac.acceptance_ratio() < 1.2);
}
}