oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
use crate::core::traits::Transformer;

#[derive(Debug, Clone, Copy)]
pub enum OutputDistribution {
    Uniform,
    Normal,
}

#[derive(Debug, Clone)]
pub struct QuantileTransformerConfig {
    pub n_quantiles: usize,
    pub output_distribution: OutputDistribution,
}

impl QuantileTransformerConfig {
    pub fn new() -> Self {
        Self {
            n_quantiles: 1000,
            output_distribution: OutputDistribution::Uniform,
        }
    }
}

impl Default for QuantileTransformerConfig {
    fn default() -> Self {
        Self::new()
    }
}

pub struct QuantileTransformer;

impl Transformer for QuantileTransformer {
    type Config = QuantileTransformerConfig;

    fn transform(config: &Self::Config, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
        assert!(!x.is_empty(), "Input must have at least one sample");
        assert!(config.n_quantiles >= 2, "n_quantiles must be at least 2");

        x.iter()
            .map(|sample| quantile_transform_single(sample, config))
            .collect()
    }
}

fn quantile_transform_single(x: &[f64], config: &QuantileTransformerConfig) -> Vec<f64> {
    let n = x.len();
    let n_quantiles = config.n_quantiles.min(n);

    // Compute quantile landmarks
    let references: Vec<f64> = (0..n_quantiles)
        .map(|i| i as f64 / (n_quantiles - 1).max(1) as f64)
        .collect();

    let mut sorted = x.to_vec();
    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());

    let quantiles: Vec<f64> = references
        .iter()
        .map(|&r| {
            let idx = r * (n - 1) as f64;
            let lo = idx.floor() as usize;
            let hi = (lo + 1).min(n - 1);
            let frac = idx - lo as f64;
            sorted[lo] + frac * (sorted[hi] - sorted[lo])
        })
        .collect();

    // Map each value to [0, 1] via interpolation
    x.iter()
        .map(|&v| {
            let u = interp_quantile(v, &quantiles, &references);
            match config.output_distribution {
                OutputDistribution::Uniform => u,
                OutputDistribution::Normal => {
                    let clamped = u.clamp(1e-7, 1.0 - 1e-7);
                    norm_ppf(clamped)
                }
            }
        })
        .collect()
}

/// Linear interpolation: given value v, find its position in the quantile table.
fn interp_quantile(v: f64, quantiles: &[f64], references: &[f64]) -> f64 {
    if v <= quantiles[0] {
        return references[0];
    }
    if v >= quantiles[quantiles.len() - 1] {
        return references[references.len() - 1];
    }

    // Binary search for the interval
    let pos = quantiles.partition_point(|&q| q < v);
    if pos == 0 {
        return references[0];
    }

    let lo = pos - 1;
    let hi = pos.min(quantiles.len() - 1);

    if quantiles[hi] == quantiles[lo] {
        references[lo]
    } else {
        let frac = (v - quantiles[lo]) / (quantiles[hi] - quantiles[lo]);
        references[lo] + frac * (references[hi] - references[lo])
    }
}

/// Inverse CDF of the standard normal distribution (probit function).
/// Uses the rational approximation by Peter Acklam.
pub fn norm_ppf(p: f64) -> f64 {
    assert!((0.0..=1.0).contains(&p), "p must be in [0, 1]");

    if p == 0.0 {
        return f64::NEG_INFINITY;
    }
    if p == 1.0 {
        return f64::INFINITY;
    }
    if (p - 0.5).abs() < 1e-15 {
        return 0.0;
    }

    // Coefficients for the rational approximation
    const A: [f64; 6] = [
        -3.969683028665376e+01,
        2.209460984245205e+02,
        -2.759285104469687e+02,
        1.383_577_518_672_69e2,
        -3.066479806614716e+01,
        2.506628277459239e+00,
    ];
    const B: [f64; 5] = [
        -5.447609879822406e+01,
        1.615858368580409e+02,
        -1.556989798598866e+02,
        6.680131188771972e+01,
        -1.328068155288572e+01,
    ];
    const C: [f64; 6] = [
        -7.784894002430293e-03,
        -3.223964580411365e-01,
        -2.400758277161838e+00,
        -2.549732539343734e+00,
        4.374664141464968e+00,
        2.938163982698783e+00,
    ];
    const D: [f64; 4] = [
        7.784695709041462e-03,
        3.224671290700398e-01,
        2.445134137142996e+00,
        3.754408661907416e+00,
    ];

    const P_LOW: f64 = 0.02425;
    const P_HIGH: f64 = 1.0 - P_LOW;

    if p < P_LOW {
        let q = (-2.0 * p.ln()).sqrt();
        (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
    } else if p <= P_HIGH {
        let q = p - 0.5;
        let r = q * q;
        (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
            / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
    } else {
        let q = (-2.0 * (1.0 - p).ln()).sqrt();
        -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_uniform_basic() {
        let config = QuantileTransformerConfig::new();
        let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
        let result = QuantileTransformer::transform(&config, &x);
        // Min should map to 0, max to 1
        assert!((result[0][0] - 0.0).abs() < 1e-10);
        assert!((result[0][4] - 1.0).abs() < 1e-10);
        // Should be monotonically increasing
        for i in 1..result[0].len() {
            assert!(result[0][i] >= result[0][i - 1]);
        }
    }

    #[test]
    fn test_normal_output() {
        let config = QuantileTransformerConfig {
            n_quantiles: 1000,
            output_distribution: OutputDistribution::Normal,
        };
        let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
        let result = QuantileTransformer::transform(&config, &x);
        // Middle value (3.0 → p=0.5) should map to ~0.0
        assert!(result[0][2].abs() < 0.1);
        // Should be monotonically increasing
        for i in 1..result[0].len() {
            assert!(result[0][i] >= result[0][i - 1]);
        }
    }

    #[test]
    fn test_norm_ppf_center() {
        assert!((norm_ppf(0.5) - 0.0).abs() < 1e-10);
    }

    #[test]
    fn test_norm_ppf_tails() {
        // ppf(0.025) ≈ -1.96
        assert!((norm_ppf(0.025) - (-1.96)).abs() < 0.01);
        // ppf(0.975) ≈ 1.96
        assert!((norm_ppf(0.975) - 1.96).abs() < 0.01);
    }

    #[test]
    fn test_norm_ppf_symmetry() {
        for &p in &[0.1, 0.2, 0.3, 0.4] {
            let low = norm_ppf(p);
            let high = norm_ppf(1.0 - p);
            assert!(
                (low + high).abs() < 1e-10,
                "ppf({p}) + ppf({}) = {}",
                1.0 - p,
                low + high
            );
        }
    }
}