use nalgebra::{DMatrix, DVector};
use std::fmt::Debug;
pub trait DistanceMetric: Send + Sync + Debug {
fn distance(&self, a: &[f64], b: &[f64]) -> f64;
fn is_mahalanobis(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Default)]
pub struct EuclideanDistance;
impl DistanceMetric for EuclideanDistance {
fn distance(&self, a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
}
#[derive(Debug, Clone)]
pub struct MahalanobisDistance {
pub inv_covariance: DMatrix<f64>,
}
impl MahalanobisDistance {
pub fn new(data: &DMatrix<f64>) -> Result<Self, String> {
let n = data.nrows();
if n < 2 {
return Err("Not enough data points to calculate covariance".to_string());
}
let centered = data.row_mean().transpose();
let mut centered_data = data.clone();
for i in 0..n {
let mut row = centered_data.row_mut(i);
row -= ¢ered;
}
let covariance = (centered_data.transpose() * centered_data) / ((n - 1) as f64);
let inv_covariance = covariance
.try_inverse()
.ok_or_else(|| "Covariance matrix is singular and cannot be inverted".to_string())?;
Ok(Self { inv_covariance })
}
pub fn from_inv_covariance(inv_covariance: DMatrix<f64>) -> Self {
Self { inv_covariance }
}
}
impl DistanceMetric for MahalanobisDistance {
fn distance(&self, a: &[f64], b: &[f64]) -> f64 {
let diff_vec: Vec<f64> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
let diff = DVector::from_vec(diff_vec);
let dist_sq = diff.dot(&(&self.inv_covariance * &diff));
if dist_sq < 0.0 {
0.0 } else {
dist_sq.sqrt()
}
}
fn is_mahalanobis(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean_distance() {
let metric = EuclideanDistance;
let a = &[1.0, 2.0];
let b = &[4.0, 6.0];
assert_eq!(metric.distance(a, b), 5.0);
}
#[test]
fn test_mahalanobis_distance() {
let inv_cov = DMatrix::identity(2, 2);
let metric = MahalanobisDistance::from_inv_covariance(inv_cov);
let a = &[1.0, 2.0];
let b = &[4.0, 6.0];
assert_eq!(metric.distance(a, b), 5.0);
let inv_cov_scaled = DMatrix::from_row_slice(
2,
2,
&[
0.25, 0.0, 0.0, 1.0, ],
);
let metric_scaled = MahalanobisDistance::from_inv_covariance(inv_cov_scaled);
let d = metric_scaled.distance(a, b);
assert!((d - 18.25f64.sqrt()).abs() < 1e-6);
}
}