use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
use ndarray::{Array1, Array2};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RobustScaler {
pub with_centering: bool,
pub with_scaling: bool,
}
impl RobustScaler {
pub fn new() -> Self {
Self {
with_centering: true,
with_scaling: true,
}
}
pub fn with_centering(mut self, with_centering: bool) -> Self {
self.with_centering = with_centering;
self
}
pub fn with_scaling(mut self, with_scaling: bool) -> Self {
self.with_scaling = with_scaling;
self
}
}
impl Default for RobustScaler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
pub struct FittedRobustScaler<F: Float> {
median: Array1<F>,
iqr: Array1<F>,
with_centering: bool,
with_scaling: bool,
}
fn percentile<F: Float>(sorted: &[F], p: f64) -> F {
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let idx = p * (n - 1) as f64;
let lo = idx.floor() as usize;
let hi = idx.ceil() as usize;
if lo == hi {
sorted[lo]
} else {
let frac = F::from_f64(idx - lo as f64).unwrap();
sorted[lo] * (F::one() - frac) + sorted[hi] * frac
}
}
impl<F: Float> FitUnsupervised<F> for RobustScaler {
type Fitted = FittedRobustScaler<F>;
fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
if x.is_empty() {
return Err(RustMlError::EmptyInput("input array is empty".into()));
}
let ncols = x.ncols();
let mut median = Array1::<F>::zeros(ncols);
let mut iqr = Array1::<F>::ones(ncols);
for j in 0..ncols {
let col = x.column(j);
let mut sorted: Vec<F> = col.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
median[j] = percentile(&sorted, 0.5);
if self.with_scaling {
let q1 = percentile(&sorted, 0.25);
let q3 = percentile(&sorted, 0.75);
iqr[j] = q3 - q1;
}
}
Ok(FittedRobustScaler {
median,
iqr,
with_centering: self.with_centering,
with_scaling: self.with_scaling,
})
}
}
impl<F: Float> Transform<F> for FittedRobustScaler<F> {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
if x.ncols() != self.median.len() {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
self.median.len(),
x.ncols()
)));
}
let mut result = x.to_owned();
for mut row in result.rows_mut() {
for (j, val) in row.iter_mut().enumerate() {
if self.with_centering {
*val -= self.median[j];
}
if self.with_scaling && self.iqr[j] > F::from_f64(1e-15).unwrap() {
*val /= self.iqr[j];
}
}
}
Ok(result)
}
}
impl<F: Float> InverseTransform<F> for FittedRobustScaler<F> {
fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
if x.ncols() != self.median.len() {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
self.median.len(),
x.ncols()
)));
}
let mut result = x.to_owned();
for mut row in result.rows_mut() {
for (j, val) in row.iter_mut().enumerate() {
if self.with_scaling && self.iqr[j] > F::from_f64(1e-15).unwrap() {
*val *= self.iqr[j];
}
if self.with_centering {
*val += self.median[j];
}
}
}
Ok(result)
}
}
impl<F: Float> FittedRobustScaler<F> {
pub fn median(&self) -> &Array1<F> {
&self.median
}
pub fn iqr(&self) -> &Array1<F> {
&self.iqr
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_fit_transform() {
let x = array![
[1.0, 10.0],
[2.0, 20.0],
[3.0, 30.0],
[4.0, 40.0],
[5.0, 50.0]
];
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(fitted.median()[0], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(fitted.iqr()[0], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[0, 0]], -1.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[4, 0]], 1.0, epsilon = 1e-10);
}
#[test]
fn test_inverse_transform_roundtrip() {
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
let recovered = fitted.inverse_transform(&transformed).unwrap();
for (a, b) in x.iter().zip(recovered.iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_without_centering() {
let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
let scaler = RobustScaler::new().with_centering(false);
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(transformed[[0, 0]], 0.5, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[2, 0]], 1.5, epsilon = 1e-10);
}
#[test]
fn test_without_scaling() {
let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
let scaler = RobustScaler::new().with_scaling(false);
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(transformed[[0, 0]], -2.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[4, 0]], 2.0, epsilon = 1e-10);
}
#[test]
fn test_constant_column() {
let x = array![[5.0, 1.0], [5.0, 2.0], [5.0, 3.0], [5.0, 4.0]];
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
for &v in transformed.iter() {
assert!(v.is_finite(), "constant column produced non-finite: {}", v);
}
}
#[test]
fn test_empty_input() {
let x: Array2<f64> = Array2::zeros((0, 0));
let scaler = RobustScaler::default();
let result = FitUnsupervised::<f64>::fit(&scaler, &x);
assert!(result.is_err());
}
#[test]
fn test_shape_mismatch() {
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let x_wrong = array![[1.0, 2.0, 3.0]];
assert!(fitted.transform(&x_wrong).is_err());
assert!(fitted.inverse_transform(&x_wrong).is_err());
}
#[test]
fn test_even_number_of_rows() {
let x = array![[1.0], [2.0], [3.0], [4.0]];
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
assert_abs_diff_eq!(fitted.median()[0], 2.5, epsilon = 1e-10);
}
#[test]
fn test_large_values() {
let x = array![[1e10], [2e10], [3e10], [4e10], [5e10]];
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
for &v in transformed.iter() {
assert!(v.is_finite(), "large values produced non-finite: {}", v);
}
}
#[test]
fn test_single_row() {
let x = array![[1.0, 2.0]];
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10);
}
#[test]
fn test_f32() {
let x = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f32>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
let recovered = fitted.inverse_transform(&transformed).unwrap();
for (a, b) in x.iter().zip(recovered.iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-5);
}
}
mod prop_tests {
use super::*;
use proptest::prelude::*;
fn make_data(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut values = Vec::with_capacity(rows * cols);
for i in 0..(rows * cols) {
let mut h = DefaultHasher::new();
seed.hash(&mut h);
(i as u64).hash(&mut h);
let bits = h.finish();
let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
values.push(v);
}
Array2::from_shape_vec((rows, cols), values).unwrap()
}
proptest! {
#[test]
fn robust_scaler_roundtrip(
rows in 2..50usize,
cols in 1..10usize,
seed in 0u64..10000,
) {
let x = make_data(rows, cols, seed);
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
let recovered = fitted.inverse_transform(&transformed).unwrap();
for (a, b) in x.iter().zip(recovered.iter()) {
prop_assert!((a - b).abs() < 1e-8,
"roundtrip failed: original={}, recovered={}", a, b);
}
}
#[test]
fn robust_scaler_median_zero(
rows in 4..50usize,
cols in 1..10usize,
seed in 0u64..10000,
) {
let x = make_data(rows, cols, seed);
let scaler = RobustScaler::default();
let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
for col_idx in 0..cols {
let col = transformed.column(col_idx);
let mut sorted: Vec<f64> = col.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median = super::super::percentile(&sorted, 0.5);
prop_assert!(median.abs() < 1e-8,
"column {} median should be ~0, got {}", col_idx, median);
}
}
}
}
}