quantrs2_ml/sklearn_compatibility/
preprocessing.rs1use super::SklearnEstimator;
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use std::collections::HashMap;
7
8pub struct StandardScaler {
10 mean_: Option<Array1<f64>>,
11 scale_: Option<Array1<f64>>,
12 fitted: bool,
13}
14
15impl StandardScaler {
16 pub fn new() -> Self {
17 Self {
18 mean_: None,
19 scale_: None,
20 fitted: false,
21 }
22 }
23}
24
25impl Default for StandardScaler {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl SklearnEstimator for StandardScaler {
32 #[allow(non_snake_case)]
33 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
34 let mean = X.mean_axis(Axis(0)).ok_or_else(|| {
35 MLError::InvalidInput("Cannot compute mean of empty array".to_string())
36 })?;
37 let std = X.std_axis(Axis(0), 0.0);
38
39 self.mean_ = Some(mean);
40 self.scale_ = Some(std);
41 self.fitted = true;
42
43 Ok(())
44 }
45
46 fn get_params(&self) -> HashMap<String, String> {
47 HashMap::new()
48 }
49
50 fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
51 Ok(())
52 }
53
54 fn is_fitted(&self) -> bool {
55 self.fitted
56 }
57}
58
59pub struct MinMaxScaler {
61 min: Option<Array1<f64>>,
62 max: Option<Array1<f64>>,
63 feature_range: (f64, f64),
64 fitted: bool,
65}
66
67impl MinMaxScaler {
68 pub fn new() -> Self {
70 Self {
71 min: None,
72 max: None,
73 feature_range: (0.0, 1.0),
74 fitted: false,
75 }
76 }
77
78 pub fn feature_range(mut self, min_val: f64, max_val: f64) -> Self {
80 self.feature_range = (min_val, max_val);
81 self
82 }
83
84 #[allow(non_snake_case)]
86 pub fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
87 let n_features = X.ncols();
88 let mut min = Array1::zeros(n_features);
89 let mut max = Array1::zeros(n_features);
90
91 for j in 0..n_features {
92 let col = X.column(j);
93 min[j] = col.iter().cloned().fold(f64::INFINITY, f64::min);
94 max[j] = col.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
95 }
96
97 self.min = Some(min);
98 self.max = Some(max);
99 self.fitted = true;
100 Ok(())
101 }
102
103 #[allow(non_snake_case)]
105 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
106 let min = self
107 .min
108 .as_ref()
109 .ok_or_else(|| MLError::ModelNotTrained("Scaler not fitted".to_string()))?;
110 let max = self
111 .max
112 .as_ref()
113 .ok_or_else(|| MLError::ModelNotTrained("Scaler not fitted".to_string()))?;
114
115 let (range_min, range_max) = self.feature_range;
116 let mut result = X.clone();
117
118 for j in 0..X.ncols() {
119 let scale = if (max[j] - min[j]).abs() > 1e-10 {
120 (range_max - range_min) / (max[j] - min[j])
121 } else {
122 1.0
123 };
124
125 for i in 0..X.nrows() {
126 result[[i, j]] = (X[[i, j]] - min[j]) * scale + range_min;
127 }
128 }
129
130 Ok(result)
131 }
132
133 #[allow(non_snake_case)]
135 pub fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
136 self.fit(X)?;
137 self.transform(X)
138 }
139}
140
141impl Default for MinMaxScaler {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl SklearnEstimator for MinMaxScaler {
148 #[allow(non_snake_case)]
149 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
150 MinMaxScaler::fit(self, X)
151 }
152
153 fn get_params(&self) -> HashMap<String, String> {
154 let mut params = HashMap::new();
155 params.insert(
156 "feature_range_min".to_string(),
157 self.feature_range.0.to_string(),
158 );
159 params.insert(
160 "feature_range_max".to_string(),
161 self.feature_range.1.to_string(),
162 );
163 params
164 }
165
166 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
167 if let Some(min_str) = params.get("feature_range_min") {
168 if let Some(max_str) = params.get("feature_range_max") {
169 let min_val: f64 = min_str.parse().map_err(|_| {
170 MLError::InvalidConfiguration("Invalid feature_range_min".to_string())
171 })?;
172 let max_val: f64 = max_str.parse().map_err(|_| {
173 MLError::InvalidConfiguration("Invalid feature_range_max".to_string())
174 })?;
175 self.feature_range = (min_val, max_val);
176 }
177 }
178 Ok(())
179 }
180
181 fn is_fitted(&self) -> bool {
182 self.fitted
183 }
184}
185
186pub struct RobustScaler {
188 center: Option<Array1<f64>>,
189 scale: Option<Array1<f64>>,
190 with_centering: bool,
191 with_scaling: bool,
192 fitted: bool,
193}
194
195impl RobustScaler {
196 pub fn new() -> Self {
198 Self {
199 center: None,
200 scale: None,
201 with_centering: true,
202 with_scaling: true,
203 fitted: false,
204 }
205 }
206
207 #[allow(non_snake_case)]
209 pub fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
210 let n_features = X.ncols();
211 let mut center = Array1::zeros(n_features);
212 let mut scale = Array1::zeros(n_features);
213
214 for j in 0..n_features {
215 let mut col: Vec<f64> = X.column(j).iter().cloned().collect();
216 col.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
217
218 let n = col.len();
219 center[j] = col[n / 2]; let q1 = col[n / 4];
222 let q3 = col[3 * n / 4];
223 scale[j] = if (q3 - q1).abs() > 1e-10 {
224 q3 - q1
225 } else {
226 1.0
227 };
228 }
229
230 self.center = Some(center);
231 self.scale = Some(scale);
232 self.fitted = true;
233 Ok(())
234 }
235
236 #[allow(non_snake_case)]
238 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
239 let center = self
240 .center
241 .as_ref()
242 .ok_or_else(|| MLError::ModelNotTrained("Scaler not fitted".to_string()))?;
243 let scale = self
244 .scale
245 .as_ref()
246 .ok_or_else(|| MLError::ModelNotTrained("Scaler not fitted".to_string()))?;
247
248 let mut result = X.clone();
249
250 for j in 0..X.ncols() {
251 for i in 0..X.nrows() {
252 result[[i, j]] = if self.with_centering {
253 (X[[i, j]] - center[j]) / scale[j]
254 } else {
255 X[[i, j]] / scale[j]
256 };
257 }
258 }
259
260 Ok(result)
261 }
262
263 #[allow(non_snake_case)]
265 pub fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
266 self.fit(X)?;
267 self.transform(X)
268 }
269}
270
271impl Default for RobustScaler {
272 fn default() -> Self {
273 Self::new()
274 }
275}
276
277pub struct LabelEncoder {
279 classes: Vec<String>,
280 fitted: bool,
281}
282
283impl LabelEncoder {
284 pub fn new() -> Self {
286 Self {
287 classes: Vec::new(),
288 fitted: false,
289 }
290 }
291
292 pub fn fit(&mut self, y: &[String]) {
294 let mut classes: Vec<String> = y.iter().cloned().collect();
295 classes.sort();
296 classes.dedup();
297 self.classes = classes;
298 self.fitted = true;
299 }
300
301 pub fn transform(&self, y: &[String]) -> Result<Array1<i32>> {
303 if !self.fitted {
304 return Err(MLError::ModelNotTrained("Encoder not fitted".to_string()));
305 }
306
307 let encoded: Vec<i32> = y
308 .iter()
309 .map(|label| {
310 self.classes
311 .iter()
312 .position(|c| c == label)
313 .map(|p| p as i32)
314 .unwrap_or(-1)
315 })
316 .collect();
317
318 Ok(Array1::from_vec(encoded))
319 }
320
321 pub fn inverse_transform(&self, y: &Array1<i32>) -> Result<Vec<String>> {
323 if !self.fitted {
324 return Err(MLError::ModelNotTrained("Encoder not fitted".to_string()));
325 }
326
327 let decoded: Vec<String> = y
328 .iter()
329 .map(|&idx| {
330 if idx >= 0 && (idx as usize) < self.classes.len() {
331 self.classes[idx as usize].clone()
332 } else {
333 "unknown".to_string()
334 }
335 })
336 .collect();
337
338 Ok(decoded)
339 }
340
341 pub fn fit_transform(&mut self, y: &[String]) -> Result<Array1<i32>> {
343 self.fit(y);
344 self.transform(y)
345 }
346
347 pub fn classes(&self) -> &[String] {
349 &self.classes
350 }
351
352 pub fn is_fitted(&self) -> bool {
354 self.fitted
355 }
356}
357
358impl Default for LabelEncoder {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364pub struct OneHotEncoder {
366 categories: Vec<Vec<String>>,
367 fitted: bool,
368 sparse: bool,
369}
370
371impl OneHotEncoder {
372 pub fn new() -> Self {
374 Self {
375 categories: Vec::new(),
376 fitted: false,
377 sparse: false,
378 }
379 }
380
381 pub fn sparse(mut self, sparse: bool) -> Self {
383 self.sparse = sparse;
384 self
385 }
386
387 #[allow(non_snake_case)]
389 pub fn fit(&mut self, X: &Array2<String>) {
390 self.categories = Vec::new();
391
392 for j in 0..X.ncols() {
393 let mut cats: Vec<String> = X.column(j).iter().cloned().collect();
394 cats.sort();
395 cats.dedup();
396 self.categories.push(cats);
397 }
398
399 self.fitted = true;
400 }
401
402 #[allow(non_snake_case)]
404 pub fn transform(&self, X: &Array2<String>) -> Result<Array2<f64>> {
405 if !self.fitted {
406 return Err(MLError::ModelNotTrained("Encoder not fitted".to_string()));
407 }
408
409 let total_cols: usize = self.categories.iter().map(|c| c.len()).sum();
410 let mut result = Array2::zeros((X.nrows(), total_cols));
411
412 let mut col_offset = 0;
413 for j in 0..X.ncols() {
414 let cats = &self.categories[j];
415 for i in 0..X.nrows() {
416 if let Some(idx) = cats.iter().position(|c| c == &X[[i, j]]) {
417 result[[i, col_offset + idx]] = 1.0;
418 }
419 }
420 col_offset += cats.len();
421 }
422
423 Ok(result)
424 }
425
426 #[allow(non_snake_case)]
428 pub fn fit_transform(&mut self, X: &Array2<String>) -> Result<Array2<f64>> {
429 self.fit(X);
430 self.transform(X)
431 }
432}
433
434impl Default for OneHotEncoder {
435 fn default() -> Self {
436 Self::new()
437 }
438}
439
440pub struct QuantumFeatureEncoder {
442 encoding_type: String,
443 normalization: String,
444 fitted: bool,
445}
446
447impl QuantumFeatureEncoder {
448 pub fn new(encoding_type: &str, normalization: &str) -> Self {
449 Self {
450 encoding_type: encoding_type.to_string(),
451 normalization: normalization.to_string(),
452 fitted: false,
453 }
454 }
455}
456
457impl SklearnEstimator for QuantumFeatureEncoder {
458 #[allow(non_snake_case)]
459 fn fit(&mut self, _X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
460 self.fitted = true;
461 Ok(())
462 }
463
464 fn get_params(&self) -> HashMap<String, String> {
465 let mut params = HashMap::new();
466 params.insert("encoding_type".to_string(), self.encoding_type.clone());
467 params.insert("normalization".to_string(), self.normalization.clone());
468 params
469 }
470
471 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
472 for (key, value) in params {
473 match key.as_str() {
474 "encoding_type" => {
475 self.encoding_type = value;
476 }
477 "normalization" => {
478 self.normalization = value;
479 }
480 _ => {}
481 }
482 }
483 Ok(())
484 }
485
486 fn is_fitted(&self) -> bool {
487 self.fitted
488 }
489}