#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NormalNormalPosterior {
pub mean: f64,
pub var: f64,
}
const MIN_OBS_VAR: f64 = 1e-12;
pub fn posterior_normal_normal(
prior_mean: f64,
prior_var: f64,
observations: &[(f64, f64)],
) -> NormalNormalPosterior {
let prior_var_safe = if prior_var > 0.0 {
prior_var
} else {
MIN_OBS_VAR
};
if observations.is_empty() {
return NormalNormalPosterior {
mean: prior_mean,
var: prior_var_safe,
};
}
let prior_precision = 1.0 / prior_var_safe;
let mut total_precision = prior_precision;
let mut weighted_sum = prior_precision * prior_mean;
for &(y, obs_var) in observations {
let obs_var_safe = if obs_var > 0.0 { obs_var } else { MIN_OBS_VAR };
let obs_precision = 1.0 / obs_var_safe;
total_precision += obs_precision;
weighted_sum += obs_precision * y;
}
let mean = weighted_sum / total_precision;
let var = 1.0 / total_precision;
NormalNormalPosterior { mean, var }
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use proptest::prelude::*;
fn reference_sequential(
prior_mean: f64,
prior_var: f64,
observations: &[(f64, f64)],
) -> NormalNormalPosterior {
let mut mean = prior_mean;
let mut var = if prior_var > 0.0 {
prior_var
} else {
MIN_OBS_VAR
};
for &(y, obs_var) in observations {
let v = if obs_var > 0.0 { obs_var } else { MIN_OBS_VAR };
let inv_var = 1.0 / var;
let inv_v = 1.0 / v;
let new_inv_var = inv_var + inv_v;
mean = (inv_var * mean + inv_v * y) / new_inv_var;
var = 1.0 / new_inv_var;
}
NormalNormalPosterior { mean, var }
}
#[test]
fn empty_observations_returns_prior() {
let post = posterior_normal_normal(2.5, 4.0, &[]);
assert_relative_eq!(post.mean, 2.5);
assert_relative_eq!(post.var, 4.0);
}
#[test]
fn single_observation_matches_two_gaussian_fusion() {
let post = posterior_normal_normal(0.0, 1.0, &[(2.0, 1.0)]);
assert_relative_eq!(post.mean, 1.0);
assert_relative_eq!(post.var, 0.5);
}
#[test]
fn three_unit_observations_against_unit_prior() {
let post = posterior_normal_normal(0.0, 1.0, &[(1.0, 1.0), (2.0, 1.0), (3.0, 1.0)]);
assert_relative_eq!(post.mean, 1.5);
assert_relative_eq!(post.var, 0.25);
}
#[test]
fn heteroscedastic_observations_weight_by_precision() {
let post = posterior_normal_normal(0.0, 1.0, &[(10.0, 1.0), (0.0, 100.0)]);
assert_relative_eq!(post.mean, 10.0 / 2.01, epsilon = 1e-12);
assert_relative_eq!(post.var, 1.0 / 2.01, epsilon = 1e-12);
}
#[test]
fn posterior_var_decreases_with_more_observations() {
let prior_mean = 0.0;
let prior_var = 1.0;
let zero = posterior_normal_normal(prior_mean, prior_var, &[]).var;
let one = posterior_normal_normal(prior_mean, prior_var, &[(0.0, 1.0)]).var;
let three =
posterior_normal_normal(prior_mean, prior_var, &[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)])
.var;
assert!(zero > one, "expected {zero} > {one}");
assert!(one > three, "expected {one} > {three}");
}
#[test]
fn obs_var_le_zero_is_clipped_not_panic() {
let post = posterior_normal_normal(0.0, 1.0, &[(7.0, 0.0)]);
assert!(post.mean.is_finite());
assert!(post.var.is_finite());
assert_relative_eq!(post.mean, 7.0, epsilon = 1e-9);
assert!(post.var < 1e-10);
let post2 = posterior_normal_normal(0.0, 1.0, &[(7.0, -1.0)]);
assert!(post2.mean.is_finite());
assert!(post2.var.is_finite());
}
#[test]
fn prior_var_le_zero_is_also_clipped() {
let post = posterior_normal_normal(0.0, 0.0, &[(5.0, 1.0)]);
assert!(post.mean.is_finite());
assert!(post.var.is_finite());
assert!(post.mean.abs() < 1e-9);
}
#[test]
fn matches_reference_sequential_implementation() {
type HandCase = (f64, f64, &'static [(f64, f64)]);
let cases: &[HandCase] = &[
(0.0, 1.0, &[(1.0, 1.0)]),
(5.0, 4.0, &[(3.0, 1.0), (7.0, 2.0)]),
(-2.0, 0.25, &[(0.0, 1.0), (-1.0, 0.5), (1.0, 0.5)]),
(
0.0,
100.0,
&[(1.0, 0.01), (2.0, 0.01), (3.0, 0.01), (4.0, 0.01)],
),
];
for (m0, v0, obs) in cases {
let direct = posterior_normal_normal(*m0, *v0, obs);
let ref_seq = reference_sequential(*m0, *v0, obs);
assert_relative_eq!(direct.mean, ref_seq.mean, epsilon = 1e-10);
assert_relative_eq!(direct.var, ref_seq.var, epsilon = 1e-10);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10_000))]
#[test]
fn batch_form_matches_sequential_form(
prior_mean in -100.0_f64..100.0,
prior_var in 1e-3_f64..100.0,
obs in proptest::collection::vec((-100.0_f64..100.0, 1e-3_f64..100.0), 0..16),
) {
let direct = posterior_normal_normal(prior_mean, prior_var, &obs);
let ref_seq = reference_sequential(prior_mean, prior_var, &obs);
prop_assert!(
(direct.mean - ref_seq.mean).abs() < 1e-9,
"mean mismatch: direct={} ref={}",
direct.mean,
ref_seq.mean,
);
prop_assert!(
(direct.var - ref_seq.var).abs() < 1e-9,
"var mismatch: direct={} ref={}",
direct.var,
ref_seq.var,
);
}
#[test]
fn posterior_var_is_non_increasing(
prior_mean in -10.0_f64..10.0,
prior_var in 1e-3_f64..10.0,
obs in proptest::collection::vec((-10.0_f64..10.0, 1e-3_f64..10.0), 1..8),
) {
let prior_only = posterior_normal_normal(prior_mean, prior_var, &[]);
let with_obs = posterior_normal_normal(prior_mean, prior_var, &obs);
prop_assert!(
with_obs.var < prior_only.var,
"var did not decrease: with_obs={} prior_only={}",
with_obs.var,
prior_only.var,
);
}
#[test]
fn posterior_mean_in_convex_hull(
prior_mean in -10.0_f64..10.0,
prior_var in 1e-3_f64..10.0,
obs in proptest::collection::vec((-10.0_f64..10.0, 1e-3_f64..10.0), 1..8),
) {
let post = posterior_normal_normal(prior_mean, prior_var, &obs);
let mut all_means: Vec<f64> = obs.iter().map(|&(y, _)| y).collect();
all_means.push(prior_mean);
let lo = all_means.iter().cloned().fold(f64::INFINITY, f64::min);
let hi = all_means.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
prop_assert!(
post.mean >= lo - 1e-9,
"post.mean={} below hull min={}",
post.mean,
lo,
);
prop_assert!(
post.mean <= hi + 1e-9,
"post.mean={} above hull max={}",
post.mean,
hi,
);
}
#[test]
fn empty_obs_is_identity(
prior_mean in -100.0_f64..100.0,
prior_var in 1e-3_f64..100.0,
) {
let post = posterior_normal_normal(prior_mean, prior_var, &[]);
prop_assert_eq!(post.mean, prior_mean);
prop_assert_eq!(post.var, prior_var);
}
}
}