Skip to main content

anofox_ml_preprocessing/
minmax_scaler.rs

1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4/// Parameters for MinMaxScaler (unfitted state).
5///
6/// Scales features to a given range (default [0, 1]):
7/// `x_scaled = (x - min) / (max - min) * (feature_max - feature_min) + feature_min`
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
10pub struct MinMaxScaler<F: Float> {
11    pub feature_min: F,
12    pub feature_max: F,
13}
14
15impl<F: Float> MinMaxScaler<F> {
16    /// Create a new `MinMaxScaler` with the default range [0, 1].
17    pub fn new() -> Self {
18        Self {
19            feature_min: F::zero(),
20            feature_max: F::one(),
21        }
22    }
23
24    /// Set the target feature range (min, max).
25    pub fn with_range(mut self, min: F, max: F) -> Self {
26        self.feature_min = min;
27        self.feature_max = max;
28        self
29    }
30}
31
32impl<F: Float> Default for MinMaxScaler<F> {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38/// Fitted MinMaxScaler — holds learned min/max per feature.
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
41pub struct FittedMinMaxScaler<F: Float> {
42    data_min: Array1<F>,
43    data_max: Array1<F>,
44    feature_min: F,
45    feature_max: F,
46}
47
48impl<F: Float> FitUnsupervised<F> for MinMaxScaler<F> {
49    type Fitted = FittedMinMaxScaler<F>;
50
51    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
52        if x.is_empty() {
53            return Err(RustMlError::EmptyInput("input array is empty".into()));
54        }
55        if self.feature_min >= self.feature_max {
56            return Err(RustMlError::InvalidParameter(
57                "feature_min must be less than feature_max".into(),
58            ));
59        }
60
61        let data_min = x.fold_axis(Axis(0), F::infinity(), |&a, &b| a.min(b));
62        let data_max = x.fold_axis(Axis(0), F::neg_infinity(), |&a, &b| a.max(b));
63
64        Ok(FittedMinMaxScaler {
65            data_min,
66            data_max,
67            feature_min: self.feature_min,
68            feature_max: self.feature_max,
69        })
70    }
71}
72
73impl<F: Float> Transform<F> for FittedMinMaxScaler<F> {
74    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
75        if x.ncols() != self.data_min.len() {
76            return Err(RustMlError::ShapeMismatch(format!(
77                "expected {} features, got {}",
78                self.data_min.len(),
79                x.ncols()
80            )));
81        }
82
83        let range = self.feature_max - self.feature_min;
84        let mut result = x.to_owned();
85
86        for mut row in result.rows_mut() {
87            for (j, val) in row.iter_mut().enumerate() {
88                let data_range = self.data_max[j] - self.data_min[j];
89                if data_range > F::from_f64(1e-15).unwrap() {
90                    *val = (*val - self.data_min[j]) / data_range * range + self.feature_min;
91                } else {
92                    *val = self.feature_min;
93                }
94            }
95        }
96        Ok(result)
97    }
98}
99
100impl<F: Float> InverseTransform<F> for FittedMinMaxScaler<F> {
101    fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
102        if x.ncols() != self.data_min.len() {
103            return Err(RustMlError::ShapeMismatch(format!(
104                "expected {} features, got {}",
105                self.data_min.len(),
106                x.ncols()
107            )));
108        }
109
110        let range = self.feature_max - self.feature_min;
111        let mut result = x.to_owned();
112
113        for mut row in result.rows_mut() {
114            for (j, val) in row.iter_mut().enumerate() {
115                let data_range = self.data_max[j] - self.data_min[j];
116                if data_range > F::from_f64(1e-15).unwrap() {
117                    *val = (*val - self.feature_min) / range * data_range + self.data_min[j];
118                } else {
119                    *val = self.data_min[j];
120                }
121            }
122        }
123        Ok(result)
124    }
125}
126
127impl<F: Float> FittedMinMaxScaler<F> {
128    pub fn data_min(&self) -> &Array1<F> {
129        &self.data_min
130    }
131
132    pub fn data_max(&self) -> &Array1<F> {
133        &self.data_max
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use approx::assert_abs_diff_eq;
141    use ndarray::array;
142
143    #[test]
144    fn test_fit_transform_default() {
145        let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
146        let scaler = MinMaxScaler::<f64>::default();
147        let fitted = scaler.fit(&x).unwrap();
148        let transformed = fitted.transform(&x).unwrap();
149
150        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
151        assert_abs_diff_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-10);
152        assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10);
153        assert_abs_diff_eq!(transformed[[2, 1]], 1.0, epsilon = 1e-10);
154    }
155
156    #[test]
157    fn test_custom_range() {
158        let x = array![[1.0], [2.0], [3.0]];
159        let scaler = MinMaxScaler {
160            feature_min: -1.0,
161            feature_max: 1.0,
162        };
163        let fitted = scaler.fit(&x).unwrap();
164        let transformed = fitted.transform(&x).unwrap();
165
166        assert_abs_diff_eq!(transformed[[0, 0]], -1.0, epsilon = 1e-10);
167        assert_abs_diff_eq!(transformed[[1, 0]], 0.0, epsilon = 1e-10);
168        assert_abs_diff_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-10);
169    }
170
171    #[test]
172    fn test_inverse_transform_roundtrip() {
173        let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
174        let scaler = MinMaxScaler::<f64>::default();
175        let fitted = scaler.fit(&x).unwrap();
176        let transformed = fitted.transform(&x).unwrap();
177        let recovered = fitted.inverse_transform(&transformed).unwrap();
178
179        for (a, b) in x.iter().zip(recovered.iter()) {
180            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
181        }
182    }
183
184    #[test]
185    fn test_large_values() {
186        let x = array![[1e10], [2e10], [3e10]];
187        let scaler = MinMaxScaler::<f64>::default();
188        let fitted = scaler.fit(&x).unwrap();
189        let transformed = fitted.transform(&x).unwrap();
190
191        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-8);
192        assert_abs_diff_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-8);
193        for &v in transformed.iter() {
194            assert!(v.is_finite(), "large values produced non-finite: {}", v);
195        }
196    }
197
198    #[test]
199    fn test_small_values() {
200        let x = array![[1e-10], [2e-10], [3e-10]];
201        let scaler = MinMaxScaler::<f64>::default();
202        let fitted = scaler.fit(&x).unwrap();
203        let transformed = fitted.transform(&x).unwrap();
204
205        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-8);
206        assert_abs_diff_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-8);
207    }
208
209    #[test]
210    fn test_constant_column() {
211        // Column with zero range should not produce NaN
212        let x = array![[5.0, 1.0], [5.0, 2.0], [5.0, 3.0]];
213        let scaler = MinMaxScaler::<f64>::default();
214        let fitted = scaler.fit(&x).unwrap();
215        let transformed = fitted.transform(&x).unwrap();
216
217        for &v in transformed.iter() {
218            assert!(v.is_finite(), "constant column produced non-finite: {}", v);
219        }
220    }
221}