use ndarray::prelude::*;
#[derive(Debug, Copy, Clone)]
pub struct Deviation {
latest_mse : f64,
recent_mse : f64,
recent_factor: f64,
}
impl Deviation {
pub fn new(recent_factor: f64) -> Self {
assert!(0.0 < recent_factor && recent_factor < 1.0);
Deviation{
latest_mse : 0.0,
recent_mse : 1.0,
recent_factor: recent_factor,
}
}
fn update_mse<F>(&mut self, actual: ArrayView1<F>, expected: ArrayView1<F>)
where F: NdFloat
{
use std::ops::Div;
use itertools::multizip;
self.latest_mse = multizip((actual.iter(), expected.iter()))
.map(|(&actual, &expected)| {
let dx = expected - actual;
(dx * dx).to_f64().unwrap()
})
.sum::<f64>()
.div(actual.len() as f64)
.sqrt();
}
fn update_recent_mse(&mut self) {
self.recent_mse = self.recent_factor * self.recent_mse
+ (1.0 - self.recent_factor) * self.latest_mse;
}
pub fn update<F>(&mut self, actual: ArrayView1<F>, expected: ArrayView1<F>)
where F: NdFloat
{
self.update_mse(actual, expected);
self.update_recent_mse();
}
pub fn latest_mse(&self) -> f64 {
self.latest_mse
}
pub fn recent_mse(&self) -> f64 {
self.recent_mse
}
}
impl Default for Deviation {
fn default() -> Self {
Deviation::new(0.95)
}
}