1use ferrolearn_core::error::FerroError;
10use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
11use ferrolearn_core::traits::{Fit, FitTransform, Transform};
12use ndarray::{Array1, Array2};
13use num_traits::Float;
14
15#[derive(Debug, Clone)]
41pub struct MaxAbsScaler<F> {
42 _marker: std::marker::PhantomData<F>,
43}
44
45impl<F: Float + Send + Sync + 'static> MaxAbsScaler<F> {
46 #[must_use]
48 pub fn new() -> Self {
49 Self {
50 _marker: std::marker::PhantomData,
51 }
52 }
53}
54
55impl<F: Float + Send + Sync + 'static> Default for MaxAbsScaler<F> {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61#[derive(Debug, Clone)]
69pub struct FittedMaxAbsScaler<F> {
70 pub(crate) max_abs: Array1<F>,
72}
73
74impl<F: Float + Send + Sync + 'static> FittedMaxAbsScaler<F> {
75 #[must_use]
77 pub fn max_abs(&self) -> &Array1<F> {
78 &self.max_abs
79 }
80
81 pub fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
90 let n_features = self.max_abs.len();
91 if x.ncols() != n_features {
92 return Err(FerroError::ShapeMismatch {
93 expected: vec![x.nrows(), n_features],
94 actual: vec![x.nrows(), x.ncols()],
95 context: "FittedMaxAbsScaler::inverse_transform".into(),
96 });
97 }
98 let mut out = x.to_owned();
99 for (j, mut col) in out.columns_mut().into_iter().enumerate() {
100 let ma = self.max_abs[j];
101 if ma == F::zero() {
102 continue;
103 }
104 for v in col.iter_mut() {
105 *v = *v * ma;
106 }
107 }
108 Ok(out)
109 }
110}
111
112impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for MaxAbsScaler<F> {
117 type Fitted = FittedMaxAbsScaler<F>;
118 type Error = FerroError;
119
120 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedMaxAbsScaler<F>, FerroError> {
126 let n_samples = x.nrows();
127 if n_samples == 0 {
128 return Err(FerroError::InsufficientSamples {
129 required: 1,
130 actual: 0,
131 context: "MaxAbsScaler::fit".into(),
132 });
133 }
134
135 let n_features = x.ncols();
136 let mut max_abs = Array1::zeros(n_features);
137
138 for j in 0..n_features {
139 let col_max_abs = x
140 .column(j)
141 .iter()
142 .copied()
143 .map(|v| v.abs())
144 .fold(F::zero(), |acc, v| if v > acc { v } else { acc });
145 max_abs[j] = col_max_abs;
146 }
147
148 Ok(FittedMaxAbsScaler { max_abs })
149 }
150}
151
152impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedMaxAbsScaler<F> {
153 type Output = Array2<F>;
154 type Error = FerroError;
155
156 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
165 let n_features = self.max_abs.len();
166 if x.ncols() != n_features {
167 return Err(FerroError::ShapeMismatch {
168 expected: vec![x.nrows(), n_features],
169 actual: vec![x.nrows(), x.ncols()],
170 context: "FittedMaxAbsScaler::transform".into(),
171 });
172 }
173
174 let mut out = x.to_owned();
175 for (j, mut col) in out.columns_mut().into_iter().enumerate() {
176 let ma = self.max_abs[j];
177 if ma == F::zero() {
178 continue;
180 }
181 for v in col.iter_mut() {
182 *v = *v / ma;
183 }
184 }
185 Ok(out)
186 }
187}
188
189impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for MaxAbsScaler<F> {
192 type Output = Array2<F>;
193 type Error = FerroError;
194
195 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
200 Err(FerroError::InvalidParameter {
201 name: "MaxAbsScaler".into(),
202 reason: "scaler must be fitted before calling transform; use fit() first".into(),
203 })
204 }
205}
206
207impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for MaxAbsScaler<F> {
208 type FitError = FerroError;
209
210 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
216 let fitted = self.fit(x, &())?;
217 fitted.transform(x)
218 }
219}
220
221impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for MaxAbsScaler<F> {
226 fn fit_pipeline(
234 &self,
235 x: &Array2<F>,
236 _y: &Array1<F>,
237 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
238 let fitted = self.fit(x, &())?;
239 Ok(Box::new(fitted))
240 }
241}
242
243impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedMaxAbsScaler<F> {
244 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
250 self.transform(x)
251 }
252}
253
254#[cfg(test)]
259mod tests {
260 use super::*;
261 use approx::assert_abs_diff_eq;
262 use ndarray::array;
263
264 #[test]
265 fn test_max_abs_scaler_basic() {
266 let scaler = MaxAbsScaler::<f64>::new();
267 let x = array![[-3.0, 1.0], [0.0, -2.0], [2.0, 4.0]];
268 let fitted = scaler.fit(&x, &()).unwrap();
269 assert_abs_diff_eq!(fitted.max_abs()[0], 3.0, epsilon = 1e-10);
271 assert_abs_diff_eq!(fitted.max_abs()[1], 4.0, epsilon = 1e-10);
272
273 let scaled = fitted.transform(&x).unwrap();
274 assert_abs_diff_eq!(scaled[[0, 0]], -1.0, epsilon = 1e-10);
275 assert_abs_diff_eq!(scaled[[1, 0]], 0.0, epsilon = 1e-10);
276 assert_abs_diff_eq!(scaled[[2, 0]], 2.0 / 3.0, epsilon = 1e-10);
277 assert_abs_diff_eq!(scaled[[2, 1]], 1.0, epsilon = 1e-10);
278 }
279
280 #[test]
281 fn test_values_in_range() {
282 let scaler = MaxAbsScaler::<f64>::new();
283 let x = array![[-10.0, 5.0], [3.0, -8.0], [7.0, 2.0]];
284 let fitted = scaler.fit(&x, &()).unwrap();
285 let scaled = fitted.transform(&x).unwrap();
286 for v in scaled.iter() {
287 assert!(
288 *v >= -1.0 - 1e-10 && *v <= 1.0 + 1e-10,
289 "value {v} out of [-1, 1]"
290 );
291 }
292 }
293
294 #[test]
295 fn test_zero_column_unchanged() {
296 let scaler = MaxAbsScaler::<f64>::new();
297 let x = array![[0.0, 1.0], [0.0, 2.0], [0.0, 3.0]];
298 let fitted = scaler.fit(&x, &()).unwrap();
299 assert_abs_diff_eq!(fitted.max_abs()[0], 0.0, epsilon = 1e-15);
300 let scaled = fitted.transform(&x).unwrap();
301 for i in 0..3 {
303 assert_abs_diff_eq!(scaled[[i, 0]], 0.0, epsilon = 1e-10);
304 }
305 }
306
307 #[test]
308 fn test_inverse_transform_roundtrip() {
309 let scaler = MaxAbsScaler::<f64>::new();
310 let x = array![[-3.0, 1.0], [0.0, -2.0], [2.0, 4.0]];
311 let fitted = scaler.fit(&x, &()).unwrap();
312 let scaled = fitted.transform(&x).unwrap();
313 let recovered = fitted.inverse_transform(&scaled).unwrap();
314 for (a, b) in x.iter().zip(recovered.iter()) {
315 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
316 }
317 }
318
319 #[test]
320 fn test_fit_transform_equivalence() {
321 let scaler = MaxAbsScaler::<f64>::new();
322 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
323 let via_fit_transform = scaler.fit_transform(&x).unwrap();
324 let fitted = scaler.fit(&x, &()).unwrap();
325 let via_separate = fitted.transform(&x).unwrap();
326 for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
327 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
328 }
329 }
330
331 #[test]
332 fn test_shape_mismatch_error() {
333 let scaler = MaxAbsScaler::<f64>::new();
334 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
335 let fitted = scaler.fit(&x_train, &()).unwrap();
336 let x_bad = array![[1.0, 2.0, 3.0]];
337 assert!(fitted.transform(&x_bad).is_err());
338 }
339
340 #[test]
341 fn test_insufficient_samples_error() {
342 let scaler = MaxAbsScaler::<f64>::new();
343 let x: Array2<f64> = Array2::zeros((0, 3));
344 assert!(scaler.fit(&x, &()).is_err());
345 }
346
347 #[test]
348 fn test_unfitted_transform_error() {
349 let scaler = MaxAbsScaler::<f64>::new();
350 let x = array![[1.0, 2.0]];
351 assert!(scaler.transform(&x).is_err());
352 }
353
354 #[test]
355 fn test_negative_values() {
356 let scaler = MaxAbsScaler::<f64>::new();
357 let x = array![[-5.0], [-3.0], [-1.0]];
359 let fitted = scaler.fit(&x, &()).unwrap();
360 assert_abs_diff_eq!(fitted.max_abs()[0], 5.0, epsilon = 1e-10);
361 let scaled = fitted.transform(&x).unwrap();
362 assert_abs_diff_eq!(scaled[[0, 0]], -1.0, epsilon = 1e-10);
363 assert_abs_diff_eq!(scaled[[1, 0]], -0.6, epsilon = 1e-10);
364 assert_abs_diff_eq!(scaled[[2, 0]], -0.2, epsilon = 1e-10);
365 }
366
367 #[test]
368 fn test_pipeline_integration() {
369 use ferrolearn_core::pipeline::PipelineTransformer;
370 let scaler = MaxAbsScaler::<f64>::new();
371 let x = array![[2.0, 4.0], [1.0, -2.0]];
372 let y = Array1::zeros(2);
373 let fitted = scaler.fit_pipeline(&x, &y).unwrap();
374 let result = fitted.transform_pipeline(&x).unwrap();
375 assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
376 assert_abs_diff_eq!(result[[1, 1]], -0.5, epsilon = 1e-10);
377 }
378
379 #[test]
380 fn test_f32_scaler() {
381 let scaler = MaxAbsScaler::<f32>::new();
382 let x: Array2<f32> = array![[2.0f32, -4.0], [1.0, 3.0]];
383 let fitted = scaler.fit(&x, &()).unwrap();
384 let scaled = fitted.transform(&x).unwrap();
385 assert!((scaled[[0, 0]] - 1.0f32).abs() < 1e-6);
386 assert!((scaled[[0, 1]] - (-1.0f32)).abs() < 1e-6);
387 }
388}