Skip to main content

anofox_ml_preprocessing/
robust_scaler.rs

1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2};
3
4/// Parameters for RobustScaler (unfitted state).
5///
6/// Scales features using statistics that are robust to outliers.
7/// Uses the median and interquartile range (IQR = Q3 - Q1) instead of
8/// mean and standard deviation:
9/// `z = (x - median) / IQR`
10#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11pub struct RobustScaler {
12    /// If true, center data by subtracting the median.
13    pub with_centering: bool,
14    /// If true, scale data by dividing by the IQR.
15    pub with_scaling: bool,
16}
17
18impl RobustScaler {
19    /// Create a new `RobustScaler` with defaults (both centering and scaling enabled).
20    pub fn new() -> Self {
21        Self {
22            with_centering: true,
23            with_scaling: true,
24        }
25    }
26
27    /// Set whether to center data by subtracting the median.
28    pub fn with_centering(mut self, with_centering: bool) -> Self {
29        self.with_centering = with_centering;
30        self
31    }
32
33    /// Set whether to scale data by dividing by the IQR.
34    pub fn with_scaling(mut self, with_scaling: bool) -> Self {
35        self.with_scaling = with_scaling;
36        self
37    }
38}
39
40impl Default for RobustScaler {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46/// Fitted RobustScaler — holds learned median and IQR per feature.
47#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
48#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
49pub struct FittedRobustScaler<F: Float> {
50    median: Array1<F>,
51    iqr: Array1<F>,
52    with_centering: bool,
53    with_scaling: bool,
54}
55
56/// Compute a percentile value using linear interpolation.
57///
58/// `sorted` must be a sorted slice of values and `p` must be in [0, 1].
59fn percentile<F: Float>(sorted: &[F], p: f64) -> F {
60    let n = sorted.len();
61    if n == 1 {
62        return sorted[0];
63    }
64    let idx = p * (n - 1) as f64;
65    let lo = idx.floor() as usize;
66    let hi = idx.ceil() as usize;
67    if lo == hi {
68        sorted[lo]
69    } else {
70        let frac = F::from_f64(idx - lo as f64).unwrap();
71        sorted[lo] * (F::one() - frac) + sorted[hi] * frac
72    }
73}
74
75impl<F: Float> FitUnsupervised<F> for RobustScaler {
76    type Fitted = FittedRobustScaler<F>;
77
78    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
79        if x.is_empty() {
80            return Err(RustMlError::EmptyInput("input array is empty".into()));
81        }
82
83        let ncols = x.ncols();
84        let mut median = Array1::<F>::zeros(ncols);
85        let mut iqr = Array1::<F>::ones(ncols);
86
87        for j in 0..ncols {
88            let col = x.column(j);
89            let mut sorted: Vec<F> = col.to_vec();
90            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
91
92            median[j] = percentile(&sorted, 0.5);
93
94            if self.with_scaling {
95                let q1 = percentile(&sorted, 0.25);
96                let q3 = percentile(&sorted, 0.75);
97                iqr[j] = q3 - q1;
98            }
99        }
100
101        Ok(FittedRobustScaler {
102            median,
103            iqr,
104            with_centering: self.with_centering,
105            with_scaling: self.with_scaling,
106        })
107    }
108}
109
110impl<F: Float> Transform<F> for FittedRobustScaler<F> {
111    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
112        if x.ncols() != self.median.len() {
113            return Err(RustMlError::ShapeMismatch(format!(
114                "expected {} features, got {}",
115                self.median.len(),
116                x.ncols()
117            )));
118        }
119
120        let mut result = x.to_owned();
121        for mut row in result.rows_mut() {
122            for (j, val) in row.iter_mut().enumerate() {
123                if self.with_centering {
124                    *val -= self.median[j];
125                }
126                if self.with_scaling && self.iqr[j] > F::from_f64(1e-15).unwrap() {
127                    *val /= self.iqr[j];
128                }
129            }
130        }
131        Ok(result)
132    }
133}
134
135impl<F: Float> InverseTransform<F> for FittedRobustScaler<F> {
136    fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
137        if x.ncols() != self.median.len() {
138            return Err(RustMlError::ShapeMismatch(format!(
139                "expected {} features, got {}",
140                self.median.len(),
141                x.ncols()
142            )));
143        }
144
145        let mut result = x.to_owned();
146        for mut row in result.rows_mut() {
147            for (j, val) in row.iter_mut().enumerate() {
148                if self.with_scaling && self.iqr[j] > F::from_f64(1e-15).unwrap() {
149                    *val *= self.iqr[j];
150                }
151                if self.with_centering {
152                    *val += self.median[j];
153                }
154            }
155        }
156        Ok(result)
157    }
158}
159
160impl<F: Float> FittedRobustScaler<F> {
161    /// Return the median per feature.
162    pub fn median(&self) -> &Array1<F> {
163        &self.median
164    }
165
166    /// Return the IQR per feature.
167    pub fn iqr(&self) -> &Array1<F> {
168        &self.iqr
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use approx::assert_abs_diff_eq;
176    use ndarray::array;
177
178    #[test]
179    fn test_fit_transform() {
180        let x = array![
181            [1.0, 10.0],
182            [2.0, 20.0],
183            [3.0, 30.0],
184            [4.0, 40.0],
185            [5.0, 50.0]
186        ];
187        let scaler = RobustScaler::default();
188        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
189        let transformed = fitted.transform(&x).unwrap();
190
191        // Median of [1,2,3,4,5] is 3; Q1=2, Q3=4, IQR=2
192        // (1 - 3)/2 = -1.0, (3-3)/2 = 0.0, (5-3)/2 = 1.0
193        assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
194        assert_abs_diff_eq!(fitted.iqr()[0], 2.0, epsilon = 1e-10);
195        assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
196        assert_abs_diff_eq!(transformed[[0, 0]], -1.0, epsilon = 1e-10);
197        assert_abs_diff_eq!(transformed[[4, 0]], 1.0, epsilon = 1e-10);
198    }
199
200    #[test]
201    fn test_inverse_transform_roundtrip() {
202        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
203        let scaler = RobustScaler::default();
204        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
205        let transformed = fitted.transform(&x).unwrap();
206        let recovered = fitted.inverse_transform(&transformed).unwrap();
207
208        for (a, b) in x.iter().zip(recovered.iter()) {
209            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
210        }
211    }
212
213    #[test]
214    fn test_without_centering() {
215        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
216        let scaler = RobustScaler::new().with_centering(false);
217        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
218        let transformed = fitted.transform(&x).unwrap();
219
220        // Without centering, median of [1,2,3,4,5] is 3; IQR=2
221        // 1/2 = 0.5, 3/2 = 1.5
222        assert_abs_diff_eq!(transformed[[0, 0]], 0.5, epsilon = 1e-10);
223        assert_abs_diff_eq!(transformed[[2, 0]], 1.5, epsilon = 1e-10);
224    }
225
226    #[test]
227    fn test_without_scaling() {
228        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
229        let scaler = RobustScaler::new().with_scaling(false);
230        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
231        let transformed = fitted.transform(&x).unwrap();
232
233        // Without scaling, just center: 1-3 = -2, 3-3 = 0, 5-3 = 2
234        assert_abs_diff_eq!(transformed[[0, 0]], -2.0, epsilon = 1e-10);
235        assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
236        assert_abs_diff_eq!(transformed[[4, 0]], 2.0, epsilon = 1e-10);
237    }
238
239    #[test]
240    fn test_constant_column() {
241        let x = array![[5.0, 1.0], [5.0, 2.0], [5.0, 3.0], [5.0, 4.0]];
242        let scaler = RobustScaler::default();
243        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
244        let transformed = fitted.transform(&x).unwrap();
245
246        for &v in transformed.iter() {
247            assert!(v.is_finite(), "constant column produced non-finite: {}", v);
248        }
249    }
250
251    #[test]
252    fn test_empty_input() {
253        let x: Array2<f64> = Array2::zeros((0, 0));
254        let scaler = RobustScaler::default();
255        let result = FitUnsupervised::<f64>::fit(&scaler, &x);
256        assert!(result.is_err());
257    }
258
259    #[test]
260    fn test_shape_mismatch() {
261        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
262        let scaler = RobustScaler::default();
263        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
264
265        let x_wrong = array![[1.0, 2.0, 3.0]];
266        assert!(fitted.transform(&x_wrong).is_err());
267        assert!(fitted.inverse_transform(&x_wrong).is_err());
268    }
269
270    #[test]
271    fn test_even_number_of_rows() {
272        // Even number of rows: median by linear interpolation
273        let x = array![[1.0], [2.0], [3.0], [4.0]];
274        let scaler = RobustScaler::default();
275        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
276        // Median of [1,2,3,4] = (2+3)/2 = 2.5
277        assert_abs_diff_eq!(fitted.median()[0], 2.5, epsilon = 1e-10);
278    }
279
280    #[test]
281    fn test_large_values() {
282        let x = array![[1e10], [2e10], [3e10], [4e10], [5e10]];
283        let scaler = RobustScaler::default();
284        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
285        let transformed = fitted.transform(&x).unwrap();
286
287        for &v in transformed.iter() {
288            assert!(v.is_finite(), "large values produced non-finite: {}", v);
289        }
290    }
291
292    #[test]
293    fn test_single_row() {
294        let x = array![[1.0, 2.0]];
295        let scaler = RobustScaler::default();
296        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
297        let transformed = fitted.transform(&x).unwrap();
298
299        // Single row: median = value, IQR = 0 -> centering gives 0, no scaling
300        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
301        assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10);
302    }
303
304    #[test]
305    fn test_f32() {
306        let x = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
307        let scaler = RobustScaler::default();
308        let fitted = FitUnsupervised::<f32>::fit(&scaler, &x).unwrap();
309        let transformed = fitted.transform(&x).unwrap();
310        let recovered = fitted.inverse_transform(&transformed).unwrap();
311
312        for (a, b) in x.iter().zip(recovered.iter()) {
313            assert_abs_diff_eq!(a, b, epsilon = 1e-5);
314        }
315    }
316
317    mod prop_tests {
318        use super::*;
319        use proptest::prelude::*;
320
321        fn make_data(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
322            use std::collections::hash_map::DefaultHasher;
323            use std::hash::{Hash, Hasher};
324            let mut values = Vec::with_capacity(rows * cols);
325            for i in 0..(rows * cols) {
326                let mut h = DefaultHasher::new();
327                seed.hash(&mut h);
328                (i as u64).hash(&mut h);
329                let bits = h.finish();
330                let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
331                values.push(v);
332            }
333            Array2::from_shape_vec((rows, cols), values).unwrap()
334        }
335
336        proptest! {
337            #[test]
338            fn robust_scaler_roundtrip(
339                rows in 2..50usize,
340                cols in 1..10usize,
341                seed in 0u64..10000,
342            ) {
343                let x = make_data(rows, cols, seed);
344                let scaler = RobustScaler::default();
345                let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
346                let transformed = fitted.transform(&x).unwrap();
347                let recovered = fitted.inverse_transform(&transformed).unwrap();
348
349                for (a, b) in x.iter().zip(recovered.iter()) {
350                    prop_assert!((a - b).abs() < 1e-8,
351                        "roundtrip failed: original={}, recovered={}", a, b);
352                }
353            }
354
355            #[test]
356            fn robust_scaler_median_zero(
357                rows in 4..50usize,
358                cols in 1..10usize,
359                seed in 0u64..10000,
360            ) {
361                let x = make_data(rows, cols, seed);
362                let scaler = RobustScaler::default();
363                let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
364                let transformed = fitted.transform(&x).unwrap();
365
366                // After centering, the median of each column should be ~0
367                for col_idx in 0..cols {
368                    let col = transformed.column(col_idx);
369                    let mut sorted: Vec<f64> = col.to_vec();
370                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
371                    let median = super::super::percentile(&sorted, 0.5);
372                    prop_assert!(median.abs() < 1e-8,
373                        "column {} median should be ~0, got {}", col_idx, median);
374                }
375            }
376        }
377    }
378}