1use ferrolearn_core::error::FerroError;
37use ferrolearn_core::traits::{Fit, Transform};
38use ndarray::Array2;
39use rand::SeedableRng;
40use rand_distr::{Distribution, Uniform};
41use rand_xoshiro::Xoshiro256PlusPlus;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum UmapMetric {
50 Euclidean,
52 Manhattan,
54 Cosine,
56}
57
58#[derive(Debug, Clone)]
67pub struct Umap {
68 n_components: usize,
70 n_neighbors: usize,
72 min_dist: f64,
74 spread: f64,
76 learning_rate: f64,
78 n_epochs: usize,
80 metric: UmapMetric,
82 negative_sample_rate: usize,
84 random_state: Option<u64>,
86}
87
88impl Umap {
89 #[must_use]
95 pub fn new() -> Self {
96 Self {
97 n_components: 2,
98 n_neighbors: 15,
99 min_dist: 0.1,
100 spread: 1.0,
101 learning_rate: 1.0,
102 n_epochs: 200,
103 metric: UmapMetric::Euclidean,
104 negative_sample_rate: 5,
105 random_state: None,
106 }
107 }
108
109 #[must_use]
111 pub fn with_n_components(mut self, n: usize) -> Self {
112 self.n_components = n;
113 self
114 }
115
116 #[must_use]
118 pub fn with_n_neighbors(mut self, k: usize) -> Self {
119 self.n_neighbors = k;
120 self
121 }
122
123 #[must_use]
125 pub fn with_min_dist(mut self, d: f64) -> Self {
126 self.min_dist = d;
127 self
128 }
129
130 #[must_use]
132 pub fn with_spread(mut self, s: f64) -> Self {
133 self.spread = s;
134 self
135 }
136
137 #[must_use]
139 pub fn with_learning_rate(mut self, lr: f64) -> Self {
140 self.learning_rate = lr;
141 self
142 }
143
144 #[must_use]
146 pub fn with_n_epochs(mut self, n: usize) -> Self {
147 self.n_epochs = n;
148 self
149 }
150
151 #[must_use]
153 pub fn with_metric(mut self, m: UmapMetric) -> Self {
154 self.metric = m;
155 self
156 }
157
158 #[must_use]
160 pub fn with_negative_sample_rate(mut self, rate: usize) -> Self {
161 self.negative_sample_rate = rate;
162 self
163 }
164
165 #[must_use]
167 pub fn with_random_state(mut self, seed: u64) -> Self {
168 self.random_state = Some(seed);
169 self
170 }
171
172 #[must_use]
174 pub fn n_components(&self) -> usize {
175 self.n_components
176 }
177
178 #[must_use]
180 pub fn n_neighbors(&self) -> usize {
181 self.n_neighbors
182 }
183
184 #[must_use]
186 pub fn min_dist(&self) -> f64 {
187 self.min_dist
188 }
189
190 #[must_use]
192 pub fn spread(&self) -> f64 {
193 self.spread
194 }
195
196 #[must_use]
198 pub fn learning_rate(&self) -> f64 {
199 self.learning_rate
200 }
201
202 #[must_use]
204 pub fn n_epochs(&self) -> usize {
205 self.n_epochs
206 }
207
208 #[must_use]
210 pub fn metric(&self) -> UmapMetric {
211 self.metric
212 }
213
214 #[must_use]
216 pub fn negative_sample_rate(&self) -> usize {
217 self.negative_sample_rate
218 }
219
220 #[must_use]
222 pub fn random_state(&self) -> Option<u64> {
223 self.random_state
224 }
225}
226
227impl Default for Umap {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[derive(Debug, Clone)]
243pub struct FittedUmap {
244 embedding_: Array2<f64>,
246 x_train_: Array2<f64>,
248 a_: f64,
250 b_: f64,
252 n_neighbors_: usize,
254 metric_: UmapMetric,
256}
257
258impl FittedUmap {
259 #[must_use]
261 pub fn embedding(&self) -> &Array2<f64> {
262 &self.embedding_
263 }
264
265 #[must_use]
267 pub fn a(&self) -> f64 {
268 self.a_
269 }
270
271 #[must_use]
273 pub fn b(&self) -> f64 {
274 self.b_
275 }
276}
277
278fn compute_distance(x: &Array2<f64>, i: usize, j: usize, metric: UmapMetric) -> f64 {
284 let ncols = x.ncols();
285 match metric {
286 UmapMetric::Euclidean => {
287 let mut sq = 0.0;
288 for k in 0..ncols {
289 let diff = x[[i, k]] - x[[j, k]];
290 sq += diff * diff;
291 }
292 sq.sqrt()
293 }
294 UmapMetric::Manhattan => {
295 let mut sum = 0.0;
296 for k in 0..ncols {
297 sum += (x[[i, k]] - x[[j, k]]).abs();
298 }
299 sum
300 }
301 UmapMetric::Cosine => {
302 let mut dot = 0.0;
303 let mut norm_i = 0.0;
304 let mut norm_j = 0.0;
305 for k in 0..ncols {
306 dot += x[[i, k]] * x[[j, k]];
307 norm_i += x[[i, k]] * x[[i, k]];
308 norm_j += x[[j, k]] * x[[j, k]];
309 }
310 let denom = (norm_i * norm_j).sqrt();
311 if denom < 1e-16 {
312 1.0
313 } else {
314 1.0 - dot / denom
315 }
316 }
317 }
318}
319
320fn compute_distance_cross(
322 x_new: &Array2<f64>,
323 i: usize,
324 x_train: &Array2<f64>,
325 j: usize,
326 metric: UmapMetric,
327) -> f64 {
328 let ncols = x_new.ncols();
329 match metric {
330 UmapMetric::Euclidean => {
331 let mut sq = 0.0;
332 for k in 0..ncols {
333 let diff = x_new[[i, k]] - x_train[[j, k]];
334 sq += diff * diff;
335 }
336 sq.sqrt()
337 }
338 UmapMetric::Manhattan => {
339 let mut sum = 0.0;
340 for k in 0..ncols {
341 sum += (x_new[[i, k]] - x_train[[j, k]]).abs();
342 }
343 sum
344 }
345 UmapMetric::Cosine => {
346 let mut dot = 0.0;
347 let mut norm_i = 0.0;
348 let mut norm_j = 0.0;
349 for k in 0..ncols {
350 dot += x_new[[i, k]] * x_train[[j, k]];
351 norm_i += x_new[[i, k]] * x_new[[i, k]];
352 norm_j += x_train[[j, k]] * x_train[[j, k]];
353 }
354 let denom = (norm_i * norm_j).sqrt();
355 if denom < 1e-16 {
356 1.0
357 } else {
358 1.0 - dot / denom
359 }
360 }
361 }
362}
363
364fn build_knn(x: &Array2<f64>, k: usize, metric: UmapMetric) -> Vec<Vec<(usize, f64)>> {
367 let n = x.nrows();
368 let mut knn = Vec::with_capacity(n);
369 for i in 0..n {
370 let mut dists: Vec<(usize, f64)> = (0..n)
371 .filter(|&j| j != i)
372 .map(|j| (j, compute_distance(x, i, j, metric)))
373 .collect();
374 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
375 dists.truncate(k);
376 knn.push(dists);
377 }
378 knn
379}
380
381fn compute_fuzzy_simplicial_set(knn: &[Vec<(usize, f64)>], n: usize) -> Vec<(usize, usize, f64)> {
389 let k = if knn.is_empty() { 0 } else { knn[0].len() };
390 let target = (k as f64).ln() / std::f64::consts::LN_2; let mut rho = vec![0.0; n];
394 let mut sigma = vec![1.0; n];
395
396 for i in 0..n {
397 if knn[i].is_empty() {
398 continue;
399 }
400 rho[i] = knn[i][0].1;
402 if rho[i] < 1e-16 {
403 for &(_, d) in &knn[i] {
405 if d > 1e-16 {
406 rho[i] = d;
407 break;
408 }
409 }
410 }
411
412 let mut lo = 1e-20_f64;
414 let mut hi = 1e4_f64;
415 for _iter in 0..64 {
416 let mid = (lo + hi) / 2.0;
417 let mut val = 0.0;
418 for &(_, d) in &knn[i] {
419 let adjusted = (d - rho[i]).max(0.0);
420 val += (-adjusted / mid).exp();
421 }
422 if val > target {
423 hi = mid;
424 } else {
425 lo = mid;
426 }
427 if (hi - lo) / (lo + 1e-16) < 1e-5 {
428 break;
429 }
430 }
431 sigma[i] = (lo + hi) / 2.0;
432 }
433
434 let mut directed: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
437 for (i, neighbors) in knn.iter().enumerate() {
438 for &(j, d) in neighbors {
439 let adjusted = (d - rho[i]).max(0.0);
440 let w = (-adjusted / sigma[i]).exp();
441 directed[i].push((j, w));
442 }
443 }
444
445 let mut forward: std::collections::HashMap<(usize, usize), f64> =
449 std::collections::HashMap::new();
450 let mut backward: std::collections::HashMap<(usize, usize), f64> =
451 std::collections::HashMap::new();
452
453 for (i, neighbors) in directed.iter().enumerate() {
454 for &(j, w) in neighbors {
455 let key = if i < j { (i, j) } else { (j, i) };
456 if i < j {
457 forward.insert(key, w);
458 } else {
459 backward.insert(key, w);
460 }
461 }
462 }
463
464 let mut all_keys: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
466 for &k in forward.keys() {
467 all_keys.insert(k);
468 }
469 for &k in backward.keys() {
470 all_keys.insert(k);
471 }
472
473 let mut edges = Vec::with_capacity(all_keys.len());
474 for key in all_keys {
475 let w_fwd = forward.get(&key).copied().unwrap_or(0.0);
476 let w_bwd = backward.get(&key).copied().unwrap_or(0.0);
477 let w = w_fwd + w_bwd - w_fwd * w_bwd;
478 if w > 1e-16 {
479 edges.push((key.0, key.1, w));
480 }
481 }
482
483 edges
484}
485
486fn find_ab_params(min_dist: f64, spread: f64) -> (f64, f64) {
493 let n_samples = 300;
495 let d_max = 3.0 * spread;
496 let mut best_a = 1.0;
497 let mut best_b = 1.0;
498 let mut best_err = f64::MAX;
499
500 let a_range: Vec<f64> = (1..=40).map(|i| i as f64 * 0.25).collect();
502 let b_range: Vec<f64> = (1..=30).map(|i| i as f64 * 0.1).collect();
503
504 for &a in &a_range {
505 for &b in &b_range {
506 let mut err = 0.0;
507 for k in 0..n_samples {
508 let d = (k as f64 + 0.5) / n_samples as f64 * d_max;
509 let target = if d <= min_dist {
510 1.0
511 } else {
512 (-(d - min_dist) / spread).exp()
513 };
514 let pred = 1.0 / (1.0 + a * d.powf(2.0 * b));
515 let diff = pred - target;
516 err += diff * diff;
517 }
518 if err < best_err {
519 best_err = err;
520 best_a = a;
521 best_b = b;
522 }
523 }
524 }
525
526 let a_lo = (best_a - 0.3).max(0.01);
528 let a_hi = best_a + 0.3;
529 let b_lo = (best_b - 0.15).max(0.01);
530 let b_hi = best_b + 0.15;
531
532 for ia in 0..20 {
533 let a = a_lo + (a_hi - a_lo) * ia as f64 / 19.0;
534 for ib in 0..20 {
535 let b = b_lo + (b_hi - b_lo) * ib as f64 / 19.0;
536 let mut err = 0.0;
537 for k in 0..n_samples {
538 let d = (k as f64 + 0.5) / n_samples as f64 * d_max;
539 let target = if d <= min_dist {
540 1.0
541 } else {
542 (-(d - min_dist) / spread).exp()
543 };
544 let pred = 1.0 / (1.0 + a * d.powf(2.0 * b));
545 let diff = pred - target;
546 err += diff * diff;
547 }
548 if err < best_err {
549 best_err = err;
550 best_a = a;
551 best_b = b;
552 }
553 }
554 }
555
556 (best_a, best_b)
557}
558
559fn clip(val: f64, lo: f64, hi: f64) -> f64 {
561 if val < lo {
562 lo
563 } else if val > hi {
564 hi
565 } else {
566 val
567 }
568}
569
570impl Fit<Array2<f64>, ()> for Umap {
575 type Fitted = FittedUmap;
576 type Error = FerroError;
577
578 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedUmap, FerroError> {
589 let n = x.nrows();
590
591 if self.n_components == 0 {
593 return Err(FerroError::InvalidParameter {
594 name: "n_components".into(),
595 reason: "must be at least 1".into(),
596 });
597 }
598 if self.n_neighbors == 0 {
599 return Err(FerroError::InvalidParameter {
600 name: "n_neighbors".into(),
601 reason: "must be at least 1".into(),
602 });
603 }
604 if n < 2 {
605 return Err(FerroError::InsufficientSamples {
606 required: 2,
607 actual: n,
608 context: "Umap::fit requires at least 2 samples".into(),
609 });
610 }
611 let effective_k = self.n_neighbors.min(n - 1);
612 if self.min_dist < 0.0 {
613 return Err(FerroError::InvalidParameter {
614 name: "min_dist".into(),
615 reason: "must be non-negative".into(),
616 });
617 }
618 if self.spread <= 0.0 {
619 return Err(FerroError::InvalidParameter {
620 name: "spread".into(),
621 reason: "must be positive".into(),
622 });
623 }
624 if self.learning_rate <= 0.0 {
625 return Err(FerroError::InvalidParameter {
626 name: "learning_rate".into(),
627 reason: "must be positive".into(),
628 });
629 }
630
631 let dim = self.n_components;
632 let seed = self.random_state.unwrap_or(0);
633
634 let knn = build_knn(x, effective_k, self.metric);
636
637 let edges = compute_fuzzy_simplicial_set(&knn, n);
639
640 let (a, b) = find_ab_params(self.min_dist, self.spread);
642
643 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
645 let uniform = Uniform::new(-10.0, 10.0).unwrap();
646 let mut y = Array2::<f64>::zeros((n, dim));
647 for elem in y.iter_mut() {
648 *elem = uniform.sample(&mut rng);
649 }
650
651 if edges.is_empty() {
653 return Ok(FittedUmap {
654 embedding_: y,
655 x_train_: x.to_owned(),
656 a_: a,
657 b_: b,
658 n_neighbors_: effective_k,
659 metric_: self.metric,
660 });
661 }
662
663 let max_weight = edges
664 .iter()
665 .map(|e| e.2)
666 .fold(0.0_f64, |a_val, b_val| a_val.max(b_val));
667
668 let epochs_per_sample: Vec<f64> = edges
670 .iter()
671 .map(|e| {
672 let ratio = e.2 / max_weight;
673 if ratio > 0.0 {
674 (self.n_epochs as f64) / ((self.n_epochs as f64) * ratio).max(1.0)
675 } else {
676 f64::MAX
677 }
678 })
679 .collect();
680
681 let mut epoch_of_next_sample: Vec<f64> = epochs_per_sample.clone();
682
683 let neg_rate = self.negative_sample_rate;
684 let idx_uniform = Uniform::new(0usize, n).unwrap();
685
686 for epoch in 0..self.n_epochs {
688 let alpha = self.learning_rate * (1.0 - epoch as f64 / self.n_epochs as f64);
689 let alpha = alpha.max(0.0);
690
691 for (edge_idx, &(ei, ej, _weight)) in edges.iter().enumerate() {
692 if epoch_of_next_sample[edge_idx] > epoch as f64 {
693 continue;
694 }
695
696 let mut dist_sq = 0.0;
698 for d in 0..dim {
699 let diff = y[[ei, d]] - y[[ej, d]];
700 dist_sq += diff * diff;
701 }
702 let dist_sq = dist_sq.max(1e-16);
703
704 let grad_coeff = -2.0 * a * b * dist_sq.powf(b - 1.0) / (1.0 + a * dist_sq.powf(b));
705
706 for d in 0..dim {
707 let diff = y[[ei, d]] - y[[ej, d]];
708 let grad = clip(grad_coeff * diff, -4.0, 4.0);
709 y[[ei, d]] += alpha * grad;
710 y[[ej, d]] -= alpha * grad;
711 }
712
713 for _ in 0..neg_rate {
715 let neg = idx_uniform.sample(&mut rng);
716 if neg == ei {
717 continue;
718 }
719 let mut neg_dist_sq = 0.0;
720 for d in 0..dim {
721 let diff = y[[ei, d]] - y[[neg, d]];
722 neg_dist_sq += diff * diff;
723 }
724 let neg_dist_sq = neg_dist_sq.max(1e-16);
725
726 let rep_coeff =
727 2.0 * b / ((0.001 + neg_dist_sq) * (1.0 + a * neg_dist_sq.powf(b)));
728
729 for d in 0..dim {
730 let diff = y[[ei, d]] - y[[neg, d]];
731 let grad = clip(rep_coeff * diff, -4.0, 4.0);
732 y[[ei, d]] += alpha * grad;
733 }
734 }
735
736 epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
737 }
738 }
739
740 Ok(FittedUmap {
741 embedding_: y,
742 x_train_: x.to_owned(),
743 a_: a,
744 b_: b,
745 n_neighbors_: effective_k,
746 metric_: self.metric,
747 })
748 }
749}
750
751impl Transform<Array2<f64>> for FittedUmap {
752 type Output = Array2<f64>;
753 type Error = FerroError;
754
755 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
766 let n_features = self.x_train_.ncols();
767 if x.ncols() != n_features {
768 return Err(FerroError::ShapeMismatch {
769 expected: vec![x.nrows(), n_features],
770 actual: vec![x.nrows(), x.ncols()],
771 context: "FittedUmap::transform".into(),
772 });
773 }
774
775 let n_test = x.nrows();
776 let n_train = self.x_train_.nrows();
777 let dim = self.embedding_.ncols();
778 let k = self.n_neighbors_.min(n_train);
779
780 let mut result = Array2::<f64>::zeros((n_test, dim));
781
782 for t in 0..n_test {
783 let mut dists: Vec<(usize, f64)> = (0..n_train)
785 .map(|j| {
786 (
787 j,
788 compute_distance_cross(x, t, &self.x_train_, j, self.metric_),
789 )
790 })
791 .collect();
792 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
793 dists.truncate(k);
794
795 let mut weights = Vec::with_capacity(k);
797 let mut weight_sum = 0.0;
798 for &(_, d) in &dists {
799 let w = 1.0 / (1.0 + self.a_ * d.powf(2.0 * self.b_));
800 weights.push(w);
801 weight_sum += w;
802 }
803
804 if weight_sum < 1e-16 {
805 weight_sum = k as f64;
807 weights = vec![1.0; k];
808 }
809
810 for (idx, &(train_idx, _)) in dists.iter().enumerate() {
812 let w = weights[idx] / weight_sum;
813 for d in 0..dim {
814 result[[t, d]] += w * self.embedding_[[train_idx, d]];
815 }
816 }
817 }
818
819 Ok(result)
820 }
821}
822
823#[cfg(test)]
828mod tests {
829 use super::*;
830 use ndarray::Array2;
831 use rand::SeedableRng;
832 use rand_distr::{Distribution, Normal};
833 use rand_xoshiro::Xoshiro256PlusPlus;
834
835 fn make_blobs(seed: u64) -> (Array2<f64>, Vec<usize>) {
837 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
838 let normal = Normal::new(0.0, 0.3).unwrap();
839 let n_per_cluster = 10;
840 let n_features = 5;
841 let centers = vec![
842 vec![0.0, 0.0, 0.0, 0.0, 0.0],
843 vec![5.0, 5.0, 5.0, 5.0, 5.0],
844 vec![10.0, 0.0, 10.0, 0.0, 10.0],
845 ];
846 let n = centers.len() * n_per_cluster;
847 let mut x = Array2::<f64>::zeros((n, n_features));
848 let mut labels = Vec::with_capacity(n);
849 for (c_idx, center) in centers.iter().enumerate() {
850 for i in 0..n_per_cluster {
851 let row = c_idx * n_per_cluster + i;
852 for (f, &c) in center.iter().enumerate() {
853 x[[row, f]] = c + normal.sample(&mut rng);
854 }
855 labels.push(c_idx);
856 }
857 }
858 (x, labels)
859 }
860
861 #[test]
862 fn test_umap_basic_shape() {
863 let x = Array2::<f64>::from_shape_fn((30, 5), |(i, j)| (i + j) as f64);
864 let umap = Umap::new().with_n_epochs(10).with_random_state(42);
865 let fitted = umap.fit(&x, &()).unwrap();
866 assert_eq!(fitted.embedding().dim(), (30, 2));
867 }
868
869 #[test]
870 fn test_umap_3d_embedding() {
871 let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
872 let umap = Umap::new()
873 .with_n_components(3)
874 .with_n_epochs(10)
875 .with_random_state(42);
876 let fitted = umap.fit(&x, &()).unwrap();
877 assert_eq!(fitted.embedding().ncols(), 3);
878 }
879
880 #[test]
881 fn test_umap_separates_clusters() {
882 let (x, labels) = make_blobs(42);
883 let umap = Umap::new()
884 .with_n_neighbors(5)
885 .with_n_epochs(100)
886 .with_random_state(42);
887 let fitted = umap.fit(&x, &()).unwrap();
888 let emb = fitted.embedding();
889
890 let n = emb.nrows();
892 let mut correct = 0;
893 for i in 0..n {
894 let mut dists: Vec<(f64, usize)> = (0..n)
895 .filter(|&j| j != i)
896 .map(|j| {
897 let mut d = 0.0;
898 for dd in 0..emb.ncols() {
899 let diff = emb[[i, dd]] - emb[[j, dd]];
900 d += diff * diff;
901 }
902 (d, labels[j])
903 })
904 .collect();
905 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
906 let mut votes = [0usize; 3];
907 for &(_, lbl) in dists.iter().take(3) {
908 votes[lbl] += 1;
909 }
910 let pred = votes.iter().enumerate().max_by_key(|&(_, v)| v).unwrap().0;
911 if pred == labels[i] {
912 correct += 1;
913 }
914 }
915 let accuracy = correct as f64 / n as f64;
916 assert!(
917 accuracy > 0.8,
918 "UMAP k-NN accuracy should be > 80%, got {:.1}%",
919 accuracy * 100.0
920 );
921 }
922
923 #[test]
924 fn test_umap_transform_new_data() {
925 let (x, _) = make_blobs(42);
926 let umap = Umap::new()
927 .with_n_neighbors(5)
928 .with_n_epochs(50)
929 .with_random_state(42);
930 let fitted = umap.fit(&x, &()).unwrap();
931
932 let x_test = x.slice(ndarray::s![0..5, ..]).to_owned();
934 let projected = fitted.transform(&x_test).unwrap();
935 assert_eq!(projected.dim(), (5, 2));
936 }
937
938 #[test]
939 fn test_umap_transform_shape_mismatch() {
940 let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
941 let umap = Umap::new().with_n_epochs(10).with_random_state(42);
942 let fitted = umap.fit(&x, &()).unwrap();
943 let x_bad = Array2::<f64>::zeros((5, 3)); assert!(fitted.transform(&x_bad).is_err());
945 }
946
947 #[test]
948 fn test_umap_ab_params_reasonable() {
949 let (a, b) = find_ab_params(0.1, 1.0);
950 assert!(a > 0.0, "a should be positive, got {a}");
952 assert!(b > 0.0, "b should be positive, got {b}");
953 let val_at_min = 1.0 / (1.0 + a * (0.1_f64).powf(2.0 * b));
956 assert!(
957 val_at_min > 0.5,
958 "kernel at min_dist should be > 0.5, got {val_at_min}"
959 );
960 }
961
962 #[test]
963 fn test_umap_invalid_n_components_zero() {
964 let x = Array2::<f64>::zeros((10, 3));
965 let umap = Umap::new().with_n_components(0);
966 assert!(umap.fit(&x, &()).is_err());
967 }
968
969 #[test]
970 fn test_umap_invalid_n_neighbors_zero() {
971 let x = Array2::<f64>::zeros((10, 3));
972 let umap = Umap::new().with_n_neighbors(0);
973 assert!(umap.fit(&x, &()).is_err());
974 }
975
976 #[test]
977 fn test_umap_invalid_min_dist() {
978 let x = Array2::<f64>::zeros((10, 3));
979 let umap = Umap::new().with_min_dist(-0.1);
980 assert!(umap.fit(&x, &()).is_err());
981 }
982
983 #[test]
984 fn test_umap_invalid_spread() {
985 let x = Array2::<f64>::zeros((10, 3));
986 let umap = Umap::new().with_spread(0.0);
987 assert!(umap.fit(&x, &()).is_err());
988 }
989
990 #[test]
991 fn test_umap_invalid_learning_rate() {
992 let x = Array2::<f64>::zeros((10, 3));
993 let umap = Umap::new().with_learning_rate(-1.0);
994 assert!(umap.fit(&x, &()).is_err());
995 }
996
997 #[test]
998 fn test_umap_insufficient_samples() {
999 let x = Array2::<f64>::zeros((1, 3));
1000 let umap = Umap::new();
1001 assert!(umap.fit(&x, &()).is_err());
1002 }
1003
1004 #[test]
1005 fn test_umap_getters() {
1006 let umap = Umap::new()
1007 .with_n_components(3)
1008 .with_n_neighbors(10)
1009 .with_min_dist(0.2)
1010 .with_spread(1.5)
1011 .with_learning_rate(0.5)
1012 .with_n_epochs(100)
1013 .with_metric(UmapMetric::Manhattan)
1014 .with_negative_sample_rate(3)
1015 .with_random_state(99);
1016 assert_eq!(umap.n_components(), 3);
1017 assert_eq!(umap.n_neighbors(), 10);
1018 assert!((umap.min_dist() - 0.2).abs() < 1e-10);
1019 assert!((umap.spread() - 1.5).abs() < 1e-10);
1020 assert!((umap.learning_rate() - 0.5).abs() < 1e-10);
1021 assert_eq!(umap.n_epochs(), 100);
1022 assert_eq!(umap.metric(), UmapMetric::Manhattan);
1023 assert_eq!(umap.negative_sample_rate(), 3);
1024 assert_eq!(umap.random_state(), Some(99));
1025 }
1026
1027 #[test]
1028 fn test_umap_default() {
1029 let umap = Umap::default();
1030 assert_eq!(umap.n_components(), 2);
1031 assert_eq!(umap.n_neighbors(), 15);
1032 }
1033
1034 #[test]
1035 fn test_umap_cosine_metric() {
1036 let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j + 1) as f64);
1037 let umap = Umap::new()
1038 .with_metric(UmapMetric::Cosine)
1039 .with_n_epochs(10)
1040 .with_random_state(42);
1041 let fitted = umap.fit(&x, &()).unwrap();
1042 assert_eq!(fitted.embedding().dim(), (20, 2));
1043 }
1044
1045 #[test]
1046 fn test_umap_small_n_neighbors_capped() {
1047 let x = Array2::<f64>::from_shape_fn((5, 3), |(i, j)| (i + j) as f64);
1049 let umap = Umap::new()
1050 .with_n_neighbors(100)
1051 .with_n_epochs(10)
1052 .with_random_state(42);
1053 let fitted = umap.fit(&x, &()).unwrap();
1054 assert_eq!(fitted.embedding().dim(), (5, 2));
1055 }
1056
1057 #[test]
1058 fn test_umap_fitted_accessors() {
1059 let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
1060 let umap = Umap::new().with_n_epochs(10).with_random_state(42);
1061 let fitted = umap.fit(&x, &()).unwrap();
1062 assert!(fitted.a() > 0.0);
1063 assert!(fitted.b() > 0.0);
1064 }
1065}