rs_ml/transformer/
scalers.rs1use ndarray::{Array1, Array2, Axis};
4use num_traits::Float;
5use std::marker::PhantomData;
6
7use crate::Estimator;
8
9use super::Transformer;
10
11#[derive(Debug, Clone, Copy)]
13pub struct StandardScalerParams;
14
15#[derive(Default, Debug, Clone, Copy)]
17pub struct MinMaxScalerParams<F>(PhantomData<F>);
18
19#[derive(Debug, Clone)]
21pub struct StandardScaler {
22 means: Array1<f64>,
23 std_devs: Array1<f64>,
24}
25
26#[derive(Debug, Clone)]
29pub struct MinMaxScaler<F> {
30 min_value: F,
31 max_value: F,
32}
33
34impl<F: Default> MinMaxScalerParams<F> {
35 pub fn new() -> Self {
37 MinMaxScalerParams::default()
38 }
39}
40
41impl<A, F> Estimator<A> for MinMaxScalerParams<F>
42where
43 A: AsRef<[F]>,
44 F: Float,
45{
46 type Estimator = MinMaxScaler<F>;
47
48 fn fit(&self, input: &A) -> Option<Self::Estimator> {
49 let max_value = input
50 .as_ref()
51 .iter()
52 .fold(F::min_value(), |agg, curr| curr.max(agg));
53 let min_value = input
54 .as_ref()
55 .iter()
56 .fold(F::max_value(), |agg, curr| curr.min(agg));
57
58 Some(MinMaxScaler::<F> {
59 min_value,
60 max_value,
61 })
62 }
63}
64
65impl<A, F> Transformer<A, Vec<F>> for MinMaxScaler<F>
66where
67 A: AsRef<[F]>,
68 F: Float,
69{
70 fn transform(&self, arr: &A) -> Option<Vec<F>> {
71 Some(
72 arr.as_ref()
73 .iter()
74 .map(|elem| {
75 elem.sub(self.min_value)
76 .div(self.max_value - self.min_value)
77 })
78 .collect(),
79 )
80 }
81}
82
83impl Estimator<Array2<f64>> for StandardScalerParams {
84 type Estimator = StandardScaler;
85
86 fn fit(&self, input: &Array2<f64>) -> Option<StandardScaler> {
87 Some(StandardScaler {
88 means: input.mean_axis(Axis(0))?,
89 std_devs: input.std_axis(Axis(0), (input.shape()[0] - 1) as f64),
90 })
91 }
92}
93
94impl Transformer<Array2<f64>, Array2<f64>> for StandardScaler {
95 fn transform(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
96 Some((arr - &self.means) / &self.std_devs)
97 }
98}