1use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use nabled_linalg::svd;
7use ndarray::{Array1, Array2, ArrayView2, Axis, s};
8use num_complex::Complex64;
9
10#[derive(Debug, Clone)]
12pub struct NdarrayPCAResult<T: NabledReal> {
13 pub components: Array2<T>,
15 pub explained_variance: Array1<T>,
17 pub explained_variance_ratio: Array1<T>,
19 pub mean: Array1<T>,
21 pub scores: Array2<T>,
23}
24
25#[derive(Debug, Clone)]
27pub struct NdarrayComplexPCAResult {
28 pub components: Array2<Complex64>,
30 pub explained_variance: Array1<f64>,
32 pub explained_variance_ratio: Array1<f64>,
34 pub mean: Array1<Complex64>,
36 pub scores: Array2<Complex64>,
38}
39
40#[derive(Debug, Clone, PartialEq)]
42pub enum PCAError {
43 EmptyMatrix,
45 InvalidInput(String),
47 DecompositionFailed,
49}
50
51impl fmt::Display for PCAError {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 PCAError::EmptyMatrix => write!(f, "Matrix cannot be empty"),
55 PCAError::InvalidInput(message) => write!(f, "Invalid input: {message}"),
56 PCAError::DecompositionFailed => write!(f, "PCA decomposition failed"),
57 }
58 }
59}
60
61impl std::error::Error for PCAError {}
62
63fn usize_to_real<T: NabledReal>(value: usize) -> T {
64 let fallback = T::from_u32(u32::MAX).unwrap_or(T::one());
65 T::from_usize(value).unwrap_or(fallback)
66}
67
68fn center_columns<T: NabledReal>(
69 matrix: &ArrayView2<'_, T>,
70) -> Result<(Array2<T>, Array1<T>), PCAError> {
71 if matrix.is_empty() {
72 return Err(PCAError::EmptyMatrix);
73 }
74 let mean = matrix
75 .mean_axis(Axis(0))
76 .ok_or_else(|| PCAError::InvalidInput("failed to compute column means".to_string()))?;
77 let mut centered = matrix.to_owned();
78 for row in 0..matrix.nrows() {
79 for col in 0..matrix.ncols() {
80 centered[[row, col]] -= mean[col];
81 }
82 }
83 Ok((centered, mean))
84}
85
86fn transform_impl<T: NabledReal>(
87 matrix: &ArrayView2<'_, T>,
88 pca: &NdarrayPCAResult<T>,
89) -> Array2<T> {
90 let mut centered = Array2::<T>::zeros((matrix.nrows(), matrix.ncols()));
91 for row in 0..matrix.nrows() {
92 for col in 0..matrix.ncols() {
93 centered[[row, col]] = matrix[[row, col]] - pca.mean[col];
94 }
95 }
96 centered.dot(&pca.components.t())
97}
98
99fn inverse_transform_impl<T: NabledReal>(
100 scores: &ArrayView2<'_, T>,
101 pca: &NdarrayPCAResult<T>,
102) -> Array2<T> {
103 let mut reconstructed = scores.dot(&pca.components);
104 for row in 0..reconstructed.nrows() {
105 for col in 0..reconstructed.ncols() {
106 reconstructed[[row, col]] += pca.mean[col];
107 }
108 }
109 reconstructed
110}
111
112fn center_columns_complex(
113 matrix: &ArrayView2<'_, Complex64>,
114) -> Result<(Array2<Complex64>, Array1<Complex64>), PCAError> {
115 if matrix.is_empty() {
116 return Err(PCAError::EmptyMatrix);
117 }
118 let mut mean = Array1::<Complex64>::zeros(matrix.ncols());
119 for col in 0..matrix.ncols() {
120 let mut sum = Complex64::new(0.0, 0.0);
121 for row in 0..matrix.nrows() {
122 sum += matrix[[row, col]];
123 }
124 mean[col] = sum / usize_to_real::<f64>(matrix.nrows());
125 }
126
127 let mut centered = matrix.to_owned();
128 for row in 0..matrix.nrows() {
129 for col in 0..matrix.ncols() {
130 centered[[row, col]] -= mean[col];
131 }
132 }
133 Ok((centered, mean))
134}
135
136fn transform_complex_impl(
137 matrix: &ArrayView2<'_, Complex64>,
138 pca: &NdarrayComplexPCAResult,
139) -> Array2<Complex64> {
140 let mut centered = Array2::<Complex64>::zeros((matrix.nrows(), matrix.ncols()));
141 for row in 0..matrix.nrows() {
142 for col in 0..matrix.ncols() {
143 centered[[row, col]] = matrix[[row, col]] - pca.mean[col];
144 }
145 }
146
147 let projection = pca.components.t().mapv(|value| value.conj());
148 centered.dot(&projection)
149}
150
151fn inverse_transform_complex_impl(
152 scores: &ArrayView2<'_, Complex64>,
153 pca: &NdarrayComplexPCAResult,
154) -> Array2<Complex64> {
155 let mut reconstructed = scores.dot(&pca.components);
156 for row in 0..reconstructed.nrows() {
157 for col in 0..reconstructed.ncols() {
158 reconstructed[[row, col]] += pca.mean[col];
159 }
160 }
161 reconstructed
162}
163
164#[cfg(feature = "lapack-provider")]
169pub fn compute_pca<T>(
170 matrix: &Array2<T>,
171 n_components: Option<usize>,
172) -> Result<NdarrayPCAResult<T>, PCAError>
173where
174 T: NabledReal + ndarray_linalg::Lapack<Real = T>,
175{
176 compute_pca_impl(&matrix.view(), n_components)
177}
178
179#[cfg(not(feature = "lapack-provider"))]
184pub fn compute_pca<T: NabledReal>(
185 matrix: &Array2<T>,
186 n_components: Option<usize>,
187) -> Result<NdarrayPCAResult<T>, PCAError> {
188 compute_pca_impl(&matrix.view(), n_components)
189}
190
191#[cfg(feature = "lapack-provider")]
192fn compute_pca_impl<T>(
193 matrix: &ArrayView2<'_, T>,
194 n_components: Option<usize>,
195) -> Result<NdarrayPCAResult<T>, PCAError>
196where
197 T: NabledReal + ndarray_linalg::Lapack<Real = T>,
198{
199 let (centered, mean) = center_columns(matrix)?;
200 let svd = svd::decompose(¢ered).map_err(|_| PCAError::DecompositionFailed)?;
201
202 let max_components = centered.nrows().min(centered.ncols());
203 let keep = n_components.unwrap_or(max_components).min(max_components);
204 if keep == 0 {
205 return Err(PCAError::InvalidInput("n_components must be greater than 0".to_string()));
206 }
207
208 let components = svd.vt.slice(s![..keep, ..]).to_owned();
209 let scores = centered.dot(&components.t());
210
211 let one = T::one();
212 let denominator = (usize_to_real::<T>(centered.nrows()) - one).max(one);
213 let mut explained_variance = Array1::<T>::zeros(keep);
214 for i in 0..keep {
215 explained_variance[i] = (svd.singular_values[i] * svd.singular_values[i]) / denominator;
216 }
217
218 let total_variance = explained_variance
219 .iter()
220 .copied()
221 .fold(T::zero(), |acc, value| acc + value)
222 .max(T::epsilon());
223 let explained_variance_ratio = explained_variance.map(|value| *value / total_variance);
224
225 Ok(NdarrayPCAResult { components, explained_variance, explained_variance_ratio, mean, scores })
226}
227
228#[cfg(not(feature = "lapack-provider"))]
229fn compute_pca_impl<T: NabledReal>(
230 matrix: &ArrayView2<'_, T>,
231 n_components: Option<usize>,
232) -> Result<NdarrayPCAResult<T>, PCAError> {
233 let (centered, mean) = center_columns(matrix)?;
234 let svd = svd::decompose(¢ered).map_err(|_| PCAError::DecompositionFailed)?;
235
236 let max_components = centered.nrows().min(centered.ncols());
237 let keep = n_components.unwrap_or(max_components).min(max_components);
238 if keep == 0 {
239 return Err(PCAError::InvalidInput("n_components must be greater than 0".to_string()));
240 }
241
242 let components = svd.vt.slice(s![..keep, ..]).to_owned();
243 let scores = centered.dot(&components.t());
244
245 let one = T::one();
246 let denominator = (usize_to_real::<T>(centered.nrows()) - one).max(one);
247 let mut explained_variance = Array1::<T>::zeros(keep);
248 for i in 0..keep {
249 explained_variance[i] = (svd.singular_values[i] * svd.singular_values[i]) / denominator;
250 }
251
252 let total_variance = explained_variance
253 .iter()
254 .copied()
255 .fold(T::zero(), |acc, value| acc + value)
256 .max(T::epsilon());
257 let explained_variance_ratio = explained_variance.map(|value| *value / total_variance);
258
259 Ok(NdarrayPCAResult { components, explained_variance, explained_variance_ratio, mean, scores })
260}
261
262#[cfg(feature = "lapack-provider")]
267pub fn compute_pca_view<T>(
268 matrix: &ArrayView2<'_, T>,
269 n_components: Option<usize>,
270) -> Result<NdarrayPCAResult<T>, PCAError>
271where
272 T: NabledReal + ndarray_linalg::Lapack<Real = T>,
273{
274 compute_pca_impl(matrix, n_components)
275}
276
277#[cfg(not(feature = "lapack-provider"))]
282pub fn compute_pca_view<T: NabledReal>(
283 matrix: &ArrayView2<'_, T>,
284 n_components: Option<usize>,
285) -> Result<NdarrayPCAResult<T>, PCAError> {
286 compute_pca_impl(matrix, n_components)
287}
288
289pub fn compute_pca_complex(
294 matrix: &Array2<Complex64>,
295 n_components: Option<usize>,
296) -> Result<NdarrayComplexPCAResult, PCAError> {
297 compute_pca_complex_impl(&matrix.view(), n_components)
298}
299
300fn compute_pca_complex_impl(
301 matrix: &ArrayView2<'_, Complex64>,
302 n_components: Option<usize>,
303) -> Result<NdarrayComplexPCAResult, PCAError> {
304 let (centered, mean) = center_columns_complex(matrix)?;
305 let svd = svd::decompose_complex(¢ered).map_err(|_| PCAError::DecompositionFailed)?;
306
307 let max_components = centered.nrows().min(centered.ncols());
308 let keep = n_components.unwrap_or(max_components).min(max_components);
309 if keep == 0 {
310 return Err(PCAError::InvalidInput("n_components must be greater than 0".to_string()));
311 }
312
313 let components = svd.vt.slice(s![..keep, ..]).to_owned();
314 let projection = components.t().mapv(|value| value.conj());
315 let scores = centered.dot(&projection);
316
317 let denominator = (usize_to_real::<f64>(centered.nrows()) - 1.0_f64).max(1.0_f64);
318 let mut explained_variance = Array1::<f64>::zeros(keep);
319 for i in 0..keep {
320 explained_variance[i] = (svd.singular_values[i] * svd.singular_values[i]) / denominator;
321 }
322
323 let total_variance = explained_variance.iter().sum::<f64>().max(f64::EPSILON);
324 let explained_variance_ratio = explained_variance.map(|value| *value / total_variance);
325
326 Ok(NdarrayComplexPCAResult {
327 components,
328 explained_variance,
329 explained_variance_ratio,
330 mean,
331 scores,
332 })
333}
334
335pub fn compute_pca_complex_view(
340 matrix: &ArrayView2<'_, Complex64>,
341 n_components: Option<usize>,
342) -> Result<NdarrayComplexPCAResult, PCAError> {
343 compute_pca_complex_impl(matrix, n_components)
344}
345
346#[must_use]
348pub fn transform<T: NabledReal>(matrix: &Array2<T>, pca: &NdarrayPCAResult<T>) -> Array2<T> {
349 transform_impl(&matrix.view(), pca)
350}
351
352#[must_use]
354pub fn transform_view<T: NabledReal>(
355 matrix: &ArrayView2<'_, T>,
356 pca: &NdarrayPCAResult<T>,
357) -> Array2<T> {
358 transform_impl(matrix, pca)
359}
360
361#[must_use]
363pub fn inverse_transform<T: NabledReal>(
364 scores: &Array2<T>,
365 pca: &NdarrayPCAResult<T>,
366) -> Array2<T> {
367 inverse_transform_impl(&scores.view(), pca)
368}
369
370#[must_use]
372pub fn inverse_transform_view<T: NabledReal>(
373 scores: &ArrayView2<'_, T>,
374 pca: &NdarrayPCAResult<T>,
375) -> Array2<T> {
376 inverse_transform_impl(scores, pca)
377}
378
379#[must_use]
381pub fn transform_complex(
382 matrix: &Array2<Complex64>,
383 pca: &NdarrayComplexPCAResult,
384) -> Array2<Complex64> {
385 transform_complex_impl(&matrix.view(), pca)
386}
387
388#[must_use]
390pub fn transform_complex_view(
391 matrix: &ArrayView2<'_, Complex64>,
392 pca: &NdarrayComplexPCAResult,
393) -> Array2<Complex64> {
394 transform_complex_impl(matrix, pca)
395}
396
397#[must_use]
399pub fn inverse_transform_complex(
400 scores: &Array2<Complex64>,
401 pca: &NdarrayComplexPCAResult,
402) -> Array2<Complex64> {
403 inverse_transform_complex_impl(&scores.view(), pca)
404}
405
406#[must_use]
408pub fn inverse_transform_complex_view(
409 scores: &ArrayView2<'_, Complex64>,
410 pca: &NdarrayComplexPCAResult,
411) -> Array2<Complex64> {
412 inverse_transform_complex_impl(scores, pca)
413}
414
415#[cfg(test)]
416mod tests {
417 use ndarray::Array2;
418 use num_complex::Complex64;
419
420 use super::*;
421
422 #[test]
423 fn pca_roundtrip_is_consistent() {
424 let matrix = Array2::<f64>::from_shape_vec((4, 2), vec![
425 1.0_f64, 2.0_f64, 2.0_f64, 1.0_f64, 3.0_f64, 4.0_f64, 4.0_f64, 3.0_f64,
426 ])
427 .unwrap();
428 let pca = compute_pca(&matrix, Some(2)).unwrap();
429 let transformed = transform(&matrix, &pca);
430 let reconstructed = inverse_transform(&transformed, &pca);
431 for i in 0..matrix.nrows() {
432 for j in 0..matrix.ncols() {
433 assert!((matrix[[i, j]] - reconstructed[[i, j]]).abs() < 1e-8_f64);
434 }
435 }
436 }
437
438 #[test]
439 fn pca_rejects_zero_components() {
440 let matrix = Array2::<f64>::from_shape_vec((4, 2), vec![
441 1.0_f64, 2.0_f64, 2.0_f64, 1.0_f64, 3.0_f64, 4.0_f64, 4.0_f64, 3.0_f64,
442 ])
443 .unwrap();
444 let result = compute_pca(&matrix, Some(0));
445 assert!(matches!(result, Err(PCAError::InvalidInput(_))));
446 }
447
448 #[test]
449 fn explained_variance_ratio_sums_to_one() {
450 let matrix = Array2::<f64>::from_shape_vec((4, 2), vec![
451 1.0_f64, 2.0_f64, 2.0_f64, 1.0_f64, 3.0_f64, 4.0_f64, 4.0_f64, 3.0_f64,
452 ])
453 .unwrap();
454 let pca = compute_pca(&matrix, Some(2)).unwrap();
455 let sum = pca.explained_variance_ratio.iter().sum::<f64>();
456 assert!((sum - 1.0_f64).abs() < 1e-10_f64);
457 }
458
459 #[test]
460 fn pca_view_variants_match_owned() {
461 let matrix = Array2::<f64>::from_shape_vec((4, 2), vec![
462 1.0_f64, 2.0_f64, 2.0_f64, 1.0_f64, 3.0_f64, 4.0_f64, 4.0_f64, 3.0_f64,
463 ])
464 .unwrap();
465 let pca_owned = compute_pca(&matrix, Some(2)).unwrap();
466 let pca_view = compute_pca_view(&matrix.view(), Some(2)).unwrap();
467
468 assert_eq!(pca_owned.components.dim(), pca_view.components.dim());
469 assert_eq!(pca_owned.scores.dim(), pca_view.scores.dim());
470
471 let transformed_owned = transform(&matrix, &pca_owned);
472 let transformed_view = transform_view(&matrix.view(), &pca_owned);
473 let reconstructed_owned = inverse_transform(&transformed_owned, &pca_owned);
474 let reconstructed_view = inverse_transform_view(&transformed_owned.view(), &pca_owned);
475
476 for i in 0..matrix.nrows() {
477 for j in 0..matrix.ncols() {
478 assert!((transformed_owned[[i, j]] - transformed_view[[i, j]]).abs() < 1e-12_f64);
479 assert!(
480 (reconstructed_owned[[i, j]] - reconstructed_view[[i, j]]).abs() < 1e-12_f64
481 );
482 }
483 }
484 }
485
486 #[test]
487 fn pca_real_f32_paths_are_consistent() {
488 let matrix = Array2::<f32>::from_shape_vec((4, 2), vec![
489 1.0_f32, 2.0_f32, 2.0_f32, 1.0_f32, 3.0_f32, 4.0_f32, 4.0_f32, 3.0_f32,
490 ])
491 .unwrap();
492 let pca = compute_pca(&matrix, Some(2)).unwrap();
493 let transformed = transform(&matrix, &pca);
494 let reconstructed = inverse_transform(&transformed, &pca);
495
496 assert_eq!(pca.components.dim(), (2, 2));
497 assert_eq!(pca.explained_variance.len(), 2);
498 assert_eq!(pca.explained_variance_ratio.len(), 2);
499 for i in 0..matrix.nrows() {
500 for j in 0..matrix.ncols() {
501 assert!((matrix[[i, j]] - reconstructed[[i, j]]).abs() < 1e-4_f32);
502 }
503 }
504 }
505
506 #[test]
507 fn complex_pca_roundtrip_is_consistent() {
508 let matrix = Array2::from_shape_vec((4, 2), vec![
509 Complex64::new(1.0, 0.0),
510 Complex64::new(2.0, 0.5),
511 Complex64::new(2.0, -1.0),
512 Complex64::new(1.0, 0.2),
513 Complex64::new(3.0, 1.1),
514 Complex64::new(4.0, -0.3),
515 Complex64::new(4.0, 0.9),
516 Complex64::new(3.0, 0.4),
517 ])
518 .unwrap();
519
520 let pca = compute_pca_complex(&matrix, Some(2)).unwrap();
521 let transformed = transform_complex(&matrix, &pca);
522 let reconstructed = inverse_transform_complex(&transformed, &pca);
523 for i in 0..matrix.nrows() {
524 for j in 0..matrix.ncols() {
525 assert!((matrix[[i, j]] - reconstructed[[i, j]]).norm() < 1e-8);
526 }
527 }
528 }
529
530 #[test]
531 fn complex_pca_view_variants_match_owned() {
532 let matrix = Array2::from_shape_vec((4, 2), vec![
533 Complex64::new(1.0, 0.0),
534 Complex64::new(2.0, 0.5),
535 Complex64::new(2.0, -1.0),
536 Complex64::new(1.0, 0.2),
537 Complex64::new(3.0, 1.1),
538 Complex64::new(4.0, -0.3),
539 Complex64::new(4.0, 0.9),
540 Complex64::new(3.0, 0.4),
541 ])
542 .unwrap();
543
544 let pca_owned = compute_pca_complex(&matrix, Some(2)).unwrap();
545 let pca_view = compute_pca_complex_view(&matrix.view(), Some(2)).unwrap();
546 assert_eq!(pca_owned.components.dim(), pca_view.components.dim());
547 assert_eq!(pca_owned.scores.dim(), pca_view.scores.dim());
548
549 let transformed_owned = transform_complex(&matrix, &pca_owned);
550 let transformed_view = transform_complex_view(&matrix.view(), &pca_owned);
551 let reconstructed_owned = inverse_transform_complex(&transformed_owned, &pca_owned);
552 let reconstructed_view =
553 inverse_transform_complex_view(&transformed_owned.view(), &pca_owned);
554
555 for i in 0..matrix.nrows() {
556 for j in 0..matrix.ncols() {
557 assert!((transformed_owned[[i, j]] - transformed_view[[i, j]]).norm() < 1e-12);
558 assert!((reconstructed_owned[[i, j]] - reconstructed_view[[i, j]]).norm() < 1e-12);
559 }
560 }
561 }
562}