rs_ml/transformer/
scalers.rs

1//! Scalers to limit, normalize range
2
3use ndarray::{Array1, Array2, Axis};
4use num_traits::Float;
5use std::marker::PhantomData;
6
7use crate::Estimator;
8
9use super::Transformer;
10
11/// Params needed to fit a standard scaler with 0 mean, unit variance
12pub struct StandardScalerParams;
13
14/// Params required to fit a min max scaler
15#[derive(Default)]
16pub struct MinMaxScalerParams<F> {
17    _data: PhantomData<F>,
18}
19
20/// Transforms input data to 0 mean, unit variance.
21pub struct StandardScaler {
22    means: Array1<f64>,
23    std_devs: Array1<f64>,
24}
25
26/// Scales range of input data to between 0 and 1 linearly.
27pub struct MinMaxScaler<F> {
28    min_value: F,
29    max_value: F,
30}
31
32impl<F: Default> MinMaxScalerParams<F> {
33    /// Create new instance of MinMaxScaler
34    pub fn new() -> Self {
35        MinMaxScalerParams::default()
36    }
37}
38
39impl<A, F> Estimator<A> for MinMaxScalerParams<F>
40where
41    A: AsRef<[F]>,
42    F: Float,
43{
44    type Estimator = MinMaxScaler<F>;
45
46    fn fit(&self, input: &A) -> Option<Self::Estimator> {
47        let max_value = input
48            .as_ref()
49            .iter()
50            .fold(F::min_value(), |agg, curr| curr.max(agg));
51        let min_value = input
52            .as_ref()
53            .iter()
54            .fold(F::max_value(), |agg, curr| curr.min(agg));
55
56        Some(MinMaxScaler::<F> {
57            min_value,
58            max_value,
59        })
60    }
61}
62
63impl<A, F> Transformer<A, Vec<F>> for MinMaxScaler<F>
64where
65    A: AsRef<[F]>,
66    F: Float,
67{
68    fn transform(&self, arr: &A) -> Option<Vec<F>> {
69        Some(
70            arr.as_ref()
71                .iter()
72                .map(|elem| {
73                    elem.sub(self.min_value)
74                        .div(self.max_value - self.min_value)
75                })
76                .collect(),
77        )
78    }
79}
80
81impl Estimator<Array2<f64>> for StandardScalerParams {
82    type Estimator = StandardScaler;
83
84    fn fit(&self, input: &Array2<f64>) -> Option<StandardScaler> {
85        Some(StandardScaler {
86            means: input.mean_axis(Axis(0))?,
87            std_devs: input.std_axis(Axis(0), (input.shape()[0] - 1) as f64),
88        })
89    }
90}
91
92impl Transformer<Array2<f64>, Array2<f64>> for StandardScaler {
93    fn transform(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
94        Some((arr - &self.means) / &self.std_devs)
95    }
96}