rs_ml/transformer/
scalers.rs

1//! Commonly used scalers to limit, normalize range
2
3use ndarray::{Array1, Array2, Axis};
4use num_traits::Float;
5use std::marker::PhantomData;
6
7use crate::Estimator;
8
9use super::Transformer;
10
11/// Params needed to fit a standard scaler with 0 mean, unit variance
12#[derive(Debug, Clone, Copy)]
13pub struct StandardScalerParams;
14
15/// Params required to fit a min max scaler.
16#[derive(Default, Debug, Clone, Copy)]
17pub struct MinMaxScalerParams<F>(PhantomData<F>);
18
19/// Transforms input data to 0 mean, unit variance.
20#[derive(Debug, Clone)]
21pub struct StandardScaler {
22    means: Array1<f64>,
23    std_devs: Array1<f64>,
24}
25
26/// Scales range of input data to between 0 and 1 linearly - keeping outliers, but limiting the
27/// output domain.
28#[derive(Debug, Clone)]
29pub struct MinMaxScaler<F> {
30    min_value: F,
31    max_value: F,
32}
33
34impl<F: Default> MinMaxScalerParams<F> {
35    /// Create new instance of MinMaxScaler
36    pub fn new() -> Self {
37        MinMaxScalerParams::default()
38    }
39}
40
41impl<A, F> Estimator<A> for MinMaxScalerParams<F>
42where
43    A: AsRef<[F]>,
44    F: Float,
45{
46    type Estimator = MinMaxScaler<F>;
47
48    fn fit(&self, input: &A) -> Option<Self::Estimator> {
49        let max_value = input
50            .as_ref()
51            .iter()
52            .fold(F::min_value(), |agg, curr| curr.max(agg));
53        let min_value = input
54            .as_ref()
55            .iter()
56            .fold(F::max_value(), |agg, curr| curr.min(agg));
57
58        Some(MinMaxScaler::<F> {
59            min_value,
60            max_value,
61        })
62    }
63}
64
65impl<A, F> Transformer<A, Vec<F>> for MinMaxScaler<F>
66where
67    A: AsRef<[F]>,
68    F: Float,
69{
70    fn transform(&self, arr: &A) -> Option<Vec<F>> {
71        Some(
72            arr.as_ref()
73                .iter()
74                .map(|elem| {
75                    elem.sub(self.min_value)
76                        .div(self.max_value - self.min_value)
77                })
78                .collect(),
79        )
80    }
81}
82
83impl Estimator<Array2<f64>> for StandardScalerParams {
84    type Estimator = StandardScaler;
85
86    fn fit(&self, input: &Array2<f64>) -> Option<StandardScaler> {
87        Some(StandardScaler {
88            means: input.mean_axis(Axis(0))?,
89            std_devs: input.std_axis(Axis(0), (input.shape()[0] - 1) as f64),
90        })
91    }
92}
93
94impl Transformer<Array2<f64>, Array2<f64>> for StandardScaler {
95    fn transform(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
96        Some((arr - &self.means) / &self.std_devs)
97    }
98}