Skip to main content

anofox_ml_preprocessing/
max_abs_scaler.rs

1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4/// Parameters for MaxAbsScaler (unfitted state).
5///
6/// Scales each feature by its maximum absolute value so that the resulting
7/// values lie in the range [-1, 1]. Unlike `StandardScaler` or
8/// `RobustScaler`, this scaler does **not** center the data, which makes
9/// it suitable for sparse data.
10///
11/// `x_scaled[i, j] = x[i, j] / max_abs[j]`
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct MaxAbsScaler;
14
15impl MaxAbsScaler {
16    /// Create a new `MaxAbsScaler`.
17    pub fn new() -> Self {
18        Self
19    }
20}
21
22impl Default for MaxAbsScaler {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28/// Fitted MaxAbsScaler — holds the maximum absolute value per feature.
29#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
31pub struct FittedMaxAbsScaler<F: Float> {
32    max_abs: Array1<F>,
33}
34
35impl<F: Float> FitUnsupervised<F> for MaxAbsScaler {
36    type Fitted = FittedMaxAbsScaler<F>;
37
38    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
39        if x.is_empty() {
40            return Err(RustMlError::EmptyInput("input array is empty".into()));
41        }
42
43        let max_abs = x
44            .mapv(|v| v.abs())
45            .fold_axis(Axis(0), F::zero(), |&a, &b| a.max(b));
46
47        Ok(FittedMaxAbsScaler { max_abs })
48    }
49}
50
51impl<F: Float> Transform<F> for FittedMaxAbsScaler<F> {
52    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
53        if x.ncols() != self.max_abs.len() {
54            return Err(RustMlError::ShapeMismatch(format!(
55                "expected {} features, got {}",
56                self.max_abs.len(),
57                x.ncols()
58            )));
59        }
60
61        let eps = F::from_f64(1e-15).unwrap();
62        let mut result = x.to_owned();
63        for mut row in result.rows_mut() {
64            for (j, val) in row.iter_mut().enumerate() {
65                if self.max_abs[j] > eps {
66                    *val = *val / self.max_abs[j];
67                }
68            }
69        }
70        Ok(result)
71    }
72}
73
74impl<F: Float> InverseTransform<F> for FittedMaxAbsScaler<F> {
75    fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
76        if x.ncols() != self.max_abs.len() {
77            return Err(RustMlError::ShapeMismatch(format!(
78                "expected {} features, got {}",
79                self.max_abs.len(),
80                x.ncols()
81            )));
82        }
83
84        let eps = F::from_f64(1e-15).unwrap();
85        let mut result = x.to_owned();
86        for mut row in result.rows_mut() {
87            for (j, val) in row.iter_mut().enumerate() {
88                if self.max_abs[j] > eps {
89                    *val = *val * self.max_abs[j];
90                }
91            }
92        }
93        Ok(result)
94    }
95}
96
97impl<F: Float> FittedMaxAbsScaler<F> {
98    /// Return the maximum absolute value per feature.
99    pub fn max_abs(&self) -> &Array1<F> {
100        &self.max_abs
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use approx::assert_abs_diff_eq;
108    use ndarray::array;
109
110    #[test]
111    fn test_basic_scaling() {
112        let x = array![[1.0, -4.0], [2.0, 2.0], [-3.0, 1.0]];
113        let scaler = MaxAbsScaler::new();
114        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
115        let transformed = fitted.transform(&x).unwrap();
116
117        // max_abs for col 0 = 3.0, col 1 = 4.0
118        assert_abs_diff_eq!(fitted.max_abs()[0], 3.0, epsilon = 1e-10);
119        assert_abs_diff_eq!(fitted.max_abs()[1], 4.0, epsilon = 1e-10);
120
121        assert_abs_diff_eq!(transformed[[0, 0]], 1.0 / 3.0, epsilon = 1e-10);
122        assert_abs_diff_eq!(transformed[[0, 1]], -1.0, epsilon = 1e-10);
123        assert_abs_diff_eq!(transformed[[2, 0]], -1.0, epsilon = 1e-10);
124
125        // All values should be in [-1, 1]
126        for &v in transformed.iter() {
127            assert!(v >= -1.0 && v <= 1.0, "value {} not in [-1, 1]", v);
128        }
129    }
130
131    #[test]
132    fn test_inverse_transform_roundtrip() {
133        let x = array![[1.0, -4.0], [2.0, 2.0], [-3.0, 1.0]];
134        let scaler = MaxAbsScaler::new();
135        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
136        let transformed = fitted.transform(&x).unwrap();
137        let recovered = fitted.inverse_transform(&transformed).unwrap();
138
139        for (a, b) in x.iter().zip(recovered.iter()) {
140            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
141        }
142    }
143
144    #[test]
145    fn test_zero_column() {
146        // A column of all zeros should pass through unchanged (no division by zero)
147        let x = array![[0.0, 2.0], [0.0, -4.0], [0.0, 1.0]];
148        let scaler = MaxAbsScaler::new();
149        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
150        let transformed = fitted.transform(&x).unwrap();
151
152        // Zero column stays zero
153        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
154        assert_abs_diff_eq!(transformed[[1, 0]], 0.0, epsilon = 1e-10);
155        assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
156
157        // Non-zero column is scaled
158        assert_abs_diff_eq!(transformed[[1, 1]], -1.0, epsilon = 1e-10);
159
160        for &v in transformed.iter() {
161            assert!(v.is_finite(), "zero column produced non-finite: {}", v);
162        }
163    }
164
165    #[test]
166    fn test_f32_support() {
167        let x = array![[1.0f32, -4.0], [2.0, 2.0], [-3.0, 1.0]];
168        let scaler = MaxAbsScaler::new();
169        let fitted = FitUnsupervised::<f32>::fit(&scaler, &x).unwrap();
170        let transformed = fitted.transform(&x).unwrap();
171        let recovered = fitted.inverse_transform(&transformed).unwrap();
172
173        for (a, b) in x.iter().zip(recovered.iter()) {
174            assert_abs_diff_eq!(a, b, epsilon = 1e-5);
175        }
176    }
177
178    #[test]
179    fn test_already_scaled() {
180        // Data already in [-1, 1] should be unchanged when max_abs == 1
181        let x = array![[-1.0, 0.5], [0.0, 1.0], [0.5, -1.0]];
182        let scaler = MaxAbsScaler::new();
183        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
184        let transformed = fitted.transform(&x).unwrap();
185
186        for (a, b) in x.iter().zip(transformed.iter()) {
187            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
188        }
189    }
190
191    #[test]
192    fn test_empty_input() {
193        let x: Array2<f64> = Array2::zeros((0, 0));
194        let scaler = MaxAbsScaler::new();
195        let result = FitUnsupervised::<f64>::fit(&scaler, &x);
196        assert!(result.is_err());
197    }
198
199    #[test]
200    fn test_shape_mismatch() {
201        let x = array![[1.0, 2.0], [3.0, 4.0]];
202        let scaler = MaxAbsScaler::new();
203        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
204
205        let x_wrong = array![[1.0, 2.0, 3.0]];
206        assert!(fitted.transform(&x_wrong).is_err());
207        assert!(fitted.inverse_transform(&x_wrong).is_err());
208    }
209}