Skip to main content

ferrolearn_preprocess/
quantile_transformer.rs

1//! Quantile transformer: map features to a uniform or normal distribution.
2//!
3//! [`QuantileTransformer`] transforms features by mapping each value through
4//! its empirical cumulative distribution function (CDF), producing values
5//! uniformly distributed in `[0, 1]`. Optionally, the result can be mapped
6//! to a standard normal distribution using the inverse normal CDF (probit).
7//!
8//! This is useful for making features more Gaussian-like, which can improve
9//! the performance of many machine learning algorithms.
10
11use ferrolearn_core::error::FerroError;
12use ferrolearn_core::traits::{Fit, FitTransform, Transform};
13use ndarray::Array2;
14use num_traits::Float;
15
16// ---------------------------------------------------------------------------
17// OutputDistribution
18// ---------------------------------------------------------------------------
19
20/// Target output distribution for the quantile transformer.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum OutputDistribution {
23    /// Map to the uniform distribution on `[0, 1]`.
24    Uniform,
25    /// Map to the standard normal distribution via the probit function.
26    Normal,
27}
28
29// ---------------------------------------------------------------------------
30// QuantileTransformer (unfitted)
31// ---------------------------------------------------------------------------
32
33/// An unfitted quantile transformer.
34///
35/// Calling [`Fit::fit`] computes the quantiles for each feature and returns a
36/// [`FittedQuantileTransformer`].
37///
38/// # Parameters
39///
40/// - `n_quantiles` — number of quantile reference points (default 1000).
41/// - `output_distribution` — target distribution (default `Uniform`).
42/// - `subsample` — maximum number of samples used to compute quantiles
43///   (default 100_000; set to 0 to use all samples).
44///
45/// # Examples
46///
47/// ```
48/// use ferrolearn_preprocess::quantile_transformer::{QuantileTransformer, OutputDistribution};
49/// use ferrolearn_core::traits::{Fit, Transform};
50/// use ndarray::array;
51///
52/// let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
53/// let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
54/// let fitted = qt.fit(&x, &()).unwrap();
55/// let out = fitted.transform(&x).unwrap();
56/// // Values should be in [0, 1]
57/// for v in out.iter() {
58///     assert!(*v >= 0.0 && *v <= 1.0);
59/// }
60/// ```
61#[must_use]
62#[derive(Debug, Clone)]
63pub struct QuantileTransformer<F> {
64    /// Number of quantile reference points.
65    n_quantiles: usize,
66    /// Target output distribution.
67    output_distribution: OutputDistribution,
68    /// Maximum number of samples for quantile computation (0 = all).
69    subsample: usize,
70    _marker: std::marker::PhantomData<F>,
71}
72
73impl<F: Float + Send + Sync + 'static> QuantileTransformer<F> {
74    /// Create a new `QuantileTransformer`.
75    pub fn new(
76        n_quantiles: usize,
77        output_distribution: OutputDistribution,
78        subsample: usize,
79    ) -> Self {
80        Self {
81            n_quantiles,
82            output_distribution,
83            subsample,
84            _marker: std::marker::PhantomData,
85        }
86    }
87
88    /// Return the number of quantiles.
89    #[must_use]
90    pub fn n_quantiles(&self) -> usize {
91        self.n_quantiles
92    }
93
94    /// Return the target output distribution.
95    #[must_use]
96    pub fn output_distribution(&self) -> OutputDistribution {
97        self.output_distribution
98    }
99
100    /// Return the subsample size.
101    #[must_use]
102    pub fn subsample(&self) -> usize {
103        self.subsample
104    }
105}
106
107impl<F: Float + Send + Sync + 'static> Default for QuantileTransformer<F> {
108    fn default() -> Self {
109        Self::new(1000, OutputDistribution::Uniform, 100_000)
110    }
111}
112
113// ---------------------------------------------------------------------------
114// FittedQuantileTransformer
115// ---------------------------------------------------------------------------
116
117/// A fitted quantile transformer holding per-feature quantile references.
118///
119/// Created by calling [`Fit::fit`] on a [`QuantileTransformer`].
120#[derive(Debug, Clone)]
121pub struct FittedQuantileTransformer<F> {
122    /// Quantile reference values per feature: `quantiles[j]` is a sorted
123    /// vector of reference values for feature `j`.
124    quantiles: Vec<Vec<F>>,
125    /// The reference quantile levels (evenly spaced in [0, 1]).
126    references: Vec<F>,
127    /// Target output distribution.
128    output_distribution: OutputDistribution,
129}
130
131impl<F: Float + Send + Sync + 'static> FittedQuantileTransformer<F> {
132    /// Return the computed quantile reference values per feature.
133    #[must_use]
134    pub fn quantiles(&self) -> &[Vec<F>] {
135        &self.quantiles
136    }
137
138    /// Return the number of features.
139    #[must_use]
140    pub fn n_features(&self) -> usize {
141        self.quantiles.len()
142    }
143}
144
145// ---------------------------------------------------------------------------
146// Helpers
147// ---------------------------------------------------------------------------
148
149/// Approximate the inverse normal CDF (probit function) using the rational
150/// approximation by Abramowitz and Stegun.
151fn probit<F: Float>(p: F) -> F {
152    // Clamp to avoid infinities
153    let eps = F::from(1e-7).unwrap_or(F::min_positive_value());
154    let p = if p < eps {
155        eps
156    } else if p > F::one() - eps {
157        F::one() - eps
158    } else {
159        p
160    };
161
162    // Rational approximation for the probit function
163    let half = F::from(0.5).unwrap();
164    if p < half {
165        // Use symmetry: probit(p) = -probit(1-p)
166        let t = (-F::from(2.0).unwrap() * p.ln()).sqrt();
167        let c0 = F::from(2.515517).unwrap();
168        let c1 = F::from(0.802853).unwrap();
169        let c2 = F::from(0.010328).unwrap();
170        let d1 = F::from(1.432788).unwrap();
171        let d2 = F::from(0.189269).unwrap();
172        let d3 = F::from(0.001308).unwrap();
173        let num = c0 + c1 * t + c2 * t * t;
174        let den = F::one() + d1 * t + d2 * t * t + d3 * t * t * t;
175        -(t - num / den)
176    } else {
177        let t = (-F::from(2.0).unwrap() * (F::one() - p).ln()).sqrt();
178        let c0 = F::from(2.515517).unwrap();
179        let c1 = F::from(0.802853).unwrap();
180        let c2 = F::from(0.010328).unwrap();
181        let d1 = F::from(1.432788).unwrap();
182        let d2 = F::from(0.189269).unwrap();
183        let d3 = F::from(0.001308).unwrap();
184        let num = c0 + c1 * t + c2 * t * t;
185        let den = F::one() + d1 * t + d2 * t * t + d3 * t * t * t;
186        t - num / den
187    }
188}
189
190/// Linearly interpolate: find the quantile level for a given value in a
191/// sorted quantile reference vector.
192fn interpolate_cdf<F: Float>(value: F, quantiles: &[F], references: &[F]) -> F {
193    if quantiles.is_empty() {
194        return F::from(0.5).unwrap();
195    }
196
197    // Clamp to range
198    if value <= quantiles[0] {
199        return references[0];
200    }
201    if value >= quantiles[quantiles.len() - 1] {
202        return references[references.len() - 1];
203    }
204
205    // Binary search for the interval
206    let mut lo = 0;
207    let mut hi = quantiles.len() - 1;
208    while lo < hi - 1 {
209        let mid = (lo + hi) / 2;
210        if quantiles[mid] <= value {
211            lo = mid;
212        } else {
213            hi = mid;
214        }
215    }
216
217    // Linear interpolation
218    let denom = quantiles[hi] - quantiles[lo];
219    if denom == F::zero() {
220        references[lo]
221    } else {
222        let frac = (value - quantiles[lo]) / denom;
223        references[lo] + frac * (references[hi] - references[lo])
224    }
225}
226
227// ---------------------------------------------------------------------------
228// Trait implementations
229// ---------------------------------------------------------------------------
230
231impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for QuantileTransformer<F> {
232    type Fitted = FittedQuantileTransformer<F>;
233    type Error = FerroError;
234
235    /// Fit by computing per-feature quantile reference values.
236    ///
237    /// # Errors
238    ///
239    /// - [`FerroError::InsufficientSamples`] if the input has fewer than 2 rows.
240    /// - [`FerroError::InvalidParameter`] if `n_quantiles` is less than 2.
241    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedQuantileTransformer<F>, FerroError> {
242        let n_samples = x.nrows();
243        if n_samples < 2 {
244            return Err(FerroError::InsufficientSamples {
245                required: 2,
246                actual: n_samples,
247                context: "QuantileTransformer::fit".into(),
248            });
249        }
250        if self.n_quantiles < 2 {
251            return Err(FerroError::InvalidParameter {
252                name: "n_quantiles".into(),
253                reason: "n_quantiles must be at least 2".into(),
254            });
255        }
256
257        let n_features = x.ncols();
258        let effective_quantiles = self.n_quantiles.min(n_samples);
259
260        // Build evenly spaced reference levels in [0, 1]
261        let references: Vec<F> = (0..effective_quantiles)
262            .map(|i| F::from(i).unwrap() / F::from(effective_quantiles - 1).unwrap_or(F::one()))
263            .collect();
264
265        let mut quantiles = Vec::with_capacity(n_features);
266
267        for j in 0..n_features {
268            let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
269            // Remove NaN values
270            col_vals.retain(|v| !v.is_nan());
271            col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
272
273            // Subsample if needed
274            if self.subsample > 0 && col_vals.len() > self.subsample {
275                let step = col_vals.len() as f64 / self.subsample as f64;
276                let mut sampled = Vec::with_capacity(self.subsample);
277                for i in 0..self.subsample {
278                    let idx = (i as f64 * step) as usize;
279                    sampled.push(col_vals[idx.min(col_vals.len() - 1)]);
280                }
281                col_vals = sampled;
282            }
283
284            // Compute quantile reference values
285            let n = col_vals.len();
286            let mut feature_quantiles = Vec::with_capacity(effective_quantiles);
287            for &ref_level in &references {
288                let pos = ref_level * F::from(n.saturating_sub(1)).unwrap();
289                let lo = pos.floor().to_usize().unwrap_or(0).min(n.saturating_sub(1));
290                let hi = pos.ceil().to_usize().unwrap_or(0).min(n.saturating_sub(1));
291                let frac = pos - F::from(lo).unwrap();
292                let val = if lo == hi {
293                    col_vals[lo]
294                } else {
295                    col_vals[lo] * (F::one() - frac) + col_vals[hi] * frac
296                };
297                feature_quantiles.push(val);
298            }
299
300            quantiles.push(feature_quantiles);
301        }
302
303        Ok(FittedQuantileTransformer {
304            quantiles,
305            references,
306            output_distribution: self.output_distribution,
307        })
308    }
309}
310
311impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedQuantileTransformer<F> {
312    type Output = Array2<F>;
313    type Error = FerroError;
314
315    /// Transform data by mapping each value through the empirical CDF.
316    ///
317    /// # Errors
318    ///
319    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
320    /// from the number of features seen during fitting.
321    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
322        let n_features = self.quantiles.len();
323        if x.ncols() != n_features {
324            return Err(FerroError::ShapeMismatch {
325                expected: vec![x.nrows(), n_features],
326                actual: vec![x.nrows(), x.ncols()],
327                context: "FittedQuantileTransformer::transform".into(),
328            });
329        }
330
331        let mut out = x.to_owned();
332
333        for j in 0..n_features {
334            let feature_quantiles = &self.quantiles[j];
335            for i in 0..out.nrows() {
336                let val = out[[i, j]];
337                if val.is_nan() {
338                    continue;
339                }
340                let cdf_val = interpolate_cdf(val, feature_quantiles, &self.references);
341
342                out[[i, j]] = match self.output_distribution {
343                    OutputDistribution::Uniform => cdf_val,
344                    OutputDistribution::Normal => probit(cdf_val),
345                };
346            }
347        }
348
349        Ok(out)
350    }
351}
352
353/// Implement `Transform` on the unfitted transformer to satisfy the
354/// `FitTransform: Transform` supertrait bound.
355impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for QuantileTransformer<F> {
356    type Output = Array2<F>;
357    type Error = FerroError;
358
359    /// Always returns an error — the transformer must be fitted first.
360    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
361        Err(FerroError::InvalidParameter {
362            name: "QuantileTransformer".into(),
363            reason: "transformer must be fitted before calling transform; use fit() first".into(),
364        })
365    }
366}
367
368impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for QuantileTransformer<F> {
369    type FitError = FerroError;
370
371    /// Fit and transform in one step.
372    ///
373    /// # Errors
374    ///
375    /// Returns an error if fitting fails.
376    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
377        let fitted = self.fit(x, &())?;
378        fitted.transform(x)
379    }
380}
381
382// ---------------------------------------------------------------------------
383// Tests
384// ---------------------------------------------------------------------------
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use approx::assert_abs_diff_eq;
390    use ndarray::array;
391
392    #[test]
393    fn test_quantile_transformer_uniform() {
394        let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
395        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
396        let fitted = qt.fit(&x, &()).unwrap();
397        let out = fitted.transform(&x).unwrap();
398        // All values should be in [0, 1]
399        for v in out.iter() {
400            assert!(*v >= 0.0 && *v <= 1.0, "Value {} not in [0,1]", v);
401        }
402        // First should be 0, last should be 1
403        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-6);
404        assert_abs_diff_eq!(out[[4, 0]], 1.0, epsilon = 1e-6);
405    }
406
407    #[test]
408    fn test_quantile_transformer_normal() {
409        let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Normal, 0);
410        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
411        let fitted = qt.fit(&x, &()).unwrap();
412        let out = fitted.transform(&x).unwrap();
413        // Middle value should be close to 0 (median → 0 in normal)
414        assert!(out[[2, 0]].abs() < 0.5, "Median should map near 0");
415        // First should be negative, last positive
416        assert!(out[[0, 0]] < out[[4, 0]]);
417    }
418
419    #[test]
420    fn test_quantile_transformer_monotonic() {
421        let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
422        let x = array![[5.0], [3.0], [1.0], [4.0], [2.0]];
423        let fitted = qt.fit(&x, &()).unwrap();
424        let out = fitted.transform(&x).unwrap();
425        // Transform should preserve ordering: rank(5) > rank(3) > rank(1)
426        assert!(out[[0, 0]] > out[[1, 0]]); // 5 > 3
427        assert!(out[[1, 0]] > out[[2, 0]]); // 3 > 1
428    }
429
430    #[test]
431    fn test_quantile_transformer_multiple_features() {
432        let qt = QuantileTransformer::<f64>::new(50, OutputDistribution::Uniform, 0);
433        let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
434        let fitted = qt.fit(&x, &()).unwrap();
435        let out = fitted.transform(&x).unwrap();
436        assert_eq!(out.ncols(), 2);
437        // Each feature independently transformed
438        for j in 0..2 {
439            assert!(out[[0, j]] <= out[[2, j]]);
440        }
441    }
442
443    #[test]
444    fn test_quantile_transformer_fit_transform() {
445        let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
446        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
447        let out = qt.fit_transform(&x).unwrap();
448        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-6);
449        assert_abs_diff_eq!(out[[4, 0]], 1.0, epsilon = 1e-6);
450    }
451
452    #[test]
453    fn test_quantile_transformer_insufficient_samples_error() {
454        let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
455        let x = array![[1.0]];
456        assert!(qt.fit(&x, &()).is_err());
457    }
458
459    #[test]
460    fn test_quantile_transformer_too_few_quantiles_error() {
461        let qt = QuantileTransformer::<f64>::new(1, OutputDistribution::Uniform, 0);
462        let x = array![[1.0], [2.0], [3.0]];
463        assert!(qt.fit(&x, &()).is_err());
464    }
465
466    #[test]
467    fn test_quantile_transformer_shape_mismatch() {
468        let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
469        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
470        let fitted = qt.fit(&x_train, &()).unwrap();
471        let x_bad = array![[1.0, 2.0, 3.0]];
472        assert!(fitted.transform(&x_bad).is_err());
473    }
474
475    #[test]
476    fn test_quantile_transformer_unfitted_error() {
477        let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
478        let x = array![[1.0]];
479        assert!(qt.transform(&x).is_err());
480    }
481
482    #[test]
483    fn test_quantile_transformer_default() {
484        let qt = QuantileTransformer::<f64>::default();
485        assert_eq!(qt.n_quantiles(), 1000);
486        assert_eq!(qt.output_distribution(), OutputDistribution::Uniform);
487        assert_eq!(qt.subsample(), 100_000);
488    }
489
490    #[test]
491    fn test_quantile_transformer_f32() {
492        let qt = QuantileTransformer::<f32>::new(50, OutputDistribution::Uniform, 0);
493        let x: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
494        let fitted = qt.fit(&x, &()).unwrap();
495        let out = fitted.transform(&x).unwrap();
496        assert!(out[[0, 0]] >= 0.0f32);
497        assert!(out[[4, 0]] <= 1.0f32);
498    }
499}