use anofox_ml_core::{FitUnsupervised, Predict, Result, RustMlError};
use ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub struct MeanShift {
pub bandwidth: f64,
pub max_iter: usize,
pub tol: f64,
}
impl MeanShift {
pub fn new(bandwidth: f64) -> Self {
Self {
bandwidth,
max_iter: 300,
tol: 1e-3,
}
}
pub fn with_max_iter(mut self, m: usize) -> Self {
self.max_iter = m;
self
}
pub fn with_tol(mut self, t: f64) -> Self {
self.tol = t;
self
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FittedMeanShift {
pub cluster_centers: Array2<f64>,
pub labels: Array1<f64>,
}
fn sq_euclid(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
fn shift_one(seed: &[f64], x: &Array2<f64>, bw_sq: f64) -> Vec<f64> {
let d = seed.len();
let mut num = vec![0.0_f64; d];
let mut den = 0.0_f64;
for i in 0..x.nrows() {
let xi: Vec<f64> = x.row(i).iter().copied().collect();
let sd = sq_euclid(seed, &xi);
if sd <= bw_sq {
for j in 0..d {
num[j] += xi[j];
}
den += 1.0;
}
}
if den == 0.0 {
seed.to_vec()
} else {
num.iter().map(|v| v / den).collect()
}
}
impl FitUnsupervised<f64> for MeanShift {
type Fitted = FittedMeanShift;
fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
let n = x.nrows();
if n == 0 {
return Err(RustMlError::EmptyInput("empty input".into()));
}
if self.bandwidth <= 0.0 {
return Err(RustMlError::InvalidParameter(
"bandwidth must be > 0".into(),
));
}
let d = x.ncols();
let bw_sq = self.bandwidth * self.bandwidth;
let mut modes: Vec<Vec<f64>> = Vec::with_capacity(n);
for i in 0..n {
let mut cur: Vec<f64> = x.row(i).iter().copied().collect();
for _ in 0..self.max_iter {
let next = shift_one(&cur, x, bw_sq);
let shift = sq_euclid(&cur, &next).sqrt();
cur = next;
if shift < self.tol {
break;
}
}
modes.push(cur);
}
let merge_radius_sq = (self.bandwidth / 2.0).powi(2);
let mut centers: Vec<Vec<f64>> = Vec::new();
let mut labels = Array1::<f64>::zeros(n);
for (i, m) in modes.iter().enumerate() {
let mut assigned = None;
for (k, c) in centers.iter().enumerate() {
if sq_euclid(m, c) < merge_radius_sq {
assigned = Some(k);
break;
}
}
match assigned {
Some(k) => labels[i] = k as f64,
None => {
labels[i] = centers.len() as f64;
centers.push(m.clone());
}
}
}
let n_clusters = centers.len();
let mut cc = Array2::<f64>::zeros((n_clusters, d));
for (k, c) in centers.iter().enumerate() {
for j in 0..d {
cc[[k, j]] = c[j];
}
}
Ok(FittedMeanShift {
cluster_centers: cc,
labels,
})
}
}
impl Predict<f64> for FittedMeanShift {
fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
let d = self.cluster_centers.ncols();
if x.ncols() != d {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
d,
x.ncols()
)));
}
let n = x.nrows();
let mut out = Array1::<f64>::zeros(n);
for i in 0..n {
let mut best = f64::INFINITY;
let mut best_k = 0;
let xi: Vec<f64> = x.row(i).iter().copied().collect();
for k in 0..self.cluster_centers.nrows() {
let ck: Vec<f64> = self.cluster_centers.row(k).iter().copied().collect();
let d2 = sq_euclid(&xi, &ck);
if d2 < best {
best = d2;
best_k = k;
}
}
out[i] = best_k as f64;
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_mean_shift_two_blobs() {
let x = array![
[0.0_f64, 0.0],
[0.1, 0.1],
[-0.1, 0.2],
[0.1, -0.1],
[10.0, 10.0],
[10.1, 9.9],
[9.9, 10.2],
[10.0, 10.1],
];
let ms = MeanShift::new(2.0);
let fitted = ms.fit(&x).unwrap();
assert_eq!(fitted.cluster_centers.nrows(), 2);
let l0 = fitted.labels[0];
for i in 1..4 {
assert_eq!(fitted.labels[i], l0);
}
for i in 4..8 {
assert_ne!(fitted.labels[i], l0);
}
}
#[test]
fn test_mean_shift_centers_near_blob_means() {
let x = array![
[0.0_f64, 0.0],
[0.1, 0.1],
[-0.1, 0.2],
[0.1, -0.1],
[10.0, 10.0],
[10.1, 9.9],
[9.9, 10.2],
[10.0, 10.1],
];
let fitted = MeanShift::new(2.0).fit(&x).unwrap();
let mut has_low = false;
let mut has_high = false;
for k in 0..fitted.cluster_centers.nrows() {
let cx = fitted.cluster_centers[[k, 0]];
if cx.abs() < 1.0 {
has_low = true;
}
if (cx - 10.0).abs() < 1.0 {
has_high = true;
}
}
assert!(has_low && has_high);
}
}
impl anofox_ml_core::ClassifierScore<f64> for FittedMeanShift {}