use scirs2_core::numeric::Float;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct KFACConfig<T: Float + Debug + Send + Sync + 'static> {
pub learning_rate: T,
pub damping: T,
pub weight_decay: T,
pub cov_update_freq: usize,
pub inv_update_freq: usize,
pub stat_decay: T,
pub min_eigenvalue: T,
pub max_inv_iterations: usize,
pub inv_tolerance: T,
pub use_tikhonov: bool,
pub auto_damping: bool,
pub target_acceptance_ratio: T,
}
impl<T: Float + Debug + Send + Sync + 'static> Default for KFACConfig<T> {
fn default() -> Self {
Self {
learning_rate: T::from(0.001).unwrap_or_else(|| T::zero()),
damping: T::from(0.001).unwrap_or_else(|| T::zero()),
weight_decay: T::from(0.0).unwrap_or_else(|| T::zero()),
cov_update_freq: 10,
inv_update_freq: 100,
stat_decay: T::from(0.95).unwrap_or_else(|| T::zero()),
min_eigenvalue: T::from(1e-7).unwrap_or_else(|| T::zero()),
max_inv_iterations: 50,
inv_tolerance: T::from(1e-6).unwrap_or_else(|| T::zero()),
use_tikhonov: true,
auto_damping: true,
target_acceptance_ratio: T::from(0.75).unwrap_or_else(|| T::zero()),
}
}
}
#[derive(Debug, Clone)]
pub struct LayerInfo {
pub name: String,
pub input_dim: usize,
pub output_dim: usize,
pub layer_type: LayerType,
pub has_bias: bool,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LayerType {
Dense,
Convolution,
GroupedConvolution { groups: usize },
Embedding,
BatchNorm,
}
#[derive(Debug, Clone, Default)]
pub struct KFACStats<T: Float + Debug + Send + Sync + 'static> {
pub total_steps: usize,
pub cov_updates: usize,
pub inv_updates: usize,
pub avg_condition_number: T,
pub time_cov_update: u64,
pub time_inv_update: u64,
pub time_gradient_update: u64,
pub memory_usage: usize,
}
impl<T: Float + Debug + Send + Sync + 'static> KFACConfig<T> {
pub fn for_large_model() -> Self {
Self {
cov_update_freq: 20,
inv_update_freq: 200,
stat_decay: T::from(0.99).unwrap_or_else(|| T::zero()),
damping: T::from(0.01).unwrap_or_else(|| T::zero()),
..Default::default()
}
}
pub fn for_small_model() -> Self {
Self {
cov_update_freq: 5,
inv_update_freq: 50,
stat_decay: T::from(0.9).unwrap_or_else(|| T::zero()),
damping: T::from(0.001).unwrap_or_else(|| T::zero()),
..Default::default()
}
}
pub fn for_stability() -> Self {
Self {
damping: T::from(0.1).unwrap_or_else(|| T::zero()),
min_eigenvalue: T::from(1e-5).unwrap_or_else(|| T::zero()),
auto_damping: false,
use_tikhonov: true,
..Default::default()
}
}
pub fn validate(&self) -> Result<(), String> {
if self.learning_rate <= T::zero() {
return Err("Learning rate must be positive".to_string());
}
if self.damping < T::zero() {
return Err("Damping must be non-negative".to_string());
}
if self.weight_decay < T::zero() {
return Err("Weight decay must be non-negative".to_string());
}
if self.cov_update_freq == 0 {
return Err("Covariance update frequency must be positive".to_string());
}
if self.inv_update_freq == 0 {
return Err("Inverse update frequency must be positive".to_string());
}
if self.stat_decay < T::zero() || self.stat_decay > T::one() {
return Err("Statistics decay must be between 0 and 1".to_string());
}
if self.min_eigenvalue <= T::zero() {
return Err("Minimum eigenvalue must be positive".to_string());
}
if self.inv_tolerance <= T::zero() {
return Err("Inverse tolerance must be positive".to_string());
}
if self.target_acceptance_ratio <= T::zero() || self.target_acceptance_ratio >= T::one() {
return Err("Target acceptance ratio must be between 0 and 1".to_string());
}
Ok(())
}
}
impl LayerInfo {
pub fn dense(name: String, input_dim: usize, output_dim: usize, has_bias: bool) -> Self {
Self {
name,
input_dim,
output_dim,
layer_type: LayerType::Dense,
has_bias,
}
}
pub fn convolution(name: String, input_dim: usize, output_dim: usize, has_bias: bool) -> Self {
Self {
name,
input_dim,
output_dim,
layer_type: LayerType::Convolution,
has_bias,
}
}
pub fn input_cov_size(&self) -> usize {
match self.layer_type {
LayerType::Dense => {
if self.has_bias {
self.input_dim + 1
} else {
self.input_dim
}
}
LayerType::Convolution => {
self.input_dim
}
LayerType::GroupedConvolution { .. } => self.input_dim,
LayerType::Embedding => self.input_dim,
LayerType::BatchNorm => self.input_dim,
}
}
pub fn output_cov_size(&self) -> usize {
self.output_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = KFACConfig::<f32>::default();
assert!(config.learning_rate > 0.0);
assert!(config.damping >= 0.0);
assert!(config.cov_update_freq > 0);
assert!(config.inv_update_freq > 0);
assert!(config.validate().is_ok());
}
#[test]
fn test_config_presets() {
let large_config = KFACConfig::<f64>::for_large_model();
assert!(large_config.validate().is_ok());
assert_eq!(large_config.cov_update_freq, 20);
let small_config = KFACConfig::<f64>::for_small_model();
assert!(small_config.validate().is_ok());
assert_eq!(small_config.cov_update_freq, 5);
let stable_config = KFACConfig::<f64>::for_stability();
assert!(stable_config.validate().is_ok());
assert!(!stable_config.auto_damping);
}
#[test]
fn test_layer_info_creation() {
let dense_layer = LayerInfo::dense("layer1".to_string(), 128, 64, true);
assert_eq!(dense_layer.layer_type, LayerType::Dense);
assert_eq!(dense_layer.input_cov_size(), 129);
let conv_layer = LayerInfo::convolution("conv1".to_string(), 64, 32, false);
assert_eq!(conv_layer.layer_type, LayerType::Convolution);
assert_eq!(conv_layer.input_cov_size(), 64);
}
#[test]
fn test_config_validation() {
let mut config = KFACConfig::<f32> {
learning_rate: -0.1,
..Default::default()
};
assert!(config.validate().is_err());
config.learning_rate = 0.001;
config.damping = -0.1;
assert!(config.validate().is_err());
config.damping = 0.001;
config.stat_decay = 1.5;
assert!(config.validate().is_err());
config.stat_decay = 0.95;
assert!(config.validate().is_ok());
}
}