1use alloc::vec::Vec;
37
38use crate::ensemble::config::SGBTConfig;
39use crate::ensemble::SGBT;
40use crate::error::{ConfigError, IrithyllError};
41use crate::loss::quantile::QuantileLoss;
42use crate::sample::Observation;
43
44fn enforce_monotonicity(values: &mut [f64]) {
62 let n = values.len();
63 if n <= 1 {
64 return;
65 }
66
67 let mut block_sums: Vec<f64> = Vec::with_capacity(n);
70 let mut block_counts: Vec<usize> = Vec::with_capacity(n);
71 let mut block_starts: Vec<usize> = Vec::with_capacity(n);
72
73 for (i, &val) in values.iter().enumerate() {
74 block_sums.push(val);
76 block_counts.push(1);
77 block_starts.push(i);
78
79 while block_sums.len() >= 2 {
81 let len = block_sums.len();
82 let mean_last = block_sums[len - 1] / block_counts[len - 1] as f64;
83 let mean_prev = block_sums[len - 2] / block_counts[len - 2] as f64;
84
85 if mean_prev <= mean_last {
86 break; }
88
89 block_sums[len - 2] += block_sums[len - 1];
91 block_counts[len - 2] += block_counts[len - 1];
92 block_sums.pop();
93 block_counts.pop();
94 block_starts.pop();
95 }
96 }
97
98 for b in 0..block_sums.len() {
100 let mean = block_sums[b] / block_counts[b] as f64;
101 let start = block_starts[b];
102 let end = if b + 1 < block_starts.len() {
103 block_starts[b + 1]
104 } else {
105 n
106 };
107 for v in values[start..end].iter_mut() {
108 *v = mean;
109 }
110 }
111}
112
113pub struct QuantileRegressorSGBT {
130 models: Vec<SGBT<QuantileLoss>>,
132 quantiles: Vec<f64>,
134 n_quantiles: usize,
136 samples_seen: u64,
138}
139
140impl Clone for QuantileRegressorSGBT {
141 fn clone(&self) -> Self {
142 Self {
143 models: self.models.clone(),
144 quantiles: self.quantiles.clone(),
145 n_quantiles: self.n_quantiles,
146 samples_seen: self.samples_seen,
147 }
148 }
149}
150
151impl core::fmt::Debug for QuantileRegressorSGBT {
152 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
153 f.debug_struct("QuantileRegressorSGBT")
154 .field("quantiles", &self.quantiles)
155 .field("n_quantiles", &self.n_quantiles)
156 .field("samples_seen", &self.samples_seen)
157 .finish()
158 }
159}
160
161impl QuantileRegressorSGBT {
162 pub fn new(config: SGBTConfig, quantiles: &[f64]) -> crate::error::Result<Self> {
174 if quantiles.is_empty() {
175 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
176 "quantiles",
177 "must have at least one quantile level",
178 0usize,
179 )));
180 }
181
182 let mut sorted: Vec<f64> = quantiles.to_vec();
184 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
185
186 for (i, &tau) in sorted.iter().enumerate() {
187 if tau <= 0.0 || tau >= 1.0 {
188 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
189 "quantiles",
190 "each quantile must be in (0, 1)",
191 tau,
192 )));
193 }
194 if i > 0 && crate::math::abs(sorted[i] - sorted[i - 1]) < 1e-15 {
196 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
197 "quantiles",
198 "duplicate quantile levels are not allowed",
199 tau,
200 )));
201 }
202 }
203
204 let n_quantiles = sorted.len();
205 let models = sorted
206 .iter()
207 .map(|&tau| SGBT::with_loss(config.clone(), QuantileLoss::new(tau)))
208 .collect();
209
210 Ok(Self {
211 models,
212 quantiles: sorted,
213 n_quantiles,
214 samples_seen: 0,
215 })
216 }
217
218 pub fn train_one(&mut self, sample: &impl Observation) {
223 self.samples_seen += 1;
224 for model in &mut self.models {
225 model.train_one(sample);
226 }
227 }
228
229 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
231 for sample in samples {
232 self.train_one(sample);
233 }
234 }
235
236 pub fn predict(&self, features: &[f64]) -> Vec<f64> {
242 let mut preds: Vec<f64> = self.models.iter().map(|m| m.predict(features)).collect();
243 enforce_monotonicity(&mut preds);
244 preds
245 }
246
247 pub fn predict_raw(&self, features: &[f64]) -> Vec<f64> {
252 self.models.iter().map(|m| m.predict(features)).collect()
253 }
254
255 pub fn predict_interval(&self, features: &[f64]) -> (f64, f64, f64) {
263 let preds = self.predict(features);
264 let lower = preds[0];
265 let upper = preds[preds.len() - 1];
266 let mid_idx = preds.len() / 2;
268 let median = preds[mid_idx];
269 (lower, median, upper)
270 }
271
272 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<Vec<f64>> {
274 feature_matrix.iter().map(|f| self.predict(f)).collect()
275 }
276
277 #[inline]
279 pub fn n_quantiles(&self) -> usize {
280 self.n_quantiles
281 }
282
283 pub fn quantiles(&self) -> &[f64] {
285 &self.quantiles
286 }
287
288 #[inline]
290 pub fn n_samples_seen(&self) -> u64 {
291 self.samples_seen
292 }
293
294 pub fn model(&self, idx: usize) -> &SGBT<QuantileLoss> {
300 &self.models[idx]
301 }
302
303 pub fn models(&self) -> &[SGBT<QuantileLoss>] {
305 &self.models
306 }
307
308 pub fn reset(&mut self) {
310 for model in &mut self.models {
311 model.reset();
312 }
313 self.samples_seen = 0;
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::sample::Sample;
321 use alloc::vec;
322 use alloc::vec::Vec;
323
324 fn test_config() -> SGBTConfig {
325 SGBTConfig::builder()
326 .n_steps(10)
327 .learning_rate(0.1)
328 .grace_period(10)
329 .initial_target_count(5)
330 .build()
331 .unwrap()
332 }
333
334 #[test]
339 fn pava_already_sorted() {
340 let mut values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
341 enforce_monotonicity(&mut values);
342 assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
343 }
344
345 #[test]
346 fn pava_single_element() {
347 let mut values = vec![42.0];
348 enforce_monotonicity(&mut values);
349 assert_eq!(values, vec![42.0]);
350 }
351
352 #[test]
353 fn pava_empty() {
354 let mut values: Vec<f64> = vec![];
355 enforce_monotonicity(&mut values);
356 assert!(values.is_empty());
357 }
358
359 #[test]
360 fn pava_simple_violation() {
361 let mut values = vec![3.0, 1.0, 2.0];
365 enforce_monotonicity(&mut values);
366 assert!((values[0] - 2.0).abs() < 1e-10);
368 assert!((values[1] - 2.0).abs() < 1e-10);
369 assert!((values[2] - 2.0).abs() < 1e-10);
370 }
371
372 #[test]
373 fn pava_reversed() {
374 let mut values = vec![5.0, 4.0, 3.0, 2.0, 1.0];
376 enforce_monotonicity(&mut values);
377 let mean = 3.0;
378 for v in &values {
379 assert!((v - mean).abs() < 1e-10, "expected {mean}, got {v}");
380 }
381 }
382
383 #[test]
384 fn pava_partial_violation() {
385 let mut values = vec![1.0, 5.0, 3.0, 4.0, 6.0];
388 enforce_monotonicity(&mut values);
389 for i in 1..values.len() {
391 assert!(
392 values[i] >= values[i - 1] - 1e-10,
393 "violation at index {i}: {} < {}",
394 values[i],
395 values[i - 1]
396 );
397 }
398 assert!((values[0] - 1.0).abs() < 1e-10);
400 assert!((values[4] - 6.0).abs() < 1e-10);
402 }
403
404 #[test]
405 fn pava_equal_values() {
406 let mut values = vec![3.0, 3.0, 3.0];
407 enforce_monotonicity(&mut values);
408 assert_eq!(values, vec![3.0, 3.0, 3.0]);
409 }
410
411 #[test]
412 fn pava_two_elements_violation() {
413 let mut values = vec![5.0, 1.0];
414 enforce_monotonicity(&mut values);
415 assert!((values[0] - 3.0).abs() < 1e-10);
416 assert!((values[1] - 3.0).abs() < 1e-10);
417 }
418
419 #[test]
424 fn creates_correct_number_of_models() {
425 let model = QuantileRegressorSGBT::new(test_config(), &[0.1, 0.5, 0.9]).unwrap();
426 assert_eq!(model.n_quantiles(), 3);
427 assert_eq!(model.models().len(), 3);
428 assert_eq!(model.n_samples_seen(), 0);
429 }
430
431 #[test]
432 fn quantiles_are_sorted() {
433 let model = QuantileRegressorSGBT::new(test_config(), &[0.9, 0.1, 0.5]).unwrap();
434 assert_eq!(model.quantiles(), &[0.1, 0.5, 0.9]);
435 }
436
437 #[test]
438 fn rejects_empty_quantiles() {
439 let result = QuantileRegressorSGBT::new(test_config(), &[]);
440 assert!(result.is_err());
441 }
442
443 #[test]
444 fn rejects_invalid_quantile_zero() {
445 let result = QuantileRegressorSGBT::new(test_config(), &[0.0, 0.5]);
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn rejects_invalid_quantile_one() {
451 let result = QuantileRegressorSGBT::new(test_config(), &[0.5, 1.0]);
452 assert!(result.is_err());
453 }
454
455 #[test]
456 fn rejects_duplicate_quantiles() {
457 let result = QuantileRegressorSGBT::new(test_config(), &[0.5, 0.5, 0.9]);
458 assert!(result.is_err());
459 }
460
461 #[test]
462 fn single_quantile_works() {
463 let mut model = QuantileRegressorSGBT::new(test_config(), &[0.5]).unwrap();
464 for i in 0..50 {
465 let x = i as f64 * 0.1;
466 model.train_one(&Sample::new(vec![x], x * 2.0));
467 }
468 let preds = model.predict(&[0.5]);
469 assert_eq!(preds.len(), 1);
470 assert!(preds[0].is_finite());
471 }
472
473 #[test]
474 fn predictions_are_non_crossing() {
475 let config = SGBTConfig::builder()
476 .n_steps(10)
477 .learning_rate(0.1)
478 .grace_period(10)
479 .initial_target_count(5)
480 .build()
481 .unwrap();
482
483 let mut model = QuantileRegressorSGBT::new(config, &[0.05, 0.25, 0.5, 0.75, 0.95]).unwrap();
484
485 let mut rng: u64 = 42;
487 for _ in 0..200 {
488 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
489 let x = (rng >> 33) as f64 / (u32::MAX as f64) * 10.0;
490 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
491 let noise = ((rng >> 33) as f64 / (u32::MAX as f64) - 0.5) * 2.0;
492 let y = 3.0 * x + noise;
493 model.train_one(&Sample::new(vec![x], y));
494 }
495
496 let test_points = [0.0, 1.0, 3.0, 5.0, 8.0, 10.0];
498 for &x in &test_points {
499 let preds = model.predict(&[x]);
500 for i in 1..preds.len() {
501 assert!(
502 preds[i] >= preds[i - 1] - 1e-10,
503 "crossing at x={x}: q[{i}]={} < q[{}]={}",
504 preds[i],
505 i - 1,
506 preds[i - 1]
507 );
508 }
509 }
510 }
511
512 #[test]
513 fn raw_predict_may_cross() {
514 let mut model = QuantileRegressorSGBT::new(test_config(), &[0.1, 0.5, 0.9]).unwrap();
517
518 for i in 0..100 {
519 let x = i as f64 * 0.1;
520 model.train_one(&Sample::new(vec![x], x));
521 }
522
523 let raw = model.predict_raw(&[0.5]);
524 assert_eq!(raw.len(), 3);
525 for v in &raw {
526 assert!(v.is_finite());
527 }
528 }
529
530 #[test]
531 fn predict_interval_returns_triple() {
532 let mut model = QuantileRegressorSGBT::new(test_config(), &[0.05, 0.5, 0.95]).unwrap();
533
534 for i in 0..100 {
535 let x = i as f64 * 0.1;
536 model.train_one(&Sample::new(vec![x], x * 2.0 + 1.0));
537 }
538
539 let (lower, median, upper) = model.predict_interval(&[0.5]);
540 assert!(lower <= median, "lower={lower} > median={median}");
541 assert!(median <= upper, "median={median} > upper={upper}");
542 }
543
544 #[test]
545 fn batch_prediction() {
546 let mut model = QuantileRegressorSGBT::new(test_config(), &[0.1, 0.5, 0.9]).unwrap();
547
548 for i in 0..100 {
549 let x = i as f64 * 0.1;
550 model.train_one(&Sample::new(vec![x], x));
551 }
552
553 let features = vec![vec![0.5], vec![1.0], vec![2.0]];
554 let batch = model.predict_batch(&features);
555 assert_eq!(batch.len(), 3);
556 for preds in &batch {
557 assert_eq!(preds.len(), 3);
558 }
559 }
560
561 #[test]
562 fn reset_clears_state() {
563 let mut model = QuantileRegressorSGBT::new(test_config(), &[0.1, 0.5, 0.9]).unwrap();
564
565 for i in 0..100 {
566 let x = i as f64;
567 model.train_one(&Sample::new(vec![x], x));
568 }
569 assert!(model.n_samples_seen() > 0);
570
571 model.reset();
572 assert_eq!(model.n_samples_seen(), 0);
573 }
574
575 #[test]
576 fn deterministic_with_same_config() {
577 let config = test_config();
578 let quantiles = [0.1, 0.5, 0.9];
579 let mut model1 = QuantileRegressorSGBT::new(config.clone(), &quantiles).unwrap();
580 let mut model2 = QuantileRegressorSGBT::new(config, &quantiles).unwrap();
581
582 let samples: Vec<Sample> = (0..50)
583 .map(|i| {
584 let x = i as f64 * 0.1;
585 Sample::new(vec![x], x * 3.0)
586 })
587 .collect();
588
589 for s in &samples {
590 model1.train_one(s);
591 model2.train_one(s);
592 }
593
594 let pred1 = model1.predict(&[0.5]);
595 let pred2 = model2.predict(&[0.5]);
596 for (a, b) in pred1.iter().zip(pred2.iter()) {
597 assert!(
598 (a - b).abs() < 1e-10,
599 "same config should give identical predictions: {a} vs {b}"
600 );
601 }
602 }
603
604 #[test]
605 fn higher_quantile_predicts_higher_after_training() {
606 let config = SGBTConfig::builder()
607 .n_steps(20)
608 .learning_rate(0.1)
609 .grace_period(10)
610 .initial_target_count(5)
611 .build()
612 .unwrap();
613
614 let mut model = QuantileRegressorSGBT::new(config, &[0.1, 0.5, 0.9]).unwrap();
615
616 let mut rng: u64 = 99;
618 for _ in 0..500 {
619 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
620 let x = (rng >> 33) as f64 / (u32::MAX as f64) * 10.0;
621 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
622 let noise = ((rng >> 33) as f64 / (u32::MAX as f64) - 0.5) * 4.0;
623 model.train_one(&Sample::new(vec![x], x + noise));
624 }
625
626 let preds = model.predict(&[5.0]);
627 assert!(
629 preds[2] > preds[0],
630 "90th percentile ({}) should be > 10th percentile ({})",
631 preds[2],
632 preds[0]
633 );
634 }
635}