Skip to main content

irithyll_core/ensemble/
quantile_regressor.rs

1//! Non-crossing multi-quantile regression via parallel SGBT ensembles.
2//!
3//! Wraps K independent [`SGBT<QuantileLoss>`] instances, one per quantile
4//! level. Predictions are post-processed with the Pool Adjacent Violators
5//! Algorithm (PAVA) to enforce monotonicity -- guaranteeing that the predicted
6//! quantile at tau_i <= tau_j implies q_hat(tau_i) <= q_hat(tau_j).
7//!
8//! Without PAVA enforcement, independently trained quantile models can
9//! produce *crossing* predictions (e.g., the 90th percentile prediction
10//! falls below the 10th percentile), which is incoherent. PAVA resolves
11//! crossings in O(K) time via isotonic regression.
12//!
13//! # Example
14//!
15//! ```text
16//! use irithyll::ensemble::quantile_regressor::QuantileRegressorSGBT;
17//! use irithyll::SGBTConfig;
18//!
19//! let config = SGBTConfig::builder()
20//!     .n_steps(10)
21//!     .learning_rate(0.1)
22//!     .grace_period(10)
23//!     .build()
24//!     .unwrap();
25//!
26//! // 90% prediction interval: [5th, 50th, 95th] percentiles
27//! let quantiles = vec![0.05, 0.5, 0.95];
28//! let mut model = QuantileRegressorSGBT::new(config, &quantiles).unwrap();
29//!
30//! model.train_one(&irithyll::Sample::new(vec![1.0, 2.0], 3.0));
31//! let preds = model.predict(&[1.0, 2.0]);
32//! assert_eq!(preds.len(), 3);
33//! // Guaranteed: preds[0] <= preds[1] <= preds[2]
34//! ```
35
36use alloc::vec::Vec;
37
38use crate::ensemble::config::SGBTConfig;
39use crate::ensemble::SGBT;
40use crate::error::{ConfigError, IrithyllError};
41use crate::loss::quantile::QuantileLoss;
42use crate::sample::Observation;
43
44// ---------------------------------------------------------------------------
45// PAVA -- Pool Adjacent Violators Algorithm
46// ---------------------------------------------------------------------------
47
48/// Enforce monotonicity on a slice of values via isotonic regression.
49///
50/// The Pool Adjacent Violators Algorithm (PAVA) scans left-to-right,
51/// merging adjacent blocks whose values violate the non-decreasing
52/// constraint. When a violation is found, the two blocks are pooled
53/// and replaced by their weighted average. This continues until the
54/// entire sequence is non-decreasing.
55///
56/// Runs in O(K) time and O(K) space where K is the number of values.
57///
58/// # Arguments
59///
60/// * `values` -- mutable slice modified in-place to be non-decreasing
61fn enforce_monotonicity(values: &mut [f64]) {
62    let n = values.len();
63    if n <= 1 {
64        return;
65    }
66
67    // Each block is (sum, count, start_idx)
68    // We use a stack-based approach for efficiency
69    let mut block_sums: Vec<f64> = Vec::with_capacity(n);
70    let mut block_counts: Vec<usize> = Vec::with_capacity(n);
71    let mut block_starts: Vec<usize> = Vec::with_capacity(n);
72
73    for (i, &val) in values.iter().enumerate() {
74        // Push new singleton block
75        block_sums.push(val);
76        block_counts.push(1);
77        block_starts.push(i);
78
79        // Merge backward while violation exists
80        while block_sums.len() >= 2 {
81            let len = block_sums.len();
82            let mean_last = block_sums[len - 1] / block_counts[len - 1] as f64;
83            let mean_prev = block_sums[len - 2] / block_counts[len - 2] as f64;
84
85            if mean_prev <= mean_last {
86                break; // No violation
87            }
88
89            // Pool: merge last block into previous
90            block_sums[len - 2] += block_sums[len - 1];
91            block_counts[len - 2] += block_counts[len - 1];
92            block_sums.pop();
93            block_counts.pop();
94            block_starts.pop();
95        }
96    }
97
98    // Expand blocks back into values
99    for b in 0..block_sums.len() {
100        let mean = block_sums[b] / block_counts[b] as f64;
101        let start = block_starts[b];
102        let end = if b + 1 < block_starts.len() {
103            block_starts[b + 1]
104        } else {
105            n
106        };
107        for v in values[start..end].iter_mut() {
108            *v = mean;
109        }
110    }
111}
112
113// ---------------------------------------------------------------------------
114// QuantileRegressorSGBT
115// ---------------------------------------------------------------------------
116
117/// Non-crossing multi-quantile SGBT regressor.
118///
119/// Maintains K independent `SGBT<QuantileLoss>` models -- one per quantile
120/// level -- and applies the Pool Adjacent Violators Algorithm (PAVA) to
121/// predictions to guarantee non-crossing quantile estimates.
122///
123/// # Non-crossing guarantee
124///
125/// For quantile levels tau_1 < tau_2 < ... < tau_K, the predicted values
126/// q_hat(tau_1) <= q_hat(tau_2) <= ... <= q_hat(tau_K) after PAVA
127/// enforcement. This makes the output suitable for prediction intervals,
128/// conditional density estimation, and risk quantification.
129pub struct QuantileRegressorSGBT {
130    /// One SGBT per quantile level, each with its own QuantileLoss.
131    models: Vec<SGBT<QuantileLoss>>,
132    /// Sorted quantile levels in (0, 1).
133    quantiles: Vec<f64>,
134    /// Number of quantile levels.
135    n_quantiles: usize,
136    /// Total samples seen.
137    samples_seen: u64,
138}
139
140impl Clone for QuantileRegressorSGBT {
141    fn clone(&self) -> Self {
142        Self {
143            models: self.models.clone(),
144            quantiles: self.quantiles.clone(),
145            n_quantiles: self.n_quantiles,
146            samples_seen: self.samples_seen,
147        }
148    }
149}
150
151impl core::fmt::Debug for QuantileRegressorSGBT {
152    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
153        f.debug_struct("QuantileRegressorSGBT")
154            .field("quantiles", &self.quantiles)
155            .field("n_quantiles", &self.n_quantiles)
156            .field("samples_seen", &self.samples_seen)
157            .finish()
158    }
159}
160
161impl QuantileRegressorSGBT {
162    /// Create a new multi-quantile regressor.
163    ///
164    /// Quantile levels are automatically sorted. Each level must be in (0, 1)
165    /// and there must be at least one level.
166    ///
167    /// # Errors
168    ///
169    /// Returns [`IrithyllError::InvalidConfig`] if:
170    /// - `quantiles` is empty
171    /// - any quantile is not in (0, 1)
172    /// - duplicate quantile levels exist
173    pub fn new(config: SGBTConfig, quantiles: &[f64]) -> crate::error::Result<Self> {
174        if quantiles.is_empty() {
175            return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
176                "quantiles",
177                "must have at least one quantile level",
178                0usize,
179            )));
180        }
181
182        // Validate and sort
183        let mut sorted: Vec<f64> = quantiles.to_vec();
184        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
185
186        for (i, &tau) in sorted.iter().enumerate() {
187            if tau <= 0.0 || tau >= 1.0 {
188                return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
189                    "quantiles",
190                    "each quantile must be in (0, 1)",
191                    tau,
192                )));
193            }
194            // Check for duplicates
195            if i > 0 && crate::math::abs(sorted[i] - sorted[i - 1]) < 1e-15 {
196                return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
197                    "quantiles",
198                    "duplicate quantile levels are not allowed",
199                    tau,
200                )));
201            }
202        }
203
204        let n_quantiles = sorted.len();
205        let models = sorted
206            .iter()
207            .map(|&tau| SGBT::with_loss(config.clone(), QuantileLoss::new(tau)))
208            .collect();
209
210        Ok(Self {
211            models,
212            quantiles: sorted,
213            n_quantiles,
214            samples_seen: 0,
215        })
216    }
217
218    /// Train all quantile models on a single observation.
219    ///
220    /// Each model independently learns from the same sample, targeting
221    /// its respective quantile level.
222    pub fn train_one(&mut self, sample: &impl Observation) {
223        self.samples_seen += 1;
224        for model in &mut self.models {
225            model.train_one(sample);
226        }
227    }
228
229    /// Train on a batch of observations.
230    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
231        for sample in samples {
232            self.train_one(sample);
233        }
234    }
235
236    /// Predict quantile values with PAVA non-crossing enforcement.
237    ///
238    /// Returns a `Vec<f64>` of length `n_quantiles` where:
239    /// - `result[i]` is the predicted quantile at level `quantiles[i]`
240    /// - `result[0] <= result[1] <= ... <= result[K-1]` (guaranteed)
241    pub fn predict(&self, features: &[f64]) -> Vec<f64> {
242        let mut preds: Vec<f64> = self.models.iter().map(|m| m.predict(features)).collect();
243        enforce_monotonicity(&mut preds);
244        preds
245    }
246
247    /// Predict raw quantile values WITHOUT non-crossing enforcement.
248    ///
249    /// This may produce crossing predictions. Use [`predict`](Self::predict)
250    /// for the PAVA-enforced version.
251    pub fn predict_raw(&self, features: &[f64]) -> Vec<f64> {
252        self.models.iter().map(|m| m.predict(features)).collect()
253    }
254
255    /// Predict a symmetric prediction interval at coverage level `1 - alpha`.
256    ///
257    /// Returns `(lower, median, upper)` if the model was constructed with
258    /// quantile levels that include `alpha/2`, `0.5`, and `1 - alpha/2`.
259    ///
260    /// If the exact levels are not present, returns the closest available
261    /// quantile predictions with PAVA enforcement applied.
262    pub fn predict_interval(&self, features: &[f64]) -> (f64, f64, f64) {
263        let preds = self.predict(features);
264        let lower = preds[0];
265        let upper = preds[preds.len() - 1];
266        // Use the middle quantile as the point estimate
267        let mid_idx = preds.len() / 2;
268        let median = preds[mid_idx];
269        (lower, median, upper)
270    }
271
272    /// Batch prediction with PAVA enforcement.
273    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<Vec<f64>> {
274        feature_matrix.iter().map(|f| self.predict(f)).collect()
275    }
276
277    /// Number of quantile levels.
278    #[inline]
279    pub fn n_quantiles(&self) -> usize {
280        self.n_quantiles
281    }
282
283    /// The sorted quantile levels.
284    pub fn quantiles(&self) -> &[f64] {
285        &self.quantiles
286    }
287
288    /// Total samples seen.
289    #[inline]
290    pub fn n_samples_seen(&self) -> u64 {
291        self.samples_seen
292    }
293
294    /// Access the model for a specific quantile index.
295    ///
296    /// # Panics
297    ///
298    /// Panics if `idx >= n_quantiles`.
299    pub fn model(&self, idx: usize) -> &SGBT<QuantileLoss> {
300        &self.models[idx]
301    }
302
303    /// Access all quantile models.
304    pub fn models(&self) -> &[SGBT<QuantileLoss>] {
305        &self.models
306    }
307
308    /// Reset all quantile models.
309    pub fn reset(&mut self) {
310        for model in &mut self.models {
311            model.reset();
312        }
313        self.samples_seen = 0;
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::sample::Sample;
321    use alloc::vec;
322    use alloc::vec::Vec;
323
324    fn test_config() -> SGBTConfig {
325        SGBTConfig::builder()
326            .n_steps(10)
327            .learning_rate(0.1)
328            .grace_period(10)
329            .initial_target_count(5)
330            .build()
331            .unwrap()
332    }
333
334    // -----------------------------------------------------------------------
335    // PAVA unit tests
336    // -----------------------------------------------------------------------
337
338    #[test]
339    fn pava_already_sorted() {
340        let mut values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
341        enforce_monotonicity(&mut values);
342        assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
343    }
344
345    #[test]
346    fn pava_single_element() {
347        let mut values = vec![42.0];
348        enforce_monotonicity(&mut values);
349        assert_eq!(values, vec![42.0]);
350    }
351
352    #[test]
353    fn pava_empty() {
354        let mut values: Vec<f64> = vec![];
355        enforce_monotonicity(&mut values);
356        assert!(values.is_empty());
357    }
358
359    #[test]
360    fn pava_simple_violation() {
361        // [3, 1, 2] -> 3 > 1 violation
362        // Pool first two: mean(3,1) = 2.0 -> [2, 2, 2]
363        // Then 2 <= 2 OK
364        let mut values = vec![3.0, 1.0, 2.0];
365        enforce_monotonicity(&mut values);
366        // After pooling 3,1 -> 2.0, then 2.0 <= 2.0, so all become 2.0
367        assert!((values[0] - 2.0).abs() < 1e-10);
368        assert!((values[1] - 2.0).abs() < 1e-10);
369        assert!((values[2] - 2.0).abs() < 1e-10);
370    }
371
372    #[test]
373    fn pava_reversed() {
374        // Fully reversed: [5, 4, 3, 2, 1] -> all become mean = 3.0
375        let mut values = vec![5.0, 4.0, 3.0, 2.0, 1.0];
376        enforce_monotonicity(&mut values);
377        let mean = 3.0;
378        for v in &values {
379            assert!((v - mean).abs() < 1e-10, "expected {mean}, got {v}");
380        }
381    }
382
383    #[test]
384    fn pava_partial_violation() {
385        // [1, 5, 3, 4, 6] -- violation at 5 > 3
386        // Pool 5,3 -> 4.0: [1, 4, 4, 4, 6]
387        let mut values = vec![1.0, 5.0, 3.0, 4.0, 6.0];
388        enforce_monotonicity(&mut values);
389        // Result should be non-decreasing
390        for i in 1..values.len() {
391            assert!(
392                values[i] >= values[i - 1] - 1e-10,
393                "violation at index {i}: {} < {}",
394                values[i],
395                values[i - 1]
396            );
397        }
398        // First should still be 1.0
399        assert!((values[0] - 1.0).abs() < 1e-10);
400        // Last should still be 6.0
401        assert!((values[4] - 6.0).abs() < 1e-10);
402    }
403
404    #[test]
405    fn pava_equal_values() {
406        let mut values = vec![3.0, 3.0, 3.0];
407        enforce_monotonicity(&mut values);
408        assert_eq!(values, vec![3.0, 3.0, 3.0]);
409    }
410
411    #[test]
412    fn pava_two_elements_violation() {
413        let mut values = vec![5.0, 1.0];
414        enforce_monotonicity(&mut values);
415        assert!((values[0] - 3.0).abs() < 1e-10);
416        assert!((values[1] - 3.0).abs() < 1e-10);
417    }
418
419    // -----------------------------------------------------------------------
420    // QuantileRegressorSGBT tests
421    // -----------------------------------------------------------------------
422
423    #[test]
424    fn creates_correct_number_of_models() {
425        let model = QuantileRegressorSGBT::new(test_config(), &[0.1, 0.5, 0.9]).unwrap();
426        assert_eq!(model.n_quantiles(), 3);
427        assert_eq!(model.models().len(), 3);
428        assert_eq!(model.n_samples_seen(), 0);
429    }
430
431    #[test]
432    fn quantiles_are_sorted() {
433        let model = QuantileRegressorSGBT::new(test_config(), &[0.9, 0.1, 0.5]).unwrap();
434        assert_eq!(model.quantiles(), &[0.1, 0.5, 0.9]);
435    }
436
437    #[test]
438    fn rejects_empty_quantiles() {
439        let result = QuantileRegressorSGBT::new(test_config(), &[]);
440        assert!(result.is_err());
441    }
442
443    #[test]
444    fn rejects_invalid_quantile_zero() {
445        let result = QuantileRegressorSGBT::new(test_config(), &[0.0, 0.5]);
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn rejects_invalid_quantile_one() {
451        let result = QuantileRegressorSGBT::new(test_config(), &[0.5, 1.0]);
452        assert!(result.is_err());
453    }
454
455    #[test]
456    fn rejects_duplicate_quantiles() {
457        let result = QuantileRegressorSGBT::new(test_config(), &[0.5, 0.5, 0.9]);
458        assert!(result.is_err());
459    }
460
461    #[test]
462    fn single_quantile_works() {
463        let mut model = QuantileRegressorSGBT::new(test_config(), &[0.5]).unwrap();
464        for i in 0..50 {
465            let x = i as f64 * 0.1;
466            model.train_one(&Sample::new(vec![x], x * 2.0));
467        }
468        let preds = model.predict(&[0.5]);
469        assert_eq!(preds.len(), 1);
470        assert!(preds[0].is_finite());
471    }
472
473    #[test]
474    fn predictions_are_non_crossing() {
475        let config = SGBTConfig::builder()
476            .n_steps(10)
477            .learning_rate(0.1)
478            .grace_period(10)
479            .initial_target_count(5)
480            .build()
481            .unwrap();
482
483        let mut model = QuantileRegressorSGBT::new(config, &[0.05, 0.25, 0.5, 0.75, 0.95]).unwrap();
484
485        // Train on noisy linear data
486        let mut rng: u64 = 42;
487        for _ in 0..200 {
488            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
489            let x = (rng >> 33) as f64 / (u32::MAX as f64) * 10.0;
490            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
491            let noise = ((rng >> 33) as f64 / (u32::MAX as f64) - 0.5) * 2.0;
492            let y = 3.0 * x + noise;
493            model.train_one(&Sample::new(vec![x], y));
494        }
495
496        // Test at multiple points -- predictions must be non-decreasing
497        let test_points = [0.0, 1.0, 3.0, 5.0, 8.0, 10.0];
498        for &x in &test_points {
499            let preds = model.predict(&[x]);
500            for i in 1..preds.len() {
501                assert!(
502                    preds[i] >= preds[i - 1] - 1e-10,
503                    "crossing at x={x}: q[{i}]={} < q[{}]={}",
504                    preds[i],
505                    i - 1,
506                    preds[i - 1]
507                );
508            }
509        }
510    }
511
512    #[test]
513    fn raw_predict_may_cross() {
514        // Raw predictions don't have PAVA -- they may cross
515        // (Just verify the method works, crossings aren't guaranteed)
516        let mut model = QuantileRegressorSGBT::new(test_config(), &[0.1, 0.5, 0.9]).unwrap();
517
518        for i in 0..100 {
519            let x = i as f64 * 0.1;
520            model.train_one(&Sample::new(vec![x], x));
521        }
522
523        let raw = model.predict_raw(&[0.5]);
524        assert_eq!(raw.len(), 3);
525        for v in &raw {
526            assert!(v.is_finite());
527        }
528    }
529
530    #[test]
531    fn predict_interval_returns_triple() {
532        let mut model = QuantileRegressorSGBT::new(test_config(), &[0.05, 0.5, 0.95]).unwrap();
533
534        for i in 0..100 {
535            let x = i as f64 * 0.1;
536            model.train_one(&Sample::new(vec![x], x * 2.0 + 1.0));
537        }
538
539        let (lower, median, upper) = model.predict_interval(&[0.5]);
540        assert!(lower <= median, "lower={lower} > median={median}");
541        assert!(median <= upper, "median={median} > upper={upper}");
542    }
543
544    #[test]
545    fn batch_prediction() {
546        let mut model = QuantileRegressorSGBT::new(test_config(), &[0.1, 0.5, 0.9]).unwrap();
547
548        for i in 0..100 {
549            let x = i as f64 * 0.1;
550            model.train_one(&Sample::new(vec![x], x));
551        }
552
553        let features = vec![vec![0.5], vec![1.0], vec![2.0]];
554        let batch = model.predict_batch(&features);
555        assert_eq!(batch.len(), 3);
556        for preds in &batch {
557            assert_eq!(preds.len(), 3);
558        }
559    }
560
561    #[test]
562    fn reset_clears_state() {
563        let mut model = QuantileRegressorSGBT::new(test_config(), &[0.1, 0.5, 0.9]).unwrap();
564
565        for i in 0..100 {
566            let x = i as f64;
567            model.train_one(&Sample::new(vec![x], x));
568        }
569        assert!(model.n_samples_seen() > 0);
570
571        model.reset();
572        assert_eq!(model.n_samples_seen(), 0);
573    }
574
575    #[test]
576    fn deterministic_with_same_config() {
577        let config = test_config();
578        let quantiles = [0.1, 0.5, 0.9];
579        let mut model1 = QuantileRegressorSGBT::new(config.clone(), &quantiles).unwrap();
580        let mut model2 = QuantileRegressorSGBT::new(config, &quantiles).unwrap();
581
582        let samples: Vec<Sample> = (0..50)
583            .map(|i| {
584                let x = i as f64 * 0.1;
585                Sample::new(vec![x], x * 3.0)
586            })
587            .collect();
588
589        for s in &samples {
590            model1.train_one(s);
591            model2.train_one(s);
592        }
593
594        let pred1 = model1.predict(&[0.5]);
595        let pred2 = model2.predict(&[0.5]);
596        for (a, b) in pred1.iter().zip(pred2.iter()) {
597            assert!(
598                (a - b).abs() < 1e-10,
599                "same config should give identical predictions: {a} vs {b}"
600            );
601        }
602    }
603
604    #[test]
605    fn higher_quantile_predicts_higher_after_training() {
606        let config = SGBTConfig::builder()
607            .n_steps(20)
608            .learning_rate(0.1)
609            .grace_period(10)
610            .initial_target_count(5)
611            .build()
612            .unwrap();
613
614        let mut model = QuantileRegressorSGBT::new(config, &[0.1, 0.5, 0.9]).unwrap();
615
616        // Train on data with spread
617        let mut rng: u64 = 99;
618        for _ in 0..500 {
619            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
620            let x = (rng >> 33) as f64 / (u32::MAX as f64) * 10.0;
621            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
622            let noise = ((rng >> 33) as f64 / (u32::MAX as f64) - 0.5) * 4.0;
623            model.train_one(&Sample::new(vec![x], x + noise));
624        }
625
626        let preds = model.predict(&[5.0]);
627        // With enough training, higher quantiles should predict higher
628        assert!(
629            preds[2] > preds[0],
630            "90th percentile ({}) should be > 10th percentile ({})",
631            preds[2],
632            preds[0]
633        );
634    }
635}