quantrs2_ml/sklearn_compatibility/
mod.rs1mod classifiers;
8mod clustering;
9mod decomposition;
10mod feature_selection;
11pub mod metrics;
12pub mod model_selection;
13pub mod pipeline;
14mod preprocessing;
15mod regressors;
16
17pub use classifiers::*;
18pub use clustering::*;
19pub use decomposition::*;
20pub use feature_selection::*;
21pub use model_selection::*;
22pub use pipeline::*;
23pub use preprocessing::*;
24pub use regressors::*;
25
26use crate::error::Result;
27use scirs2_core::ndarray::{Array1, Array2};
28use std::collections::HashMap;
29
30pub trait SklearnEstimator: Send + Sync {
32 #[allow(non_snake_case)]
34 fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()>;
35
36 fn get_params(&self) -> HashMap<String, String>;
38
39 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()>;
41
42 fn is_fitted(&self) -> bool;
44
45 fn get_feature_names_out(&self) -> Vec<String> {
47 vec![]
48 }
49}
50
51pub trait SklearnClassifier: SklearnEstimator {
53 #[allow(non_snake_case)]
55 fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>>;
56
57 #[allow(non_snake_case)]
59 fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>>;
60
61 fn classes(&self) -> &[i32];
63
64 #[allow(non_snake_case)]
66 fn score(&self, X: &Array2<f64>, y: &Array1<i32>) -> Result<f64> {
67 let predictions = self.predict(X)?;
68 let correct = predictions
69 .iter()
70 .zip(y.iter())
71 .filter(|(&pred, &true_label)| pred == true_label)
72 .count();
73 Ok(correct as f64 / y.len() as f64)
74 }
75
76 fn feature_importances(&self) -> Option<Array1<f64>> {
78 None
79 }
80
81 fn save(&self, _path: &str) -> Result<()> {
83 Ok(())
84 }
85}
86
87pub trait SklearnRegressor: SklearnEstimator {
89 #[allow(non_snake_case)]
91 fn predict(&self, X: &Array2<f64>) -> Result<Array1<f64>>;
92
93 #[allow(non_snake_case)]
95 fn score(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
96 let predictions = self.predict(X)?;
97 let y_mean = y.mean().unwrap_or(0.0);
98
99 let ss_res: f64 = y
100 .iter()
101 .zip(predictions.iter())
102 .map(|(&true_val, &pred)| (true_val - pred).powi(2))
103 .sum();
104
105 let ss_tot: f64 = y.iter().map(|&val| (val - y_mean).powi(2)).sum();
106
107 Ok(1.0 - ss_res / ss_tot)
108 }
109}
110
111pub trait SklearnFit {
113 #[allow(non_snake_case)]
114 fn fit(&mut self, X: &Array2<f64>, y: &Array1<f64>) -> Result<()>;
115}
116
117pub trait SklearnClusterer: SklearnEstimator {
119 #[allow(non_snake_case)]
121 fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>>;
122
123 #[allow(non_snake_case)]
125 fn fit_predict(&mut self, X: &Array2<f64>) -> Result<Array1<i32>> {
126 SklearnEstimator::fit(self, X, None)?;
127 self.predict(X)
128 }
129
130 fn cluster_centers(&self) -> Option<&Array2<f64>> {
132 None
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use scirs2_core::ndarray::array;
140
141 #[test]
142 fn test_standard_scaler() {
143 let mut scaler = StandardScaler::new();
144
145 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
146 scaler.fit(&X, None).expect("Fit should succeed");
147
148 assert!(scaler.is_fitted());
149 }
150
151 #[test]
152 fn test_minmax_scaler() {
153 let scaler = MinMaxScaler::new();
154 let params = scaler.get_params();
155 assert!(params.contains_key("feature_range_min"));
156 }
157
158 #[test]
159 fn test_label_encoder() {
160 let encoder = LabelEncoder::new();
161 assert!(!encoder.is_fitted());
162 }
163
164 #[test]
165 fn test_pca() {
166 let pca = PCA::new(2);
167 let params = pca.get_params();
168 assert_eq!(params.get("n_components"), Some(&"2".to_string()));
169 }
170}