Skip to main content

anofox_ml_preprocessing/
quantile_transformer.rs

1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4/// Output distribution for the quantile transformer.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
6pub enum OutputDistribution {
7    /// Map to a uniform distribution on [0, 1].
8    Uniform,
9    /// Map to a standard normal distribution.
10    Normal,
11}
12
13/// Parameters for QuantileTransformer (unfitted state).
14///
15/// Transforms features to follow a uniform or normal distribution by
16/// estimating the cumulative distribution function (CDF) via quantiles.
17#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
18pub struct QuantileTransformer {
19    /// Number of quantiles to compute. Clamped to n_samples if larger.
20    pub n_quantiles: usize,
21    /// Target output distribution.
22    pub output_distribution: OutputDistribution,
23}
24
25impl QuantileTransformer {
26    /// Create a new `QuantileTransformer` with defaults (1000 quantiles, uniform output).
27    pub fn new() -> Self {
28        Self {
29            n_quantiles: 1000,
30            output_distribution: OutputDistribution::Uniform,
31        }
32    }
33
34    /// Set the number of quantiles to compute.
35    pub fn n_quantiles(mut self, n_quantiles: usize) -> Self {
36        self.n_quantiles = n_quantiles;
37        self
38    }
39
40    /// Set the output distribution.
41    pub fn output_distribution(mut self, output_distribution: OutputDistribution) -> Self {
42        self.output_distribution = output_distribution;
43        self
44    }
45}
46
47impl Default for QuantileTransformer {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53/// Fitted QuantileTransformer -- holds quantile references per feature.
54#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
56pub struct FittedQuantileTransformer<F: Float> {
57    /// For each feature, sorted quantile values (length = effective n_quantiles).
58    quantiles: Vec<Vec<F>>,
59    /// The corresponding CDF positions for each quantile, in [0, 1].
60    references: Vec<f64>,
61    output_distribution: OutputDistribution,
62}
63
64/// Approximate the inverse of the standard normal CDF (probit function)
65/// using the rational approximation by Peter Acklam.
66fn inverse_normal_cdf(p: f64) -> f64 {
67    if p <= 0.0 {
68        return -8.0; // clamp
69    }
70    if p >= 1.0 {
71        return 8.0; // clamp
72    }
73
74    // Coefficients for the rational approximation
75    const A: [f64; 6] = [
76        -3.969683028665376e+01,
77        2.209460984245205e+02,
78        -2.759285104469687e+02,
79        1.383577518672690e+02,
80        -3.066479806614716e+01,
81        2.506628277459239e+00,
82    ];
83    const B: [f64; 5] = [
84        -5.447609879822406e+01,
85        1.615858368580409e+02,
86        -1.556989798598866e+02,
87        6.680131188771972e+01,
88        -1.328068155288572e+01,
89    ];
90    const C: [f64; 6] = [
91        -7.784894002430293e-03,
92        -3.223964580411365e-01,
93        -2.400758277161838e+00,
94        -2.549732539343734e+00,
95        4.374664141464968e+00,
96        2.938163982698783e+00,
97    ];
98    const D: [f64; 4] = [
99        7.784695709041462e-03,
100        3.224671290700398e-01,
101        2.445134137142996e+00,
102        3.754408661907416e+00,
103    ];
104
105    let p_low = 0.02425;
106    let p_high = 1.0 - p_low;
107
108    if p < p_low {
109        // Rational approximation for lower region
110        let q = (-2.0 * p.ln()).sqrt();
111        (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
112            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
113    } else if p <= p_high {
114        // Rational approximation for central region
115        let q = p - 0.5;
116        let r = q * q;
117        (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
118            / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
119    } else {
120        // Rational approximation for upper region
121        let q = (-2.0 * (1.0 - p).ln()).sqrt();
122        -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
123            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
124    }
125}
126
127/// Linear interpolation: given sorted `xp` with corresponding `fp`,
128/// find the interpolated value at `x`.
129fn interp(x: f64, xp: &[f64], fp: &[f64]) -> f64 {
130    debug_assert_eq!(xp.len(), fp.len());
131    let n = xp.len();
132    if n == 0 {
133        return 0.0;
134    }
135    if x <= xp[0] {
136        return fp[0];
137    }
138    if x >= xp[n - 1] {
139        return fp[n - 1];
140    }
141
142    // Binary search for the interval
143    let mut lo = 0;
144    let mut hi = n - 1;
145    while lo + 1 < hi {
146        let mid = (lo + hi) / 2;
147        if xp[mid] <= x {
148            lo = mid;
149        } else {
150            hi = mid;
151        }
152    }
153
154    let dx = xp[hi] - xp[lo];
155    if dx.abs() < 1e-30 {
156        return fp[lo];
157    }
158    let t = (x - xp[lo]) / dx;
159    fp[lo] + t * (fp[hi] - fp[lo])
160}
161
162impl<F: Float> FitUnsupervised<F> for QuantileTransformer {
163    type Fitted = FittedQuantileTransformer<F>;
164
165    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
166        if x.is_empty() {
167            return Err(RustMlError::EmptyInput("input array is empty".into()));
168        }
169
170        let n_samples = x.nrows();
171        let ncols = x.ncols();
172        let effective_n = self.n_quantiles.min(n_samples);
173
174        // Compute reference positions in [0, 1]
175        let references: Vec<f64> = if effective_n == 1 {
176            vec![0.5]
177        } else {
178            (0..effective_n)
179                .map(|i| i as f64 / (effective_n - 1) as f64)
180                .collect()
181        };
182
183        let mut quantiles = Vec::with_capacity(ncols);
184
185        for j in 0..ncols {
186            let mut col: Vec<F> = x.column(j).to_vec();
187            col.sort_by(|a, b| a.partial_cmp(b).unwrap());
188
189            // Compute quantiles at the reference positions
190            let q: Vec<F> = references
191                .iter()
192                .map(|&p| percentile_sorted(&col, p))
193                .collect();
194
195            quantiles.push(q);
196        }
197
198        Ok(FittedQuantileTransformer {
199            quantiles,
200            references,
201            output_distribution: self.output_distribution,
202        })
203    }
204}
205
206/// Compute a percentile from a sorted slice using linear interpolation.
207fn percentile_sorted<F: Float>(sorted: &[F], p: f64) -> F {
208    let n = sorted.len();
209    if n == 1 {
210        return sorted[0];
211    }
212    let idx = p * (n - 1) as f64;
213    let lo = idx.floor() as usize;
214    let hi = idx.ceil().min((n - 1) as f64) as usize;
215    if lo == hi {
216        sorted[lo]
217    } else {
218        let frac = F::from_f64(idx - lo as f64).unwrap();
219        sorted[lo] * (F::one() - frac) + sorted[hi] * frac
220    }
221}
222
223impl<F: Float> Transform<F> for FittedQuantileTransformer<F> {
224    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
225        let expected_cols = self.quantiles.len();
226        if x.ncols() != expected_cols {
227            return Err(RustMlError::ShapeMismatch(format!(
228                "expected {} features, got {}",
229                expected_cols,
230                x.ncols()
231            )));
232        }
233
234        let mut result = Array2::<F>::zeros(x.raw_dim());
235
236        for j in 0..x.ncols() {
237            let q = &self.quantiles[j];
238            // Build xp (quantile values as f64) and fp (references)
239            let xp: Vec<f64> = q.iter().map(|&v| v.to_f64().unwrap()).collect();
240            let fp = &self.references;
241
242            for i in 0..x.nrows() {
243                let val = x[[i, j]].to_f64().unwrap();
244                // Interpolate: map from data space to [0, 1]
245                let mut u = interp(val, &xp, fp);
246
247                // Clip to (epsilon, 1 - epsilon) to avoid infinities in normal transform
248                let eps = 1e-7;
249                u = u.max(eps).min(1.0 - eps);
250
251                let out = match self.output_distribution {
252                    OutputDistribution::Uniform => u,
253                    OutputDistribution::Normal => inverse_normal_cdf(u),
254                };
255
256                result[[i, j]] = F::from_f64(out).unwrap();
257            }
258        }
259
260        Ok(result)
261    }
262}
263
264impl<F: Float> FittedQuantileTransformer<F> {
265    /// Return the quantile values per feature.
266    pub fn quantiles(&self) -> &Vec<Vec<F>> {
267        &self.quantiles
268    }
269
270    /// Return the reference positions used for interpolation.
271    pub fn references(&self) -> &Vec<f64> {
272        &self.references
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use approx::assert_abs_diff_eq;
280    use ndarray::array;
281
282    #[test]
283    fn test_uniform_output() {
284        let x = array![
285            [1.0, 10.0],
286            [2.0, 20.0],
287            [3.0, 30.0],
288            [4.0, 40.0],
289            [5.0, 50.0],
290        ];
291        let qt = QuantileTransformer::new()
292            .n_quantiles(5)
293            .output_distribution(OutputDistribution::Uniform);
294        let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
295        let transformed = fitted.transform(&x).unwrap();
296
297        // With 5 samples and 5 quantiles, the result should be approximately
298        // [0, 0.25, 0.5, 0.75, 1.0] clipped to (eps, 1-eps)
299        let eps = 1e-7;
300        assert_abs_diff_eq!(transformed[[0, 0]], eps, epsilon = 1e-6);
301        assert_abs_diff_eq!(transformed[[1, 0]], 0.25, epsilon = 1e-6);
302        assert_abs_diff_eq!(transformed[[2, 0]], 0.5, epsilon = 1e-6);
303        assert_abs_diff_eq!(transformed[[3, 0]], 0.75, epsilon = 1e-6);
304        assert_abs_diff_eq!(transformed[[4, 0]], 1.0 - eps, epsilon = 1e-6);
305    }
306
307    #[test]
308    fn test_normal_output() {
309        let x = array![
310            [1.0],
311            [2.0],
312            [3.0],
313            [4.0],
314            [5.0],
315            [6.0],
316            [7.0],
317            [8.0],
318            [9.0],
319            [10.0],
320        ];
321        let qt = QuantileTransformer::new()
322            .n_quantiles(10)
323            .output_distribution(OutputDistribution::Normal);
324        let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
325        let transformed = fitted.transform(&x).unwrap();
326
327        // The median value (5.5) should map to approximately 0
328        // Values below median should be negative, above should be positive
329        assert!(transformed[[0, 0]] < 0.0);
330        assert!(transformed[[9, 0]] > 0.0);
331
332        // The output should be symmetric around the median
333        assert_abs_diff_eq!(transformed[[0, 0]], -transformed[[9, 0]], epsilon = 1e-6);
334    }
335
336    #[test]
337    fn test_output_range_uniform() {
338        let x = array![
339            [10.0],
340            [20.0],
341            [30.0],
342            [40.0],
343            [50.0],
344            [60.0],
345            [70.0],
346            [80.0],
347            [90.0],
348            [100.0],
349        ];
350        let qt = QuantileTransformer::new()
351            .n_quantiles(10)
352            .output_distribution(OutputDistribution::Uniform);
353        let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
354        let transformed = fitted.transform(&x).unwrap();
355
356        // All values should be in (0, 1)
357        for &v in transformed.iter() {
358            assert!(v > 0.0 && v < 1.0, "value out of range: {}", v);
359        }
360    }
361
362    #[test]
363    fn test_empty_input() {
364        let x: Array2<f64> = Array2::zeros((0, 0));
365        let qt = QuantileTransformer::default();
366        assert!(FitUnsupervised::<f64>::fit(&qt, &x).is_err());
367    }
368
369    #[test]
370    fn test_shape_mismatch() {
371        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
372        let qt = QuantileTransformer::default();
373        let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
374
375        let x_wrong = array![[1.0, 2.0, 3.0]];
376        assert!(fitted.transform(&x_wrong).is_err());
377    }
378
379    #[test]
380    fn test_n_quantiles_larger_than_samples() {
381        // n_quantiles > n_samples should be clamped
382        let x = array![[1.0], [2.0], [3.0]];
383        let qt = QuantileTransformer::new()
384            .n_quantiles(1000)
385            .output_distribution(OutputDistribution::Uniform);
386        let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
387        let transformed = fitted.transform(&x).unwrap();
388
389        // Should still produce valid output
390        for &v in transformed.iter() {
391            assert!(v.is_finite(), "non-finite value: {}", v);
392        }
393    }
394
395    #[test]
396    fn test_monotonicity_preserved() {
397        // Transform should preserve ordering
398        let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
399        let qt = QuantileTransformer::new()
400            .n_quantiles(5)
401            .output_distribution(OutputDistribution::Uniform);
402        let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
403        let transformed = fitted.transform(&x).unwrap();
404
405        for i in 1..x.nrows() {
406            assert!(
407                transformed[[i, 0]] >= transformed[[i - 1, 0]],
408                "monotonicity violated at row {}",
409                i
410            );
411        }
412    }
413
414    #[test]
415    fn test_inverse_normal_cdf_symmetry() {
416        // inverse_normal_cdf(0.5) should be 0
417        assert_abs_diff_eq!(inverse_normal_cdf(0.5), 0.0, epsilon = 1e-10);
418        // Symmetry: inv_cdf(p) = -inv_cdf(1-p)
419        for &p in &[0.1, 0.2, 0.3, 0.4] {
420            assert_abs_diff_eq!(
421                inverse_normal_cdf(p),
422                -inverse_normal_cdf(1.0 - p),
423                epsilon = 1e-10
424            );
425        }
426    }
427}