#![allow(missing_docs)]
use crate::BoxError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Ellipsoid {
mu: Vec<f32>,
cholesky: Vec<f32>,
dim: usize,
}
impl Ellipsoid {
pub fn new(mu: Vec<f32>, cholesky: Vec<f32>) -> Result<Self, BoxError> {
let dim = mu.len();
if cholesky.len() != dim * dim {
return Err(BoxError::DimensionMismatch {
expected: dim * dim,
actual: cholesky.len(),
});
}
for (i, &m) in mu.iter().enumerate() {
if !m.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: m as f64,
max: m as f64,
});
}
}
for (i, &c) in cholesky.iter().enumerate() {
if !c.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i / dim,
min: c as f64,
max: c as f64,
});
}
}
for i in 0..dim {
let diag = cholesky[i * dim + i];
if diag <= -50.0 {
return Err(BoxError::InvalidBounds {
dim: i,
min: diag as f64,
max: diag as f64,
});
}
}
Ok(Self { mu, cholesky, dim })
}
pub fn from_log_diagonal(mu: Vec<f32>, log_diag: Vec<f32>) -> Result<Self, BoxError> {
let dim = mu.len();
if log_diag.len() != dim {
return Err(BoxError::DimensionMismatch {
expected: dim,
actual: log_diag.len(),
});
}
for (i, &v) in mu.iter().enumerate() {
if !v.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: v as f64,
max: v as f64,
});
}
}
for (i, &v) in log_diag.iter().enumerate() {
if !v.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: v as f64,
max: v as f64,
});
}
}
let mut cholesky = vec![0.0f32; dim * dim];
for i in 0..dim {
cholesky[i * dim + i] = log_diag[i];
}
Ok(Self { mu, cholesky, dim })
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
pub fn mu(&self) -> &[f32] {
&self.mu
}
pub fn cholesky(&self) -> &[f32] {
&self.cholesky
}
#[must_use]
pub fn log_det(&self) -> f32 {
let mut sum = 0.0f32;
for i in 0..self.dim {
sum += self.cholesky[i * self.dim + i]; }
2.0 * sum
}
#[must_use]
pub fn log_volume(&self) -> f32 {
0.5 * self.log_det()
}
pub fn mu_mut(&mut self) -> &mut [f32] {
&mut self.mu
}
pub fn log_diag(&self) -> Vec<f32> {
(0..self.dim)
.map(|i| self.cholesky[i * self.dim + i])
.collect()
}
pub fn set_log_diag(&mut self, log_diag: &[f32]) {
for (i, &v) in log_diag.iter().enumerate() {
self.cholesky[i * self.dim + i] = v.clamp(-10.0, 5.0);
}
}
}
fn get_l(cholesky: &[f32], dim: usize, i: usize, j: usize) -> f32 {
if j > i {
return 0.0;
}
let val = cholesky[i * dim + j];
if i == j {
val.exp()
} else {
val
}
}
fn solve_lower(cholesky: &[f32], dim: usize, b: &[f32]) -> Vec<f32> {
let mut x = vec![0.0f32; dim];
for i in 0..dim {
let mut sum = 0.0f32;
for (j, &xj) in x.iter().enumerate().take(i) {
sum += get_l(cholesky, dim, i, j) * xj;
}
let diag = get_l(cholesky, dim, i, i);
if diag.abs() < 1e-8 {
x[i] = 0.0;
} else {
x[i] = (b[i] - sum) / diag;
}
}
x
}
fn mahalanobis_sq(cholesky: &[f32], dim: usize, x: &[f32]) -> f32 {
let y = solve_lower(cholesky, dim, x);
y.iter().map(|v| v * v).sum()
}
fn trace_inv_product(cholesky_a: &[f32], cholesky_b: &[f32], dim: usize) -> f32 {
let mut total = 0.0f32;
for j in 0..dim {
let col_b: Vec<f32> = (0..dim).map(|i| get_l(cholesky_b, dim, i, j)).collect();
let y = solve_lower(cholesky_a, dim, &col_b);
total += y.iter().map(|v| v * v).sum::<f32>();
}
total
}
fn cholesky_sum(cholesky_a: &[f32], cholesky_b: &[f32], dim: usize) -> Vec<f32> {
let mut s = vec![0.0f32; dim * dim];
for i in 0..dim {
for j in 0..=i {
let mut val = 0.0f32;
for k in 0..=j {
val += get_l(cholesky_a, dim, i, k) * get_l(cholesky_a, dim, j, k);
val += get_l(cholesky_b, dim, i, k) * get_l(cholesky_b, dim, j, k);
}
s[i * dim + j] = val;
}
}
cholesky_decompose(&s, dim)
}
fn cholesky_decompose(s: &[f32], dim: usize) -> Vec<f32> {
let mut l = vec![0.0f32; dim * dim];
for i in 0..dim {
for j in 0..=i {
let mut sum = 0.0f32;
for k in 0..j {
sum += l[i * dim + k] * l[j * dim + k];
}
if i == j {
let val = s[i * dim + i] - sum;
l[i * dim + i] = 0.5 * (val.max(1e-10)).ln();
} else {
let ljj = get_l(&l, dim, j, j);
if ljj > 1e-10 {
l[i * dim + j] = (s[i * dim + j] - sum) / ljj;
}
}
}
}
l
}
pub fn kl_divergence(child: &Ellipsoid, parent: &Ellipsoid) -> Result<f32, BoxError> {
if child.dim != parent.dim {
return Err(BoxError::DimensionMismatch {
expected: child.dim,
actual: parent.dim,
});
}
let d = child.dim as f32;
let trace_term = trace_inv_product(&parent.cholesky, &child.cholesky, child.dim);
let diff: Vec<f32> = parent
.mu
.iter()
.zip(child.mu.iter())
.map(|(&p, &c)| p - c)
.collect();
let mahal = mahalanobis_sq(&parent.cholesky, child.dim, &diff);
let log_det_ratio = parent.log_det() - child.log_det();
Ok(0.5 * (trace_term + mahal - d + log_det_ratio))
}
pub fn bhattacharyya_distance(a: &Ellipsoid, b: &Ellipsoid) -> Result<f32, BoxError> {
if a.dim != b.dim {
return Err(BoxError::DimensionMismatch {
expected: a.dim,
actual: b.dim,
});
}
let l_sum = cholesky_sum(&a.cholesky, &b.cholesky, a.dim);
let log_det_s: f32 = (0..a.dim)
.map(|i| 2.0 * get_l(&l_sum, a.dim, i, i).ln())
.sum();
let d = a.dim as f32;
let log_det_m = log_det_s - d * 2.0f32.ln();
let diff: Vec<f32> = a.mu.iter().zip(b.mu.iter()).map(|(&x, &y)| x - y).collect();
let mahal = mahalanobis_sq(&l_sum, a.dim, &diff);
let mahal_term = 0.125 * 2.0 * mahal;
let log_det_a = a.log_det();
let log_det_b = b.log_det();
let log_det_term = 0.5 * (log_det_m - 0.5 * (log_det_a + log_det_b));
Ok(mahal_term + log_det_term)
}
pub fn containment_prob(child: &Ellipsoid, parent: &Ellipsoid, k: f32) -> Result<f32, BoxError> {
let kl = kl_divergence(child, parent)?;
Ok(crate::utils::stable_sigmoid(-k * kl))
}
pub fn surface_distance(a: &Ellipsoid, b: &Ellipsoid) -> Result<f32, BoxError> {
let kl_ab = kl_divergence(a, b)?;
let kl_ba = kl_divergence(b, a)?;
Ok((kl_ab + kl_ba).max(0.0).sqrt())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ellipsoid_new_diagonal() {
let e = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
assert_eq!(e.dim(), 2);
}
#[test]
fn ellipsoid_new_full() {
let cholesky = vec![
0.0, 0.0, 0.5, 0.0, ];
let e = Ellipsoid::new(vec![0.0, 0.0], cholesky).unwrap();
assert_eq!(e.dim(), 2);
}
#[test]
fn ellipsoid_rejects_dim_mismatch() {
assert!(Ellipsoid::new(vec![0.0], vec![0.0, 0.0, 0.0, 0.0]).is_err());
}
#[test]
fn ellipsoid_rejects_non_finite() {
assert!(Ellipsoid::from_log_diagonal(vec![f32::NAN], vec![0.0]).is_err());
}
#[test]
fn kl_identical_is_zero() {
let e = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let kl = kl_divergence(&e, &e).unwrap();
assert!(kl.abs() < 1e-4, "KL(identical) = {kl}, expected 0");
}
#[test]
fn kl_same_center_different_scale() {
let child = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![-1.0, -1.0]).unwrap();
let parent = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let kl = kl_divergence(&child, &parent).unwrap();
assert!(kl > 0.0, "KL = {kl}, expected > 0");
}
#[test]
fn kl_different_center() {
let child = Ellipsoid::from_log_diagonal(vec![1.0, 0.0], vec![0.0, 0.0]).unwrap();
let parent = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let kl = kl_divergence(&child, &parent).unwrap();
assert!((kl - 0.5).abs() < 1e-4, "KL = {kl}, expected 0.5");
}
#[test]
fn kl_asymmetric() {
let a = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![-1.0, -1.0]).unwrap();
let b = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let kl_ab = kl_divergence(&a, &b).unwrap();
let kl_ba = kl_divergence(&b, &a).unwrap();
assert!(
(kl_ab - kl_ba).abs() > 0.01,
"KL should be asymmetric: {kl_ab} vs {kl_ba}"
);
}
#[test]
fn kl_dimension_mismatch() {
let a = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let b = Ellipsoid::from_log_diagonal(vec![0.0], vec![0.0]).unwrap();
assert!(kl_divergence(&a, &b).is_err());
}
#[test]
fn bhattacharyya_identical_is_zero() {
let e = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let d = bhattacharyya_distance(&e, &e).unwrap();
assert!(d.abs() < 1e-4, "B(identical) = {d}, expected 0");
}
#[test]
fn bhattacharyya_symmetric() {
let a = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let b = Ellipsoid::from_log_diagonal(vec![1.0, 0.0], vec![0.5, 0.5]).unwrap();
let d_ab = bhattacharyya_distance(&a, &b).unwrap();
let d_ba = bhattacharyya_distance(&b, &a).unwrap();
assert!(
(d_ab - d_ba).abs() < 1e-4,
"Bhattacharyya should be symmetric: {d_ab} != {d_ba}"
);
}
#[test]
fn containment_prob_identical_is_half() {
let e = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let p = containment_prob(&e, &e, 1.0).unwrap();
assert!((p - 0.5).abs() < 1e-4);
}
#[test]
fn containment_prob_near_identical_is_half() {
let child = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![-0.01, -0.01]).unwrap();
let parent = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let p = containment_prob(&child, &parent, 1.0).unwrap();
assert!(
(p - 0.5).abs() < 0.05,
"near-identical containment = {p}, expected ~0.5"
);
}
#[test]
fn containment_prob_widely_different_is_low() {
let child = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![-3.0, -3.0]).unwrap();
let parent = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let p = containment_prob(&child, &parent, 1.0).unwrap();
assert!(p < 0.5, "narrower child containment = {p}, expected < 0.5");
}
#[test]
fn surface_distance_identical_is_zero() {
let e = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let d = surface_distance(&e, &e).unwrap();
assert!(d.abs() < 1e-4);
}
#[test]
fn log_volume_diagonal() {
let e = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
assert!((e.log_volume()).abs() < 1e-6);
}
#[test]
fn log_volume_scales() {
let _e1 = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let e2 = Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let lv2 = e2.log_volume();
assert!((lv2 - 2.0).abs() < 1e-4, "log_volume = {lv2}, expected 2");
}
#[test]
fn sigmoid_large_positive() {
assert!((crate::utils::stable_sigmoid(100.0) - 1.0).abs() < 1e-4);
}
#[test]
fn sigmoid_large_negative() {
assert!(crate::utils::stable_sigmoid(-100.0).abs() < 1e-4);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn arb_ellipsoid(dim: usize) -> impl Strategy<Value = Ellipsoid> {
(
prop::collection::vec(-5.0f32..5.0, dim),
prop::collection::vec(-3.0f32..3.0, dim),
)
.prop_filter_map("valid ellipsoid", move |(mu, log_d)| {
Ellipsoid::from_log_diagonal(mu, log_d).ok()
})
}
fn arb_ellipsoid_pair(dim: usize) -> impl Strategy<Value = (Ellipsoid, Ellipsoid)> {
(arb_ellipsoid(dim), arb_ellipsoid(dim))
}
proptest! {
#[test]
fn prop_kl_nonneg_for_same_center(
(a, b) in arb_ellipsoid_pair(3)
) {
let kl = kl_divergence(&a, &b).unwrap();
prop_assert!(kl > -1.0, "KL unexpectedly negative: {kl}");
}
#[test]
fn prop_self_kl_is_zero(
e in arb_ellipsoid(3)
) {
let kl = kl_divergence(&e, &e).unwrap();
prop_assert!(kl.abs() < 1e-3, "KL(self) = {kl}, expected 0");
}
#[test]
fn prop_bhattacharyya_nonneg(
(a, b) in arb_ellipsoid_pair(3)
) {
let d = bhattacharyya_distance(&a, &b).unwrap();
prop_assert!(d > -1e-4, "Bhattacharyya < 0: {d}");
}
#[test]
fn prop_bhattacharyya_symmetric(
(a, b) in arb_ellipsoid_pair(3)
) {
let d_ab = bhattacharyya_distance(&a, &b).unwrap();
let d_ba = bhattacharyya_distance(&b, &a).unwrap();
prop_assert!(
(d_ab - d_ba).abs() < 1e-3,
"Bhattacharyya should be symmetric: {d_ab} != {d_ba}"
);
}
#[test]
fn prop_containment_in_unit_interval(
(a, b) in arb_ellipsoid_pair(3)
) {
let p = containment_prob(&a, &b, 1.0).unwrap();
prop_assert!(p >= -1e-6, "containment_prob < 0: {p}");
prop_assert!(p <= 1.0 + 1e-6, "containment_prob > 1: {p}");
}
#[test]
fn prop_surface_distance_nonneg(
(a, b) in arb_ellipsoid_pair(3)
) {
let d = surface_distance(&a, &b).unwrap();
prop_assert!(d >= -1e-6, "surface_distance < 0: {d}");
}
#[test]
fn prop_surface_distance_symmetric(
(a, b) in arb_ellipsoid_pair(3)
) {
let d_ab = surface_distance(&a, &b).unwrap();
let d_ba = surface_distance(&b, &a).unwrap();
prop_assert!(
(d_ab - d_ba).abs() < 1e-3,
"surface_distance should be symmetric: {d_ab} != {d_ba}"
);
}
#[test]
fn prop_log_volume_finite(
e in arb_ellipsoid(4)
) {
let lv = e.log_volume();
prop_assert!(lv.is_finite(), "log_volume not finite: {lv}");
}
}
}