quantrs2_ml/sklearn_compatibility/
decomposition.rs

1//! Sklearn-compatible decomposition algorithms
2
3use super::SklearnEstimator;
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use std::collections::HashMap;
7
8/// Principal Component Analysis
9pub struct PCA {
10    n_components: usize,
11    components: Option<Array2<f64>>,
12    mean: Option<Array1<f64>>,
13    explained_variance: Option<Array1<f64>>,
14    explained_variance_ratio: Option<Array1<f64>>,
15    fitted: bool,
16}
17
18impl PCA {
19    /// Create new PCA
20    pub fn new(n_components: usize) -> Self {
21        Self {
22            n_components,
23            components: None,
24            mean: None,
25            explained_variance: None,
26            explained_variance_ratio: None,
27            fitted: false,
28        }
29    }
30
31    /// Fit PCA
32    #[allow(non_snake_case)]
33    pub fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
34        let n_samples = X.nrows();
35        let n_features = X.ncols();
36
37        // Center data
38        let mean = X
39            .mean_axis(Axis(0))
40            .ok_or_else(|| MLError::InvalidConfiguration("Failed to compute mean".to_string()))?;
41
42        let mut centered = X.clone();
43        for i in 0..n_samples {
44            for j in 0..n_features {
45                centered[[i, j]] -= mean[j];
46            }
47        }
48
49        // Compute covariance matrix (simplified)
50        let cov = centered.t().dot(&centered) / (n_samples - 1) as f64;
51
52        // Power iteration for eigendecomposition (simplified)
53        let n_comp = self.n_components.min(n_features);
54        let mut components = Array2::zeros((n_comp, n_features));
55        let mut variances = Array1::zeros(n_comp);
56
57        for k in 0..n_comp {
58            // Initialize random vector
59            let mut v = Array1::from_vec((0..n_features).map(|i| ((i + k) as f64).sin()).collect());
60            let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
61            v.mapv_inplace(|x| x / norm);
62
63            // Power iteration
64            for _ in 0..100 {
65                let mut new_v = cov.dot(&v);
66
67                // Orthogonalize against previous components
68                for prev_k in 0..k {
69                    let prev_comp = components.row(prev_k);
70                    let proj: f64 = new_v.iter().zip(prev_comp.iter()).map(|(a, b)| a * b).sum();
71                    for (i, val) in new_v.iter_mut().enumerate() {
72                        *val -= proj * prev_comp[i];
73                    }
74                }
75
76                let norm: f64 = new_v.iter().map(|x| x * x).sum::<f64>().sqrt();
77                if norm > 1e-10 {
78                    new_v.mapv_inplace(|x| x / norm);
79                }
80                v = new_v;
81            }
82
83            // Store component and compute variance
84            for j in 0..n_features {
85                components[[k, j]] = v[j];
86            }
87            variances[k] = cov.dot(&v).dot(&v);
88        }
89
90        let total_var: f64 = variances.sum();
91        let variance_ratio = variances.mapv(|v| v / total_var);
92
93        self.components = Some(components);
94        self.mean = Some(mean);
95        self.explained_variance = Some(variances);
96        self.explained_variance_ratio = Some(variance_ratio);
97        self.fitted = true;
98
99        Ok(())
100    }
101
102    /// Transform data
103    #[allow(non_snake_case)]
104    pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
105        let components = self
106            .components
107            .as_ref()
108            .ok_or_else(|| MLError::ModelNotTrained("PCA not fitted".to_string()))?;
109        let mean = self
110            .mean
111            .as_ref()
112            .ok_or_else(|| MLError::ModelNotTrained("PCA not fitted".to_string()))?;
113
114        // Center and project
115        let n_samples = X.nrows();
116        let mut centered = X.clone();
117        for i in 0..n_samples {
118            for j in 0..X.ncols() {
119                centered[[i, j]] -= mean[j];
120            }
121        }
122
123        Ok(centered.dot(&components.t()))
124    }
125
126    /// Fit and transform
127    #[allow(non_snake_case)]
128    pub fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
129        self.fit(X)?;
130        self.transform(X)
131    }
132
133    /// Get explained variance ratio
134    pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
135        self.explained_variance_ratio.as_ref()
136    }
137
138    /// Get components
139    pub fn components(&self) -> Option<&Array2<f64>> {
140        self.components.as_ref()
141    }
142}
143
144impl SklearnEstimator for PCA {
145    #[allow(non_snake_case)]
146    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
147        PCA::fit(self, X)
148    }
149
150    fn get_params(&self) -> HashMap<String, String> {
151        let mut params = HashMap::new();
152        params.insert("n_components".to_string(), self.n_components.to_string());
153        params
154    }
155
156    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
157        if let Some(n) = params.get("n_components") {
158            self.n_components = n
159                .parse()
160                .map_err(|_| MLError::InvalidConfiguration("Invalid n_components".to_string()))?;
161        }
162        Ok(())
163    }
164
165    fn is_fitted(&self) -> bool {
166        self.fitted
167    }
168}