1use ferrolearn_core::error::FerroError;
37use ferrolearn_core::traits::{Fit, Transform};
38use ndarray::Array2;
39use rand::SeedableRng;
40use rand_distr::{Distribution, Uniform};
41use rand_xoshiro::Xoshiro256PlusPlus;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum UmapMetric {
50 Euclidean,
52 Manhattan,
54 Cosine,
56}
57
58#[derive(Debug, Clone)]
67pub struct Umap {
68 n_components: usize,
70 n_neighbors: usize,
72 min_dist: f64,
74 spread: f64,
76 learning_rate: f64,
78 n_epochs: usize,
80 metric: UmapMetric,
82 negative_sample_rate: usize,
84 random_state: Option<u64>,
86}
87
88impl Umap {
89 #[must_use]
95 pub fn new() -> Self {
96 Self {
97 n_components: 2,
98 n_neighbors: 15,
99 min_dist: 0.1,
100 spread: 1.0,
101 learning_rate: 1.0,
102 n_epochs: 200,
103 metric: UmapMetric::Euclidean,
104 negative_sample_rate: 5,
105 random_state: None,
106 }
107 }
108
109 #[must_use]
111 pub fn with_n_components(mut self, n: usize) -> Self {
112 self.n_components = n;
113 self
114 }
115
116 #[must_use]
118 pub fn with_n_neighbors(mut self, k: usize) -> Self {
119 self.n_neighbors = k;
120 self
121 }
122
123 #[must_use]
125 pub fn with_min_dist(mut self, d: f64) -> Self {
126 self.min_dist = d;
127 self
128 }
129
130 #[must_use]
132 pub fn with_spread(mut self, s: f64) -> Self {
133 self.spread = s;
134 self
135 }
136
137 #[must_use]
139 pub fn with_learning_rate(mut self, lr: f64) -> Self {
140 self.learning_rate = lr;
141 self
142 }
143
144 #[must_use]
146 pub fn with_n_epochs(mut self, n: usize) -> Self {
147 self.n_epochs = n;
148 self
149 }
150
151 #[must_use]
153 pub fn with_metric(mut self, m: UmapMetric) -> Self {
154 self.metric = m;
155 self
156 }
157
158 #[must_use]
160 pub fn with_negative_sample_rate(mut self, rate: usize) -> Self {
161 self.negative_sample_rate = rate;
162 self
163 }
164
165 #[must_use]
167 pub fn with_random_state(mut self, seed: u64) -> Self {
168 self.random_state = Some(seed);
169 self
170 }
171
172 #[must_use]
174 pub fn n_components(&self) -> usize {
175 self.n_components
176 }
177
178 #[must_use]
180 pub fn n_neighbors(&self) -> usize {
181 self.n_neighbors
182 }
183
184 #[must_use]
186 pub fn min_dist(&self) -> f64 {
187 self.min_dist
188 }
189
190 #[must_use]
192 pub fn spread(&self) -> f64 {
193 self.spread
194 }
195
196 #[must_use]
198 pub fn learning_rate(&self) -> f64 {
199 self.learning_rate
200 }
201
202 #[must_use]
204 pub fn n_epochs(&self) -> usize {
205 self.n_epochs
206 }
207
208 #[must_use]
210 pub fn metric(&self) -> UmapMetric {
211 self.metric
212 }
213
214 #[must_use]
216 pub fn negative_sample_rate(&self) -> usize {
217 self.negative_sample_rate
218 }
219
220 #[must_use]
222 pub fn random_state(&self) -> Option<u64> {
223 self.random_state
224 }
225}
226
227impl Default for Umap {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[derive(Debug, Clone)]
243pub struct FittedUmap {
244 embedding_: Array2<f64>,
246 x_train_: Array2<f64>,
248 a_: f64,
250 b_: f64,
252 n_neighbors_: usize,
254 metric_: UmapMetric,
256}
257
258impl FittedUmap {
259 #[must_use]
261 pub fn embedding(&self) -> &Array2<f64> {
262 &self.embedding_
263 }
264
265 #[must_use]
267 pub fn a(&self) -> f64 {
268 self.a_
269 }
270
271 #[must_use]
273 pub fn b(&self) -> f64 {
274 self.b_
275 }
276}
277
278fn compute_distance(x: &Array2<f64>, i: usize, j: usize, metric: UmapMetric) -> f64 {
284 let ncols = x.ncols();
285 match metric {
286 UmapMetric::Euclidean => {
287 let mut sq = 0.0;
288 for k in 0..ncols {
289 let diff = x[[i, k]] - x[[j, k]];
290 sq += diff * diff;
291 }
292 sq.sqrt()
293 }
294 UmapMetric::Manhattan => {
295 let mut sum = 0.0;
296 for k in 0..ncols {
297 sum += (x[[i, k]] - x[[j, k]]).abs();
298 }
299 sum
300 }
301 UmapMetric::Cosine => {
302 let mut dot = 0.0;
303 let mut norm_i = 0.0;
304 let mut norm_j = 0.0;
305 for k in 0..ncols {
306 dot += x[[i, k]] * x[[j, k]];
307 norm_i += x[[i, k]] * x[[i, k]];
308 norm_j += x[[j, k]] * x[[j, k]];
309 }
310 let denom = (norm_i * norm_j).sqrt();
311 if denom < 1e-16 {
312 1.0
313 } else {
314 1.0 - dot / denom
315 }
316 }
317 }
318}
319
320fn compute_distance_cross(
322 x_new: &Array2<f64>,
323 i: usize,
324 x_train: &Array2<f64>,
325 j: usize,
326 metric: UmapMetric,
327) -> f64 {
328 let ncols = x_new.ncols();
329 match metric {
330 UmapMetric::Euclidean => {
331 let mut sq = 0.0;
332 for k in 0..ncols {
333 let diff = x_new[[i, k]] - x_train[[j, k]];
334 sq += diff * diff;
335 }
336 sq.sqrt()
337 }
338 UmapMetric::Manhattan => {
339 let mut sum = 0.0;
340 for k in 0..ncols {
341 sum += (x_new[[i, k]] - x_train[[j, k]]).abs();
342 }
343 sum
344 }
345 UmapMetric::Cosine => {
346 let mut dot = 0.0;
347 let mut norm_i = 0.0;
348 let mut norm_j = 0.0;
349 for k in 0..ncols {
350 dot += x_new[[i, k]] * x_train[[j, k]];
351 norm_i += x_new[[i, k]] * x_new[[i, k]];
352 norm_j += x_train[[j, k]] * x_train[[j, k]];
353 }
354 let denom = (norm_i * norm_j).sqrt();
355 if denom < 1e-16 {
356 1.0
357 } else {
358 1.0 - dot / denom
359 }
360 }
361 }
362}
363
364fn build_knn(x: &Array2<f64>, k: usize, metric: UmapMetric) -> Vec<Vec<(usize, f64)>> {
367 let n = x.nrows();
368 let mut knn = Vec::with_capacity(n);
369 for i in 0..n {
370 let mut dists: Vec<(usize, f64)> = (0..n)
371 .filter(|&j| j != i)
372 .map(|j| (j, compute_distance(x, i, j, metric)))
373 .collect();
374 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
375 dists.truncate(k);
376 knn.push(dists);
377 }
378 knn
379}
380
381fn compute_fuzzy_simplicial_set(knn: &[Vec<(usize, f64)>], n: usize) -> Vec<(usize, usize, f64)> {
389 let k = if knn.is_empty() { 0 } else { knn[0].len() };
390 let target = (k as f64).ln() / std::f64::consts::LN_2; let mut rho = vec![0.0; n];
394 let mut sigma = vec![1.0; n];
395
396 for i in 0..n {
397 if knn[i].is_empty() {
398 continue;
399 }
400 rho[i] = knn[i][0].1;
402 if rho[i] < 1e-16 {
403 for &(_, d) in &knn[i] {
405 if d > 1e-16 {
406 rho[i] = d;
407 break;
408 }
409 }
410 }
411
412 let mut lo = 1e-20_f64;
414 let mut hi = 1e4_f64;
415 for _iter in 0..64 {
416 let mid = f64::midpoint(lo, hi);
417 let mut val = 0.0;
418 for &(_, d) in &knn[i] {
419 let adjusted = (d - rho[i]).max(0.0);
420 val += (-adjusted / mid).exp();
421 }
422 if val > target {
423 hi = mid;
424 } else {
425 lo = mid;
426 }
427 if (hi - lo) / (lo + 1e-16) < 1e-5 {
428 break;
429 }
430 }
431 sigma[i] = f64::midpoint(lo, hi);
432 }
433
434 let mut directed: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
437 for (i, neighbors) in knn.iter().enumerate() {
438 for &(j, d) in neighbors {
439 let adjusted = (d - rho[i]).max(0.0);
440 let w = (-adjusted / sigma[i]).exp();
441 directed[i].push((j, w));
442 }
443 }
444
445 let mut forward: std::collections::HashMap<(usize, usize), f64> =
449 std::collections::HashMap::new();
450 let mut backward: std::collections::HashMap<(usize, usize), f64> =
451 std::collections::HashMap::new();
452
453 for (i, neighbors) in directed.iter().enumerate() {
454 for &(j, w) in neighbors {
455 let key = if i < j { (i, j) } else { (j, i) };
456 if i < j {
457 forward.insert(key, w);
458 } else {
459 backward.insert(key, w);
460 }
461 }
462 }
463
464 let mut all_keys: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
466 for &k in forward.keys() {
467 all_keys.insert(k);
468 }
469 for &k in backward.keys() {
470 all_keys.insert(k);
471 }
472
473 let mut edges = Vec::with_capacity(all_keys.len());
474 for key in all_keys {
475 let w_fwd = forward.get(&key).copied().unwrap_or(0.0);
476 let w_bwd = backward.get(&key).copied().unwrap_or(0.0);
477 let w = w_fwd + w_bwd - w_fwd * w_bwd;
478 if w > 1e-16 {
479 edges.push((key.0, key.1, w));
480 }
481 }
482
483 edges
484}
485
486fn find_ab_params(min_dist: f64, spread: f64) -> (f64, f64) {
493 let n_samples = 300;
495 let d_max = 3.0 * spread;
496 let mut best_a = 1.0;
497 let mut best_b = 1.0;
498 let mut best_err = f64::MAX;
499
500 let a_range: Vec<f64> = (1..=40).map(|i| f64::from(i) * 0.25).collect();
502 let b_range: Vec<f64> = (1..=30).map(|i| f64::from(i) * 0.1).collect();
503
504 for &a in &a_range {
505 for &b in &b_range {
506 let mut err = 0.0;
507 for k in 0..n_samples {
508 let d = (f64::from(k) + 0.5) / f64::from(n_samples) * d_max;
509 let target = if d <= min_dist {
510 1.0
511 } else {
512 (-(d - min_dist) / spread).exp()
513 };
514 let pred = 1.0 / (1.0 + a * d.powf(2.0 * b));
515 let diff = pred - target;
516 err += diff * diff;
517 }
518 if err < best_err {
519 best_err = err;
520 best_a = a;
521 best_b = b;
522 }
523 }
524 }
525
526 let a_lo = (best_a - 0.3).max(0.01);
528 let a_hi = best_a + 0.3;
529 let b_lo = (best_b - 0.15).max(0.01);
530 let b_hi = best_b + 0.15;
531
532 for ia in 0..20 {
533 let a = a_lo + (a_hi - a_lo) * f64::from(ia) / 19.0;
534 for ib in 0..20 {
535 let b = b_lo + (b_hi - b_lo) * f64::from(ib) / 19.0;
536 let mut err = 0.0;
537 for k in 0..n_samples {
538 let d = (f64::from(k) + 0.5) / f64::from(n_samples) * d_max;
539 let target = if d <= min_dist {
540 1.0
541 } else {
542 (-(d - min_dist) / spread).exp()
543 };
544 let pred = 1.0 / (1.0 + a * d.powf(2.0 * b));
545 let diff = pred - target;
546 err += diff * diff;
547 }
548 if err < best_err {
549 best_err = err;
550 best_a = a;
551 best_b = b;
552 }
553 }
554 }
555
556 (best_a, best_b)
557}
558
559fn clip(val: f64, lo: f64, hi: f64) -> f64 {
561 if val < lo {
562 lo
563 } else if val > hi {
564 hi
565 } else {
566 val
567 }
568}
569
570impl Fit<Array2<f64>, ()> for Umap {
575 type Fitted = FittedUmap;
576 type Error = FerroError;
577
578 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedUmap, FerroError> {
589 let n = x.nrows();
590
591 if self.n_components == 0 {
593 return Err(FerroError::InvalidParameter {
594 name: "n_components".into(),
595 reason: "must be at least 1".into(),
596 });
597 }
598 if self.n_neighbors == 0 {
599 return Err(FerroError::InvalidParameter {
600 name: "n_neighbors".into(),
601 reason: "must be at least 1".into(),
602 });
603 }
604 if n < 2 {
605 return Err(FerroError::InsufficientSamples {
606 required: 2,
607 actual: n,
608 context: "Umap::fit requires at least 2 samples".into(),
609 });
610 }
611 let effective_k = self.n_neighbors.min(n - 1);
612 if self.min_dist < 0.0 {
613 return Err(FerroError::InvalidParameter {
614 name: "min_dist".into(),
615 reason: "must be non-negative".into(),
616 });
617 }
618 if self.spread <= 0.0 {
619 return Err(FerroError::InvalidParameter {
620 name: "spread".into(),
621 reason: "must be positive".into(),
622 });
623 }
624 if self.learning_rate <= 0.0 {
625 return Err(FerroError::InvalidParameter {
626 name: "learning_rate".into(),
627 reason: "must be positive".into(),
628 });
629 }
630
631 let dim = self.n_components;
632 let seed = self.random_state.unwrap_or(0);
633
634 let knn = build_knn(x, effective_k, self.metric);
636
637 let edges = compute_fuzzy_simplicial_set(&knn, n);
639
640 let (a, b) = find_ab_params(self.min_dist, self.spread);
642
643 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
645 let uniform = Uniform::new(-10.0, 10.0).unwrap();
646 let mut y = Array2::<f64>::zeros((n, dim));
647 for elem in &mut y {
648 *elem = uniform.sample(&mut rng);
649 }
650
651 if edges.is_empty() {
653 return Ok(FittedUmap {
654 embedding_: y,
655 x_train_: x.to_owned(),
656 a_: a,
657 b_: b,
658 n_neighbors_: effective_k,
659 metric_: self.metric,
660 });
661 }
662
663 let max_weight = edges.iter().map(|e| e.2).fold(0.0_f64, f64::max);
664
665 let epochs_per_sample: Vec<f64> = edges
667 .iter()
668 .map(|e| {
669 let ratio = e.2 / max_weight;
670 if ratio > 0.0 {
671 (self.n_epochs as f64) / ((self.n_epochs as f64) * ratio).max(1.0)
672 } else {
673 f64::MAX
674 }
675 })
676 .collect();
677
678 let mut epoch_of_next_sample: Vec<f64> = epochs_per_sample.clone();
679
680 let neg_rate = self.negative_sample_rate;
681 let idx_uniform = Uniform::new(0usize, n).unwrap();
682
683 for epoch in 0..self.n_epochs {
685 let alpha = self.learning_rate * (1.0 - epoch as f64 / self.n_epochs as f64);
686 let alpha = alpha.max(0.0);
687
688 for (edge_idx, &(ei, ej, _weight)) in edges.iter().enumerate() {
689 if epoch_of_next_sample[edge_idx] > epoch as f64 {
690 continue;
691 }
692
693 let mut dist_sq = 0.0;
695 for d in 0..dim {
696 let diff = y[[ei, d]] - y[[ej, d]];
697 dist_sq += diff * diff;
698 }
699 let dist_sq = dist_sq.max(1e-16);
700
701 let grad_coeff = -2.0 * a * b * dist_sq.powf(b - 1.0) / (1.0 + a * dist_sq.powf(b));
702
703 for d in 0..dim {
704 let diff = y[[ei, d]] - y[[ej, d]];
705 let grad = clip(grad_coeff * diff, -4.0, 4.0);
706 y[[ei, d]] += alpha * grad;
707 y[[ej, d]] -= alpha * grad;
708 }
709
710 for _ in 0..neg_rate {
712 let neg = idx_uniform.sample(&mut rng);
713 if neg == ei {
714 continue;
715 }
716 let mut neg_dist_sq = 0.0;
717 for d in 0..dim {
718 let diff = y[[ei, d]] - y[[neg, d]];
719 neg_dist_sq += diff * diff;
720 }
721 let neg_dist_sq = neg_dist_sq.max(1e-16);
722
723 let rep_coeff =
724 2.0 * b / ((0.001 + neg_dist_sq) * (1.0 + a * neg_dist_sq.powf(b)));
725
726 for d in 0..dim {
727 let diff = y[[ei, d]] - y[[neg, d]];
728 let grad = clip(rep_coeff * diff, -4.0, 4.0);
729 y[[ei, d]] += alpha * grad;
730 }
731 }
732
733 epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
734 }
735 }
736
737 Ok(FittedUmap {
738 embedding_: y,
739 x_train_: x.to_owned(),
740 a_: a,
741 b_: b,
742 n_neighbors_: effective_k,
743 metric_: self.metric,
744 })
745 }
746}
747
748impl Transform<Array2<f64>> for FittedUmap {
749 type Output = Array2<f64>;
750 type Error = FerroError;
751
752 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
763 let n_features = self.x_train_.ncols();
764 if x.ncols() != n_features {
765 return Err(FerroError::ShapeMismatch {
766 expected: vec![x.nrows(), n_features],
767 actual: vec![x.nrows(), x.ncols()],
768 context: "FittedUmap::transform".into(),
769 });
770 }
771
772 let n_test = x.nrows();
773 let n_train = self.x_train_.nrows();
774 let dim = self.embedding_.ncols();
775 let k = self.n_neighbors_.min(n_train);
776
777 let mut result = Array2::<f64>::zeros((n_test, dim));
778
779 for t in 0..n_test {
780 let mut dists: Vec<(usize, f64)> = (0..n_train)
782 .map(|j| {
783 (
784 j,
785 compute_distance_cross(x, t, &self.x_train_, j, self.metric_),
786 )
787 })
788 .collect();
789 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
790 dists.truncate(k);
791
792 let mut weights = Vec::with_capacity(k);
794 let mut weight_sum = 0.0;
795 for &(_, d) in &dists {
796 let w = 1.0 / (1.0 + self.a_ * d.powf(2.0 * self.b_));
797 weights.push(w);
798 weight_sum += w;
799 }
800
801 if weight_sum < 1e-16 {
802 weight_sum = k as f64;
804 weights = vec![1.0; k];
805 }
806
807 for (idx, &(train_idx, _)) in dists.iter().enumerate() {
809 let w = weights[idx] / weight_sum;
810 for d in 0..dim {
811 result[[t, d]] += w * self.embedding_[[train_idx, d]];
812 }
813 }
814 }
815
816 Ok(result)
817 }
818}
819
820#[cfg(test)]
825mod tests {
826 use super::*;
827 use ndarray::Array2;
828 use rand::SeedableRng;
829 use rand_distr::{Distribution, Normal};
830 use rand_xoshiro::Xoshiro256PlusPlus;
831
832 fn make_blobs(seed: u64) -> (Array2<f64>, Vec<usize>) {
834 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
835 let normal = Normal::new(0.0, 0.3).unwrap();
836 let n_per_cluster = 10;
837 let n_features = 5;
838 let centers = [
839 vec![0.0, 0.0, 0.0, 0.0, 0.0],
840 vec![5.0, 5.0, 5.0, 5.0, 5.0],
841 vec![10.0, 0.0, 10.0, 0.0, 10.0],
842 ];
843 let n = centers.len() * n_per_cluster;
844 let mut x = Array2::<f64>::zeros((n, n_features));
845 let mut labels = Vec::with_capacity(n);
846 for (c_idx, center) in centers.iter().enumerate() {
847 for i in 0..n_per_cluster {
848 let row = c_idx * n_per_cluster + i;
849 for (f, &c) in center.iter().enumerate() {
850 x[[row, f]] = c + normal.sample(&mut rng);
851 }
852 labels.push(c_idx);
853 }
854 }
855 (x, labels)
856 }
857
858 #[test]
859 fn test_umap_basic_shape() {
860 let x = Array2::<f64>::from_shape_fn((30, 5), |(i, j)| (i + j) as f64);
861 let umap = Umap::new().with_n_epochs(10).with_random_state(42);
862 let fitted = umap.fit(&x, &()).unwrap();
863 assert_eq!(fitted.embedding().dim(), (30, 2));
864 }
865
866 #[test]
867 fn test_umap_3d_embedding() {
868 let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
869 let umap = Umap::new()
870 .with_n_components(3)
871 .with_n_epochs(10)
872 .with_random_state(42);
873 let fitted = umap.fit(&x, &()).unwrap();
874 assert_eq!(fitted.embedding().ncols(), 3);
875 }
876
877 #[test]
878 fn test_umap_separates_clusters() {
879 let (x, labels) = make_blobs(42);
880 let umap = Umap::new()
881 .with_n_neighbors(5)
882 .with_n_epochs(100)
883 .with_random_state(42);
884 let fitted = umap.fit(&x, &()).unwrap();
885 let emb = fitted.embedding();
886
887 let n = emb.nrows();
889 let mut correct = 0;
890 for i in 0..n {
891 let mut dists: Vec<(f64, usize)> = (0..n)
892 .filter(|&j| j != i)
893 .map(|j| {
894 let mut d = 0.0;
895 for dd in 0..emb.ncols() {
896 let diff = emb[[i, dd]] - emb[[j, dd]];
897 d += diff * diff;
898 }
899 (d, labels[j])
900 })
901 .collect();
902 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
903 let mut votes = [0usize; 3];
904 for &(_, lbl) in dists.iter().take(3) {
905 votes[lbl] += 1;
906 }
907 let pred = votes.iter().enumerate().max_by_key(|&(_, v)| v).unwrap().0;
908 if pred == labels[i] {
909 correct += 1;
910 }
911 }
912 let accuracy = f64::from(correct) / n as f64;
913 assert!(
914 accuracy > 0.8,
915 "UMAP k-NN accuracy should be > 80%, got {:.1}%",
916 accuracy * 100.0
917 );
918 }
919
920 #[test]
921 fn test_umap_transform_new_data() {
922 let (x, _) = make_blobs(42);
923 let umap = Umap::new()
924 .with_n_neighbors(5)
925 .with_n_epochs(50)
926 .with_random_state(42);
927 let fitted = umap.fit(&x, &()).unwrap();
928
929 let x_test = x.slice(ndarray::s![0..5, ..]).to_owned();
931 let projected = fitted.transform(&x_test).unwrap();
932 assert_eq!(projected.dim(), (5, 2));
933 }
934
935 #[test]
936 fn test_umap_transform_shape_mismatch() {
937 let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
938 let umap = Umap::new().with_n_epochs(10).with_random_state(42);
939 let fitted = umap.fit(&x, &()).unwrap();
940 let x_bad = Array2::<f64>::zeros((5, 3)); assert!(fitted.transform(&x_bad).is_err());
942 }
943
944 #[test]
945 fn test_umap_ab_params_reasonable() {
946 let (a, b) = find_ab_params(0.1, 1.0);
947 assert!(a > 0.0, "a should be positive, got {a}");
949 assert!(b > 0.0, "b should be positive, got {b}");
950 let val_at_min = 1.0 / (1.0 + a * (0.1_f64).powf(2.0 * b));
953 assert!(
954 val_at_min > 0.5,
955 "kernel at min_dist should be > 0.5, got {val_at_min}"
956 );
957 }
958
959 #[test]
960 fn test_umap_invalid_n_components_zero() {
961 let x = Array2::<f64>::zeros((10, 3));
962 let umap = Umap::new().with_n_components(0);
963 assert!(umap.fit(&x, &()).is_err());
964 }
965
966 #[test]
967 fn test_umap_invalid_n_neighbors_zero() {
968 let x = Array2::<f64>::zeros((10, 3));
969 let umap = Umap::new().with_n_neighbors(0);
970 assert!(umap.fit(&x, &()).is_err());
971 }
972
973 #[test]
974 fn test_umap_invalid_min_dist() {
975 let x = Array2::<f64>::zeros((10, 3));
976 let umap = Umap::new().with_min_dist(-0.1);
977 assert!(umap.fit(&x, &()).is_err());
978 }
979
980 #[test]
981 fn test_umap_invalid_spread() {
982 let x = Array2::<f64>::zeros((10, 3));
983 let umap = Umap::new().with_spread(0.0);
984 assert!(umap.fit(&x, &()).is_err());
985 }
986
987 #[test]
988 fn test_umap_invalid_learning_rate() {
989 let x = Array2::<f64>::zeros((10, 3));
990 let umap = Umap::new().with_learning_rate(-1.0);
991 assert!(umap.fit(&x, &()).is_err());
992 }
993
994 #[test]
995 fn test_umap_insufficient_samples() {
996 let x = Array2::<f64>::zeros((1, 3));
997 let umap = Umap::new();
998 assert!(umap.fit(&x, &()).is_err());
999 }
1000
1001 #[test]
1002 fn test_umap_getters() {
1003 let umap = Umap::new()
1004 .with_n_components(3)
1005 .with_n_neighbors(10)
1006 .with_min_dist(0.2)
1007 .with_spread(1.5)
1008 .with_learning_rate(0.5)
1009 .with_n_epochs(100)
1010 .with_metric(UmapMetric::Manhattan)
1011 .with_negative_sample_rate(3)
1012 .with_random_state(99);
1013 assert_eq!(umap.n_components(), 3);
1014 assert_eq!(umap.n_neighbors(), 10);
1015 assert!((umap.min_dist() - 0.2).abs() < 1e-10);
1016 assert!((umap.spread() - 1.5).abs() < 1e-10);
1017 assert!((umap.learning_rate() - 0.5).abs() < 1e-10);
1018 assert_eq!(umap.n_epochs(), 100);
1019 assert_eq!(umap.metric(), UmapMetric::Manhattan);
1020 assert_eq!(umap.negative_sample_rate(), 3);
1021 assert_eq!(umap.random_state(), Some(99));
1022 }
1023
1024 #[test]
1025 fn test_umap_default() {
1026 let umap = Umap::default();
1027 assert_eq!(umap.n_components(), 2);
1028 assert_eq!(umap.n_neighbors(), 15);
1029 }
1030
1031 #[test]
1032 fn test_umap_cosine_metric() {
1033 let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j + 1) as f64);
1034 let umap = Umap::new()
1035 .with_metric(UmapMetric::Cosine)
1036 .with_n_epochs(10)
1037 .with_random_state(42);
1038 let fitted = umap.fit(&x, &()).unwrap();
1039 assert_eq!(fitted.embedding().dim(), (20, 2));
1040 }
1041
1042 #[test]
1043 fn test_umap_small_n_neighbors_capped() {
1044 let x = Array2::<f64>::from_shape_fn((5, 3), |(i, j)| (i + j) as f64);
1046 let umap = Umap::new()
1047 .with_n_neighbors(100)
1048 .with_n_epochs(10)
1049 .with_random_state(42);
1050 let fitted = umap.fit(&x, &()).unwrap();
1051 assert_eq!(fitted.embedding().dim(), (5, 2));
1052 }
1053
1054 #[test]
1055 fn test_umap_fitted_accessors() {
1056 let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
1057 let umap = Umap::new().with_n_epochs(10).with_random_state(42);
1058 let fitted = umap.fit(&x, &()).unwrap();
1059 assert!(fitted.a() > 0.0);
1060 assert!(fitted.b() > 0.0);
1061 }
1062}