1use ferrolearn_core::error::FerroError;
9use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
10use ferrolearn_core::traits::{Fit, FitTransform, Transform};
11use ndarray::{Array1, Array2};
12use num_traits::Float;
13
14fn quantile_sorted<F: Float>(sorted: &[F], q: f64) -> F {
22 let n = sorted.len();
23 if n == 1 {
24 return sorted[0];
25 }
26 let idx = q * (n - 1) as f64;
27 let lo = idx.floor() as usize;
28 let hi = idx.ceil() as usize;
29 if lo == hi {
30 return sorted[lo];
31 }
32 let frac = F::from(idx - lo as f64).unwrap_or(F::zero());
33 sorted[lo] + (sorted[hi] - sorted[lo]) * frac
34}
35
36#[derive(Debug, Clone)]
61pub struct RobustScaler<F> {
62 _marker: std::marker::PhantomData<F>,
63}
64
65impl<F: Float + Send + Sync + 'static> RobustScaler<F> {
66 #[must_use]
68 pub fn new() -> Self {
69 Self {
70 _marker: std::marker::PhantomData,
71 }
72 }
73}
74
75impl<F: Float + Send + Sync + 'static> Default for RobustScaler<F> {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81#[derive(Debug, Clone)]
89pub struct FittedRobustScaler<F> {
90 pub(crate) median: Array1<F>,
92 pub(crate) iqr: Array1<F>,
94}
95
96impl<F: Float + Send + Sync + 'static> FittedRobustScaler<F> {
97 #[must_use]
99 pub fn median(&self) -> &Array1<F> {
100 &self.median
101 }
102
103 #[must_use]
105 pub fn iqr(&self) -> &Array1<F> {
106 &self.iqr
107 }
108}
109
110impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for RobustScaler<F> {
115 type Fitted = FittedRobustScaler<F>;
116 type Error = FerroError;
117
118 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedRobustScaler<F>, FerroError> {
124 let n_samples = x.nrows();
125 if n_samples == 0 {
126 return Err(FerroError::InsufficientSamples {
127 required: 1,
128 actual: 0,
129 context: "RobustScaler::fit".into(),
130 });
131 }
132
133 let n_features = x.ncols();
134 let mut median_arr = Array1::zeros(n_features);
135 let mut iqr_arr = Array1::zeros(n_features);
136
137 for j in 0..n_features {
138 let mut col: Vec<F> = x.column(j).iter().copied().collect();
139 col.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
140
141 let med = quantile_sorted(&col, 0.5);
142 let q25 = quantile_sorted(&col, 0.25);
143 let q75 = quantile_sorted(&col, 0.75);
144
145 median_arr[j] = med;
146 iqr_arr[j] = q75 - q25;
147 }
148
149 Ok(FittedRobustScaler {
150 median: median_arr,
151 iqr: iqr_arr,
152 })
153 }
154}
155
156impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedRobustScaler<F> {
157 type Output = Array2<F>;
158 type Error = FerroError;
159
160 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
169 let n_features = self.median.len();
170 if x.ncols() != n_features {
171 return Err(FerroError::ShapeMismatch {
172 expected: vec![x.nrows(), n_features],
173 actual: vec![x.nrows(), x.ncols()],
174 context: "FittedRobustScaler::transform".into(),
175 });
176 }
177
178 let mut out = x.to_owned();
179 for (j, mut col) in out.columns_mut().into_iter().enumerate() {
180 let med = self.median[j];
181 let iqr = self.iqr[j];
182 if iqr == F::zero() {
183 continue;
185 }
186 for v in col.iter_mut() {
187 *v = (*v - med) / iqr;
188 }
189 }
190 Ok(out)
191 }
192}
193
194impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RobustScaler<F> {
197 type Output = Array2<F>;
198 type Error = FerroError;
199
200 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
205 Err(FerroError::InvalidParameter {
206 name: "RobustScaler".into(),
207 reason: "scaler must be fitted before calling transform; use fit() first".into(),
208 })
209 }
210}
211
212impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for RobustScaler<F> {
213 type FitError = FerroError;
214
215 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
221 let fitted = self.fit(x, &())?;
222 fitted.transform(x)
223 }
224}
225
226impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for RobustScaler<F> {
231 fn fit_pipeline(
239 &self,
240 x: &Array2<F>,
241 _y: &Array1<F>,
242 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
243 let fitted = self.fit(x, &())?;
244 Ok(Box::new(fitted))
245 }
246}
247
248impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedRobustScaler<F> {
249 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
255 self.transform(x)
256 }
257}
258
259#[cfg(test)]
264mod tests {
265 use super::*;
266 use approx::assert_abs_diff_eq;
267 use ndarray::array;
268
269 #[test]
270 fn test_robust_scaler_basic() {
271 let scaler = RobustScaler::<f64>::new();
272 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
274 let fitted = scaler.fit(&x, &()).unwrap();
275 assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
276 assert_abs_diff_eq!(fitted.iqr()[0], 2.0, epsilon = 1e-10);
277
278 let scaled = fitted.transform(&x).unwrap();
279 assert_abs_diff_eq!(scaled[[2, 0]], 0.0, epsilon = 1e-10);
281 }
282
283 #[test]
284 fn test_zero_iqr_column_unchanged() {
285 let scaler = RobustScaler::<f64>::new();
286 let x = array![[7.0, 1.0], [7.0, 2.0], [7.0, 3.0]];
288 let fitted = scaler.fit(&x, &()).unwrap();
289 assert_abs_diff_eq!(fitted.iqr()[0], 0.0, epsilon = 1e-15);
290 let scaled = fitted.transform(&x).unwrap();
291 for i in 0..3 {
293 assert_abs_diff_eq!(scaled[[i, 0]], 7.0, epsilon = 1e-10);
294 }
295 }
296
297 #[test]
298 fn test_outlier_robustness() {
299 let scaler = RobustScaler::<f64>::new();
300 let x = array![[1.0], [2.0], [3.0], [4.0], [1000.0]];
302 let fitted = scaler.fit(&x, &()).unwrap();
303 assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
305 }
306
307 #[test]
308 fn test_fit_transform_equivalence() {
309 let scaler = RobustScaler::<f64>::new();
310 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
311 let via_fit_transform = scaler.fit_transform(&x).unwrap();
312 let fitted = scaler.fit(&x, &()).unwrap();
313 let via_separate = fitted.transform(&x).unwrap();
314 for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
315 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
316 }
317 }
318
319 #[test]
320 fn test_shape_mismatch_error() {
321 let scaler = RobustScaler::<f64>::new();
322 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
323 let fitted = scaler.fit(&x_train, &()).unwrap();
324 let x_bad = array![[1.0, 2.0, 3.0]];
325 assert!(fitted.transform(&x_bad).is_err());
326 }
327
328 #[test]
329 fn test_insufficient_samples_error() {
330 let scaler = RobustScaler::<f64>::new();
331 let x: Array2<f64> = Array2::zeros((0, 3));
332 assert!(scaler.fit(&x, &()).is_err());
333 }
334}