pacmap/lib.rs
1#![allow(clippy::multiple_crate_versions)]
2
3//! # `PaCMAP`: Pairwise Controlled Manifold Approximation
4//!
5//! This crate provides a Rust implementation of `PaCMAP` (Pairwise Controlled
6//! Manifold Approximation), a dimensionality reduction technique that preserves
7//! both local and global structure of high-dimensional data.
8//!
9//! `PaCMAP` transforms high-dimensional data into a lower-dimensional
10//! representation while preserving important relationships between points. This
11//! is useful for visualization, analysis, and as preprocessing for other
12//! algorithms.
13//!
14//! ## Key Features
15//!
16//! `PaCMAP` preserves both local and global structure through three types of
17//! point relationships:
18//! - Nearest neighbor pairs preserve local structure
19//! - Mid-near pairs preserve intermediate structure
20//! - Far pairs prevent collapse and maintain separation
21//!
22//! The implementation provides:
23//! - Configurable optimization with adaptive learning rates via Adam
24//! optimization
25//! - Phase-based weight schedules to balance local and global preservation
26//! - Multiple initialization options including PCA and random seeding
27//! - Optional snapshot capture of intermediate states
28//!
29//! ## Examples
30//!
31//! Basic usage with default parameters:
32//! ```rust,no_run
33//! use ndarray::Array2;
34//! use pacmap::{Configuration, fit_transform};
35//!
36//! let data: Array2<f32> = // ... load your high-dimensional data
37//! # Array2::zeros((100, 50));
38//! let config = Configuration::default();
39//! let (embedding, _) = fit_transform(data.view(), config).unwrap();
40//! ```
41//!
42//! Customized embedding:
43//! ```rust,no_run
44//! use pacmap::{Configuration, Initialization};
45//!
46//! let config = Configuration::builder()
47//! .embedding_dimensions(3)
48//! .initialization(Initialization::Random(Some(42)))
49//! .learning_rate(0.8)
50//! .num_iters((50, 50, 100))
51//! .mid_near_ratio(0.3)
52//! .far_pair_ratio(2.0)
53//! .build();
54//! ```
55//!
56//! Capturing intermediate states:
57//! ```rust,no_run
58//! use pacmap::Configuration;
59//!
60//! let config = Configuration::builder()
61//! .snapshots(vec![100, 200, 300])
62//! .build();
63//! ```
64//!
65//! ## Configuration
66//!
67//! Core parameters:
68//! - `embedding_dimensions`: Output dimensionality (default: 2)
69//! - `initialization`: How to initialize coordinates:
70//! - `Pca` - Project data using PCA (default)
71//! - `Value(array)` - Use provided coordinates
72//! - `Random(seed)` - Random initialization with optional seed
73//! - `learning_rate`: Learning rate for Adam optimizer (default: 1.0)
74//! - `num_iters`: Iteration counts for three optimization phases (default:
75//! (100, 100, 250))
76//! - `snapshots`: Optional vector of iterations at which to save embedding
77//! states
78//! - `approx_threshold`: Number of points above which approximate neighbor
79//! search is used
80//!
81//! Pair sampling parameters:
82//! - `mid_near_ratio`: Ratio of mid-near to nearest neighbor pairs (default:
83//! 0.5)
84//! - `far_pair_ratio`: Ratio of far to nearest neighbor pairs (default: 2.0)
85//! - `override_neighbors`: Optional fixed neighbor count override
86//! - `seed`: Optional random seed for reproducible sampling
87//!
88//! ## Feature Flags
89//!
90//! ### BLAS/LAPACK Backends
91//!
92//! Only one BLAS/LAPACK backend feature should be enabled at a time. These are
93//! required for PCA operations except on macOS which uses Accelerate by
94//! default.
95//!
96//! - `intel-mkl-static` - Static linking with Intel MKL
97//! - `intel-mkl-system` - Dynamic linking with system Intel MKL
98//! - `openblas-static` - Static linking with `OpenBLAS`
99//! - `openblas-system` - Dynamic linking with system `OpenBLAS`
100//! - `netlib-static` - Static linking with Netlib
101//! - `netlib-system` - Dynamic linking with system Netlib
102//!
103//! For more details on BLAS/LAPACK configuration, see the [ndarray-linalg
104//! documentation](https://github.com/rust-ndarray/ndarray-linalg#backend-features).
105//!
106//! ### Performance Features
107//!
108//! - `simsimd` - Enable SIMD optimizations in `USearch` for faster approximate
109//! nearest neighbor search. Requires GCC 13+ for compilation and a recent
110//! glibc at runtime.
111//!
112//! ## Implementation Notes
113//!
114//! - Supports both exact and approximate nearest neighbor search
115//! - Uses Euclidean distances for pair relationships
116//! - Leverages ndarray for efficient matrix operations
117//! - Employs parallel iterators via rayon for performance
118//! - Provides detailed error handling with custom error types
119//!
120//! ## References
121//!
122//! [Understanding How Dimension Reduction Tools Work: An Empirical Approach to Deciphering t-SNE, UMAP, TriMap, and PaCMAP for Data Visualization](https://jmlr.org/papers/v22/20-1061.html).
123//! Wang, Y., Huang, H., Rudin, C., & Shaposhnik, Y. (2021).
124//! Journal of Machine Learning Research, 22(201), 1-73.
125//!
126//! Original Python implementation: <https://github.com/YingfanWang/PaCMAP>
127
128// Submodule imports
129mod adam;
130mod distance;
131mod gradient;
132pub mod knn;
133mod neighbors;
134mod sampling;
135mod weights;
136
137#[cfg(test)]
138mod tests;
139
140use bon::Builder;
141use ndarray::{s, Array1, Array2, Array3, ArrayView2, Axis, Zip};
142use ndarray_rand::rand_distr::{Normal, NormalError};
143use ndarray_rand::RandomExt;
144use petal_decomposition::{DecompositionError, Pca, RandomizedPca, RandomizedPcaBuilder};
145use rand::rngs::SmallRng;
146use rand::SeedableRng;
147use rand_pcg::Mcg128Xsl64;
148use std::cmp::min;
149use std::time::Instant;
150use thiserror::Error;
151use tracing::{debug, warn};
152
153use crate::adam::update_embedding_adam;
154use crate::gradient::pacmap_grad;
155use crate::knn::KnnError;
156use crate::neighbors::{generate_pair_no_neighbors, generate_pairs};
157use crate::weights::find_weights;
158
159/// Configuration options for the `PaCMAP` embedding process.
160///
161/// Controls initialization, sampling ratios, optimization parameters, and
162/// snapshot capture.
163#[derive(Builder, Clone, Debug)]
164pub struct Configuration {
165 /// Number of dimensions in the output embedding space, typically 2 or 3
166 #[builder(default = 2)]
167 pub embedding_dimensions: usize,
168
169 /// Method for initializing the embedding coordinates
170 #[builder(default)]
171 pub initialization: Initialization,
172
173 /// Ratio of mid-near pairs to nearest neighbor pairs
174 #[builder(default = 0.5)]
175 pub mid_near_ratio: f32,
176
177 /// Ratio of far pairs to nearest neighbor pairs
178 #[builder(default = 2.0)]
179 pub far_pair_ratio: f32,
180
181 /// Optional fixed neighbor count override
182 pub override_neighbors: Option<usize>,
183
184 /// Optional random seed for reproducibility
185 pub seed: Option<u64>,
186
187 /// Controls how point pairs are sampled or provided
188 #[builder(default)]
189 pub pair_configuration: PairConfiguration,
190
191 /// Learning rate for the Adam optimizer
192 #[builder(default = 1.0)]
193 pub learning_rate: f32,
194
195 /// Number of iterations for attraction, local structure, and global
196 /// structure phases
197 #[builder(default = (100, 100, 250))]
198 pub num_iters: (usize, usize, usize),
199
200 /// Optional iteration indices at which to save embedding states
201 pub snapshots: Option<Vec<usize>>,
202
203 /// Number of points above which approximate neighbor search is used
204 #[builder(default = 8_000)]
205 pub approx_threshold: usize,
206}
207
208impl Default for Configuration {
209 fn default() -> Self {
210 Self {
211 embedding_dimensions: 2,
212 initialization: Initialization::default(),
213 mid_near_ratio: 0.5,
214 far_pair_ratio: 2.0,
215 override_neighbors: None,
216 seed: None,
217 pair_configuration: PairConfiguration::default(),
218 learning_rate: 1.0,
219 num_iters: (100, 100, 250),
220 snapshots: None,
221 approx_threshold: 8_000,
222 }
223 }
224}
225
226/// Methods for initializing the embedding coordinates.
227#[derive(Clone, Debug, Default)]
228#[non_exhaustive]
229pub enum Initialization {
230 /// Project data using PCA
231 #[default]
232 Pca,
233
234 /// Use provided coordinate values
235 Value(Array2<f32>),
236
237 /// Initialize randomly with optional seed
238 Random(Option<u64>),
239}
240
241/// Strategy for sampling pairs during optimization.
242#[derive(Clone, Debug, Default)]
243#[non_exhaustive]
244pub enum PairConfiguration {
245 /// Sample all pairs from scratch based on distances.
246 /// Most computationally intensive but requires no prior information.
247 #[default]
248 Generate,
249
250 /// Use provided nearest neighbors and generate mid-near and far pairs.
251 /// Useful when nearest neighbors are pre-computed.
252 NeighborsProvided {
253 /// Matrix of shape (n * k, 2) containing nearest neighbor pair indices
254 pair_neighbors: Array2<u32>,
255 },
256
257 /// Use all provided pair indices without additional sampling.
258 /// Most efficient when all required pairs are pre-computed.
259 AllProvided {
260 /// Nearest neighbor pair indices
261 pair_neighbors: Array2<u32>,
262 /// Mid-near pair indices
263 pair_mn: Array2<u32>,
264 /// Far pair indices
265 pair_fp: Array2<u32>,
266 },
267}
268
269/// Reduces dimensionality of input data using `PaCMAP`.
270///
271/// # Arguments
272/// * `x` - Input data matrix where each row is a sample
273/// * `config` - Configuration options controlling the embedding process
274///
275/// # Returns
276/// A tuple containing:
277/// * Final embedding coordinates as a matrix
278/// * Optional array of intermediate embedding states if snapshots were
279/// requested
280///
281/// # Errors
282/// * `PaCMapError::SampleSize` - Input has <= 1 samples
283/// * `PaCMapError::InvalidNeighborCount` - Calculated neighbor count is invalid
284/// * `PaCMapError::InvalidFarPointCount` - Calculated far point count is
285/// invalid
286/// * `PaCMapError::InvalidNearestNeighborShape` - Provided neighbor pairs have
287/// wrong shape
288/// * `PaCMapError::EmptyArrayMean` - Mean cannot be calculated for
289/// preprocessing
290/// * `PaCMapError::EmptyArrayMinMax` - Min/max cannot be found during
291/// preprocessing
292/// * `PaCMapError::Pca` - PCA decomposition fails
293/// * `PaCMapError::Normal` - Random initialization fails
294pub fn fit_transform(
295 x: ArrayView2<f32>,
296 config: Configuration,
297) -> Result<(Array2<f32>, Option<Array3<f32>>), PaCMapError> {
298 // Input validation
299 let (n, dim) = x.dim();
300 if n <= 1 {
301 return Err(PaCMapError::SampleSize);
302 }
303
304 // Preprocess input data with optional dimensionality reduction
305 let PreprocessingResult {
306 x,
307 pca_solution,
308 maybe_transform,
309 ..
310 } = preprocess_x(
311 x,
312 matches!(config.initialization, Initialization::Pca),
313 dim,
314 config.embedding_dimensions,
315 config.seed,
316 )?;
317
318 // Initialize embedding coordinates
319 let embedding_init = if pca_solution {
320 YInit::Preprocessed
321 } else {
322 match config.initialization {
323 Initialization::Pca => {
324 let transform = maybe_transform.ok_or(PaCMapError::MissingTransform)?;
325 YInit::DimensionalReduction(transform)
326 }
327 Initialization::Value(value) => YInit::Value(value),
328 Initialization::Random(maybe_seed) => YInit::Random(maybe_seed),
329 }
330 };
331
332 // Calculate pair sampling parameters
333 let pair_decision = decide_num_pairs(
334 n,
335 config.override_neighbors,
336 config.mid_near_ratio,
337 config.far_pair_ratio,
338 )?;
339
340 if n - 1 < pair_decision.n_neighbors {
341 warn!("Sample size is smaller than n_neighbors. n_neighbors will be reduced.");
342 }
343
344 // Sample point pairs for optimization
345 let pairs = sample_pairs(
346 x.view(),
347 pair_decision.n_neighbors,
348 pair_decision.n_mn,
349 pair_decision.n_fp,
350 config.pair_configuration,
351 config.seed,
352 config.approx_threshold,
353 )?;
354
355 // Run optimization to compute embedding
356 pacmap(
357 x.view(),
358 config.embedding_dimensions,
359 pairs.pair_neighbors.view(),
360 pairs.pair_mn.view(),
361 pairs.pair_fp.view(),
362 config.learning_rate,
363 config.num_iters,
364 embedding_init,
365 config.snapshots.as_deref(),
366 )
367}
368
369/// Results from preprocessing input data.
370#[allow(dead_code)]
371struct PreprocessingResult {
372 /// Preprocessed data matrix
373 x: Array2<f32>,
374
375 /// Whether PCA dimensionality reduction was applied
376 pca_solution: bool,
377
378 /// Optional fitted dimensionality reduction transform
379 maybe_transform: Option<Transform>,
380
381 /// Minimum x value
382 x_min: f32,
383
384 /// Maximum x value
385 x_max: f32,
386
387 /// Mean of x along axis 0
388 x_mean: Array1<f32>,
389}
390
391/// Types of dimensionality reduction transforms used for initialization.
392#[non_exhaustive]
393enum Transform {
394 /// Standard PCA
395 Pca(Pca<f32>),
396
397 /// Randomized PCA without fixed seed for efficiency on large datasets
398 RandomizedPca(RandomizedPca<f32, Mcg128Xsl64>),
399
400 /// Randomized PCA with fixed seed for reproducibility
401 SeededPca(RandomizedPca<f32, SmallRng>),
402}
403
404impl Transform {
405 /// Applies the transform to new data.
406 ///
407 /// # Arguments
408 /// * `x` - Input data to transform
409 ///
410 /// # Errors
411 /// * `DecompositionError` if transform fails
412 pub fn transform(&self, x: ArrayView2<f32>) -> Result<Array2<f32>, DecompositionError> {
413 match self {
414 Transform::Pca(pca) => pca.transform(&x),
415 Transform::RandomizedPca(pca) => pca.transform(&x),
416 Transform::SeededPca(pca) => pca.transform(&x),
417 }
418 }
419}
420
421/// Preprocesses input data through normalization and optional dimensionality
422/// reduction.
423///
424/// For high dimensional data (>100 dimensions), optionally applies PCA to
425/// reduce to 100 dimensions. Otherwise normalizes the data by centering and
426/// scaling.
427///
428/// # Arguments
429/// * `x` - Input data matrix
430/// * `apply_pca` - Whether to apply PCA dimensionality reduction
431/// * `high_dim` - Original data dimensionality
432/// * `low_dim` - Target dimensionality after reduction
433/// * `maybe_seed` - Optional random seed for reproducibility
434///
435/// # Returns
436/// A `PreprocessingResult` containing the processed data and transform
437///
438/// # Errors
439/// * `PaCMapError::EmptyArrayMean` if mean cannot be calculated
440/// * `PaCMapError::EmptyArrayMinMax` if min/max cannot be found
441/// * `PaCMapError::Pca` if PCA decomposition fails
442fn preprocess_x(
443 x: ArrayView2<f32>,
444 apply_pca: bool,
445 high_dim: usize,
446 low_dim: usize,
447 maybe_seed: Option<u64>,
448) -> Result<PreprocessingResult, PaCMapError> {
449 let mut pca_solution = false;
450 let mut x_out: Array2<f32>;
451 let x_mean: Array1<f32>;
452 let x_min: f32;
453 let x_max: f32;
454 let mut maybe_transform = None;
455
456 if high_dim > 100 && apply_pca {
457 let n_components = min(100, x.nrows());
458 // Compute the mean of x along axis 0
459 x_mean = x.mean_axis(Axis(0)).ok_or(PaCMapError::EmptyArrayMean)?;
460
461 // Initialize PCA and transform
462 match maybe_seed {
463 None => {
464 let mut pca = RandomizedPca::new(n_components);
465 x_out = pca.fit_transform(&x)?;
466 maybe_transform = Some(Transform::RandomizedPca(pca));
467 }
468 Some(seed) => {
469 let mut pca =
470 RandomizedPcaBuilder::with_rng(SmallRng::seed_from_u64(seed), n_components)
471 .build();
472
473 x_out = pca.fit_transform(&x)?;
474 maybe_transform = Some(Transform::SeededPca(pca));
475 }
476 };
477
478 pca_solution = true;
479
480 // Set x_min and x_max to zero
481 x_min = 0.0;
482 x_max = 0.0;
483
484 debug!("Applied PCA, the dimensionality becomes {n_components}");
485 } else {
486 x_out = x.to_owned();
487
488 // Compute x_min and x_max
489 x_min = *x_out
490 .iter()
491 .min_by(|&a, &b| f32::total_cmp(a, b))
492 .ok_or(PaCMapError::EmptyArrayMinMax)?;
493
494 x_max = *x_out
495 .iter()
496 .max_by(|&a, &b| f32::total_cmp(a, b))
497 .ok_or(PaCMapError::EmptyArrayMinMax)?;
498
499 // Subtract x_min from x
500 x_out.mapv_inplace(|val| val - x_min);
501
502 // Divide by x_max (not the range) to replicate the Python function
503 x_out.mapv_inplace(|val| val / x_max);
504
505 // Compute x_mean
506 x_mean = x_out
507 .mean_axis(Axis(0))
508 .ok_or(PaCMapError::EmptyArrayMean)?;
509
510 // Subtract x_mean from x
511 x_out -= &x_mean;
512
513 if apply_pca {
514 // Proceed with PCA
515 let n_components = min(x_out.nrows(), low_dim);
516 let mut pca = Pca::new(n_components);
517 pca.fit(&x_out)?;
518 maybe_transform = Some(Transform::Pca(pca));
519 }
520
521 debug!("x is normalized");
522 };
523
524 Ok(PreprocessingResult {
525 x: x_out,
526 pca_solution,
527 x_min,
528 x_max,
529 x_mean,
530 maybe_transform,
531 })
532}
533
534/// Parameters controlling pair sampling based on dataset size.
535struct PairDecision {
536 /// Number of nearest neighbors per point
537 n_neighbors: usize,
538
539 /// Number of mid-near pairs per point
540 n_mn: usize,
541
542 /// Number of far pairs per point
543 n_fp: usize,
544}
545
546/// Calculates number of pairs to use based on dataset size and configuration.
547///
548/// Automatically scales neighbor counts with dataset size unless overridden.
549///
550/// # Arguments
551/// * `n` - Number of samples in dataset
552/// * `n_neighbors` - Optional fixed neighbor count override
553/// * `mn_ratio` - Ratio of mid-near pairs to nearest neighbors
554/// * `fp_ratio` - Ratio of far pairs to nearest neighbors
555///
556/// # Returns
557/// A `PairDecision` containing the calculated pair counts
558///
559/// # Errors
560/// * `PaCMapError::InvalidNeighborCount` - Calculated neighbor count is less
561/// than 1
562/// * `PaCMapError::InvalidFarPointCount` - Calculated far pair count is less
563/// than 1
564#[allow(clippy::cast_precision_loss)]
565fn decide_num_pairs(
566 n: usize,
567 n_neighbors: Option<usize>,
568 mn_ratio: f32,
569 fp_ratio: f32,
570) -> Result<PairDecision, PaCMapError> {
571 // Scale neighbors with data size or use override
572 let n_neighbors = n_neighbors.unwrap_or_else(|| {
573 if n <= 10000 {
574 10
575 } else {
576 (10.0 + 15.0 * ((n as f32).log10() - 4.0)).round() as usize
577 }
578 });
579
580 let n_mn = (n_neighbors as f32 * mn_ratio).round() as usize;
581 let n_fp = (n_neighbors as f32 * fp_ratio).round() as usize;
582
583 // Validate calculated pair counts
584 if n_neighbors < 1 {
585 return Err(PaCMapError::InvalidNeighborCount);
586 }
587
588 if n_fp < 1 {
589 return Err(PaCMapError::InvalidFarPointCount);
590 }
591
592 Ok(PairDecision {
593 n_neighbors,
594 n_mn,
595 n_fp,
596 })
597}
598
599/// Collection of sampled point pairs used during optimization.
600struct Pairs {
601 /// Nearest neighbor pairs preserving local structure
602 pair_neighbors: Array2<u32>,
603
604 /// Mid-near pairs preserving medium-range structure
605 pair_mn: Array2<u32>,
606
607 /// Far pairs preventing collapse
608 pair_fp: Array2<u32>,
609}
610
611/// Samples point pairs according to the `PaCMAP` strategy.
612///
613/// # Arguments
614/// * `x` - Input data matrix
615/// * `n_neighbors` - Number of nearest neighbors per point
616/// * `n_mn` - Number of mid-near pairs per point
617/// * `n_fp` - Number of far pairs per point
618/// * `pair_config` - Configuration for pair sampling
619/// * `random_state` - Optional random seed
620/// * `approx_threshold` - Number of points above which approximate search is
621/// used
622///
623/// # Returns
624/// A `Pairs` struct containing the sampled pair indices
625///
626/// # Errors
627/// * `PaCMapError::InvalidNearestNeighborShape` if provided pairs have invalid
628/// shape
629fn sample_pairs(
630 x: ArrayView2<f32>,
631 n_neighbors: usize,
632 n_mn: usize,
633 n_fp: usize,
634 pair_config: PairConfiguration,
635 random_state: Option<u64>,
636 approx_threshold: usize,
637) -> Result<Pairs, PaCMapError> {
638 debug!("Finding pairs");
639 match pair_config {
640 // Generate all pairs from scratch
641 PairConfiguration::Generate => Ok(generate_pairs(
642 x,
643 n_neighbors,
644 n_mn,
645 n_fp,
646 random_state,
647 approx_threshold,
648 )?),
649
650 // Use provided nearest neighbors, generate remaining pairs
651 PairConfiguration::NeighborsProvided { pair_neighbors } => {
652 let expected_shape = [x.nrows() * n_neighbors, 2];
653 if pair_neighbors.shape() != expected_shape {
654 return Err(PaCMapError::InvalidNearestNeighborShape {
655 expected: expected_shape,
656 actual: pair_neighbors.shape().to_vec(),
657 });
658 }
659
660 debug!("Using provided nearest neighbor pairs.");
661 let (pair_mn, pair_fp) = generate_pair_no_neighbors(
662 x,
663 n_neighbors,
664 n_mn,
665 n_fp,
666 pair_neighbors.view(),
667 random_state,
668 );
669
670 debug!("Additional pairs sampled successfully.");
671 Ok(Pairs {
672 pair_neighbors,
673 pair_mn,
674 pair_fp,
675 })
676 }
677
678 // Use all provided pairs without additional sampling
679 PairConfiguration::AllProvided {
680 pair_neighbors,
681 pair_mn,
682 pair_fp,
683 } => {
684 debug!("Using all provided pairs.");
685 Ok(Pairs {
686 pair_neighbors,
687 pair_mn,
688 pair_fp,
689 })
690 }
691 }
692}
693
694/// Methods for initializing embedding coordinates.
695enum YInit {
696 /// Use provided coordinate values
697 Value(Array2<f32>),
698
699 /// Use preprocessed data directly
700 Preprocessed,
701
702 /// Apply dimensionality reduction transform
703 DimensionalReduction(Transform),
704
705 /// Initialize randomly with optional seed
706 Random(Option<u64>),
707}
708
709/// Core `PaCMAP` optimization function.
710///
711/// Iteratively updates embedding coordinates through gradient descent to
712/// preserve data structure. Uses phase-based weight schedules to balance local
713/// and global structure preservation.
714///
715/// # Arguments
716/// * `x` - Input data matrix
717/// * `n_dims` - Desired output dimensionality
718/// * `pair_neighbors` - Nearest neighbor pairs
719/// * `pair_mn` - Mid-near pairs
720/// * `pair_fp` - Far pairs
721/// * `lr` - Learning rate
722/// * `num_iters` - Number of iterations for each optimization phase
723/// * `y_initialization` - Method for initializing coordinates
724/// * `inter_snapshots` - Optional indices for saving intermediate states
725///
726/// # Returns
727/// A tuple containing:
728/// * Final embedding coordinates
729/// * Optional intermediate states if snapshots were requested
730///
731/// # Errors
732/// * `PaCMapError::EmptyArrayMean` if mean cannot be calculated
733/// * `PaCMapError::Normal` if random initialization fails
734/// * `PaCMapError::Pca` if PCA transform fails
735#[allow(clippy::too_many_arguments)]
736fn pacmap<'a>(
737 x: ArrayView2<f32>,
738 n_dims: usize,
739 pair_neighbors: ArrayView2<'a, u32>,
740 pair_mn: ArrayView2<'a, u32>,
741 pair_fp: ArrayView2<'a, u32>,
742 lr: f32,
743 num_iters: (usize, usize, usize),
744 y_initialization: YInit,
745 inter_snapshots: Option<&[usize]>,
746) -> Result<(Array2<f32>, Option<Array3<f32>>), PaCMapError> {
747 let start_time = Instant::now();
748 let n = x.nrows();
749 let mut inter_snapshots = Snapshots::from(n_dims, n, inter_snapshots);
750
751 // Initialize embedding coordinates based on specified method
752 let mut y: Array2<f32> = match y_initialization {
753 YInit::Value(mut y) => {
754 let mean = y.mean_axis(Axis(0)).ok_or(PaCMapError::EmptyArrayMean)?;
755 let std = y.std_axis(Axis(0), 0.0);
756
757 // Center and scale provided coordinates
758 Zip::from(&mut y)
759 .and_broadcast(&mean)
760 .and_broadcast(&std)
761 .par_for_each(|y_elem, &mean_elem, &std| {
762 *y_elem = (*y_elem - mean_elem) * 0.0001 / std;
763 });
764
765 y
766 }
767 YInit::Preprocessed => x.slice(s![.., ..n_dims]).to_owned() * 0.01,
768 YInit::DimensionalReduction(transform) => transform.transform(x)? * 0.01,
769 YInit::Random(maybe_seed) => {
770 let normal = Normal::new(0.0, 1.0)?;
771
772 match maybe_seed {
773 None => Array2::random((n, n_dims), normal),
774 Some(seed) => {
775 Array2::random_using((n, n_dims), normal, &mut SmallRng::seed_from_u64(seed))
776 }
777 }
778 }
779 };
780
781 // Initialize optimizer parameters
782 let w_mn_init = 1000.0;
783 let beta1 = 0.9;
784 let beta2 = 0.999;
785 let mut m = Array2::zeros(y.dim());
786 let mut v = Array2::zeros(y.dim());
787
788 // Store initial state if snapshots requested
789 if let Some(ref mut snapshots) = inter_snapshots {
790 snapshots.states.slice_mut(s![0_usize, .., ..]).assign(&y);
791 }
792
793 debug!(
794 "Pair shapes: neighbors {:?}, MN {:?}, FP {:?}",
795 pair_neighbors.dim(),
796 pair_mn.dim(),
797 pair_fp.dim()
798 );
799
800 let num_iters_total = num_iters.0 + num_iters.1 + num_iters.2;
801
802 // Main optimization loop
803 for itr in 0..num_iters_total {
804 // Update weights based on phase
805 let weights = find_weights(w_mn_init, itr, num_iters.0, num_iters.1);
806 let grad = pacmap_grad(y.view(), pair_neighbors, pair_mn, pair_fp, &weights);
807
808 let c = grad[(n, 0)];
809 if itr == 0 {
810 debug!("Initial Loss: {}", c);
811 }
812
813 // Update embedding with Adam optimizer
814 update_embedding_adam(
815 y.view_mut(),
816 grad.view(),
817 m.view_mut(),
818 v.view_mut(),
819 beta1,
820 beta2,
821 lr,
822 itr,
823 );
824
825 if (itr + 1) % 10 == 0 {
826 debug!("Iteration: {:4}, Loss: {}", itr + 1, c);
827 }
828
829 // Store intermediate state if requested
830 let Some(ref mut snapshots) = &mut inter_snapshots else {
831 continue;
832 };
833
834 if let Some(index) = snapshots.indices.iter().position(|&x| x == itr + 1) {
835 snapshots.states.slice_mut(s![index, .., ..]).assign(&y);
836 }
837 }
838
839 let elapsed = start_time.elapsed();
840 debug!("Elapsed time: {:.2?}", elapsed);
841
842 Ok((y, inter_snapshots.map(|s| s.states)))
843}
844
845/// Manages intermediate embedding states during optimization.
846struct Snapshots<'a> {
847 /// Stored embedding states
848 states: Array3<f32>,
849
850 /// Indices at which to take snapshots
851 indices: &'a [usize],
852}
853
854impl<'a> Snapshots<'a> {
855 /// Creates new snapshot manager if indices are provided.
856 ///
857 /// # Arguments
858 /// * `n_dims` - Dimensionality of embedding
859 /// * `n` - Number of samples
860 /// * `maybe_snapshots` - Optional snapshot indices
861 fn from(n_dims: usize, n: usize, maybe_snapshots: Option<&'a [usize]>) -> Option<Self> {
862 let snapshots = maybe_snapshots?;
863 Some(Self {
864 states: Array3::zeros((snapshots.len(), n, n_dims)),
865 indices: snapshots,
866 })
867 }
868}
869
870/// Errors that can occur during `PaCMAP` embedding.
871#[derive(Error, Debug)]
872#[non_exhaustive]
873pub enum PaCMapError {
874 /// Input data has 1 or fewer samples
875 #[error("Sample size must be larger than one")]
876 SampleSize,
877
878 /// Provided nearest neighbor pairs have incorrect dimensions
879 #[error("Invalid shape for nearest neighbor pairs. Expected {expected:?}, got {actual:?}")]
880 InvalidNearestNeighborShape {
881 /// Expected shape: [`n_samples` * `n_neighbors`, 2]
882 expected: [usize; 2],
883 /// Actual shape of provided matrix
884 actual: Vec<usize>,
885 },
886
887 /// Mean calculation failed due to empty array
888 #[error("Failed to calculate mean axis: the array is empty")]
889 EmptyArrayMean,
890
891 /// Normal distribution creation failed
892 #[error(transparent)]
893 Normal(#[from] NormalError),
894
895 /// Calculated number of nearest neighbors is less than 1
896 #[error("The number of nearest neighbors can't be less than 1")]
897 InvalidNeighborCount,
898
899 /// Calculated number of far points is less than 1
900 #[error("The number of far points can't be less than 1")]
901 InvalidFarPointCount,
902
903 /// Min/max computation failed due to empty array
904 #[error("Failed to compute min or max of X: the array is empty")]
905 EmptyArrayMinMax,
906
907 /// Data cannot be normalized due to zero range
908 #[error("The range of X is zero (max - min = 0), cannot normalize")]
909 ZeroRange,
910
911 /// PCA decomposition failed
912 #[error(transparent)]
913 Pca(#[from] DecompositionError),
914
915 /// K-nearest neighbors error
916 #[error(transparent)]
917 Neighbors(#[from] KnnError),
918
919 /// PCA transform was not initialized
920 #[error("Unable to perform PCA; transform not initialized")]
921 MissingTransform,
922}