fast_umap/train/config.rs
1use serde::{Deserialize, Serialize};
2use std::fmt;
3use std::path::PathBuf;
4
5// ─── UMAP a, b curve fitting ─────────────────────────────────────────────────
6
7/// Fit the UMAP kernel parameters `a` and `b` from `min_dist` and `spread`.
8///
9/// The kernel is `phi(d) = 1 / (1 + a * d^(2b))`.
10/// We fit it to the piecewise target:
11/// - `phi(d) = 1.0` for `d <= min_dist`
12/// - `phi(d) = exp(-(d - min_dist) / spread)` for `d > min_dist`
13///
14/// Uses simple grid search + refinement (runs once at init, ~ms).
15pub fn fit_ab(min_dist: f32, spread: f32) -> (f32, f32) {
16 // Generate target data
17 let n = 300;
18 let x_max = 3.0 * spread;
19 let xs: Vec<f32> = (0..n).map(|i| (i as f32 + 0.5) / n as f32 * x_max).collect();
20 let ys: Vec<f32> = xs
21 .iter()
22 .map(|&x| {
23 if x <= min_dist {
24 1.0
25 } else {
26 (-(x - min_dist) / spread).exp()
27 }
28 })
29 .collect();
30
31 // phi(d; a, b) = 1 / (1 + a * d^(2b))
32 // Minimize sum of squared residuals
33 let residual = |a: f32, b: f32| -> f32 {
34 xs.iter()
35 .zip(ys.iter())
36 .map(|(&x, &y)| {
37 let pred = 1.0 / (1.0 + a * x.powf(2.0 * b));
38 (pred - y) * (pred - y)
39 })
40 .sum::<f32>()
41 };
42
43 // Coarse grid search
44 let mut best_a = 1.0f32;
45 let mut best_b = 1.0f32;
46 let mut best_err = f32::INFINITY;
47
48 for ai in 1..=80 {
49 let a = ai as f32 * 0.08;
50 for bi in 1..=50 {
51 let b = bi as f32 * 0.06;
52 let err = residual(a, b);
53 if err < best_err {
54 best_err = err;
55 best_a = a;
56 best_b = b;
57 }
58 }
59 }
60
61 // Fine refinement via coordinate descent
62 for _ in 0..100 {
63 let step_a = best_a * 0.02;
64 let step_b = best_b * 0.02;
65 for &da in &[-step_a, 0.0, step_a] {
66 for &db in &[-step_b, 0.0, step_b] {
67 let a = (best_a + da).max(1e-4);
68 let b = (best_b + db).max(1e-4);
69 let err = residual(a, b);
70 if err < best_err {
71 best_err = err;
72 best_a = a;
73 best_b = b;
74 }
75 }
76 }
77 }
78
79 (best_a, best_b)
80}
81
82// ─── Metric ──────────────────────────────────────────────────────────────────
83
84/// Distance metric used to build the high-dimensional k-NN graph during the
85/// precomputation phase.
86///
87/// The choice of metric determines how "closeness" is measured in the original
88/// feature space. [`Euclidean`](Metric::Euclidean) (L2) is the default and
89/// works well for most continuous data.
90#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
91pub enum Metric {
92 /// Standard L2 (Euclidean) distance — default.
93 Euclidean,
94 /// Euclidean distance computed via the GPU k-NN kernel path.
95 EuclideanKNN,
96 /// L1 (Manhattan / taxicab) distance.
97 Manhattan,
98 /// Cosine dissimilarity `1 − cos(θ)`.
99 Cosine,
100 /// Generalised Minkowski distance of order `p`
101 /// (`p = 1` → Manhattan, `p = 2` → Euclidean).
102 Minkowski,
103}
104
105impl From<&str> for Metric {
106 fn from(s: &str) -> Self {
107 match s.to_lowercase().as_str() {
108 "euclidean" => Metric::Euclidean,
109 "euclideanknn" | "euclidean_knn" => Metric::EuclideanKNN,
110 "manhattan" => Metric::Manhattan,
111 "cosine" => Metric::Cosine,
112 "minkowski" => Metric::Minkowski,
113 _ => panic!("Invalid metric type: {}", s),
114 }
115 }
116}
117
118impl fmt::Display for Metric {
119 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120 match self {
121 Metric::Euclidean => write!(f, "Euclidean"),
122 Metric::EuclideanKNN => write!(f, "Euclidean KNN"),
123 Metric::Manhattan => write!(f, "Manhattan"),
124 Metric::Cosine => write!(f, "cosine"),
125 Metric::Minkowski => write!(f, "minkowski"),
126 }
127 }
128}
129
130// ─── LossReduction ───────────────────────────────────────────────────────────
131
132/// How the per-sample losses are combined into a single scalar for
133/// backpropagation.
134///
135/// * [`Mean`](LossReduction::Mean) - divide by the number of elements
136/// (scale-invariant, recommended for most use cases).
137/// * [`Sum`](LossReduction::Sum) - sum without normalisation (sensitive to
138/// batch size; may require a lower learning rate).
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub enum LossReduction {
141 /// Average the loss over all contributing pairs.
142 Mean,
143 /// Sum the loss over all contributing pairs without normalisation.
144 Sum,
145}
146
147// ─── ManifoldParams ──────────────────────────────────────────────────────────
148
149/// Configuration for manifold shape and embedding space properties.
150///
151/// These parameters control the geometric properties of the low-dimensional
152/// embedding space and how the manifold is shaped.
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ManifoldParams {
155 /// Minimum distance between points in the embedding space.
156 ///
157 /// Controls how tightly points can be packed together. Smaller values
158 /// create more clustered embeddings, larger values spread points out more.
159 ///
160 /// Default: 0.1
161 pub min_dist: f32,
162
163 /// The effective scale of embedded points.
164 ///
165 /// Together with `min_dist`, this determines the embedding's overall spread.
166 ///
167 /// Default: 1.0
168 pub spread: f32,
169}
170
171impl Default for ManifoldParams {
172 fn default() -> Self {
173 Self {
174 min_dist: 0.1,
175 spread: 1.0,
176 }
177 }
178}
179
180// ─── GraphParams ─────────────────────────────────────────────────────────────
181
182/// Configuration for k-nearest neighbor graph construction.
183///
184/// These parameters control how the high-dimensional manifold structure
185/// is captured via a fuzzy topological representation.
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct GraphParams {
188 /// Number of nearest neighbors to use for manifold approximation.
189 ///
190 /// Larger values capture more global structure but may miss fine details.
191 /// Smaller values focus on local structure but may fragment the manifold.
192 ///
193 /// Default: 15
194 pub n_neighbors: usize,
195
196 /// The distance metric to use for building the k-NN graph.
197 ///
198 /// Default: Euclidean
199 pub metric: Metric,
200
201 /// Whether to normalize distance outputs before use in the loss.
202 ///
203 /// Default: true
204 pub normalized: bool,
205
206 /// The Minkowski `p` parameter (only used when metric is Minkowski).
207 ///
208 /// Default: 1.0
209 pub minkowski_p: f64,
210}
211
212impl Default for GraphParams {
213 fn default() -> Self {
214 Self {
215 n_neighbors: 15,
216 metric: Metric::Euclidean,
217 normalized: true,
218 minkowski_p: 1.0,
219 }
220 }
221}
222
223// ─── OptimizationParams ──────────────────────────────────────────────────────
224
225/// Configuration for stochastic gradient descent optimization.
226///
227/// These parameters control the embedding optimization process.
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct OptimizationParams {
230 /// Number of optimization epochs.
231 ///
232 /// Default: 100
233 pub n_epochs: usize,
234
235 /// The number of samples to process in each training batch.
236 ///
237 /// Default: 1000
238 pub batch_size: usize,
239
240 /// Initial learning rate for the Adam optimizer.
241 ///
242 /// Default: 0.001
243 pub learning_rate: f64,
244
245 /// Beta1 parameter for the Adam optimizer.
246 ///
247 /// Default: 0.9
248 pub beta1: f64,
249
250 /// Beta2 parameter for the Adam optimizer.
251 ///
252 /// Default: 0.999
253 pub beta2: f64,
254
255 /// L2 regularization (weight decay) penalty.
256 ///
257 /// Default: 1e-5
258 pub penalty: f32,
259
260 /// Weight applied to the repulsion term of the UMAP cross-entropy loss.
261 ///
262 /// Default: 1.0
263 pub repulsion_strength: f32,
264
265 /// Number of epochs to wait for improvement before triggering early stopping.
266 /// `None` disables early stopping.
267 ///
268 /// Default: None
269 pub patience: Option<i32>,
270
271 /// The method used to reduce the loss (mean or sum).
272 ///
273 /// Default: Sum
274 pub loss_reduction: LossReduction,
275
276 /// Minimum desired loss to achieve before stopping early.
277 ///
278 /// Default: None
279 pub min_desired_loss: Option<f64>,
280
281 /// Maximum training time in seconds. `None` means no limit.
282 ///
283 /// Default: None
284 pub timeout: Option<u64>,
285
286 /// Whether to show detailed progress information during training.
287 ///
288 /// Default: false
289 pub verbose: bool,
290
291 /// Number of negative (repulsion) samples drawn per positive (attraction)
292 /// edge each epoch.
293 ///
294 /// Higher values produce stronger repulsion and better cluster separation
295 /// at the cost of more computation per epoch.
296 ///
297 /// Default: 5
298 pub neg_sample_rate: usize,
299
300 /// Milliseconds to sleep at the end of every training epoch.
301 ///
302 /// Inserting a small pause between epochs lets the GPU scheduler breathe,
303 /// preventing the device from being pinned at 100 % utilisation for the
304 /// entire run. Typical values:
305 ///
306 /// | `cooldown_ms` | Effect |
307 /// |---------------|--------|
308 /// | `0` (default) | No pause — maximum throughput |
309 /// | `1–5` | Barely perceptible pause, ~10–20 % GPU headroom |
310 /// | `10–50` | Noticeable slowdown, significant GPU headroom |
311 ///
312 /// Default: 0 (disabled)
313 pub cooldown_ms: u64,
314
315 /// Directory where loss-curve and embedding snapshot plots are written
316 /// when `verbose` is `true` (or the `verbose` feature flag is enabled).
317 ///
318 /// Defaults to `"figures"` (relative to the current working directory).
319 /// Set this to an absolute path — or any writable location — if the
320 /// process runs on a read-only filesystem.
321 ///
322 /// When `None` the default `"figures"` directory is used.
323 #[serde(skip)]
324 pub figures_dir: Option<PathBuf>,
325}
326
327impl Default for OptimizationParams {
328 fn default() -> Self {
329 Self {
330 n_epochs: 100,
331 batch_size: 1000,
332 learning_rate: 0.001,
333 beta1: 0.9,
334 beta2: 0.999,
335 penalty: 1e-5,
336 repulsion_strength: 1.0,
337 patience: None,
338 loss_reduction: LossReduction::Sum,
339 min_desired_loss: None,
340 timeout: None,
341 verbose: false,
342 neg_sample_rate: 5,
343 cooldown_ms: 0,
344 figures_dir: None,
345 }
346 }
347}
348
349// ─── UmapConfig ──────────────────────────────────────────────────────────────
350
351/// Complete UMAP configuration.
352///
353/// Groups all parameters for dimensionality reduction into a coherent structure.
354/// All parameter groups have sensible defaults and can be customized individually.
355///
356/// This struct mirrors the configuration style of
357/// [`umap-rs`](https://crates.io/crates/umap-rs) with nested parameter groups.
358///
359/// # Example
360///
361/// ```ignore
362/// use fast_umap::prelude::*;
363///
364/// // Use all defaults (2-D output, Euclidean metric)
365/// let config = UmapConfig::default();
366///
367/// // Customize specific groups
368/// let config = UmapConfig {
369/// n_components: 3,
370/// graph: GraphParams {
371/// n_neighbors: 30,
372/// ..Default::default()
373/// },
374/// optimization: OptimizationParams {
375/// n_epochs: 500,
376/// learning_rate: 1e-3,
377/// ..Default::default()
378/// },
379/// ..Default::default()
380/// };
381/// ```
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct UmapConfig {
384 /// Number of dimensions in the output embedding.
385 ///
386 /// Typically 2 for visualization or 3-50 for downstream ML tasks.
387 ///
388 /// Default: 2
389 pub n_components: usize,
390
391 /// Hidden layer sizes for the parametric neural network.
392 ///
393 /// Default: [100]
394 pub hidden_sizes: Vec<usize>,
395
396 /// Manifold shape configuration.
397 pub manifold: ManifoldParams,
398
399 /// Graph construction configuration.
400 pub graph: GraphParams,
401
402 /// Optimization configuration.
403 pub optimization: OptimizationParams,
404}
405
406impl Default for UmapConfig {
407 fn default() -> Self {
408 Self {
409 n_components: 2,
410 hidden_sizes: vec![100],
411 manifold: ManifoldParams::default(),
412 graph: GraphParams::default(),
413 optimization: OptimizationParams::default(),
414 }
415 }
416}
417
418// ─── TrainingConfig (backward compatibility) ─────────────────────────────────
419
420/// Configuration for training the UMAP model.
421///
422/// **Deprecated**: Use [`UmapConfig`] instead. This type is provided for
423/// backward compatibility and converts to/from `UmapConfig`.
424#[derive(Debug, Clone)]
425pub struct TrainingConfig {
426 /// The distance metric to use for training the model.
427 pub metric: Metric,
428 /// The total number of epochs to run during training.
429 pub epochs: usize,
430 /// The number of samples to process in each training batch.
431 pub batch_size: usize,
432 /// The learning rate for the optimizer.
433 pub learning_rate: f64,
434 /// The Beta1 parameter for the Adam optimizer.
435 pub beta1: f64,
436 /// The Beta2 parameter for the Adam optimizer.
437 pub beta2: f64,
438 /// The L2 regularization (weight decay) penalty.
439 pub penalty: f32,
440 /// Whether to show detailed progress information during training.
441 pub verbose: bool,
442 /// The number of epochs to wait for improvement before triggering early stopping.
443 pub patience: Option<i32>,
444 /// The method used to reduce the loss during training.
445 pub loss_reduction: LossReduction,
446 /// The number of nearest neighbors to consider.
447 pub k_neighbors: usize,
448 /// Minimum desired loss to achieve before stopping early.
449 pub min_desired_loss: Option<f64>,
450 /// Maximum training time in seconds.
451 pub timeout: Option<u64>,
452 /// Normalize distance output.
453 pub normalized: bool,
454 /// Minkowski p parameter.
455 pub minkowski_p: f64,
456 /// Weight applied to the repulsion term.
457 pub repulsion_strength: f32,
458 /// UMAP kernel parameter `a`, fitted from `min_dist` and `spread`.
459 /// Controls the width of the kernel: `q = 1 / (1 + a * d^(2b))`.
460 pub kernel_a: f32,
461 /// UMAP kernel parameter `b`, fitted from `min_dist` and `spread`.
462 /// Controls the decay shape: `q = 1 / (1 + a * d^(2b))`.
463 pub kernel_b: f32,
464 /// Number of negative samples per positive edge per epoch.
465 pub neg_sample_rate: usize,
466
467 /// Milliseconds to sleep at the end of every training epoch.
468 ///
469 /// `0` (the default) disables the pause and gives maximum throughput.
470 /// Increase this value to reduce GPU utilisation at the cost of longer
471 /// training time (e.g. `cooldown_ms = 5` for ~10–20 % GPU headroom).
472 pub cooldown_ms: u64,
473
474 /// Directory where loss-curve and embedding snapshot plots are written.
475 ///
476 /// Defaults to `None`, which resolves to `"figures"` in the current working
477 /// directory. Set to a writable [`PathBuf`] when the process runs on a
478 /// read-only filesystem.
479 pub figures_dir: Option<PathBuf>,
480}
481
482impl TrainingConfig {
483 /// Creates a new builder for constructing a `TrainingConfig`.
484 pub fn builder() -> TrainingConfigBuilder {
485 TrainingConfigBuilder::default()
486 }
487}
488
489impl From<&UmapConfig> for TrainingConfig {
490 fn from(config: &UmapConfig) -> Self {
491 let (kernel_a, kernel_b) = fit_ab(config.manifold.min_dist, config.manifold.spread);
492 TrainingConfig {
493 metric: config.graph.metric.clone(),
494 epochs: config.optimization.n_epochs,
495 batch_size: config.optimization.batch_size,
496 learning_rate: config.optimization.learning_rate,
497 beta1: config.optimization.beta1,
498 beta2: config.optimization.beta2,
499 penalty: config.optimization.penalty,
500 verbose: config.optimization.verbose,
501 patience: config.optimization.patience,
502 loss_reduction: config.optimization.loss_reduction.clone(),
503 k_neighbors: config.graph.n_neighbors,
504 min_desired_loss: config.optimization.min_desired_loss,
505 timeout: config.optimization.timeout,
506 normalized: config.graph.normalized,
507 minkowski_p: config.graph.minkowski_p,
508 repulsion_strength: config.optimization.repulsion_strength,
509 kernel_a,
510 kernel_b,
511 neg_sample_rate: config.optimization.neg_sample_rate,
512 cooldown_ms: config.optimization.cooldown_ms,
513 figures_dir: config.optimization.figures_dir.clone(),
514 }
515 }
516}
517
518impl From<UmapConfig> for TrainingConfig {
519 fn from(config: UmapConfig) -> Self {
520 TrainingConfig::from(&config)
521 }
522}
523
524impl From<&TrainingConfig> for UmapConfig {
525 fn from(config: &TrainingConfig) -> Self {
526 UmapConfig {
527 n_components: 2,
528 hidden_sizes: vec![100],
529 manifold: ManifoldParams::default(),
530 graph: GraphParams {
531 n_neighbors: config.k_neighbors,
532 metric: config.metric.clone(),
533 normalized: config.normalized,
534 minkowski_p: config.minkowski_p,
535 },
536 optimization: OptimizationParams {
537 n_epochs: config.epochs,
538 batch_size: config.batch_size,
539 learning_rate: config.learning_rate,
540 beta1: config.beta1,
541 beta2: config.beta2,
542 penalty: config.penalty,
543 repulsion_strength: config.repulsion_strength,
544 patience: config.patience,
545 loss_reduction: config.loss_reduction.clone(),
546 min_desired_loss: config.min_desired_loss,
547 timeout: config.timeout,
548 verbose: config.verbose,
549 neg_sample_rate: config.neg_sample_rate,
550 cooldown_ms: config.cooldown_ms,
551 figures_dir: config.figures_dir.clone(),
552 },
553 }
554 }
555}
556
557impl From<TrainingConfig> for UmapConfig {
558 fn from(config: TrainingConfig) -> Self {
559 UmapConfig::from(&config)
560 }
561}
562
563/// Builder pattern for constructing a `TrainingConfig` with optional parameters.
564#[derive(Default)]
565pub struct TrainingConfigBuilder {
566 metric: Option<Metric>,
567 epochs: Option<usize>,
568 batch_size: Option<usize>,
569 learning_rate: Option<f64>,
570 beta1: Option<f64>,
571 beta2: Option<f64>,
572 penalty: Option<f32>,
573 verbose: Option<bool>,
574 patience: Option<i32>,
575 loss_reduction: Option<LossReduction>,
576 k_neighbors: Option<usize>,
577 min_desired_loss: Option<f64>,
578 timeout: Option<u64>,
579 normalized: Option<bool>,
580 minkowski_p: Option<f64>,
581 repulsion_strength: Option<f32>,
582 neg_sample_rate: Option<usize>,
583 cooldown_ms: Option<u64>,
584 figures_dir: Option<PathBuf>,
585}
586
587impl TrainingConfigBuilder {
588 pub fn with_metric(mut self, metric: Metric) -> Self {
589 self.metric = Some(metric);
590 self
591 }
592
593 pub fn with_epochs(mut self, epochs: usize) -> Self {
594 self.epochs = Some(epochs);
595 self
596 }
597
598 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
599 self.batch_size = Some(batch_size);
600 self
601 }
602
603 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
604 self.learning_rate = Some(learning_rate);
605 self
606 }
607
608 pub fn with_beta1(mut self, beta1: f64) -> Self {
609 self.beta1 = Some(beta1);
610 self
611 }
612
613 pub fn with_beta2(mut self, beta2: f64) -> Self {
614 self.beta2 = Some(beta2);
615 self
616 }
617
618 pub fn with_penalty(mut self, penalty: f32) -> Self {
619 self.penalty = Some(penalty);
620 self
621 }
622
623 pub fn with_verbose(mut self, verbose: bool) -> Self {
624 self.verbose = Some(verbose);
625 self
626 }
627
628 pub fn with_patience(mut self, patience: i32) -> Self {
629 self.patience = Some(patience);
630 self
631 }
632
633 pub fn with_loss_reduction(mut self, loss_reduction: LossReduction) -> Self {
634 self.loss_reduction = Some(loss_reduction);
635 self
636 }
637
638 pub fn with_k_neighbors(mut self, k_neighbors: usize) -> Self {
639 self.k_neighbors = Some(k_neighbors);
640 self
641 }
642
643 pub fn with_min_desired_loss(mut self, min_desired_loss: f64) -> Self {
644 self.min_desired_loss = Some(min_desired_loss);
645 self
646 }
647
648 pub fn with_timeout(mut self, timeout: u64) -> Self {
649 self.timeout = Some(timeout);
650 self
651 }
652
653 pub fn with_normalized(mut self, normalized: bool) -> Self {
654 self.normalized = Some(normalized);
655 self
656 }
657
658 pub fn with_minkowski_p(mut self, minkowski_p: f64) -> Self {
659 self.minkowski_p = Some(minkowski_p);
660 self
661 }
662
663 pub fn with_repulsion_strength(mut self, repulsion_strength: f32) -> Self {
664 self.repulsion_strength = Some(repulsion_strength);
665 self
666 }
667
668 pub fn with_neg_sample_rate(mut self, neg_sample_rate: usize) -> Self {
669 self.neg_sample_rate = Some(neg_sample_rate);
670 self
671 }
672
673 /// Set the per-epoch cooldown sleep in milliseconds.
674 ///
675 /// Inserting a pause between epochs prevents the GPU from being pinned at
676 /// 100 % utilisation. `0` (the default) disables the sleep entirely.
677 pub fn with_cooldown_ms(mut self, cooldown_ms: u64) -> Self {
678 self.cooldown_ms = Some(cooldown_ms);
679 self
680 }
681
682 /// Set the directory where loss-curve and snapshot plots are saved.
683 ///
684 /// Use this to redirect output away from a read-only working directory:
685 /// ```ignore
686 /// .with_figures_dir(std::env::temp_dir().join("umap_figures"))
687 /// ```
688 pub fn with_figures_dir(mut self, dir: impl Into<PathBuf>) -> Self {
689 self.figures_dir = Some(dir.into());
690 self
691 }
692
693 pub fn build(self) -> Option<TrainingConfig> {
694 let defaults = ManifoldParams::default();
695 let (kernel_a, kernel_b) = fit_ab(defaults.min_dist, defaults.spread);
696 Some(TrainingConfig {
697 metric: self.metric.unwrap_or(Metric::Euclidean),
698 epochs: self.epochs.unwrap_or(1000),
699 batch_size: self.batch_size.unwrap_or(1000),
700 learning_rate: self.learning_rate.unwrap_or(0.001),
701 beta1: self.beta1.unwrap_or(0.9),
702 beta2: self.beta2.unwrap_or(0.999),
703 penalty: self.penalty.unwrap_or(1e-5),
704 verbose: self.verbose.unwrap_or(false),
705 patience: self.patience,
706 loss_reduction: self.loss_reduction.unwrap_or(LossReduction::Sum),
707 k_neighbors: self.k_neighbors.unwrap_or(15),
708 min_desired_loss: self.min_desired_loss,
709 timeout: self.timeout,
710 normalized: self.normalized.unwrap_or(true),
711 minkowski_p: self.minkowski_p.unwrap_or(1.0),
712 repulsion_strength: self.repulsion_strength.unwrap_or(1.0),
713 kernel_a,
714 kernel_b,
715 neg_sample_rate: self.neg_sample_rate.unwrap_or(5),
716 cooldown_ms: self.cooldown_ms.unwrap_or(0),
717 figures_dir: self.figures_dir,
718 })
719 }
720}