use ndarray::{ArrayView2, Axis};
use rand::seq::IndexedRandom;
use rand::SeedableRng;
use crate::error::{RagDriftError, Result};
use crate::types::{check_min_samples, check_same_cols};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum MmdEstimator {
Biased,
#[default]
Unbiased,
}
pub fn mmd_rbf(
baseline: &ArrayView2<'_, f32>,
current: &ArrayView2<'_, f32>,
estimator: MmdEstimator,
seed: u64,
) -> Result<f64> {
check_same_cols(baseline, current)?;
check_min_samples(baseline.nrows(), 2)?;
check_min_samples(current.nrows(), 2)?;
let bandwidth = median_pairwise_bandwidth(baseline, current, seed)?;
if bandwidth <= 0.0 {
return Ok(0.0);
}
let inv_two_h2 = 1.0 / (2.0 * bandwidth * bandwidth);
let n = baseline.nrows();
let m = current.nrows();
let kxx = sum_kernel(
baseline,
baseline,
inv_two_h2,
estimator == MmdEstimator::Unbiased,
);
let kyy = sum_kernel(
current,
current,
inv_two_h2,
estimator == MmdEstimator::Unbiased,
);
let kxy = sum_kernel(baseline, current, inv_two_h2, false);
let mmd2 = match estimator {
MmdEstimator::Biased => {
kxx / (n * n) as f64 + kyy / (m * m) as f64 - 2.0 * kxy / (n * m) as f64
}
MmdEstimator::Unbiased => {
kxx / (n * (n - 1)) as f64 + kyy / (m * (m - 1)) as f64 - 2.0 * kxy / (n * m) as f64
}
};
if !mmd2.is_finite() {
return Err(RagDriftError::NumericalInstability {
step: "mmd".into(),
reason: "non-finite mmd^2".into(),
});
}
Ok(mmd2)
}
#[cfg(feature = "parallel")]
fn sum_kernel(
xs: &ArrayView2<'_, f32>,
ys: &ArrayView2<'_, f32>,
inv_two_h2: f64,
exclude_diagonal: bool,
) -> f64 {
use rayon::prelude::*;
xs.axis_iter(Axis(0))
.into_par_iter()
.enumerate()
.map(|(i, x)| {
let mut row_sum = 0.0_f64;
for (j, y) in ys.axis_iter(Axis(0)).enumerate() {
if exclude_diagonal && i == j {
continue;
}
let mut d = 0.0_f64;
for (a, b) in x.iter().zip(y.iter()) {
let diff = (*a as f64) - (*b as f64);
d += diff * diff;
}
row_sum += (-d * inv_two_h2).exp();
}
row_sum
})
.sum()
}
#[cfg(not(feature = "parallel"))]
fn sum_kernel(
xs: &ArrayView2<'_, f32>,
ys: &ArrayView2<'_, f32>,
inv_two_h2: f64,
exclude_diagonal: bool,
) -> f64 {
let mut total = 0.0_f64;
for (i, x) in xs.axis_iter(Axis(0)).enumerate() {
for (j, y) in ys.axis_iter(Axis(0)).enumerate() {
if exclude_diagonal && i == j {
continue;
}
let mut d = 0.0_f64;
for (a, b) in x.iter().zip(y.iter()) {
let diff = (*a as f64) - (*b as f64);
d += diff * diff;
}
total += (-d * inv_two_h2).exp();
}
}
total
}
fn median_pairwise_bandwidth(
a: &ArrayView2<'_, f32>,
b: &ArrayView2<'_, f32>,
seed: u64,
) -> Result<f64> {
const MAX_PAIRS: usize = 500;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let combined: Vec<ndarray::ArrayView1<'_, f32>> =
a.axis_iter(Axis(0)).chain(b.axis_iter(Axis(0))).collect();
let total = combined.len();
if total < 2 {
return Err(RagDriftError::InsufficientSamples {
needed: 2,
got: total,
});
}
let n_pick = total.min(64);
let picked: Vec<&ndarray::ArrayView1<'_, f32>> =
combined.choose_multiple(&mut rng, n_pick).collect();
let mut dists: Vec<f64> = Vec::with_capacity((n_pick * (n_pick - 1)) / 2);
for i in 0..picked.len() {
for j in (i + 1)..picked.len() {
if dists.len() >= MAX_PAIRS {
break;
}
let mut d = 0.0_f64;
for (x, y) in picked[i].iter().zip(picked[j].iter()) {
let diff = (*x as f64) - (*y as f64);
d += diff * diff;
}
dists.push(d.sqrt());
}
if dists.len() >= MAX_PAIRS {
break;
}
}
if dists.is_empty() {
return Ok(0.0);
}
dists.sort_by(|x, y| x.partial_cmp(y).unwrap());
let mid = dists.len() / 2;
Ok(dists[mid].max(1e-12))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::Array2;
use ndarray_rand::rand_distr::StandardNormal;
use ndarray_rand::RandomExt;
#[test]
fn identical_arrays_zero_mmd_biased() {
let a = Array2::<f32>::zeros((16, 4));
let b = a.clone();
let v = mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 0).unwrap();
assert_abs_diff_eq!(v, 0.0, epsilon = 1e-9);
}
#[test]
fn same_distribution_small_mmd() {
let a = Array2::<f32>::random((128, 8), StandardNormal);
let b = Array2::<f32>::random((128, 8), StandardNormal);
let v = mmd_rbf(&a.view(), &b.view(), MmdEstimator::Unbiased, 1).unwrap();
assert!(v.abs() < 0.05, "expected small MMD^2 under H0, got {v}");
}
#[test]
fn shifted_distribution_larger_mmd() {
let a = Array2::<f32>::random((128, 8), StandardNormal);
let mut b = Array2::<f32>::random((128, 8), StandardNormal);
b.mapv_inplace(|v| v + 2.0);
let v0 = mmd_rbf(&a.view(), &a.view(), MmdEstimator::Biased, 1).unwrap();
let v1 = mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 1).unwrap();
assert!(v1 > v0, "shifted MMD should exceed identical, {v1} vs {v0}");
assert!(v1 > 0.05);
}
#[test]
fn rejects_dim_mismatch() {
let a = Array2::<f32>::zeros((4, 4));
let b = Array2::<f32>::zeros((4, 8));
assert!(mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 0).is_err());
}
#[test]
fn rejects_too_few_samples() {
let a = Array2::<f32>::zeros((1, 4));
let b = Array2::<f32>::zeros((4, 4));
assert!(mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 0).is_err());
}
}