use crate::error::{MLError, Result};
use scirs2_core::ndarray::{s, Array1, Array2};
use std::collections::HashMap;
use super::super::config::{DRTrainedState, QICAConfig};
#[derive(Debug)]
pub struct QICA {
config: QICAConfig,
trained_state: Option<DRTrainedState>,
}
impl QICA {
pub fn new(config: QICAConfig) -> Self {
Self {
config,
trained_state: None,
}
}
pub fn fit(&mut self, data: &Array2<f64>) -> Result<()> {
let n_features = data.ncols();
let n_components = self.config.n_components.min(n_features);
let mean = data
.mean_axis(scirs2_core::ndarray::Axis(0))
.unwrap_or_else(|| scirs2_core::ndarray::Array1::zeros(data.ncols()));
let components = Array2::eye(n_features)
.slice(s![..n_components, ..])
.to_owned();
let explained_variance_ratio = Array1::ones(n_components) / n_components as f64;
self.trained_state = Some(DRTrainedState {
components,
explained_variance_ratio,
mean,
scale: None,
quantum_parameters: HashMap::new(),
model_parameters: HashMap::new(),
training_statistics: HashMap::new(),
});
Ok(())
}
pub fn get_trained_state(&self) -> Option<DRTrainedState> {
self.trained_state.clone()
}
}