rs_ml/transformer/
scalers.rs1use ndarray::{Array1, Array2, Axis};
4use num_traits::Float;
5use std::marker::PhantomData;
6
7use crate::Estimator;
8
9use super::Transformer;
10
11#[derive(Debug, Clone, Copy)]
13pub struct StandardScalerParams;
14
15#[derive(Default, Debug, Clone, Copy)]
17pub struct MinMaxScalerParams<F>(PhantomData<F>);
18
19#[derive(Debug, Clone)]
21pub struct StandardScaler {
22 means: Array1<f64>,
23 std_devs: Array1<f64>,
24}
25
26#[derive(Debug, Clone)]
28pub struct MinMaxScaler<F> {
29 min_value: F,
30 max_value: F,
31}
32
33impl<F: Default> MinMaxScalerParams<F> {
34 pub fn new() -> Self {
36 MinMaxScalerParams::default()
37 }
38}
39
40impl<A, F> Estimator<A> for MinMaxScalerParams<F>
41where
42 A: AsRef<[F]>,
43 F: Float,
44{
45 type Estimator = MinMaxScaler<F>;
46
47 fn fit(&self, input: &A) -> Option<Self::Estimator> {
48 let max_value = input
49 .as_ref()
50 .iter()
51 .fold(F::min_value(), |agg, curr| curr.max(agg));
52 let min_value = input
53 .as_ref()
54 .iter()
55 .fold(F::max_value(), |agg, curr| curr.min(agg));
56
57 Some(MinMaxScaler::<F> {
58 min_value,
59 max_value,
60 })
61 }
62}
63
64impl<A, F> Transformer<A, Vec<F>> for MinMaxScaler<F>
65where
66 A: AsRef<[F]>,
67 F: Float,
68{
69 fn transform(&self, arr: &A) -> Option<Vec<F>> {
70 Some(
71 arr.as_ref()
72 .iter()
73 .map(|elem| {
74 elem.sub(self.min_value)
75 .div(self.max_value - self.min_value)
76 })
77 .collect(),
78 )
79 }
80}
81
82impl Estimator<Array2<f64>> for StandardScalerParams {
83 type Estimator = StandardScaler;
84
85 fn fit(&self, input: &Array2<f64>) -> Option<StandardScaler> {
86 Some(StandardScaler {
87 means: input.mean_axis(Axis(0))?,
88 std_devs: input.std_axis(Axis(0), (input.shape()[0] - 1) as f64),
89 })
90 }
91}
92
93impl Transformer<Array2<f64>, Array2<f64>> for StandardScaler {
94 fn transform(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
95 Some((arr - &self.means) / &self.std_devs)
96 }
97}