1use crate::error::{PreprocessingError, Result};
4use approx::abs_diff_eq;
5use linfa::dataset::{AsTargets, DatasetBase, Float, WithLapack};
6use linfa::traits::{Fit, Transformer};
7#[cfg(not(feature = "blas"))]
8use linfa_linalg::norm::Norm;
9use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2, Zip};
10#[cfg(feature = "blas")]
11use ndarray_linalg::norm::Norm;
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16#[cfg_attr(
17 feature = "serde",
18 derive(Serialize, Deserialize),
19 serde(crate = "serde_crate")
20)]
21#[derive(Clone, Debug, PartialEq, Eq)]
22pub enum ScalingMethod<F: Float> {
28 Standard(bool, bool),
29 MinMax(F, F),
30 MaxAbs,
31}
32
33impl<F: Float> ScalingMethod<F> {
34 pub(crate) fn fit<D: Data<Elem = F>>(
35 &self,
36 records: &ArrayBase<D, Ix2>,
37 ) -> Result<LinearScaler<F>> {
38 match self {
39 ScalingMethod::Standard(a, b) => Self::standardize(records, *a, *b),
40 ScalingMethod::MinMax(a, b) => Self::min_max(records, *a, *b),
41 ScalingMethod::MaxAbs => Self::max_abs(records),
42 }
43 }
44
45 fn standardize<D: Data<Elem = F>>(
46 records: &ArrayBase<D, Ix2>,
47 with_mean: bool,
48 with_std: bool,
49 ) -> Result<LinearScaler<F>> {
50 if records.dim().0 == 0 {
51 return Err(PreprocessingError::NotEnoughSamples);
52 }
53 let means = records.mean_axis(Axis(0)).unwrap();
55 let std_devs = if with_std {
56 records.std_axis(Axis(0), F::zero()).mapv(|s| {
57 if abs_diff_eq!(s, F::zero()) {
58 F::one()
60 } else {
61 F::one() / s
62 }
63 })
64 } else {
65 Array1::ones(records.dim().1)
66 };
67 Ok(LinearScaler {
68 offsets: means,
69 scales: std_devs,
70 method: ScalingMethod::Standard(with_mean, with_std),
71 })
72 }
73
74 fn min_max<D: Data<Elem = F>>(
75 records: &ArrayBase<D, Ix2>,
76 min: F,
77 max: F,
78 ) -> Result<LinearScaler<F>> {
79 if records.dim().0 == 0 {
80 return Err(PreprocessingError::NotEnoughSamples);
81 } else if min > max {
82 return Err(PreprocessingError::FlippedMinMaxRange);
83 }
84
85 let mins = records.fold_axis(
86 Axis(0),
87 F::infinity(),
88 |&x, &prev| if x < prev { x } else { prev },
89 );
90 let mut scales =
91 records.fold_axis(
92 Axis(0),
93 F::neg_infinity(),
94 |&x, &prev| if x > prev { x } else { prev },
95 );
96 Zip::from(&mut scales).and(&mins).for_each(|max, min| {
97 if abs_diff_eq!(*max - *min, F::zero()) {
98 *max = F::one();
100 } else {
101 *max = F::one() / (*max - *min);
102 }
103 });
104 Ok(LinearScaler {
105 offsets: mins,
106 scales,
107 method: ScalingMethod::MinMax(min, max),
108 })
109 }
110
111 fn max_abs<D: Data<Elem = F>>(records: &ArrayBase<D, Ix2>) -> Result<LinearScaler<F>> {
112 if records.dim().0 == 0 {
113 return Err(PreprocessingError::NotEnoughSamples);
114 }
115 let scales: Array1<F> = records.map_axis(Axis(0), |col| {
116 let norm_max = F::cast(col.with_lapack().norm_max());
117 if abs_diff_eq!(norm_max, F::zero()) {
118 F::one()
120 } else {
121 F::one() / norm_max
122 }
123 });
124
125 let offsets = Array1::zeros(records.dim().1);
126 Ok(LinearScaler {
127 offsets,
128 scales,
129 method: ScalingMethod::MaxAbs,
130 })
131 }
132}
133
134impl<F: Float> std::fmt::Display for ScalingMethod<F> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 ScalingMethod::Standard(with_mean, with_std) => write!(
138 f,
139 "Standard scaler (with_mean = {}, with_std = {})",
140 with_mean, with_std
141 ),
142 ScalingMethod::MinMax(min, max) => {
143 write!(f, "Min-Max scaler (min = {}, max = {})", min, max)
144 }
145 ScalingMethod::MaxAbs => write!(f, "MaxAbs scaler"),
146 }
147 }
148}
149
150#[cfg_attr(
168 feature = "serde",
169 derive(Serialize, Deserialize),
170 serde(crate = "serde_crate")
171)]
172#[derive(Debug, Clone, PartialEq, Eq)]
173pub struct LinearScalerParams<F: Float> {
174 method: ScalingMethod<F>,
175}
176
177impl<F: Float> LinearScalerParams<F> {
178 pub fn new(method: ScalingMethod<F>) -> Self {
180 Self { method }
181 }
182
183 pub fn method(mut self, method: ScalingMethod<F>) -> Self {
185 self.method = method;
186 self
187 }
188}
189
190impl<F: Float> LinearScaler<F> {
191 pub fn standard() -> LinearScalerParams<F> {
193 LinearScalerParams {
194 method: ScalingMethod::Standard(true, true),
195 }
196 }
197
198 pub fn standard_no_mean() -> LinearScalerParams<F> {
200 LinearScalerParams {
201 method: ScalingMethod::Standard(false, true),
202 }
203 }
204
205 pub fn standard_no_std() -> LinearScalerParams<F> {
207 LinearScalerParams {
208 method: ScalingMethod::Standard(true, false),
209 }
210 }
211
212 pub fn min_max() -> LinearScalerParams<F> {
214 LinearScalerParams {
215 method: ScalingMethod::MinMax(F::zero(), F::one()),
216 }
217 }
218
219 pub fn min_max_range(min: F, max: F) -> LinearScalerParams<F> {
223 LinearScalerParams {
224 method: ScalingMethod::MinMax(min, max),
225 }
226 }
227
228 pub fn max_abs() -> LinearScalerParams<F> {
230 LinearScalerParams {
231 method: ScalingMethod::MaxAbs,
232 }
233 }
234}
235
236impl<F: Float, D: Data<Elem = F>, T: AsTargets> Fit<ArrayBase<D, Ix2>, T, PreprocessingError>
237 for LinearScalerParams<F>
238{
239 type Object = LinearScaler<F>;
240
241 fn fit(&self, x: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
244 self.method.fit(x.records())
245 }
246}
247
248#[cfg_attr(
249 feature = "serde",
250 derive(Serialize, Deserialize),
251 serde(crate = "serde_crate")
252)]
253#[derive(Debug, Clone, PartialEq, Eq)]
254pub struct LinearScaler<F: Float> {
257 offsets: Array1<F>,
258 scales: Array1<F>,
259 method: ScalingMethod<F>,
260}
261
262impl<F: Float> LinearScaler<F> {
263 pub fn offsets(&self) -> &Array1<F> {
265 &self.offsets
266 }
267
268 pub fn scales(&self) -> &Array1<F> {
270 &self.scales
271 }
272
273 pub fn method(&self) -> &ScalingMethod<F> {
275 &self.method
276 }
277}
278
279impl<F: Float> Transformer<Array2<F>, Array2<F>> for LinearScaler<F> {
280 fn transform(&self, x: Array2<F>) -> Array2<F> {
283 if x.is_empty() {
284 return x;
285 }
286 let mut x = x;
287 Zip::from(x.columns_mut())
288 .and(self.offsets())
289 .and(self.scales())
290 .for_each(|mut col, &offset, &scale| {
291 if let ScalingMethod::Standard(false, _) = self.method {
292 col.mapv_inplace(|el| (el - offset) * scale + offset);
293 } else {
294 col.mapv_inplace(|el| (el - offset) * scale);
295 }
296 });
297 match &self.method {
298 ScalingMethod::MinMax(min, max) => x * (*max - *min) + *min,
299 _ => x,
300 }
301 }
302}
303
304impl<F: Float, D: Data<Elem = F>, T: AsTargets>
305 Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>> for LinearScaler<F>
306{
307 fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
310 let feature_names = x.feature_names();
311 let (records, targets, weights) = (x.records, x.targets, x.weights);
312 let records = self.transform(records.to_owned());
313 DatasetBase::new(records, targets)
314 .with_weights(weights)
315 .with_feature_names(feature_names)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use crate::linear_scaling::{LinearScaler, LinearScalerParams};
322 use approx::assert_abs_diff_eq;
323 use linfa::dataset::DatasetBase;
324 use linfa::traits::{Fit, Transformer};
325 use ndarray::{array, Array2, Axis};
326
327 #[test]
328 fn autotraits() {
329 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
330 has_autotraits::<LinearScaler<f64>>();
331 has_autotraits::<LinearScalerParams<f64>>();
332 has_autotraits::<ScalingMethod<f64>>();
333 }
334
335 #[test]
336 fn test_max_abs() {
337 let dataset = array![[1., -1.], [2., -2.], [3., -3.], [4., -5.]].into();
338 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
339 let scaled = scaler.transform(dataset);
340 let col0 = scaled.records().column(0);
341 let col1 = scaled.records().column(1);
342 assert_abs_diff_eq!(col0, array![1. / 4., 2. / 4., 3. / 4., 1.]);
343 assert_abs_diff_eq!(col1, array![-1. / 5., -2. / 5., -3. / 5., -1.]);
344 }
345
346 #[test]
347 fn test_standard_scaler() {
348 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
349 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
350 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
351 assert_abs_diff_eq!(
352 *scaler.scales(),
353 array![1. / 0.81, 1. / 0.81, 1. / 1.24],
354 epsilon = 1e-2
355 );
356 let transformed = scaler.transform(dataset);
357 let means = transformed.records().mean_axis(Axis(0)).unwrap();
358 let std_devs = transformed.records().std_axis(Axis(0), 0.);
359 assert_abs_diff_eq!(means, array![0., 0., 0.]);
360 assert_abs_diff_eq!(std_devs, array![1., 1., 1.]);
361 }
362
363 #[test]
364 fn test_standard_scaler_no_mean() {
365 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
366 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
367 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
368 assert_abs_diff_eq!(
369 *scaler.scales(),
370 array![1. / 0.81, 1. / 0.81, 1. / 1.24],
371 epsilon = 1e-2
372 );
373 let transformed = scaler.transform(dataset);
374 let means = transformed.records().mean_axis(Axis(0)).unwrap();
375 let std_devs = transformed.records().std_axis(Axis(0), 0.);
376 assert_abs_diff_eq!(means, array![1., 0., (1. / 3.)], epsilon = 1e-2);
377 assert_abs_diff_eq!(std_devs, array![1., 1., 1.]);
378 }
379
380 #[test]
381 fn test_standard_scaler_no_std() {
382 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
383 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
384 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
385 assert_abs_diff_eq!(*scaler.scales(), array![1., 1., 1.],);
386 let transformed = scaler.transform(dataset);
387 let means = transformed.records().mean_axis(Axis(0)).unwrap();
388 let std_devs = transformed.records().std_axis(Axis(0), 0.);
389 assert_abs_diff_eq!(means, array![0., 0., 0.]);
390 assert_abs_diff_eq!(std_devs, array![0.81, 0.81, 1.24], epsilon = 1e-2);
391 }
392
393 use super::ScalingMethod;
394
395 #[test]
396 fn test_standard_scaler_no_both() {
397 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
398 let scaler = LinearScalerParams::new(ScalingMethod::Standard(false, false))
399 .fit(&dataset)
400 .unwrap();
401
402 let original_means = dataset.records().mean_axis(Axis(0)).unwrap();
403 let original_stds = dataset.records().std_axis(Axis(0), 0.);
404
405 assert_abs_diff_eq!(*scaler.offsets(), original_means);
406 assert_abs_diff_eq!(*scaler.scales(), array![1., 1., 1.],);
407
408 let transformed = scaler.transform(dataset);
409
410 let means = transformed.records().mean_axis(Axis(0)).unwrap();
411 let std_devs = transformed.records().std_axis(Axis(0), 0.);
412
413 assert_abs_diff_eq!(means, original_means);
414 assert_abs_diff_eq!(std_devs, original_stds, epsilon = 1e-2);
415 }
416
417 #[test]
418 fn test_min_max_scaler() {
419 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
420 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
421 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., -1.]);
422 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1. / 3.]);
423 let transformed = scaler.transform(dataset);
424 let mins = transformed
425 .records()
426 .fold_axis(
427 Axis(0),
428 f64::INFINITY,
429 |&x, &prev| if x < prev { x } else { prev },
430 );
431 let maxes = transformed
432 .records()
433 .fold_axis(
434 Axis(0),
435 f64::NEG_INFINITY,
436 |&x, &prev| if x > prev { x } else { prev },
437 );
438 assert_abs_diff_eq!(maxes, array![1., 1., 1.]);
439 assert_abs_diff_eq!(mins, array![0., 0., 0.]);
440 }
441
442 #[test]
443 fn test_min_max_scaler_range() {
444 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
445 let scaler = LinearScaler::min_max_range(5., 10.).fit(&dataset).unwrap();
446 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., -1.]);
447 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1. / 3.]);
448 let transformed = scaler.transform(dataset);
449 let mins = transformed
450 .records()
451 .fold_axis(
452 Axis(0),
453 f64::INFINITY,
454 |&x, &prev| if x < prev { x } else { prev },
455 );
456 let maxes = transformed
457 .records()
458 .fold_axis(
459 Axis(0),
460 f64::NEG_INFINITY,
461 |&x, &prev| if x > prev { x } else { prev },
462 );
463 assert_abs_diff_eq!(mins, array![5., 5., 5.]);
464 assert_abs_diff_eq!(maxes, array![10., 10., 10.]);
465 }
466
467 #[test]
468 fn test_standard_const_feature() {
469 let dataset = array![[1., 2., 2.], [2., 2., 0.], [0., 2., -1.]].into();
470 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
471 assert_abs_diff_eq!(*scaler.offsets(), array![1., 2., 1. / 3.]);
472 assert_abs_diff_eq!(
473 *scaler.scales(),
474 array![1. / 0.81, 1., 1. / 1.24],
475 epsilon = 1e-2
476 );
477 let transformed = scaler.transform(dataset);
478 let means = transformed.records().mean_axis(Axis(0)).unwrap();
479 let std_devs = transformed.records().std_axis(Axis(0), 0.);
480 assert_abs_diff_eq!(means, array![0., 0., 0.]);
481 assert_abs_diff_eq!(std_devs, array![1., 0., 1.]);
483 }
484
485 #[test]
486 fn test_max_abs_const_null_feature() {
487 let dataset = array![[1., 0.], [2., 0.], [3., 0.], [4., 0.]].into();
488 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
489 let scaled = scaler.transform(dataset);
490 let col0 = scaled.records().column(0);
491 let col1 = scaled.records().column(1);
492 assert_abs_diff_eq!(col0, array![1. / 4., 2. / 4., 3. / 4., 1.]);
493 assert_abs_diff_eq!(col1, array![0., 0., 0., 0.]);
495 }
496
497 #[test]
498 fn test_min_max_scaler_const_feature() {
499 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
500 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
501 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., 2.]);
502 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1.]);
503 let transformed = scaler.transform(dataset);
504 let mins = transformed
505 .records()
506 .fold_axis(
507 Axis(0),
508 f64::INFINITY,
509 |&x, &prev| if x < prev { x } else { prev },
510 );
511 let maxes = transformed
512 .records()
513 .fold_axis(
514 Axis(0),
515 f64::NEG_INFINITY,
516 |&x, &prev| if x > prev { x } else { prev },
517 );
518 assert_abs_diff_eq!(maxes, array![1., 1., 0.]);
520 assert_abs_diff_eq!(mins, array![0., 0., 0.]);
521 }
522
523 #[test]
524 fn test_empty_input() {
525 let dataset: DatasetBase<Array2<f64>, _> =
526 Array2::from_shape_vec((0, 0), vec![]).unwrap().into();
527 let scaler = LinearScaler::standard().fit(&dataset);
528 assert_eq!(
529 scaler.err().unwrap().to_string(),
530 "not enough samples".to_string()
531 );
532 let scaler = LinearScaler::standard_no_mean().fit(&dataset);
533 assert_eq!(
534 scaler.err().unwrap().to_string(),
535 "not enough samples".to_string()
536 );
537 let scaler = LinearScaler::standard_no_std().fit(&dataset);
538 assert_eq!(
539 scaler.err().unwrap().to_string(),
540 "not enough samples".to_string()
541 );
542 let scaler = LinearScaler::min_max().fit(&dataset);
543 assert_eq!(
544 scaler.err().unwrap().to_string(),
545 "not enough samples".to_string()
546 );
547 let scaler = LinearScaler::max_abs().fit(&dataset);
548 assert_eq!(
549 scaler.err().unwrap().to_string(),
550 "not enough samples".to_string()
551 );
552 }
553
554 #[test]
555 fn test_transform_empty_array() {
556 let empty: Array2<f64> = Array2::from_shape_vec((0, 0), vec![]).unwrap();
557 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
558 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
559 let transformed = scaler.transform(empty.clone());
560 assert!(transformed.is_empty());
561 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
562 let transformed = scaler.transform(empty.clone());
563 assert!(transformed.is_empty());
564 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
565 let transformed = scaler.transform(empty.clone());
566 assert!(transformed.is_empty());
567 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
568 let transformed = scaler.transform(empty.clone());
569 assert!(transformed.is_empty());
570 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
571 let transformed = scaler.transform(empty);
572 assert!(transformed.is_empty());
573 }
574
575 #[test]
576 fn test_retain_feature_names() {
577 let dataset = linfa_datasets::diabetes();
578 let original_feature_names = dataset.feature_names();
579 let transformed = LinearScaler::standard()
580 .fit(&dataset)
581 .unwrap()
582 .transform(dataset);
583 assert_eq!(original_feature_names, transformed.feature_names())
584 }
585
586 #[test]
587 #[should_panic]
588 fn test_transform_wrong_size_array_standard() {
589 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
590 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
591 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
592 let _transformed = scaler.transform(wrong_size);
593 }
594 #[test]
595 #[should_panic]
596 fn test_transform_wrong_size_array_standard_no_mean() {
597 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
598 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
599 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
600 let _transformed = scaler.transform(wrong_size);
601 }
602 #[test]
603 #[should_panic]
604 fn test_transform_wrong_size_array_standard_no_std() {
605 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
606 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
607 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
608 let _transformed = scaler.transform(wrong_size);
609 }
610 #[test]
611 #[should_panic]
612 fn test_transform_wrong_size_array_min_max() {
613 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
614 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
615 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
616 let _transformed = scaler.transform(wrong_size);
617 }
618 #[test]
619 #[should_panic]
620 fn test_transform_wrong_size_array_max_abs() {
621 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
622 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
623 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
624 let _transformed = scaler.transform(wrong_size);
625 }
626
627 #[test]
628 #[should_panic]
629 fn test_min_max_wrong_range() {
630 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
631 let _scaler = LinearScaler::min_max_range(10., 5.).fit(&dataset).unwrap();
632 }
633}