quantrs2_ml/sklearn_compatibility/
mod.rs

1//! Scikit-learn compatibility layer for QuantRS2-ML
2//!
3//! This module provides a compatibility layer that mimics scikit-learn APIs,
4//! allowing easy integration of quantum ML models with existing scikit-learn
5//! workflows and pipelines.
6
7mod classifiers;
8mod clustering;
9mod decomposition;
10mod feature_selection;
11pub mod metrics;
12pub mod model_selection;
13pub mod pipeline;
14mod preprocessing;
15mod regressors;
16
17pub use classifiers::*;
18pub use clustering::*;
19pub use decomposition::*;
20pub use feature_selection::*;
21pub use model_selection::*;
22pub use pipeline::*;
23pub use preprocessing::*;
24pub use regressors::*;
25
26use crate::error::Result;
27use scirs2_core::ndarray::{Array1, Array2};
28use std::collections::HashMap;
29
30/// Base estimator trait following scikit-learn conventions
31pub trait SklearnEstimator: Send + Sync {
32    /// Fit the model to training data
33    #[allow(non_snake_case)]
34    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()>;
35
36    /// Get model parameters
37    fn get_params(&self) -> HashMap<String, String>;
38
39    /// Set model parameters
40    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()>;
41
42    /// Check if model is fitted
43    fn is_fitted(&self) -> bool;
44
45    /// Get feature names
46    fn get_feature_names_out(&self) -> Vec<String> {
47        vec![]
48    }
49}
50
51/// Classifier mixin trait
52pub trait SklearnClassifier: SklearnEstimator {
53    /// Predict class labels
54    #[allow(non_snake_case)]
55    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>>;
56
57    /// Predict class probabilities
58    #[allow(non_snake_case)]
59    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>>;
60
61    /// Get unique class labels
62    fn classes(&self) -> &[i32];
63
64    /// Score the model (accuracy by default)
65    #[allow(non_snake_case)]
66    fn score(&self, X: &Array2<f64>, y: &Array1<i32>) -> Result<f64> {
67        let predictions = self.predict(X)?;
68        let correct = predictions
69            .iter()
70            .zip(y.iter())
71            .filter(|(&pred, &true_label)| pred == true_label)
72            .count();
73        Ok(correct as f64 / y.len() as f64)
74    }
75
76    /// Get feature importances (optional)
77    fn feature_importances(&self) -> Option<Array1<f64>> {
78        None
79    }
80
81    /// Save model to file (optional)
82    fn save(&self, _path: &str) -> Result<()> {
83        Ok(())
84    }
85}
86
87/// Regressor mixin trait
88pub trait SklearnRegressor: SklearnEstimator {
89    /// Predict continuous values
90    #[allow(non_snake_case)]
91    fn predict(&self, X: &Array2<f64>) -> Result<Array1<f64>>;
92
93    /// Score the model (R² by default)
94    #[allow(non_snake_case)]
95    fn score(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
96        let predictions = self.predict(X)?;
97        let y_mean = y.mean().unwrap_or(0.0);
98
99        let ss_res: f64 = y
100            .iter()
101            .zip(predictions.iter())
102            .map(|(&true_val, &pred)| (true_val - pred).powi(2))
103            .sum();
104
105        let ss_tot: f64 = y.iter().map(|&val| (val - y_mean).powi(2)).sum();
106
107        Ok(1.0 - ss_res / ss_tot)
108    }
109}
110
111/// Extension trait for fitting with Array1<f64> directly
112pub trait SklearnFit {
113    #[allow(non_snake_case)]
114    fn fit(&mut self, X: &Array2<f64>, y: &Array1<f64>) -> Result<()>;
115}
116
117/// Clusterer mixin trait
118pub trait SklearnClusterer: SklearnEstimator {
119    /// Predict cluster labels
120    #[allow(non_snake_case)]
121    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>>;
122
123    /// Fit and predict in one step
124    #[allow(non_snake_case)]
125    fn fit_predict(&mut self, X: &Array2<f64>) -> Result<Array1<i32>> {
126        SklearnEstimator::fit(self, X, None)?;
127        self.predict(X)
128    }
129
130    /// Get cluster centers (if applicable)
131    fn cluster_centers(&self) -> Option<&Array2<f64>> {
132        None
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use scirs2_core::ndarray::array;
140
141    #[test]
142    fn test_standard_scaler() {
143        let mut scaler = StandardScaler::new();
144
145        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
146        scaler.fit(&X, None).expect("Fit should succeed");
147
148        assert!(scaler.is_fitted());
149    }
150
151    #[test]
152    fn test_minmax_scaler() {
153        let scaler = MinMaxScaler::new();
154        let params = scaler.get_params();
155        assert!(params.contains_key("feature_range_min"));
156    }
157
158    #[test]
159    fn test_label_encoder() {
160        let encoder = LabelEncoder::new();
161        assert!(!encoder.is_fitted());
162    }
163
164    #[test]
165    fn test_pca() {
166        let pca = PCA::new(2);
167        let params = pca.get_params();
168        assert_eq!(params.get("n_components"), Some(&"2".to_string()));
169    }
170}