pub trait Convergeable: Clone + Sized {
fn zero_like(&self) -> Self;
fn weighted_add(&self, other: &Self, self_weight: f32, other_weight: f32) -> Self;
fn mul_elem(&self, other: &Self) -> Self;
fn div_elem(&self, other: &Self) -> Self;
fn add_elem(&self, other: &Self) -> Self;
fn sub_elem(&self, other: &Self) -> Self;
fn scale(&self, scalar: f32) -> Self;
fn sqrt_elem(&self) -> Self;
fn to_weighted(&self) -> Self;
fn weights(&self) -> Self;
}
#[derive(Debug)]
pub struct ConvergenceTracker<T: Convergeable> {
i: usize, m: T, s: T, w: T, }
impl<T: Convergeable> ConvergenceTracker<T> {
pub fn new(template: &T) -> Self {
Self {
i: 0,
m: template.zero_like(),
s: template.zero_like(),
w: template.zero_like(),
}
}
pub fn update(&mut self, result: &T) {
self.i += 1;
let value = result.to_weighted();
let weight = result.weights();
if self.i == 1 {
self.m = value;
self.w = weight;
} else {
let delta = value.sub_elem(&self.m);
self.m = self.m.add_elem(&delta.scale(1.0 / self.i as f32));
let delta_sq = delta.mul_elem(&delta);
let factor = (self.i - 1) as f32 / self.i as f32;
self.s = self.s.add_elem(&delta_sq.scale(factor));
let dw = weight.sub_elem(&self.w);
self.w = self.w.add_elem(&dw.scale(1.0 / self.i as f32));
}
}
pub fn count(&self) -> usize {
self.i
}
pub fn mean(&self) -> T {
if self.i == 0 {
return self.m.zero_like();
}
self.m.div_elem(&self.w)
}
pub fn sem(&self) -> T {
if self.i < 2 {
return self.m.zero_like();
}
let n_minus_1 = (self.i - 1) as f32;
let w_sq = self.w.mul_elem(&self.w);
self.s
.scale(1.0 / (n_minus_1 * n_minus_1))
.div_elem(&w_sq)
.sqrt_elem()
}
pub fn merge(&mut self, other: &Self) {
if other.i == 0 {
return;
}
if self.i == 0 {
self.i = other.i;
self.m = other.m.clone();
self.s = other.s.clone();
self.w = other.w.clone();
return;
}
let n_a = self.i as f32;
let n_b = other.i as f32;
let n = n_a + n_b;
let delta = other.m.sub_elem(&self.m);
self.m = self.m.add_elem(&delta.scale(n_b / n));
let delta_sq = delta.mul_elem(&delta);
self.s = self
.s
.add_elem(&other.s)
.add_elem(&delta_sq.scale(n_a * n_b / n));
let dw = other.w.sub_elem(&self.w);
self.w = self.w.add_elem(&dw.scale(n_b / n));
self.i = (n_a + n_b) as usize;
}
}