use super::FisherInformation;
use crate::error::{MathError, Result};
use crate::utils::EPS;
#[derive(Debug, Clone)]
pub struct NaturalGradient {
learning_rate: f64,
damping: f64,
use_diagonal: bool,
ema_factor: f64,
fim_estimate: Option<FimEstimate>,
}
#[derive(Debug, Clone)]
enum FimEstimate {
Full(Vec<Vec<f64>>),
Diagonal(Vec<f64>),
}
impl NaturalGradient {
pub fn new(learning_rate: f64) -> Self {
Self {
learning_rate: learning_rate.max(EPS),
damping: 1e-4,
use_diagonal: false,
ema_factor: 0.9,
fim_estimate: None,
}
}
pub fn with_damping(mut self, damping: f64) -> Self {
self.damping = damping.max(EPS);
self
}
pub fn with_diagonal(mut self, use_diagonal: bool) -> Self {
self.use_diagonal = use_diagonal;
self
}
pub fn with_ema(mut self, ema: f64) -> Self {
self.ema_factor = ema.clamp(0.0, 1.0);
self
}
pub fn step(
&mut self,
gradient: &[f64],
gradient_samples: Option<&[Vec<f64>]>,
) -> Result<Vec<f64>> {
if let Some(samples) = gradient_samples {
self.update_fim(samples)?;
}
let nat_grad = match &self.fim_estimate {
Some(FimEstimate::Full(fim)) => {
let fisher = FisherInformation::new().with_damping(self.damping);
fisher.natural_gradient(fim, gradient)?
}
Some(FimEstimate::Diagonal(diag)) => {
gradient
.iter()
.zip(diag.iter())
.map(|(&g, &d)| g / (d + self.damping))
.collect()
}
None => {
gradient.to_vec()
}
};
Ok(nat_grad.iter().map(|&g| -self.learning_rate * g).collect())
}
fn update_fim(&mut self, gradient_samples: &[Vec<f64>]) -> Result<()> {
let fisher = FisherInformation::new().with_damping(0.0);
if self.use_diagonal {
let new_diag = fisher.diagonal_fim(gradient_samples)?;
self.fim_estimate = Some(FimEstimate::Diagonal(match &self.fim_estimate {
Some(FimEstimate::Diagonal(old)) => {
old.iter()
.zip(new_diag.iter())
.map(|(&o, &n)| self.ema_factor * o + (1.0 - self.ema_factor) * n)
.collect()
}
_ => new_diag,
}));
} else {
let new_fim = fisher.empirical_fim(gradient_samples)?;
let dim = new_fim.len();
self.fim_estimate = Some(FimEstimate::Full(match &self.fim_estimate {
Some(FimEstimate::Full(old)) if old.len() == dim => {
(0..dim)
.map(|i| {
(0..dim)
.map(|j| {
self.ema_factor * old[i][j]
+ (1.0 - self.ema_factor) * new_fim[i][j]
})
.collect()
})
.collect()
}
_ => new_fim,
}));
}
Ok(())
}
pub fn apply_update(parameters: &mut [f64], update: &[f64]) -> Result<()> {
if parameters.len() != update.len() {
return Err(MathError::dimension_mismatch(
parameters.len(),
update.len(),
));
}
for (p, &u) in parameters.iter_mut().zip(update.iter()) {
*p += u;
}
Ok(())
}
pub fn optimize_step(
&mut self,
parameters: &mut [f64],
gradient: &[f64],
gradient_samples: Option<&[Vec<f64>]>,
) -> Result<f64> {
let update = self.step(gradient, gradient_samples)?;
let update_norm: f64 = update.iter().map(|&u| u * u).sum::<f64>().sqrt();
Self::apply_update(parameters, &update)?;
Ok(update_norm)
}
pub fn reset(&mut self) {
self.fim_estimate = None;
}
}
#[derive(Debug, Clone)]
pub struct DiagonalNaturalGradient {
learning_rate: f64,
damping: f64,
accumulator: Vec<f64>,
}
impl DiagonalNaturalGradient {
pub fn new(learning_rate: f64, dim: usize) -> Self {
Self {
learning_rate: learning_rate.max(EPS),
damping: 1e-8,
accumulator: vec![0.0; dim],
}
}
pub fn with_damping(mut self, damping: f64) -> Self {
self.damping = damping.max(EPS);
self
}
pub fn step(&mut self, parameters: &mut [f64], gradient: &[f64]) -> Result<f64> {
if parameters.len() != gradient.len() || parameters.len() != self.accumulator.len() {
return Err(MathError::dimension_mismatch(
parameters.len(),
gradient.len(),
));
}
let mut update_norm_sq = 0.0;
for (i, (p, &g)) in parameters.iter_mut().zip(gradient.iter()).enumerate() {
self.accumulator[i] += g * g;
let update = -self.learning_rate * g / (self.accumulator[i].sqrt() + self.damping);
*p += update;
update_norm_sq += update * update;
}
Ok(update_norm_sq.sqrt())
}
pub fn reset(&mut self) {
self.accumulator.fill(0.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_natural_gradient_step() {
let mut ng = NaturalGradient::new(0.1).with_diagonal(true);
let gradient = vec![1.0, 2.0, 3.0];
let update = ng.step(&gradient, None).unwrap();
assert_eq!(update.len(), 3);
assert!((update[0] + 0.1).abs() < 1e-10);
}
#[test]
fn test_natural_gradient_with_fim() {
let mut ng = NaturalGradient::new(0.1)
.with_diagonal(true)
.with_damping(0.0);
let gradient = vec![2.0, 4.0];
let samples = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let update = ng.step(&gradient, Some(&samples)).unwrap();
assert_eq!(update.len(), 2);
}
#[test]
fn test_diagonal_natural_gradient() {
let mut dng = DiagonalNaturalGradient::new(1.0, 2);
let mut params = vec![0.0, 0.0];
let gradient = vec![1.0, 2.0];
let norm = dng.step(&mut params, &gradient).unwrap();
assert!(norm > 0.0);
assert!(params[0] < 0.0); }
#[test]
fn test_optimizer_reset() {
let mut ng = NaturalGradient::new(0.1);
let samples = vec![vec![1.0, 2.0]];
let _ = ng.step(&[1.0, 1.0], Some(&samples));
ng.reset();
assert!(ng.fim_estimate.is_none());
}
}