use crate::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Gaussian {
pub mean: f64,
pub std: f64,
}
impl Gaussian {
pub fn new(mean: f64, std: f64) -> Result<Self> {
if !mean.is_finite() {
return Err(Error::Domain("mean must be finite"));
}
if !std.is_finite() || std < 0.0 {
return Err(Error::Domain("std must be finite and non-negative"));
}
Ok(Self { mean, std })
}
pub fn point(value: f64) -> Result<Self> {
Self::new(value, 0.0)
}
}
pub struct DifferentiableFunc {
pub f: Box<dyn Fn(f64) -> f64>,
pub df: Box<dyn Fn(f64) -> f64>,
}
pub fn propagate_linearized(input: &Gaussian, func: &DifferentiableFunc) -> Gaussian {
let output_mean = (func.f)(input.mean);
let jacobian = (func.df)(input.mean);
let output_std = (jacobian * input.std).abs();
Gaussian {
mean: output_mean,
std: output_std,
}
}
pub fn propagate_unscented(input: &Gaussian, f: impl Fn(f64) -> f64, kappa: f64) -> Gaussian {
let n = 1.0; let lambda = kappa;
let scale = (n + lambda).sqrt();
let spread = scale * input.std;
let x0 = input.mean;
let x1 = input.mean + spread;
let x2 = input.mean - spread;
let y0 = f(x0);
let y1 = f(x1);
let y2 = f(x2);
let w0_mean = lambda / (n + lambda);
let w0_cov = w0_mean; let wi = 1.0 / (2.0 * (n + lambda));
let output_mean = w0_mean * y0 + wi * y1 + wi * y2;
let output_var = w0_cov * (y0 - output_mean).powi(2)
+ wi * (y1 - output_mean).powi(2)
+ wi * (y2 - output_mean).powi(2);
Gaussian {
mean: output_mean,
std: output_var.sqrt(),
}
}
pub fn propagate_elementwise(
means: &[f64],
stds: &[f64],
func: &DifferentiableFunc,
) -> Result<(Vec<f64>, Vec<f64>)> {
if means.len() != stds.len() {
return Err(Error::LengthMismatch(means.len(), stds.len()));
}
let mut out_means = Vec::with_capacity(means.len());
let mut out_stds = Vec::with_capacity(stds.len());
for (&m, &s) in means.iter().zip(stds.iter()) {
let g = Gaussian { mean: m, std: s };
let out = propagate_linearized(&g, func);
out_means.push(out.mean);
out_stds.push(out.std);
}
Ok((out_means, out_stds))
}
pub mod activations {
use super::DifferentiableFunc;
pub fn relu() -> DifferentiableFunc {
DifferentiableFunc {
f: Box::new(|x| x.max(0.0)),
df: Box::new(|x| if x > 0.0 { 1.0 } else { 0.0 }),
}
}
pub fn sigmoid() -> DifferentiableFunc {
DifferentiableFunc {
f: Box::new(|x| 1.0 / (1.0 + (-x).exp())),
df: Box::new(|x| {
let s = 1.0 / (1.0 + (-x).exp());
s * (1.0 - s)
}),
}
}
pub fn tanh_act() -> DifferentiableFunc {
DifferentiableFunc {
f: Box::new(|x| x.tanh()),
df: Box::new(|x| 1.0 - x.tanh().powi(2)),
}
}
pub fn softplus() -> DifferentiableFunc {
DifferentiableFunc {
f: Box::new(|x| (1.0 + x.exp()).ln()),
df: Box::new(|x| 1.0 / (1.0 + (-x).exp())),
}
}
pub fn square() -> DifferentiableFunc {
DifferentiableFunc {
f: Box::new(|x| x * x),
df: Box::new(|x| 2.0 * x),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identity_propagation() {
let g = Gaussian::new(1.0, 0.5).unwrap();
let id = DifferentiableFunc {
f: Box::new(|x| x),
df: Box::new(|_| 1.0),
};
let out = propagate_linearized(&g, &id);
assert!((out.mean - 1.0).abs() < 1e-12);
assert!((out.std - 0.5).abs() < 1e-12);
}
#[test]
fn linear_scaling() {
let g = Gaussian::new(2.0, 1.0).unwrap();
let scale3 = DifferentiableFunc {
f: Box::new(|x| 3.0 * x),
df: Box::new(|_| 3.0),
};
let out = propagate_linearized(&g, &scale3);
assert!((out.mean - 6.0).abs() < 1e-12);
assert!((out.std - 3.0).abs() < 1e-12);
}
#[test]
fn sigmoid_narrows_variance() {
let g = Gaussian::new(0.0, 1.0).unwrap();
let sig = activations::sigmoid();
let out = propagate_linearized(&g, &sig);
assert!((out.mean - 0.5).abs() < 1e-12);
assert!((out.std - 0.25).abs() < 1e-12);
}
#[test]
fn relu_at_positive_mean() {
let g = Gaussian::new(3.0, 0.5).unwrap();
let relu = activations::relu();
let out = propagate_linearized(&g, &relu);
assert!((out.mean - 3.0).abs() < 1e-12);
assert!((out.std - 0.5).abs() < 1e-12);
}
#[test]
fn unscented_linear_exact() {
let g = Gaussian::new(2.0, 1.5).unwrap();
let out = propagate_unscented(&g, |x| 3.0 * x + 1.0, 1.0);
assert!((out.mean - 7.0).abs() < 1e-10);
assert!((out.std - 4.5).abs() < 1e-10);
}
#[test]
fn unscented_vs_linearized_square() {
let g = Gaussian::new(0.0, 1.0).unwrap();
let lin = propagate_linearized(&g, &activations::square());
let uns = propagate_unscented(&g, |x| x * x, 1.0);
assert!((lin.mean - 0.0).abs() < 1e-12);
assert!(
(uns.mean - 1.0).abs() < 0.1,
"unscented mean = {}",
uns.mean
);
}
#[test]
fn elementwise_propagation() {
let means = [1.0, 2.0, 3.0];
let stds = [0.1, 0.2, 0.3];
let scale = DifferentiableFunc {
f: Box::new(|x| 2.0 * x),
df: Box::new(|_| 2.0),
};
let (out_m, out_s) = propagate_elementwise(&means, &stds, &scale).unwrap();
assert!((out_m[0] - 2.0).abs() < 1e-12);
assert!((out_s[1] - 0.4).abs() < 1e-12);
}
#[test]
fn point_mass_propagation() {
let g = Gaussian::point(5.0).unwrap();
let sig = activations::sigmoid();
let out = propagate_linearized(&g, &sig);
assert!(out.std.abs() < 1e-15, "point mass should have zero std");
}
#[test]
fn length_mismatch_error() {
let scale = DifferentiableFunc {
f: Box::new(|x| x),
df: Box::new(|_| 1.0),
};
let r = propagate_elementwise(&[1.0, 2.0], &[1.0], &scale);
assert!(r.is_err());
}
}