Skip to main content

ftui_runtime/
conformal_predictor.rs

1#![forbid(unsafe_code)]
2
3//! Conformal predictor for frame-time risk (bd-3e1t.3.2).
4//!
5//! This module provides a distribution-free upper bound on frame time using
6//! Mondrian (bucketed) conformal prediction. It is intentionally lightweight
7//! and explainable: each prediction returns the bucket key, quantile, and
8//! fallback level used to produce the bound.
9//!
10//! See docs/spec/state-machines.md section 3.13 for the governing spec.
11
12use std::collections::{HashMap, VecDeque};
13use std::fmt;
14
15use ftui_render::diff_strategy::DiffStrategy;
16
17use crate::terminal_writer::ScreenMode;
18
19/// Configuration for conformal frame-time prediction.
20#[derive(Debug, Clone)]
21pub struct ConformalConfig {
22    /// Significance level alpha. Coverage is >= 1 - alpha.
23    /// Default: 0.05.
24    pub alpha: f64,
25
26    /// Minimum samples required before a bucket is considered valid.
27    /// Default: 20.
28    pub min_samples: usize,
29
30    /// Maximum samples retained per bucket (rolling window).
31    /// Default: 256.
32    pub window_size: usize,
33
34    /// Conservative fallback residual (microseconds) when no calibration exists.
35    /// Default: 10_000.0 (10ms).
36    pub q_default: f64,
37}
38
39impl Default for ConformalConfig {
40    fn default() -> Self {
41        Self {
42            alpha: 0.05,
43            min_samples: 20,
44            window_size: 256,
45            q_default: 10_000.0,
46        }
47    }
48}
49
50/// Bucket identifier for conformal calibration.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct BucketKey {
53    pub mode: ModeBucket,
54    pub diff: DiffBucket,
55    pub size_bucket: u8,
56}
57
58impl BucketKey {
59    /// Create a bucket key from rendering context.
60    pub fn from_context(
61        screen_mode: ScreenMode,
62        diff_strategy: DiffStrategy,
63        cols: u16,
64        rows: u16,
65    ) -> Self {
66        Self {
67            mode: ModeBucket::from_screen_mode(screen_mode),
68            diff: DiffBucket::from(diff_strategy),
69            size_bucket: size_bucket(cols, rows),
70        }
71    }
72}
73
74/// Mode bucket for conformal calibration.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum ModeBucket {
77    Inline,
78    InlineAuto,
79    AltScreen,
80}
81
82impl ModeBucket {
83    pub fn as_str(self) -> &'static str {
84        match self {
85            Self::Inline => "inline",
86            Self::InlineAuto => "inline_auto",
87            Self::AltScreen => "altscreen",
88        }
89    }
90
91    pub fn from_screen_mode(mode: ScreenMode) -> Self {
92        match mode {
93            ScreenMode::Inline { .. } => Self::Inline,
94            ScreenMode::InlineAuto { .. } => Self::InlineAuto,
95            ScreenMode::AltScreen => Self::AltScreen,
96        }
97    }
98}
99
100/// Diff strategy bucket for conformal calibration.
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub enum DiffBucket {
103    Full,
104    DirtyRows,
105    FullRedraw,
106}
107
108impl DiffBucket {
109    pub fn as_str(self) -> &'static str {
110        match self {
111            Self::Full => "full",
112            Self::DirtyRows => "dirty",
113            Self::FullRedraw => "redraw",
114        }
115    }
116}
117
118impl From<DiffStrategy> for DiffBucket {
119    fn from(strategy: DiffStrategy) -> Self {
120        match strategy {
121            DiffStrategy::Full => Self::Full,
122            DiffStrategy::DirtyRows => Self::DirtyRows,
123            DiffStrategy::FullRedraw => Self::FullRedraw,
124        }
125    }
126}
127
128impl fmt::Display for BucketKey {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(
131            f,
132            "{}:{}:{}",
133            self.mode.as_str(),
134            self.diff.as_str(),
135            self.size_bucket
136        )
137    }
138}
139
140/// Prediction output with full explainability.
141#[derive(Debug, Clone)]
142pub struct ConformalPrediction {
143    /// Upper bound on frame time (microseconds).
144    pub upper_us: f64,
145    /// Whether the bound exceeds the current budget.
146    pub risk: bool,
147    /// Coverage confidence (1 - alpha).
148    pub confidence: f64,
149    /// Bucket key used for calibration (may be fallback aggregate).
150    pub bucket: BucketKey,
151    /// Calibration sample count used for the quantile.
152    pub sample_count: usize,
153    /// Conformal quantile q_b.
154    pub quantile: f64,
155    /// Fallback level (0 = exact, 1 = mode+diff, 2 = mode, 3 = global/default).
156    pub fallback_level: u8,
157    /// Rolling window size.
158    pub window_size: usize,
159    /// Total reset count for this predictor.
160    pub reset_count: u64,
161    /// Base prediction f(x_t).
162    pub y_hat: f64,
163    /// Frame budget in microseconds.
164    pub budget_us: f64,
165}
166
167impl ConformalPrediction {
168    /// Format this prediction as a JSONL line for structured logging.
169    #[must_use]
170    pub fn to_jsonl(&self) -> String {
171        format!(
172            r#"{{"schema":"conformal-v1","upper_us":{:.1},"risk":{},"confidence":{:.4},"bucket":"{}","samples":{},"quantile":{:.2},"fallback_level":{},"window":{},"resets":{},"y_hat":{:.1},"budget_us":{:.1}}}"#,
173            self.upper_us,
174            self.risk,
175            self.confidence,
176            self.bucket,
177            self.sample_count,
178            self.quantile,
179            self.fallback_level,
180            self.window_size,
181            self.reset_count,
182            self.y_hat,
183            self.budget_us,
184        )
185    }
186}
187
188/// Update metadata after observing a frame.
189#[derive(Debug, Clone)]
190pub struct ConformalUpdate {
191    /// Residual (y_t - f(x_t)).
192    pub residual: f64,
193    /// Bucket updated.
194    pub bucket: BucketKey,
195    /// New sample count in the bucket.
196    pub sample_count: usize,
197}
198
199#[derive(Debug, Default)]
200struct BucketState {
201    residuals: VecDeque<f64>,
202}
203
204impl BucketState {
205    fn push(&mut self, residual: f64, window_size: usize) {
206        self.residuals.push_back(residual);
207        while self.residuals.len() > window_size {
208            self.residuals.pop_front();
209        }
210    }
211}
212
213/// Conformal predictor with bucketed calibration.
214#[derive(Debug)]
215pub struct ConformalPredictor {
216    config: ConformalConfig,
217    buckets: HashMap<BucketKey, BucketState>,
218    reset_count: u64,
219}
220
221impl ConformalPredictor {
222    /// Create a new predictor with the given config.
223    pub fn new(config: ConformalConfig) -> Self {
224        Self {
225            config,
226            buckets: HashMap::new(),
227            reset_count: 0,
228        }
229    }
230
231    /// Access the configuration.
232    pub fn config(&self) -> &ConformalConfig {
233        &self.config
234    }
235
236    /// Number of samples currently stored for a bucket.
237    pub fn bucket_samples(&self, key: BucketKey) -> usize {
238        self.buckets
239            .get(&key)
240            .map(|state| state.residuals.len())
241            .unwrap_or(0)
242    }
243
244    /// Clear calibration for all buckets.
245    pub fn reset_all(&mut self) {
246        self.buckets.clear();
247        self.reset_count += 1;
248    }
249
250    /// Clear calibration for a single bucket.
251    pub fn reset_bucket(&mut self, key: BucketKey) {
252        if let Some(state) = self.buckets.get_mut(&key) {
253            state.residuals.clear();
254            self.reset_count += 1;
255        }
256    }
257
258    /// Observe a realized frame time and update calibration.
259    pub fn observe(&mut self, key: BucketKey, y_hat_us: f64, observed_us: f64) -> ConformalUpdate {
260        let residual = observed_us - y_hat_us;
261        if !residual.is_finite() {
262            return ConformalUpdate {
263                residual,
264                bucket: key,
265                sample_count: self.bucket_samples(key),
266            };
267        }
268
269        let window_size = self.config.window_size.max(1);
270        let state = self.buckets.entry(key).or_default();
271        state.push(residual, window_size);
272        ConformalUpdate {
273            residual,
274            bucket: key,
275            sample_count: state.residuals.len(),
276        }
277    }
278
279    /// Predict a conservative upper bound for frame time.
280    pub fn predict(&self, key: BucketKey, y_hat_us: f64, budget_us: f64) -> ConformalPrediction {
281        let span = tracing::info_span!(
282            "conformal.predict",
283            calibration_set_size = tracing::field::Empty,
284            predicted_upper_bound_us = tracing::field::Empty,
285            frame_budget_us = budget_us,
286            coverage_alpha = self.config.alpha,
287            gate_triggered = tracing::field::Empty,
288        );
289        let _guard = span.enter();
290
291        let QuantileDecision {
292            quantile,
293            sample_count,
294            fallback_level,
295        } = self.quantile_for(key);
296
297        let upper_us = y_hat_us + quantile.max(0.0);
298        let risk = upper_us > budget_us;
299
300        span.record("calibration_set_size", sample_count);
301        span.record("predicted_upper_bound_us", upper_us);
302        span.record("gate_triggered", risk);
303
304        tracing::debug!(
305            bucket = %key,
306            y_hat_us,
307            quantile,
308            interval_width_us = quantile.max(0.0),
309            fallback_level,
310            sample_count,
311            "prediction interval"
312        );
313
314        ConformalPrediction {
315            upper_us,
316            risk,
317            confidence: 1.0 - self.config.alpha,
318            bucket: key,
319            sample_count,
320            quantile,
321            fallback_level,
322            window_size: self.config.window_size,
323            reset_count: self.reset_count,
324            y_hat: y_hat_us,
325            budget_us,
326        }
327    }
328
329    fn quantile_for(&self, key: BucketKey) -> QuantileDecision {
330        let min_samples = self.config.min_samples.max(1);
331
332        let exact = self.collect_exact(key);
333        if exact.len() >= min_samples {
334            return QuantileDecision::new(self.config.alpha, exact, 0);
335        }
336
337        let mode_diff = self.collect_mode_diff(key.mode, key.diff);
338        if mode_diff.len() >= min_samples {
339            return QuantileDecision::new(self.config.alpha, mode_diff, 1);
340        }
341
342        let mode_only = self.collect_mode(key.mode);
343        if mode_only.len() >= min_samples {
344            return QuantileDecision::new(self.config.alpha, mode_only, 2);
345        }
346
347        let global = self.collect_all();
348        if !global.is_empty() {
349            return QuantileDecision::new(self.config.alpha, global, 3);
350        }
351
352        QuantileDecision {
353            quantile: self.config.q_default,
354            sample_count: 0,
355            fallback_level: 3,
356        }
357    }
358
359    fn collect_exact(&self, key: BucketKey) -> Vec<f64> {
360        self.buckets
361            .get(&key)
362            .map(|state| state.residuals.iter().copied().collect())
363            .unwrap_or_default()
364    }
365
366    fn collect_mode_diff(&self, mode: ModeBucket, diff: DiffBucket) -> Vec<f64> {
367        let mut values = Vec::new();
368        for (key, state) in &self.buckets {
369            if key.mode == mode && key.diff == diff {
370                values.extend(state.residuals.iter().copied());
371            }
372        }
373        values
374    }
375
376    fn collect_mode(&self, mode: ModeBucket) -> Vec<f64> {
377        let mut values = Vec::new();
378        for (key, state) in &self.buckets {
379            if key.mode == mode {
380                values.extend(state.residuals.iter().copied());
381            }
382        }
383        values
384    }
385
386    fn collect_all(&self) -> Vec<f64> {
387        let mut values = Vec::new();
388        for state in self.buckets.values() {
389            values.extend(state.residuals.iter().copied());
390        }
391        values
392    }
393}
394
395#[derive(Debug)]
396struct QuantileDecision {
397    quantile: f64,
398    sample_count: usize,
399    fallback_level: u8,
400}
401
402impl QuantileDecision {
403    fn new(alpha: f64, mut residuals: Vec<f64>, fallback_level: u8) -> Self {
404        let quantile = conformal_quantile(alpha, &mut residuals);
405        Self {
406            quantile,
407            sample_count: residuals.len(),
408            fallback_level,
409        }
410    }
411}
412
413fn conformal_quantile(alpha: f64, residuals: &mut [f64]) -> f64 {
414    if residuals.is_empty() {
415        return 0.0;
416    }
417    residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
418    let n = residuals.len();
419    let rank = ((n as f64 + 1.0) * (1.0 - alpha)).ceil() as usize;
420    let idx = rank.saturating_sub(1).min(n - 1);
421    residuals[idx]
422}
423
424fn size_bucket(cols: u16, rows: u16) -> u8 {
425    let area = cols as u32 * rows as u32;
426    if area == 0 {
427        return 0;
428    }
429    (31 - area.leading_zeros()) as u8
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    fn test_key(cols: u16, rows: u16) -> BucketKey {
437        BucketKey::from_context(
438            ScreenMode::Inline { ui_height: 4 },
439            DiffStrategy::Full,
440            cols,
441            rows,
442        )
443    }
444
445    #[test]
446    fn quantile_n_plus_1_rule() {
447        let mut predictor = ConformalPredictor::new(ConformalConfig {
448            alpha: 0.2,
449            min_samples: 1,
450            window_size: 10,
451            q_default: 0.0,
452        });
453
454        let key = test_key(80, 24);
455        predictor.observe(key, 0.0, 1.0);
456        predictor.observe(key, 0.0, 2.0);
457        predictor.observe(key, 0.0, 3.0);
458
459        let decision = predictor.predict(key, 0.0, 1_000.0);
460        assert_eq!(decision.quantile, 3.0);
461    }
462
463    #[test]
464    fn fallback_hierarchy_mode_diff() {
465        let mut predictor = ConformalPredictor::new(ConformalConfig {
466            alpha: 0.1,
467            min_samples: 4,
468            window_size: 16,
469            q_default: 0.0,
470        });
471
472        let key_a = test_key(80, 24);
473        for value in [1.0, 2.0, 3.0, 4.0] {
474            predictor.observe(key_a, 0.0, value);
475        }
476
477        let key_b = test_key(120, 40);
478        let decision = predictor.predict(key_b, 0.0, 1_000.0);
479        assert_eq!(decision.fallback_level, 1);
480        assert_eq!(decision.sample_count, 4);
481    }
482
483    #[test]
484    fn fallback_hierarchy_mode_only() {
485        let mut predictor = ConformalPredictor::new(ConformalConfig {
486            alpha: 0.1,
487            min_samples: 3,
488            window_size: 16,
489            q_default: 0.0,
490        });
491
492        let key_dirty = BucketKey::from_context(
493            ScreenMode::Inline { ui_height: 4 },
494            DiffStrategy::DirtyRows,
495            80,
496            24,
497        );
498        for value in [10.0, 20.0, 30.0] {
499            predictor.observe(key_dirty, 0.0, value);
500        }
501
502        let key_full = BucketKey::from_context(
503            ScreenMode::Inline { ui_height: 4 },
504            DiffStrategy::Full,
505            120,
506            40,
507        );
508        let decision = predictor.predict(key_full, 0.0, 1_000.0);
509        assert_eq!(decision.fallback_level, 2);
510        assert_eq!(decision.sample_count, 3);
511    }
512
513    #[test]
514    fn window_enforced() {
515        let mut predictor = ConformalPredictor::new(ConformalConfig {
516            alpha: 0.1,
517            min_samples: 1,
518            window_size: 3,
519            q_default: 0.0,
520        });
521        let key = test_key(80, 24);
522        for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
523            predictor.observe(key, 0.0, value);
524        }
525        assert_eq!(predictor.bucket_samples(key), 3);
526    }
527
528    #[test]
529    fn predict_uses_default_when_empty() {
530        let predictor = ConformalPredictor::new(ConformalConfig {
531            alpha: 0.1,
532            min_samples: 2,
533            window_size: 4,
534            q_default: 42.0,
535        });
536        let key = test_key(120, 40);
537        let prediction = predictor.predict(key, 5.0, 10_000.0);
538        assert_eq!(prediction.quantile, 42.0);
539        assert_eq!(prediction.sample_count, 0);
540        assert_eq!(prediction.fallback_level, 3);
541    }
542
543    #[test]
544    fn bucket_isolation_by_size() {
545        let mut predictor = ConformalPredictor::new(ConformalConfig {
546            alpha: 0.2,
547            min_samples: 2,
548            window_size: 10,
549            q_default: 0.0,
550        });
551
552        let small = test_key(40, 10);
553        predictor.observe(small, 0.0, 1.0);
554        predictor.observe(small, 0.0, 2.0);
555
556        let large = test_key(200, 60);
557        predictor.observe(large, 0.0, 10.0);
558        predictor.observe(large, 0.0, 12.0);
559
560        let prediction = predictor.predict(large, 0.0, 1_000.0);
561        assert_eq!(prediction.fallback_level, 0);
562        assert_eq!(prediction.sample_count, 2);
563        assert_eq!(prediction.quantile, 12.0);
564    }
565
566    #[test]
567    fn reset_clears_bucket_and_raises_reset_count() {
568        let mut predictor = ConformalPredictor::new(ConformalConfig {
569            alpha: 0.1,
570            min_samples: 1,
571            window_size: 8,
572            q_default: 7.0,
573        });
574        let key = test_key(80, 24);
575        predictor.observe(key, 0.0, 3.0);
576        assert_eq!(predictor.bucket_samples(key), 1);
577
578        predictor.reset_bucket(key);
579        assert_eq!(predictor.bucket_samples(key), 0);
580
581        let prediction = predictor.predict(key, 0.0, 1_000.0);
582        assert_eq!(prediction.quantile, 7.0);
583        assert_eq!(prediction.reset_count, 1);
584    }
585
586    #[test]
587    fn reset_all_forces_conservative_fallback() {
588        let mut predictor = ConformalPredictor::new(ConformalConfig {
589            alpha: 0.1,
590            min_samples: 1,
591            window_size: 8,
592            q_default: 9.0,
593        });
594        let key = test_key(80, 24);
595        predictor.observe(key, 0.0, 2.0);
596
597        predictor.reset_all();
598        let prediction = predictor.predict(key, 0.0, 1_000.0);
599        assert_eq!(prediction.quantile, 9.0);
600        assert_eq!(prediction.sample_count, 0);
601        assert_eq!(prediction.fallback_level, 3);
602        assert_eq!(prediction.reset_count, 1);
603    }
604
605    #[test]
606    fn size_bucket_log2_area() {
607        let a = size_bucket(8, 8); // area 64 -> log2 = 6
608        let b = size_bucket(8, 16); // area 128 -> log2 = 7
609        assert_eq!(a, 6);
610        assert_eq!(b, 7);
611    }
612
613    // --- size_bucket edge cases ---
614
615    #[test]
616    fn size_bucket_zero_area() {
617        assert_eq!(size_bucket(0, 0), 0);
618        assert_eq!(size_bucket(0, 24), 0);
619        assert_eq!(size_bucket(80, 0), 0);
620    }
621
622    #[test]
623    fn size_bucket_one_by_one() {
624        assert_eq!(size_bucket(1, 1), 0); // area 1, log2(1) = 0
625    }
626
627    #[test]
628    fn size_bucket_typical_terminals() {
629        let b80 = size_bucket(80, 24); // 1920 -> log2 ~ 10
630        let b120 = size_bucket(120, 40); // 4800 -> log2 ~ 12
631        assert_eq!(b80, 10);
632        assert_eq!(b120, 12);
633    }
634
635    // --- conformal_quantile edge cases ---
636
637    #[test]
638    fn conformal_quantile_empty() {
639        let mut data: Vec<f64> = vec![];
640        assert_eq!(conformal_quantile(0.1, &mut data), 0.0);
641    }
642
643    #[test]
644    fn conformal_quantile_single_element() {
645        let mut data = vec![42.0];
646        assert_eq!(conformal_quantile(0.1, &mut data), 42.0);
647    }
648
649    #[test]
650    fn conformal_quantile_sorted_data() {
651        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
652        let q = conformal_quantile(0.5, &mut data);
653        // (5+1)*0.5 = 3.0 -> ceil = 3 -> idx = 2 -> data[2] = 3.0
654        assert_eq!(q, 3.0);
655    }
656
657    #[test]
658    fn conformal_quantile_alpha_half() {
659        let mut data = vec![10.0, 20.0, 30.0, 40.0];
660        let q = conformal_quantile(0.5, &mut data);
661        // (4+1)*0.5 = 2.5 -> ceil = 3 -> idx = 2 -> data[2] = 30.0
662        assert_eq!(q, 30.0);
663    }
664
665    // --- ModeBucket / DiffBucket ---
666
667    #[test]
668    fn mode_bucket_as_str_all_variants() {
669        assert_eq!(ModeBucket::Inline.as_str(), "inline");
670        assert_eq!(ModeBucket::InlineAuto.as_str(), "inline_auto");
671        assert_eq!(ModeBucket::AltScreen.as_str(), "altscreen");
672    }
673
674    #[test]
675    fn diff_bucket_as_str_all_variants() {
676        assert_eq!(DiffBucket::Full.as_str(), "full");
677        assert_eq!(DiffBucket::DirtyRows.as_str(), "dirty");
678        assert_eq!(DiffBucket::FullRedraw.as_str(), "redraw");
679    }
680
681    #[test]
682    fn diff_bucket_from_strategy() {
683        assert_eq!(DiffBucket::from(DiffStrategy::Full), DiffBucket::Full);
684        assert_eq!(
685            DiffBucket::from(DiffStrategy::DirtyRows),
686            DiffBucket::DirtyRows
687        );
688        assert_eq!(
689            DiffBucket::from(DiffStrategy::FullRedraw),
690            DiffBucket::FullRedraw
691        );
692    }
693
694    // --- BucketKey Display ---
695
696    #[test]
697    fn bucket_key_display_format() {
698        let key = BucketKey {
699            mode: ModeBucket::AltScreen,
700            diff: DiffBucket::DirtyRows,
701            size_bucket: 12,
702        };
703        assert_eq!(format!("{key}"), "altscreen:dirty:12");
704    }
705
706    // --- observe edge cases ---
707
708    #[test]
709    fn observe_nan_residual_not_stored() {
710        let mut predictor = ConformalPredictor::new(ConformalConfig {
711            alpha: 0.1,
712            min_samples: 1,
713            window_size: 8,
714            q_default: 5.0,
715        });
716        let key = test_key(80, 24);
717        let update = predictor.observe(key, 0.0, f64::NAN);
718        assert!(!update.residual.is_finite());
719        assert_eq!(predictor.bucket_samples(key), 0);
720    }
721
722    #[test]
723    fn observe_infinity_residual_not_stored() {
724        let mut predictor = ConformalPredictor::new(ConformalConfig {
725            alpha: 0.1,
726            min_samples: 1,
727            window_size: 8,
728            q_default: 5.0,
729        });
730        let key = test_key(80, 24);
731        predictor.observe(key, 0.0, f64::INFINITY);
732        assert_eq!(predictor.bucket_samples(key), 0);
733    }
734
735    // --- prediction fields ---
736
737    #[test]
738    fn prediction_risk_flag() {
739        let predictor = ConformalPredictor::new(ConformalConfig {
740            alpha: 0.1,
741            min_samples: 1,
742            window_size: 8,
743            q_default: 50.0,
744        });
745        let key = test_key(80, 24);
746        // No data -> q_default = 50.0, y_hat = 0 -> upper = 50
747        let p = predictor.predict(key, 0.0, 100.0);
748        assert!(!p.risk); // 50 <= 100
749        let p2 = predictor.predict(key, 0.0, 30.0);
750        assert!(p2.risk); // 50 > 30
751    }
752
753    #[test]
754    fn prediction_confidence() {
755        let predictor = ConformalPredictor::new(ConformalConfig {
756            alpha: 0.05,
757            min_samples: 1,
758            window_size: 8,
759            q_default: 0.0,
760        });
761        let key = test_key(80, 24);
762        let p = predictor.predict(key, 0.0, 100.0);
763        assert!((p.confidence - 0.95).abs() < 1e-10);
764    }
765
766    // --- global fallback with data ---
767
768    #[test]
769    fn global_fallback_with_data() {
770        let mut predictor = ConformalPredictor::new(ConformalConfig {
771            alpha: 0.1,
772            min_samples: 100, // impossibly high -> always fall through
773            window_size: 256,
774            q_default: 999.0,
775        });
776        // Use altscreen mode bucket, then query inline
777        let alt_key = BucketKey::from_context(ScreenMode::AltScreen, DiffStrategy::Full, 80, 24);
778        predictor.observe(alt_key, 0.0, 5.0);
779
780        let inline_key = test_key(80, 24);
781        let p = predictor.predict(inline_key, 0.0, 1000.0);
782        // Falls all the way to global (level 3), has 1 sample
783        assert_eq!(p.fallback_level, 3);
784        assert_eq!(p.sample_count, 1);
785        assert_eq!(p.quantile, 5.0);
786    }
787
788    // --- ModeBucket from_screen_mode ---
789
790    #[test]
791    fn mode_bucket_from_screen_modes() {
792        assert_eq!(
793            ModeBucket::from_screen_mode(ScreenMode::Inline { ui_height: 4 }),
794            ModeBucket::Inline
795        );
796        assert_eq!(
797            ModeBucket::from_screen_mode(ScreenMode::InlineAuto {
798                min_height: 4,
799                max_height: 24
800            }),
801            ModeBucket::InlineAuto
802        );
803        assert_eq!(
804            ModeBucket::from_screen_mode(ScreenMode::AltScreen),
805            ModeBucket::AltScreen
806        );
807    }
808
809    // --- Config defaults ---
810
811    #[test]
812    fn config_defaults() {
813        let config = ConformalConfig::default();
814        assert!((config.alpha - 0.05).abs() < 1e-10);
815        assert_eq!(config.min_samples, 20);
816        assert_eq!(config.window_size, 256);
817        assert!((config.q_default - 10_000.0).abs() < 1e-10);
818    }
819
820    #[test]
821    fn predictor_config_accessor() {
822        let config = ConformalConfig {
823            alpha: 0.2,
824            min_samples: 5,
825            window_size: 32,
826            q_default: 100.0,
827        };
828        let predictor = ConformalPredictor::new(config);
829        assert!((predictor.config().alpha - 0.2).abs() < 1e-10);
830        assert_eq!(predictor.config().min_samples, 5);
831    }
832
833    // --- negative residuals ---
834
835    #[test]
836    fn negative_residual_clamped_in_prediction() {
837        let mut predictor = ConformalPredictor::new(ConformalConfig {
838            alpha: 0.1,
839            min_samples: 1,
840            window_size: 8,
841            q_default: 0.0,
842        });
843        let key = test_key(80, 24);
844        // observed < y_hat -> negative residual
845        predictor.observe(key, 10.0, 5.0);
846        let p = predictor.predict(key, 10.0, 100.0);
847        // quantile is -5.0, but clamped to 0.0 via .max(0.0)
848        // so upper_us = 10.0 + 0.0 = 10.0
849        assert_eq!(p.upper_us, 10.0);
850    }
851
852    // --- ConformalUpdate fields ---
853
854    #[test]
855    fn observe_returns_correct_update() {
856        let mut predictor = ConformalPredictor::new(ConformalConfig {
857            alpha: 0.1,
858            min_samples: 1,
859            window_size: 8,
860            q_default: 0.0,
861        });
862        let key = test_key(80, 24);
863        let update = predictor.observe(key, 3.0, 10.0);
864        assert!((update.residual - 7.0).abs() < 1e-10);
865        assert_eq!(update.bucket, key);
866        assert_eq!(update.sample_count, 1);
867    }
868
869    // --- prediction y_hat and budget fields ---
870
871    #[test]
872    fn prediction_preserves_yhat_and_budget() {
873        let predictor = ConformalPredictor::new(ConformalConfig::default());
874        let key = test_key(80, 24);
875        let p = predictor.predict(key, 42.5, 16666.0);
876        assert!((p.y_hat - 42.5).abs() < 1e-10);
877        assert!((p.budget_us - 16666.0).abs() < 1e-10);
878    }
879
880    // --- tracing span verification ---
881
882    #[test]
883    fn predict_emits_conformal_predict_span() {
884        use std::sync::Arc;
885        use std::sync::atomic::{AtomicBool, Ordering};
886
887        struct SpanChecker {
888            saw_conformal_predict: Arc<AtomicBool>,
889        }
890
891        impl tracing::Subscriber for SpanChecker {
892            fn enabled(&self, _metadata: &tracing::Metadata<'_>) -> bool {
893                true
894            }
895            fn new_span(&self, span: &tracing::span::Attributes<'_>) -> tracing::span::Id {
896                if span.metadata().name() == "conformal.predict" {
897                    self.saw_conformal_predict.store(true, Ordering::Relaxed);
898                }
899                tracing::span::Id::from_u64(1)
900            }
901            fn record(&self, _span: &tracing::span::Id, _values: &tracing::span::Record<'_>) {}
902            fn record_follows_from(&self, _span: &tracing::span::Id, _follows: &tracing::span::Id) {
903            }
904            fn event(&self, _event: &tracing::Event<'_>) {}
905            fn enter(&self, _span: &tracing::span::Id) {}
906            fn exit(&self, _span: &tracing::span::Id) {}
907        }
908
909        let saw_it = Arc::new(AtomicBool::new(false));
910        let subscriber = SpanChecker {
911            saw_conformal_predict: Arc::clone(&saw_it),
912        };
913        let _guard = tracing::subscriber::set_default(subscriber);
914
915        let predictor = ConformalPredictor::new(ConformalConfig::default());
916        let key = test_key(80, 24);
917        let _ = predictor.predict(key, 100.0, 16666.0);
918
919        assert!(
920            saw_it.load(Ordering::Relaxed),
921            "predict() must emit a 'conformal.predict' tracing span"
922        );
923    }
924
925    #[test]
926    fn predict_span_records_gate_triggered_true() {
927        use std::sync::Arc;
928        use std::sync::atomic::{AtomicBool, Ordering};
929
930        struct GateChecker {
931            saw_gate_true: Arc<AtomicBool>,
932        }
933
934        struct GateVisitor(Arc<AtomicBool>);
935
936        impl tracing::field::Visit for GateVisitor {
937            fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
938                if field.name() == "gate_triggered" && value {
939                    self.0.store(true, Ordering::Relaxed);
940                }
941            }
942            fn record_debug(&mut self, _field: &tracing::field::Field, _value: &dyn fmt::Debug) {}
943        }
944
945        impl tracing::Subscriber for GateChecker {
946            fn enabled(&self, _metadata: &tracing::Metadata<'_>) -> bool {
947                true
948            }
949            fn new_span(&self, _span: &tracing::span::Attributes<'_>) -> tracing::span::Id {
950                tracing::span::Id::from_u64(1)
951            }
952            fn record(&self, _span: &tracing::span::Id, values: &tracing::span::Record<'_>) {
953                let mut visitor = GateVisitor(Arc::clone(&self.saw_gate_true));
954                values.record(&mut visitor);
955            }
956            fn record_follows_from(&self, _span: &tracing::span::Id, _follows: &tracing::span::Id) {
957            }
958            fn event(&self, _event: &tracing::Event<'_>) {}
959            fn enter(&self, _span: &tracing::span::Id) {}
960            fn exit(&self, _span: &tracing::span::Id) {}
961        }
962
963        let saw_gate = Arc::new(AtomicBool::new(false));
964        let subscriber = GateChecker {
965            saw_gate_true: Arc::clone(&saw_gate),
966        };
967        let _guard = tracing::subscriber::set_default(subscriber);
968
969        let predictor = ConformalPredictor::new(ConformalConfig {
970            alpha: 0.1,
971            min_samples: 1,
972            window_size: 8,
973            q_default: 50_000.0, // large default to guarantee risk
974        });
975        let key = test_key(80, 24);
976        // budget_us = 100 << q_default = 50_000 -> risk = true
977        let p = predictor.predict(key, 0.0, 100.0);
978        assert!(p.risk, "prediction should be risky");
979        assert!(
980            saw_gate.load(Ordering::Relaxed),
981            "predict() must record gate_triggered=true when risk"
982        );
983    }
984}