rs_ml/transformer/
scalers.rs1use ndarray::{Array1, Array2, Axis};
4use num_traits::Float;
5
6use crate::Estimator;
7
8use super::Transformer;
9
10#[derive(Debug, Clone, Copy)]
12pub struct StandardScalerEstimator;
13
14#[derive(Debug, Clone, Copy)]
17pub struct MinMaxScalerParams<F> {
18 min: F,
19 max: F,
20}
21
22#[derive(Debug, Clone)]
25pub struct StandardScaler {
26 means: Array1<f64>,
27 std_devs: Array1<f64>,
28}
29
30#[derive(Debug, Clone)]
33pub struct MinMaxScaler<F> {
34 min: F,
35 diff: F,
36 min_value: F,
37 diff_value: F,
38}
39
40impl<F: Float> Default for MinMaxScalerParams<F> {
41 fn default() -> Self {
42 Self {
43 min: F::zero(),
44 max: F::one(),
45 }
46 }
47}
48
49impl<F: Float> MinMaxScalerParams<F> {
50 pub fn new(min: F, max: F) -> Self {
52 MinMaxScalerParams { min, max }
53 }
54}
55
56impl<A, F> Estimator<A> for MinMaxScalerParams<F>
57where
58 A: AsRef<[F]>,
59 F: Float,
60{
61 type Estimator = MinMaxScaler<F>;
62
63 fn fit(&self, input: &A) -> Option<Self::Estimator> {
64 let max_value = input
65 .as_ref()
66 .iter()
67 .fold(F::min_value(), |agg, curr| curr.max(agg));
68 let min_value = input
69 .as_ref()
70 .iter()
71 .fold(F::max_value(), |agg, curr| curr.min(agg));
72
73 Some(MinMaxScaler::<F> {
74 min: self.min,
75 diff: self.max - self.min,
76 min_value,
77 diff_value: max_value - min_value,
78 })
79 }
80}
81
82impl<A, F> Transformer<A, Vec<F>> for MinMaxScaler<F>
83where
84 A: AsRef<[F]>,
85 F: Float,
86{
87 fn transform(&self, arr: &A) -> Option<Vec<F>> {
88 Some(
89 arr.as_ref()
90 .iter()
91 .map(|elem| {
92 elem.sub(self.min_value)
93 .div(self.diff_value)
94 .mul(self.diff)
95 .add(self.min)
96 })
97 .collect(),
98 )
99 }
100}
101
102impl Estimator<Array2<f64>> for StandardScalerEstimator {
103 type Estimator = StandardScaler;
104
105 fn fit(&self, input: &Array2<f64>) -> Option<StandardScaler> {
106 Some(StandardScaler {
107 means: input.mean_axis(Axis(0))?,
108 std_devs: input.std_axis(Axis(0), (input.shape()[0] - 1) as f64),
109 })
110 }
111}
112
113impl Transformer<Array2<f64>, Array2<f64>> for StandardScaler {
114 fn transform(&self, arr: &Array2<f64>) -> Option<Array2<f64>> {
115 Some((arr - &self.means) / &self.std_devs)
116 }
117}