Skip to main content

ferrolearn_decomp/
umap.rs

1//! Uniform Manifold Approximation and Projection (UMAP).
2//!
3//! [`Umap`] performs non-linear dimensionality reduction based on the
4//! mathematical framework of Riemannian geometry and algebraic topology,
5//! as described by McInnes, Healy, and Melville (2018).
6//!
7//! # Algorithm
8//!
9//! 1. Build a k-nearest-neighbor graph with `n_neighbors` neighbors.
10//! 2. Compute the fuzzy simplicial set by smoothing kNN distances: for each
11//!    point find a local connectivity parameter `rho` (distance to nearest
12//!    neighbor) and a bandwidth `sigma` such that the sum of the membership
13//!    strengths equals `log2(n_neighbors)`.
14//! 3. Symmetrise the fuzzy graph: `w_ij = w_i|j + w_j|i - w_i|j * w_j|i`.
15//! 4. Determine curve parameters `a` and `b` from `min_dist` and `spread`
16//!    that define the target distribution in the embedding space:
17//!    `phi(d) = 1 / (1 + a * d^(2b))`.
18//! 5. Initialise the embedding (spectral or random).
19//! 6. Optimise via SGD with attractive forces on positive edges and
20//!    repulsive forces via negative sampling (5 negatives per positive).
21//!
22//! # Examples
23//!
24//! ```
25//! use ferrolearn_decomp::Umap;
26//! use ferrolearn_core::traits::{Fit, Transform};
27//! use ndarray::Array2;
28//!
29//! let x = Array2::<f64>::from_shape_fn((30, 5), |(i, j)| (i + j) as f64);
30//! let umap = Umap::new().with_random_state(42).with_n_epochs(50);
31//! let fitted = umap.fit(&x, &()).unwrap();
32//! let emb = fitted.embedding();
33//! assert_eq!(emb.ncols(), 2);
34//! ```
35
36use ferrolearn_core::error::FerroError;
37use ferrolearn_core::traits::{Fit, Transform};
38use ndarray::Array2;
39use rand::SeedableRng;
40use rand_distr::{Distribution, Uniform};
41use rand_xoshiro::Xoshiro256PlusPlus;
42
43// ---------------------------------------------------------------------------
44// Metric enum
45// ---------------------------------------------------------------------------
46
47/// Distance metric for UMAP.
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum UmapMetric {
50    /// Standard Euclidean distance.
51    Euclidean,
52    /// Manhattan (L1) distance.
53    Manhattan,
54    /// Cosine distance (1 - cosine similarity).
55    Cosine,
56}
57
58// ---------------------------------------------------------------------------
59// Umap (unfitted)
60// ---------------------------------------------------------------------------
61
62/// UMAP configuration.
63///
64/// Holds hyperparameters for the UMAP algorithm. Calling [`Fit::fit`]
65/// computes the embedding and returns a [`FittedUmap`].
66#[derive(Debug, Clone)]
67pub struct Umap {
68    /// Number of embedding dimensions (default 2).
69    n_components: usize,
70    /// Number of nearest neighbors for the kNN graph (default 15).
71    n_neighbors: usize,
72    /// Minimum distance in the embedding space (default 0.1).
73    min_dist: f64,
74    /// Spread of the embedding (default 1.0).
75    spread: f64,
76    /// Learning rate for SGD (default 1.0).
77    learning_rate: f64,
78    /// Number of SGD epochs (default 200).
79    n_epochs: usize,
80    /// Distance metric (default Euclidean).
81    metric: UmapMetric,
82    /// Number of negative samples per positive edge (default 5).
83    negative_sample_rate: usize,
84    /// Optional random seed for reproducibility.
85    random_state: Option<u64>,
86}
87
88impl Umap {
89    /// Create a new `Umap` with default parameters.
90    ///
91    /// Defaults: `n_components=2`, `n_neighbors=15`, `min_dist=0.1`,
92    /// `spread=1.0`, `learning_rate=1.0`, `n_epochs=200`, metric=`Euclidean`,
93    /// `negative_sample_rate=5`.
94    #[must_use]
95    pub fn new() -> Self {
96        Self {
97            n_components: 2,
98            n_neighbors: 15,
99            min_dist: 0.1,
100            spread: 1.0,
101            learning_rate: 1.0,
102            n_epochs: 200,
103            metric: UmapMetric::Euclidean,
104            negative_sample_rate: 5,
105            random_state: None,
106        }
107    }
108
109    /// Set the number of embedding dimensions.
110    #[must_use]
111    pub fn with_n_components(mut self, n: usize) -> Self {
112        self.n_components = n;
113        self
114    }
115
116    /// Set the number of nearest neighbors.
117    #[must_use]
118    pub fn with_n_neighbors(mut self, k: usize) -> Self {
119        self.n_neighbors = k;
120        self
121    }
122
123    /// Set the minimum distance in the embedding.
124    #[must_use]
125    pub fn with_min_dist(mut self, d: f64) -> Self {
126        self.min_dist = d;
127        self
128    }
129
130    /// Set the spread.
131    #[must_use]
132    pub fn with_spread(mut self, s: f64) -> Self {
133        self.spread = s;
134        self
135    }
136
137    /// Set the learning rate.
138    #[must_use]
139    pub fn with_learning_rate(mut self, lr: f64) -> Self {
140        self.learning_rate = lr;
141        self
142    }
143
144    /// Set the number of SGD epochs.
145    #[must_use]
146    pub fn with_n_epochs(mut self, n: usize) -> Self {
147        self.n_epochs = n;
148        self
149    }
150
151    /// Set the distance metric.
152    #[must_use]
153    pub fn with_metric(mut self, m: UmapMetric) -> Self {
154        self.metric = m;
155        self
156    }
157
158    /// Set the negative sample rate.
159    #[must_use]
160    pub fn with_negative_sample_rate(mut self, rate: usize) -> Self {
161        self.negative_sample_rate = rate;
162        self
163    }
164
165    /// Set the random seed.
166    #[must_use]
167    pub fn with_random_state(mut self, seed: u64) -> Self {
168        self.random_state = Some(seed);
169        self
170    }
171
172    /// Return the configured number of components.
173    #[must_use]
174    pub fn n_components(&self) -> usize {
175        self.n_components
176    }
177
178    /// Return the configured number of neighbors.
179    #[must_use]
180    pub fn n_neighbors(&self) -> usize {
181        self.n_neighbors
182    }
183
184    /// Return the configured minimum distance.
185    #[must_use]
186    pub fn min_dist(&self) -> f64 {
187        self.min_dist
188    }
189
190    /// Return the configured spread.
191    #[must_use]
192    pub fn spread(&self) -> f64 {
193        self.spread
194    }
195
196    /// Return the configured learning rate.
197    #[must_use]
198    pub fn learning_rate(&self) -> f64 {
199        self.learning_rate
200    }
201
202    /// Return the configured number of epochs.
203    #[must_use]
204    pub fn n_epochs(&self) -> usize {
205        self.n_epochs
206    }
207
208    /// Return the configured metric.
209    #[must_use]
210    pub fn metric(&self) -> UmapMetric {
211        self.metric
212    }
213
214    /// Return the configured negative sample rate.
215    #[must_use]
216    pub fn negative_sample_rate(&self) -> usize {
217        self.negative_sample_rate
218    }
219
220    /// Return the configured random state, if any.
221    #[must_use]
222    pub fn random_state(&self) -> Option<u64> {
223        self.random_state
224    }
225}
226
227impl Default for Umap {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233// ---------------------------------------------------------------------------
234// FittedUmap
235// ---------------------------------------------------------------------------
236
237/// A fitted UMAP model holding the learned embedding and training data.
238///
239/// Created by calling [`Fit::fit`] on a [`Umap`]. Implements
240/// [`Transform<Array2<f64>>`] for projecting new data via nearest-neighbor
241/// lookup.
242#[derive(Debug, Clone)]
243pub struct FittedUmap {
244    /// The embedding, shape `(n_samples, n_components)`.
245    embedding_: Array2<f64>,
246    /// Training data, stored for out-of-sample extension.
247    x_train_: Array2<f64>,
248    /// Curve parameter `a`.
249    a_: f64,
250    /// Curve parameter `b`.
251    b_: f64,
252    /// Number of neighbors used.
253    n_neighbors_: usize,
254    /// The metric used.
255    metric_: UmapMetric,
256}
257
258impl FittedUmap {
259    /// The embedding coordinates, shape `(n_samples, n_components)`.
260    #[must_use]
261    pub fn embedding(&self) -> &Array2<f64> {
262        &self.embedding_
263    }
264
265    /// The curve parameter `a`.
266    #[must_use]
267    pub fn a(&self) -> f64 {
268        self.a_
269    }
270
271    /// The curve parameter `b`.
272    #[must_use]
273    pub fn b(&self) -> f64 {
274        self.b_
275    }
276}
277
278// ---------------------------------------------------------------------------
279// Internal helpers
280// ---------------------------------------------------------------------------
281
282/// Compute distance between two rows of a matrix using the given metric.
283fn compute_distance(x: &Array2<f64>, i: usize, j: usize, metric: UmapMetric) -> f64 {
284    let ncols = x.ncols();
285    match metric {
286        UmapMetric::Euclidean => {
287            let mut sq = 0.0;
288            for k in 0..ncols {
289                let diff = x[[i, k]] - x[[j, k]];
290                sq += diff * diff;
291            }
292            sq.sqrt()
293        }
294        UmapMetric::Manhattan => {
295            let mut sum = 0.0;
296            for k in 0..ncols {
297                sum += (x[[i, k]] - x[[j, k]]).abs();
298            }
299            sum
300        }
301        UmapMetric::Cosine => {
302            let mut dot = 0.0;
303            let mut norm_i = 0.0;
304            let mut norm_j = 0.0;
305            for k in 0..ncols {
306                dot += x[[i, k]] * x[[j, k]];
307                norm_i += x[[i, k]] * x[[i, k]];
308                norm_j += x[[j, k]] * x[[j, k]];
309            }
310            let denom = (norm_i * norm_j).sqrt();
311            if denom < 1e-16 {
312                1.0
313            } else {
314                1.0 - dot / denom
315            }
316        }
317    }
318}
319
320/// Compute distance between a point (row of x_new) and a training point.
321fn compute_distance_cross(
322    x_new: &Array2<f64>,
323    i: usize,
324    x_train: &Array2<f64>,
325    j: usize,
326    metric: UmapMetric,
327) -> f64 {
328    let ncols = x_new.ncols();
329    match metric {
330        UmapMetric::Euclidean => {
331            let mut sq = 0.0;
332            for k in 0..ncols {
333                let diff = x_new[[i, k]] - x_train[[j, k]];
334                sq += diff * diff;
335            }
336            sq.sqrt()
337        }
338        UmapMetric::Manhattan => {
339            let mut sum = 0.0;
340            for k in 0..ncols {
341                sum += (x_new[[i, k]] - x_train[[j, k]]).abs();
342            }
343            sum
344        }
345        UmapMetric::Cosine => {
346            let mut dot = 0.0;
347            let mut norm_i = 0.0;
348            let mut norm_j = 0.0;
349            for k in 0..ncols {
350                dot += x_new[[i, k]] * x_train[[j, k]];
351                norm_i += x_new[[i, k]] * x_new[[i, k]];
352                norm_j += x_train[[j, k]] * x_train[[j, k]];
353            }
354            let denom = (norm_i * norm_j).sqrt();
355            if denom < 1e-16 {
356                1.0
357            } else {
358                1.0 - dot / denom
359            }
360        }
361    }
362}
363
364/// Build k-nearest-neighbor graph. Returns for each point the sorted list of
365/// (neighbor_index, distance) pairs.
366fn build_knn(x: &Array2<f64>, k: usize, metric: UmapMetric) -> Vec<Vec<(usize, f64)>> {
367    let n = x.nrows();
368    let mut knn = Vec::with_capacity(n);
369    for i in 0..n {
370        let mut dists: Vec<(usize, f64)> = (0..n)
371            .filter(|&j| j != i)
372            .map(|j| (j, compute_distance(x, i, j, metric)))
373            .collect();
374        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
375        dists.truncate(k);
376        knn.push(dists);
377    }
378    knn
379}
380
381/// Compute the fuzzy simplicial set: smooth kNN distances to get membership
382/// strengths.
383///
384/// For each point i, find `rho_i` (distance to nearest neighbor) and
385/// `sigma_i` such that `sum_j exp(-(d(i,j) - rho_i) / sigma_i) = log2(k)`.
386///
387/// Returns a sparse-ish weighted graph as a list of (i, j, weight) edges.
388fn compute_fuzzy_simplicial_set(knn: &[Vec<(usize, f64)>], n: usize) -> Vec<(usize, usize, f64)> {
389    let k = if knn.is_empty() { 0 } else { knn[0].len() };
390    let target = (k as f64).ln() / std::f64::consts::LN_2; // log2(k)
391
392    // For each point, compute rho and sigma.
393    let mut rho = vec![0.0; n];
394    let mut sigma = vec![1.0; n];
395
396    for i in 0..n {
397        if knn[i].is_empty() {
398            continue;
399        }
400        // rho_i = distance to nearest neighbor.
401        rho[i] = knn[i][0].1;
402        if rho[i] < 1e-16 {
403            // If nearest neighbor is at distance 0, find first non-zero.
404            for &(_, d) in &knn[i] {
405                if d > 1e-16 {
406                    rho[i] = d;
407                    break;
408                }
409            }
410        }
411
412        // Binary search for sigma.
413        let mut lo = 1e-20_f64;
414        let mut hi = 1e4_f64;
415        for _iter in 0..64 {
416            let mid = (lo + hi) / 2.0;
417            let mut val = 0.0;
418            for &(_, d) in &knn[i] {
419                let adjusted = (d - rho[i]).max(0.0);
420                val += (-adjusted / mid).exp();
421            }
422            if val > target {
423                hi = mid;
424            } else {
425                lo = mid;
426            }
427            if (hi - lo) / (lo + 1e-16) < 1e-5 {
428                break;
429            }
430        }
431        sigma[i] = (lo + hi) / 2.0;
432    }
433
434    // Build directed graph with membership strengths.
435    // w_{i|j} = exp(-(d(i,j) - rho_i) / sigma_i)  for j in knn(i)
436    let mut directed: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
437    for (i, neighbors) in knn.iter().enumerate() {
438        for &(j, d) in neighbors {
439            let adjusted = (d - rho[i]).max(0.0);
440            let w = (-adjusted / sigma[i]).exp();
441            directed[i].push((j, w));
442        }
443    }
444
445    // Symmetrise: w_ij = w_{i|j} + w_{j|i} - w_{i|j} * w_{j|i}
446    // Use a hash map approach for efficiency.
447    // Collect directed weights for each undirected edge.
448    let mut forward: std::collections::HashMap<(usize, usize), f64> =
449        std::collections::HashMap::new();
450    let mut backward: std::collections::HashMap<(usize, usize), f64> =
451        std::collections::HashMap::new();
452
453    for (i, neighbors) in directed.iter().enumerate() {
454        for &(j, w) in neighbors {
455            let key = if i < j { (i, j) } else { (j, i) };
456            if i < j {
457                forward.insert(key, w);
458            } else {
459                backward.insert(key, w);
460            }
461        }
462    }
463
464    // Combine keys.
465    let mut all_keys: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
466    for &k in forward.keys() {
467        all_keys.insert(k);
468    }
469    for &k in backward.keys() {
470        all_keys.insert(k);
471    }
472
473    let mut edges = Vec::with_capacity(all_keys.len());
474    for key in all_keys {
475        let w_fwd = forward.get(&key).copied().unwrap_or(0.0);
476        let w_bwd = backward.get(&key).copied().unwrap_or(0.0);
477        let w = w_fwd + w_bwd - w_fwd * w_bwd;
478        if w > 1e-16 {
479            edges.push((key.0, key.1, w));
480        }
481    }
482
483    edges
484}
485
486/// Find curve parameters `a` and `b` from `min_dist` and `spread`.
487///
488/// We want `1 / (1 + a * d^(2b)) ~ 1` when `d < min_dist` and
489/// `exp(-(d - min_dist) / spread)` when `d >= min_dist`.
490///
491/// This is solved by a simple grid search / least squares fit.
492fn find_ab_params(min_dist: f64, spread: f64) -> (f64, f64) {
493    // Sample distances and target values.
494    let n_samples = 300;
495    let d_max = 3.0 * spread;
496    let mut best_a = 1.0;
497    let mut best_b = 1.0;
498    let mut best_err = f64::MAX;
499
500    // Grid search over a and b.
501    let a_range: Vec<f64> = (1..=40).map(|i| i as f64 * 0.25).collect();
502    let b_range: Vec<f64> = (1..=30).map(|i| i as f64 * 0.1).collect();
503
504    for &a in &a_range {
505        for &b in &b_range {
506            let mut err = 0.0;
507            for k in 0..n_samples {
508                let d = (k as f64 + 0.5) / n_samples as f64 * d_max;
509                let target = if d <= min_dist {
510                    1.0
511                } else {
512                    (-(d - min_dist) / spread).exp()
513                };
514                let pred = 1.0 / (1.0 + a * d.powf(2.0 * b));
515                let diff = pred - target;
516                err += diff * diff;
517            }
518            if err < best_err {
519                best_err = err;
520                best_a = a;
521                best_b = b;
522            }
523        }
524    }
525
526    // Refine with a finer grid around the best.
527    let a_lo = (best_a - 0.3).max(0.01);
528    let a_hi = best_a + 0.3;
529    let b_lo = (best_b - 0.15).max(0.01);
530    let b_hi = best_b + 0.15;
531
532    for ia in 0..20 {
533        let a = a_lo + (a_hi - a_lo) * ia as f64 / 19.0;
534        for ib in 0..20 {
535            let b = b_lo + (b_hi - b_lo) * ib as f64 / 19.0;
536            let mut err = 0.0;
537            for k in 0..n_samples {
538                let d = (k as f64 + 0.5) / n_samples as f64 * d_max;
539                let target = if d <= min_dist {
540                    1.0
541                } else {
542                    (-(d - min_dist) / spread).exp()
543                };
544                let pred = 1.0 / (1.0 + a * d.powf(2.0 * b));
545                let diff = pred - target;
546                err += diff * diff;
547            }
548            if err < best_err {
549                best_err = err;
550                best_a = a;
551                best_b = b;
552            }
553        }
554    }
555
556    (best_a, best_b)
557}
558
559/// Clip a value to prevent overflow/underflow in gradient computation.
560fn clip(val: f64, lo: f64, hi: f64) -> f64 {
561    if val < lo {
562        lo
563    } else if val > hi {
564        hi
565    } else {
566        val
567    }
568}
569
570// ---------------------------------------------------------------------------
571// Trait implementations
572// ---------------------------------------------------------------------------
573
574impl Fit<Array2<f64>, ()> for Umap {
575    type Fitted = FittedUmap;
576    type Error = FerroError;
577
578    /// Fit UMAP by computing the fuzzy simplicial set and optimising the
579    /// low-dimensional embedding via SGD.
580    ///
581    /// # Errors
582    ///
583    /// - [`FerroError::InvalidParameter`] if `n_components` is zero,
584    ///   `n_neighbors` is zero or too large, `min_dist` is negative,
585    ///   `spread` is non-positive, or `learning_rate` is non-positive.
586    /// - [`FerroError::InsufficientSamples`] if there are fewer samples than
587    ///   `n_neighbors + 1`.
588    fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedUmap, FerroError> {
589        let n = x.nrows();
590
591        // Validate parameters.
592        if self.n_components == 0 {
593            return Err(FerroError::InvalidParameter {
594                name: "n_components".into(),
595                reason: "must be at least 1".into(),
596            });
597        }
598        if self.n_neighbors == 0 {
599            return Err(FerroError::InvalidParameter {
600                name: "n_neighbors".into(),
601                reason: "must be at least 1".into(),
602            });
603        }
604        if n < 2 {
605            return Err(FerroError::InsufficientSamples {
606                required: 2,
607                actual: n,
608                context: "Umap::fit requires at least 2 samples".into(),
609            });
610        }
611        let effective_k = self.n_neighbors.min(n - 1);
612        if self.min_dist < 0.0 {
613            return Err(FerroError::InvalidParameter {
614                name: "min_dist".into(),
615                reason: "must be non-negative".into(),
616            });
617        }
618        if self.spread <= 0.0 {
619            return Err(FerroError::InvalidParameter {
620                name: "spread".into(),
621                reason: "must be positive".into(),
622            });
623        }
624        if self.learning_rate <= 0.0 {
625            return Err(FerroError::InvalidParameter {
626                name: "learning_rate".into(),
627                reason: "must be positive".into(),
628            });
629        }
630
631        let dim = self.n_components;
632        let seed = self.random_state.unwrap_or(0);
633
634        // Step 1: Build kNN graph.
635        let knn = build_knn(x, effective_k, self.metric);
636
637        // Step 2: Compute fuzzy simplicial set.
638        let edges = compute_fuzzy_simplicial_set(&knn, n);
639
640        // Step 3: Find a, b parameters.
641        let (a, b) = find_ab_params(self.min_dist, self.spread);
642
643        // Step 4: Initialise embedding (random uniform).
644        let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
645        let uniform = Uniform::new(-10.0, 10.0).unwrap();
646        let mut y = Array2::<f64>::zeros((n, dim));
647        for elem in y.iter_mut() {
648            *elem = uniform.sample(&mut rng);
649        }
650
651        // Pre-compute epochs per edge: spread epochs proportional to weight.
652        if edges.is_empty() {
653            return Ok(FittedUmap {
654                embedding_: y,
655                x_train_: x.to_owned(),
656                a_: a,
657                b_: b,
658                n_neighbors_: effective_k,
659                metric_: self.metric,
660            });
661        }
662
663        let max_weight = edges
664            .iter()
665            .map(|e| e.2)
666            .fold(0.0_f64, |a_val, b_val| a_val.max(b_val));
667
668        // Each edge gets `n_epochs * (weight / max_weight)` total updates.
669        let epochs_per_sample: Vec<f64> = edges
670            .iter()
671            .map(|e| {
672                let ratio = e.2 / max_weight;
673                if ratio > 0.0 {
674                    (self.n_epochs as f64) / ((self.n_epochs as f64) * ratio).max(1.0)
675                } else {
676                    f64::MAX
677                }
678            })
679            .collect();
680
681        let mut epoch_of_next_sample: Vec<f64> = epochs_per_sample.clone();
682
683        let neg_rate = self.negative_sample_rate;
684        let idx_uniform = Uniform::new(0usize, n).unwrap();
685
686        // Step 5: SGD optimisation.
687        for epoch in 0..self.n_epochs {
688            let alpha = self.learning_rate * (1.0 - epoch as f64 / self.n_epochs as f64);
689            let alpha = alpha.max(0.0);
690
691            for (edge_idx, &(ei, ej, _weight)) in edges.iter().enumerate() {
692                if epoch_of_next_sample[edge_idx] > epoch as f64 {
693                    continue;
694                }
695
696                // Attractive force.
697                let mut dist_sq = 0.0;
698                for d in 0..dim {
699                    let diff = y[[ei, d]] - y[[ej, d]];
700                    dist_sq += diff * diff;
701                }
702                let dist_sq = dist_sq.max(1e-16);
703
704                let grad_coeff = -2.0 * a * b * dist_sq.powf(b - 1.0) / (1.0 + a * dist_sq.powf(b));
705
706                for d in 0..dim {
707                    let diff = y[[ei, d]] - y[[ej, d]];
708                    let grad = clip(grad_coeff * diff, -4.0, 4.0);
709                    y[[ei, d]] += alpha * grad;
710                    y[[ej, d]] -= alpha * grad;
711                }
712
713                // Negative sampling.
714                for _ in 0..neg_rate {
715                    let neg = idx_uniform.sample(&mut rng);
716                    if neg == ei {
717                        continue;
718                    }
719                    let mut neg_dist_sq = 0.0;
720                    for d in 0..dim {
721                        let diff = y[[ei, d]] - y[[neg, d]];
722                        neg_dist_sq += diff * diff;
723                    }
724                    let neg_dist_sq = neg_dist_sq.max(1e-16);
725
726                    let rep_coeff =
727                        2.0 * b / ((0.001 + neg_dist_sq) * (1.0 + a * neg_dist_sq.powf(b)));
728
729                    for d in 0..dim {
730                        let diff = y[[ei, d]] - y[[neg, d]];
731                        let grad = clip(rep_coeff * diff, -4.0, 4.0);
732                        y[[ei, d]] += alpha * grad;
733                    }
734                }
735
736                epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
737            }
738        }
739
740        Ok(FittedUmap {
741            embedding_: y,
742            x_train_: x.to_owned(),
743            a_: a,
744            b_: b,
745            n_neighbors_: effective_k,
746            metric_: self.metric,
747        })
748    }
749}
750
751impl Transform<Array2<f64>> for FittedUmap {
752    type Output = Array2<f64>;
753    type Error = FerroError;
754
755    /// Project new data into the UMAP embedding space.
756    ///
757    /// For each new point, find the nearest neighbors in the training data
758    /// and compute a weighted average of their embeddings (weighted by the
759    /// UMAP kernel).
760    ///
761    /// # Errors
762    ///
763    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
764    /// not match the training data.
765    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
766        let n_features = self.x_train_.ncols();
767        if x.ncols() != n_features {
768            return Err(FerroError::ShapeMismatch {
769                expected: vec![x.nrows(), n_features],
770                actual: vec![x.nrows(), x.ncols()],
771                context: "FittedUmap::transform".into(),
772            });
773        }
774
775        let n_test = x.nrows();
776        let n_train = self.x_train_.nrows();
777        let dim = self.embedding_.ncols();
778        let k = self.n_neighbors_.min(n_train);
779
780        let mut result = Array2::<f64>::zeros((n_test, dim));
781
782        for t in 0..n_test {
783            // Find k nearest training neighbors.
784            let mut dists: Vec<(usize, f64)> = (0..n_train)
785                .map(|j| {
786                    (
787                        j,
788                        compute_distance_cross(x, t, &self.x_train_, j, self.metric_),
789                    )
790                })
791                .collect();
792            dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
793            dists.truncate(k);
794
795            // Compute weights using the UMAP kernel: 1/(1 + a * d^(2b)).
796            let mut weights = Vec::with_capacity(k);
797            let mut weight_sum = 0.0;
798            for &(_, d) in &dists {
799                let w = 1.0 / (1.0 + self.a_ * d.powf(2.0 * self.b_));
800                weights.push(w);
801                weight_sum += w;
802            }
803
804            if weight_sum < 1e-16 {
805                // Fallback: uniform weights.
806                weight_sum = k as f64;
807                weights = vec![1.0; k];
808            }
809
810            // Weighted average of neighbor embeddings.
811            for (idx, &(train_idx, _)) in dists.iter().enumerate() {
812                let w = weights[idx] / weight_sum;
813                for d in 0..dim {
814                    result[[t, d]] += w * self.embedding_[[train_idx, d]];
815                }
816            }
817        }
818
819        Ok(result)
820    }
821}
822
823// ---------------------------------------------------------------------------
824// Tests
825// ---------------------------------------------------------------------------
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830    use ndarray::Array2;
831    use rand::SeedableRng;
832    use rand_distr::{Distribution, Normal};
833    use rand_xoshiro::Xoshiro256PlusPlus;
834
835    /// Generate small blobs dataset.
836    fn make_blobs(seed: u64) -> (Array2<f64>, Vec<usize>) {
837        let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
838        let normal = Normal::new(0.0, 0.3).unwrap();
839        let n_per_cluster = 10;
840        let n_features = 5;
841        let centers = vec![
842            vec![0.0, 0.0, 0.0, 0.0, 0.0],
843            vec![5.0, 5.0, 5.0, 5.0, 5.0],
844            vec![10.0, 0.0, 10.0, 0.0, 10.0],
845        ];
846        let n = centers.len() * n_per_cluster;
847        let mut x = Array2::<f64>::zeros((n, n_features));
848        let mut labels = Vec::with_capacity(n);
849        for (c_idx, center) in centers.iter().enumerate() {
850            for i in 0..n_per_cluster {
851                let row = c_idx * n_per_cluster + i;
852                for (f, &c) in center.iter().enumerate() {
853                    x[[row, f]] = c + normal.sample(&mut rng);
854                }
855                labels.push(c_idx);
856            }
857        }
858        (x, labels)
859    }
860
861    #[test]
862    fn test_umap_basic_shape() {
863        let x = Array2::<f64>::from_shape_fn((30, 5), |(i, j)| (i + j) as f64);
864        let umap = Umap::new().with_n_epochs(10).with_random_state(42);
865        let fitted = umap.fit(&x, &()).unwrap();
866        assert_eq!(fitted.embedding().dim(), (30, 2));
867    }
868
869    #[test]
870    fn test_umap_3d_embedding() {
871        let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
872        let umap = Umap::new()
873            .with_n_components(3)
874            .with_n_epochs(10)
875            .with_random_state(42);
876        let fitted = umap.fit(&x, &()).unwrap();
877        assert_eq!(fitted.embedding().ncols(), 3);
878    }
879
880    #[test]
881    fn test_umap_separates_clusters() {
882        let (x, labels) = make_blobs(42);
883        let umap = Umap::new()
884            .with_n_neighbors(5)
885            .with_n_epochs(100)
886            .with_random_state(42);
887        let fitted = umap.fit(&x, &()).unwrap();
888        let emb = fitted.embedding();
889
890        // Check cluster separation with k-NN accuracy (k=3).
891        let n = emb.nrows();
892        let mut correct = 0;
893        for i in 0..n {
894            let mut dists: Vec<(f64, usize)> = (0..n)
895                .filter(|&j| j != i)
896                .map(|j| {
897                    let mut d = 0.0;
898                    for dd in 0..emb.ncols() {
899                        let diff = emb[[i, dd]] - emb[[j, dd]];
900                        d += diff * diff;
901                    }
902                    (d, labels[j])
903                })
904                .collect();
905            dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
906            let mut votes = [0usize; 3];
907            for &(_, lbl) in dists.iter().take(3) {
908                votes[lbl] += 1;
909            }
910            let pred = votes.iter().enumerate().max_by_key(|&(_, v)| v).unwrap().0;
911            if pred == labels[i] {
912                correct += 1;
913            }
914        }
915        let accuracy = correct as f64 / n as f64;
916        assert!(
917            accuracy > 0.8,
918            "UMAP k-NN accuracy should be > 80%, got {:.1}%",
919            accuracy * 100.0
920        );
921    }
922
923    #[test]
924    fn test_umap_transform_new_data() {
925        let (x, _) = make_blobs(42);
926        let umap = Umap::new()
927            .with_n_neighbors(5)
928            .with_n_epochs(50)
929            .with_random_state(42);
930        let fitted = umap.fit(&x, &()).unwrap();
931
932        // Transform a subset of training data.
933        let x_test = x.slice(ndarray::s![0..5, ..]).to_owned();
934        let projected = fitted.transform(&x_test).unwrap();
935        assert_eq!(projected.dim(), (5, 2));
936    }
937
938    #[test]
939    fn test_umap_transform_shape_mismatch() {
940        let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
941        let umap = Umap::new().with_n_epochs(10).with_random_state(42);
942        let fitted = umap.fit(&x, &()).unwrap();
943        let x_bad = Array2::<f64>::zeros((5, 3)); // wrong number of features
944        assert!(fitted.transform(&x_bad).is_err());
945    }
946
947    #[test]
948    fn test_umap_ab_params_reasonable() {
949        let (a, b) = find_ab_params(0.1, 1.0);
950        // a and b should be positive.
951        assert!(a > 0.0, "a should be positive, got {a}");
952        assert!(b > 0.0, "b should be positive, got {b}");
953        // At d=0, 1/(1+a*0) = 1, which is correct.
954        // At d=min_dist, should be close to 1.
955        let val_at_min = 1.0 / (1.0 + a * (0.1_f64).powf(2.0 * b));
956        assert!(
957            val_at_min > 0.5,
958            "kernel at min_dist should be > 0.5, got {val_at_min}"
959        );
960    }
961
962    #[test]
963    fn test_umap_invalid_n_components_zero() {
964        let x = Array2::<f64>::zeros((10, 3));
965        let umap = Umap::new().with_n_components(0);
966        assert!(umap.fit(&x, &()).is_err());
967    }
968
969    #[test]
970    fn test_umap_invalid_n_neighbors_zero() {
971        let x = Array2::<f64>::zeros((10, 3));
972        let umap = Umap::new().with_n_neighbors(0);
973        assert!(umap.fit(&x, &()).is_err());
974    }
975
976    #[test]
977    fn test_umap_invalid_min_dist() {
978        let x = Array2::<f64>::zeros((10, 3));
979        let umap = Umap::new().with_min_dist(-0.1);
980        assert!(umap.fit(&x, &()).is_err());
981    }
982
983    #[test]
984    fn test_umap_invalid_spread() {
985        let x = Array2::<f64>::zeros((10, 3));
986        let umap = Umap::new().with_spread(0.0);
987        assert!(umap.fit(&x, &()).is_err());
988    }
989
990    #[test]
991    fn test_umap_invalid_learning_rate() {
992        let x = Array2::<f64>::zeros((10, 3));
993        let umap = Umap::new().with_learning_rate(-1.0);
994        assert!(umap.fit(&x, &()).is_err());
995    }
996
997    #[test]
998    fn test_umap_insufficient_samples() {
999        let x = Array2::<f64>::zeros((1, 3));
1000        let umap = Umap::new();
1001        assert!(umap.fit(&x, &()).is_err());
1002    }
1003
1004    #[test]
1005    fn test_umap_getters() {
1006        let umap = Umap::new()
1007            .with_n_components(3)
1008            .with_n_neighbors(10)
1009            .with_min_dist(0.2)
1010            .with_spread(1.5)
1011            .with_learning_rate(0.5)
1012            .with_n_epochs(100)
1013            .with_metric(UmapMetric::Manhattan)
1014            .with_negative_sample_rate(3)
1015            .with_random_state(99);
1016        assert_eq!(umap.n_components(), 3);
1017        assert_eq!(umap.n_neighbors(), 10);
1018        assert!((umap.min_dist() - 0.2).abs() < 1e-10);
1019        assert!((umap.spread() - 1.5).abs() < 1e-10);
1020        assert!((umap.learning_rate() - 0.5).abs() < 1e-10);
1021        assert_eq!(umap.n_epochs(), 100);
1022        assert_eq!(umap.metric(), UmapMetric::Manhattan);
1023        assert_eq!(umap.negative_sample_rate(), 3);
1024        assert_eq!(umap.random_state(), Some(99));
1025    }
1026
1027    #[test]
1028    fn test_umap_default() {
1029        let umap = Umap::default();
1030        assert_eq!(umap.n_components(), 2);
1031        assert_eq!(umap.n_neighbors(), 15);
1032    }
1033
1034    #[test]
1035    fn test_umap_cosine_metric() {
1036        let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j + 1) as f64);
1037        let umap = Umap::new()
1038            .with_metric(UmapMetric::Cosine)
1039            .with_n_epochs(10)
1040            .with_random_state(42);
1041        let fitted = umap.fit(&x, &()).unwrap();
1042        assert_eq!(fitted.embedding().dim(), (20, 2));
1043    }
1044
1045    #[test]
1046    fn test_umap_small_n_neighbors_capped() {
1047        // n_neighbors > n-1 should be automatically capped
1048        let x = Array2::<f64>::from_shape_fn((5, 3), |(i, j)| (i + j) as f64);
1049        let umap = Umap::new()
1050            .with_n_neighbors(100)
1051            .with_n_epochs(10)
1052            .with_random_state(42);
1053        let fitted = umap.fit(&x, &()).unwrap();
1054        assert_eq!(fitted.embedding().dim(), (5, 2));
1055    }
1056
1057    #[test]
1058    fn test_umap_fitted_accessors() {
1059        let x = Array2::<f64>::from_shape_fn((20, 4), |(i, j)| (i + j) as f64);
1060        let umap = Umap::new().with_n_epochs(10).with_random_state(42);
1061        let fitted = umap.fit(&x, &()).unwrap();
1062        assert!(fitted.a() > 0.0);
1063        assert!(fitted.b() > 0.0);
1064    }
1065}