1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct MaxAbsScaler;
14
15impl MaxAbsScaler {
16 pub fn new() -> Self {
18 Self
19 }
20}
21
22impl Default for MaxAbsScaler {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
31pub struct FittedMaxAbsScaler<F: Float> {
32 max_abs: Array1<F>,
33}
34
35impl<F: Float> FitUnsupervised<F> for MaxAbsScaler {
36 type Fitted = FittedMaxAbsScaler<F>;
37
38 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
39 if x.is_empty() {
40 return Err(RustMlError::EmptyInput("input array is empty".into()));
41 }
42
43 let max_abs = x
44 .mapv(|v| v.abs())
45 .fold_axis(Axis(0), F::zero(), |&a, &b| a.max(b));
46
47 Ok(FittedMaxAbsScaler { max_abs })
48 }
49}
50
51impl<F: Float> Transform<F> for FittedMaxAbsScaler<F> {
52 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
53 if x.ncols() != self.max_abs.len() {
54 return Err(RustMlError::ShapeMismatch(format!(
55 "expected {} features, got {}",
56 self.max_abs.len(),
57 x.ncols()
58 )));
59 }
60
61 let eps = F::from_f64(1e-15).unwrap();
62 let mut result = x.to_owned();
63 for mut row in result.rows_mut() {
64 for (j, val) in row.iter_mut().enumerate() {
65 if self.max_abs[j] > eps {
66 *val = *val / self.max_abs[j];
67 }
68 }
69 }
70 Ok(result)
71 }
72}
73
74impl<F: Float> InverseTransform<F> for FittedMaxAbsScaler<F> {
75 fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
76 if x.ncols() != self.max_abs.len() {
77 return Err(RustMlError::ShapeMismatch(format!(
78 "expected {} features, got {}",
79 self.max_abs.len(),
80 x.ncols()
81 )));
82 }
83
84 let eps = F::from_f64(1e-15).unwrap();
85 let mut result = x.to_owned();
86 for mut row in result.rows_mut() {
87 for (j, val) in row.iter_mut().enumerate() {
88 if self.max_abs[j] > eps {
89 *val = *val * self.max_abs[j];
90 }
91 }
92 }
93 Ok(result)
94 }
95}
96
97impl<F: Float> FittedMaxAbsScaler<F> {
98 pub fn max_abs(&self) -> &Array1<F> {
100 &self.max_abs
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use approx::assert_abs_diff_eq;
108 use ndarray::array;
109
110 #[test]
111 fn test_basic_scaling() {
112 let x = array![[1.0, -4.0], [2.0, 2.0], [-3.0, 1.0]];
113 let scaler = MaxAbsScaler::new();
114 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
115 let transformed = fitted.transform(&x).unwrap();
116
117 assert_abs_diff_eq!(fitted.max_abs()[0], 3.0, epsilon = 1e-10);
119 assert_abs_diff_eq!(fitted.max_abs()[1], 4.0, epsilon = 1e-10);
120
121 assert_abs_diff_eq!(transformed[[0, 0]], 1.0 / 3.0, epsilon = 1e-10);
122 assert_abs_diff_eq!(transformed[[0, 1]], -1.0, epsilon = 1e-10);
123 assert_abs_diff_eq!(transformed[[2, 0]], -1.0, epsilon = 1e-10);
124
125 for &v in transformed.iter() {
127 assert!(v >= -1.0 && v <= 1.0, "value {} not in [-1, 1]", v);
128 }
129 }
130
131 #[test]
132 fn test_inverse_transform_roundtrip() {
133 let x = array![[1.0, -4.0], [2.0, 2.0], [-3.0, 1.0]];
134 let scaler = MaxAbsScaler::new();
135 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
136 let transformed = fitted.transform(&x).unwrap();
137 let recovered = fitted.inverse_transform(&transformed).unwrap();
138
139 for (a, b) in x.iter().zip(recovered.iter()) {
140 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
141 }
142 }
143
144 #[test]
145 fn test_zero_column() {
146 let x = array![[0.0, 2.0], [0.0, -4.0], [0.0, 1.0]];
148 let scaler = MaxAbsScaler::new();
149 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
150 let transformed = fitted.transform(&x).unwrap();
151
152 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
154 assert_abs_diff_eq!(transformed[[1, 0]], 0.0, epsilon = 1e-10);
155 assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
156
157 assert_abs_diff_eq!(transformed[[1, 1]], -1.0, epsilon = 1e-10);
159
160 for &v in transformed.iter() {
161 assert!(v.is_finite(), "zero column produced non-finite: {}", v);
162 }
163 }
164
165 #[test]
166 fn test_f32_support() {
167 let x = array![[1.0f32, -4.0], [2.0, 2.0], [-3.0, 1.0]];
168 let scaler = MaxAbsScaler::new();
169 let fitted = FitUnsupervised::<f32>::fit(&scaler, &x).unwrap();
170 let transformed = fitted.transform(&x).unwrap();
171 let recovered = fitted.inverse_transform(&transformed).unwrap();
172
173 for (a, b) in x.iter().zip(recovered.iter()) {
174 assert_abs_diff_eq!(a, b, epsilon = 1e-5);
175 }
176 }
177
178 #[test]
179 fn test_already_scaled() {
180 let x = array![[-1.0, 0.5], [0.0, 1.0], [0.5, -1.0]];
182 let scaler = MaxAbsScaler::new();
183 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
184 let transformed = fitted.transform(&x).unwrap();
185
186 for (a, b) in x.iter().zip(transformed.iter()) {
187 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
188 }
189 }
190
191 #[test]
192 fn test_empty_input() {
193 let x: Array2<f64> = Array2::zeros((0, 0));
194 let scaler = MaxAbsScaler::new();
195 let result = FitUnsupervised::<f64>::fit(&scaler, &x);
196 assert!(result.is_err());
197 }
198
199 #[test]
200 fn test_shape_mismatch() {
201 let x = array![[1.0, 2.0], [3.0, 4.0]];
202 let scaler = MaxAbsScaler::new();
203 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
204
205 let x_wrong = array![[1.0, 2.0, 3.0]];
206 assert!(fitted.transform(&x_wrong).is_err());
207 assert!(fitted.inverse_transform(&x_wrong).is_err());
208 }
209}