use crate::error::InsightError;
use u_numflow::matrix::Matrix;
#[derive(Debug, Clone)]
pub struct MahalanobisConfig {
pub chi2_quantile: f64,
}
impl Default for MahalanobisConfig {
fn default() -> Self {
Self {
chi2_quantile: 0.975,
}
}
}
impl MahalanobisConfig {
pub fn chi2_quantile(mut self, q: f64) -> Self {
self.chi2_quantile = q;
self
}
}
#[derive(Debug, Clone)]
pub struct MahalanobisResult {
pub distances: Vec<f64>,
pub anomalies: Vec<bool>,
pub threshold: f64,
pub outlier_count: usize,
pub outlier_fraction: f64,
pub mean: Vec<f64>,
}
pub fn mahalanobis(
data: &[Vec<f64>],
config: &MahalanobisConfig,
) -> Result<MahalanobisResult, InsightError> {
let n = data.len();
if n == 0 {
return Err(InsightError::InsufficientData {
min_required: 2,
actual: 0,
});
}
let p = data[0].len();
if p == 0 {
return Err(InsightError::DimensionMismatch {
expected: 1,
actual: 0,
});
}
if n <= p {
return Err(InsightError::InsufficientData {
min_required: p + 1,
actual: n,
});
}
for (row_idx, point) in data.iter().enumerate() {
if point.len() != p {
return Err(InsightError::DimensionMismatch {
expected: p,
actual: point.len(),
});
}
if let Some(col_idx) = point.iter().position(|v| !v.is_finite()) {
return Err(InsightError::DegenerateData {
reason: format!(
"non-finite value at row {row_idx}, column {col_idx}"
),
});
}
}
let mut mean = vec![0.0; p];
for point in data.iter() {
for (j, val) in point.iter().enumerate() {
mean[j] += val;
}
}
for m in mean.iter_mut() {
*m /= n as f64;
}
let mut cov_data = vec![0.0; p * p];
for point in data.iter() {
for r in 0..p {
let dr = point[r] - mean[r];
for c in r..p {
let dc = point[c] - mean[c];
cov_data[r * p + c] += dr * dc;
}
}
}
let denom = (n - 1) as f64;
for r in 0..p {
for c in r..p {
let val = cov_data[r * p + c] / denom;
cov_data[r * p + c] = val;
cov_data[c * p + r] = val; }
}
let zero_var_cols: Vec<usize> = (0..p)
.filter(|&c| cov_data[c * p + c].abs() < 1e-12)
.collect();
if !zero_var_cols.is_empty() {
return Err(InsightError::DegenerateData {
reason: format!(
"near-zero variance in column(s) {zero_var_cols:?}; covariance matrix would be singular"
),
});
}
let cov_mat = Matrix::new(p, p, cov_data).map_err(|e| InsightError::ComputationFailed {
operation: "covariance matrix construction".into(),
detail: e.to_string(),
})?;
let inv_cov = cov_mat.inverse().map_err(|_| InsightError::DegenerateData {
reason: "covariance matrix is singular or near-singular (likely collinear columns)".into(),
})?;
let mut distances = Vec::with_capacity(n);
for point in data.iter() {
let mut diff = vec![0.0; p];
for j in 0..p {
diff[j] = point[j] - mean[j];
}
let mut d2 = 0.0;
for (r, &dr) in diff.iter().enumerate() {
let mut row_sum = 0.0;
for (c, &dc) in diff.iter().enumerate() {
row_sum += inv_cov.get(r, c) * dc;
}
d2 += dr * row_sum;
}
distances.push(d2.max(0.0));
}
let threshold = chi2_quantile(p as f64, config.chi2_quantile);
let anomalies: Vec<bool> = distances.iter().map(|&d| d > threshold).collect();
let outlier_count = anomalies.iter().filter(|&&a| a).count();
let outlier_fraction = if n > 0 {
outlier_count as f64 / n as f64
} else {
0.0
};
Ok(MahalanobisResult {
distances,
anomalies,
threshold,
outlier_count,
outlier_fraction,
mean,
})
}
fn chi2_quantile(df: f64, p: f64) -> f64 {
if df <= 0.0 || p <= 0.0 || p >= 1.0 {
return f64::NAN;
}
let z = u_numflow::special::inverse_normal_cdf(p);
let term = 2.0 / (9.0 * df);
let cube = 1.0 - term + z * term.sqrt();
if cube <= 0.0 {
return 0.0;
}
df * cube * cube * cube
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cluster(center: &[f64], n: usize, spread: f64) -> Vec<Vec<f64>> {
let mut points = Vec::with_capacity(n);
for i in 0..n {
let t = i as f64 / n as f64;
let point: Vec<f64> = center
.iter()
.enumerate()
.map(|(j, &c)| c + spread * (t * (j as f64 + 1.0)).sin())
.collect();
points.push(point);
}
points
}
#[test]
fn detects_obvious_outlier() {
let mut data = make_cluster(&[0.0, 0.0], 30, 1.0);
data.push(vec![50.0, 50.0]);
let result = mahalanobis(&data, &MahalanobisConfig::default()).unwrap();
let max_idx = result
.distances
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
assert_eq!(max_idx, 30);
assert!(result.anomalies[30]);
}
#[test]
fn inliers_below_threshold() {
let data: Vec<Vec<f64>> = (0..20)
.map(|i| {
let t = i as f64 * 0.1;
let n1 = (i as f64 * 0.73).sin() * 0.5;
let n2 = (i as f64 * 1.41).cos() * 0.5;
let n3 = (i as f64 * 2.17).sin() * 0.5;
vec![t + n1, n2 + 1.0, n3 + 2.0]
})
.collect();
let result = mahalanobis(&data, &MahalanobisConfig::default()).unwrap();
assert!(
result.outlier_fraction < 0.5,
"too many outliers in tight cluster: {}",
result.outlier_fraction
);
}
#[test]
fn mean_computed_correctly() {
let data = vec![
vec![2.0, 4.0],
vec![4.0, 7.0],
vec![6.0, 13.0],
vec![4.0, 8.0],
];
let result = mahalanobis(&data, &MahalanobisConfig::default()).unwrap();
assert!((result.mean[0] - 4.0).abs() < 1e-10);
assert!((result.mean[1] - 8.0).abs() < 1e-10); }
#[test]
fn empty_data() {
let data: Vec<Vec<f64>> = Vec::new();
assert!(mahalanobis(&data, &MahalanobisConfig::default()).is_err());
}
#[test]
fn insufficient_data() {
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
assert!(mahalanobis(&data, &MahalanobisConfig::default()).is_err());
}
#[test]
fn dimension_mismatch() {
let data = vec![vec![1.0, 2.0], vec![3.0]];
assert!(mahalanobis(&data, &MahalanobisConfig::default()).is_err());
}
#[test]
fn nan_rejected() {
let data = vec![
vec![1.0, 2.0],
vec![f64::NAN, 3.0],
vec![4.0, 5.0],
vec![6.0, 7.0],
];
let err = mahalanobis(&data, &MahalanobisConfig::default()).unwrap_err();
match err {
InsightError::DegenerateData { reason } => {
assert!(
reason.contains("row 1") && reason.contains("column 0"),
"reason should pinpoint the offending cell, got: {reason}"
);
}
other => panic!("expected DegenerateData, got {other:?}"),
}
}
#[test]
fn infinity_rejected_as_degenerate() {
let data = vec![
vec![1.0, 2.0],
vec![3.0, f64::INFINITY],
vec![4.0, 5.0],
vec![6.0, 7.0],
];
let err = mahalanobis(&data, &MahalanobisConfig::default()).unwrap_err();
assert!(matches!(err, InsightError::DegenerateData { .. }));
}
#[test]
fn zero_variance_column_diagnosed() {
let data: Vec<Vec<f64>> = (0..20)
.map(|i| {
let t = i as f64 * 0.1;
vec![5.0, t.sin(), t.cos()]
})
.collect();
let err = mahalanobis(&data, &MahalanobisConfig::default()).unwrap_err();
match err {
InsightError::DegenerateData { reason } => {
assert!(
reason.contains('0'),
"reason should mention column 0, got: {reason}"
);
assert!(
reason.contains("variance") || reason.contains("singular"),
"reason should explain the cause, got: {reason}"
);
}
other => panic!("expected DegenerateData, got {other:?}"),
}
}
#[test]
fn collinear_columns_diagnosed_as_singular() {
let data: Vec<Vec<f64>> = (0..20)
.map(|i| {
let t = i as f64 * 0.5;
vec![t, 2.0 * t]
})
.collect();
let err = mahalanobis(&data, &MahalanobisConfig::default()).unwrap_err();
assert!(
matches!(err, InsightError::DegenerateData { .. }),
"collinear input must yield DegenerateData, got {err:?}"
);
}
#[test]
fn distances_non_negative() {
let data: Vec<Vec<f64>> = (0..15)
.map(|i| {
let t = i as f64 * 0.5;
let noise = (i as f64 * 1.37).sin() * 0.2;
vec![t + noise, t * 1.5 - noise]
})
.collect();
let result = mahalanobis(&data, &MahalanobisConfig::default()).unwrap();
for &d in &result.distances {
assert!(d >= 0.0, "distance should be non-negative, got {d}");
}
}
#[test]
fn high_dimensional() {
let mut data: Vec<Vec<f64>> = (0..30)
.map(|i| {
let t = i as f64 * 0.1;
let n = (i as f64 * 0.97).sin() * 0.3;
vec![t + n, t * 0.5 - n, t * 2.0 + n * 0.5, t.sin(), t.cos()]
})
.collect();
data.push(vec![100.0, 100.0, 100.0, 100.0, 100.0]);
let result = mahalanobis(&data, &MahalanobisConfig::default()).unwrap();
let max_idx = result
.distances
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
assert_eq!(max_idx, 30);
}
#[test]
fn custom_quantile() {
let mut data = make_cluster(&[0.0, 0.0], 30, 1.0);
data.push(vec![10.0, 10.0]);
let strict = mahalanobis(&data, &MahalanobisConfig::default().chi2_quantile(0.99)).unwrap();
let lenient =
mahalanobis(&data, &MahalanobisConfig::default().chi2_quantile(0.95)).unwrap();
assert!(strict.threshold >= lenient.threshold);
}
#[test]
fn chi2_quantile_known_values() {
let q = chi2_quantile(2.0, 0.95);
assert!(
(q - 5.991).abs() < 0.1,
"chi2(2, 0.95) expected ~5.991, got {q}"
);
let q1 = chi2_quantile(1.0, 0.95);
assert!(
(q1 - 3.841).abs() < 0.1,
"chi2(1, 0.95) expected ~3.841, got {q1}"
);
let q5 = chi2_quantile(5.0, 0.975);
assert!(
(q5 - 12.833).abs() < 0.2,
"chi2(5, 0.975) expected ~12.833, got {q5}"
);
}
}