1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
6pub enum OutputDistribution {
7 Uniform,
9 Normal,
11}
12
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
18pub struct QuantileTransformer {
19 pub n_quantiles: usize,
21 pub output_distribution: OutputDistribution,
23}
24
25impl QuantileTransformer {
26 pub fn new() -> Self {
28 Self {
29 n_quantiles: 1000,
30 output_distribution: OutputDistribution::Uniform,
31 }
32 }
33
34 pub fn n_quantiles(mut self, n_quantiles: usize) -> Self {
36 self.n_quantiles = n_quantiles;
37 self
38 }
39
40 pub fn output_distribution(mut self, output_distribution: OutputDistribution) -> Self {
42 self.output_distribution = output_distribution;
43 self
44 }
45}
46
47impl Default for QuantileTransformer {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
56pub struct FittedQuantileTransformer<F: Float> {
57 quantiles: Vec<Vec<F>>,
59 references: Vec<f64>,
61 output_distribution: OutputDistribution,
62}
63
64fn inverse_normal_cdf(p: f64) -> f64 {
67 if p <= 0.0 {
68 return -8.0; }
70 if p >= 1.0 {
71 return 8.0; }
73
74 const A: [f64; 6] = [
76 -3.969683028665376e+01,
77 2.209460984245205e+02,
78 -2.759285104469687e+02,
79 1.383577518672690e+02,
80 -3.066479806614716e+01,
81 2.506628277459239e+00,
82 ];
83 const B: [f64; 5] = [
84 -5.447609879822406e+01,
85 1.615858368580409e+02,
86 -1.556989798598866e+02,
87 6.680131188771972e+01,
88 -1.328068155288572e+01,
89 ];
90 const C: [f64; 6] = [
91 -7.784894002430293e-03,
92 -3.223964580411365e-01,
93 -2.400758277161838e+00,
94 -2.549732539343734e+00,
95 4.374664141464968e+00,
96 2.938163982698783e+00,
97 ];
98 const D: [f64; 4] = [
99 7.784695709041462e-03,
100 3.224671290700398e-01,
101 2.445134137142996e+00,
102 3.754408661907416e+00,
103 ];
104
105 let p_low = 0.02425;
106 let p_high = 1.0 - p_low;
107
108 if p < p_low {
109 let q = (-2.0 * p.ln()).sqrt();
111 (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
112 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
113 } else if p <= p_high {
114 let q = p - 0.5;
116 let r = q * q;
117 (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
118 / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
119 } else {
120 let q = (-2.0 * (1.0 - p).ln()).sqrt();
122 -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
123 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
124 }
125}
126
127fn interp(x: f64, xp: &[f64], fp: &[f64]) -> f64 {
130 debug_assert_eq!(xp.len(), fp.len());
131 let n = xp.len();
132 if n == 0 {
133 return 0.0;
134 }
135 if x <= xp[0] {
136 return fp[0];
137 }
138 if x >= xp[n - 1] {
139 return fp[n - 1];
140 }
141
142 let mut lo = 0;
144 let mut hi = n - 1;
145 while lo + 1 < hi {
146 let mid = (lo + hi) / 2;
147 if xp[mid] <= x {
148 lo = mid;
149 } else {
150 hi = mid;
151 }
152 }
153
154 let dx = xp[hi] - xp[lo];
155 if dx.abs() < 1e-30 {
156 return fp[lo];
157 }
158 let t = (x - xp[lo]) / dx;
159 fp[lo] + t * (fp[hi] - fp[lo])
160}
161
162impl<F: Float> FitUnsupervised<F> for QuantileTransformer {
163 type Fitted = FittedQuantileTransformer<F>;
164
165 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
166 if x.is_empty() {
167 return Err(RustMlError::EmptyInput("input array is empty".into()));
168 }
169
170 let n_samples = x.nrows();
171 let ncols = x.ncols();
172 let effective_n = self.n_quantiles.min(n_samples);
173
174 let references: Vec<f64> = if effective_n == 1 {
176 vec![0.5]
177 } else {
178 (0..effective_n)
179 .map(|i| i as f64 / (effective_n - 1) as f64)
180 .collect()
181 };
182
183 let mut quantiles = Vec::with_capacity(ncols);
184
185 for j in 0..ncols {
186 let mut col: Vec<F> = x.column(j).to_vec();
187 col.sort_by(|a, b| a.partial_cmp(b).unwrap());
188
189 let q: Vec<F> = references
191 .iter()
192 .map(|&p| percentile_sorted(&col, p))
193 .collect();
194
195 quantiles.push(q);
196 }
197
198 Ok(FittedQuantileTransformer {
199 quantiles,
200 references,
201 output_distribution: self.output_distribution,
202 })
203 }
204}
205
206fn percentile_sorted<F: Float>(sorted: &[F], p: f64) -> F {
208 let n = sorted.len();
209 if n == 1 {
210 return sorted[0];
211 }
212 let idx = p * (n - 1) as f64;
213 let lo = idx.floor() as usize;
214 let hi = idx.ceil().min((n - 1) as f64) as usize;
215 if lo == hi {
216 sorted[lo]
217 } else {
218 let frac = F::from_f64(idx - lo as f64).unwrap();
219 sorted[lo] * (F::one() - frac) + sorted[hi] * frac
220 }
221}
222
223impl<F: Float> Transform<F> for FittedQuantileTransformer<F> {
224 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
225 let expected_cols = self.quantiles.len();
226 if x.ncols() != expected_cols {
227 return Err(RustMlError::ShapeMismatch(format!(
228 "expected {} features, got {}",
229 expected_cols,
230 x.ncols()
231 )));
232 }
233
234 let mut result = Array2::<F>::zeros(x.raw_dim());
235
236 for j in 0..x.ncols() {
237 let q = &self.quantiles[j];
238 let xp: Vec<f64> = q.iter().map(|&v| v.to_f64().unwrap()).collect();
240 let fp = &self.references;
241
242 for i in 0..x.nrows() {
243 let val = x[[i, j]].to_f64().unwrap();
244 let mut u = interp(val, &xp, fp);
246
247 let eps = 1e-7;
249 u = u.max(eps).min(1.0 - eps);
250
251 let out = match self.output_distribution {
252 OutputDistribution::Uniform => u,
253 OutputDistribution::Normal => inverse_normal_cdf(u),
254 };
255
256 result[[i, j]] = F::from_f64(out).unwrap();
257 }
258 }
259
260 Ok(result)
261 }
262}
263
264impl<F: Float> FittedQuantileTransformer<F> {
265 pub fn quantiles(&self) -> &Vec<Vec<F>> {
267 &self.quantiles
268 }
269
270 pub fn references(&self) -> &Vec<f64> {
272 &self.references
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use approx::assert_abs_diff_eq;
280 use ndarray::array;
281
282 #[test]
283 fn test_uniform_output() {
284 let x = array![
285 [1.0, 10.0],
286 [2.0, 20.0],
287 [3.0, 30.0],
288 [4.0, 40.0],
289 [5.0, 50.0],
290 ];
291 let qt = QuantileTransformer::new()
292 .n_quantiles(5)
293 .output_distribution(OutputDistribution::Uniform);
294 let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
295 let transformed = fitted.transform(&x).unwrap();
296
297 let eps = 1e-7;
300 assert_abs_diff_eq!(transformed[[0, 0]], eps, epsilon = 1e-6);
301 assert_abs_diff_eq!(transformed[[1, 0]], 0.25, epsilon = 1e-6);
302 assert_abs_diff_eq!(transformed[[2, 0]], 0.5, epsilon = 1e-6);
303 assert_abs_diff_eq!(transformed[[3, 0]], 0.75, epsilon = 1e-6);
304 assert_abs_diff_eq!(transformed[[4, 0]], 1.0 - eps, epsilon = 1e-6);
305 }
306
307 #[test]
308 fn test_normal_output() {
309 let x = array![
310 [1.0],
311 [2.0],
312 [3.0],
313 [4.0],
314 [5.0],
315 [6.0],
316 [7.0],
317 [8.0],
318 [9.0],
319 [10.0],
320 ];
321 let qt = QuantileTransformer::new()
322 .n_quantiles(10)
323 .output_distribution(OutputDistribution::Normal);
324 let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
325 let transformed = fitted.transform(&x).unwrap();
326
327 assert!(transformed[[0, 0]] < 0.0);
330 assert!(transformed[[9, 0]] > 0.0);
331
332 assert_abs_diff_eq!(transformed[[0, 0]], -transformed[[9, 0]], epsilon = 1e-6);
334 }
335
336 #[test]
337 fn test_output_range_uniform() {
338 let x = array![
339 [10.0],
340 [20.0],
341 [30.0],
342 [40.0],
343 [50.0],
344 [60.0],
345 [70.0],
346 [80.0],
347 [90.0],
348 [100.0],
349 ];
350 let qt = QuantileTransformer::new()
351 .n_quantiles(10)
352 .output_distribution(OutputDistribution::Uniform);
353 let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
354 let transformed = fitted.transform(&x).unwrap();
355
356 for &v in transformed.iter() {
358 assert!(v > 0.0 && v < 1.0, "value out of range: {}", v);
359 }
360 }
361
362 #[test]
363 fn test_empty_input() {
364 let x: Array2<f64> = Array2::zeros((0, 0));
365 let qt = QuantileTransformer::default();
366 assert!(FitUnsupervised::<f64>::fit(&qt, &x).is_err());
367 }
368
369 #[test]
370 fn test_shape_mismatch() {
371 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
372 let qt = QuantileTransformer::default();
373 let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
374
375 let x_wrong = array![[1.0, 2.0, 3.0]];
376 assert!(fitted.transform(&x_wrong).is_err());
377 }
378
379 #[test]
380 fn test_n_quantiles_larger_than_samples() {
381 let x = array![[1.0], [2.0], [3.0]];
383 let qt = QuantileTransformer::new()
384 .n_quantiles(1000)
385 .output_distribution(OutputDistribution::Uniform);
386 let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
387 let transformed = fitted.transform(&x).unwrap();
388
389 for &v in transformed.iter() {
391 assert!(v.is_finite(), "non-finite value: {}", v);
392 }
393 }
394
395 #[test]
396 fn test_monotonicity_preserved() {
397 let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
399 let qt = QuantileTransformer::new()
400 .n_quantiles(5)
401 .output_distribution(OutputDistribution::Uniform);
402 let fitted = FitUnsupervised::<f64>::fit(&qt, &x).unwrap();
403 let transformed = fitted.transform(&x).unwrap();
404
405 for i in 1..x.nrows() {
406 assert!(
407 transformed[[i, 0]] >= transformed[[i - 1, 0]],
408 "monotonicity violated at row {}",
409 i
410 );
411 }
412 }
413
414 #[test]
415 fn test_inverse_normal_cdf_symmetry() {
416 assert_abs_diff_eq!(inverse_normal_cdf(0.5), 0.0, epsilon = 1e-10);
418 for &p in &[0.1, 0.2, 0.3, 0.4] {
420 assert_abs_diff_eq!(
421 inverse_normal_cdf(p),
422 -inverse_normal_cdf(1.0 - p),
423 epsilon = 1e-10
424 );
425 }
426 }
427}