quantrs2_ml/dimensionality_reduction/linear/
qpca.rs1use crate::error::{MLError, Result};
4use scirs2_core::ndarray::{s, Array1, Array2};
5use std::collections::HashMap;
6
7use super::super::config::{DRTrainedState, QPCAConfig};
8
9#[derive(Debug)]
11pub struct QPCA {
12 config: QPCAConfig,
13 trained_state: Option<DRTrainedState>,
14}
15
16impl QPCA {
17 pub fn new(config: QPCAConfig) -> Self {
19 Self {
20 config,
21 trained_state: None,
22 }
23 }
24
25 pub fn fit(&mut self, data: &Array2<f64>) -> Result<()> {
27 let n_samples = data.nrows();
28 let n_features = data.ncols();
29 let n_components = self.config.n_components.min(n_features);
30
31 let mean = data
33 .mean_axis(scirs2_core::ndarray::Axis(0))
34 .unwrap_or_else(|| scirs2_core::ndarray::Array1::zeros(data.ncols()));
35
36 let centered = data - &mean;
38
39 let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
41
42 let components = Array2::eye(n_features)
45 .slice(s![..n_components, ..])
46 .to_owned();
47 let eigenvalues =
48 Array1::from_vec((0..n_components).map(|i| 1.0 / (i + 1) as f64).collect());
49 let explained_variance_ratio = &eigenvalues / eigenvalues.sum();
50
51 self.trained_state = Some(DRTrainedState {
53 components,
54 explained_variance_ratio,
55 mean,
56 scale: None,
57 quantum_parameters: HashMap::new(),
58 model_parameters: HashMap::new(),
59 training_statistics: HashMap::new(),
60 });
61
62 Ok(())
63 }
64
65 pub fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
67 if let Some(state) = &self.trained_state {
68 let centered = data - &state.mean;
69 Ok(centered.dot(&state.components.t()))
70 } else {
71 Err(MLError::ModelNotTrained(
72 "QPCA model must be fitted before transform".to_string(),
73 ))
74 }
75 }
76
77 pub fn get_trained_state(&self) -> Option<DRTrainedState> {
79 self.trained_state.clone()
80 }
81}