use crate::core::traits::Transformer;
#[derive(Debug, Clone, Copy)]
pub enum OutputDistribution {
Uniform,
Normal,
}
#[derive(Debug, Clone)]
pub struct QuantileTransformerConfig {
pub n_quantiles: usize,
pub output_distribution: OutputDistribution,
}
impl QuantileTransformerConfig {
pub fn new() -> Self {
Self {
n_quantiles: 1000,
output_distribution: OutputDistribution::Uniform,
}
}
}
impl Default for QuantileTransformerConfig {
fn default() -> Self {
Self::new()
}
}
pub struct QuantileTransformer;
impl Transformer for QuantileTransformer {
type Config = QuantileTransformerConfig;
fn transform(config: &Self::Config, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
assert!(!x.is_empty(), "Input must have at least one sample");
assert!(config.n_quantiles >= 2, "n_quantiles must be at least 2");
x.iter()
.map(|sample| quantile_transform_single(sample, config))
.collect()
}
}
fn quantile_transform_single(x: &[f64], config: &QuantileTransformerConfig) -> Vec<f64> {
let n = x.len();
let n_quantiles = config.n_quantiles.min(n);
let references: Vec<f64> = (0..n_quantiles)
.map(|i| i as f64 / (n_quantiles - 1).max(1) as f64)
.collect();
let mut sorted = x.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let quantiles: Vec<f64> = references
.iter()
.map(|&r| {
let idx = r * (n - 1) as f64;
let lo = idx.floor() as usize;
let hi = (lo + 1).min(n - 1);
let frac = idx - lo as f64;
sorted[lo] + frac * (sorted[hi] - sorted[lo])
})
.collect();
x.iter()
.map(|&v| {
let u = interp_quantile(v, &quantiles, &references);
match config.output_distribution {
OutputDistribution::Uniform => u,
OutputDistribution::Normal => {
let clamped = u.clamp(1e-7, 1.0 - 1e-7);
norm_ppf(clamped)
}
}
})
.collect()
}
fn interp_quantile(v: f64, quantiles: &[f64], references: &[f64]) -> f64 {
if v <= quantiles[0] {
return references[0];
}
if v >= quantiles[quantiles.len() - 1] {
return references[references.len() - 1];
}
let pos = quantiles.partition_point(|&q| q < v);
if pos == 0 {
return references[0];
}
let lo = pos - 1;
let hi = pos.min(quantiles.len() - 1);
if quantiles[hi] == quantiles[lo] {
references[lo]
} else {
let frac = (v - quantiles[lo]) / (quantiles[hi] - quantiles[lo]);
references[lo] + frac * (references[hi] - references[lo])
}
}
pub fn norm_ppf(p: f64) -> f64 {
assert!((0.0..=1.0).contains(&p), "p must be in [0, 1]");
if p == 0.0 {
return f64::NEG_INFINITY;
}
if p == 1.0 {
return f64::INFINITY;
}
if (p - 0.5).abs() < 1e-15 {
return 0.0;
}
const A: [f64; 6] = [
-3.969683028665376e+01,
2.209460984245205e+02,
-2.759285104469687e+02,
1.383_577_518_672_69e2,
-3.066479806614716e+01,
2.506628277459239e+00,
];
const B: [f64; 5] = [
-5.447609879822406e+01,
1.615858368580409e+02,
-1.556989798598866e+02,
6.680131188771972e+01,
-1.328068155288572e+01,
];
const C: [f64; 6] = [
-7.784894002430293e-03,
-3.223964580411365e-01,
-2.400758277161838e+00,
-2.549732539343734e+00,
4.374664141464968e+00,
2.938163982698783e+00,
];
const D: [f64; 4] = [
7.784695709041462e-03,
3.224671290700398e-01,
2.445134137142996e+00,
3.754408661907416e+00,
];
const P_LOW: f64 = 0.02425;
const P_HIGH: f64 = 1.0 - P_LOW;
if p < P_LOW {
let q = (-2.0 * p.ln()).sqrt();
(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
/ ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
} else if p <= P_HIGH {
let q = p - 0.5;
let r = q * q;
(((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
/ (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
} else {
let q = (-2.0 * (1.0 - p).ln()).sqrt();
-(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
/ ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uniform_basic() {
let config = QuantileTransformerConfig::new();
let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
let result = QuantileTransformer::transform(&config, &x);
assert!((result[0][0] - 0.0).abs() < 1e-10);
assert!((result[0][4] - 1.0).abs() < 1e-10);
for i in 1..result[0].len() {
assert!(result[0][i] >= result[0][i - 1]);
}
}
#[test]
fn test_normal_output() {
let config = QuantileTransformerConfig {
n_quantiles: 1000,
output_distribution: OutputDistribution::Normal,
};
let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
let result = QuantileTransformer::transform(&config, &x);
assert!(result[0][2].abs() < 0.1);
for i in 1..result[0].len() {
assert!(result[0][i] >= result[0][i - 1]);
}
}
#[test]
fn test_norm_ppf_center() {
assert!((norm_ppf(0.5) - 0.0).abs() < 1e-10);
}
#[test]
fn test_norm_ppf_tails() {
assert!((norm_ppf(0.025) - (-1.96)).abs() < 0.01);
assert!((norm_ppf(0.975) - 1.96).abs() < 0.01);
}
#[test]
fn test_norm_ppf_symmetry() {
for &p in &[0.1, 0.2, 0.3, 0.4] {
let low = norm_ppf(p);
let high = norm_ppf(1.0 - p);
assert!(
(low + high).abs() < 1e-10,
"ppf({p}) + ppf({}) = {}",
1.0 - p,
low + high
);
}
}
}