Skip to main content

ferrolearn_decomp/
umap.rs

1//! Uniform Manifold Approximation and Projection (UMAP).
2//!
3//! [`Umap`] performs non-linear dimensionality reduction based on the
4//! mathematical framework of Riemannian geometry and algebraic topology,
5//! as described by McInnes, Healy, and Melville (2018).
6//!
7//! # Algorithm
8//!
9//! 1. Build a k-nearest-neighbor graph with `n_neighbors` neighbors.
10//! 2. Compute the fuzzy simplicial set by smoothing kNN distances: for each
11//!    point find a local connectivity parameter `rho` (distance to nearest
12//!    neighbor) and a bandwidth `sigma` such that the sum of the membership
13//!    strengths equals `log2(n_neighbors)`.
14//! 3. Symmetrise the fuzzy graph: `w_ij = w_i|j + w_j|i - w_i|j * w_j|i`.
15//! 4. Determine curve parameters `a` and `b` from `min_dist` and `spread`
16//!    that define the target distribution in the embedding space:
17//!    `phi(d) = 1 / (1 + a * d^(2b))`.
18//! 5. Initialise the embedding (spectral or random).
19//! 6. Optimise via SGD with attractive forces on positive edges and
20//!    repulsive forces via negative sampling (5 negatives per positive).
21//!
22//! # Examples
23//!
24//! ```
25//! use ferrolearn_decomp::Umap;
26//! use ferrolearn_core::traits::{Fit, Transform};
27//! use ndarray::Array2;
28//!
29//! let x = Array2::<f64>::from_shape_fn((30, 5), |(i, j)| (i + j) as f64);
30//! let umap = Umap::new().with_random_state(42).with_n_epochs(50);
31//! let fitted = umap.fit(&x, &()).unwrap();
32//! let emb = fitted.embedding();
33//! assert_eq!(emb.ncols(), 2);
34//! ```
35
36use ferrolearn_core::error::FerroError;
37use ferrolearn_core::traits::{Fit, Transform};
38use ndarray::Array2;
39use rand::SeedableRng;
40use rand_distr::{Distribution, Uniform};
41use rand_xoshiro::Xoshiro256PlusPlus;
42
43// ---------------------------------------------------------------------------
44// Metric enum
45// ---------------------------------------------------------------------------
46
47/// Distance metric for UMAP.
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum UmapMetric {
50    /// Standard Euclidean distance.
51    Euclidean,
52    /// Manhattan (L1) distance.
53    Manhattan,
54    /// Cosine distance (1 - cosine similarity).
55    Cosine,
56}
57
58// ---------------------------------------------------------------------------
59// Umap (unfitted)
60// ---------------------------------------------------------------------------
61
62/// UMAP configuration.
63///
64/// Holds hyperparameters for the UMAP algorithm. Calling [`Fit::fit`]
65/// computes the embedding and returns a [`FittedUmap`].
66#[derive(Debug, Clone)]
67pub struct Umap {
68    /// Number of embedding dimensions (default 2).
69    n_components: usize,
70    /// Number of nearest neighbors for the kNN graph (default 15).
71    n_neighbors: usize,
72    /// Minimum distance in the embedding space (default 0.1).
73    min_dist: f64,
74    /// Spread of the embedding (default 1.0).
75    spread: f64,
76    /// Learning rate for SGD (default 1.0).
77    learning_rate: f64,
78    /// Number of SGD epochs (default 200).
79    n_epochs: usize,
80    /// Distance metric (default Euclidean).
81    metric: UmapMetric,
82    /// Number of negative samples per positive edge (default 5).
83    negative_sample_rate: usize,
84    /// Optional random seed for reproducibility.
85    random_state: Option<u64>,
86}
87
88impl Umap {
89    /// Create a new `Umap` with default parameters.
90    ///
91    /// Defaults: `n_components=2`, `n_neighbors=15`, `min_dist=0.1`,
92    /// `spread=1.0`, `learning_rate=1.0`, `n_epochs=200`, metric=`Euclidean`,
93    /// `negative_sample_rate=5`.
94    #[must_use]
95    pub fn new() -> Self {
96        Self {
97            n_components: 2,
98            n_neighbors: 15,
99            min_dist: 0.1,
100            spread: 1.0,
101            learning_rate: 1.0,
102            n_epochs: 200,
103            metric: UmapMetric::Euclidean,
104            negative_sample_rate: 5,
105            random_state: None,
106        }
107    }
108
109    /// Set the number of embedding dimensions.
110    #[must_use]
111    pub fn with_n_components(mut self, n: usize) -> Self {
112        self.n_components = n;
113        self
114    }
115
116    /// Set the number of nearest neighbors.
117    #[must_use]
118    pub fn with_n_neighbors(mut self, k: usize) -> Self {
119        self.n_neighbors = k;
120        self
121    }
122
123    /// Set the minimum distance in the embedding.
124    #[must_use]
125    pub fn with_min_dist(mut self, d: f64) -> Self {
126        self.min_dist = d;
127        self
128    }
129
130    /// Set the spread.
131    #[must_use]
132    pub fn with_spread(mut self, s: f64) -> Self {
133        self.spread = s;
134        self
135    }
136
137    /// Set the learning rate.
138    #[must_use]
139    pub fn with_learning_rate(mut self, lr: f64) -> Self {
140        self.learning_rate = lr;
141        self
142    }
143
144    /// Set the number of SGD epochs.
145    #[must_use]
146    pub fn with_n_epochs(mut self, n: usize) -> Self {
147        self.n_epochs = n;
148        self
149    }
150
151    /// Set the distance metric.
152    #[must_use]
153    pub fn with_metric(mut self, m: UmapMetric) -> Self {
154        self.metric = m;
155        self
156    }
157
158    /// Set the negative sample rate.
159    #[must_use]
160    pub fn with_negative_sample_rate(mut self, rate: usize) -> Self {
161        self.negative_sample_rate = rate;
162        self
163    }
164
165    /// Set the random seed.
166    #[must_use]
167    pub fn with_random_state(mut self, seed: u64) -> Self {
168        self.random_state = Some(seed);
169        self
170    }
171
172    /// Return the configured number of components.
173    #[must_use]
174    pub fn n_components(&self) -> usize {
175        self.n_components
176    }
177
178    /// Return the configured number of neighbors.
179    #[must_use]
180    pub fn n_neighbors(&self) -> usize {
181        self.n_neighbors
182    }
183
184    /// Return the configured minimum distance.
185    #[must_use]
186    pub fn min_dist(&self) -> f64 {
187        self.min_dist
188    }
189
190    /// Return the configured spread.
191    #[must_use]
192    pub fn spread(&self) -> f64 {
193        self.spread
194    }
195
196    /// Return the configured learning rate.
197    #[must_use]
198    pub fn learning_rate(&self) -> f64 {
199        self.learning_rate
200    }
201
202    /// Return the configured number of epochs.
203    #[must_use]
204    pub fn n_epochs(&self) -> usize {
205        self.n_epochs
206    }
207
208    /// Return the configured metric.
209    #[must_use]
210    pub fn metric(&self) -> UmapMetric {
211        self.metric
212    }
213
214    /// Return the configured negative sample rate.
215    #[must_use]
216    pub fn negative_sample_rate(&self) -> usize {
217        self.negative_sample_rate
218    }
219
220    /// Return the configured random state, if any.
221    #[must_use]
222    pub fn random_state(&self) -> Option<u64> {
223        self.random_state
224    }
225}
226
227impl Default for Umap {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233// ---------------------------------------------------------------------------
234// FittedUmap
235// ---------------------------------------------------------------------------
236
237/// A fitted UMAP model holding the learned embedding and training data.
238///
239/// Created by calling [`Fit::fit`] on a [`Umap`]. Implements
240/// [`Transform<Array2<f64>>`] for projecting new data via nearest-neighbor
241/// lookup.
242#[derive(Debug, Clone)]
243pub struct FittedUmap {
244    /// The embedding, shape `(n_samples, n_components)`.
245    embedding_: Array2<f64>,
246    /// Training data, stored for out-of-sample extension.
247    x_train_: Array2<f64>,
248    /// Curve parameter `a`.
249    a_: f64,
250    /// Curve parameter `b`.
251    b_: f64,
252    /// Number of neighbors used.
253    n_neighbors_: usize,
254    /// The metric used.
255    metric_: UmapMetric,
256}
257
258impl FittedUmap {
259    /// The embedding coordinates, shape `(n_samples, n_components)`.
260    #[must_use]
261    pub fn embedding(&self) -> &Array2<f64> {
262        &self.embedding_
263    }
264
265    /// The curve parameter `a`.
266    #[must_use]
267    pub fn a(&self) -> f64 {
268        self.a_
269    }
270
271    /// The curve parameter `b`.
272    #[must_use]
273    pub fn b(&self) -> f64 {
274        self.b_
275    }
276}
277
278// ---------------------------------------------------------------------------
279// Internal helpers
280// ---------------------------------------------------------------------------
281
282/// Compute distance between two rows of a matrix using the given metric.
283fn compute_distance(x: &Array2<f64>, i: usize, j: usize, metric: UmapMetric) -> f64 {
284    let ncols = x.ncols();
285    match metric {
286        UmapMetric::Euclidean => {
287            let mut sq = 0.0;
288            for k in 0..ncols {
289                let diff = x[[i, k]] - x[[j, k]];
290                sq += diff * diff;
291            }
292            sq.sqrt()
293        }
294        UmapMetric::Manhattan => {
295            let mut sum = 0.0;
296            for k in 0..ncols {
297                sum += (x[[i, k]] - x[[j, k]]).abs();
298            }
299            sum
300        }
301        UmapMetric::Cosine => {
302            let mut dot = 0.0;
303            let mut norm_i = 0.0;
304            let mut norm_j = 0.0;
305            for k in 0..ncols {
306                dot += x[[i, k]] * x[[j, k]];
307                norm_i += x[[i, k]] * x[[i, k]];
308                norm_j += x[[j, k]] * x[[j, k]];
309            }
310            let denom = (norm_i * norm_j).sqrt();
311            if denom < 1e-16 {
312                1.0
313            } else {
314                1.0 - dot / denom
315            }
316        }
317    }
318}
319
320/// Compute distance between a point (row of x_new) and a training point.
321fn compute_distance_cross(
322    x_new: &Array2<f64>,
323    i: usize,
324    x_train: &Array2<f64>,
325    j: usize,
326    metric: UmapMetric,
327) -> f64 {
328    let ncols = x_new.ncols();
329    match metric {
330        UmapMetric::Euclidean => {
331            let mut sq = 0.0;
332            for k in 0..ncols {
333                let diff = x_new[[i, k]] - x_train[[j, k]];
334                sq += diff * diff;
335            }
336            sq.sqrt()
337        }
338        UmapMetric::Manhattan => {
339            let mut sum = 0.0;
340            for k in 0..ncols {
341                sum += (x_new[[i, k]] - x_train[[j, k]]).abs();
342            }
343            sum
344        }
345        UmapMetric::Cosine => {
346            let mut dot = 0.0;
347            let mut norm_i = 0.0;
348            let mut norm_j = 0.0;
349            for k in 0..ncols {
350                dot += x_new[[i, k]] * x_train[[j, k]];
351                norm_i += x_new[[i, k]] * x_new[[i, k]];
352                norm_j += x_train[[j, k]] * x_train[[j, k]];
353            }
354            let denom = (norm_i * norm_j).sqrt();
355            if denom < 1e-16 {
356                1.0
357            } else {
358                1.0 - dot / denom
359            }
360        }
361    }
362}
363
364/// Build k-nearest-neighbor graph. Returns for each point the sorted list of
365/// (neighbor_index, distance) pairs.
366fn build_knn(x: &Array2<f64>, k: usize, metric: UmapMetric) -> Vec<Vec<(usize, f64)>> {
367    let n = x.nrows();
368    let mut knn = Vec::with_capacity(n);
369    for i in 0..n {
370        let mut dists: Vec<(usize, f64)> = (0..n)
371            .filter(|&j| j != i)
372            .map(|j| (j, compute_distance(x, i, j, metric)))
373            .collect();
374        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
375        dists.truncate(k);
376        knn.push(dists);
377    }
378    knn
379}
380
381/// Compute the fuzzy simplicial set: smooth kNN distances to get membership
382/// strengths.
383///
384/// For each point i, find `rho_i` (distance to nearest neighbor) and
385/// `sigma_i` such that `sum_j exp(-(d(i,j) - rho_i) / sigma_i) = log2(k)`.
386///
387/// Returns a sparse-ish weighted graph as a list of (i, j, weight) edges.
388fn compute_fuzzy_simplicial_set(knn: &[Vec<(usize, f64)>], n: usize) -> Vec<(usize, usize, f64)> {
389    let k = if knn.is_empty() { 0 } else { knn[0].len() };
390    let target = (k as f64).ln() / std::f64::consts::LN_2; // log2(k)
391
392    // For each point, compute rho and sigma.
393    let mut rho = vec![0.0; n];
394    let mut sigma = vec![1.0; n];
395
396    for i in 0..n {
397        if knn[i].is_empty() {
398            continue;
399        }
400        // rho_i = distance to nearest neighbor.
401        rho[i] = knn[i][0].1;
402        if rho[i] < 1e-16 {
403            // If nearest neighbor is at distance 0, find first non-zero.
404            for &(_, d) in &knn[i] {
405                if d > 1e-16 {
406                    rho[i] = d;
407                    break;
408                }
409            }
410        }
411
412        // Binary search for sigma.
413        let mut lo = 1e-20_f64;
414        let mut hi = 1e4_f64;
415        for _iter in 0..64 {
416            let mid = f64::midpoint(lo, hi);
417            let mut val = 0.0;
418            for &(_, d) in &knn[i] {
419                let adjusted = (d - rho[i]).max(0.0);
420                val += (-adjusted / mid).exp();
421            }
422            if val > target {
423                hi = mid;
424            } else {
425                lo = mid;
426            }
427            if (hi - lo) / (lo + 1e-16) < 1e-5 {
428                break;
429            }
430        }
431        sigma[i] = f64::midpoint(lo, hi);
432    }
433
434    // Build directed graph with membership strengths.
435    // w_{i|j} = exp(-(d(i,j) - rho_i) / sigma_i)  for j in knn(i)
436    let mut directed: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
437    for (i, neighbors) in knn.iter().enumerate() {
438        for &(j, d) in neighbors {
439            let adjusted = (d - rho[i]).max(0.0);
440            let w = (-adjusted / sigma[i]).exp();
441            directed[i].push((j, w));
442        }
443    }
444
445    // Symmetrise: w_ij = w_{i|j} + w_{j|i} - w_{i|j} * w_{j|i}
446    // Use a hash map approach for efficiency.
447    // Collect directed weights for each undirected edge.
448    let mut forward: std::collections::HashMap<(usize, usize), f64> =
449        std::collections::HashMap::new();
450    let mut backward: std::collections::HashMap<(usize, usize), f64> =
451        std::collections::HashMap::new();
452
453    for (i, neighbors) in directed.iter().enumerate() {
454        for &(j, w) in neighbors {
455            let key = if i < j { (i, j) } else { (j, i) };
456            if i < j {
457                forward.insert(key, w);
458            } else {
459                backward.insert(key, w);
460            }
461        }
462    }
463
464    // Combine keys.
465    let mut all_keys: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
466    for &k in forward.keys() {
467        all_keys.insert(k);
468    }
469    for &k in backward.keys() {
470        all_keys.insert(k);
471    }
472
473    let mut edges = Vec::with_capacity(all_keys.len());
474    for key in all_keys {
475        let w_fwd = forward.get(&key).copied().unwrap_or(0.0);
476        let w_bwd = backward.get(&key).copied().unwrap_or(0.0);
477        let w = w_fwd + w_bwd - w_fwd * w_bwd;
478        if w > 1e-16 {
479            edges.push((key.0, key.1, w));
480        }
481    }
482
483    edges
484}
485
486/// Find curve parameters `a` and `b` from `min_dist` and `spread`.
487///
488/// We want `1 / (1 + a * d^(2b)) ~ 1` when `d < min_dist` and
489/// `exp(-(d - min_dist) / spread)` when `d >= min_dist`.
490///
491/// This is solved by a simple grid search / least squares fit.
492fn find_ab_params(min_dist: f64, spread: f64) -> (f64, f64) {
493    // Sample distances and target values.
494    let n_samples = 300;
495    let d_max = 3.0 * spread;
496    let mut best_a = 1.0;
497    let mut best_b = 1.0;
498    let mut best_err = f64::MAX;
499
500    // Grid search over a and b.
501    let a_range: Vec<f64> = (1..=40).map(|i| f64::from(i) * 0.25).collect();
502    let b_range: Vec<f64> = (1..=30).map(|i| f64::from(i) * 0.1).collect();
503
504    for &a in &a_range {
505        for &b in &b_range {
506            let mut err = 0.0;
507            for k in 0..n_samples {
508                let d = (f64::from(k) + 0.5) / f64::from(n_samples) * d_max;
509                let target = if d <= min_dist {
510                    1.0
511                } else {
512                    (-(d - min_dist) / spread).exp()
513                };
514                let pred = 1.0 / (1.0 + a * d.powf(2.0 * b));
515                let diff = pred - target;
516                err += diff * diff;
517            }
518            if err < best_err {
519                best_err = err;
520                best_a = a;
521                best_b = b;
522            }
523        }
524    }
525
526    // Refine with a finer grid around the best.
527    let a_lo = (best_a - 0.3).max(0.01);
528    let a_hi = best_a + 0.3;
529    let b_lo = (best_b - 0.15).max(0.01);
530    let b_hi = best_b + 0.15;
531
532    for ia in 0..20 {
533        let a = a_lo + (a_hi - a_lo) * f64::from(ia) / 19.0;
534        for ib in 0..20 {
535            let b = b_lo + (b_hi - b_lo) * f64::from(ib) / 19.0;
536            let mut err = 0.0;
537            for k in 0..n_samples {
538                let d = (f64::from(k) + 0.5) / f64::from(n_samples) * d_max;
539                let target = if d <= min_dist {
540                    1.0
541                } else {
542                    (-(d - min_dist) / spread).exp()
543                };
544                let pred = 1.0 / (1.0 + a * d.powf(2.0 * b));
545                let diff = pred - target;
546                err += diff * diff;
547            }
548            if err < best_err {
549                best_err = err;
550                best_a = a;
551                best_b = b;
552            }
553        }
554    }
555
556    (best_a, best_b)
557}
558
559/// Clip a value to prevent overflow/underflow in gradient computation.
560fn clip(val: f64, lo: f64, hi: f64) -> f64 {
561    if val < lo {
562        lo
563    } else if val > hi {
564        hi
565    } else {
566        val
567    }
568}
569
570// ---------------------------------------------------------------------------
571// Trait implementations
572// ---------------------------------------------------------------------------
573
574impl Fit<Array2<f64>, ()> for Umap {
575    type Fitted = FittedUmap;
576    type Error = FerroError;
577
578    /// Fit UMAP by computing the fuzzy simplicial set and optimising the
579    /// low-dimensional embedding via SGD.
580    ///
581    /// # Errors
582    ///
583    /// - [`FerroError::InvalidParameter`] if `n_components` is zero,
584    ///   `n_neighbors` is zero or too large, `min_dist` is negative,
585    ///   `spread` is non-positive, or `learning_rate` is non-positive.
586    /// - [`FerroError::InsufficientSamples`] if there are fewer samples than
587    ///   `n_neighbors + 1`.
588    fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedUmap, FerroError> {
589        let n = x.nrows();
590
591        // Validate parameters.
592        if self.n_components == 0 {
593            return Err(FerroError::InvalidParameter {
594                name: "n_components".into(),
595                reason: "must be at least 1".into(),
596            });
597        }
598        if self.n_neighbors == 0 {
599            return Err(FerroError::InvalidParameter {
600                name: "n_neighbors".into(),
601                reason: "must be at least 1".into(),
602            });
603        }
604        if n < 2 {
605            return Err(FerroError::InsufficientSamples {
606                required: 2,
607                actual: n,
608                context: "Umap::fit requires at least 2 samples".into(),
609            });
610        }
611        let effective_k = self.n_neighbors.min(n - 1);
612        if self.min_dist < 0.0 {
613            return Err(FerroError::InvalidParameter {
614                name: "min_dist".into(),
615                reason: "must be non-negative".into(),
616            });
617        }
618        if self.spread <= 0.0 {
619            return Err(FerroError::InvalidParameter {
620                name: "spread".into(),
621                reason: "must be positive".into(),
622            });
623        }
624        if self.learning_rate <= 0.0 {
625            return Err(FerroError::InvalidParameter {
626                name: "learning_rate".into(),
627                reason: "must be positive".into(),
628            });
629        }
630
631        let dim = self.n_components;
632        let seed = self.random_state.unwrap_or(0);
633
634        // Step 1: Build kNN graph.
635        let knn = build_knn(x, effective_k, self.metric);
636
637        // Step 2: Compute fuzzy simplicial set.
638        let edges = compute_fuzzy_simplicial_set(&knn, n);
639
640        // Step 3: Find a, b parameters.
641        let (a, b) = find_ab_params(self.min_dist, self.spread);
642
643        // Step 4: Initialise embedding (random uniform).
644        let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
645        let uniform = Uniform::new(-10.0, 10.0).unwrap();
646        let mut y = Array2::<f64>::zeros((n, dim));
647        for elem in &mut y {
648            *elem = uniform.sample(&mut rng);
649        }
650
651        // Pre-compute epochs per edge: spread epochs proportional to weight.
652        if edges.is_empty() {
653            return Ok(FittedUmap {
654                embedding_: y,
655                x_train_: x.to_owned(),
656                a_: a,
657                b_: b,
658                n_neighbors_: effective_k,
659                metric_: self.metric,
660            });
661        }
662
663        let max_weight = edges.iter().map(|e| e.2).fold(0.0_f64, f64::max);
664
665        // Each edge gets `n_epochs * (weight / max_weight)` total updates.
666        let epochs_per_sample: Vec<f64> = edges
667            .iter()
668            .map(|e| {
669                let ratio = e.2 / max_weight;
670                if ratio > 0.0 {
671                    (self.n_epochs as f64) / ((self.n_epochs as f64) * ratio).max(1.0)
672                } else {
673                    f64::MAX
674                }
675            })
676            .collect();
677
678        let mut epoch_of_next_sample: Vec<f64> = epochs_per_sample.clone();
679
680        let neg_rate = self.negative_sample_rate;
681        let idx_uniform = Uniform::new(0usize, n).unwrap();
682
683        // Step 5: SGD optimisation.
684        for epoch in 0..self.n_epochs {
685            let alpha = self.learning_rate * (1.0 - epoch as f64 / self.n_epochs as f64);
686            let alpha = alpha.max(0.0);
687
688            for (edge_idx, &(ei, ej, _weight)) in edges.iter().enumerate() {
689                if epoch_of_next_sample[edge_idx] > epoch as f64 {
690                    continue;
691                }
692
693                // Attractive force.
694                let mut dist_sq = 0.0;
695                for d in 0..dim {
696                    let diff = y[[ei, d]] - y[[ej, d]];
697                    dist_sq += diff * diff;
698                }
699                let dist_sq = dist_sq.max(1e-16);
700
701                let grad_coeff = -2.0 * a * b * dist_sq.powf(b - 1.0) / (1.0 + a * dist_sq.powf(b));
702
703                for d in 0..dim {
704                    let diff = y[[ei, d]] - y[[ej, d]];
705                    let grad = clip(grad_coeff * diff, -4.0, 4.0);
706                    y[[ei, d]] += alpha * grad;
707                    y[[ej, d]] -= alpha * grad;
708                }
709
710                // Negative sampling.
711                for _ in 0..neg_rate {
712                    let neg = idx_uniform.sample(&mut rng);
713                    if neg == ei {
714                        continue;
715                    }
716                    let mut neg_dist_sq = 0.0;
717                    for d in 0..dim {
718                        let diff = y[[ei, d]] - y[[neg, d]];
719                        neg_dist_sq += diff * diff;
720                    }
721                    let neg_dist_sq = neg_dist_sq.max(1e-16);
722
723                    let rep_coeff =
724                        2.0 * b / ((0.001 + neg_dist_sq) * (1.0 + a * neg_dist_sq.powf(b)));
725
726                    for d in 0..dim {
727                        let diff = y[[ei, d]] - y[[neg, d]];
728                        let grad = clip(rep_coeff * diff, -4.0, 4.0);
729                        y[[ei, d]] += alpha * grad;
730                    }
731                }
732
733                epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
734            }
735        }
736
737        Ok(FittedUmap {
738            embedding_: y,
739            x_train_: x.to_owned(),
740            a_: a,
741            b_: b,
742            n_neighbors_: effective_k,
743            metric_: self.metric,
744        })
745    }
746}
747
748impl Transform<Array2<f64>> for FittedUmap {
749    type Output = Array2<f64>;
750    type Error = FerroError;
751
752    /// Project new data into the UMAP embedding space.
753    ///
754    /// For each new point, find the nearest neighbors in the training data
755    /// and compute a weighted average of their embeddings (weighted by the
756    /// UMAP kernel).
757    ///
758    /// # Errors
759    ///
760    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
761    /// not match the training data.
762    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
763        let n_features = self.x_train_.ncols();
764        if x.ncols() != n_features {
765            return Err(FerroError::ShapeMismatch {
766                expected: vec![x.nrows(), n_features],
767                actual: vec![x.nrows(), x.ncols()],
768                context: "FittedUmap::transform".into(),
769            });
770        }
771
772        let n_test = x.nrows();
773        let n_train = self.x_train_.nrows();
774        let dim = self.embedding_.ncols();
775        let k = self.n_neighbors_.min(n_train);
776
777        let mut result = Array2::<f64>::zeros((n_test, dim));
778
779        for t in 0..n_test {
780            // Find k nearest training neighbors.
781            let mut dists: Vec<(usize, f64)> = (0..n_train)
782                .map(|j| {
783                    (
784                        j,
785                        compute_distance_cross(x, t, &self.x_train_, j, self.metric_),
786                    )
787                })
788                .collect();
789            dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
790            dists.truncate(k);
791
792            // Compute weights using the UMAP kernel: 1/(1 + a * d^(2b)).
793            let mut weights = Vec::with_capacity(k);
794            let mut weight_sum = 0.0;
795            for &(_, d) in &dists {
796                let w = 1.0 / (1.0 + self.a_ * d.powf(2.0 * self.b_));
797                weights.push(w);
798                weight_sum += w;
799            }
800
801            if weight_sum < 1e-16 {
802                // Fallback: uniform weights.
803                weight_sum = k as f64;
804                weights = vec![1.0; k];
805            }
806
807            // Weighted average of neighbor embeddings.
808            for (idx, &(train_idx, _)) in dists.iter().enumerate() {
809                let w = weights[idx] / weight_sum;
810                for d in 0..dim {
811                    result[[t, d]] += w * self.embedding_[[train_idx, d]];
812                }
813            }
814        }
815
816        Ok(result)
817    }
818}
819
820// ---------------------------------------------------------------------------
821// Tests
822// ---------------------------------------------------------------------------
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827    use ndarray::Array2;
828    use rand::SeedableRng;
829    use rand_distr::{Distribution, Normal};
830    use rand_xoshiro::Xoshiro256PlusPlus;
831
832    /// Generate small blobs dataset.
833    fn make_blobs(seed: u64) -> (Array2<f64>, Vec<usize>) {
834        let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
835        let normal = Normal::new(0.0, 0.3).unwrap();
836        let n_per_cluster = 10;
837        let n_features = 5;
838        let centers = [
839            vec![0.0, 0.0, 0.0, 0.0, 0.0],
840            vec![5.0, 5.0, 5.0, 5.0, 5.0],
841            vec![10.0, 0.0, 10.0, 0.0, 10.0],
842        ];
843        let n = centers.len() * n_per_cluster;
844        let mut x = Array2::<f64>::zeros((n, n_features));
845        let mut labels = Vec::with_capacity(n);
846        for (c_idx, center) in centers.iter().enumerate() {
847            for i in 0..n_per_cluster {
848                let row = c_idx * n_per_cluster + i;
849                for (f, &c) in center.iter().enumerate() {
850                    x[[row, f]] = c + normal.sample(&mut rng);
851                }
852                labels.push(c_idx);
853            }
854        }
855        (x, labels)
856    }
857
858    #[test]
859    fn test_umap_basic_shape() {
860        let x = Array2::<f64>::from_shape_fn((30, 5), |(i, j)| (i + j) as f64);
861        let umap = Umap::new().with_n_epochs(10).with_random_state(42);
862        let fitted = umap.fit(&x, &()).unwrap();
863        assert_eq!(fitted.embedding().dim(), (30, 2));
864    }
865
866    #[test]
867    fn test_umap_3d_embedding() {
868        let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
869        let umap = Umap::new()
870            .with_n_components(3)
871            .with_n_epochs(10)
872            .with_random_state(42);
873        let fitted = umap.fit(&x, &()).unwrap();
874        assert_eq!(fitted.embedding().ncols(), 3);
875    }
876
877    #[test]
878    fn test_umap_separates_clusters() {
879        let (x, labels) = make_blobs(42);
880        let umap = Umap::new()
881            .with_n_neighbors(5)
882            .with_n_epochs(100)
883            .with_random_state(42);
884        let fitted = umap.fit(&x, &()).unwrap();
885        let emb = fitted.embedding();
886
887        // Check cluster separation with k-NN accuracy (k=3).
888        let n = emb.nrows();
889        let mut correct = 0;
890        for i in 0..n {
891            let mut dists: Vec<(f64, usize)> = (0..n)
892                .filter(|&j| j != i)
893                .map(|j| {
894                    let mut d = 0.0;
895                    for dd in 0..emb.ncols() {
896                        let diff = emb[[i, dd]] - emb[[j, dd]];
897                        d += diff * diff;
898                    }
899                    (d, labels[j])
900                })
901                .collect();
902            dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
903            let mut votes = [0usize; 3];
904            for &(_, lbl) in dists.iter().take(3) {
905                votes[lbl] += 1;
906            }
907            let pred = votes.iter().enumerate().max_by_key(|&(_, v)| v).unwrap().0;
908            if pred == labels[i] {
909                correct += 1;
910            }
911        }
912        let accuracy = f64::from(correct) / n as f64;
913        assert!(
914            accuracy > 0.8,
915            "UMAP k-NN accuracy should be > 80%, got {:.1}%",
916            accuracy * 100.0
917        );
918    }
919
920    #[test]
921    fn test_umap_transform_new_data() {
922        let (x, _) = make_blobs(42);
923        let umap = Umap::new()
924            .with_n_neighbors(5)
925            .with_n_epochs(50)
926            .with_random_state(42);
927        let fitted = umap.fit(&x, &()).unwrap();
928
929        // Transform a subset of training data.
930        let x_test = x.slice(ndarray::s![0..5, ..]).to_owned();
931        let projected = fitted.transform(&x_test).unwrap();
932        assert_eq!(projected.dim(), (5, 2));
933    }
934
935    #[test]
936    fn test_umap_transform_shape_mismatch() {
937        let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
938        let umap = Umap::new().with_n_epochs(10).with_random_state(42);
939        let fitted = umap.fit(&x, &()).unwrap();
940        let x_bad = Array2::<f64>::zeros((5, 3)); // wrong number of features
941        assert!(fitted.transform(&x_bad).is_err());
942    }
943
944    #[test]
945    fn test_umap_ab_params_reasonable() {
946        let (a, b) = find_ab_params(0.1, 1.0);
947        // a and b should be positive.
948        assert!(a > 0.0, "a should be positive, got {a}");
949        assert!(b > 0.0, "b should be positive, got {b}");
950        // At d=0, 1/(1+a*0) = 1, which is correct.
951        // At d=min_dist, should be close to 1.
952        let val_at_min = 1.0 / (1.0 + a * (0.1_f64).powf(2.0 * b));
953        assert!(
954            val_at_min > 0.5,
955            "kernel at min_dist should be > 0.5, got {val_at_min}"
956        );
957    }
958
959    #[test]
960    fn test_umap_invalid_n_components_zero() {
961        let x = Array2::<f64>::zeros((10, 3));
962        let umap = Umap::new().with_n_components(0);
963        assert!(umap.fit(&x, &()).is_err());
964    }
965
966    #[test]
967    fn test_umap_invalid_n_neighbors_zero() {
968        let x = Array2::<f64>::zeros((10, 3));
969        let umap = Umap::new().with_n_neighbors(0);
970        assert!(umap.fit(&x, &()).is_err());
971    }
972
973    #[test]
974    fn test_umap_invalid_min_dist() {
975        let x = Array2::<f64>::zeros((10, 3));
976        let umap = Umap::new().with_min_dist(-0.1);
977        assert!(umap.fit(&x, &()).is_err());
978    }
979
980    #[test]
981    fn test_umap_invalid_spread() {
982        let x = Array2::<f64>::zeros((10, 3));
983        let umap = Umap::new().with_spread(0.0);
984        assert!(umap.fit(&x, &()).is_err());
985    }
986
987    #[test]
988    fn test_umap_invalid_learning_rate() {
989        let x = Array2::<f64>::zeros((10, 3));
990        let umap = Umap::new().with_learning_rate(-1.0);
991        assert!(umap.fit(&x, &()).is_err());
992    }
993
994    #[test]
995    fn test_umap_insufficient_samples() {
996        let x = Array2::<f64>::zeros((1, 3));
997        let umap = Umap::new();
998        assert!(umap.fit(&x, &()).is_err());
999    }
1000
1001    #[test]
1002    fn test_umap_getters() {
1003        let umap = Umap::new()
1004            .with_n_components(3)
1005            .with_n_neighbors(10)
1006            .with_min_dist(0.2)
1007            .with_spread(1.5)
1008            .with_learning_rate(0.5)
1009            .with_n_epochs(100)
1010            .with_metric(UmapMetric::Manhattan)
1011            .with_negative_sample_rate(3)
1012            .with_random_state(99);
1013        assert_eq!(umap.n_components(), 3);
1014        assert_eq!(umap.n_neighbors(), 10);
1015        assert!((umap.min_dist() - 0.2).abs() < 1e-10);
1016        assert!((umap.spread() - 1.5).abs() < 1e-10);
1017        assert!((umap.learning_rate() - 0.5).abs() < 1e-10);
1018        assert_eq!(umap.n_epochs(), 100);
1019        assert_eq!(umap.metric(), UmapMetric::Manhattan);
1020        assert_eq!(umap.negative_sample_rate(), 3);
1021        assert_eq!(umap.random_state(), Some(99));
1022    }
1023
1024    #[test]
1025    fn test_umap_default() {
1026        let umap = Umap::default();
1027        assert_eq!(umap.n_components(), 2);
1028        assert_eq!(umap.n_neighbors(), 15);
1029    }
1030
1031    #[test]
1032    fn test_umap_cosine_metric() {
1033        let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j + 1) as f64);
1034        let umap = Umap::new()
1035            .with_metric(UmapMetric::Cosine)
1036            .with_n_epochs(10)
1037            .with_random_state(42);
1038        let fitted = umap.fit(&x, &()).unwrap();
1039        assert_eq!(fitted.embedding().dim(), (20, 2));
1040    }
1041
1042    #[test]
1043    fn test_umap_small_n_neighbors_capped() {
1044        // n_neighbors > n-1 should be automatically capped
1045        let x = Array2::<f64>::from_shape_fn((5, 3), |(i, j)| (i + j) as f64);
1046        let umap = Umap::new()
1047            .with_n_neighbors(100)
1048            .with_n_epochs(10)
1049            .with_random_state(42);
1050        let fitted = umap.fit(&x, &()).unwrap();
1051        assert_eq!(fitted.embedding().dim(), (5, 2));
1052    }
1053
1054    #[test]
1055    fn test_umap_fitted_accessors() {
1056        let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
1057        let umap = Umap::new().with_n_epochs(10).with_random_state(42);
1058        let fitted = umap.fit(&x, &()).unwrap();
1059        assert!(fitted.a() > 0.0);
1060        assert!(fitted.b() > 0.0);
1061    }
1062}