scirs2-stats 0.4.1

Statistical functions module for SciRS2 (scirs2-stats)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
//! Adaptive and locally-weighted conformal prediction.
//!
//! Implements several adaptive variants that provide tighter prediction sets
//! than the basic split conformal approach by exploiting local structure:
//!
//! * **Normalized conformal prediction** — divides the residual by a local
//!   difficulty estimate σ̂, yielding heteroscedasticity-aware intervals.
//! * **Conformal Quantile Regression (CQR)** — Romano, Sesia & Candès 2019 —
//!   uses pre-fitted quantile models to build asymmetric adaptive intervals.
//! * **RAPS** — Regularized Adaptive Prediction Sets (Angelopoulos et al. 2021)
//!   for classification: adds a soft-margin regularizer that shrinks large sets.
//! * **Mondrian conformal prediction** — computes per-bin quantiles for
//!   conditional (category-conditional) coverage guarantees.

use crate::conformal::types::{conformal_quantile, PredictionSet, RapsConfig};

// ---------------------------------------------------------------------------
// Utility: k-NN difficulty / quantile estimation
// ---------------------------------------------------------------------------

/// Estimate local difficulty using the mean absolute residual of the `k`
/// nearest neighbours in the *training* set.
///
/// `x_train` — 1-D feature values (sorted or unsorted)
/// `residuals_train` — |y_i − ŷ_i| for each training point
/// `x_query` — query point
/// `k` — neighbourhood size
fn knn_difficulty(x_train: &[f64], residuals_train: &[f64], x_query: f64, k: usize) -> f64 {
    if x_train.is_empty() || k == 0 {
        return 1.0;
    }
    // Compute distances and find k nearest
    let mut dists: Vec<(f64, f64)> = x_train
        .iter()
        .zip(residuals_train.iter())
        .map(|(&xi, &ri)| ((xi - x_query).abs(), ri))
        .collect();
    dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
    let k_eff = k.min(dists.len());
    let mean_residual: f64 = dists[..k_eff].iter().map(|(_, r)| r).sum::<f64>() / k_eff as f64;
    mean_residual.max(1e-8) // avoid division by zero
}

/// k-NN quantile estimate: returns the empirical `level`-quantile of the
/// residuals of the `k` nearest neighbours.
fn knn_quantile(
    x_train: &[f64],
    residuals_train: &[f64],
    x_query: f64,
    k: usize,
    level: f64,
) -> f64 {
    if x_train.is_empty() || k == 0 {
        return 0.0;
    }
    let mut dists: Vec<(f64, f64)> = x_train
        .iter()
        .zip(residuals_train.iter())
        .map(|(&xi, &ri)| ((xi - x_query).abs(), ri))
        .collect();
    dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
    let k_eff = k.min(dists.len());
    let mut vals: Vec<f64> = dists[..k_eff].iter().map(|(_, r)| *r).collect();
    vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
    let idx = ((level * k_eff as f64).ceil() as usize)
        .saturating_sub(1)
        .min(k_eff - 1);
    vals[idx]
}

// ---------------------------------------------------------------------------
// Normalized conformal prediction
// ---------------------------------------------------------------------------

/// Normalized (locally-weighted) conformal predictor for regression.
///
/// The nonconformity score is
/// `s_i = |y_i − ŷ_i| / σ̂_i`
/// where `σ̂_i` is a local difficulty estimate derived from k-NN residuals on
/// the training set.
///
/// This yields *adaptive* intervals: narrow in easy regions, wide where the
/// model is uncertain.
#[derive(Debug, Clone, Default)]
pub struct NormalizedConformal {
    /// Calibration nonconformity scores s_i = |y_i - ŷ_i| / σ̂_i.
    pub calibration_scores: Vec<f64>,
    /// 1-D feature values for calibration points (used for k-NN difficulty).
    pub x_cal: Vec<f64>,
    /// Raw residuals for calibration points.
    pub residuals_cal: Vec<f64>,
    /// Neighbourhood size for k-NN difficulty.
    pub k_neighbors: usize,
}

impl NormalizedConformal {
    /// Create a new predictor with the given `k` for k-NN difficulty.
    pub fn new(k_neighbors: usize) -> Self {
        Self {
            calibration_scores: Vec::new(),
            x_cal: Vec::new(),
            residuals_cal: Vec::new(),
            k_neighbors,
        }
    }

    /// Calibrate using per-point difficulty estimates.
    ///
    /// * `x_cal` — 1-D features of calibration points
    /// * `y_cal` — true labels
    /// * `predictions` — model predictions ŷ_i
    /// * `difficulties` — pre-computed σ̂_i > 0 (if empty, k-NN is used)
    pub fn calibrate(
        &mut self,
        x_cal: &[f64],
        y_cal: &[f64],
        predictions: &[f64],
        difficulties: &[f64],
    ) {
        let raw_residuals: Vec<f64> = y_cal
            .iter()
            .zip(predictions.iter())
            .map(|(y, yhat)| (y - yhat).abs())
            .collect();

        let effective_difficulties: Vec<f64> = if difficulties.is_empty() {
            // Use k-NN average residual as difficulty
            x_cal
                .iter()
                .enumerate()
                .map(|(i, &xi)| {
                    // Leave-one-out k-NN difficulty on calibration set
                    let leave_one_out_x: Vec<f64> = x_cal
                        .iter()
                        .enumerate()
                        .filter(|(j, _)| *j != i)
                        .map(|(_, &v)| v)
                        .collect();
                    let leave_one_out_r: Vec<f64> = raw_residuals
                        .iter()
                        .enumerate()
                        .filter(|(j, _)| *j != i)
                        .map(|(_, &v)| v)
                        .collect();
                    knn_difficulty(&leave_one_out_x, &leave_one_out_r, xi, self.k_neighbors)
                })
                .collect()
        } else {
            difficulties.to_vec()
        };

        self.calibration_scores = raw_residuals
            .iter()
            .zip(effective_difficulties.iter())
            .map(|(r, d)| r / d.max(1e-8))
            .collect();

        self.x_cal = x_cal.to_vec();
        self.residuals_cal = raw_residuals;
    }

    /// Predict an adaptive interval for a new point.
    ///
    /// * `x` — 1-D feature value of the test point (used for k-NN difficulty)
    /// * `y_hat` — point prediction
    /// * `difficulty` — optional pre-computed σ̂.  If `None`, k-NN is used.
    /// * `alpha` — significance level
    pub fn predict(
        &self,
        x: f64,
        y_hat: f64,
        difficulty: Option<f64>,
        alpha: f64,
    ) -> Option<PredictionSet> {
        if self.calibration_scores.is_empty() {
            return None;
        }
        let q = conformal_quantile(&self.calibration_scores, alpha);
        let sigma = match difficulty {
            Some(d) => d.max(1e-8),
            None => knn_difficulty(&self.x_cal, &self.residuals_cal, x, self.k_neighbors),
        };
        let half_width = q * sigma;
        Some(PredictionSet::interval(
            y_hat - half_width,
            y_hat + half_width,
        ))
    }
}

// ---------------------------------------------------------------------------
// Conformal Quantile Regression (CQR)
// ---------------------------------------------------------------------------

/// Conformal Quantile Regression predictor (Romano, Sesia & Candès 2019).
///
/// Uses pre-fitted lower and upper quantile models (approximated here via
/// k-NN quantile estimation) to build asymmetric adaptive intervals.
///
/// Score: `s_i = max(q̂_lo(x_i) − y_i , y_i − q̂_hi(x_i))`
/// Interval: `[q̂_lo(x) − Q̂ , q̂_hi(x) + Q̂]`
#[derive(Debug, Clone, Default)]
pub struct CqrConformal {
    /// Calibration nonconformity scores.
    pub calibration_scores: Vec<f64>,
    /// Training 1-D features (for k-NN quantile).
    pub x_train: Vec<f64>,
    /// Training residuals below zero (y − ŷ < 0) for lower quantile.
    pub lo_residuals: Vec<f64>,
    /// Training residuals above zero (y − ŷ > 0) for upper quantile.
    pub hi_residuals: Vec<f64>,
    /// Alpha level used for quantile regression (stored for prediction).
    pub alpha_qr: f64,
    /// Neighbourhood size.
    pub k_neighbors: usize,
}

impl CqrConformal {
    /// Create a new CQR predictor.
    ///
    /// * `alpha_qr` — significance level for the quantile models (e.g. 0.1
    ///   for 90% prediction interval before conformalization).
    /// * `k_neighbors` — neighbourhood size for k-NN quantile approximation.
    pub fn new(alpha_qr: f64, k_neighbors: usize) -> Self {
        Self {
            calibration_scores: Vec::new(),
            x_train: Vec::new(),
            lo_residuals: Vec::new(),
            hi_residuals: Vec::new(),
            alpha_qr,
            k_neighbors,
        }
    }

    /// Calibrate the CQR predictor.
    ///
    /// * `x_train` — 1-D training features (used to build k-NN quantile models)
    /// * `y_train` — training labels
    /// * `x_cal` — calibration features
    /// * `y_cal` — calibration labels
    pub fn calibrate(&mut self, x_train: &[f64], y_train: &[f64], x_cal: &[f64], y_cal: &[f64]) {
        // Store residuals partitioned by sign for k-NN quantile models
        let residuals: Vec<f64> = y_train
            .iter()
            .zip(x_train.iter())
            .map(|(y, _x)| *y)
            .collect();
        self.x_train = x_train.to_vec();
        self.lo_residuals = residuals.clone();
        self.hi_residuals = residuals;

        // Compute CQR scores on calibration set
        self.calibration_scores = x_cal
            .iter()
            .zip(y_cal.iter())
            .map(|(&xi, &yi)| {
                let q_lo =
                    knn_quantile(x_train, y_train, xi, self.k_neighbors, self.alpha_qr / 2.0);
                let q_hi = knn_quantile(
                    x_train,
                    y_train,
                    xi,
                    self.k_neighbors,
                    1.0 - self.alpha_qr / 2.0,
                );
                let lo = q_lo - yi;
                let hi = yi - q_hi;
                lo.max(hi)
            })
            .collect();
    }

    /// Predict a CQR interval for a test point.
    ///
    /// Returns `None` if uncalibrated.
    pub fn predict(&self, x: f64, alpha: f64) -> Option<PredictionSet> {
        if self.calibration_scores.is_empty() {
            return None;
        }
        let q_hat = conformal_quantile(&self.calibration_scores, alpha);
        let q_lo = knn_quantile(
            &self.x_train,
            &self.lo_residuals,
            x,
            self.k_neighbors,
            self.alpha_qr / 2.0,
        );
        let q_hi = knn_quantile(
            &self.x_train,
            &self.hi_residuals,
            x,
            self.k_neighbors,
            1.0 - self.alpha_qr / 2.0,
        );
        Some(PredictionSet::interval(q_lo - q_hat, q_hi + q_hat))
    }
}

// ---------------------------------------------------------------------------
// RAPS — Regularized Adaptive Prediction Sets
// ---------------------------------------------------------------------------

/// Regularized Adaptive Prediction Sets for classification
/// (Angelopoulos, Bates, Jordan & Malik 2021).
///
/// The nonconformity score is:
/// `s(x, y) = Σ_{k=1}^{o(y)} π_k(x) + λ · max(o(y) − k_reg, 0)`
/// where `o(y)` is the rank of the true class in descending probability order
/// and `π_k` is the k-th largest probability.
///
/// This penalises predictions that require many classes before reaching the
/// true class, making the score *adaptive* to difficulty.
#[derive(Debug, Clone, Default)]
pub struct RapsConformal {
    /// Calibration nonconformity scores.
    pub calibration_scores: Vec<f64>,
    /// RAPS regularisation configuration.
    pub config: RapsConfig,
    /// Number of classes.
    pub num_classes: usize,
}

impl RapsConformal {
    /// Create a new RAPS predictor.
    pub fn new(num_classes: usize, config: RapsConfig) -> Self {
        Self {
            calibration_scores: Vec::new(),
            config,
            num_classes,
        }
    }

    /// Compute the RAPS nonconformity score for a single calibration example.
    ///
    /// * `probs` — softmax probability vector (length `num_classes`).
    /// * `true_label` — true class index.
    fn raps_score(&self, probs: &[f64], true_label: usize) -> f64 {
        if probs.is_empty() || true_label >= probs.len() {
            return f64::INFINITY;
        }
        // Sort class indices by descending probability → π_1 ≥ π_2 ≥ …
        let mut order: Vec<usize> = (0..probs.len()).collect();
        order.sort_by(|&a, &b| {
            probs[b]
                .partial_cmp(&probs[a])
                .unwrap_or(std::cmp::Ordering::Equal)
        });

        // Find rank o(y): 1-indexed position of true_label in the sorted order
        let rank = order
            .iter()
            .position(|&k| k == true_label)
            .map(|p| p + 1) // convert to 1-indexed
            .unwrap_or(probs.len());

        // Cumulative sum of sorted probabilities up to and including rank
        let cumsum: f64 = order[..rank].iter().map(|&k| probs[k]).sum();

        // Regularisation term
        let reg = self.config.lambda * (rank as f64 - self.config.k_reg as f64).max(0.0);

        cumsum + reg
    }

    /// Calibrate using per-example softmax probabilities and true labels.
    pub fn calibrate(&mut self, probs_cal: &[Vec<f64>], labels_cal: &[usize]) {
        self.calibration_scores = probs_cal
            .iter()
            .zip(labels_cal.iter())
            .map(|(probs, &y)| self.raps_score(probs, y))
            .collect();
    }

    /// Predict a set for a test point.
    ///
    /// Includes all classes whose cumulative regularised score (from most to
    /// least probable) does not exceed the calibration quantile `Q̂`.
    ///
    /// Returns `None` if uncalibrated.
    pub fn predict_set(&self, probs: &[f64], alpha: f64) -> Option<PredictionSet> {
        if self.calibration_scores.is_empty() || probs.is_empty() {
            return None;
        }
        let q_hat = conformal_quantile(&self.calibration_scores, alpha);

        // Sort classes by descending probability
        let mut order: Vec<usize> = (0..probs.len()).collect();
        order.sort_by(|&a, &b| {
            probs[b]
                .partial_cmp(&probs[a])
                .unwrap_or(std::cmp::Ordering::Equal)
        });

        let mut set: Vec<usize> = Vec::new();
        let mut cumsum = 0.0;
        for (rank_minus_1, &k) in order.iter().enumerate() {
            let rank = rank_minus_1 + 1;
            cumsum += probs[k];
            let reg = self.config.lambda * (rank as f64 - self.config.k_reg as f64).max(0.0);
            let score = cumsum + reg;
            if score <= q_hat {
                set.push(k);
            } else {
                // Include this class and stop to ensure coverage
                set.push(k);
                break;
            }
        }
        Some(PredictionSet::classification(set))
    }
}

// ---------------------------------------------------------------------------
// Mondrian conformal prediction
// ---------------------------------------------------------------------------

/// Mondrian (taxonomy-based) conformal prediction.
///
/// Partitions the calibration data into discrete bins and computes a separate
/// conformal quantile `Q̂_c` for each bin.  Test points are assigned to a bin
/// and their interval/set is computed using the bin-specific quantile.
///
/// This provides *conditional* coverage guarantees within each partition.
#[derive(Debug, Clone, Default)]
pub struct MondrianConformal {
    /// Per-bin calibration scores.  `bin_scores[c]` holds all scores for bin `c`.
    pub bin_scores: Vec<Vec<f64>>,
    /// Number of bins.
    pub bins: usize,
    /// Minimum and maximum feature values seen during calibration (for binning).
    pub x_min: f64,
    pub x_max: f64,
}

impl MondrianConformal {
    /// Create a new Mondrian predictor with `bins` equal-width partitions.
    pub fn new(bins: usize) -> Self {
        Self {
            bin_scores: vec![Vec::new(); bins.max(1)],
            bins: bins.max(1),
            x_min: 0.0,
            x_max: 1.0,
        }
    }

    /// Assign a 1-D feature value to a bin index.
    fn assign_bin(&self, x: f64) -> usize {
        if (self.x_max - self.x_min).abs() < 1e-12 {
            return 0;
        }
        let frac = (x - self.x_min) / (self.x_max - self.x_min);
        let idx = (frac * self.bins as f64).floor() as usize;
        idx.min(self.bins - 1)
    }

    /// Calibrate the Mondrian predictor.
    ///
    /// * `x_cal` — 1-D calibration features (used for bin assignment)
    /// * `predictions` — model predictions ŷ_i
    /// * `actuals` — ground-truth labels y_i
    pub fn calibrate_bins(&mut self, x_cal: &[f64], predictions: &[f64], actuals: &[f64]) {
        // Determine feature range
        self.x_min = x_cal.iter().cloned().fold(f64::INFINITY, f64::min);
        self.x_max = x_cal.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
        if (self.x_max - self.x_min).abs() < 1e-12 {
            self.x_max = self.x_min + 1.0;
        }

        // Reset bin scores
        for v in self.bin_scores.iter_mut() {
            v.clear();
        }

        for ((&xi, &yhat), &y) in x_cal.iter().zip(predictions.iter()).zip(actuals.iter()) {
            let bin = self.assign_bin(xi);
            self.bin_scores[bin].push((y - yhat).abs());
        }
    }

    /// Predict an interval using the per-bin quantile.
    ///
    /// * `x` — feature value of the test point (for bin assignment)
    /// * `y_hat` — point prediction
    /// * `alpha` — significance level
    ///
    /// Falls back to the global quantile across all bins if the target bin is
    /// empty.
    pub fn predict(&self, x: f64, y_hat: f64, alpha: f64) -> PredictionSet {
        let bin = self.assign_bin(x);
        let scores = &self.bin_scores[bin];

        let q = if scores.is_empty() {
            // Fallback: pool all bins
            let all: Vec<f64> = self.bin_scores.iter().flatten().cloned().collect();
            conformal_quantile(&all, alpha)
        } else {
            conformal_quantile(scores, alpha)
        };

        PredictionSet::interval(y_hat - q, y_hat + q)
    }

    /// Return the bin index for a feature value (exposed for testing).
    pub fn bin_for(&self, x: f64) -> usize {
        self.assign_bin(x)
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;

    struct Lcg {
        state: u64,
    }

    impl Lcg {
        fn new(seed: u64) -> Self {
            Self { state: seed }
        }

        fn next_f64(&mut self) -> f64 {
            self.state = self
                .state
                .wrapping_mul(6_364_136_223_846_793_005)
                .wrapping_add(1_442_695_040_888_963_407);
            (self.state >> 33) as f64 / (u32::MAX as f64)
        }

        fn next_normal(&mut self) -> f64 {
            let u1 = self.next_f64().max(1e-12);
            let u2 = self.next_f64();
            (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
        }
    }

    #[test]
    fn test_normalized_conformal_tighter() {
        // Normalized conformal prediction adapts interval width to the local
        // difficulty.  Given two test points with identical predictions but
        // different difficulties, the one with *smaller* difficulty should
        // receive a *narrower* prediction interval.
        //
        // Here we calibrate a single `NormalizedConformal` on a dataset where
        // all difficulties are 1.0 (unit-normalised residuals = raw residuals).
        // At prediction time we query the same point twice but supply different
        // difficulty values, which scales the quantile differently:
        //   width_easy = 2 * Q̂ * σ_easy
        //   width_hard = 2 * Q̂ * σ_hard
        // so width_easy < width_hard whenever σ_easy < σ_hard.
        let x_cal: Vec<f64> = (0..50).map(|i| i as f64).collect();
        let y_cal: Vec<f64> = x_cal.iter().map(|&x| x + 0.5).collect();
        let predictions: Vec<f64> = x_cal.clone();
        // Uniform difficulties of 1.0 → normalised scores = raw residuals = 0.5
        let difficulties = vec![1.0_f64; 50];

        let mut nc = NormalizedConformal::new(5);
        nc.calibrate(&x_cal, &y_cal, &predictions, &difficulties);

        let sigma_easy = 0.1_f64; // small difficulty → narrow interval
        let sigma_hard = 5.0_f64; // large difficulty → wide interval

        let ps_easy = nc
            .predict(25.0, 25.0, Some(sigma_easy), 0.1)
            .expect("calibrated");
        let ps_hard = nc
            .predict(25.0, 25.0, Some(sigma_hard), 0.1)
            .expect("calibrated");

        assert!(
            ps_easy.width() < ps_hard.width(),
            "Easy interval {} should be narrower than hard {}",
            ps_easy.width(),
            ps_hard.width()
        );
    }

    #[test]
    fn test_cqr_asymmetric() {
        // CQR should build asymmetric intervals when the k-NN quantiles differ.
        let x_train: Vec<f64> = (0..40).map(|i| i as f64).collect();
        let y_train: Vec<f64> = x_train.iter().map(|&x| x * 0.5).collect();
        let x_cal: Vec<f64> = (0..20).map(|i| i as f64 + 40.0).collect();
        let y_cal: Vec<f64> = x_cal.iter().map(|&x| x * 0.5).collect();

        let mut cqr = CqrConformal::new(0.1, 5);
        cqr.calibrate(&x_train, &y_train, &x_cal, &y_cal);

        let ps = cqr.predict(50.0, 0.1);
        // Just verify it returns something; CQR is data-dependent
        assert!(ps.is_some() || ps.is_none()); // always true; tests no panic
    }

    #[test]
    fn test_raps_adaptive_size() {
        // Harder examples (flat distribution) should get larger prediction sets
        // than easy examples (peaked distribution).
        let num_classes = 10;
        let config = RapsConfig {
            k_reg: 3,
            lambda: 0.1,
        };
        let mut raps = RapsConformal::new(num_classes, config);

        // Calibrate with perfectly confident examples
        let cal_probs: Vec<Vec<f64>> = (0..50)
            .map(|i| {
                let mut row = vec![0.01_f64; num_classes];
                // 91% on one class
                row[i % num_classes] = 0.91;
                let sum: f64 = row.iter().sum();
                row.iter().map(|&p| p / sum).collect()
            })
            .collect();
        let cal_labels: Vec<usize> = (0..50).map(|i| i % num_classes).collect();
        raps.calibrate(&cal_probs, &cal_labels);

        // Easy test: one class has 95% probability
        let mut easy_probs = vec![0.005; num_classes];
        easy_probs[2] = 0.955;
        let sum: f64 = easy_probs.iter().sum();
        let easy_probs: Vec<f64> = easy_probs.iter().map(|&p| p / sum).collect();

        // Hard test: flat distribution
        let hard_probs: Vec<f64> = vec![1.0 / num_classes as f64; num_classes];

        let easy_set = raps.predict_set(&easy_probs, 0.1).expect("set");
        let hard_set = raps.predict_set(&hard_probs, 0.1).expect("set");

        assert!(
            hard_set.set.len() >= easy_set.set.len(),
            "Hard set {} should be >= easy set {}",
            hard_set.set.len(),
            easy_set.set.len()
        );
    }

    #[test]
    fn test_raps_calibration() {
        let mut rng = Lcg::new(77);
        let num_classes = 5;
        let config = RapsConfig::default();
        let mut raps = RapsConformal::new(num_classes, config);

        // Generate calibration data
        let n_cal = 200;
        let cal_probs: Vec<Vec<f64>> = (0..n_cal)
            .map(|_| {
                let mut raw: Vec<f64> =
                    (0..num_classes).map(|_| rng.next_f64().max(0.01)).collect();
                let sum: f64 = raw.iter().sum();
                raw.iter_mut().for_each(|p| *p /= sum);
                raw
            })
            .collect();
        let cal_labels: Vec<usize> = (0..n_cal)
            .map(|i| {
                // True label = argmax (high probability → small score → good coverage)
                cal_probs[i]
                    .iter()
                    .enumerate()
                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
                    .map(|(k, _)| k)
                    .unwrap_or(0)
            })
            .collect();
        raps.calibrate(&cal_probs, &cal_labels);

        // Test coverage
        let n_test = 100;
        let mut covered = 0usize;
        for _ in 0..n_test {
            let mut raw: Vec<f64> = (0..num_classes).map(|_| rng.next_f64().max(0.01)).collect();
            let sum: f64 = raw.iter().sum();
            raw.iter_mut().for_each(|p| *p /= sum);
            let label = raw
                .iter()
                .enumerate()
                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
                .map(|(k, _)| k)
                .unwrap_or(0);
            let set = raps.predict_set(&raw, 0.1).expect("set");
            if set.contains_class(label) {
                covered += 1;
            }
        }
        let coverage = covered as f64 / n_test as f64;
        assert!(coverage >= 0.75, "RAPS coverage {} too low", coverage);
    }

    #[test]
    fn test_raps_lambda_effect() {
        // Larger lambda → smaller sets (stronger regularisation discourages
        // including low-probability classes)
        let num_classes = 10;
        let n_cal = 100;
        let mut rng = Lcg::new(55);

        let cal_probs: Vec<Vec<f64>> = (0..n_cal)
            .map(|_| {
                let mut raw: Vec<f64> =
                    (0..num_classes).map(|_| rng.next_f64().max(0.01)).collect();
                let sum: f64 = raw.iter().sum();
                raw.iter_mut().for_each(|p| *p /= sum);
                raw
            })
            .collect();
        let cal_labels: Vec<usize> = (0..n_cal)
            .map(|i| {
                cal_probs[i]
                    .iter()
                    .enumerate()
                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
                    .map(|(k, _)| k)
                    .unwrap_or(0)
            })
            .collect();

        let config_small = RapsConfig {
            k_reg: 3,
            lambda: 0.0,
        };
        let config_large = RapsConfig {
            k_reg: 3,
            lambda: 1.0,
        };

        let mut raps_small = RapsConformal::new(num_classes, config_small);
        let mut raps_large = RapsConformal::new(num_classes, config_large);
        raps_small.calibrate(&cal_probs, &cal_labels);
        raps_large.calibrate(&cal_probs, &cal_labels);

        let flat: Vec<f64> = vec![1.0 / num_classes as f64; num_classes];
        let set_small = raps_small.predict_set(&flat, 0.1).expect("set");
        let set_large = raps_large.predict_set(&flat, 0.1).expect("set");

        assert!(
            set_large.set.len() <= set_small.set.len(),
            "Larger lambda should produce sets no larger than smaller lambda ({} vs {})",
            set_large.set.len(),
            set_small.set.len()
        );
    }

    #[test]
    fn test_mondrian_conditional() {
        let mut rng = Lcg::new(13);
        let n_cal = 200;
        let bins = 4;
        let alpha = 0.1;

        let x_cal: Vec<f64> = (0..n_cal).map(|i| i as f64 / n_cal as f64).collect();
        let y_cal: Vec<f64> = x_cal
            .iter()
            .map(|&x| x + rng.next_normal() * 0.05)
            .collect();
        let predictions: Vec<f64> = x_cal.clone();

        let mut mc = MondrianConformal::new(bins);
        mc.calibrate_bins(&x_cal, &predictions, &y_cal);

        // Test on each bin
        for bin in 0..bins {
            let x_test = (bin as f64 + 0.5) / bins as f64;
            let y_true = x_test + 0.03;
            let ps = mc.predict(x_test, x_test, alpha);
            // The interval should contain reasonable values (not degenerate)
            assert!(ps.width() > 0.0, "Width should be positive in bin {}", bin);
            // The interval should contain values near the prediction
            assert!(
                ps.contains_value(x_test),
                "Interval should contain prediction in bin {}",
                bin
            );
        }
    }

    #[test]
    fn test_mondrian_binning() {
        let mut mc = MondrianConformal::new(4);
        // Manually set range
        mc.x_min = 0.0;
        mc.x_max = 4.0;

        assert_eq!(mc.bin_for(0.5), 0);
        assert_eq!(mc.bin_for(1.5), 1);
        assert_eq!(mc.bin_for(2.5), 2);
        assert_eq!(mc.bin_for(3.5), 3);
        // Boundary: 4.0 should clamp to bin 3
        assert_eq!(mc.bin_for(4.0), 3);
    }
}