use crate::BoxError;
use serde::Serialize;
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct GaussianBox {
mu: Vec<f32>,
sigma: Vec<f32>,
}
impl GaussianBox {
pub fn new(mu: Vec<f32>, sigma: Vec<f32>) -> Result<Self, BoxError> {
if mu.len() != sigma.len() {
return Err(BoxError::DimensionMismatch {
expected: mu.len(),
actual: sigma.len(),
});
}
for (i, &s) in sigma.iter().enumerate() {
if !s.is_finite() || s <= 0.0 {
return Err(BoxError::InvalidBounds {
dim: i,
min: 0.0,
max: s as f64,
});
}
}
for (i, &m) in mu.iter().enumerate() {
if !m.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: m as f64,
max: m as f64,
});
}
}
Ok(Self { mu, sigma })
}
#[must_use]
pub fn dim(&self) -> usize {
self.mu.len()
}
pub fn mu(&self) -> &[f32] {
&self.mu
}
pub fn sigma(&self) -> &[f32] {
&self.sigma
}
#[must_use]
pub fn log_volume(&self) -> f32 {
self.sigma.iter().map(|s| s.ln()).sum()
}
pub fn from_center_offset(center: Vec<f32>, offset: Vec<f32>) -> Result<Self, BoxError> {
if center.len() != offset.len() {
return Err(BoxError::DimensionMismatch {
expected: center.len(),
actual: offset.len(),
});
}
for (i, &c) in center.iter().enumerate() {
if !c.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: c as f64,
max: c as f64,
});
}
}
for (i, &o) in offset.iter().enumerate() {
if !o.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: o as f64,
max: o as f64,
});
}
}
let sigma: Vec<f32> = offset
.iter()
.map(|&o| {
if o > 20.0 {
o
} else if o < -20.0 {
1e-7
} else {
o.exp().ln_1p()
}
})
.collect();
Ok(Self { mu: center, sigma })
}
}
impl<'de> serde::Deserialize<'de> for GaussianBox {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct Raw {
mu: Vec<f32>,
sigma: Vec<f32>,
}
let raw = Raw::deserialize(deserializer)?;
GaussianBox::new(raw.mu, raw.sigma).map_err(serde::de::Error::custom)
}
}
pub fn kl_divergence(child: &GaussianBox, parent: &GaussianBox) -> Result<f32, BoxError> {
if child.dim() != parent.dim() {
return Err(BoxError::DimensionMismatch {
expected: child.dim(),
actual: parent.dim(),
});
}
const EPS: f32 = 1e-7;
let mut sum = 0.0f32;
for i in 0..child.dim() {
let sc = child.sigma[i].max(EPS);
let sp = parent.sigma[i].max(EPS);
let dm = parent.mu[i] - child.mu[i];
let ratio_sq = (sc / sp).powi(2);
let mean_sq = dm * dm / (sp * sp);
let log_ratio = 2.0 * (sp / sc).ln();
sum += ratio_sq + mean_sq - 1.0 + log_ratio;
}
Ok(0.5 * sum)
}
pub fn bhattacharyya_distance(a: &GaussianBox, b: &GaussianBox) -> Result<f32, BoxError> {
if a.dim() != b.dim() {
return Err(BoxError::DimensionMismatch {
expected: a.dim(),
actual: b.dim(),
});
}
const EPS: f32 = 1e-7;
let mut term1 = 0.0f32;
let mut term2 = 0.0f32;
for i in 0..a.dim() {
let s1 = a.sigma[i].max(EPS);
let s2 = b.sigma[i].max(EPS);
let s1_sq = s1 * s1;
let s2_sq = s2 * s2;
let sigma_m = (s1_sq + s2_sq) / 2.0;
let dm = a.mu[i] - b.mu[i];
term1 += dm * dm / sigma_m;
term2 += sigma_m.ln() - 0.5 * (s1_sq.ln() + s2_sq.ln());
}
Ok(0.125 * term1 + 0.5 * term2)
}
pub fn bhattacharyya_coefficient(a: &GaussianBox, b: &GaussianBox) -> Result<f32, BoxError> {
Ok((-bhattacharyya_distance(a, b)?).exp())
}
#[must_use]
pub fn volume_regularization(g: &GaussianBox, min_var: f32) -> f32 {
let d = g.sigma.len();
if d == 0 {
return 0.0;
}
let sum: f32 = g
.sigma
.iter()
.map(|&s| {
let deficit = (min_var - s * s).max(0.0);
deficit * deficit
})
.sum();
sum / d as f32
}
#[must_use]
pub fn sigma_ceiling_loss(g: &GaussianBox, max_var: f32) -> f32 {
let d = g.sigma.len();
if d == 0 {
return 0.0;
}
let sum: f32 = g.sigma.iter().map(|&s| (s * s - max_var).max(0.0)).sum();
sum / d as f32
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
fn arb_gaussian(dim: usize) -> impl Strategy<Value = GaussianBox> {
let mus = prop::collection::vec(-100.0f32..100.0, dim);
let sigmas = prop::collection::vec(0.01f32..100.0, dim);
(mus, sigmas).prop_map(|(mu, sigma)| GaussianBox::new(mu, sigma).unwrap())
}
fn arb_gaussian_pair(dim: usize) -> impl Strategy<Value = (GaussianBox, GaussianBox)> {
(arb_gaussian(dim), arb_gaussian(dim))
}
proptest! {
#[test]
fn prop_kl_nonnegative(
(a, b) in arb_gaussian_pair(8)
) {
let kl = kl_divergence(&a, &b).unwrap();
prop_assert!(kl >= -1e-5, "KL divergence should be non-negative, got {}", kl);
}
#[test]
fn prop_kl_identical_is_zero(
g in arb_gaussian(8)
) {
let kl = kl_divergence(&g, &g).unwrap();
prop_assert!(kl.abs() < 1e-4, "KL(g, g) should be 0, got {}", kl);
}
#[test]
fn prop_kl_asymmetric(
(a, b) in arb_gaussian_pair(4)
) {
let kl_ab = kl_divergence(&a, &b).unwrap();
let kl_ba = kl_divergence(&b, &a).unwrap();
prop_assert!(kl_ab >= -1e-5);
prop_assert!(kl_ba >= -1e-5);
}
#[test]
fn prop_bc_symmetric(
(a, b) in arb_gaussian_pair(8)
) {
let bc_ab = bhattacharyya_coefficient(&a, &b).unwrap();
let bc_ba = bhattacharyya_coefficient(&b, &a).unwrap();
prop_assert!(
(bc_ab - bc_ba).abs() < 1e-5,
"BC should be symmetric: {} != {}", bc_ab, bc_ba
);
}
#[test]
fn prop_bc_in_unit_interval(
(a, b) in arb_gaussian_pair(8)
) {
let bc = bhattacharyya_coefficient(&a, &b).unwrap();
prop_assert!((-1e-6..=1.0 + 1e-6).contains(&bc),
"BC should be in [0, 1], got {}", bc);
}
#[test]
fn prop_bc_identical_is_one(
g in arb_gaussian(8)
) {
let bc = bhattacharyya_coefficient(&g, &g).unwrap();
prop_assert!((bc - 1.0).abs() < 1e-4,
"BC(g, g) should be 1.0, got {}", bc);
}
#[test]
fn prop_volume_regularization_nonneg(
g in arb_gaussian(8),
target in -10.0f32..10.0,
) {
let loss = volume_regularization(&g, target);
prop_assert!(loss >= 0.0, "Volume regularization should be non-negative, got {}", loss);
}
#[test]
fn prop_from_center_offset_positive_sigma(
center in prop::collection::vec(-100.0f32..100.0, 8),
offset in prop::collection::vec(-50.0f32..50.0, 8),
) {
let g = GaussianBox::from_center_offset(center, offset).unwrap();
for (i, &s) in g.sigma().iter().enumerate() {
prop_assert!(s > 0.0, "sigma[{}] should be positive, got {}", i, s);
}
}
#[test]
fn prop_new_rejects_nonpositive_sigma(
mu in prop::collection::vec(-10.0f32..10.0, 1..=8usize),
) {
let dim = mu.len();
let sigma = vec![0.0f32; dim];
prop_assert!(GaussianBox::new(mu.clone(), sigma).is_err());
let sigma_neg = vec![-1.0f32; dim];
prop_assert!(GaussianBox::new(mu, sigma_neg).is_err());
}
#[test]
fn prop_bhattacharyya_distance_nonneg(
(a, b) in arb_gaussian_pair(8)
) {
let bd = bhattacharyya_distance(&a, &b).unwrap();
prop_assert!(bd >= -1e-5, "BD should be non-negative, got {bd}");
}
#[test]
fn prop_sigma_ceiling_nonneg(
g in arb_gaussian(8),
max_var in 0.01f32..100.0,
) {
let loss = sigma_ceiling_loss(&g, max_var);
prop_assert!(loss >= 0.0, "sigma_ceiling_loss should be non-negative, got {loss}");
}
#[test]
fn sigma_ceiling_is_linear_hinge(
base_sigma in 1.1f32..10.0,
max_var in 0.01f32..1.0,
) {
let g1 = GaussianBox::new(vec![0.0], vec![base_sigma]).unwrap();
let doubled_var = 2.0 * base_sigma * base_sigma - max_var;
if doubled_var <= 0.0 {
return Ok(());
}
let g2 = GaussianBox::new(vec![0.0], vec![doubled_var.sqrt()]).unwrap();
let loss1 = sigma_ceiling_loss(&g1, max_var);
let loss2 = sigma_ceiling_loss(&g2, max_var);
let ratio = loss2 / loss1;
prop_assert!(
(ratio - 2.0).abs() < 0.01,
"linear hinge ratio should be ~2.0, got {ratio} (loss1={loss1}, loss2={loss2})"
);
}
#[test]
fn prop_volume_regularization_is_per_dim_squared_hinge(
g in arb_gaussian(8),
min_var in 0.01f32..10.0,
) {
let loss = volume_regularization(&g, min_var);
prop_assert!(loss >= 0.0, "volume_regularization should be non-negative, got {loss}");
let d = g.sigma.len() as f32;
let expected: f32 = g.sigma.iter()
.map(|&s| { let deficit = (min_var - s * s).max(0.0); deficit * deficit })
.sum::<f32>() / d;
prop_assert!(
(loss - expected).abs() < 1e-3,
"volume_regularization mismatch: {loss} vs expected {expected}"
);
}
}
#[test]
fn test_high_dim_256() {
let a = GaussianBox::new(vec![0.0; 256], vec![1.0; 256]).unwrap();
let b = GaussianBox::new(vec![0.0; 256], vec![1.0; 256]).unwrap();
let kl = kl_divergence(&a, &b).unwrap();
assert!(
kl.abs() < 1e-4,
"KL of identical 256-d unit Gaussians: {kl}"
);
let bc = bhattacharyya_coefficient(&a, &b).unwrap();
assert!((bc - 1.0).abs() < 1e-4, "BC of identical 256-d: {bc}");
}
#[test]
fn test_high_dim_1024() {
let a = GaussianBox::new(vec![0.0; 1024], vec![1.0; 1024]).unwrap();
let b = GaussianBox::new(vec![0.0; 1024], vec![1.0; 1024]).unwrap();
let kl = kl_divergence(&a, &b).unwrap();
assert!(
kl.abs() < 1e-3,
"KL of identical 1024-d unit Gaussians: {kl}"
);
}
#[test]
fn test_single_dim() {
let a = GaussianBox::new(vec![3.0], vec![0.5]).unwrap();
let b = GaussianBox::new(vec![5.0], vec![1.0]).unwrap();
let kl = kl_divergence(&a, &b).unwrap();
assert!(kl > 0.0);
let bc = bhattacharyya_coefficient(&a, &b).unwrap();
assert!(bc > 0.0 && bc < 1.0);
}
#[test]
fn test_very_small_sigma_stability() {
let a = GaussianBox::new(vec![0.0], vec![1e-6]).unwrap();
let b = GaussianBox::new(vec![0.0], vec![1.0]).unwrap();
let kl = kl_divergence(&a, &b).unwrap();
assert!(
kl.is_finite(),
"KL should be finite with small sigma, got {kl}"
);
assert!(kl >= 0.0);
}
#[test]
fn test_very_large_sigma_stability() {
let a = GaussianBox::new(vec![0.0], vec![1e6]).unwrap();
let b = GaussianBox::new(vec![0.0], vec![1.0]).unwrap();
let kl = kl_divergence(&a, &b).unwrap();
assert!(
kl.is_finite(),
"KL should be finite with large sigma, got {kl}"
);
}
#[test]
fn test_large_mu_difference_bc_near_zero() {
let a = GaussianBox::new(vec![0.0; 8], vec![1.0; 8]).unwrap();
let mu_far: Vec<f32> = vec![1000.0; 8];
let b = GaussianBox::new(mu_far, vec![1.0; 8]).unwrap();
let bc = bhattacharyya_coefficient(&a, &b).unwrap();
assert!(
bc < 1e-10,
"BC for very distant Gaussians should be ~0, got {bc}"
);
}
#[test]
fn test_kl_divergence_formula_matches_standard() {
let child = GaussianBox::new(vec![1.0, 2.0], vec![0.5, 1.5]).unwrap();
let parent = GaussianBox::new(vec![3.0, 0.0], vec![2.0, 1.0]).unwrap();
let kl = kl_divergence(&child, &parent).unwrap();
let dim0 = 0.5
* ((0.5_f32 / 2.0).powi(2) + (3.0 - 1.0_f32).powi(2) / 4.0 - 1.0
+ 2.0 * (2.0_f32 / 0.5).ln());
let dim1 = 0.5
* ((1.5_f32 / 1.0).powi(2) + (0.0 - 2.0_f32).powi(2) / 1.0 - 1.0
+ 2.0 * (1.0_f32 / 1.5).ln());
let expected = dim0 + dim1;
assert!(
(kl - expected).abs() < 1e-4,
"KL formula mismatch: got {kl}, expected {expected}"
);
}
#[test]
fn test_bhattacharyya_coefficient_identical_distributions() {
let g = GaussianBox::new(vec![1.5, -2.3, 0.7], vec![0.3, 2.1, 1.0]).unwrap();
let bc = bhattacharyya_coefficient(&g, &g).unwrap();
assert!((bc - 1.0).abs() < 1e-6, "BC(N,N) should be 1.0, got {bc}");
}
proptest! {
#[test]
fn test_bhattacharyya_coefficient_range_invariant(
(a, b) in arb_gaussian_pair(16)
) {
let bc = bhattacharyya_coefficient(&a, &b).unwrap();
prop_assert!(bc >= -1e-7, "BC below 0: {bc}");
prop_assert!(bc <= 1.0 + 1e-6, "BC above 1: {bc}");
}
}
#[test]
fn test_sigma_ceiling_loss_hand_computed() {
let g = GaussianBox::new(vec![0.0; 3], vec![0.05, 0.3, 5.0]).unwrap();
let ceil_loss = sigma_ceiling_loss(&g, 1.0);
let expected_ceil = 24.0 / 3.0;
assert!(
(ceil_loss - expected_ceil).abs() < 1e-3,
"ceiling loss: expected {expected_ceil}, got {ceil_loss}"
);
}
#[test]
fn test_softplus_floor_not_too_small() {
let g = GaussianBox::from_center_offset(vec![0.0; 4], vec![-100.0, -50.0, -25.0, -21.0])
.unwrap();
for (i, &s) in g.sigma().iter().enumerate() {
assert!(
s >= 1e-7,
"sigma[{i}] = {s} is below 1e-7 floor for extreme negative offset"
);
}
}
#[test]
fn test_kl_extreme_sigma_ratio() {
let a = GaussianBox::new(vec![0.0, 0.0], vec![1e-3, 1e-3]).unwrap();
let b = GaussianBox::new(vec![0.0, 0.0], vec![1e3, 1e3]).unwrap();
let kl_ab = kl_divergence(&a, &b).unwrap();
let kl_ba = kl_divergence(&b, &a).unwrap();
assert!(kl_ab.is_finite(), "KL(small||large) is not finite: {kl_ab}");
assert!(kl_ba.is_finite(), "KL(large||small) is not finite: {kl_ba}");
assert!(kl_ab >= 0.0);
assert!(kl_ba >= 0.0);
}
#[test]
fn test_volume_regularization_zero_above_threshold() {
let g = GaussianBox::new(vec![0.0; 3], vec![2.0, 3.0, 1.5]).unwrap();
let loss = volume_regularization(&g, 1.0);
assert!(
loss.abs() < 1e-10,
"loss should be 0 when all variances >= min_var, got {loss}"
);
let loss_nonzero = volume_regularization(&g, 5.0);
let expected = (1.0 + 0.0 + 2.75 * 2.75) / 3.0;
assert!(
(loss_nonzero - expected).abs() < 1e-4,
"loss should be {expected}, got {loss_nonzero}"
);
}
#[test]
fn test_gaussian_new_valid() {
let g = GaussianBox::new(vec![0.0, 1.0], vec![1.0, 2.0]).unwrap();
assert_eq!(g.dim(), 2);
}
#[test]
fn test_gaussian_new_dim_mismatch() {
let err = GaussianBox::new(vec![0.0], vec![1.0, 2.0]).unwrap_err();
assert!(matches!(err, BoxError::DimensionMismatch { .. }));
}
#[test]
fn test_gaussian_new_negative_sigma() {
let err = GaussianBox::new(vec![0.0], vec![-1.0]).unwrap_err();
assert!(matches!(err, BoxError::InvalidBounds { .. }));
}
#[test]
fn test_kl_identical() {
let g = GaussianBox::new(vec![0.0; 4], vec![1.0; 4]).unwrap();
let kl = kl_divergence(&g, &g).unwrap();
assert!(
(kl).abs() < 1e-6,
"KL of identical Gaussians should be 0, got {kl}"
);
}
#[test]
fn test_kl_asymmetric() {
let child = GaussianBox::new(vec![0.0], vec![0.5]).unwrap();
let parent = GaussianBox::new(vec![0.0], vec![2.0]).unwrap();
let kl_cp = kl_divergence(&child, &parent).unwrap();
let kl_pc = kl_divergence(&parent, &child).unwrap();
assert!(kl_cp < kl_pc, "D_KL(narrow||wide) < D_KL(wide||narrow)");
}
#[test]
fn test_kl_known_value() {
let child = GaussianBox::new(vec![0.0], vec![1.0]).unwrap();
let parent = GaussianBox::new(vec![1.0], vec![2.0]).unwrap();
let kl = kl_divergence(&child, &parent).unwrap();
let expected = 0.5 * (0.25 + 0.25 - 1.0 + 2.0 * 2.0_f32.ln());
assert!(
(kl - expected).abs() < 1e-5,
"expected {expected}, got {kl}"
);
}
#[test]
fn test_bhattacharyya_identical() {
let g = GaussianBox::new(vec![1.0, 2.0], vec![0.5, 1.5]).unwrap();
let bc = bhattacharyya_coefficient(&g, &g).unwrap();
assert!(
(bc - 1.0).abs() < 1e-6,
"BC of identical Gaussians should be 1.0, got {bc}"
);
}
#[test]
fn test_bhattacharyya_symmetric() {
let a = GaussianBox::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let b = GaussianBox::new(vec![2.0, 1.0], vec![0.5, 2.0]).unwrap();
let bc_ab = bhattacharyya_coefficient(&a, &b).unwrap();
let bc_ba = bhattacharyya_coefficient(&b, &a).unwrap();
assert!(
(bc_ab - bc_ba).abs() < 1e-6,
"BC should be symmetric: {bc_ab} != {bc_ba}"
);
}
#[test]
fn test_bhattacharyya_distant() {
let a = GaussianBox::new(vec![0.0], vec![0.1]).unwrap();
let b = GaussianBox::new(vec![100.0], vec![0.1]).unwrap();
let bc = bhattacharyya_coefficient(&a, &b).unwrap();
assert!(bc < 1e-10, "distant Gaussians should have BC ~0, got {bc}");
}
#[test]
fn test_from_center_offset() {
let g = GaussianBox::from_center_offset(vec![1.0, -1.0], vec![0.0, 0.5]).unwrap();
assert_eq!(g.mu(), [1.0, -1.0]);
assert!((g.sigma()[0] - std::f32::consts::LN_2).abs() < 0.01);
let expected_sp_half = (0.5_f32.exp() + 1.0).ln();
assert!((g.sigma()[1] - expected_sp_half).abs() < 0.01);
}
#[test]
fn test_volume_regularization() {
let g = GaussianBox::new(vec![0.0; 4], vec![1.0; 4]).unwrap();
let loss = volume_regularization(&g, 0.5);
assert!(
loss.abs() < 1e-6,
"unit Gaussian with min_var=0.5 should have loss=0, got {loss}"
);
let loss2 = volume_regularization(&g, 2.0);
assert!((loss2 - 1.0).abs() < 1e-6, "expected 1.0, got {loss2}");
}
#[test]
fn test_sigma_ceiling_loss_below_threshold() {
let g = GaussianBox::new(vec![0.0, 0.0], vec![0.5, 0.8]).unwrap();
let loss = sigma_ceiling_loss(&g, 1.0);
assert!(
loss.abs() < 1e-10,
"all below threshold: expected 0, got {loss}"
);
}
#[test]
fn test_sigma_ceiling_loss_above_threshold() {
let g = GaussianBox::new(vec![0.0, 0.0], vec![2.0, 0.5]).unwrap();
let loss = sigma_ceiling_loss(&g, 1.0);
assert!((loss - 1.5).abs() < 1e-5, "expected 1.5, got {loss}");
}
#[test]
fn nan_sigma_returns_err() {
let result = GaussianBox::new(vec![0.0], vec![f32::NAN]);
assert!(result.is_err(), "NaN sigma should be rejected");
}
#[test]
fn nan_mu_returns_err() {
let result = GaussianBox::new(vec![f32::NAN], vec![1.0]);
assert!(result.is_err(), "NaN mu should be rejected");
}
#[test]
fn from_center_offset_nan_center_returns_err() {
let result = GaussianBox::from_center_offset(vec![f32::NAN, 0.0], vec![1.0, 1.0]);
assert!(result.is_err(), "NaN center should be rejected");
}
#[test]
fn from_center_offset_nan_offset_returns_err() {
let result = GaussianBox::from_center_offset(vec![0.0, 0.0], vec![1.0, f32::NAN]);
assert!(result.is_err(), "NaN offset should be rejected");
}
}