mod common;
use common::{assert_batch_close, load_golden_data};
use oxits::core::config::BinStrategy;
use oxits::preprocessing::discretizer::{KBinsDiscretizer, KBinsDiscretizerConfig};
use oxits::preprocessing::scaler::{
MaxAbsScaler, MinMaxScaler, MinMaxScalerConfig, RobustScaler, RobustScalerConfig,
StandardScaler, StandardScalerConfig,
};
use oxits::Transformer;
const TOL: f64 = 1e-10;
fn fixtures() -> Vec<common::GoldenData> {
load_golden_data("preprocessing/preprocessing.json")
}
fn find(name: &str) -> common::GoldenData {
fixtures()
.into_iter()
.find(|f| f.test_name == name)
.unwrap_or_else(|| panic!("{name} fixture not found"))
}
#[test]
fn test_standard_scaler_golden() {
let f = find("standard_scaler_basic");
let x = f.input.x.as_ref().unwrap();
let expected = f.expected.output.as_ref().unwrap();
let config = StandardScalerConfig::new();
let result = StandardScaler::transform(&config, x);
assert_batch_close("standard_scaler_basic", &result, expected, TOL);
}
#[test]
fn test_standard_scaler_no_mean_golden() {
let f = find("standard_scaler_no_mean");
let x = f.input.x.as_ref().unwrap();
let expected = f.expected.output.as_ref().unwrap();
let config = StandardScalerConfig {
with_mean: false,
..StandardScalerConfig::new()
};
let result = StandardScaler::transform(&config, x);
assert_batch_close("standard_scaler_no_mean", &result, expected, TOL);
}
#[test]
fn test_minmax_scaler_golden() {
let f = find("minmax_scaler_basic");
let x = f.input.x.as_ref().unwrap();
let expected = f.expected.output.as_ref().unwrap();
let config = MinMaxScalerConfig::new();
let result = MinMaxScaler::transform(&config, x);
assert_batch_close("minmax_scaler_basic", &result, expected, TOL);
}
#[test]
fn test_maxabs_scaler_golden() {
let f = find("maxabs_scaler_basic");
let x = f.input.x.as_ref().unwrap();
let expected = f.expected.output.as_ref().unwrap();
let result = MaxAbsScaler::transform(&(), x);
assert_batch_close("maxabs_scaler_basic", &result, expected, TOL);
}
#[test]
fn test_robust_scaler_golden() {
let f = find("robust_scaler_basic");
let x = f.input.x.as_ref().unwrap();
let expected = f.expected.output.as_ref().unwrap();
let config = RobustScalerConfig::new();
let result = RobustScaler::transform(&config, x);
assert_batch_close("robust_scaler_basic", &result, expected, TOL);
}
#[test]
fn test_discretizer_uniform_golden() {
let f = find("discretizer_uniform");
let x = f.input.x.as_ref().unwrap();
let expected = f.expected.output.as_ref().unwrap();
let config = KBinsDiscretizerConfig {
n_bins: 4,
strategy: BinStrategy::Uniform,
};
let result = KBinsDiscretizer::transform(&config, x);
let result_f64: Vec<Vec<f64>> = result
.iter()
.map(|row| row.iter().map(|&v| v as f64).collect())
.collect();
assert_batch_close("discretizer_uniform", &result_f64, expected, TOL);
}
#[test]
fn test_discretizer_quantile_golden() {
let f = find("discretizer_quantile");
let x = f.input.x.as_ref().unwrap();
let expected = f.expected.output.as_ref().unwrap();
let config = KBinsDiscretizerConfig {
n_bins: 3,
strategy: BinStrategy::Quantile,
};
let result = KBinsDiscretizer::transform(&config, x);
let result_f64: Vec<Vec<f64>> = result
.iter()
.map(|row| row.iter().map(|&v| v as f64).collect())
.collect();
assert_batch_close("discretizer_quantile", &result_f64, expected, TOL);
}
#[test]
fn test_discretizer_normal_golden() {
let f = find("discretizer_normal");
let x = f.input.x.as_ref().unwrap();
let expected = f.expected.output.as_ref().unwrap();
let config = KBinsDiscretizerConfig {
n_bins: 4,
strategy: BinStrategy::Normal,
};
let result = KBinsDiscretizer::transform(&config, x);
let result_f64: Vec<Vec<f64>> = result
.iter()
.map(|row| row.iter().map(|&v| v as f64).collect())
.collect();
assert_batch_close("discretizer_normal", &result_f64, expected, TOL);
}