Skip to main content

nabled_ml/
pca.rs

1//! Principal component analysis over ndarray matrices.
2
3use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use nabled_linalg::svd;
7use ndarray::{Array1, Array2, ArrayView2, Axis, s};
8use num_complex::Complex64;
9
10/// PCA result for ndarray matrices.
11#[derive(Debug, Clone)]
12pub struct NdarrayPCAResult<T: NabledReal> {
13    /// Principal components as rows (`k x features`).
14    pub components:               Array2<T>,
15    /// Explained variance for each retained component.
16    pub explained_variance:       Array1<T>,
17    /// Explained variance ratio for each retained component.
18    pub explained_variance_ratio: Array1<T>,
19    /// Column means used for centering.
20    pub mean:                     Array1<T>,
21    /// Scores (`samples x k`).
22    pub scores:                   Array2<T>,
23}
24
25/// PCA result for complex ndarray matrices.
26#[derive(Debug, Clone)]
27pub struct NdarrayComplexPCAResult {
28    /// Principal components as rows (`k x features`).
29    pub components:               Array2<Complex64>,
30    /// Explained variance for each retained component.
31    pub explained_variance:       Array1<f64>,
32    /// Explained variance ratio for each retained component.
33    pub explained_variance_ratio: Array1<f64>,
34    /// Column means used for centering.
35    pub mean:                     Array1<Complex64>,
36    /// Scores (`samples x k`).
37    pub scores:                   Array2<Complex64>,
38}
39
40/// Error type for PCA operations.
41#[derive(Debug, Clone, PartialEq)]
42pub enum PCAError {
43    /// Input matrix is empty.
44    EmptyMatrix,
45    /// Invalid user input.
46    InvalidInput(String),
47    /// Decomposition failed.
48    DecompositionFailed,
49}
50
51impl fmt::Display for PCAError {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        match self {
54            PCAError::EmptyMatrix => write!(f, "Matrix cannot be empty"),
55            PCAError::InvalidInput(message) => write!(f, "Invalid input: {message}"),
56            PCAError::DecompositionFailed => write!(f, "PCA decomposition failed"),
57        }
58    }
59}
60
61impl std::error::Error for PCAError {}
62
63fn usize_to_real<T: NabledReal>(value: usize) -> T {
64    let fallback = T::from_u32(u32::MAX).unwrap_or(T::one());
65    T::from_usize(value).unwrap_or(fallback)
66}
67
68fn center_columns<T: NabledReal>(
69    matrix: &ArrayView2<'_, T>,
70) -> Result<(Array2<T>, Array1<T>), PCAError> {
71    if matrix.is_empty() {
72        return Err(PCAError::EmptyMatrix);
73    }
74    let mean = matrix
75        .mean_axis(Axis(0))
76        .ok_or_else(|| PCAError::InvalidInput("failed to compute column means".to_string()))?;
77    let mut centered = matrix.to_owned();
78    for row in 0..matrix.nrows() {
79        for col in 0..matrix.ncols() {
80            centered[[row, col]] -= mean[col];
81        }
82    }
83    Ok((centered, mean))
84}
85
86fn transform_impl<T: NabledReal>(
87    matrix: &ArrayView2<'_, T>,
88    pca: &NdarrayPCAResult<T>,
89) -> Array2<T> {
90    let mut centered = Array2::<T>::zeros((matrix.nrows(), matrix.ncols()));
91    for row in 0..matrix.nrows() {
92        for col in 0..matrix.ncols() {
93            centered[[row, col]] = matrix[[row, col]] - pca.mean[col];
94        }
95    }
96    centered.dot(&pca.components.t())
97}
98
99fn inverse_transform_impl<T: NabledReal>(
100    scores: &ArrayView2<'_, T>,
101    pca: &NdarrayPCAResult<T>,
102) -> Array2<T> {
103    let mut reconstructed = scores.dot(&pca.components);
104    for row in 0..reconstructed.nrows() {
105        for col in 0..reconstructed.ncols() {
106            reconstructed[[row, col]] += pca.mean[col];
107        }
108    }
109    reconstructed
110}
111
112fn center_columns_complex(
113    matrix: &ArrayView2<'_, Complex64>,
114) -> Result<(Array2<Complex64>, Array1<Complex64>), PCAError> {
115    if matrix.is_empty() {
116        return Err(PCAError::EmptyMatrix);
117    }
118    let mut mean = Array1::<Complex64>::zeros(matrix.ncols());
119    for col in 0..matrix.ncols() {
120        let mut sum = Complex64::new(0.0, 0.0);
121        for row in 0..matrix.nrows() {
122            sum += matrix[[row, col]];
123        }
124        mean[col] = sum / usize_to_real::<f64>(matrix.nrows());
125    }
126
127    let mut centered = matrix.to_owned();
128    for row in 0..matrix.nrows() {
129        for col in 0..matrix.ncols() {
130            centered[[row, col]] -= mean[col];
131        }
132    }
133    Ok((centered, mean))
134}
135
136fn transform_complex_impl(
137    matrix: &ArrayView2<'_, Complex64>,
138    pca: &NdarrayComplexPCAResult,
139) -> Array2<Complex64> {
140    let mut centered = Array2::<Complex64>::zeros((matrix.nrows(), matrix.ncols()));
141    for row in 0..matrix.nrows() {
142        for col in 0..matrix.ncols() {
143            centered[[row, col]] = matrix[[row, col]] - pca.mean[col];
144        }
145    }
146
147    let projection = pca.components.t().mapv(|value| value.conj());
148    centered.dot(&projection)
149}
150
151fn inverse_transform_complex_impl(
152    scores: &ArrayView2<'_, Complex64>,
153    pca: &NdarrayComplexPCAResult,
154) -> Array2<Complex64> {
155    let mut reconstructed = scores.dot(&pca.components);
156    for row in 0..reconstructed.nrows() {
157        for col in 0..reconstructed.ncols() {
158            reconstructed[[row, col]] += pca.mean[col];
159        }
160    }
161    reconstructed
162}
163
164/// Compute principal component analysis.
165///
166/// # Errors
167/// Returns an error for invalid input or decomposition failure.
168#[cfg(feature = "lapack-provider")]
169pub fn compute_pca<T>(
170    matrix: &Array2<T>,
171    n_components: Option<usize>,
172) -> Result<NdarrayPCAResult<T>, PCAError>
173where
174    T: NabledReal + ndarray_linalg::Lapack<Real = T>,
175{
176    compute_pca_impl(&matrix.view(), n_components)
177}
178
179/// Compute principal component analysis.
180///
181/// # Errors
182/// Returns an error for invalid input or decomposition failure.
183#[cfg(not(feature = "lapack-provider"))]
184pub fn compute_pca<T: NabledReal>(
185    matrix: &Array2<T>,
186    n_components: Option<usize>,
187) -> Result<NdarrayPCAResult<T>, PCAError> {
188    compute_pca_impl(&matrix.view(), n_components)
189}
190
191#[cfg(feature = "lapack-provider")]
192fn compute_pca_impl<T>(
193    matrix: &ArrayView2<'_, T>,
194    n_components: Option<usize>,
195) -> Result<NdarrayPCAResult<T>, PCAError>
196where
197    T: NabledReal + ndarray_linalg::Lapack<Real = T>,
198{
199    let (centered, mean) = center_columns(matrix)?;
200    let svd = svd::decompose(&centered).map_err(|_| PCAError::DecompositionFailed)?;
201
202    let max_components = centered.nrows().min(centered.ncols());
203    let keep = n_components.unwrap_or(max_components).min(max_components);
204    if keep == 0 {
205        return Err(PCAError::InvalidInput("n_components must be greater than 0".to_string()));
206    }
207
208    let components = svd.vt.slice(s![..keep, ..]).to_owned();
209    let scores = centered.dot(&components.t());
210
211    let one = T::one();
212    let denominator = (usize_to_real::<T>(centered.nrows()) - one).max(one);
213    let mut explained_variance = Array1::<T>::zeros(keep);
214    for i in 0..keep {
215        explained_variance[i] = (svd.singular_values[i] * svd.singular_values[i]) / denominator;
216    }
217
218    let total_variance = explained_variance
219        .iter()
220        .copied()
221        .fold(T::zero(), |acc, value| acc + value)
222        .max(T::epsilon());
223    let explained_variance_ratio = explained_variance.map(|value| *value / total_variance);
224
225    Ok(NdarrayPCAResult { components, explained_variance, explained_variance_ratio, mean, scores })
226}
227
228#[cfg(not(feature = "lapack-provider"))]
229fn compute_pca_impl<T: NabledReal>(
230    matrix: &ArrayView2<'_, T>,
231    n_components: Option<usize>,
232) -> Result<NdarrayPCAResult<T>, PCAError> {
233    let (centered, mean) = center_columns(matrix)?;
234    let svd = svd::decompose(&centered).map_err(|_| PCAError::DecompositionFailed)?;
235
236    let max_components = centered.nrows().min(centered.ncols());
237    let keep = n_components.unwrap_or(max_components).min(max_components);
238    if keep == 0 {
239        return Err(PCAError::InvalidInput("n_components must be greater than 0".to_string()));
240    }
241
242    let components = svd.vt.slice(s![..keep, ..]).to_owned();
243    let scores = centered.dot(&components.t());
244
245    let one = T::one();
246    let denominator = (usize_to_real::<T>(centered.nrows()) - one).max(one);
247    let mut explained_variance = Array1::<T>::zeros(keep);
248    for i in 0..keep {
249        explained_variance[i] = (svd.singular_values[i] * svd.singular_values[i]) / denominator;
250    }
251
252    let total_variance = explained_variance
253        .iter()
254        .copied()
255        .fold(T::zero(), |acc, value| acc + value)
256        .max(T::epsilon());
257    let explained_variance_ratio = explained_variance.map(|value| *value / total_variance);
258
259    Ok(NdarrayPCAResult { components, explained_variance, explained_variance_ratio, mean, scores })
260}
261
262/// Compute principal component analysis from a matrix view.
263///
264/// # Errors
265/// Returns an error for invalid input or decomposition failure.
266#[cfg(feature = "lapack-provider")]
267pub fn compute_pca_view<T>(
268    matrix: &ArrayView2<'_, T>,
269    n_components: Option<usize>,
270) -> Result<NdarrayPCAResult<T>, PCAError>
271where
272    T: NabledReal + ndarray_linalg::Lapack<Real = T>,
273{
274    compute_pca_impl(matrix, n_components)
275}
276
277/// Compute principal component analysis from a matrix view.
278///
279/// # Errors
280/// Returns an error for invalid input or decomposition failure.
281#[cfg(not(feature = "lapack-provider"))]
282pub fn compute_pca_view<T: NabledReal>(
283    matrix: &ArrayView2<'_, T>,
284    n_components: Option<usize>,
285) -> Result<NdarrayPCAResult<T>, PCAError> {
286    compute_pca_impl(matrix, n_components)
287}
288
289/// Compute principal component analysis for complex matrices.
290///
291/// # Errors
292/// Returns an error for invalid input or decomposition failure.
293pub fn compute_pca_complex(
294    matrix: &Array2<Complex64>,
295    n_components: Option<usize>,
296) -> Result<NdarrayComplexPCAResult, PCAError> {
297    compute_pca_complex_impl(&matrix.view(), n_components)
298}
299
300fn compute_pca_complex_impl(
301    matrix: &ArrayView2<'_, Complex64>,
302    n_components: Option<usize>,
303) -> Result<NdarrayComplexPCAResult, PCAError> {
304    let (centered, mean) = center_columns_complex(matrix)?;
305    let svd = svd::decompose_complex(&centered).map_err(|_| PCAError::DecompositionFailed)?;
306
307    let max_components = centered.nrows().min(centered.ncols());
308    let keep = n_components.unwrap_or(max_components).min(max_components);
309    if keep == 0 {
310        return Err(PCAError::InvalidInput("n_components must be greater than 0".to_string()));
311    }
312
313    let components = svd.vt.slice(s![..keep, ..]).to_owned();
314    let projection = components.t().mapv(|value| value.conj());
315    let scores = centered.dot(&projection);
316
317    let denominator = (usize_to_real::<f64>(centered.nrows()) - 1.0_f64).max(1.0_f64);
318    let mut explained_variance = Array1::<f64>::zeros(keep);
319    for i in 0..keep {
320        explained_variance[i] = (svd.singular_values[i] * svd.singular_values[i]) / denominator;
321    }
322
323    let total_variance = explained_variance.iter().sum::<f64>().max(f64::EPSILON);
324    let explained_variance_ratio = explained_variance.map(|value| *value / total_variance);
325
326    Ok(NdarrayComplexPCAResult {
327        components,
328        explained_variance,
329        explained_variance_ratio,
330        mean,
331        scores,
332    })
333}
334
335/// Compute principal component analysis for complex matrices from a matrix view.
336///
337/// # Errors
338/// Returns an error for invalid input or decomposition failure.
339pub fn compute_pca_complex_view(
340    matrix: &ArrayView2<'_, Complex64>,
341    n_components: Option<usize>,
342) -> Result<NdarrayComplexPCAResult, PCAError> {
343    compute_pca_complex_impl(matrix, n_components)
344}
345
346/// Project data to PCA score space.
347#[must_use]
348pub fn transform<T: NabledReal>(matrix: &Array2<T>, pca: &NdarrayPCAResult<T>) -> Array2<T> {
349    transform_impl(&matrix.view(), pca)
350}
351
352/// Project data to PCA score space from a matrix view.
353#[must_use]
354pub fn transform_view<T: NabledReal>(
355    matrix: &ArrayView2<'_, T>,
356    pca: &NdarrayPCAResult<T>,
357) -> Array2<T> {
358    transform_impl(matrix, pca)
359}
360
361/// Reconstruct from PCA scores.
362#[must_use]
363pub fn inverse_transform<T: NabledReal>(
364    scores: &Array2<T>,
365    pca: &NdarrayPCAResult<T>,
366) -> Array2<T> {
367    inverse_transform_impl(&scores.view(), pca)
368}
369
370/// Reconstruct from PCA scores provided as a matrix view.
371#[must_use]
372pub fn inverse_transform_view<T: NabledReal>(
373    scores: &ArrayView2<'_, T>,
374    pca: &NdarrayPCAResult<T>,
375) -> Array2<T> {
376    inverse_transform_impl(scores, pca)
377}
378
379/// Project complex data to PCA score space.
380#[must_use]
381pub fn transform_complex(
382    matrix: &Array2<Complex64>,
383    pca: &NdarrayComplexPCAResult,
384) -> Array2<Complex64> {
385    transform_complex_impl(&matrix.view(), pca)
386}
387
388/// Project complex data to PCA score space from a matrix view.
389#[must_use]
390pub fn transform_complex_view(
391    matrix: &ArrayView2<'_, Complex64>,
392    pca: &NdarrayComplexPCAResult,
393) -> Array2<Complex64> {
394    transform_complex_impl(matrix, pca)
395}
396
397/// Reconstruct complex inputs from PCA scores.
398#[must_use]
399pub fn inverse_transform_complex(
400    scores: &Array2<Complex64>,
401    pca: &NdarrayComplexPCAResult,
402) -> Array2<Complex64> {
403    inverse_transform_complex_impl(&scores.view(), pca)
404}
405
406/// Reconstruct complex inputs from PCA scores provided as a matrix view.
407#[must_use]
408pub fn inverse_transform_complex_view(
409    scores: &ArrayView2<'_, Complex64>,
410    pca: &NdarrayComplexPCAResult,
411) -> Array2<Complex64> {
412    inverse_transform_complex_impl(scores, pca)
413}
414
415#[cfg(test)]
416mod tests {
417    use ndarray::Array2;
418    use num_complex::Complex64;
419
420    use super::*;
421
422    #[test]
423    fn pca_roundtrip_is_consistent() {
424        let matrix = Array2::<f64>::from_shape_vec((4, 2), vec![
425            1.0_f64, 2.0_f64, 2.0_f64, 1.0_f64, 3.0_f64, 4.0_f64, 4.0_f64, 3.0_f64,
426        ])
427        .unwrap();
428        let pca = compute_pca(&matrix, Some(2)).unwrap();
429        let transformed = transform(&matrix, &pca);
430        let reconstructed = inverse_transform(&transformed, &pca);
431        for i in 0..matrix.nrows() {
432            for j in 0..matrix.ncols() {
433                assert!((matrix[[i, j]] - reconstructed[[i, j]]).abs() < 1e-8_f64);
434            }
435        }
436    }
437
438    #[test]
439    fn pca_rejects_zero_components() {
440        let matrix = Array2::<f64>::from_shape_vec((4, 2), vec![
441            1.0_f64, 2.0_f64, 2.0_f64, 1.0_f64, 3.0_f64, 4.0_f64, 4.0_f64, 3.0_f64,
442        ])
443        .unwrap();
444        let result = compute_pca(&matrix, Some(0));
445        assert!(matches!(result, Err(PCAError::InvalidInput(_))));
446    }
447
448    #[test]
449    fn explained_variance_ratio_sums_to_one() {
450        let matrix = Array2::<f64>::from_shape_vec((4, 2), vec![
451            1.0_f64, 2.0_f64, 2.0_f64, 1.0_f64, 3.0_f64, 4.0_f64, 4.0_f64, 3.0_f64,
452        ])
453        .unwrap();
454        let pca = compute_pca(&matrix, Some(2)).unwrap();
455        let sum = pca.explained_variance_ratio.iter().sum::<f64>();
456        assert!((sum - 1.0_f64).abs() < 1e-10_f64);
457    }
458
459    #[test]
460    fn pca_view_variants_match_owned() {
461        let matrix = Array2::<f64>::from_shape_vec((4, 2), vec![
462            1.0_f64, 2.0_f64, 2.0_f64, 1.0_f64, 3.0_f64, 4.0_f64, 4.0_f64, 3.0_f64,
463        ])
464        .unwrap();
465        let pca_owned = compute_pca(&matrix, Some(2)).unwrap();
466        let pca_view = compute_pca_view(&matrix.view(), Some(2)).unwrap();
467
468        assert_eq!(pca_owned.components.dim(), pca_view.components.dim());
469        assert_eq!(pca_owned.scores.dim(), pca_view.scores.dim());
470
471        let transformed_owned = transform(&matrix, &pca_owned);
472        let transformed_view = transform_view(&matrix.view(), &pca_owned);
473        let reconstructed_owned = inverse_transform(&transformed_owned, &pca_owned);
474        let reconstructed_view = inverse_transform_view(&transformed_owned.view(), &pca_owned);
475
476        for i in 0..matrix.nrows() {
477            for j in 0..matrix.ncols() {
478                assert!((transformed_owned[[i, j]] - transformed_view[[i, j]]).abs() < 1e-12_f64);
479                assert!(
480                    (reconstructed_owned[[i, j]] - reconstructed_view[[i, j]]).abs() < 1e-12_f64
481                );
482            }
483        }
484    }
485
486    #[test]
487    fn pca_real_f32_paths_are_consistent() {
488        let matrix = Array2::<f32>::from_shape_vec((4, 2), vec![
489            1.0_f32, 2.0_f32, 2.0_f32, 1.0_f32, 3.0_f32, 4.0_f32, 4.0_f32, 3.0_f32,
490        ])
491        .unwrap();
492        let pca = compute_pca(&matrix, Some(2)).unwrap();
493        let transformed = transform(&matrix, &pca);
494        let reconstructed = inverse_transform(&transformed, &pca);
495
496        assert_eq!(pca.components.dim(), (2, 2));
497        assert_eq!(pca.explained_variance.len(), 2);
498        assert_eq!(pca.explained_variance_ratio.len(), 2);
499        for i in 0..matrix.nrows() {
500            for j in 0..matrix.ncols() {
501                assert!((matrix[[i, j]] - reconstructed[[i, j]]).abs() < 1e-4_f32);
502            }
503        }
504    }
505
506    #[test]
507    fn complex_pca_roundtrip_is_consistent() {
508        let matrix = Array2::from_shape_vec((4, 2), vec![
509            Complex64::new(1.0, 0.0),
510            Complex64::new(2.0, 0.5),
511            Complex64::new(2.0, -1.0),
512            Complex64::new(1.0, 0.2),
513            Complex64::new(3.0, 1.1),
514            Complex64::new(4.0, -0.3),
515            Complex64::new(4.0, 0.9),
516            Complex64::new(3.0, 0.4),
517        ])
518        .unwrap();
519
520        let pca = compute_pca_complex(&matrix, Some(2)).unwrap();
521        let transformed = transform_complex(&matrix, &pca);
522        let reconstructed = inverse_transform_complex(&transformed, &pca);
523        for i in 0..matrix.nrows() {
524            for j in 0..matrix.ncols() {
525                assert!((matrix[[i, j]] - reconstructed[[i, j]]).norm() < 1e-8);
526            }
527        }
528    }
529
530    #[test]
531    fn complex_pca_view_variants_match_owned() {
532        let matrix = Array2::from_shape_vec((4, 2), vec![
533            Complex64::new(1.0, 0.0),
534            Complex64::new(2.0, 0.5),
535            Complex64::new(2.0, -1.0),
536            Complex64::new(1.0, 0.2),
537            Complex64::new(3.0, 1.1),
538            Complex64::new(4.0, -0.3),
539            Complex64::new(4.0, 0.9),
540            Complex64::new(3.0, 0.4),
541        ])
542        .unwrap();
543
544        let pca_owned = compute_pca_complex(&matrix, Some(2)).unwrap();
545        let pca_view = compute_pca_complex_view(&matrix.view(), Some(2)).unwrap();
546        assert_eq!(pca_owned.components.dim(), pca_view.components.dim());
547        assert_eq!(pca_owned.scores.dim(), pca_view.scores.dim());
548
549        let transformed_owned = transform_complex(&matrix, &pca_owned);
550        let transformed_view = transform_complex_view(&matrix.view(), &pca_owned);
551        let reconstructed_owned = inverse_transform_complex(&transformed_owned, &pca_owned);
552        let reconstructed_view =
553            inverse_transform_complex_view(&transformed_owned.view(), &pca_owned);
554
555        for i in 0..matrix.nrows() {
556            for j in 0..matrix.ncols() {
557                assert!((transformed_owned[[i, j]] - transformed_view[[i, j]]).norm() < 1e-12);
558                assert!((reconstructed_owned[[i, j]] - reconstructed_view[[i, j]]).norm() < 1e-12);
559            }
560        }
561    }
562}