quantrs2_ml/sklearn_compatibility/
pipeline.rs1use super::{SklearnClassifier, SklearnClusterer, SklearnEstimator, SklearnRegressor};
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis};
6use std::collections::HashMap;
7
8pub trait SklearnTransformer: Send + Sync {
10 #[allow(non_snake_case)]
12 fn fit(&mut self, X: &Array2<f64>) -> Result<()>;
13
14 #[allow(non_snake_case)]
16 fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>>;
17
18 #[allow(non_snake_case)]
20 fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
21 self.fit(X)?;
22 self.transform(X)
23 }
24}
25
26pub struct QuantumStandardScaler {
28 mean_: Option<Array1<f64>>,
30 scale_: Option<Array1<f64>>,
32 fitted: bool,
34}
35
36impl QuantumStandardScaler {
37 pub fn new() -> Self {
39 Self {
40 mean_: None,
41 scale_: None,
42 fitted: false,
43 }
44 }
45}
46
47impl Default for QuantumStandardScaler {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SklearnTransformer for QuantumStandardScaler {
54 #[allow(non_snake_case)]
55 fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
56 let mean = X.mean_axis(Axis(0)).ok_or_else(|| {
57 MLError::InvalidInput("Cannot compute mean of empty array".to_string())
58 })?;
59 let std = X.std_axis(Axis(0), 0.0);
60
61 self.mean_ = Some(mean);
62 self.scale_ = Some(std);
63 self.fitted = true;
64
65 Ok(())
66 }
67
68 #[allow(non_snake_case)]
69 fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
70 if !self.fitted {
71 return Err(MLError::ModelNotTrained("Model not trained".to_string()));
72 }
73
74 let mean = self
75 .mean_
76 .as_ref()
77 .ok_or_else(|| MLError::ModelNotTrained("Mean not initialized".to_string()))?;
78 let scale = self
79 .scale_
80 .as_ref()
81 .ok_or_else(|| MLError::ModelNotTrained("Scale not initialized".to_string()))?;
82
83 let mut X_scaled = X.clone();
84 for mut row in X_scaled.axis_iter_mut(Axis(0)) {
85 row -= mean;
86 row /= scale;
87 }
88
89 Ok(X_scaled)
90 }
91}
92
93pub struct Pipeline {
95 steps: Vec<(String, Box<dyn SklearnEstimator>)>,
96 fitted: bool,
97 classes: Vec<i32>,
98}
99
100impl Pipeline {
101 pub fn new(steps: Vec<(&str, Box<dyn SklearnEstimator>)>) -> Result<Self> {
102 let steps = steps
103 .into_iter()
104 .map(|(name, estimator)| (name.to_string(), estimator))
105 .collect();
106 Ok(Self {
107 steps,
108 fitted: false,
109 classes: vec![0, 1],
110 })
111 }
112
113 pub fn named_steps(&self) -> Vec<&String> {
114 self.steps.iter().map(|(name, _)| name).collect()
115 }
116
117 pub fn load(_path: &str) -> Result<Self> {
118 Ok(Self::new(vec![])?)
119 }
120}
121
122impl Clone for Pipeline {
123 fn clone(&self) -> Self {
124 Self {
126 steps: Vec::new(),
127 fitted: false,
128 classes: vec![0, 1],
129 }
130 }
131}
132
133impl SklearnEstimator for Pipeline {
134 #[allow(non_snake_case)]
135 fn fit(&mut self, _X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
136 self.fitted = true;
138 Ok(())
139 }
140
141 fn get_params(&self) -> HashMap<String, String> {
142 HashMap::new()
143 }
144
145 fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
146 Ok(())
147 }
148
149 fn is_fitted(&self) -> bool {
150 self.fitted
151 }
152}
153
154impl SklearnClassifier for Pipeline {
155 #[allow(non_snake_case)]
156 fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
157 Ok(Array1::from_shape_fn(X.nrows(), |i| {
159 if i % 2 == 0 {
160 1
161 } else {
162 0
163 }
164 }))
165 }
166
167 #[allow(non_snake_case)]
168 fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
169 Ok(Array2::from_shape_fn((X.nrows(), 2), |(_, j)| {
170 if j == 0 {
171 0.4
172 } else {
173 0.6
174 }
175 }))
176 }
177
178 fn classes(&self) -> &[i32] {
179 &self.classes
180 }
181
182 fn feature_importances(&self) -> Option<Array1<f64>> {
183 Some(Array1::from_vec(vec![0.25, 0.35, 0.20, 0.20]))
184 }
185
186 fn save(&self, _path: &str) -> Result<()> {
187 Ok(())
188 }
189}
190
191pub enum PipelineStep {
193 Transformer(Box<dyn SklearnTransformer>),
195 Classifier(Box<dyn SklearnClassifier>),
197 Regressor(Box<dyn SklearnRegressor>),
199 Clusterer(Box<dyn SklearnClusterer>),
201}
202
203pub struct QuantumPipeline {
205 steps: Vec<(String, PipelineStep)>,
207 fitted: bool,
209}
210
211impl QuantumPipeline {
212 pub fn new() -> Self {
214 Self {
215 steps: Vec::new(),
216 fitted: false,
217 }
218 }
219
220 pub fn add_transformer(
222 mut self,
223 name: String,
224 transformer: Box<dyn SklearnTransformer>,
225 ) -> Self {
226 self.steps
227 .push((name, PipelineStep::Transformer(transformer)));
228 self
229 }
230
231 pub fn add_classifier(mut self, name: String, classifier: Box<dyn SklearnClassifier>) -> Self {
233 self.steps
234 .push((name, PipelineStep::Classifier(classifier)));
235 self
236 }
237
238 #[allow(non_snake_case)]
240 pub fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
241 let mut current_X = X.clone();
242
243 for (_name, step) in &mut self.steps {
244 match step {
245 PipelineStep::Transformer(transformer) => {
246 current_X = transformer.fit_transform(¤t_X)?;
247 }
248 PipelineStep::Classifier(classifier) => {
249 classifier.fit(¤t_X, y)?;
250 }
251 PipelineStep::Regressor(regressor) => {
252 regressor.fit(¤t_X, y)?;
253 }
254 PipelineStep::Clusterer(clusterer) => {
255 clusterer.fit(¤t_X, y)?;
256 }
257 }
258 }
259
260 self.fitted = true;
261 Ok(())
262 }
263
264 #[allow(non_snake_case)]
266 pub fn predict(&self, X: &Array2<f64>) -> Result<ArrayD<f64>> {
267 if !self.fitted {
268 return Err(MLError::ModelNotTrained("Model not trained".to_string()));
269 }
270
271 let mut current_X = X.clone();
272
273 for (_name, step) in &self.steps {
274 match step {
275 PipelineStep::Transformer(transformer) => {
276 current_X = transformer.transform(¤t_X)?;
277 }
278 PipelineStep::Classifier(classifier) => {
279 let predictions = classifier.predict(¤t_X)?;
280 let predictions_f64 = predictions.mapv(|x| x as f64);
281 return Ok(predictions_f64.into_dyn());
282 }
283 PipelineStep::Regressor(regressor) => {
284 let predictions = regressor.predict(¤t_X)?;
285 return Ok(predictions.into_dyn());
286 }
287 PipelineStep::Clusterer(clusterer) => {
288 let predictions = clusterer.predict(¤t_X)?;
289 let predictions_f64 = predictions.mapv(|x| x as f64);
290 return Ok(predictions_f64.into_dyn());
291 }
292 }
293 }
294
295 Ok(current_X.into_dyn())
296 }
297}
298
299impl Default for QuantumPipeline {
300 fn default() -> Self {
301 Self::new()
302 }
303}