Skip to main content

ferrolearn_linear/
quantile_regressor.rs

1//! Quantile Regression via IRLS on the pinball loss.
2//!
3//! This module provides [`QuantileRegressor`], which estimates conditional
4//! quantiles of the response variable. The default `quantile = 0.5`
5//! corresponds to the conditional median, which is more robust to outliers
6//! than the conditional mean (OLS).
7//!
8//! The pinball (check) loss for quantile `q` is:
9//!
10//! ```text
11//! L_q(r) = q * max(r, 0) + (1 - q) * max(-r, 0)
12//! ```
13//!
14//! The model is fitted via IRLS with weights `w_i = 1 / (2 * max(|r_i|, eps))`
15//! and optional L1 regularization (`alpha`).
16//!
17//! # Examples
18//!
19//! ```
20//! use ferrolearn_linear::QuantileRegressor;
21//! use ferrolearn_core::{Fit, Predict};
22//! use ndarray::{array, Array1, Array2};
23//!
24//! let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
25//! let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
26//!
27//! let model = QuantileRegressor::<f64>::new(); // median regression
28//! let fitted = model.fit(&x, &y).unwrap();
29//! let preds = fitted.predict(&x).unwrap();
30//! assert_eq!(preds.len(), 5);
31//! ```
32
33use ferrolearn_core::error::FerroError;
34use ferrolearn_core::introspection::HasCoefficients;
35use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
36use ferrolearn_core::traits::{Fit, Predict};
37use ndarray::{Array1, Array2, Axis, ScalarOperand};
38use num_traits::{Float, FromPrimitive};
39
40/// Quantile Regressor — conditional quantile estimation via IRLS.
41///
42/// Minimises the pinball loss with optional L1 regularization. The IRLS
43/// weights are `w_i = 1 / (2 * max(|r_i|, eps))`, which gives the
44/// iteratively reweighted least absolute deviations procedure.
45///
46/// # Type Parameters
47///
48/// - `F`: The floating-point type (`f32` or `f64`).
49#[derive(Debug, Clone)]
50pub struct QuantileRegressor<F> {
51    /// The quantile to estimate (must be in (0, 1)). Default 0.5 (median).
52    pub quantile: F,
53    /// L1 regularization strength.
54    pub alpha: F,
55    /// Maximum number of IRLS iterations.
56    pub max_iter: usize,
57    /// Convergence tolerance on the maximum coefficient change.
58    pub tol: F,
59    /// Whether to fit an intercept (bias) term.
60    pub fit_intercept: bool,
61}
62
63impl<F: Float + FromPrimitive> QuantileRegressor<F> {
64    /// Create a new `QuantileRegressor` with default settings.
65    ///
66    /// Defaults: `quantile = 0.5`, `alpha = 1.0`, `max_iter = 1000`,
67    /// `tol = 1e-5`, `fit_intercept = true`.
68    #[must_use]
69    pub fn new() -> Self {
70        Self {
71            quantile: F::from(0.5).unwrap(),
72            alpha: F::one(),
73            max_iter: 1000,
74            tol: F::from(1e-5).unwrap(),
75            fit_intercept: true,
76        }
77    }
78
79    /// Set the quantile to estimate.
80    ///
81    /// Must be strictly between 0 and 1.
82    #[must_use]
83    pub fn with_quantile(mut self, quantile: F) -> Self {
84        self.quantile = quantile;
85        self
86    }
87
88    /// Set the L1 regularization strength.
89    #[must_use]
90    pub fn with_alpha(mut self, alpha: F) -> Self {
91        self.alpha = alpha;
92        self
93    }
94
95    /// Set the maximum number of IRLS iterations.
96    #[must_use]
97    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
98        self.max_iter = max_iter;
99        self
100    }
101
102    /// Set the convergence tolerance.
103    #[must_use]
104    pub fn with_tol(mut self, tol: F) -> Self {
105        self.tol = tol;
106        self
107    }
108
109    /// Set whether to fit an intercept term.
110    #[must_use]
111    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
112        self.fit_intercept = fit_intercept;
113        self
114    }
115}
116
117impl<F: Float + FromPrimitive> Default for QuantileRegressor<F> {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123/// Fitted Quantile Regressor model.
124///
125/// Stores the learned coefficients and intercept.
126#[derive(Debug, Clone)]
127pub struct FittedQuantileRegressor<F> {
128    /// Learned coefficient vector.
129    coefficients: Array1<F>,
130    /// Learned intercept (bias) term.
131    intercept: F,
132}
133
134// ---------------------------------------------------------------------------
135// Internal helpers
136// ---------------------------------------------------------------------------
137
138/// Cholesky solve for `A x = b`.
139fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
140    let n = a.nrows();
141    let mut l = Array2::<F>::zeros((n, n));
142
143    for i in 0..n {
144        for j in 0..=i {
145            let mut s = a[[i, j]];
146            for k in 0..j {
147                s = s - l[[i, k]] * l[[j, k]];
148            }
149            if i == j {
150                if s <= F::zero() {
151                    return Err(FerroError::NumericalInstability {
152                        message: "Cholesky: matrix not positive definite".into(),
153                    });
154                }
155                l[[i, j]] = s.sqrt();
156            } else {
157                l[[i, j]] = s / l[[j, j]];
158            }
159        }
160    }
161
162    let mut z = Array1::<F>::zeros(n);
163    for i in 0..n {
164        let mut s = b[i];
165        for k in 0..i {
166            s = s - l[[i, k]] * z[k];
167        }
168        z[i] = s / l[[i, i]];
169    }
170
171    let mut x_sol = Array1::<F>::zeros(n);
172    for i in (0..n).rev() {
173        let mut s = z[i];
174        for k in (i + 1)..n {
175            s = s - l[[k, i]] * x_sol[k];
176        }
177        x_sol[i] = s / l[[i, i]];
178    }
179
180    Ok(x_sol)
181}
182
183/// Gaussian elimination with partial pivoting.
184fn gaussian_solve<F: Float>(
185    n: usize,
186    a: &Array2<F>,
187    b: &Array1<F>,
188) -> Result<Array1<F>, FerroError> {
189    let mut aug = Array2::<F>::zeros((n, n + 1));
190    for i in 0..n {
191        for j in 0..n {
192            aug[[i, j]] = a[[i, j]];
193        }
194        aug[[i, n]] = b[i];
195    }
196
197    for col in 0..n {
198        let mut max_val = aug[[col, col]].abs();
199        let mut max_row = col;
200        for row in (col + 1)..n {
201            let v = aug[[row, col]].abs();
202            if v > max_val {
203                max_val = v;
204                max_row = row;
205            }
206        }
207
208        if max_val < F::from(1e-12).unwrap_or_else(F::epsilon) {
209            return Err(FerroError::NumericalInstability {
210                message: "singular matrix in Gaussian elimination".into(),
211            });
212        }
213
214        if max_row != col {
215            for j in 0..=n {
216                let tmp = aug[[col, j]];
217                aug[[col, j]] = aug[[max_row, j]];
218                aug[[max_row, j]] = tmp;
219            }
220        }
221
222        let pivot = aug[[col, col]];
223        for row in (col + 1)..n {
224            let factor = aug[[row, col]] / pivot;
225            for j in col..=n {
226                let above = aug[[col, j]];
227                aug[[row, j]] = aug[[row, j]] - factor * above;
228            }
229        }
230    }
231
232    let mut x_sol = Array1::<F>::zeros(n);
233    for i in (0..n).rev() {
234        let mut s = aug[[i, n]];
235        for j in (i + 1)..n {
236            s = s - aug[[i, j]] * x_sol[j];
237        }
238        if aug[[i, i]].abs() < F::from(1e-12).unwrap_or_else(F::epsilon) {
239            return Err(FerroError::NumericalInstability {
240                message: "near-zero pivot in back substitution".into(),
241            });
242        }
243        x_sol[i] = s / aug[[i, i]];
244    }
245
246    Ok(x_sol)
247}
248
249/// Solve the weighted least-squares problem with L1 penalty approximation.
250///
251/// `(X^T W X + n_samples * alpha * diag) w = X^T W y`
252///
253/// For the quantile regression IRLS, the L1 penalty is linearised around
254/// the current coefficients. The penalty is scaled by `n_samples` so that
255/// `alpha` has the same meaning as in scikit-learn — sklearn's
256/// `QuantileRegressor` averages the data-fit term by `1/n` and adds an
257/// unscaled `alpha * ||w||_1`, which is mathematically equivalent to our
258/// unaveraged data fit plus `n_samples * alpha * ||w||_1`. Without this
259/// factor, `alpha = 1.0` in ferrolearn would be roughly `n_samples` times
260/// weaker than the same value in sklearn.
261fn weighted_l1_solve<F: Float + FromPrimitive>(
262    x: &Array2<F>,
263    y: &Array1<F>,
264    weights: &Array1<F>,
265    alpha: F,
266    prev_coef: &Array1<F>,
267) -> Result<Array1<F>, FerroError> {
268    let (n_samples, n_features) = x.dim();
269    let eps = F::from(1e-8).unwrap();
270    let n_f = F::from(n_samples).unwrap_or_else(F::one);
271    let scaled_alpha = alpha * n_f;
272
273    let mut xtwx = Array2::<F>::zeros((n_features, n_features));
274    let mut xtwy = Array1::<F>::zeros(n_features);
275
276    for i in 0..n_samples {
277        let wi = weights[i];
278        let xi = x.row(i);
279        for r in 0..n_features {
280            xtwy[r] = xtwy[r] + wi * xi[r] * y[i];
281            for c in 0..n_features {
282                xtwx[[r, c]] = xtwx[[r, c]] + wi * xi[r] * xi[c];
283            }
284        }
285    }
286
287    // Add L1 penalty via IRLS approximation: penalise with
288    // (n_samples * alpha) / max(|w_j|, eps). The n_samples factor keeps
289    // `alpha` numerically equivalent to scikit-learn's `alpha` parameter
290    // (see function-level docstring).
291    for j in 0..n_features {
292        let pen = scaled_alpha / prev_coef[j].abs().max(eps);
293        xtwx[[j, j]] = xtwx[[j, j]] + pen;
294    }
295
296    cholesky_solve(&xtwx, &xtwy).or_else(|_| gaussian_solve(n_features, &xtwx, &xtwy))
297}
298
299// ---------------------------------------------------------------------------
300// Fit
301// ---------------------------------------------------------------------------
302
303impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
304    for QuantileRegressor<F>
305{
306    type Fitted = FittedQuantileRegressor<F>;
307    type Error = FerroError;
308
309    /// Fit the quantile regression model via IRLS.
310    ///
311    /// # Errors
312    ///
313    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
314    /// - [`FerroError::InsufficientSamples`] — zero samples.
315    /// - [`FerroError::InvalidParameter`] — quantile outside (0, 1) or
316    ///   negative alpha.
317    fn fit(
318        &self,
319        x: &Array2<F>,
320        y: &Array1<F>,
321    ) -> Result<FittedQuantileRegressor<F>, FerroError> {
322        let (n_samples, n_features) = x.dim();
323
324        if n_samples != y.len() {
325            return Err(FerroError::ShapeMismatch {
326                expected: vec![n_samples],
327                actual: vec![y.len()],
328                context: "y length must match number of samples in X".into(),
329            });
330        }
331
332        if n_samples == 0 {
333            return Err(FerroError::InsufficientSamples {
334                required: 1,
335                actual: 0,
336                context: "QuantileRegressor requires at least one sample".into(),
337            });
338        }
339
340        if self.quantile <= F::zero() || self.quantile >= F::one() {
341            return Err(FerroError::InvalidParameter {
342                name: "quantile".into(),
343                reason: "must be strictly between 0 and 1".into(),
344            });
345        }
346
347        if self.alpha < F::zero() {
348            return Err(FerroError::InvalidParameter {
349                name: "alpha".into(),
350                reason: "must be non-negative".into(),
351            });
352        }
353
354        let eps = F::from(1e-8).unwrap();
355        let one = F::one();
356        let q = self.quantile;
357
358        // Center data if fitting intercept.
359        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
360            let x_mean = x
361                .mean_axis(Axis(0))
362                .ok_or_else(|| FerroError::NumericalInstability {
363                    message: "failed to compute column means".into(),
364                })?;
365            let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
366                message: "failed to compute target mean".into(),
367            })?;
368            let x_c = x - &x_mean;
369            let y_c = y - y_mean;
370            (x_c, y_c, Some(x_mean), Some(y_mean))
371        } else {
372            (x.clone(), y.clone(), None, None)
373        };
374
375        // Initialise coefficients to zero.
376        let mut w = Array1::<F>::zeros(n_features);
377        // Initialise with small values for L1 linearisation.
378        let mut w_prev = Array1::from_elem(n_features, eps);
379
380        for _iter in 0..self.max_iter {
381            let w_old = w.clone();
382
383            // Compute residuals.
384            let residuals = &y_work - x_work.dot(&w);
385
386            // Compute IRLS weights for pinball loss.
387            // weight_i = asymmetric_weight_i / (2 * max(|r_i|, eps))
388            let mut weights = Array1::<F>::zeros(n_samples);
389            for i in 0..n_samples {
390                let abs_r = residuals[i].abs().max(eps);
391                // Asymmetric weight: q for positive residuals, (1-q) for negative.
392                let asym = if residuals[i] >= F::zero() { q } else { one - q };
393                weights[i] = asym / abs_r;
394            }
395
396            // Working response is y_work itself (we re-solve for w directly).
397            w = weighted_l1_solve(&x_work, &y_work, &weights, self.alpha, &w_prev)?;
398            w_prev = w.mapv(|v| v.abs().max(eps));
399
400            // Check convergence.
401            let max_change = w
402                .iter()
403                .zip(w_old.iter())
404                .map(|(&wn, &wo)| (wn - wo).abs())
405                .fold(F::zero(), |a, b| if b > a { b } else { a });
406
407            if max_change < self.tol {
408                break;
409            }
410        }
411
412        let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
413            *ym - xm.dot(&w)
414        } else {
415            F::zero()
416        };
417
418        Ok(FittedQuantileRegressor {
419            coefficients: w,
420            intercept,
421        })
422    }
423}
424
425// ---------------------------------------------------------------------------
426// Predict / HasCoefficients / Pipeline
427// ---------------------------------------------------------------------------
428
429impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
430    for FittedQuantileRegressor<F>
431{
432    type Output = Array1<F>;
433    type Error = FerroError;
434
435    /// Predict target values for the given feature matrix.
436    ///
437    /// Computes `X @ coefficients + intercept`.
438    ///
439    /// # Errors
440    ///
441    /// Returns [`FerroError::ShapeMismatch`] if the number of features
442    /// does not match the fitted model.
443    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
444        if x.ncols() != self.coefficients.len() {
445            return Err(FerroError::ShapeMismatch {
446                expected: vec![self.coefficients.len()],
447                actual: vec![x.ncols()],
448                context: "number of features must match fitted model".into(),
449            });
450        }
451        Ok(x.dot(&self.coefficients) + self.intercept)
452    }
453}
454
455impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
456    for FittedQuantileRegressor<F>
457{
458    fn coefficients(&self) -> &Array1<F> {
459        &self.coefficients
460    }
461
462    fn intercept(&self) -> F {
463        self.intercept
464    }
465}
466
467impl<F> PipelineEstimator<F> for QuantileRegressor<F>
468where
469    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
470{
471    fn fit_pipeline(
472        &self,
473        x: &Array2<F>,
474        y: &Array1<F>,
475    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
476        let fitted = self.fit(x, y)?;
477        Ok(Box::new(fitted))
478    }
479}
480
481impl<F> FittedPipelineEstimator<F> for FittedQuantileRegressor<F>
482where
483    F: Float + ScalarOperand + Send + Sync + 'static,
484{
485    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
486        self.predict(x)
487    }
488}
489
490// ---------------------------------------------------------------------------
491// Tests
492// ---------------------------------------------------------------------------
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use approx::assert_relative_eq;
498    use ndarray::array;
499
500    #[test]
501    fn test_defaults() {
502        let m = QuantileRegressor::<f64>::new();
503        assert_relative_eq!(m.quantile, 0.5);
504        assert_relative_eq!(m.alpha, 1.0);
505        assert_eq!(m.max_iter, 1000);
506        assert!(m.fit_intercept);
507    }
508
509    #[test]
510    fn test_builder() {
511        let m = QuantileRegressor::<f64>::new()
512            .with_quantile(0.9)
513            .with_alpha(0.5)
514            .with_max_iter(500)
515            .with_tol(1e-8)
516            .with_fit_intercept(false);
517        assert_relative_eq!(m.quantile, 0.9);
518        assert_relative_eq!(m.alpha, 0.5);
519        assert_eq!(m.max_iter, 500);
520        assert!(!m.fit_intercept);
521    }
522
523    #[test]
524    fn test_shape_mismatch() {
525        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
526        let y = array![1.0, 2.0];
527        assert!(QuantileRegressor::<f64>::new().fit(&x, &y).is_err());
528    }
529
530    #[test]
531    fn test_invalid_quantile_zero() {
532        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
533        let y = array![1.0, 2.0, 3.0];
534        assert!(QuantileRegressor::<f64>::new()
535            .with_quantile(0.0)
536            .fit(&x, &y)
537            .is_err());
538    }
539
540    #[test]
541    fn test_invalid_quantile_one() {
542        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
543        let y = array![1.0, 2.0, 3.0];
544        assert!(QuantileRegressor::<f64>::new()
545            .with_quantile(1.0)
546            .fit(&x, &y)
547            .is_err());
548    }
549
550    #[test]
551    fn test_negative_alpha() {
552        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
553        let y = array![1.0, 2.0, 3.0];
554        assert!(QuantileRegressor::<f64>::new()
555            .with_alpha(-1.0)
556            .fit(&x, &y)
557            .is_err());
558    }
559
560    #[test]
561    fn test_median_regression_clean_data() {
562        // On clean linear data, median regression should approximate OLS.
563        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
564        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
565
566        let fitted = QuantileRegressor::<f64>::new()
567            .with_alpha(0.0)
568            .with_max_iter(2000)
569            .fit(&x, &y)
570            .unwrap();
571
572        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.5);
573        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1.0);
574    }
575
576    #[test]
577    fn test_predict_length() {
578        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
579        let y = array![2.0, 4.0, 6.0, 8.0];
580
581        let fitted = QuantileRegressor::<f64>::new()
582            .with_alpha(0.0)
583            .fit(&x, &y)
584            .unwrap();
585        let preds = fitted.predict(&x).unwrap();
586        assert_eq!(preds.len(), 4);
587    }
588
589    #[test]
590    fn test_predict_feature_mismatch() {
591        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
592        let y = array![1.0, 2.0, 3.0];
593        let fitted = QuantileRegressor::<f64>::new().fit(&x, &y).unwrap();
594        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
595        assert!(fitted.predict(&x_bad).is_err());
596    }
597
598    #[test]
599    fn test_has_coefficients() {
600        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
601        let y = array![1.0, 2.0, 3.0];
602        let fitted = QuantileRegressor::<f64>::new().fit(&x, &y).unwrap();
603        assert_eq!(fitted.coefficients().len(), 2);
604    }
605
606    #[test]
607    fn test_no_intercept() {
608        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
609        let y = array![2.0, 4.0, 6.0, 8.0];
610
611        let fitted = QuantileRegressor::<f64>::new()
612            .with_alpha(0.0)
613            .with_fit_intercept(false)
614            .fit(&x, &y)
615            .unwrap();
616        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
617    }
618
619    #[test]
620    fn test_high_quantile_higher_prediction() {
621        // A higher quantile should generally yield higher predicted values.
622        let x = Array2::from_shape_vec(
623            (10, 1),
624            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
625        )
626        .unwrap();
627        // y with some noise.
628        let y = array![2.5, 3.8, 6.2, 7.9, 10.5, 12.1, 14.3, 15.8, 18.2, 20.5];
629
630        let fitted_low = QuantileRegressor::<f64>::new()
631            .with_quantile(0.1)
632            .with_alpha(0.0)
633            .fit(&x, &y)
634            .unwrap();
635        let fitted_high = QuantileRegressor::<f64>::new()
636            .with_quantile(0.9)
637            .with_alpha(0.0)
638            .fit(&x, &y)
639            .unwrap();
640
641        let x_test = Array2::from_shape_vec((1, 1), vec![5.5]).unwrap();
642        let pred_low = fitted_low.predict(&x_test).unwrap()[0];
643        let pred_high = fitted_high.predict(&x_test).unwrap()[0];
644
645        assert!(
646            pred_high >= pred_low,
647            "q=0.9 prediction ({pred_high}) should be >= q=0.1 prediction ({pred_low})"
648        );
649    }
650
651    #[test]
652    fn test_pipeline() {
653        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
654        let y = array![3.0, 5.0, 7.0, 9.0];
655        let model = QuantileRegressor::<f64>::new().with_alpha(0.0);
656        let fitted = model.fit_pipeline(&x, &y).unwrap();
657        let preds = fitted.predict_pipeline(&x).unwrap();
658        assert_eq!(preds.len(), 4);
659    }
660}