use crate::error::{MLError, Result};
use scirs2_core::ndarray::{s, Array1, Array2};
use std::collections::HashMap;
use super::super::config::{DRTrainedState, QPCAConfig};
#[derive(Debug)]
pub struct QPCA {
config: QPCAConfig,
trained_state: Option<DRTrainedState>,
}
impl QPCA {
pub fn new(config: QPCAConfig) -> Self {
Self {
config,
trained_state: None,
}
}
pub fn fit(&mut self, data: &Array2<f64>) -> Result<()> {
let n_samples = data.nrows();
let n_features = data.ncols();
let n_components = self.config.n_components.min(n_features);
let mean = data
.mean_axis(scirs2_core::ndarray::Axis(0))
.unwrap_or_else(|| scirs2_core::ndarray::Array1::zeros(data.ncols()));
let centered = data - &mean;
let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
let components = Array2::eye(n_features)
.slice(s![..n_components, ..])
.to_owned();
let eigenvalues =
Array1::from_vec((0..n_components).map(|i| 1.0 / (i + 1) as f64).collect());
let explained_variance_ratio = &eigenvalues / eigenvalues.sum();
self.trained_state = Some(DRTrainedState {
components,
explained_variance_ratio,
mean,
scale: None,
quantum_parameters: HashMap::new(),
model_parameters: HashMap::new(),
training_statistics: HashMap::new(),
});
Ok(())
}
pub fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
if let Some(state) = &self.trained_state {
let centered = data - &state.mean;
Ok(centered.dot(&state.components.t()))
} else {
Err(MLError::ModelNotTrained(
"QPCA model must be fitted before transform".to_string(),
))
}
}
pub fn get_trained_state(&self) -> Option<DRTrainedState> {
self.trained_state.clone()
}
}