rs_ml/transformer/
scalers.rs

1//! Scalers to limit, normalize range
2
3use ndarray::{Array1, Array2, Axis};
4use num_traits::Float;
5use std::marker::PhantomData;
6
7use crate::Estimator;
8
9use super::Transformer;
10
11/// Params needed to fit a standard scaler with 0 mean, unit variance
12#[derive(Debug, Clone, Copy)]
13pub struct StandardScalerParams;
14
15/// Params required to fit a min max scaler.
16#[derive(Default, Debug, Clone, Copy)]
17pub struct MinMaxScalerParams<F>(PhantomData<F>);
18
19/// Transforms input data to 0 mean, unit variance.
20#[derive(Debug, Clone)]
21pub struct StandardScaler {
22    means: Array1<f64>,
23    std_devs: Array1<f64>,
24}
25
26/// Scales range of input data to between 0 and 1 linearly.
27#[derive(Debug, Clone)]
28pub struct MinMaxScaler<F> {
29    min_value: F,
30    max_value: F,
31}
32
33impl<F: Default> MinMaxScalerParams<F> {
34    /// Create new instance of MinMaxScaler
35    pub fn new() -> Self {
36        MinMaxScalerParams::default()
37    }
38}
39
40impl<A, F> Estimator<A> for MinMaxScalerParams<F>
41where
42    A: AsRef<[F]>,
43    F: Float,
44{
45    type Estimator = MinMaxScaler<F>;
46
47    fn fit(&self, input: &A) -> Option<Self::Estimator> {
48        let max_value = input
49            .as_ref()
50            .iter()
51            .fold(F::min_value(), |agg, curr| curr.max(agg));
52        let min_value = input
53            .as_ref()
54            .iter()
55            .fold(F::max_value(), |agg, curr| curr.min(agg));
56
57        Some(MinMaxScaler::<F> {
58            min_value,
59            max_value,
60        })
61    }
62}
63
64impl<A, F> Transformer<A, Vec<F>> for MinMaxScaler<F>
65where
66    A: AsRef<[F]>,
67    F: Float,
68{
69    fn transform(&self, arr: &A) -> Option<Vec<F>> {
70        Some(
71            arr.as_ref()
72                .iter()
73                .map(|elem| {
74                    elem.sub(self.min_value)
75                        .div(self.max_value - self.min_value)
76                })
77                .collect(),
78        )
79    }
80}
81
82impl Estimator<Array2<f64>> for StandardScalerParams {
83    type Estimator = StandardScaler;
84
85    fn fit(&self, input: &Array2<f64>) -> Option<StandardScaler> {
86        Some(StandardScaler {
87            means: input.mean_axis(Axis(0))?,
88            std_devs: input.std_axis(Axis(0), (input.shape()[0] - 1) as f64),
89        })
90    }
91}
92
93impl Transformer<Array2<f64>, Array2<f64>> for StandardScaler {
94    fn transform(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
95        Some((arr - &self.means) / &self.std_devs)
96    }
97}