1use ferrolearn_core::error::FerroError;
12use ferrolearn_core::traits::{Fit, FitTransform, Transform};
13use ndarray::Array2;
14use num_traits::Float;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum OutputDistribution {
23 Uniform,
25 Normal,
27}
28
29#[must_use]
62#[derive(Debug, Clone)]
63pub struct QuantileTransformer<F> {
64 n_quantiles: usize,
66 output_distribution: OutputDistribution,
68 subsample: usize,
70 _marker: std::marker::PhantomData<F>,
71}
72
73impl<F: Float + Send + Sync + 'static> QuantileTransformer<F> {
74 pub fn new(
76 n_quantiles: usize,
77 output_distribution: OutputDistribution,
78 subsample: usize,
79 ) -> Self {
80 Self {
81 n_quantiles,
82 output_distribution,
83 subsample,
84 _marker: std::marker::PhantomData,
85 }
86 }
87
88 #[must_use]
90 pub fn n_quantiles(&self) -> usize {
91 self.n_quantiles
92 }
93
94 #[must_use]
96 pub fn output_distribution(&self) -> OutputDistribution {
97 self.output_distribution
98 }
99
100 #[must_use]
102 pub fn subsample(&self) -> usize {
103 self.subsample
104 }
105}
106
107impl<F: Float + Send + Sync + 'static> Default for QuantileTransformer<F> {
108 fn default() -> Self {
109 Self::new(1000, OutputDistribution::Uniform, 100_000)
110 }
111}
112
113#[derive(Debug, Clone)]
121pub struct FittedQuantileTransformer<F> {
122 quantiles: Vec<Vec<F>>,
125 references: Vec<F>,
127 output_distribution: OutputDistribution,
129}
130
131impl<F: Float + Send + Sync + 'static> FittedQuantileTransformer<F> {
132 #[must_use]
134 pub fn quantiles(&self) -> &[Vec<F>] {
135 &self.quantiles
136 }
137
138 #[must_use]
140 pub fn n_features(&self) -> usize {
141 self.quantiles.len()
142 }
143}
144
145fn probit<F: Float>(p: F) -> F {
152 let eps = F::from(1e-7).unwrap_or(F::min_positive_value());
154 let p = if p < eps {
155 eps
156 } else if p > F::one() - eps {
157 F::one() - eps
158 } else {
159 p
160 };
161
162 let half = F::from(0.5).unwrap();
164 if p < half {
165 let t = (-F::from(2.0).unwrap() * p.ln()).sqrt();
167 let c0 = F::from(2.515517).unwrap();
168 let c1 = F::from(0.802853).unwrap();
169 let c2 = F::from(0.010328).unwrap();
170 let d1 = F::from(1.432788).unwrap();
171 let d2 = F::from(0.189269).unwrap();
172 let d3 = F::from(0.001308).unwrap();
173 let num = c0 + c1 * t + c2 * t * t;
174 let den = F::one() + d1 * t + d2 * t * t + d3 * t * t * t;
175 -(t - num / den)
176 } else {
177 let t = (-F::from(2.0).unwrap() * (F::one() - p).ln()).sqrt();
178 let c0 = F::from(2.515517).unwrap();
179 let c1 = F::from(0.802853).unwrap();
180 let c2 = F::from(0.010328).unwrap();
181 let d1 = F::from(1.432788).unwrap();
182 let d2 = F::from(0.189269).unwrap();
183 let d3 = F::from(0.001308).unwrap();
184 let num = c0 + c1 * t + c2 * t * t;
185 let den = F::one() + d1 * t + d2 * t * t + d3 * t * t * t;
186 t - num / den
187 }
188}
189
190fn interpolate_cdf<F: Float>(value: F, quantiles: &[F], references: &[F]) -> F {
193 if quantiles.is_empty() {
194 return F::from(0.5).unwrap();
195 }
196
197 if value <= quantiles[0] {
199 return references[0];
200 }
201 if value >= quantiles[quantiles.len() - 1] {
202 return references[references.len() - 1];
203 }
204
205 let mut lo = 0;
207 let mut hi = quantiles.len() - 1;
208 while lo < hi - 1 {
209 let mid = (lo + hi) / 2;
210 if quantiles[mid] <= value {
211 lo = mid;
212 } else {
213 hi = mid;
214 }
215 }
216
217 let denom = quantiles[hi] - quantiles[lo];
219 if denom == F::zero() {
220 references[lo]
221 } else {
222 let frac = (value - quantiles[lo]) / denom;
223 references[lo] + frac * (references[hi] - references[lo])
224 }
225}
226
227impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for QuantileTransformer<F> {
232 type Fitted = FittedQuantileTransformer<F>;
233 type Error = FerroError;
234
235 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedQuantileTransformer<F>, FerroError> {
242 let n_samples = x.nrows();
243 if n_samples < 2 {
244 return Err(FerroError::InsufficientSamples {
245 required: 2,
246 actual: n_samples,
247 context: "QuantileTransformer::fit".into(),
248 });
249 }
250 if self.n_quantiles < 2 {
251 return Err(FerroError::InvalidParameter {
252 name: "n_quantiles".into(),
253 reason: "n_quantiles must be at least 2".into(),
254 });
255 }
256
257 let n_features = x.ncols();
258 let effective_quantiles = self.n_quantiles.min(n_samples);
259
260 let references: Vec<F> = (0..effective_quantiles)
262 .map(|i| F::from(i).unwrap() / F::from(effective_quantiles - 1).unwrap_or(F::one()))
263 .collect();
264
265 let mut quantiles = Vec::with_capacity(n_features);
266
267 for j in 0..n_features {
268 let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
269 col_vals.retain(|v| !v.is_nan());
271 col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
272
273 if self.subsample > 0 && col_vals.len() > self.subsample {
275 let step = col_vals.len() as f64 / self.subsample as f64;
276 let mut sampled = Vec::with_capacity(self.subsample);
277 for i in 0..self.subsample {
278 let idx = (i as f64 * step) as usize;
279 sampled.push(col_vals[idx.min(col_vals.len() - 1)]);
280 }
281 col_vals = sampled;
282 }
283
284 let n = col_vals.len();
286 let mut feature_quantiles = Vec::with_capacity(effective_quantiles);
287 for &ref_level in &references {
288 let pos = ref_level * F::from(n.saturating_sub(1)).unwrap();
289 let lo = pos.floor().to_usize().unwrap_or(0).min(n.saturating_sub(1));
290 let hi = pos.ceil().to_usize().unwrap_or(0).min(n.saturating_sub(1));
291 let frac = pos - F::from(lo).unwrap();
292 let val = if lo == hi {
293 col_vals[lo]
294 } else {
295 col_vals[lo] * (F::one() - frac) + col_vals[hi] * frac
296 };
297 feature_quantiles.push(val);
298 }
299
300 quantiles.push(feature_quantiles);
301 }
302
303 Ok(FittedQuantileTransformer {
304 quantiles,
305 references,
306 output_distribution: self.output_distribution,
307 })
308 }
309}
310
311impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedQuantileTransformer<F> {
312 type Output = Array2<F>;
313 type Error = FerroError;
314
315 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
322 let n_features = self.quantiles.len();
323 if x.ncols() != n_features {
324 return Err(FerroError::ShapeMismatch {
325 expected: vec![x.nrows(), n_features],
326 actual: vec![x.nrows(), x.ncols()],
327 context: "FittedQuantileTransformer::transform".into(),
328 });
329 }
330
331 let mut out = x.to_owned();
332
333 for j in 0..n_features {
334 let feature_quantiles = &self.quantiles[j];
335 for i in 0..out.nrows() {
336 let val = out[[i, j]];
337 if val.is_nan() {
338 continue;
339 }
340 let cdf_val = interpolate_cdf(val, feature_quantiles, &self.references);
341
342 out[[i, j]] = match self.output_distribution {
343 OutputDistribution::Uniform => cdf_val,
344 OutputDistribution::Normal => probit(cdf_val),
345 };
346 }
347 }
348
349 Ok(out)
350 }
351}
352
353impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for QuantileTransformer<F> {
356 type Output = Array2<F>;
357 type Error = FerroError;
358
359 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
361 Err(FerroError::InvalidParameter {
362 name: "QuantileTransformer".into(),
363 reason: "transformer must be fitted before calling transform; use fit() first".into(),
364 })
365 }
366}
367
368impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for QuantileTransformer<F> {
369 type FitError = FerroError;
370
371 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
377 let fitted = self.fit(x, &())?;
378 fitted.transform(x)
379 }
380}
381
382#[cfg(test)]
387mod tests {
388 use super::*;
389 use approx::assert_abs_diff_eq;
390 use ndarray::array;
391
392 #[test]
393 fn test_quantile_transformer_uniform() {
394 let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
395 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
396 let fitted = qt.fit(&x, &()).unwrap();
397 let out = fitted.transform(&x).unwrap();
398 for v in out.iter() {
400 assert!(*v >= 0.0 && *v <= 1.0, "Value {} not in [0,1]", v);
401 }
402 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-6);
404 assert_abs_diff_eq!(out[[4, 0]], 1.0, epsilon = 1e-6);
405 }
406
407 #[test]
408 fn test_quantile_transformer_normal() {
409 let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Normal, 0);
410 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
411 let fitted = qt.fit(&x, &()).unwrap();
412 let out = fitted.transform(&x).unwrap();
413 assert!(out[[2, 0]].abs() < 0.5, "Median should map near 0");
415 assert!(out[[0, 0]] < out[[4, 0]]);
417 }
418
419 #[test]
420 fn test_quantile_transformer_monotonic() {
421 let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
422 let x = array![[5.0], [3.0], [1.0], [4.0], [2.0]];
423 let fitted = qt.fit(&x, &()).unwrap();
424 let out = fitted.transform(&x).unwrap();
425 assert!(out[[0, 0]] > out[[1, 0]]); assert!(out[[1, 0]] > out[[2, 0]]); }
429
430 #[test]
431 fn test_quantile_transformer_multiple_features() {
432 let qt = QuantileTransformer::<f64>::new(50, OutputDistribution::Uniform, 0);
433 let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
434 let fitted = qt.fit(&x, &()).unwrap();
435 let out = fitted.transform(&x).unwrap();
436 assert_eq!(out.ncols(), 2);
437 for j in 0..2 {
439 assert!(out[[0, j]] <= out[[2, j]]);
440 }
441 }
442
443 #[test]
444 fn test_quantile_transformer_fit_transform() {
445 let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
446 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
447 let out = qt.fit_transform(&x).unwrap();
448 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-6);
449 assert_abs_diff_eq!(out[[4, 0]], 1.0, epsilon = 1e-6);
450 }
451
452 #[test]
453 fn test_quantile_transformer_insufficient_samples_error() {
454 let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
455 let x = array![[1.0]];
456 assert!(qt.fit(&x, &()).is_err());
457 }
458
459 #[test]
460 fn test_quantile_transformer_too_few_quantiles_error() {
461 let qt = QuantileTransformer::<f64>::new(1, OutputDistribution::Uniform, 0);
462 let x = array![[1.0], [2.0], [3.0]];
463 assert!(qt.fit(&x, &()).is_err());
464 }
465
466 #[test]
467 fn test_quantile_transformer_shape_mismatch() {
468 let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
469 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
470 let fitted = qt.fit(&x_train, &()).unwrap();
471 let x_bad = array![[1.0, 2.0, 3.0]];
472 assert!(fitted.transform(&x_bad).is_err());
473 }
474
475 #[test]
476 fn test_quantile_transformer_unfitted_error() {
477 let qt = QuantileTransformer::<f64>::new(100, OutputDistribution::Uniform, 0);
478 let x = array![[1.0]];
479 assert!(qt.transform(&x).is_err());
480 }
481
482 #[test]
483 fn test_quantile_transformer_default() {
484 let qt = QuantileTransformer::<f64>::default();
485 assert_eq!(qt.n_quantiles(), 1000);
486 assert_eq!(qt.output_distribution(), OutputDistribution::Uniform);
487 assert_eq!(qt.subsample(), 100_000);
488 }
489
490 #[test]
491 fn test_quantile_transformer_f32() {
492 let qt = QuantileTransformer::<f32>::new(50, OutputDistribution::Uniform, 0);
493 let x: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
494 let fitted = qt.fit(&x, &()).unwrap();
495 let out = fitted.transform(&x).unwrap();
496 assert!(out[[0, 0]] >= 0.0f32);
497 assert!(out[[4, 0]] <= 1.0f32);
498 }
499}