Skip to main content

ferrolearn_preprocess/
robust_scaler.rs

1//! Robust scaler: median and IQR-based scaling.
2//!
3//! Each feature is transformed as `(x - median) / IQR` where
4//! `IQR = Q75 - Q25`. This scaler is robust to outliers.
5//!
6//! Columns where IQR = 0 are left unchanged after transformation.
7
8use ferrolearn_core::error::FerroError;
9use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
10use ferrolearn_core::traits::{Fit, FitTransform, Transform};
11use ndarray::{Array1, Array2};
12use num_traits::Float;
13
14// ---------------------------------------------------------------------------
15// Helper: compute quantile of a sorted slice
16// ---------------------------------------------------------------------------
17
18/// Compute the `q`-th quantile (0.0–1.0) of a sorted slice using linear interpolation.
19///
20/// Panics if `sorted` is empty.
21fn quantile_sorted<F: Float>(sorted: &[F], q: f64) -> F {
22    let n = sorted.len();
23    if n == 1 {
24        return sorted[0];
25    }
26    let idx = q * (n - 1) as f64;
27    let lo = idx.floor() as usize;
28    let hi = idx.ceil() as usize;
29    if lo == hi {
30        return sorted[lo];
31    }
32    let frac = F::from(idx - lo as f64).unwrap_or(F::zero());
33    sorted[lo] + (sorted[hi] - sorted[lo]) * frac
34}
35
36// ---------------------------------------------------------------------------
37// RobustScaler (unfitted)
38// ---------------------------------------------------------------------------
39
40/// An unfitted robust scaler.
41///
42/// Calling [`Fit::fit`] learns the per-column medians and interquartile ranges
43/// (IQR = Q75 − Q25) and returns a [`FittedRobustScaler`] that can transform
44/// new data.
45///
46/// Columns with IQR = 0 are left unchanged after transformation.
47///
48/// # Examples
49///
50/// ```
51/// use ferrolearn_preprocess::RobustScaler;
52/// use ferrolearn_core::traits::{Fit, Transform};
53/// use ndarray::array;
54///
55/// let scaler = RobustScaler::<f64>::new();
56/// let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0], [100.0, 40.0]];
57/// let fitted = scaler.fit(&x, &()).unwrap();
58/// let scaled = fitted.transform(&x).unwrap();
59/// ```
60#[derive(Debug, Clone)]
61pub struct RobustScaler<F> {
62    _marker: std::marker::PhantomData<F>,
63}
64
65impl<F: Float + Send + Sync + 'static> RobustScaler<F> {
66    /// Create a new `RobustScaler`.
67    #[must_use]
68    pub fn new() -> Self {
69        Self {
70            _marker: std::marker::PhantomData,
71        }
72    }
73}
74
75impl<F: Float + Send + Sync + 'static> Default for RobustScaler<F> {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81// ---------------------------------------------------------------------------
82// FittedRobustScaler
83// ---------------------------------------------------------------------------
84
85/// A fitted robust scaler holding per-column medians and IQRs.
86///
87/// Created by calling [`Fit::fit`] on a [`RobustScaler`].
88#[derive(Debug, Clone)]
89pub struct FittedRobustScaler<F> {
90    /// Per-column medians learned during fitting.
91    pub(crate) median: Array1<F>,
92    /// Per-column interquartile ranges (Q75 − Q25) learned during fitting.
93    pub(crate) iqr: Array1<F>,
94}
95
96impl<F: Float + Send + Sync + 'static> FittedRobustScaler<F> {
97    /// Return the per-column medians learned during fitting.
98    #[must_use]
99    pub fn median(&self) -> &Array1<F> {
100        &self.median
101    }
102
103    /// Return the per-column IQR values learned during fitting.
104    #[must_use]
105    pub fn iqr(&self) -> &Array1<F> {
106        &self.iqr
107    }
108}
109
110// ---------------------------------------------------------------------------
111// Trait implementations
112// ---------------------------------------------------------------------------
113
114impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for RobustScaler<F> {
115    type Fitted = FittedRobustScaler<F>;
116    type Error = FerroError;
117
118    /// Fit the scaler by computing per-column medians and IQRs.
119    ///
120    /// # Errors
121    ///
122    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
123    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedRobustScaler<F>, FerroError> {
124        let n_samples = x.nrows();
125        if n_samples == 0 {
126            return Err(FerroError::InsufficientSamples {
127                required: 1,
128                actual: 0,
129                context: "RobustScaler::fit".into(),
130            });
131        }
132
133        let n_features = x.ncols();
134        let mut median_arr = Array1::zeros(n_features);
135        let mut iqr_arr = Array1::zeros(n_features);
136
137        for j in 0..n_features {
138            let mut col: Vec<F> = x.column(j).iter().copied().collect();
139            col.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
140
141            let med = quantile_sorted(&col, 0.5);
142            let q25 = quantile_sorted(&col, 0.25);
143            let q75 = quantile_sorted(&col, 0.75);
144
145            median_arr[j] = med;
146            iqr_arr[j] = q75 - q25;
147        }
148
149        Ok(FittedRobustScaler {
150            median: median_arr,
151            iqr: iqr_arr,
152        })
153    }
154}
155
156impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedRobustScaler<F> {
157    type Output = Array2<F>;
158    type Error = FerroError;
159
160    /// Transform data by subtracting the median and dividing by the IQR.
161    ///
162    /// Columns with IQR = 0 are left unchanged.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
167    /// match the number of features seen during fitting.
168    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
169        let n_features = self.median.len();
170        if x.ncols() != n_features {
171            return Err(FerroError::ShapeMismatch {
172                expected: vec![x.nrows(), n_features],
173                actual: vec![x.nrows(), x.ncols()],
174                context: "FittedRobustScaler::transform".into(),
175            });
176        }
177
178        let mut out = x.to_owned();
179        for (j, mut col) in out.columns_mut().into_iter().enumerate() {
180            let med = self.median[j];
181            let iqr = self.iqr[j];
182            if iqr == F::zero() {
183                // Zero-IQR column: leave unchanged.
184                continue;
185            }
186            for v in col.iter_mut() {
187                *v = (*v - med) / iqr;
188            }
189        }
190        Ok(out)
191    }
192}
193
194/// Implement `Transform` on the unfitted scaler to satisfy the `FitTransform: Transform`
195/// supertrait bound. Calling `transform` on an unfitted scaler always returns an error.
196impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RobustScaler<F> {
197    type Output = Array2<F>;
198    type Error = FerroError;
199
200    /// Always returns an error — the scaler must be fitted first.
201    ///
202    /// Use [`Fit::fit`] to produce a [`FittedRobustScaler`], then call
203    /// [`Transform::transform`] on that.
204    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
205        Err(FerroError::InvalidParameter {
206            name: "RobustScaler".into(),
207            reason: "scaler must be fitted before calling transform; use fit() first".into(),
208        })
209    }
210}
211
212impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for RobustScaler<F> {
213    type FitError = FerroError;
214
215    /// Fit the scaler on `x` and return the scaled output in one step.
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if fitting fails (e.g., zero rows).
220    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
221        let fitted = self.fit(x, &())?;
222        fitted.transform(x)
223    }
224}
225
226// ---------------------------------------------------------------------------
227// Pipeline integration (generic)
228// ---------------------------------------------------------------------------
229
230impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for RobustScaler<F> {
231    /// Fit the scaler using the pipeline interface.
232    ///
233    /// The `y` argument is ignored; it exists only for API compatibility.
234    ///
235    /// # Errors
236    ///
237    /// Propagates errors from [`Fit::fit`].
238    fn fit_pipeline(
239        &self,
240        x: &Array2<F>,
241        _y: &Array1<F>,
242    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
243        let fitted = self.fit(x, &())?;
244        Ok(Box::new(fitted))
245    }
246}
247
248impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedRobustScaler<F> {
249    /// Transform data using the pipeline interface.
250    ///
251    /// # Errors
252    ///
253    /// Propagates errors from [`Transform::transform`].
254    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
255        self.transform(x)
256    }
257}
258
259// ---------------------------------------------------------------------------
260// Tests
261// ---------------------------------------------------------------------------
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use approx::assert_abs_diff_eq;
267    use ndarray::array;
268
269    #[test]
270    fn test_robust_scaler_basic() {
271        let scaler = RobustScaler::<f64>::new();
272        // Symmetric distribution: median = 3, Q25 = 2, Q75 = 4, IQR = 2
273        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
274        let fitted = scaler.fit(&x, &()).unwrap();
275        assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
276        assert_abs_diff_eq!(fitted.iqr()[0], 2.0, epsilon = 1e-10);
277
278        let scaled = fitted.transform(&x).unwrap();
279        // Median should be 0 after scaling
280        assert_abs_diff_eq!(scaled[[2, 0]], 0.0, epsilon = 1e-10);
281    }
282
283    #[test]
284    fn test_zero_iqr_column_unchanged() {
285        let scaler = RobustScaler::<f64>::new();
286        // Column 0 is constant: IQR = 0
287        let x = array![[7.0, 1.0], [7.0, 2.0], [7.0, 3.0]];
288        let fitted = scaler.fit(&x, &()).unwrap();
289        assert_abs_diff_eq!(fitted.iqr()[0], 0.0, epsilon = 1e-15);
290        let scaled = fitted.transform(&x).unwrap();
291        // Constant column should remain 7.0
292        for i in 0..3 {
293            assert_abs_diff_eq!(scaled[[i, 0]], 7.0, epsilon = 1e-10);
294        }
295    }
296
297    #[test]
298    fn test_outlier_robustness() {
299        let scaler = RobustScaler::<f64>::new();
300        // Add a large outlier; median should not shift much
301        let x = array![[1.0], [2.0], [3.0], [4.0], [1000.0]];
302        let fitted = scaler.fit(&x, &()).unwrap();
303        // Median of sorted [1,2,3,4,1000] = 3.0
304        assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
305    }
306
307    #[test]
308    fn test_fit_transform_equivalence() {
309        let scaler = RobustScaler::<f64>::new();
310        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
311        let via_fit_transform = scaler.fit_transform(&x).unwrap();
312        let fitted = scaler.fit(&x, &()).unwrap();
313        let via_separate = fitted.transform(&x).unwrap();
314        for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
315            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
316        }
317    }
318
319    #[test]
320    fn test_shape_mismatch_error() {
321        let scaler = RobustScaler::<f64>::new();
322        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
323        let fitted = scaler.fit(&x_train, &()).unwrap();
324        let x_bad = array![[1.0, 2.0, 3.0]];
325        assert!(fitted.transform(&x_bad).is_err());
326    }
327
328    #[test]
329    fn test_insufficient_samples_error() {
330        let scaler = RobustScaler::<f64>::new();
331        let x: Array2<f64> = Array2::zeros((0, 3));
332        assert!(scaler.fit(&x, &()).is_err());
333    }
334}