rs_ml/transformer/
scalers.rs

1//! Commonly used scalers to limit, normalize range
2
3use ndarray::{Array1, Array2, Axis};
4use num_traits::Float;
5
6use crate::Estimator;
7
8use super::Transformer;
9
10/// Fits a [`SandardScaler`] to scale input data down to 0 mean and unit variance.
11#[derive(Debug, Clone, Copy)]
12pub struct StandardScalerEstimator;
13
14/// Params required to fit a [`MinMaxScaler`]. By default scales values between 0 and 1 linearly.
15/// Outliers remain, but range is limited.
16#[derive(Debug, Clone, Copy)]
17pub struct MinMaxScalerParams<F> {
18    min: F,
19    max: F,
20}
21
22/// Result of a [`StandardScalerEstimator`]. Scales data down based on the mean and variance
23/// observed during fitting stage.
24#[derive(Debug, Clone)]
25pub struct StandardScaler {
26    means: Array1<f64>,
27    std_devs: Array1<f64>,
28}
29
30/// Result of a fitted [`MinMaxScalerParams`] estimator. Scales values linearly based on the
31/// minimum and maximum values observed during training
32#[derive(Debug, Clone)]
33pub struct MinMaxScaler<F> {
34    min: F,
35    diff: F,
36    min_value: F,
37    diff_value: F,
38}
39
40impl<F: Float> Default for MinMaxScalerParams<F> {
41    fn default() -> Self {
42        Self {
43            min: F::zero(),
44            max: F::one(),
45        }
46    }
47}
48
49impl<F: Float> MinMaxScalerParams<F> {
50    /// Create new instance of MinMaxScaler
51    pub fn new(min: F, max: F) -> Self {
52        MinMaxScalerParams { min, max }
53    }
54}
55
56impl<A, F> Estimator<A> for MinMaxScalerParams<F>
57where
58    A: AsRef<[F]>,
59    F: Float,
60{
61    type Estimator = MinMaxScaler<F>;
62
63    fn fit(&self, input: &A) -> Option<Self::Estimator> {
64        let max_value = input
65            .as_ref()
66            .iter()
67            .fold(F::min_value(), |agg, curr| curr.max(agg));
68        let min_value = input
69            .as_ref()
70            .iter()
71            .fold(F::max_value(), |agg, curr| curr.min(agg));
72
73        Some(MinMaxScaler::<F> {
74            min: self.min,
75            diff: self.max - self.min,
76            min_value,
77            diff_value: max_value - min_value,
78        })
79    }
80}
81
82impl<A, F> Transformer<A, Vec<F>> for MinMaxScaler<F>
83where
84    A: AsRef<[F]>,
85    F: Float,
86{
87    fn transform(&self, arr: &A) -> Option<Vec<F>> {
88        Some(
89            arr.as_ref()
90                .iter()
91                .map(|elem| {
92                    elem.sub(self.min_value)
93                        .div(self.diff_value)
94                        .mul(self.diff)
95                        .add(self.min)
96                })
97                .collect(),
98        )
99    }
100}
101
102impl Estimator<Array2<f64>> for StandardScalerEstimator {
103    type Estimator = StandardScaler;
104
105    fn fit(&self, input: &Array2<f64>) -> Option<StandardScaler> {
106        Some(StandardScaler {
107            means: input.mean_axis(Axis(0))?,
108            std_devs: input.std_axis(Axis(0), (input.shape()[0] - 1) as f64),
109        })
110    }
111}
112
113impl Transformer<Array2<f64>, Array2<f64>> for StandardScaler {
114    fn transform(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
115        Some((arr - &self.means) / &self.std_devs)
116    }
117}