1use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::validation::*;
9
10#[derive(Debug, Clone)]
12pub struct PCA {
13 pub n_components: Option<usize>,
15 pub svd_solver: SvdSolver,
17 pub center: bool,
19 pub scale: bool,
21 pub random_state: Option<u64>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum SvdSolver {
28 Full,
30 Randomized,
32 Auto,
34}
35
36#[derive(Debug, Clone)]
38pub struct PCAResult {
39 pub components: Array2<f64>,
41 pub explained_variance: Array1<f64>,
43 pub explained_variance_ratio: Array1<f64>,
45 pub singular_values: Array1<f64>,
47 pub mean: Array1<f64>,
49 pub scale: Option<Array1<f64>>,
51 pub n_samples_: usize,
53 pub n_features: usize,
55}
56
57impl Default for PCA {
58 fn default() -> Self {
59 Self {
60 n_components: None,
61 svd_solver: SvdSolver::Auto,
62 center: true,
63 scale: false,
64 random_state: None,
65 }
66 }
67}
68
69impl PCA {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn with_n_components(mut self, n_components: usize) -> Self {
77 self.n_components = Some(n_components);
78 self
79 }
80
81 pub fn with_svd_solver(mut self, solver: SvdSolver) -> Self {
83 self.svd_solver = solver;
84 self
85 }
86
87 pub fn with_center(mut self, center: bool) -> Self {
89 self.center = center;
90 self
91 }
92
93 pub fn with_scale(mut self, scale: bool) -> Self {
95 self.scale = scale;
96 self
97 }
98
99 pub fn with_random_state(mut self, seed: u64) -> Self {
101 self.random_state = Some(seed);
102 self
103 }
104
105 pub fn fit(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
107 checkarray_finite(&data, "data")?;
108 let (n_samples, n_features) = data.dim();
109 if n_samples < 2 {
110 return Err(StatsError::InvalidArgument(
111 "n_samples must be at least 2".to_string(),
112 ));
113 }
114 if n_features < 1 {
115 return Err(StatsError::InvalidArgument(
116 "n_features must be at least 1".to_string(),
117 ));
118 }
119
120 let max_components = n_samples.min(n_features);
122 let n_components = match self.n_components {
123 Some(k) => {
124 check_positive(k, "n_components")?;
125 if k > max_components {
126 return Err(StatsError::InvalidArgument(format!(
127 "n_components ({}) cannot be larger than min(n_samples, n_features) = {}",
128 k, max_components
129 )));
130 }
131 k
132 }
133 None => max_components,
134 };
135
136 let mean = if self.center {
138 data.mean_axis(Axis(0)).unwrap()
139 } else {
140 Array1::zeros(n_features)
141 };
142
143 let mut centereddata = data.to_owned();
144 if self.center {
145 for mut row in centereddata.rows_mut() {
146 row -= &mean;
147 }
148 }
149
150 let scale = if self.scale {
152 let std = centereddata.std_axis(Axis(0), 1.0);
153 let std = std.mapv(|s| if s > 1e-10 { s } else { 1.0 });
155
156 for (mut col, &s) in centereddata.columns_mut().into_iter().zip(std.iter()) {
157 col /= s;
158 }
159 Some(std)
160 } else {
161 None
162 };
163
164 let solver = match self.svd_solver {
166 SvdSolver::Auto => {
167 if n_samples >= 500 && n_features >= 500 && n_components < max_components / 2 {
168 SvdSolver::Randomized
169 } else {
170 SvdSolver::Full
171 }
172 }
173 solver => solver,
174 };
175
176 let result = match solver {
178 SvdSolver::Full => self.pca_svd(¢ereddata, n_components, n_samples)?,
179 SvdSolver::Randomized => self.pca_randomized(¢ereddata, n_components, n_samples)?,
180 _ => unreachable!(),
181 };
182
183 Ok(PCAResult {
184 components: result.0,
185 explained_variance: result.1,
186 explained_variance_ratio: result.2,
187 singular_values: result.3,
188 mean,
189 scale,
190 n_samples_: n_samples,
191 n_features,
192 })
193 }
194
195 fn pca_svd(
197 &self,
198 data: &Array2<f64>,
199 n_components: usize,
200 n_samples: usize,
201 ) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
202 use scirs2_core::ndarray::ndarray_linalg::SVD;
203
204 let (_u, s, vt) = data
206 .svd(true, true)
207 .map_err(|e| StatsError::ComputationError(format!("SVD failed: {}", e)))?;
208 let v = vt.unwrap().t().to_owned();
209
210 let components = v
212 .slice(scirs2_core::ndarray::s![.., ..n_components])
213 .to_owned();
214
215 let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
217 let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
218
219 let total_variance = explained_variance.sum();
221 let explained_variance_ratio = &explained_variance / total_variance;
222
223 Ok((
224 components.t().to_owned(),
225 explained_variance,
226 explained_variance_ratio,
227 singular_values,
228 ))
229 }
230
231 fn pca_randomized(
233 &self,
234 data: &Array2<f64>,
235 n_components: usize,
236 n_samples: usize,
237 ) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
238 use scirs2_core::ndarray::ndarray_linalg::{QR, SVD};
239 use scirs2_core::random::{rngs::StdRng, SeedableRng};
240 use scirs2_core::random::{Distribution, Normal};
241
242 let n_features = data.ncols();
243 let n_oversamples = 10.min((n_features - n_components) / 2);
244 let n_random = n_components + n_oversamples;
245
246 let mut rng = match self.random_state {
248 Some(seed) => StdRng::seed_from_u64(seed),
249 None => {
250 use std::time::{SystemTime, UNIX_EPOCH};
252 let seed = SystemTime::now()
253 .duration_since(UNIX_EPOCH)
254 .unwrap_or_default()
255 .as_secs();
256 StdRng::seed_from_u64(seed)
257 }
258 };
259
260 let normal = Normal::new(0.0, 1.0).map_err(|e| {
262 StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
263 })?;
264 let omega = Array2::from_shape_fn((n_features, n_random), |_| normal.sample(&mut rng));
265
266 let n_iter = 4;
268 let mut q = data.dot(&omega);
269
270 for _ in 0..n_iter {
271 let (q_mat, r) = q.qr().map_err(|e| {
273 StatsError::ComputationError(format!("QR decomposition failed: {}", e))
274 })?;
275 q = q_mat;
276
277 let z = data.t().dot(&q);
279 let (q_mat, r) = z.qr().map_err(|e| {
280 StatsError::ComputationError(format!("QR decomposition failed: {}", e))
281 })?;
282 q = data.dot(&q_mat);
283 }
284
285 let (q_final, r) = q.qr().map_err(|e| {
287 StatsError::ComputationError(format!("Final QR decomposition failed: {}", e))
288 })?;
289
290 let b = q_final.t().dot(data);
292
293 let (_u_small, s, vt) = b.svd(true, true).map_err(|e| {
295 StatsError::ComputationError(format!("SVD of projected matrix failed: {}", e))
296 })?;
297
298 let v = vt.unwrap().t().to_owned();
299
300 let components = v
302 .slice(scirs2_core::ndarray::s![.., ..n_components])
303 .to_owned();
304
305 let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
307 let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
308
309 let total_variance = explained_variance.sum();
311 let explained_variance_ratio = &explained_variance / total_variance;
312
313 Ok((
314 components.t().to_owned(),
315 explained_variance,
316 explained_variance_ratio,
317 singular_values,
318 ))
319 }
320
321 pub fn transform(&self, data: ArrayView2<f64>, result: &PCAResult) -> Result<Array2<f64>> {
323 checkarray_finite(&data, "data")?;
324 if data.ncols() != result.n_features {
325 return Err(StatsError::DimensionMismatch(format!(
326 "data has {} features, expected {}",
327 data.ncols(),
328 result.n_features
329 )));
330 }
331
332 let mut transformed = data.to_owned();
333
334 if self.center {
336 for mut row in transformed.rows_mut() {
337 row -= &result.mean;
338 }
339 }
340
341 if let Some(ref scale) = result.scale {
343 for (mut col, &s) in transformed.columns_mut().into_iter().zip(scale.iter()) {
344 col /= s;
345 }
346 }
347
348 Ok(transformed.dot(&result.components.t()))
350 }
351
352 pub fn inverse_transform(
354 &self,
355 data: ArrayView2<f64>,
356 result: &PCAResult,
357 ) -> Result<Array2<f64>> {
358 checkarray_finite(&data, "data")?;
359 let n_components = result.components.nrows();
360 if data.ncols() != n_components {
361 return Err(StatsError::DimensionMismatch(format!(
362 "data has {} components, expected {}",
363 data.ncols(),
364 n_components
365 )));
366 }
367
368 let mut reconstructed = data.dot(&result.components);
370
371 if let Some(ref scale) = result.scale {
373 for (mut col, &s) in reconstructed.columns_mut().into_iter().zip(scale.iter()) {
374 col *= s;
375 }
376 }
377
378 if self.center {
380 for mut row in reconstructed.rows_mut() {
381 row += &result.mean;
382 }
383 }
384
385 Ok(reconstructed)
386 }
387
388 pub fn fit_transform(&self, data: ArrayView2<f64>) -> Result<(Array2<f64>, PCAResult)> {
390 let result = self.fit(data)?;
391 let transformed = self.transform(data, &result)?;
392 Ok((transformed, result))
393 }
394}
395
396#[allow(dead_code)]
398pub fn mle_components(data: ArrayView2<f64>, maxcomponents: Option<usize>) -> Result<usize> {
399 checkarray_finite(&data, "data")?;
400 let (n_samples, n_features) = data.dim();
401
402 let pca = PCA::new().with_n_components(maxcomponents.unwrap_or(n_features.min(n_samples)));
403 let result = pca.fit(data)?;
404
405 let eigenvalues = &result.explained_variance;
406 let n = n_samples as f64;
407 let p = n_features as f64;
408
409 let mut best_k = 0;
411 let mut best_ll = f64::NEG_INFINITY;
412
413 for k in 0..eigenvalues.len() {
414 let k_f64 = k as f64;
415
416 let sigma2 = if k < eigenvalues.len() - 1 {
418 eigenvalues.slice(scirs2_core::ndarray::s![k + 1..]).sum() / (p - k_f64 - 1.0)
419 } else {
420 1e-10
421 };
422
423 let ll = -n / 2.0
425 * (eigenvalues
426 .slice(scirs2_core::ndarray::s![..=k])
427 .mapv(f64::ln)
428 .sum()
429 + (p - k_f64 - 1.0) * sigma2.ln()
430 + p * (2.0 * std::f64::consts::PI).ln());
431
432 let aic_penalty = k_f64 * (2.0 * p - k_f64 - 1.0);
434 let aic = ll - aic_penalty;
435
436 if aic > best_ll {
437 best_ll = aic;
438 best_k = k + 1;
439 }
440 }
441
442 Ok(best_k)
443}
444
445#[derive(Debug, Clone)]
447pub struct IncrementalPCA {
448 pub pca: PCA,
450 pub batchsize: usize,
452 mean: Option<Array1<f64>>,
454 components: Option<Array2<f64>>,
456 singular_values: Option<Array1<f64>>,
458 n_samples_seen: usize,
460 svd_u: Option<Array2<f64>>,
462 svd_s: Option<Array1<f64>>,
463 svd_v: Option<Array2<f64>>,
464}
465
466impl IncrementalPCA {
467 pub fn new(n_components: usize, batchsize: usize) -> Result<Self> {
469 check_positive(n_components, "n_components")?;
470 check_positive(batchsize, "batchsize")?;
471
472 Ok(Self {
473 pca: PCA::new().with_n_components(n_components),
474 batchsize,
475 mean: None,
476 components: None,
477 singular_values: None,
478 n_samples_seen: 0,
479 svd_u: None,
480 svd_s: None,
481 svd_v: None,
482 })
483 }
484
485 pub fn partial_fit(&mut self, batch: ArrayView2<f64>) -> Result<()> {
487 checkarray_finite(&batch, "batch")?;
488 let (batchsize, n_features) = batch.dim();
489
490 let batch_mean = batch.mean_axis(Axis(0)).unwrap();
492 let old_n = self.n_samples_seen;
493 self.n_samples_seen += batchsize;
494
495 self.mean = match &self.mean {
496 None => Some(batch_mean.clone()),
497 Some(mean) => {
498 let updated = (mean * old_n as f64 + &batch_mean * batchsize as f64)
499 / self.n_samples_seen as f64;
500 Some(updated)
501 }
502 };
503
504 let mut centered_batch = batch.to_owned();
506 for mut row in centered_batch.rows_mut() {
507 row -= &batch_mean;
508 }
509
510 let n_components = self
512 .pca
513 .n_components
514 .unwrap_or(n_features.min(self.n_samples_seen));
515
516 if self.svd_u.is_none() {
517 use scirs2_core::ndarray::ndarray_linalg::SVD;
519 let (u, s, vt) = centered_batch
520 .svd(true, true)
521 .map_err(|e| StatsError::ComputationError(format!("Initial SVD failed: {}", e)))?;
522
523 let u = u.unwrap();
524 let vt = vt.unwrap();
525
526 self.svd_u = Some(
528 u.slice(scirs2_core::ndarray::s![.., ..n_components])
529 .to_owned(),
530 );
531 self.svd_s = Some(s.slice(scirs2_core::ndarray::s![..n_components]).to_owned());
532 self.svd_v = Some(
533 vt.slice(scirs2_core::ndarray::s![..n_components, ..])
534 .t()
535 .to_owned(),
536 );
537
538 self.components = Some(self.svd_v.as_ref().unwrap().t().to_owned());
539 self.singular_values = Some(self.svd_s.as_ref().unwrap().clone());
540 } else {
541 let u_old = self.svd_u.as_ref().unwrap();
543 let s_old = self.svd_s.as_ref().unwrap();
544 let v_old = self.svd_v.as_ref().unwrap();
545
546 let projection = centered_batch.dot(v_old);
548 let residual = ¢ered_batch - &projection.dot(&v_old.t());
549
550 use scirs2_core::ndarray::ndarray_linalg::QR;
552 let (q_res, r_res) = residual.qr().map_err(|e| {
553 StatsError::ComputationError(format!("QR decomposition failed: {}", e))
554 })?;
555
556 let k = s_old.len();
558 let p = r_res.ncols();
559
560 let mut augmented = Array2::zeros((k + p, k + p));
562 for i in 0..k {
563 augmented[[i, i]] = s_old[i];
564 }
565 for i in 0..projection.nrows() {
566 for j in 0..k {
567 augmented[[j, k + i]] = projection[[i, j]];
568 }
569 }
570 for i in 0..p {
571 for j in 0..p {
572 augmented[[k + i, k + j]] = r_res[[i, j]];
573 }
574 }
575
576 use scirs2_core::ndarray::ndarray_linalg::SVD;
578 let (u_aug, s_aug, vt_aug) = augmented.svd(true, true).map_err(|e| {
579 StatsError::ComputationError(format!("Augmented SVD failed: {}", e))
580 })?;
581
582 let u_aug = u_aug.unwrap();
583 let vt_aug = vt_aug.unwrap();
584
585 let mut u_new = Array2::zeros((old_n + batchsize, n_components));
587 let u_aug_slice = u_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
588
589 let u_old_part = u_old.dot(&u_aug_slice.t());
591 u_new
592 .slice_mut(scirs2_core::ndarray::s![..old_n, ..])
593 .assign(&u_old_part);
594
595 let u_batch_part =
597 projection.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
598 let u_res_part = q_res.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
599 u_new
600 .slice_mut(scirs2_core::ndarray::s![old_n.., ..])
601 .assign(&(&u_batch_part + &u_res_part));
602
603 self.svd_s = Some(
605 s_aug
606 .slice(scirs2_core::ndarray::s![..n_components])
607 .to_owned(),
608 );
609
610 let v_aug_slice =
612 vt_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
613 let mut v_new = Array2::zeros((n_features, n_components));
614
615 let v_old_part = v_old.dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
616 let v_res_part = q_res
617 .t()
618 .dot(¢ered_batch)
619 .t()
620 .dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
621 v_new.assign(&(&v_old_part + &v_res_part));
622
623 self.svd_u = Some(u_new);
624 self.svd_v = Some(v_new.clone());
625 self.components = Some(v_new.t().to_owned());
626 self.singular_values = Some(self.svd_s.as_ref().unwrap().clone());
627 }
628
629 Ok(())
630 }
631
632 pub fn transform(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
634 if self.components.is_none() || self.mean.is_none() {
635 return Err(StatsError::ComputationError(
636 "IncrementalPCA must be fitted before transform".to_string(),
637 ));
638 }
639
640 let mut centered = data.to_owned();
641 for mut row in centered.rows_mut() {
642 row -= self.mean.as_ref().unwrap();
643 }
644
645 Ok(centered.dot(&self.components.as_ref().unwrap().t()))
646 }
647}