use crate::error::{MLError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::HashMap;
use super::super::config::{DRTrainedState, QAutoencoderConfig};
#[derive(Debug)]
pub struct QSparseAE {
config: QAutoencoderConfig,
trained_state: Option<DRTrainedState>,
}
impl QSparseAE {
pub fn new(config: QAutoencoderConfig) -> Self {
Self {
config,
trained_state: None,
}
}
pub fn fit(&mut self, data: &Array2<f64>) -> Result<()> {
let n_components = self.config.latent_dim;
let mean = data
.mean_axis(scirs2_core::ndarray::Axis(0))
.unwrap_or_else(|| scirs2_core::ndarray::Array1::zeros(data.ncols()));
let components = Array2::zeros((n_components, data.ncols()));
let explained_variance_ratio = Array1::ones(n_components) / n_components as f64;
self.trained_state = Some(DRTrainedState {
components,
explained_variance_ratio,
mean,
scale: None,
quantum_parameters: HashMap::new(),
model_parameters: HashMap::new(),
training_statistics: HashMap::new(),
});
Ok(())
}
pub fn get_trained_state(&self) -> Option<DRTrainedState> {
self.trained_state.clone()
}
}