use ferrolearn_core::error::FerroError;
use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
use ferrolearn_core::traits::{Fit, FitTransform, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
fn quantile_sorted<F: Float>(sorted: &[F], q: f64) -> F {
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let idx = q * (n - 1) as f64;
let lo = idx.floor() as usize;
let hi = idx.ceil() as usize;
if lo == hi {
return sorted[lo];
}
let frac = F::from(idx - lo as f64).unwrap_or_else(F::zero);
sorted[lo] + (sorted[hi] - sorted[lo]) * frac
}
#[derive(Debug, Clone)]
pub struct RobustScaler<F> {
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> RobustScaler<F> {
#[must_use]
pub fn new() -> Self {
Self {
_marker: std::marker::PhantomData,
}
}
}
impl<F: Float + Send + Sync + 'static> Default for RobustScaler<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FittedRobustScaler<F> {
pub(crate) median: Array1<F>,
pub(crate) iqr: Array1<F>,
}
impl<F: Float + Send + Sync + 'static> FittedRobustScaler<F> {
#[must_use]
pub fn median(&self) -> &Array1<F> {
&self.median
}
#[must_use]
pub fn iqr(&self) -> &Array1<F> {
&self.iqr
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for RobustScaler<F> {
type Fitted = FittedRobustScaler<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedRobustScaler<F>, FerroError> {
let n_samples = x.nrows();
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "RobustScaler::fit".into(),
});
}
let n_features = x.ncols();
let mut median_arr = Array1::zeros(n_features);
let mut iqr_arr = Array1::zeros(n_features);
for j in 0..n_features {
let mut col: Vec<F> = x.column(j).iter().copied().collect();
col.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let med = quantile_sorted(&col, 0.5);
let q25 = quantile_sorted(&col, 0.25);
let q75 = quantile_sorted(&col, 0.75);
median_arr[j] = med;
iqr_arr[j] = q75 - q25;
}
Ok(FittedRobustScaler {
median: median_arr,
iqr: iqr_arr,
})
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedRobustScaler<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = self.median.len();
if x.ncols() != n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), n_features],
actual: vec![x.nrows(), x.ncols()],
context: "FittedRobustScaler::transform".into(),
});
}
let mut out = x.to_owned();
for (j, mut col) in out.columns_mut().into_iter().enumerate() {
let med = self.median[j];
let iqr = self.iqr[j];
if iqr == F::zero() {
continue;
}
for v in &mut col {
*v = (*v - med) / iqr;
}
}
Ok(out)
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RobustScaler<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
Err(FerroError::InvalidParameter {
name: "RobustScaler".into(),
reason: "scaler must be fitted before calling transform; use fit() first".into(),
})
}
}
impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for RobustScaler<F> {
type FitError = FerroError;
fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let fitted = self.fit(x, &())?;
fitted.transform(x)
}
}
impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for RobustScaler<F> {
fn fit_pipeline(
&self,
x: &Array2<F>,
_y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
let fitted = self.fit(x, &())?;
Ok(Box::new(fitted))
}
}
impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedRobustScaler<F> {
fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
self.transform(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_robust_scaler_basic() {
let scaler = RobustScaler::<f64>::new();
let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
let fitted = scaler.fit(&x, &()).unwrap();
assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(fitted.iqr()[0], 2.0, epsilon = 1e-10);
let scaled = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(scaled[[2, 0]], 0.0, epsilon = 1e-10);
}
#[test]
fn test_zero_iqr_column_unchanged() {
let scaler = RobustScaler::<f64>::new();
let x = array![[7.0, 1.0], [7.0, 2.0], [7.0, 3.0]];
let fitted = scaler.fit(&x, &()).unwrap();
assert_abs_diff_eq!(fitted.iqr()[0], 0.0, epsilon = 1e-15);
let scaled = fitted.transform(&x).unwrap();
for i in 0..3 {
assert_abs_diff_eq!(scaled[[i, 0]], 7.0, epsilon = 1e-10);
}
}
#[test]
fn test_outlier_robustness() {
let scaler = RobustScaler::<f64>::new();
let x = array![[1.0], [2.0], [3.0], [4.0], [1000.0]];
let fitted = scaler.fit(&x, &()).unwrap();
assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
}
#[test]
fn test_fit_transform_equivalence() {
let scaler = RobustScaler::<f64>::new();
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let via_fit_transform = scaler.fit_transform(&x).unwrap();
let fitted = scaler.fit(&x, &()).unwrap();
let via_separate = fitted.transform(&x).unwrap();
for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-15);
}
}
#[test]
fn test_shape_mismatch_error() {
let scaler = RobustScaler::<f64>::new();
let x_train = array![[1.0, 2.0], [3.0, 4.0]];
let fitted = scaler.fit(&x_train, &()).unwrap();
let x_bad = array![[1.0, 2.0, 3.0]];
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_insufficient_samples_error() {
let scaler = RobustScaler::<f64>::new();
let x: Array2<f64> = Array2::zeros((0, 3));
assert!(scaler.fit(&x, &()).is_err());
}
}