1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2};
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11pub struct RobustScaler {
12 pub with_centering: bool,
14 pub with_scaling: bool,
16}
17
18impl RobustScaler {
19 pub fn new() -> Self {
21 Self {
22 with_centering: true,
23 with_scaling: true,
24 }
25 }
26
27 pub fn with_centering(mut self, with_centering: bool) -> Self {
29 self.with_centering = with_centering;
30 self
31 }
32
33 pub fn with_scaling(mut self, with_scaling: bool) -> Self {
35 self.with_scaling = with_scaling;
36 self
37 }
38}
39
40impl Default for RobustScaler {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
48#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
49pub struct FittedRobustScaler<F: Float> {
50 median: Array1<F>,
51 iqr: Array1<F>,
52 with_centering: bool,
53 with_scaling: bool,
54}
55
56fn percentile<F: Float>(sorted: &[F], p: f64) -> F {
60 let n = sorted.len();
61 if n == 1 {
62 return sorted[0];
63 }
64 let idx = p * (n - 1) as f64;
65 let lo = idx.floor() as usize;
66 let hi = idx.ceil() as usize;
67 if lo == hi {
68 sorted[lo]
69 } else {
70 let frac = F::from_f64(idx - lo as f64).unwrap();
71 sorted[lo] * (F::one() - frac) + sorted[hi] * frac
72 }
73}
74
75impl<F: Float> FitUnsupervised<F> for RobustScaler {
76 type Fitted = FittedRobustScaler<F>;
77
78 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
79 if x.is_empty() {
80 return Err(RustMlError::EmptyInput("input array is empty".into()));
81 }
82
83 let ncols = x.ncols();
84 let mut median = Array1::<F>::zeros(ncols);
85 let mut iqr = Array1::<F>::ones(ncols);
86
87 for j in 0..ncols {
88 let col = x.column(j);
89 let mut sorted: Vec<F> = col.to_vec();
90 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
91
92 median[j] = percentile(&sorted, 0.5);
93
94 if self.with_scaling {
95 let q1 = percentile(&sorted, 0.25);
96 let q3 = percentile(&sorted, 0.75);
97 iqr[j] = q3 - q1;
98 }
99 }
100
101 Ok(FittedRobustScaler {
102 median,
103 iqr,
104 with_centering: self.with_centering,
105 with_scaling: self.with_scaling,
106 })
107 }
108}
109
110impl<F: Float> Transform<F> for FittedRobustScaler<F> {
111 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
112 if x.ncols() != self.median.len() {
113 return Err(RustMlError::ShapeMismatch(format!(
114 "expected {} features, got {}",
115 self.median.len(),
116 x.ncols()
117 )));
118 }
119
120 let mut result = x.to_owned();
121 for mut row in result.rows_mut() {
122 for (j, val) in row.iter_mut().enumerate() {
123 if self.with_centering {
124 *val -= self.median[j];
125 }
126 if self.with_scaling && self.iqr[j] > F::from_f64(1e-15).unwrap() {
127 *val /= self.iqr[j];
128 }
129 }
130 }
131 Ok(result)
132 }
133}
134
135impl<F: Float> InverseTransform<F> for FittedRobustScaler<F> {
136 fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
137 if x.ncols() != self.median.len() {
138 return Err(RustMlError::ShapeMismatch(format!(
139 "expected {} features, got {}",
140 self.median.len(),
141 x.ncols()
142 )));
143 }
144
145 let mut result = x.to_owned();
146 for mut row in result.rows_mut() {
147 for (j, val) in row.iter_mut().enumerate() {
148 if self.with_scaling && self.iqr[j] > F::from_f64(1e-15).unwrap() {
149 *val *= self.iqr[j];
150 }
151 if self.with_centering {
152 *val += self.median[j];
153 }
154 }
155 }
156 Ok(result)
157 }
158}
159
160impl<F: Float> FittedRobustScaler<F> {
161 pub fn median(&self) -> &Array1<F> {
163 &self.median
164 }
165
166 pub fn iqr(&self) -> &Array1<F> {
168 &self.iqr
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use approx::assert_abs_diff_eq;
176 use ndarray::array;
177
178 #[test]
179 fn test_fit_transform() {
180 let x = array![
181 [1.0, 10.0],
182 [2.0, 20.0],
183 [3.0, 30.0],
184 [4.0, 40.0],
185 [5.0, 50.0]
186 ];
187 let scaler = RobustScaler::default();
188 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
189 let transformed = fitted.transform(&x).unwrap();
190
191 assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
194 assert_abs_diff_eq!(fitted.iqr()[0], 2.0, epsilon = 1e-10);
195 assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
196 assert_abs_diff_eq!(transformed[[0, 0]], -1.0, epsilon = 1e-10);
197 assert_abs_diff_eq!(transformed[[4, 0]], 1.0, epsilon = 1e-10);
198 }
199
200 #[test]
201 fn test_inverse_transform_roundtrip() {
202 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
203 let scaler = RobustScaler::default();
204 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
205 let transformed = fitted.transform(&x).unwrap();
206 let recovered = fitted.inverse_transform(&transformed).unwrap();
207
208 for (a, b) in x.iter().zip(recovered.iter()) {
209 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
210 }
211 }
212
213 #[test]
214 fn test_without_centering() {
215 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
216 let scaler = RobustScaler::new().with_centering(false);
217 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
218 let transformed = fitted.transform(&x).unwrap();
219
220 assert_abs_diff_eq!(transformed[[0, 0]], 0.5, epsilon = 1e-10);
223 assert_abs_diff_eq!(transformed[[2, 0]], 1.5, epsilon = 1e-10);
224 }
225
226 #[test]
227 fn test_without_scaling() {
228 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
229 let scaler = RobustScaler::new().with_scaling(false);
230 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
231 let transformed = fitted.transform(&x).unwrap();
232
233 assert_abs_diff_eq!(transformed[[0, 0]], -2.0, epsilon = 1e-10);
235 assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
236 assert_abs_diff_eq!(transformed[[4, 0]], 2.0, epsilon = 1e-10);
237 }
238
239 #[test]
240 fn test_constant_column() {
241 let x = array![[5.0, 1.0], [5.0, 2.0], [5.0, 3.0], [5.0, 4.0]];
242 let scaler = RobustScaler::default();
243 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
244 let transformed = fitted.transform(&x).unwrap();
245
246 for &v in transformed.iter() {
247 assert!(v.is_finite(), "constant column produced non-finite: {}", v);
248 }
249 }
250
251 #[test]
252 fn test_empty_input() {
253 let x: Array2<f64> = Array2::zeros((0, 0));
254 let scaler = RobustScaler::default();
255 let result = FitUnsupervised::<f64>::fit(&scaler, &x);
256 assert!(result.is_err());
257 }
258
259 #[test]
260 fn test_shape_mismatch() {
261 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
262 let scaler = RobustScaler::default();
263 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
264
265 let x_wrong = array![[1.0, 2.0, 3.0]];
266 assert!(fitted.transform(&x_wrong).is_err());
267 assert!(fitted.inverse_transform(&x_wrong).is_err());
268 }
269
270 #[test]
271 fn test_even_number_of_rows() {
272 let x = array![[1.0], [2.0], [3.0], [4.0]];
274 let scaler = RobustScaler::default();
275 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
276 assert_abs_diff_eq!(fitted.median()[0], 2.5, epsilon = 1e-10);
278 }
279
280 #[test]
281 fn test_large_values() {
282 let x = array![[1e10], [2e10], [3e10], [4e10], [5e10]];
283 let scaler = RobustScaler::default();
284 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
285 let transformed = fitted.transform(&x).unwrap();
286
287 for &v in transformed.iter() {
288 assert!(v.is_finite(), "large values produced non-finite: {}", v);
289 }
290 }
291
292 #[test]
293 fn test_single_row() {
294 let x = array![[1.0, 2.0]];
295 let scaler = RobustScaler::default();
296 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
297 let transformed = fitted.transform(&x).unwrap();
298
299 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
301 assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10);
302 }
303
304 #[test]
305 fn test_f32() {
306 let x = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
307 let scaler = RobustScaler::default();
308 let fitted = FitUnsupervised::<f32>::fit(&scaler, &x).unwrap();
309 let transformed = fitted.transform(&x).unwrap();
310 let recovered = fitted.inverse_transform(&transformed).unwrap();
311
312 for (a, b) in x.iter().zip(recovered.iter()) {
313 assert_abs_diff_eq!(a, b, epsilon = 1e-5);
314 }
315 }
316
317 mod prop_tests {
318 use super::*;
319 use proptest::prelude::*;
320
321 fn make_data(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
322 use std::collections::hash_map::DefaultHasher;
323 use std::hash::{Hash, Hasher};
324 let mut values = Vec::with_capacity(rows * cols);
325 for i in 0..(rows * cols) {
326 let mut h = DefaultHasher::new();
327 seed.hash(&mut h);
328 (i as u64).hash(&mut h);
329 let bits = h.finish();
330 let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
331 values.push(v);
332 }
333 Array2::from_shape_vec((rows, cols), values).unwrap()
334 }
335
336 proptest! {
337 #[test]
338 fn robust_scaler_roundtrip(
339 rows in 2..50usize,
340 cols in 1..10usize,
341 seed in 0u64..10000,
342 ) {
343 let x = make_data(rows, cols, seed);
344 let scaler = RobustScaler::default();
345 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
346 let transformed = fitted.transform(&x).unwrap();
347 let recovered = fitted.inverse_transform(&transformed).unwrap();
348
349 for (a, b) in x.iter().zip(recovered.iter()) {
350 prop_assert!((a - b).abs() < 1e-8,
351 "roundtrip failed: original={}, recovered={}", a, b);
352 }
353 }
354
355 #[test]
356 fn robust_scaler_median_zero(
357 rows in 4..50usize,
358 cols in 1..10usize,
359 seed in 0u64..10000,
360 ) {
361 let x = make_data(rows, cols, seed);
362 let scaler = RobustScaler::default();
363 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
364 let transformed = fitted.transform(&x).unwrap();
365
366 for col_idx in 0..cols {
368 let col = transformed.column(col_idx);
369 let mut sorted: Vec<f64> = col.to_vec();
370 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
371 let median = super::super::percentile(&sorted, 0.5);
372 prop_assert!(median.abs() < 1e-8,
373 "column {} median should be ~0, got {}", col_idx, median);
374 }
375 }
376 }
377 }
378}