Skip to main content

ferrolearn_decomp/
mds.rs

1//! Multidimensional Scaling (MDS).
2//!
3//! Classical (metric) MDS embeds data into a low-dimensional space such that
4//! pairwise distances are preserved as well as possible.
5//!
6//! # Algorithm
7//!
8//! 1. Compute the pairwise squared-distance matrix `D^2` (or accept a
9//!    precomputed dissimilarity matrix).
10//! 2. Double-centre `D^2`:  `B = -0.5 * J D^2 J`  where `J = I - (1/n) 11^T`.
11//! 3. Eigendecompose `B` and retain the top `n_components` eigenvectors
12//!    scaled by the square root of their eigenvalues.
13//!
14//! # Examples
15//!
16//! ```
17//! use ferrolearn_decomp::{MDS, Dissimilarity};
18//! use ferrolearn_core::traits::Fit;
19//! use ndarray::array;
20//!
21//! let mds = MDS::new(2);
22//! let x = array![
23//!     [0.0, 0.0],
24//!     [1.0, 0.0],
25//!     [0.0, 1.0],
26//!     [1.0, 1.0],
27//! ];
28//! let fitted = mds.fit(&x, &()).unwrap();
29//! assert_eq!(fitted.embedding().ncols(), 2);
30//! ```
31
32use ferrolearn_core::error::FerroError;
33use ferrolearn_core::traits::Fit;
34use ndarray::Array2;
35
36// ---------------------------------------------------------------------------
37// Dissimilarity type
38// ---------------------------------------------------------------------------
39
40/// How the input matrix should be interpreted.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum Dissimilarity {
43    /// The input is a feature matrix; pairwise Euclidean distances will be
44    /// computed internally.
45    Euclidean,
46    /// The input is already a square pairwise-distance matrix.
47    Precomputed,
48}
49
50// ---------------------------------------------------------------------------
51// MDS (unfitted)
52// ---------------------------------------------------------------------------
53
54/// Classical Multidimensional Scaling configuration.
55///
56/// Holds hyperparameters for the MDS algorithm. Call [`Fit::fit`] to compute
57/// the embedding and obtain a [`FittedMDS`].
58#[derive(Debug, Clone)]
59pub struct MDS {
60    /// Number of embedding dimensions.
61    n_components: usize,
62    /// Whether input is a feature matrix or a precomputed distance matrix.
63    dissimilarity: Dissimilarity,
64}
65
66impl MDS {
67    /// Create a new `MDS` that embeds into `n_components` dimensions.
68    ///
69    /// By default the input is treated as a feature matrix
70    /// ([`Dissimilarity::Euclidean`]).
71    #[must_use]
72    pub fn new(n_components: usize) -> Self {
73        Self {
74            n_components,
75            dissimilarity: Dissimilarity::Euclidean,
76        }
77    }
78
79    /// Set the dissimilarity mode.
80    #[must_use]
81    pub fn with_dissimilarity(mut self, d: Dissimilarity) -> Self {
82        self.dissimilarity = d;
83        self
84    }
85
86    /// Return the configured number of components.
87    #[must_use]
88    pub fn n_components(&self) -> usize {
89        self.n_components
90    }
91
92    /// Return the configured dissimilarity mode.
93    #[must_use]
94    pub fn dissimilarity(&self) -> Dissimilarity {
95        self.dissimilarity
96    }
97}
98
99// ---------------------------------------------------------------------------
100// FittedMDS
101// ---------------------------------------------------------------------------
102
103/// A fitted MDS model holding the learned embedding.
104///
105/// Created by calling [`Fit::fit`] on an [`MDS`].
106#[derive(Debug, Clone)]
107pub struct FittedMDS {
108    /// The embedding, shape `(n_samples, n_components)`.
109    embedding_: Array2<f64>,
110    /// Kruskal's stress-1 measure of fit quality.
111    stress_: f64,
112}
113
114impl FittedMDS {
115    /// The embedding coordinates, shape `(n_samples, n_components)`.
116    #[must_use]
117    pub fn embedding(&self) -> &Array2<f64> {
118        &self.embedding_
119    }
120
121    /// Kruskal's stress-1 (lower is better).
122    #[must_use]
123    pub fn stress(&self) -> f64 {
124        self.stress_
125    }
126}
127
128// ---------------------------------------------------------------------------
129// Helper: pairwise squared-Euclidean distance matrix
130// ---------------------------------------------------------------------------
131
132/// Compute the pairwise squared-Euclidean distance matrix.
133pub(crate) fn pairwise_sq_distances(x: &Array2<f64>) -> Array2<f64> {
134    let n = x.nrows();
135    let mut d = Array2::<f64>::zeros((n, n));
136    for i in 0..n {
137        for j in (i + 1)..n {
138            let mut sq = 0.0;
139            for k in 0..x.ncols() {
140                let diff = x[[i, k]] - x[[j, k]];
141                sq += diff * diff;
142            }
143            d[[i, j]] = sq;
144            d[[j, i]] = sq;
145        }
146    }
147    d
148}
149
150/// Compute Kruskal's stress-1.
151fn kruskal_stress(dist_orig: &Array2<f64>, embedding: &Array2<f64>) -> f64 {
152    let n = embedding.nrows();
153    let mut numerator = 0.0;
154    let mut denominator = 0.0;
155    for i in 0..n {
156        for j in (i + 1)..n {
157            let d_orig = dist_orig[[i, j]].sqrt();
158            let mut sq = 0.0;
159            for k in 0..embedding.ncols() {
160                let diff = embedding[[i, k]] - embedding[[j, k]];
161                sq += diff * diff;
162            }
163            let d_embed = sq.sqrt();
164            let diff = d_orig - d_embed;
165            numerator += diff * diff;
166            denominator += d_orig * d_orig;
167        }
168    }
169    if denominator > 0.0 {
170        (numerator / denominator).sqrt()
171    } else {
172        0.0
173    }
174}
175
176/// Eigendecompose a symmetric matrix using faer's self-adjoint eigen.
177pub(crate) fn eigh_faer(a: &Array2<f64>) -> Result<(Vec<f64>, Array2<f64>), FerroError> {
178    let n = a.nrows();
179    let mat = faer::Mat::from_fn(n, n, |i, j| a[[i, j]]);
180    let decomp = mat.self_adjoint_eigen(faer::Side::Lower).map_err(|e| {
181        FerroError::NumericalInstability {
182            message: format!("Symmetric eigendecomposition failed: {e:?}"),
183        }
184    })?;
185
186    let eigenvalues: Vec<f64> = decomp.S().column_vector().iter().copied().collect();
187    let eigenvectors = Array2::from_shape_fn((n, n), |(i, j)| decomp.U()[(i, j)]);
188
189    Ok((eigenvalues, eigenvectors))
190}
191
192/// Core classical MDS on a squared-distance matrix.
193///
194/// Returns `(embedding, stress)`.
195pub(crate) fn classical_mds(
196    sq_dist: &Array2<f64>,
197    n_components: usize,
198) -> Result<(Array2<f64>, f64), FerroError> {
199    let n = sq_dist.nrows();
200
201    // Double-centre: B = -0.5 * J * D^2 * J, where J = I - (1/n) * 11^T
202    let n_f = n as f64;
203    let mut row_means = vec![0.0; n];
204    let mut col_means = vec![0.0; n];
205    let mut grand_mean = 0.0;
206
207    for i in 0..n {
208        for j in 0..n {
209            row_means[i] += sq_dist[[i, j]];
210            col_means[j] += sq_dist[[i, j]];
211            grand_mean += sq_dist[[i, j]];
212        }
213    }
214    for i in 0..n {
215        row_means[i] /= n_f;
216        col_means[i] /= n_f;
217    }
218    grand_mean /= n_f * n_f;
219
220    let mut b = Array2::<f64>::zeros((n, n));
221    for i in 0..n {
222        for j in 0..n {
223            b[[i, j]] = -0.5 * (sq_dist[[i, j]] - row_means[i] - col_means[j] + grand_mean);
224        }
225    }
226
227    // Eigendecompose B
228    let (eigenvalues, eigenvectors) = eigh_faer(&b)?;
229
230    // Sort eigenvalues descending
231    let mut indices: Vec<usize> = (0..n).collect();
232    indices.sort_by(|&a, &b_idx| {
233        eigenvalues[b_idx]
234            .partial_cmp(&eigenvalues[a])
235            .unwrap_or(std::cmp::Ordering::Equal)
236    });
237
238    // Build embedding: X_k = v_k * sqrt(lambda_k)
239    let n_comp = n_components.min(n);
240    let mut embedding = Array2::<f64>::zeros((n, n_comp));
241    for (k, &idx) in indices.iter().take(n_comp).enumerate() {
242        let eigval = eigenvalues[idx].max(0.0);
243        let scale = eigval.sqrt();
244        for i in 0..n {
245            embedding[[i, k]] = eigenvectors[[i, idx]] * scale;
246        }
247    }
248
249    // Compute stress
250    let stress = kruskal_stress(sq_dist, &embedding);
251
252    Ok((embedding, stress))
253}
254
255// ---------------------------------------------------------------------------
256// Trait implementations
257// ---------------------------------------------------------------------------
258
259impl Fit<Array2<f64>, ()> for MDS {
260    type Fitted = FittedMDS;
261    type Error = FerroError;
262
263    /// Fit classical MDS.
264    ///
265    /// # Errors
266    ///
267    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or too large.
268    /// - [`FerroError::InsufficientSamples`] if there are fewer than 2 samples.
269    /// - [`FerroError::ShapeMismatch`] if `Precomputed` is set but the matrix
270    ///   is not square.
271    fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedMDS, FerroError> {
272        if self.n_components == 0 {
273            return Err(FerroError::InvalidParameter {
274                name: "n_components".into(),
275                reason: "must be at least 1".into(),
276            });
277        }
278
279        let sq_dist = match self.dissimilarity {
280            Dissimilarity::Euclidean => {
281                let n_samples = x.nrows();
282                if n_samples < 2 {
283                    return Err(FerroError::InsufficientSamples {
284                        required: 2,
285                        actual: n_samples,
286                        context: "MDS::fit requires at least 2 samples".into(),
287                    });
288                }
289                if self.n_components > n_samples {
290                    return Err(FerroError::InvalidParameter {
291                        name: "n_components".into(),
292                        reason: format!(
293                            "n_components ({}) exceeds n_samples ({})",
294                            self.n_components, n_samples
295                        ),
296                    });
297                }
298                pairwise_sq_distances(x)
299            }
300            Dissimilarity::Precomputed => {
301                if x.nrows() != x.ncols() {
302                    return Err(FerroError::ShapeMismatch {
303                        expected: vec![x.nrows(), x.nrows()],
304                        actual: vec![x.nrows(), x.ncols()],
305                        context: "MDS with Precomputed dissimilarity requires a square matrix"
306                            .into(),
307                    });
308                }
309                let n = x.nrows();
310                if n < 2 {
311                    return Err(FerroError::InsufficientSamples {
312                        required: 2,
313                        actual: n,
314                        context: "MDS::fit requires at least 2 samples".into(),
315                    });
316                }
317                if self.n_components > n {
318                    return Err(FerroError::InvalidParameter {
319                        name: "n_components".into(),
320                        reason: format!(
321                            "n_components ({}) exceeds n_samples ({})",
322                            self.n_components, n
323                        ),
324                    });
325                }
326                // Input is already distances; square them for classical MDS
327                x.mapv(|v| v * v)
328            }
329        };
330
331        let (embedding, stress) = classical_mds(&sq_dist, self.n_components)?;
332
333        Ok(FittedMDS {
334            embedding_: embedding,
335            stress_: stress,
336        })
337    }
338}
339
340// ---------------------------------------------------------------------------
341// Tests
342// ---------------------------------------------------------------------------
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use approx::assert_abs_diff_eq;
348    use ndarray::array;
349
350    /// Helper: simple 2D dataset.
351    fn square_data() -> Array2<f64> {
352        array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],]
353    }
354
355    #[test]
356    fn test_mds_basic_embedding_shape() {
357        let mds = MDS::new(2);
358        let x = square_data();
359        let fitted = mds.fit(&x, &()).unwrap();
360        assert_eq!(fitted.embedding().dim(), (4, 2));
361    }
362
363    #[test]
364    fn test_mds_1d_embedding() {
365        let mds = MDS::new(1);
366        let x = array![[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0],];
367        let fitted = mds.fit(&x, &()).unwrap();
368        assert_eq!(fitted.embedding().ncols(), 1);
369    }
370
371    #[test]
372    fn test_mds_stress_non_negative() {
373        let mds = MDS::new(2);
374        let x = square_data();
375        let fitted = mds.fit(&x, &()).unwrap();
376        assert!(fitted.stress() >= 0.0);
377    }
378
379    #[test]
380    fn test_mds_perfect_embedding_low_stress() {
381        // 2D points embedded into 2D should have near-zero stress.
382        let mds = MDS::new(2);
383        let x = square_data();
384        let fitted = mds.fit(&x, &()).unwrap();
385        assert!(fitted.stress() < 0.1, "stress = {}", fitted.stress());
386    }
387
388    #[test]
389    fn test_mds_preserves_distances() {
390        let mds = MDS::new(2);
391        let x = square_data();
392        let fitted = mds.fit(&x, &()).unwrap();
393        let emb = fitted.embedding();
394
395        // Check that pairwise distances in the embedding approximately
396        // match the original pairwise distances.
397        let orig = pairwise_sq_distances(&x);
398        for i in 0..4 {
399            for j in (i + 1)..4 {
400                let d_orig = orig[[i, j]].sqrt();
401                let mut sq = 0.0;
402                for k in 0..emb.ncols() {
403                    let diff = emb[[i, k]] - emb[[j, k]];
404                    sq += diff * diff;
405                }
406                let d_emb = sq.sqrt();
407                assert_abs_diff_eq!(d_orig, d_emb, epsilon = 0.3);
408            }
409        }
410    }
411
412    #[test]
413    fn test_mds_precomputed() {
414        // Build a precomputed distance matrix.
415        let x = square_data();
416        let sq = pairwise_sq_distances(&x);
417        let dist = sq.mapv(f64::sqrt);
418
419        let mds = MDS::new(2).with_dissimilarity(Dissimilarity::Precomputed);
420        let fitted = mds.fit(&dist, &()).unwrap();
421        assert_eq!(fitted.embedding().dim(), (4, 2));
422    }
423
424    #[test]
425    fn test_mds_invalid_n_components_zero() {
426        let mds = MDS::new(0);
427        let x = square_data();
428        assert!(mds.fit(&x, &()).is_err());
429    }
430
431    #[test]
432    fn test_mds_invalid_n_components_too_large() {
433        let mds = MDS::new(10);
434        let x = square_data(); // 4 samples
435        assert!(mds.fit(&x, &()).is_err());
436    }
437
438    #[test]
439    fn test_mds_insufficient_samples() {
440        let mds = MDS::new(1);
441        let x = array![[1.0, 2.0]]; // 1 sample
442        assert!(mds.fit(&x, &()).is_err());
443    }
444
445    #[test]
446    fn test_mds_precomputed_not_square() {
447        let mds = MDS::new(1).with_dissimilarity(Dissimilarity::Precomputed);
448        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
449        assert!(mds.fit(&x, &()).is_err());
450    }
451
452    #[test]
453    fn test_mds_collinear_data() {
454        // Points on a line should embed well into 1D.
455        let mds = MDS::new(1);
456        let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0],];
457        let fitted = mds.fit(&x, &()).unwrap();
458        assert_eq!(fitted.embedding().ncols(), 1);
459        // Differences between consecutive embeddings should be roughly equal.
460        let emb = fitted.embedding();
461        let mut vals: Vec<f64> = (0..5).map(|i| emb[[i, 0]]).collect();
462        vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
463        let diffs: Vec<f64> = vals.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
464        for d in &diffs {
465            assert_abs_diff_eq!(d, &diffs[0], epsilon = 0.1);
466        }
467    }
468
469    #[test]
470    fn test_mds_getters() {
471        let mds = MDS::new(3).with_dissimilarity(Dissimilarity::Precomputed);
472        assert_eq!(mds.n_components(), 3);
473        assert_eq!(mds.dissimilarity(), Dissimilarity::Precomputed);
474    }
475
476    #[test]
477    fn test_mds_larger_dataset() {
478        let n = 20;
479        let d = 5;
480        let mut data = Array2::<f64>::zeros((n, d));
481        for i in 0..n {
482            for j in 0..d {
483                data[[i, j]] = (i * d + j) as f64 / (n * d) as f64;
484            }
485        }
486        let mds = MDS::new(2);
487        let fitted = mds.fit(&data, &()).unwrap();
488        assert_eq!(fitted.embedding().dim(), (20, 2));
489        assert!(fitted.stress() >= 0.0);
490    }
491}