Skip to main content

ftui_runtime/
conformal_predictor.rs

1#![forbid(unsafe_code)]
2
3//! Conformal predictor for frame-time risk (bd-3e1t.3.2).
4//!
5//! This module provides a distribution-free upper bound on frame time using
6//! Mondrian (bucketed) conformal prediction. It is intentionally lightweight
7//! and explainable: each prediction returns the bucket key, quantile, and
8//! fallback level used to produce the bound.
9//!
10//! See docs/spec/state-machines.md section 3.13 for the governing spec.
11
12use std::collections::{HashMap, VecDeque};
13use std::fmt;
14
15use ftui_render::diff_strategy::DiffStrategy;
16
17use crate::terminal_writer::ScreenMode;
18
19/// Configuration for conformal frame-time prediction.
20#[derive(Debug, Clone)]
21pub struct ConformalConfig {
22    /// Significance level alpha. Coverage is >= 1 - alpha.
23    /// Default: 0.05.
24    pub alpha: f64,
25
26    /// Minimum samples required before a bucket is considered valid.
27    /// Default: 20.
28    pub min_samples: usize,
29
30    /// Maximum samples retained per bucket (rolling window).
31    /// Default: 256.
32    pub window_size: usize,
33
34    /// Conservative fallback residual (microseconds) when no calibration exists.
35    /// Default: 10_000.0 (10ms).
36    pub q_default: f64,
37}
38
39impl Default for ConformalConfig {
40    fn default() -> Self {
41        Self {
42            alpha: 0.05,
43            min_samples: 20,
44            window_size: 256,
45            q_default: 10_000.0,
46        }
47    }
48}
49
50/// Bucket identifier for conformal calibration.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct BucketKey {
53    pub mode: ModeBucket,
54    pub diff: DiffBucket,
55    pub size_bucket: u8,
56}
57
58impl BucketKey {
59    /// Create a bucket key from rendering context.
60    pub fn from_context(
61        screen_mode: ScreenMode,
62        diff_strategy: DiffStrategy,
63        cols: u16,
64        rows: u16,
65    ) -> Self {
66        Self {
67            mode: ModeBucket::from_screen_mode(screen_mode),
68            diff: DiffBucket::from(diff_strategy),
69            size_bucket: size_bucket(cols, rows),
70        }
71    }
72}
73
74/// Mode bucket for conformal calibration.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum ModeBucket {
77    Inline,
78    InlineAuto,
79    AltScreen,
80}
81
82impl ModeBucket {
83    pub fn as_str(self) -> &'static str {
84        match self {
85            Self::Inline => "inline",
86            Self::InlineAuto => "inline_auto",
87            Self::AltScreen => "altscreen",
88        }
89    }
90
91    pub fn from_screen_mode(mode: ScreenMode) -> Self {
92        match mode {
93            ScreenMode::Inline { .. } => Self::Inline,
94            ScreenMode::InlineAuto { .. } => Self::InlineAuto,
95            ScreenMode::AltScreen => Self::AltScreen,
96        }
97    }
98}
99
100/// Diff strategy bucket for conformal calibration.
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub enum DiffBucket {
103    Full,
104    DirtyRows,
105    FullRedraw,
106}
107
108impl DiffBucket {
109    pub fn as_str(self) -> &'static str {
110        match self {
111            Self::Full => "full",
112            Self::DirtyRows => "dirty",
113            Self::FullRedraw => "redraw",
114        }
115    }
116}
117
118impl From<DiffStrategy> for DiffBucket {
119    fn from(strategy: DiffStrategy) -> Self {
120        match strategy {
121            DiffStrategy::Full => Self::Full,
122            DiffStrategy::DirtyRows => Self::DirtyRows,
123            DiffStrategy::FullRedraw => Self::FullRedraw,
124        }
125    }
126}
127
128impl fmt::Display for BucketKey {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(
131            f,
132            "{}:{}:{}",
133            self.mode.as_str(),
134            self.diff.as_str(),
135            self.size_bucket
136        )
137    }
138}
139
140/// Prediction output with full explainability.
141#[derive(Debug, Clone)]
142pub struct ConformalPrediction {
143    /// Upper bound on frame time (microseconds).
144    pub upper_us: f64,
145    /// Whether the bound exceeds the current budget.
146    pub risk: bool,
147    /// Coverage confidence (1 - alpha).
148    pub confidence: f64,
149    /// Bucket key used for calibration (may be fallback aggregate).
150    pub bucket: BucketKey,
151    /// Calibration sample count used for the quantile.
152    pub sample_count: usize,
153    /// Conformal quantile q_b.
154    pub quantile: f64,
155    /// Fallback level (0 = exact, 1 = mode+diff, 2 = mode, 3 = global/default).
156    pub fallback_level: u8,
157    /// Rolling window size.
158    pub window_size: usize,
159    /// Total reset count for this predictor.
160    pub reset_count: u64,
161    /// Base prediction f(x_t).
162    pub y_hat: f64,
163    /// Frame budget in microseconds.
164    pub budget_us: f64,
165}
166
167/// Update metadata after observing a frame.
168#[derive(Debug, Clone)]
169pub struct ConformalUpdate {
170    /// Residual (y_t - f(x_t)).
171    pub residual: f64,
172    /// Bucket updated.
173    pub bucket: BucketKey,
174    /// New sample count in the bucket.
175    pub sample_count: usize,
176}
177
178#[derive(Debug, Default)]
179struct BucketState {
180    residuals: VecDeque<f64>,
181}
182
183impl BucketState {
184    fn push(&mut self, residual: f64, window_size: usize) {
185        self.residuals.push_back(residual);
186        while self.residuals.len() > window_size {
187            self.residuals.pop_front();
188        }
189    }
190}
191
192/// Conformal predictor with bucketed calibration.
193#[derive(Debug)]
194pub struct ConformalPredictor {
195    config: ConformalConfig,
196    buckets: HashMap<BucketKey, BucketState>,
197    reset_count: u64,
198}
199
200impl ConformalPredictor {
201    /// Create a new predictor with the given config.
202    pub fn new(config: ConformalConfig) -> Self {
203        Self {
204            config,
205            buckets: HashMap::new(),
206            reset_count: 0,
207        }
208    }
209
210    /// Access the configuration.
211    pub fn config(&self) -> &ConformalConfig {
212        &self.config
213    }
214
215    /// Number of samples currently stored for a bucket.
216    pub fn bucket_samples(&self, key: BucketKey) -> usize {
217        self.buckets
218            .get(&key)
219            .map(|state| state.residuals.len())
220            .unwrap_or(0)
221    }
222
223    /// Clear calibration for all buckets.
224    pub fn reset_all(&mut self) {
225        self.buckets.clear();
226        self.reset_count += 1;
227    }
228
229    /// Clear calibration for a single bucket.
230    pub fn reset_bucket(&mut self, key: BucketKey) {
231        if let Some(state) = self.buckets.get_mut(&key) {
232            state.residuals.clear();
233            self.reset_count += 1;
234        }
235    }
236
237    /// Observe a realized frame time and update calibration.
238    pub fn observe(&mut self, key: BucketKey, y_hat_us: f64, observed_us: f64) -> ConformalUpdate {
239        let residual = observed_us - y_hat_us;
240        if !residual.is_finite() {
241            return ConformalUpdate {
242                residual,
243                bucket: key,
244                sample_count: self.bucket_samples(key),
245            };
246        }
247
248        let window_size = self.config.window_size.max(1);
249        let state = self.buckets.entry(key).or_default();
250        state.push(residual, window_size);
251        ConformalUpdate {
252            residual,
253            bucket: key,
254            sample_count: state.residuals.len(),
255        }
256    }
257
258    /// Predict a conservative upper bound for frame time.
259    pub fn predict(&self, key: BucketKey, y_hat_us: f64, budget_us: f64) -> ConformalPrediction {
260        let QuantileDecision {
261            quantile,
262            sample_count,
263            fallback_level,
264        } = self.quantile_for(key);
265
266        let upper_us = y_hat_us + quantile.max(0.0);
267        let risk = upper_us > budget_us;
268
269        ConformalPrediction {
270            upper_us,
271            risk,
272            confidence: 1.0 - self.config.alpha,
273            bucket: key,
274            sample_count,
275            quantile,
276            fallback_level,
277            window_size: self.config.window_size,
278            reset_count: self.reset_count,
279            y_hat: y_hat_us,
280            budget_us,
281        }
282    }
283
284    fn quantile_for(&self, key: BucketKey) -> QuantileDecision {
285        let min_samples = self.config.min_samples.max(1);
286
287        let exact = self.collect_exact(key);
288        if exact.len() >= min_samples {
289            return QuantileDecision::new(self.config.alpha, exact, 0);
290        }
291
292        let mode_diff = self.collect_mode_diff(key.mode, key.diff);
293        if mode_diff.len() >= min_samples {
294            return QuantileDecision::new(self.config.alpha, mode_diff, 1);
295        }
296
297        let mode_only = self.collect_mode(key.mode);
298        if mode_only.len() >= min_samples {
299            return QuantileDecision::new(self.config.alpha, mode_only, 2);
300        }
301
302        let global = self.collect_all();
303        if !global.is_empty() {
304            return QuantileDecision::new(self.config.alpha, global, 3);
305        }
306
307        QuantileDecision {
308            quantile: self.config.q_default,
309            sample_count: 0,
310            fallback_level: 3,
311        }
312    }
313
314    fn collect_exact(&self, key: BucketKey) -> Vec<f64> {
315        self.buckets
316            .get(&key)
317            .map(|state| state.residuals.iter().copied().collect())
318            .unwrap_or_default()
319    }
320
321    fn collect_mode_diff(&self, mode: ModeBucket, diff: DiffBucket) -> Vec<f64> {
322        let mut values = Vec::new();
323        for (key, state) in &self.buckets {
324            if key.mode == mode && key.diff == diff {
325                values.extend(state.residuals.iter().copied());
326            }
327        }
328        values
329    }
330
331    fn collect_mode(&self, mode: ModeBucket) -> Vec<f64> {
332        let mut values = Vec::new();
333        for (key, state) in &self.buckets {
334            if key.mode == mode {
335                values.extend(state.residuals.iter().copied());
336            }
337        }
338        values
339    }
340
341    fn collect_all(&self) -> Vec<f64> {
342        let mut values = Vec::new();
343        for state in self.buckets.values() {
344            values.extend(state.residuals.iter().copied());
345        }
346        values
347    }
348}
349
350#[derive(Debug)]
351struct QuantileDecision {
352    quantile: f64,
353    sample_count: usize,
354    fallback_level: u8,
355}
356
357impl QuantileDecision {
358    fn new(alpha: f64, mut residuals: Vec<f64>, fallback_level: u8) -> Self {
359        let quantile = conformal_quantile(alpha, &mut residuals);
360        Self {
361            quantile,
362            sample_count: residuals.len(),
363            fallback_level,
364        }
365    }
366}
367
368fn conformal_quantile(alpha: f64, residuals: &mut [f64]) -> f64 {
369    if residuals.is_empty() {
370        return 0.0;
371    }
372    residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
373    let n = residuals.len();
374    let rank = ((n as f64 + 1.0) * (1.0 - alpha)).ceil() as usize;
375    let idx = rank.saturating_sub(1).min(n - 1);
376    residuals[idx]
377}
378
379fn size_bucket(cols: u16, rows: u16) -> u8 {
380    let area = cols as u32 * rows as u32;
381    if area == 0 {
382        return 0;
383    }
384    (31 - area.leading_zeros()) as u8
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    fn test_key(cols: u16, rows: u16) -> BucketKey {
392        BucketKey::from_context(
393            ScreenMode::Inline { ui_height: 4 },
394            DiffStrategy::Full,
395            cols,
396            rows,
397        )
398    }
399
400    #[test]
401    fn quantile_n_plus_1_rule() {
402        let mut predictor = ConformalPredictor::new(ConformalConfig {
403            alpha: 0.2,
404            min_samples: 1,
405            window_size: 10,
406            q_default: 0.0,
407        });
408
409        let key = test_key(80, 24);
410        predictor.observe(key, 0.0, 1.0);
411        predictor.observe(key, 0.0, 2.0);
412        predictor.observe(key, 0.0, 3.0);
413
414        let decision = predictor.predict(key, 0.0, 1_000.0);
415        assert_eq!(decision.quantile, 3.0);
416    }
417
418    #[test]
419    fn fallback_hierarchy_mode_diff() {
420        let mut predictor = ConformalPredictor::new(ConformalConfig {
421            alpha: 0.1,
422            min_samples: 4,
423            window_size: 16,
424            q_default: 0.0,
425        });
426
427        let key_a = test_key(80, 24);
428        for value in [1.0, 2.0, 3.0, 4.0] {
429            predictor.observe(key_a, 0.0, value);
430        }
431
432        let key_b = test_key(120, 40);
433        let decision = predictor.predict(key_b, 0.0, 1_000.0);
434        assert_eq!(decision.fallback_level, 1);
435        assert_eq!(decision.sample_count, 4);
436    }
437
438    #[test]
439    fn fallback_hierarchy_mode_only() {
440        let mut predictor = ConformalPredictor::new(ConformalConfig {
441            alpha: 0.1,
442            min_samples: 3,
443            window_size: 16,
444            q_default: 0.0,
445        });
446
447        let key_dirty = BucketKey::from_context(
448            ScreenMode::Inline { ui_height: 4 },
449            DiffStrategy::DirtyRows,
450            80,
451            24,
452        );
453        for value in [10.0, 20.0, 30.0] {
454            predictor.observe(key_dirty, 0.0, value);
455        }
456
457        let key_full = BucketKey::from_context(
458            ScreenMode::Inline { ui_height: 4 },
459            DiffStrategy::Full,
460            120,
461            40,
462        );
463        let decision = predictor.predict(key_full, 0.0, 1_000.0);
464        assert_eq!(decision.fallback_level, 2);
465        assert_eq!(decision.sample_count, 3);
466    }
467
468    #[test]
469    fn window_enforced() {
470        let mut predictor = ConformalPredictor::new(ConformalConfig {
471            alpha: 0.1,
472            min_samples: 1,
473            window_size: 3,
474            q_default: 0.0,
475        });
476        let key = test_key(80, 24);
477        for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
478            predictor.observe(key, 0.0, value);
479        }
480        assert_eq!(predictor.bucket_samples(key), 3);
481    }
482
483    #[test]
484    fn predict_uses_default_when_empty() {
485        let predictor = ConformalPredictor::new(ConformalConfig {
486            alpha: 0.1,
487            min_samples: 2,
488            window_size: 4,
489            q_default: 42.0,
490        });
491        let key = test_key(120, 40);
492        let prediction = predictor.predict(key, 5.0, 10_000.0);
493        assert_eq!(prediction.quantile, 42.0);
494        assert_eq!(prediction.sample_count, 0);
495        assert_eq!(prediction.fallback_level, 3);
496    }
497
498    #[test]
499    fn bucket_isolation_by_size() {
500        let mut predictor = ConformalPredictor::new(ConformalConfig {
501            alpha: 0.2,
502            min_samples: 2,
503            window_size: 10,
504            q_default: 0.0,
505        });
506
507        let small = test_key(40, 10);
508        predictor.observe(small, 0.0, 1.0);
509        predictor.observe(small, 0.0, 2.0);
510
511        let large = test_key(200, 60);
512        predictor.observe(large, 0.0, 10.0);
513        predictor.observe(large, 0.0, 12.0);
514
515        let prediction = predictor.predict(large, 0.0, 1_000.0);
516        assert_eq!(prediction.fallback_level, 0);
517        assert_eq!(prediction.sample_count, 2);
518        assert_eq!(prediction.quantile, 12.0);
519    }
520
521    #[test]
522    fn reset_clears_bucket_and_raises_reset_count() {
523        let mut predictor = ConformalPredictor::new(ConformalConfig {
524            alpha: 0.1,
525            min_samples: 1,
526            window_size: 8,
527            q_default: 7.0,
528        });
529        let key = test_key(80, 24);
530        predictor.observe(key, 0.0, 3.0);
531        assert_eq!(predictor.bucket_samples(key), 1);
532
533        predictor.reset_bucket(key);
534        assert_eq!(predictor.bucket_samples(key), 0);
535
536        let prediction = predictor.predict(key, 0.0, 1_000.0);
537        assert_eq!(prediction.quantile, 7.0);
538        assert_eq!(prediction.reset_count, 1);
539    }
540
541    #[test]
542    fn reset_all_forces_conservative_fallback() {
543        let mut predictor = ConformalPredictor::new(ConformalConfig {
544            alpha: 0.1,
545            min_samples: 1,
546            window_size: 8,
547            q_default: 9.0,
548        });
549        let key = test_key(80, 24);
550        predictor.observe(key, 0.0, 2.0);
551
552        predictor.reset_all();
553        let prediction = predictor.predict(key, 0.0, 1_000.0);
554        assert_eq!(prediction.quantile, 9.0);
555        assert_eq!(prediction.sample_count, 0);
556        assert_eq!(prediction.fallback_level, 3);
557        assert_eq!(prediction.reset_count, 1);
558    }
559
560    #[test]
561    fn size_bucket_log2_area() {
562        let a = size_bucket(8, 8); // area 64 -> log2 = 6
563        let b = size_bucket(8, 16); // area 128 -> log2 = 7
564        assert_eq!(a, 6);
565        assert_eq!(b, 7);
566    }
567}