1use ferrolearn_core::error::FerroError;
33use ferrolearn_core::traits::Fit;
34use ndarray::Array2;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum Dissimilarity {
43 Euclidean,
46 Precomputed,
48}
49
50#[derive(Debug, Clone)]
59pub struct MDS {
60 n_components: usize,
62 dissimilarity: Dissimilarity,
64}
65
66impl MDS {
67 #[must_use]
72 pub fn new(n_components: usize) -> Self {
73 Self {
74 n_components,
75 dissimilarity: Dissimilarity::Euclidean,
76 }
77 }
78
79 #[must_use]
81 pub fn with_dissimilarity(mut self, d: Dissimilarity) -> Self {
82 self.dissimilarity = d;
83 self
84 }
85
86 #[must_use]
88 pub fn n_components(&self) -> usize {
89 self.n_components
90 }
91
92 #[must_use]
94 pub fn dissimilarity(&self) -> Dissimilarity {
95 self.dissimilarity
96 }
97}
98
99#[derive(Debug, Clone)]
107pub struct FittedMDS {
108 embedding_: Array2<f64>,
110 stress_: f64,
112}
113
114impl FittedMDS {
115 #[must_use]
117 pub fn embedding(&self) -> &Array2<f64> {
118 &self.embedding_
119 }
120
121 #[must_use]
123 pub fn stress(&self) -> f64 {
124 self.stress_
125 }
126}
127
128pub(crate) fn pairwise_sq_distances(x: &Array2<f64>) -> Array2<f64> {
134 let n = x.nrows();
135 let mut d = Array2::<f64>::zeros((n, n));
136 for i in 0..n {
137 for j in (i + 1)..n {
138 let mut sq = 0.0;
139 for k in 0..x.ncols() {
140 let diff = x[[i, k]] - x[[j, k]];
141 sq += diff * diff;
142 }
143 d[[i, j]] = sq;
144 d[[j, i]] = sq;
145 }
146 }
147 d
148}
149
150fn kruskal_stress(dist_orig: &Array2<f64>, embedding: &Array2<f64>) -> f64 {
152 let n = embedding.nrows();
153 let mut numerator = 0.0;
154 let mut denominator = 0.0;
155 for i in 0..n {
156 for j in (i + 1)..n {
157 let d_orig = dist_orig[[i, j]].sqrt();
158 let mut sq = 0.0;
159 for k in 0..embedding.ncols() {
160 let diff = embedding[[i, k]] - embedding[[j, k]];
161 sq += diff * diff;
162 }
163 let d_embed = sq.sqrt();
164 let diff = d_orig - d_embed;
165 numerator += diff * diff;
166 denominator += d_orig * d_orig;
167 }
168 }
169 if denominator > 0.0 {
170 (numerator / denominator).sqrt()
171 } else {
172 0.0
173 }
174}
175
176pub(crate) fn eigh_faer(a: &Array2<f64>) -> Result<(Vec<f64>, Array2<f64>), FerroError> {
178 let n = a.nrows();
179 let mat = faer::Mat::from_fn(n, n, |i, j| a[[i, j]]);
180 let decomp = mat.self_adjoint_eigen(faer::Side::Lower).map_err(|e| {
181 FerroError::NumericalInstability {
182 message: format!("Symmetric eigendecomposition failed: {e:?}"),
183 }
184 })?;
185
186 let eigenvalues: Vec<f64> = decomp.S().column_vector().iter().copied().collect();
187 let eigenvectors = Array2::from_shape_fn((n, n), |(i, j)| decomp.U()[(i, j)]);
188
189 Ok((eigenvalues, eigenvectors))
190}
191
192pub(crate) fn classical_mds(
196 sq_dist: &Array2<f64>,
197 n_components: usize,
198) -> Result<(Array2<f64>, f64), FerroError> {
199 let n = sq_dist.nrows();
200
201 let n_f = n as f64;
203 let mut row_means = vec![0.0; n];
204 let mut col_means = vec![0.0; n];
205 let mut grand_mean = 0.0;
206
207 for i in 0..n {
208 for j in 0..n {
209 row_means[i] += sq_dist[[i, j]];
210 col_means[j] += sq_dist[[i, j]];
211 grand_mean += sq_dist[[i, j]];
212 }
213 }
214 for i in 0..n {
215 row_means[i] /= n_f;
216 col_means[i] /= n_f;
217 }
218 grand_mean /= n_f * n_f;
219
220 let mut b = Array2::<f64>::zeros((n, n));
221 for i in 0..n {
222 for j in 0..n {
223 b[[i, j]] = -0.5 * (sq_dist[[i, j]] - row_means[i] - col_means[j] + grand_mean);
224 }
225 }
226
227 let (eigenvalues, eigenvectors) = eigh_faer(&b)?;
229
230 let mut indices: Vec<usize> = (0..n).collect();
232 indices.sort_by(|&a, &b_idx| {
233 eigenvalues[b_idx]
234 .partial_cmp(&eigenvalues[a])
235 .unwrap_or(std::cmp::Ordering::Equal)
236 });
237
238 let n_comp = n_components.min(n);
240 let mut embedding = Array2::<f64>::zeros((n, n_comp));
241 for (k, &idx) in indices.iter().take(n_comp).enumerate() {
242 let eigval = eigenvalues[idx].max(0.0);
243 let scale = eigval.sqrt();
244 for i in 0..n {
245 embedding[[i, k]] = eigenvectors[[i, idx]] * scale;
246 }
247 }
248
249 let stress = kruskal_stress(sq_dist, &embedding);
251
252 Ok((embedding, stress))
253}
254
255impl Fit<Array2<f64>, ()> for MDS {
260 type Fitted = FittedMDS;
261 type Error = FerroError;
262
263 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedMDS, FerroError> {
272 if self.n_components == 0 {
273 return Err(FerroError::InvalidParameter {
274 name: "n_components".into(),
275 reason: "must be at least 1".into(),
276 });
277 }
278
279 let sq_dist = match self.dissimilarity {
280 Dissimilarity::Euclidean => {
281 let n_samples = x.nrows();
282 if n_samples < 2 {
283 return Err(FerroError::InsufficientSamples {
284 required: 2,
285 actual: n_samples,
286 context: "MDS::fit requires at least 2 samples".into(),
287 });
288 }
289 if self.n_components > n_samples {
290 return Err(FerroError::InvalidParameter {
291 name: "n_components".into(),
292 reason: format!(
293 "n_components ({}) exceeds n_samples ({})",
294 self.n_components, n_samples
295 ),
296 });
297 }
298 pairwise_sq_distances(x)
299 }
300 Dissimilarity::Precomputed => {
301 if x.nrows() != x.ncols() {
302 return Err(FerroError::ShapeMismatch {
303 expected: vec![x.nrows(), x.nrows()],
304 actual: vec![x.nrows(), x.ncols()],
305 context: "MDS with Precomputed dissimilarity requires a square matrix"
306 .into(),
307 });
308 }
309 let n = x.nrows();
310 if n < 2 {
311 return Err(FerroError::InsufficientSamples {
312 required: 2,
313 actual: n,
314 context: "MDS::fit requires at least 2 samples".into(),
315 });
316 }
317 if self.n_components > n {
318 return Err(FerroError::InvalidParameter {
319 name: "n_components".into(),
320 reason: format!(
321 "n_components ({}) exceeds n_samples ({})",
322 self.n_components, n
323 ),
324 });
325 }
326 x.mapv(|v| v * v)
328 }
329 };
330
331 let (embedding, stress) = classical_mds(&sq_dist, self.n_components)?;
332
333 Ok(FittedMDS {
334 embedding_: embedding,
335 stress_: stress,
336 })
337 }
338}
339
340#[cfg(test)]
345mod tests {
346 use super::*;
347 use approx::assert_abs_diff_eq;
348 use ndarray::array;
349
350 fn square_data() -> Array2<f64> {
352 array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],]
353 }
354
355 #[test]
356 fn test_mds_basic_embedding_shape() {
357 let mds = MDS::new(2);
358 let x = square_data();
359 let fitted = mds.fit(&x, &()).unwrap();
360 assert_eq!(fitted.embedding().dim(), (4, 2));
361 }
362
363 #[test]
364 fn test_mds_1d_embedding() {
365 let mds = MDS::new(1);
366 let x = array![[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0],];
367 let fitted = mds.fit(&x, &()).unwrap();
368 assert_eq!(fitted.embedding().ncols(), 1);
369 }
370
371 #[test]
372 fn test_mds_stress_non_negative() {
373 let mds = MDS::new(2);
374 let x = square_data();
375 let fitted = mds.fit(&x, &()).unwrap();
376 assert!(fitted.stress() >= 0.0);
377 }
378
379 #[test]
380 fn test_mds_perfect_embedding_low_stress() {
381 let mds = MDS::new(2);
383 let x = square_data();
384 let fitted = mds.fit(&x, &()).unwrap();
385 assert!(fitted.stress() < 0.1, "stress = {}", fitted.stress());
386 }
387
388 #[test]
389 fn test_mds_preserves_distances() {
390 let mds = MDS::new(2);
391 let x = square_data();
392 let fitted = mds.fit(&x, &()).unwrap();
393 let emb = fitted.embedding();
394
395 let orig = pairwise_sq_distances(&x);
398 for i in 0..4 {
399 for j in (i + 1)..4 {
400 let d_orig = orig[[i, j]].sqrt();
401 let mut sq = 0.0;
402 for k in 0..emb.ncols() {
403 let diff = emb[[i, k]] - emb[[j, k]];
404 sq += diff * diff;
405 }
406 let d_emb = sq.sqrt();
407 assert_abs_diff_eq!(d_orig, d_emb, epsilon = 0.3);
408 }
409 }
410 }
411
412 #[test]
413 fn test_mds_precomputed() {
414 let x = square_data();
416 let sq = pairwise_sq_distances(&x);
417 let dist = sq.mapv(f64::sqrt);
418
419 let mds = MDS::new(2).with_dissimilarity(Dissimilarity::Precomputed);
420 let fitted = mds.fit(&dist, &()).unwrap();
421 assert_eq!(fitted.embedding().dim(), (4, 2));
422 }
423
424 #[test]
425 fn test_mds_invalid_n_components_zero() {
426 let mds = MDS::new(0);
427 let x = square_data();
428 assert!(mds.fit(&x, &()).is_err());
429 }
430
431 #[test]
432 fn test_mds_invalid_n_components_too_large() {
433 let mds = MDS::new(10);
434 let x = square_data(); assert!(mds.fit(&x, &()).is_err());
436 }
437
438 #[test]
439 fn test_mds_insufficient_samples() {
440 let mds = MDS::new(1);
441 let x = array![[1.0, 2.0]]; assert!(mds.fit(&x, &()).is_err());
443 }
444
445 #[test]
446 fn test_mds_precomputed_not_square() {
447 let mds = MDS::new(1).with_dissimilarity(Dissimilarity::Precomputed);
448 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
449 assert!(mds.fit(&x, &()).is_err());
450 }
451
452 #[test]
453 fn test_mds_collinear_data() {
454 let mds = MDS::new(1);
456 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0],];
457 let fitted = mds.fit(&x, &()).unwrap();
458 assert_eq!(fitted.embedding().ncols(), 1);
459 let emb = fitted.embedding();
461 let mut vals: Vec<f64> = (0..5).map(|i| emb[[i, 0]]).collect();
462 vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
463 let diffs: Vec<f64> = vals.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
464 for d in &diffs {
465 assert_abs_diff_eq!(d, &diffs[0], epsilon = 0.1);
466 }
467 }
468
469 #[test]
470 fn test_mds_getters() {
471 let mds = MDS::new(3).with_dissimilarity(Dissimilarity::Precomputed);
472 assert_eq!(mds.n_components(), 3);
473 assert_eq!(mds.dissimilarity(), Dissimilarity::Precomputed);
474 }
475
476 #[test]
477 fn test_mds_larger_dataset() {
478 let n = 20;
479 let d = 5;
480 let mut data = Array2::<f64>::zeros((n, d));
481 for i in 0..n {
482 for j in 0..d {
483 data[[i, j]] = (i * d + j) as f64 / (n * d) as f64;
484 }
485 }
486 let mds = MDS::new(2);
487 let fitted = mds.fit(&data, &()).unwrap();
488 assert_eq!(fitted.embedding().dim(), (20, 2));
489 assert!(fitted.stress() >= 0.0);
490 }
491}