quantrs2_ml/sklearn_compatibility/
decomposition.rs1use super::SklearnEstimator;
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use std::collections::HashMap;
7
8pub struct PCA {
10 n_components: usize,
11 components: Option<Array2<f64>>,
12 mean: Option<Array1<f64>>,
13 explained_variance: Option<Array1<f64>>,
14 explained_variance_ratio: Option<Array1<f64>>,
15 fitted: bool,
16}
17
18impl PCA {
19 pub fn new(n_components: usize) -> Self {
21 Self {
22 n_components,
23 components: None,
24 mean: None,
25 explained_variance: None,
26 explained_variance_ratio: None,
27 fitted: false,
28 }
29 }
30
31 #[allow(non_snake_case)]
33 pub fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
34 let n_samples = X.nrows();
35 let n_features = X.ncols();
36
37 let mean = X
39 .mean_axis(Axis(0))
40 .ok_or_else(|| MLError::InvalidConfiguration("Failed to compute mean".to_string()))?;
41
42 let mut centered = X.clone();
43 for i in 0..n_samples {
44 for j in 0..n_features {
45 centered[[i, j]] -= mean[j];
46 }
47 }
48
49 let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
51
52 let n_comp = self.n_components.min(n_features);
54 let mut components = Array2::zeros((n_comp, n_features));
55 let mut variances = Array1::zeros(n_comp);
56
57 for k in 0..n_comp {
58 let mut v = Array1::from_vec((0..n_features).map(|i| ((i + k) as f64).sin()).collect());
60 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
61 v.mapv_inplace(|x| x / norm);
62
63 for _ in 0..100 {
65 let mut new_v = cov.dot(&v);
66
67 for prev_k in 0..k {
69 let prev_comp = components.row(prev_k);
70 let proj: f64 = new_v.iter().zip(prev_comp.iter()).map(|(a, b)| a * b).sum();
71 for (i, val) in new_v.iter_mut().enumerate() {
72 *val -= proj * prev_comp[i];
73 }
74 }
75
76 let norm: f64 = new_v.iter().map(|x| x * x).sum::<f64>().sqrt();
77 if norm > 1e-10 {
78 new_v.mapv_inplace(|x| x / norm);
79 }
80 v = new_v;
81 }
82
83 for j in 0..n_features {
85 components[[k, j]] = v[j];
86 }
87 variances[k] = cov.dot(&v).dot(&v);
88 }
89
90 let total_var: f64 = variances.sum();
91 let variance_ratio = variances.mapv(|v| v / total_var);
92
93 self.components = Some(components);
94 self.mean = Some(mean);
95 self.explained_variance = Some(variances);
96 self.explained_variance_ratio = Some(variance_ratio);
97 self.fitted = true;
98
99 Ok(())
100 }
101
102 #[allow(non_snake_case)]
104 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
105 let components = self
106 .components
107 .as_ref()
108 .ok_or_else(|| MLError::ModelNotTrained("PCA not fitted".to_string()))?;
109 let mean = self
110 .mean
111 .as_ref()
112 .ok_or_else(|| MLError::ModelNotTrained("PCA not fitted".to_string()))?;
113
114 let n_samples = X.nrows();
116 let mut centered = X.clone();
117 for i in 0..n_samples {
118 for j in 0..X.ncols() {
119 centered[[i, j]] -= mean[j];
120 }
121 }
122
123 Ok(centered.dot(&components.t()))
124 }
125
126 #[allow(non_snake_case)]
128 pub fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
129 self.fit(X)?;
130 self.transform(X)
131 }
132
133 pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
135 self.explained_variance_ratio.as_ref()
136 }
137
138 pub fn components(&self) -> Option<&Array2<f64>> {
140 self.components.as_ref()
141 }
142}
143
144impl SklearnEstimator for PCA {
145 #[allow(non_snake_case)]
146 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
147 PCA::fit(self, X)
148 }
149
150 fn get_params(&self) -> HashMap<String, String> {
151 let mut params = HashMap::new();
152 params.insert("n_components".to_string(), self.n_components.to_string());
153 params
154 }
155
156 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
157 if let Some(n) = params.get("n_components") {
158 self.n_components = n
159 .parse()
160 .map_err(|_| MLError::InvalidConfiguration("Invalid n_components".to_string()))?;
161 }
162 Ok(())
163 }
164
165 fn is_fitted(&self) -> bool {
166 self.fitted
167 }
168}