use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::Fit;
use ndarray::Array2;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Dissimilarity {
Euclidean,
Precomputed,
}
#[derive(Debug, Clone)]
pub struct MDS {
n_components: usize,
dissimilarity: Dissimilarity,
}
impl MDS {
#[must_use]
pub fn new(n_components: usize) -> Self {
Self {
n_components,
dissimilarity: Dissimilarity::Euclidean,
}
}
#[must_use]
pub fn with_dissimilarity(mut self, d: Dissimilarity) -> Self {
self.dissimilarity = d;
self
}
#[must_use]
pub fn n_components(&self) -> usize {
self.n_components
}
#[must_use]
pub fn dissimilarity(&self) -> Dissimilarity {
self.dissimilarity
}
}
#[derive(Debug, Clone)]
pub struct FittedMDS {
embedding_: Array2<f64>,
stress_: f64,
}
impl FittedMDS {
#[must_use]
pub fn embedding(&self) -> &Array2<f64> {
&self.embedding_
}
#[must_use]
pub fn stress(&self) -> f64 {
self.stress_
}
}
pub(crate) fn pairwise_sq_distances(x: &Array2<f64>) -> Array2<f64> {
let n = x.nrows();
let mut d = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in (i + 1)..n {
let mut sq = 0.0;
for k in 0..x.ncols() {
let diff = x[[i, k]] - x[[j, k]];
sq += diff * diff;
}
d[[i, j]] = sq;
d[[j, i]] = sq;
}
}
d
}
fn kruskal_stress(dist_orig: &Array2<f64>, embedding: &Array2<f64>) -> f64 {
let n = embedding.nrows();
let mut numerator = 0.0;
let mut denominator = 0.0;
for i in 0..n {
for j in (i + 1)..n {
let d_orig = dist_orig[[i, j]].sqrt();
let mut sq = 0.0;
for k in 0..embedding.ncols() {
let diff = embedding[[i, k]] - embedding[[j, k]];
sq += diff * diff;
}
let d_embed = sq.sqrt();
let diff = d_orig - d_embed;
numerator += diff * diff;
denominator += d_orig * d_orig;
}
}
if denominator > 0.0 {
(numerator / denominator).sqrt()
} else {
0.0
}
}
pub(crate) fn eigh_faer(a: &Array2<f64>) -> Result<(Vec<f64>, Array2<f64>), FerroError> {
let n = a.nrows();
let mat = faer::Mat::from_fn(n, n, |i, j| a[[i, j]]);
let decomp = mat.self_adjoint_eigen(faer::Side::Lower).map_err(|e| {
FerroError::NumericalInstability {
message: format!("Symmetric eigendecomposition failed: {e:?}"),
}
})?;
let eigenvalues: Vec<f64> = decomp.S().column_vector().iter().copied().collect();
let eigenvectors = Array2::from_shape_fn((n, n), |(i, j)| decomp.U()[(i, j)]);
Ok((eigenvalues, eigenvectors))
}
pub(crate) fn classical_mds(
sq_dist: &Array2<f64>,
n_components: usize,
) -> Result<(Array2<f64>, f64), FerroError> {
let n = sq_dist.nrows();
let n_f = n as f64;
let mut row_means = vec![0.0; n];
let mut col_means = vec![0.0; n];
let mut grand_mean = 0.0;
for i in 0..n {
for j in 0..n {
row_means[i] += sq_dist[[i, j]];
col_means[j] += sq_dist[[i, j]];
grand_mean += sq_dist[[i, j]];
}
}
for i in 0..n {
row_means[i] /= n_f;
col_means[i] /= n_f;
}
grand_mean /= n_f * n_f;
let mut b = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
b[[i, j]] = -0.5 * (sq_dist[[i, j]] - row_means[i] - col_means[j] + grand_mean);
}
}
let (eigenvalues, eigenvectors) = eigh_faer(&b)?;
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b_idx| {
eigenvalues[b_idx]
.partial_cmp(&eigenvalues[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let n_comp = n_components.min(n);
let mut embedding = Array2::<f64>::zeros((n, n_comp));
for (k, &idx) in indices.iter().take(n_comp).enumerate() {
let eigval = eigenvalues[idx].max(0.0);
let scale = eigval.sqrt();
for i in 0..n {
embedding[[i, k]] = eigenvectors[[i, idx]] * scale;
}
}
let stress = kruskal_stress(sq_dist, &embedding);
Ok((embedding, stress))
}
impl Fit<Array2<f64>, ()> for MDS {
type Fitted = FittedMDS;
type Error = FerroError;
fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedMDS, FerroError> {
if self.n_components == 0 {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: "must be at least 1".into(),
});
}
let sq_dist = match self.dissimilarity {
Dissimilarity::Euclidean => {
let n_samples = x.nrows();
if n_samples < 2 {
return Err(FerroError::InsufficientSamples {
required: 2,
actual: n_samples,
context: "MDS::fit requires at least 2 samples".into(),
});
}
if self.n_components > n_samples {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: format!(
"n_components ({}) exceeds n_samples ({})",
self.n_components, n_samples
),
});
}
pairwise_sq_distances(x)
}
Dissimilarity::Precomputed => {
if x.nrows() != x.ncols() {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), x.nrows()],
actual: vec![x.nrows(), x.ncols()],
context: "MDS with Precomputed dissimilarity requires a square matrix"
.into(),
});
}
let n = x.nrows();
if n < 2 {
return Err(FerroError::InsufficientSamples {
required: 2,
actual: n,
context: "MDS::fit requires at least 2 samples".into(),
});
}
if self.n_components > n {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: format!(
"n_components ({}) exceeds n_samples ({})",
self.n_components, n
),
});
}
x.mapv(|v| v * v)
}
};
let (embedding, stress) = classical_mds(&sq_dist, self.n_components)?;
Ok(FittedMDS {
embedding_: embedding,
stress_: stress,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
fn square_data() -> Array2<f64> {
array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],]
}
#[test]
fn test_mds_basic_embedding_shape() {
let mds = MDS::new(2);
let x = square_data();
let fitted = mds.fit(&x, &()).unwrap();
assert_eq!(fitted.embedding().dim(), (4, 2));
}
#[test]
fn test_mds_1d_embedding() {
let mds = MDS::new(1);
let x = array![[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0],];
let fitted = mds.fit(&x, &()).unwrap();
assert_eq!(fitted.embedding().ncols(), 1);
}
#[test]
fn test_mds_stress_non_negative() {
let mds = MDS::new(2);
let x = square_data();
let fitted = mds.fit(&x, &()).unwrap();
assert!(fitted.stress() >= 0.0);
}
#[test]
fn test_mds_perfect_embedding_low_stress() {
let mds = MDS::new(2);
let x = square_data();
let fitted = mds.fit(&x, &()).unwrap();
assert!(fitted.stress() < 0.1, "stress = {}", fitted.stress());
}
#[test]
fn test_mds_preserves_distances() {
let mds = MDS::new(2);
let x = square_data();
let fitted = mds.fit(&x, &()).unwrap();
let emb = fitted.embedding();
let orig = pairwise_sq_distances(&x);
for i in 0..4 {
for j in (i + 1)..4 {
let d_orig = orig[[i, j]].sqrt();
let mut sq = 0.0;
for k in 0..emb.ncols() {
let diff = emb[[i, k]] - emb[[j, k]];
sq += diff * diff;
}
let d_emb = sq.sqrt();
assert_abs_diff_eq!(d_orig, d_emb, epsilon = 0.3);
}
}
}
#[test]
fn test_mds_precomputed() {
let x = square_data();
let sq = pairwise_sq_distances(&x);
let dist = sq.mapv(f64::sqrt);
let mds = MDS::new(2).with_dissimilarity(Dissimilarity::Precomputed);
let fitted = mds.fit(&dist, &()).unwrap();
assert_eq!(fitted.embedding().dim(), (4, 2));
}
#[test]
fn test_mds_invalid_n_components_zero() {
let mds = MDS::new(0);
let x = square_data();
assert!(mds.fit(&x, &()).is_err());
}
#[test]
fn test_mds_invalid_n_components_too_large() {
let mds = MDS::new(10);
let x = square_data(); assert!(mds.fit(&x, &()).is_err());
}
#[test]
fn test_mds_insufficient_samples() {
let mds = MDS::new(1);
let x = array![[1.0, 2.0]]; assert!(mds.fit(&x, &()).is_err());
}
#[test]
fn test_mds_precomputed_not_square() {
let mds = MDS::new(1).with_dissimilarity(Dissimilarity::Precomputed);
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
assert!(mds.fit(&x, &()).is_err());
}
#[test]
fn test_mds_collinear_data() {
let mds = MDS::new(1);
let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0],];
let fitted = mds.fit(&x, &()).unwrap();
assert_eq!(fitted.embedding().ncols(), 1);
let emb = fitted.embedding();
let mut vals: Vec<f64> = (0..5).map(|i| emb[[i, 0]]).collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
let diffs: Vec<f64> = vals.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
for d in &diffs {
assert_abs_diff_eq!(d, &diffs[0], epsilon = 0.1);
}
}
#[test]
fn test_mds_getters() {
let mds = MDS::new(3).with_dissimilarity(Dissimilarity::Precomputed);
assert_eq!(mds.n_components(), 3);
assert_eq!(mds.dissimilarity(), Dissimilarity::Precomputed);
}
#[test]
fn test_mds_larger_dataset() {
let n = 20;
let d = 5;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
for j in 0..d {
data[[i, j]] = (i * d + j) as f64 / (n * d) as f64;
}
}
let mds = MDS::new(2);
let fitted = mds.fit(&data, &()).unwrap();
assert_eq!(fitted.embedding().dim(), (20, 2));
assert!(fitted.stress() >= 0.0);
}
}