use crate::{Error, Result};
#[cfg(feature = "ndarray")]
use ndarray::{Array2, ArrayView2};
#[derive(Debug, Clone)]
pub struct DistributionDistanceConfig {
pub rbf_sigma: f64,
pub sw_projections: usize,
}
impl Default for DistributionDistanceConfig {
fn default() -> Self {
Self {
rbf_sigma: 1.0,
sw_projections: 32,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DistributionDistance {
pub mmd_rbf_biased: Option<f64>,
pub sliced_wasserstein: Option<f32>,
}
impl DistributionDistance {
pub fn compute(
x: ArrayView2<'_, f32>,
y: ArrayView2<'_, f32>,
cfg: &DistributionDistanceConfig,
) -> Result<Self> {
let dx = x.ncols();
let dy = y.ncols();
if dx != dy {
return Err(Error::DimensionMismatch {
expected: dx,
found: dy,
});
}
if x.nrows() == 0 || y.nrows() == 0 {
return Ok(Self::default());
}
let mut out = Self::default();
#[cfg(feature = "rkhs")]
{
out.mmd_rbf_biased = Some(mmd_rbf_biased(x, y, cfg.rbf_sigma)?);
}
#[cfg(feature = "wass")]
{
out.sliced_wasserstein = Some(sliced_wasserstein(x, y, cfg.sw_projections)?);
}
Ok(out)
}
}
#[cfg(feature = "rkhs")]
fn mmd_rbf_biased(x: ArrayView2<'_, f32>, y: ArrayView2<'_, f32>, sigma: f64) -> Result<f64> {
if sigma <= 0.0 {
return Err(Error::InvalidParameter {
name: "rbf_sigma",
message: "must be > 0",
});
}
fn to_vecs(a: ArrayView2<'_, f32>) -> Vec<Vec<f64>> {
a.outer_iter()
.map(|row| row.iter().map(|v| *v as f64).collect::<Vec<f64>>())
.collect()
}
let xa = to_vecs(x);
let ya = to_vecs(y);
let mmd = rkhs::mmd_biased(&xa, &ya, |a, b| rkhs::rbf(a, b, sigma));
Ok(mmd)
}
#[cfg(feature = "wass")]
fn sliced_wasserstein(
x: ArrayView2<'_, f32>,
y: ArrayView2<'_, f32>,
n_projections: usize,
) -> Result<f32> {
if n_projections == 0 {
return Ok(0.0);
}
let x_owned: Array2<f32> = x.to_owned();
let y_owned: Array2<f32> = y.to_owned();
Ok(wass::sliced_wasserstein(
&x_owned,
&y_owned,
n_projections,
42,
2.0,
))
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn empty_inputs_return_zeroes() {
let x = Array2::<f32>::zeros((0, 3));
let y = Array2::<f32>::zeros((0, 3));
let d = DistributionDistance::compute(
x.view(),
y.view(),
&DistributionDistanceConfig::default(),
)
.unwrap();
assert!(d.mmd_rbf_biased.is_none() || d.mmd_rbf_biased == Some(0.0));
assert!(d.sliced_wasserstein.is_none() || d.sliced_wasserstein == Some(0.0));
}
#[test]
fn dim_mismatch_errors() {
let x = Array2::<f32>::zeros((5, 2));
let y = Array2::<f32>::zeros((5, 3));
let err = DistributionDistance::compute(
x.view(),
y.view(),
&DistributionDistanceConfig::default(),
)
.unwrap_err();
assert!(matches!(err, Error::DimensionMismatch { .. }));
}
#[cfg(feature = "wass")]
#[test]
fn sliced_wasserstein_is_symmetric_on_simple_case() {
let x = array![[0.0f32, 0.0], [1.0, 0.0], [0.0, 1.0]];
let y = array![[10.0f32, 10.0], [10.5, 10.0], [10.0, 10.5]];
let cfg = DistributionDistanceConfig {
sw_projections: 16,
..Default::default()
};
let d_xy = DistributionDistance::compute(x.view(), y.view(), &cfg)
.unwrap()
.sliced_wasserstein
.unwrap();
let d_yx = DistributionDistance::compute(y.view(), x.view(), &cfg)
.unwrap()
.sliced_wasserstein
.unwrap();
assert!((d_xy - d_yx).abs() < 1e-4);
assert!(d_xy > 0.0);
}
#[cfg(feature = "rkhs")]
#[test]
fn mmd_rbf_biased_is_nonnegative_and_small_for_same_distribution() {
let x = array![[0.0f32, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let y = x.clone();
let cfg = DistributionDistanceConfig {
rbf_sigma: 1.0,
..Default::default()
};
let m = DistributionDistance::compute(x.view(), y.view(), &cfg)
.unwrap()
.mmd_rbf_biased
.unwrap();
assert!(m >= -1e-12);
assert!(m < 1e-6, "expected ~0 for identical point clouds; got {m}");
}
#[cfg(feature = "rkhs")]
#[test]
fn mmd_rbf_biased_increases_for_separated_clouds() {
let x = array![[0.0f32, 0.0], [0.1, 0.0], [0.0, 0.1], [0.1, 0.1]];
let y = array![[10.0f32, 10.0], [10.1, 10.0], [10.0, 10.1], [10.1, 10.1]];
let cfg = DistributionDistanceConfig {
rbf_sigma: 1.0,
..Default::default()
};
let m = DistributionDistance::compute(x.view(), y.view(), &cfg)
.unwrap()
.mmd_rbf_biased
.unwrap();
assert!(
m > 0.1,
"expected noticeably positive MMD for separated clouds; got {m}"
);
}
}