pacmap/
lib.rs

1#![allow(clippy::multiple_crate_versions)]
2
3//! # `PaCMAP`: Pairwise Controlled Manifold Approximation
4//!
5//! This crate provides a Rust implementation of `PaCMAP` (Pairwise Controlled
6//! Manifold Approximation), a dimensionality reduction technique that preserves
7//! both local and global structure of high-dimensional data.
8//!
9//! `PaCMAP` transforms high-dimensional data into a lower-dimensional
10//! representation while preserving important relationships between points. This
11//! is useful for visualization, analysis, and as preprocessing for other
12//! algorithms.
13//!
14//! ## Key Features
15//!
16//! `PaCMAP` preserves both local and global structure through three types of
17//! point relationships:
18//! - Nearest neighbor pairs preserve local structure
19//! - Mid-near pairs preserve intermediate structure
20//! - Far pairs prevent collapse and maintain separation
21//!
22//! The implementation provides:
23//! - Configurable optimization with adaptive learning rates via Adam
24//!   optimization
25//! - Phase-based weight schedules to balance local and global preservation
26//! - Multiple initialization options including PCA and random seeding
27//! - Optional snapshot capture of intermediate states
28//!
29//! ## Examples
30//!
31//! Basic usage with default parameters:
32//! ```rust,no_run
33//! use ndarray::Array2;
34//! use pacmap::{Configuration, fit_transform};
35//!
36//! let data: Array2<f32> = // ... load your high-dimensional data
37//! # Array2::zeros((100, 50));
38//! let config = Configuration::default();
39//! let (embedding, _) = fit_transform(data.view(), config).unwrap();
40//! ```
41//!
42//! Customized embedding:
43//! ```rust,no_run
44//! use pacmap::{Configuration, Initialization};
45//!
46//! let config = Configuration::builder()
47//!     .embedding_dimensions(3)
48//!     .initialization(Initialization::Random(Some(42)))
49//!     .learning_rate(0.8)
50//!     .num_iters((50, 50, 100))
51//!     .mid_near_ratio(0.3)
52//!     .far_pair_ratio(2.0)
53//!     .build();
54//! ```
55//!
56//! Capturing intermediate states:
57//! ```rust,no_run
58//! use pacmap::Configuration;
59//!
60//! let config = Configuration::builder()
61//!     .snapshots(vec![100, 200, 300])
62//!     .build();
63//! ```
64//!
65//! ## Configuration
66//!
67//! Core parameters:
68//! - `embedding_dimensions`: Output dimensionality (default: 2)
69//! - `initialization`: How to initialize coordinates:
70//!   - `Pca` - Project data using PCA (default)
71//!   - `Value(array)` - Use provided coordinates
72//!   - `Random(seed)` - Random initialization with optional seed
73//! - `learning_rate`: Learning rate for Adam optimizer (default: 1.0)
74//! - `num_iters`: Iteration counts for three optimization phases (default:
75//!   (100, 100, 250))
76//! - `snapshots`: Optional vector of iterations at which to save embedding
77//!   states
78//! - `approx_threshold`: Number of points above which approximate neighbor
79//!   search is used
80//!
81//! Pair sampling parameters:
82//! - `mid_near_ratio`: Ratio of mid-near to nearest neighbor pairs (default:
83//!   0.5)
84//! - `far_pair_ratio`: Ratio of far to nearest neighbor pairs (default: 2.0)
85//! - `override_neighbors`: Optional fixed neighbor count override
86//! - `seed`: Optional random seed for reproducible sampling
87//!
88//! ## Feature Flags
89//!
90//! ### BLAS/LAPACK Backends
91//!
92//! Only one BLAS/LAPACK backend feature should be enabled at a time. These are
93//! required for PCA operations except on macOS which uses Accelerate by
94//! default.
95//!
96//! - `intel-mkl-static` - Static linking with Intel MKL
97//! - `intel-mkl-system` - Dynamic linking with system Intel MKL
98//! - `openblas-static` - Static linking with `OpenBLAS`
99//! - `openblas-system` - Dynamic linking with system `OpenBLAS`
100//! - `netlib-static` - Static linking with Netlib
101//! - `netlib-system` - Dynamic linking with system Netlib
102//!
103//! For more details on BLAS/LAPACK configuration, see the [ndarray-linalg
104//! documentation](https://github.com/rust-ndarray/ndarray-linalg#backend-features).
105//!
106//! ### Performance Features
107//!
108//! - `simsimd` - Enable SIMD optimizations in `USearch` for faster approximate
109//!   nearest neighbor search. Requires GCC 13+ for compilation and a recent
110//!   glibc at runtime.
111//!
112//! ## Implementation Notes
113//!
114//! - Supports both exact and approximate nearest neighbor search
115//! - Uses Euclidean distances for pair relationships
116//! - Leverages ndarray for efficient matrix operations
117//! - Employs parallel iterators via rayon for performance
118//! - Provides detailed error handling with custom error types
119//!
120//! ## References
121//!
122//! [Understanding How Dimension Reduction Tools Work: An Empirical Approach to Deciphering t-SNE, UMAP, TriMap, and PaCMAP for Data Visualization](https://jmlr.org/papers/v22/20-1061.html).
123//! Wang, Y., Huang, H., Rudin, C., & Shaposhnik, Y. (2021).
124//! Journal of Machine Learning Research, 22(201), 1-73.
125//!
126//! Original Python implementation: <https://github.com/YingfanWang/PaCMAP>
127
128// Submodule imports
129mod adam;
130mod distance;
131mod gradient;
132pub mod knn;
133mod neighbors;
134mod sampling;
135mod weights;
136
137#[cfg(test)]
138mod tests;
139
140use bon::Builder;
141use ndarray::{s, Array1, Array2, Array3, ArrayView2, Axis, Zip};
142use ndarray_rand::rand_distr::{Normal, NormalError};
143use ndarray_rand::RandomExt;
144use petal_decomposition::{DecompositionError, Pca, RandomizedPca, RandomizedPcaBuilder};
145use rand::rngs::SmallRng;
146use rand::SeedableRng;
147use rand_pcg::Mcg128Xsl64;
148use std::cmp::min;
149use std::time::Instant;
150use thiserror::Error;
151use tracing::{debug, warn};
152
153use crate::adam::update_embedding_adam;
154use crate::gradient::pacmap_grad;
155use crate::knn::KnnError;
156use crate::neighbors::{generate_pair_no_neighbors, generate_pairs};
157use crate::weights::find_weights;
158
159/// Configuration options for the `PaCMAP` embedding process.
160///
161/// Controls initialization, sampling ratios, optimization parameters, and
162/// snapshot capture.
163#[derive(Builder, Clone, Debug)]
164pub struct Configuration {
165    /// Number of dimensions in the output embedding space, typically 2 or 3
166    #[builder(default = 2)]
167    pub embedding_dimensions: usize,
168
169    /// Method for initializing the embedding coordinates
170    #[builder(default)]
171    pub initialization: Initialization,
172
173    /// Ratio of mid-near pairs to nearest neighbor pairs
174    #[builder(default = 0.5)]
175    pub mid_near_ratio: f32,
176
177    /// Ratio of far pairs to nearest neighbor pairs
178    #[builder(default = 2.0)]
179    pub far_pair_ratio: f32,
180
181    /// Optional fixed neighbor count override
182    pub override_neighbors: Option<usize>,
183
184    /// Optional random seed for reproducibility
185    pub seed: Option<u64>,
186
187    /// Controls how point pairs are sampled or provided
188    #[builder(default)]
189    pub pair_configuration: PairConfiguration,
190
191    /// Learning rate for the Adam optimizer
192    #[builder(default = 1.0)]
193    pub learning_rate: f32,
194
195    /// Number of iterations for attraction, local structure, and global
196    /// structure phases
197    #[builder(default = (100, 100, 250))]
198    pub num_iters: (usize, usize, usize),
199
200    /// Optional iteration indices at which to save embedding states
201    pub snapshots: Option<Vec<usize>>,
202
203    /// Number of points above which approximate neighbor search is used
204    #[builder(default = 8_000)]
205    pub approx_threshold: usize,
206}
207
208impl Default for Configuration {
209    fn default() -> Self {
210        Self {
211            embedding_dimensions: 2,
212            initialization: Initialization::default(),
213            mid_near_ratio: 0.5,
214            far_pair_ratio: 2.0,
215            override_neighbors: None,
216            seed: None,
217            pair_configuration: PairConfiguration::default(),
218            learning_rate: 1.0,
219            num_iters: (100, 100, 250),
220            snapshots: None,
221            approx_threshold: 8_000,
222        }
223    }
224}
225
226/// Methods for initializing the embedding coordinates.
227#[derive(Clone, Debug, Default)]
228#[non_exhaustive]
229pub enum Initialization {
230    /// Project data using PCA
231    #[default]
232    Pca,
233
234    /// Use provided coordinate values
235    Value(Array2<f32>),
236
237    /// Initialize randomly with optional seed
238    Random(Option<u64>),
239}
240
241/// Strategy for sampling pairs during optimization.
242#[derive(Clone, Debug, Default)]
243#[non_exhaustive]
244pub enum PairConfiguration {
245    /// Sample all pairs from scratch based on distances.
246    /// Most computationally intensive but requires no prior information.
247    #[default]
248    Generate,
249
250    /// Use provided nearest neighbors and generate mid-near and far pairs.
251    /// Useful when nearest neighbors are pre-computed.
252    NeighborsProvided {
253        /// Matrix of shape (n * k, 2) containing nearest neighbor pair indices
254        pair_neighbors: Array2<u32>,
255    },
256
257    /// Use all provided pair indices without additional sampling.
258    /// Most efficient when all required pairs are pre-computed.
259    AllProvided {
260        /// Nearest neighbor pair indices
261        pair_neighbors: Array2<u32>,
262        /// Mid-near pair indices
263        pair_mn: Array2<u32>,
264        /// Far pair indices
265        pair_fp: Array2<u32>,
266    },
267}
268
269/// Reduces dimensionality of input data using `PaCMAP`.
270///
271/// # Arguments
272/// * `x` - Input data matrix where each row is a sample
273/// * `config` - Configuration options controlling the embedding process
274///
275/// # Returns
276/// A tuple containing:
277/// * Final embedding coordinates as a matrix
278/// * Optional array of intermediate embedding states if snapshots were
279///   requested
280///
281/// # Errors
282/// * `PaCMapError::SampleSize` - Input has <= 1 samples
283/// * `PaCMapError::InvalidNeighborCount` - Calculated neighbor count is invalid
284/// * `PaCMapError::InvalidFarPointCount` - Calculated far point count is
285///   invalid
286/// * `PaCMapError::InvalidNearestNeighborShape` - Provided neighbor pairs have
287///   wrong shape
288/// * `PaCMapError::EmptyArrayMean` - Mean cannot be calculated for
289///   preprocessing
290/// * `PaCMapError::EmptyArrayMinMax` - Min/max cannot be found during
291///   preprocessing
292/// * `PaCMapError::Pca` - PCA decomposition fails
293/// * `PaCMapError::Normal` - Random initialization fails
294pub fn fit_transform(
295    x: ArrayView2<f32>,
296    config: Configuration,
297) -> Result<(Array2<f32>, Option<Array3<f32>>), PaCMapError> {
298    // Input validation
299    let (n, dim) = x.dim();
300    if n <= 1 {
301        return Err(PaCMapError::SampleSize);
302    }
303
304    // Preprocess input data with optional dimensionality reduction
305    let PreprocessingResult {
306        x,
307        pca_solution,
308        maybe_transform,
309        ..
310    } = preprocess_x(
311        x,
312        matches!(config.initialization, Initialization::Pca),
313        dim,
314        config.embedding_dimensions,
315        config.seed,
316    )?;
317
318    // Initialize embedding coordinates
319    let embedding_init = if pca_solution {
320        YInit::Preprocessed
321    } else {
322        match config.initialization {
323            Initialization::Pca => {
324                let transform = maybe_transform.ok_or(PaCMapError::MissingTransform)?;
325                YInit::DimensionalReduction(transform)
326            }
327            Initialization::Value(value) => YInit::Value(value),
328            Initialization::Random(maybe_seed) => YInit::Random(maybe_seed),
329        }
330    };
331
332    // Calculate pair sampling parameters
333    let pair_decision = decide_num_pairs(
334        n,
335        config.override_neighbors,
336        config.mid_near_ratio,
337        config.far_pair_ratio,
338    )?;
339
340    if n - 1 < pair_decision.n_neighbors {
341        warn!("Sample size is smaller than n_neighbors. n_neighbors will be reduced.");
342    }
343
344    // Sample point pairs for optimization
345    let pairs = sample_pairs(
346        x.view(),
347        pair_decision.n_neighbors,
348        pair_decision.n_mn,
349        pair_decision.n_fp,
350        config.pair_configuration,
351        config.seed,
352        config.approx_threshold,
353    )?;
354
355    // Run optimization to compute embedding
356    pacmap(
357        x.view(),
358        config.embedding_dimensions,
359        pairs.pair_neighbors.view(),
360        pairs.pair_mn.view(),
361        pairs.pair_fp.view(),
362        config.learning_rate,
363        config.num_iters,
364        embedding_init,
365        config.snapshots.as_deref(),
366    )
367}
368
369/// Results from preprocessing input data.
370#[allow(dead_code)]
371struct PreprocessingResult {
372    /// Preprocessed data matrix
373    x: Array2<f32>,
374
375    /// Whether PCA dimensionality reduction was applied
376    pca_solution: bool,
377
378    /// Optional fitted dimensionality reduction transform
379    maybe_transform: Option<Transform>,
380
381    /// Minimum x value
382    x_min: f32,
383
384    /// Maximum x value
385    x_max: f32,
386
387    /// Mean of x along axis 0
388    x_mean: Array1<f32>,
389}
390
391/// Types of dimensionality reduction transforms used for initialization.
392#[non_exhaustive]
393enum Transform {
394    /// Standard PCA
395    Pca(Pca<f32>),
396
397    /// Randomized PCA without fixed seed for efficiency on large datasets
398    RandomizedPca(RandomizedPca<f32, Mcg128Xsl64>),
399
400    /// Randomized PCA with fixed seed for reproducibility
401    SeededPca(RandomizedPca<f32, SmallRng>),
402}
403
404impl Transform {
405    /// Applies the transform to new data.
406    ///
407    /// # Arguments
408    /// * `x` - Input data to transform
409    ///
410    /// # Errors
411    /// * `DecompositionError` if transform fails
412    pub fn transform(&self, x: ArrayView2<f32>) -> Result<Array2<f32>, DecompositionError> {
413        match self {
414            Transform::Pca(pca) => pca.transform(&x),
415            Transform::RandomizedPca(pca) => pca.transform(&x),
416            Transform::SeededPca(pca) => pca.transform(&x),
417        }
418    }
419}
420
421/// Preprocesses input data through normalization and optional dimensionality
422/// reduction.
423///
424/// For high dimensional data (>100 dimensions), optionally applies PCA to
425/// reduce to 100 dimensions. Otherwise normalizes the data by centering and
426/// scaling.
427///
428/// # Arguments
429/// * `x` - Input data matrix
430/// * `apply_pca` - Whether to apply PCA dimensionality reduction
431/// * `high_dim` - Original data dimensionality
432/// * `low_dim` - Target dimensionality after reduction
433/// * `maybe_seed` - Optional random seed for reproducibility
434///
435/// # Returns
436/// A `PreprocessingResult` containing the processed data and transform
437///
438/// # Errors
439/// * `PaCMapError::EmptyArrayMean` if mean cannot be calculated
440/// * `PaCMapError::EmptyArrayMinMax` if min/max cannot be found
441/// * `PaCMapError::Pca` if PCA decomposition fails
442fn preprocess_x(
443    x: ArrayView2<f32>,
444    apply_pca: bool,
445    high_dim: usize,
446    low_dim: usize,
447    maybe_seed: Option<u64>,
448) -> Result<PreprocessingResult, PaCMapError> {
449    let mut pca_solution = false;
450    let mut x_out: Array2<f32>;
451    let x_mean: Array1<f32>;
452    let x_min: f32;
453    let x_max: f32;
454    let mut maybe_transform = None;
455
456    if high_dim > 100 && apply_pca {
457        let n_components = min(100, x.nrows());
458        // Compute the mean of x along axis 0
459        x_mean = x.mean_axis(Axis(0)).ok_or(PaCMapError::EmptyArrayMean)?;
460
461        // Initialize PCA and transform
462        match maybe_seed {
463            None => {
464                let mut pca = RandomizedPca::new(n_components);
465                x_out = pca.fit_transform(&x)?;
466                maybe_transform = Some(Transform::RandomizedPca(pca));
467            }
468            Some(seed) => {
469                let mut pca =
470                    RandomizedPcaBuilder::with_rng(SmallRng::seed_from_u64(seed), n_components)
471                        .build();
472
473                x_out = pca.fit_transform(&x)?;
474                maybe_transform = Some(Transform::SeededPca(pca));
475            }
476        };
477
478        pca_solution = true;
479
480        // Set x_min and x_max to zero
481        x_min = 0.0;
482        x_max = 0.0;
483
484        debug!("Applied PCA, the dimensionality becomes {n_components}");
485    } else {
486        x_out = x.to_owned();
487
488        // Compute x_min and x_max
489        x_min = *x_out
490            .iter()
491            .min_by(|&a, &b| f32::total_cmp(a, b))
492            .ok_or(PaCMapError::EmptyArrayMinMax)?;
493
494        x_max = *x_out
495            .iter()
496            .max_by(|&a, &b| f32::total_cmp(a, b))
497            .ok_or(PaCMapError::EmptyArrayMinMax)?;
498
499        // Subtract x_min from x
500        x_out.mapv_inplace(|val| val - x_min);
501
502        // Divide by x_max (not the range) to replicate the Python function
503        x_out.mapv_inplace(|val| val / x_max);
504
505        // Compute x_mean
506        x_mean = x_out
507            .mean_axis(Axis(0))
508            .ok_or(PaCMapError::EmptyArrayMean)?;
509
510        // Subtract x_mean from x
511        x_out -= &x_mean;
512
513        if apply_pca {
514            // Proceed with PCA
515            let n_components = min(x_out.nrows(), low_dim);
516            let mut pca = Pca::new(n_components);
517            pca.fit(&x_out)?;
518            maybe_transform = Some(Transform::Pca(pca));
519        }
520
521        debug!("x is normalized");
522    };
523
524    Ok(PreprocessingResult {
525        x: x_out,
526        pca_solution,
527        x_min,
528        x_max,
529        x_mean,
530        maybe_transform,
531    })
532}
533
534/// Parameters controlling pair sampling based on dataset size.
535struct PairDecision {
536    /// Number of nearest neighbors per point
537    n_neighbors: usize,
538
539    /// Number of mid-near pairs per point
540    n_mn: usize,
541
542    /// Number of far pairs per point
543    n_fp: usize,
544}
545
546/// Calculates number of pairs to use based on dataset size and configuration.
547///
548/// Automatically scales neighbor counts with dataset size unless overridden.
549///
550/// # Arguments
551/// * `n` - Number of samples in dataset
552/// * `n_neighbors` - Optional fixed neighbor count override
553/// * `mn_ratio` - Ratio of mid-near pairs to nearest neighbors
554/// * `fp_ratio` - Ratio of far pairs to nearest neighbors
555///
556/// # Returns
557/// A `PairDecision` containing the calculated pair counts
558///
559/// # Errors
560/// * `PaCMapError::InvalidNeighborCount` - Calculated neighbor count is less
561///   than 1
562/// * `PaCMapError::InvalidFarPointCount` - Calculated far pair count is less
563///   than 1
564#[allow(clippy::cast_precision_loss)]
565fn decide_num_pairs(
566    n: usize,
567    n_neighbors: Option<usize>,
568    mn_ratio: f32,
569    fp_ratio: f32,
570) -> Result<PairDecision, PaCMapError> {
571    // Scale neighbors with data size or use override
572    let n_neighbors = n_neighbors.unwrap_or_else(|| {
573        if n <= 10000 {
574            10
575        } else {
576            (10.0 + 15.0 * ((n as f32).log10() - 4.0)).round() as usize
577        }
578    });
579
580    let n_mn = (n_neighbors as f32 * mn_ratio).round() as usize;
581    let n_fp = (n_neighbors as f32 * fp_ratio).round() as usize;
582
583    // Validate calculated pair counts
584    if n_neighbors < 1 {
585        return Err(PaCMapError::InvalidNeighborCount);
586    }
587
588    if n_fp < 1 {
589        return Err(PaCMapError::InvalidFarPointCount);
590    }
591
592    Ok(PairDecision {
593        n_neighbors,
594        n_mn,
595        n_fp,
596    })
597}
598
599/// Collection of sampled point pairs used during optimization.
600struct Pairs {
601    /// Nearest neighbor pairs preserving local structure
602    pair_neighbors: Array2<u32>,
603
604    /// Mid-near pairs preserving medium-range structure
605    pair_mn: Array2<u32>,
606
607    /// Far pairs preventing collapse
608    pair_fp: Array2<u32>,
609}
610
611/// Samples point pairs according to the `PaCMAP` strategy.
612///
613/// # Arguments
614/// * `x` - Input data matrix
615/// * `n_neighbors` - Number of nearest neighbors per point
616/// * `n_mn` - Number of mid-near pairs per point
617/// * `n_fp` - Number of far pairs per point
618/// * `pair_config` - Configuration for pair sampling
619/// * `random_state` - Optional random seed
620/// * `approx_threshold` - Number of points above which approximate search is
621///   used
622///
623/// # Returns
624/// A `Pairs` struct containing the sampled pair indices
625///
626/// # Errors
627/// * `PaCMapError::InvalidNearestNeighborShape` if provided pairs have invalid
628///   shape
629fn sample_pairs(
630    x: ArrayView2<f32>,
631    n_neighbors: usize,
632    n_mn: usize,
633    n_fp: usize,
634    pair_config: PairConfiguration,
635    random_state: Option<u64>,
636    approx_threshold: usize,
637) -> Result<Pairs, PaCMapError> {
638    debug!("Finding pairs");
639    match pair_config {
640        // Generate all pairs from scratch
641        PairConfiguration::Generate => Ok(generate_pairs(
642            x,
643            n_neighbors,
644            n_mn,
645            n_fp,
646            random_state,
647            approx_threshold,
648        )?),
649
650        // Use provided nearest neighbors, generate remaining pairs
651        PairConfiguration::NeighborsProvided { pair_neighbors } => {
652            let expected_shape = [x.nrows() * n_neighbors, 2];
653            if pair_neighbors.shape() != expected_shape {
654                return Err(PaCMapError::InvalidNearestNeighborShape {
655                    expected: expected_shape,
656                    actual: pair_neighbors.shape().to_vec(),
657                });
658            }
659
660            debug!("Using provided nearest neighbor pairs.");
661            let (pair_mn, pair_fp) = generate_pair_no_neighbors(
662                x,
663                n_neighbors,
664                n_mn,
665                n_fp,
666                pair_neighbors.view(),
667                random_state,
668            );
669
670            debug!("Additional pairs sampled successfully.");
671            Ok(Pairs {
672                pair_neighbors,
673                pair_mn,
674                pair_fp,
675            })
676        }
677
678        // Use all provided pairs without additional sampling
679        PairConfiguration::AllProvided {
680            pair_neighbors,
681            pair_mn,
682            pair_fp,
683        } => {
684            debug!("Using all provided pairs.");
685            Ok(Pairs {
686                pair_neighbors,
687                pair_mn,
688                pair_fp,
689            })
690        }
691    }
692}
693
694/// Methods for initializing embedding coordinates.
695enum YInit {
696    /// Use provided coordinate values
697    Value(Array2<f32>),
698
699    /// Use preprocessed data directly
700    Preprocessed,
701
702    /// Apply dimensionality reduction transform
703    DimensionalReduction(Transform),
704
705    /// Initialize randomly with optional seed
706    Random(Option<u64>),
707}
708
709/// Core `PaCMAP` optimization function.
710///
711/// Iteratively updates embedding coordinates through gradient descent to
712/// preserve data structure. Uses phase-based weight schedules to balance local
713/// and global structure preservation.
714///
715/// # Arguments
716/// * `x` - Input data matrix
717/// * `n_dims` - Desired output dimensionality
718/// * `pair_neighbors` - Nearest neighbor pairs
719/// * `pair_mn` - Mid-near pairs
720/// * `pair_fp` - Far pairs
721/// * `lr` - Learning rate
722/// * `num_iters` - Number of iterations for each optimization phase
723/// * `y_initialization` - Method for initializing coordinates
724/// * `inter_snapshots` - Optional indices for saving intermediate states
725///
726/// # Returns
727/// A tuple containing:
728/// * Final embedding coordinates
729/// * Optional intermediate states if snapshots were requested
730///
731/// # Errors
732/// * `PaCMapError::EmptyArrayMean` if mean cannot be calculated
733/// * `PaCMapError::Normal` if random initialization fails
734/// * `PaCMapError::Pca` if PCA transform fails
735#[allow(clippy::too_many_arguments)]
736fn pacmap<'a>(
737    x: ArrayView2<f32>,
738    n_dims: usize,
739    pair_neighbors: ArrayView2<'a, u32>,
740    pair_mn: ArrayView2<'a, u32>,
741    pair_fp: ArrayView2<'a, u32>,
742    lr: f32,
743    num_iters: (usize, usize, usize),
744    y_initialization: YInit,
745    inter_snapshots: Option<&[usize]>,
746) -> Result<(Array2<f32>, Option<Array3<f32>>), PaCMapError> {
747    let start_time = Instant::now();
748    let n = x.nrows();
749    let mut inter_snapshots = Snapshots::from(n_dims, n, inter_snapshots);
750
751    // Initialize embedding coordinates based on specified method
752    let mut y: Array2<f32> = match y_initialization {
753        YInit::Value(mut y) => {
754            let mean = y.mean_axis(Axis(0)).ok_or(PaCMapError::EmptyArrayMean)?;
755            let std = y.std_axis(Axis(0), 0.0);
756
757            // Center and scale provided coordinates
758            Zip::from(&mut y)
759                .and_broadcast(&mean)
760                .and_broadcast(&std)
761                .par_for_each(|y_elem, &mean_elem, &std| {
762                    *y_elem = (*y_elem - mean_elem) * 0.0001 / std;
763                });
764
765            y
766        }
767        YInit::Preprocessed => x.slice(s![.., ..n_dims]).to_owned() * 0.01,
768        YInit::DimensionalReduction(transform) => transform.transform(x)? * 0.01,
769        YInit::Random(maybe_seed) => {
770            let normal = Normal::new(0.0, 1.0)?;
771
772            match maybe_seed {
773                None => Array2::random((n, n_dims), normal),
774                Some(seed) => {
775                    Array2::random_using((n, n_dims), normal, &mut SmallRng::seed_from_u64(seed))
776                }
777            }
778        }
779    };
780
781    // Initialize optimizer parameters
782    let w_mn_init = 1000.0;
783    let beta1 = 0.9;
784    let beta2 = 0.999;
785    let mut m = Array2::zeros(y.dim());
786    let mut v = Array2::zeros(y.dim());
787
788    // Store initial state if snapshots requested
789    if let Some(ref mut snapshots) = inter_snapshots {
790        snapshots.states.slice_mut(s![0_usize, .., ..]).assign(&y);
791    }
792
793    debug!(
794        "Pair shapes: neighbors {:?}, MN {:?}, FP {:?}",
795        pair_neighbors.dim(),
796        pair_mn.dim(),
797        pair_fp.dim()
798    );
799
800    let num_iters_total = num_iters.0 + num_iters.1 + num_iters.2;
801
802    // Main optimization loop
803    for itr in 0..num_iters_total {
804        // Update weights based on phase
805        let weights = find_weights(w_mn_init, itr, num_iters.0, num_iters.1);
806        let grad = pacmap_grad(y.view(), pair_neighbors, pair_mn, pair_fp, &weights);
807
808        let c = grad[(n, 0)];
809        if itr == 0 {
810            debug!("Initial Loss: {}", c);
811        }
812
813        // Update embedding with Adam optimizer
814        update_embedding_adam(
815            y.view_mut(),
816            grad.view(),
817            m.view_mut(),
818            v.view_mut(),
819            beta1,
820            beta2,
821            lr,
822            itr,
823        );
824
825        if (itr + 1) % 10 == 0 {
826            debug!("Iteration: {:4}, Loss: {}", itr + 1, c);
827        }
828
829        // Store intermediate state if requested
830        let Some(ref mut snapshots) = &mut inter_snapshots else {
831            continue;
832        };
833
834        if let Some(index) = snapshots.indices.iter().position(|&x| x == itr + 1) {
835            snapshots.states.slice_mut(s![index, .., ..]).assign(&y);
836        }
837    }
838
839    let elapsed = start_time.elapsed();
840    debug!("Elapsed time: {:.2?}", elapsed);
841
842    Ok((y, inter_snapshots.map(|s| s.states)))
843}
844
845/// Manages intermediate embedding states during optimization.
846struct Snapshots<'a> {
847    /// Stored embedding states
848    states: Array3<f32>,
849
850    /// Indices at which to take snapshots
851    indices: &'a [usize],
852}
853
854impl<'a> Snapshots<'a> {
855    /// Creates new snapshot manager if indices are provided.
856    ///
857    /// # Arguments
858    /// * `n_dims` - Dimensionality of embedding
859    /// * `n` - Number of samples
860    /// * `maybe_snapshots` - Optional snapshot indices
861    fn from(n_dims: usize, n: usize, maybe_snapshots: Option<&'a [usize]>) -> Option<Self> {
862        let snapshots = maybe_snapshots?;
863        Some(Self {
864            states: Array3::zeros((snapshots.len(), n, n_dims)),
865            indices: snapshots,
866        })
867    }
868}
869
870/// Errors that can occur during `PaCMAP` embedding.
871#[derive(Error, Debug)]
872#[non_exhaustive]
873pub enum PaCMapError {
874    /// Input data has 1 or fewer samples
875    #[error("Sample size must be larger than one")]
876    SampleSize,
877
878    /// Provided nearest neighbor pairs have incorrect dimensions
879    #[error("Invalid shape for nearest neighbor pairs. Expected {expected:?}, got {actual:?}")]
880    InvalidNearestNeighborShape {
881        /// Expected shape: [`n_samples` * `n_neighbors`, 2]
882        expected: [usize; 2],
883        /// Actual shape of provided matrix
884        actual: Vec<usize>,
885    },
886
887    /// Mean calculation failed due to empty array
888    #[error("Failed to calculate mean axis: the array is empty")]
889    EmptyArrayMean,
890
891    /// Normal distribution creation failed
892    #[error(transparent)]
893    Normal(#[from] NormalError),
894
895    /// Calculated number of nearest neighbors is less than 1
896    #[error("The number of nearest neighbors can't be less than 1")]
897    InvalidNeighborCount,
898
899    /// Calculated number of far points is less than 1
900    #[error("The number of far points can't be less than 1")]
901    InvalidFarPointCount,
902
903    /// Min/max computation failed due to empty array
904    #[error("Failed to compute min or max of X: the array is empty")]
905    EmptyArrayMinMax,
906
907    /// Data cannot be normalized due to zero range
908    #[error("The range of X is zero (max - min = 0), cannot normalize")]
909    ZeroRange,
910
911    /// PCA decomposition failed
912    #[error(transparent)]
913    Pca(#[from] DecompositionError),
914
915    /// K-nearest neighbors error
916    #[error(transparent)]
917    Neighbors(#[from] KnnError),
918
919    /// PCA transform was not initialized
920    #[error("Unable to perform PCA; transform not initialized")]
921    MissingTransform,
922}