1#![forbid(unsafe_code)]
2
3use std::collections::{HashMap, VecDeque};
13use std::fmt;
14
15use ftui_render::diff_strategy::DiffStrategy;
16
17use crate::terminal_writer::ScreenMode;
18
19#[derive(Debug, Clone)]
21pub struct ConformalConfig {
22 pub alpha: f64,
25
26 pub min_samples: usize,
29
30 pub window_size: usize,
33
34 pub q_default: f64,
37}
38
39impl Default for ConformalConfig {
40 fn default() -> Self {
41 Self {
42 alpha: 0.05,
43 min_samples: 20,
44 window_size: 256,
45 q_default: 10_000.0,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct BucketKey {
53 pub mode: ModeBucket,
54 pub diff: DiffBucket,
55 pub size_bucket: u8,
56}
57
58impl BucketKey {
59 pub fn from_context(
61 screen_mode: ScreenMode,
62 diff_strategy: DiffStrategy,
63 cols: u16,
64 rows: u16,
65 ) -> Self {
66 Self {
67 mode: ModeBucket::from_screen_mode(screen_mode),
68 diff: DiffBucket::from(diff_strategy),
69 size_bucket: size_bucket(cols, rows),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum ModeBucket {
77 Inline,
78 InlineAuto,
79 AltScreen,
80}
81
82impl ModeBucket {
83 pub fn as_str(self) -> &'static str {
84 match self {
85 Self::Inline => "inline",
86 Self::InlineAuto => "inline_auto",
87 Self::AltScreen => "altscreen",
88 }
89 }
90
91 pub fn from_screen_mode(mode: ScreenMode) -> Self {
92 match mode {
93 ScreenMode::Inline { .. } => Self::Inline,
94 ScreenMode::InlineAuto { .. } => Self::InlineAuto,
95 ScreenMode::AltScreen => Self::AltScreen,
96 }
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub enum DiffBucket {
103 Full,
104 DirtyRows,
105 FullRedraw,
106}
107
108impl DiffBucket {
109 pub fn as_str(self) -> &'static str {
110 match self {
111 Self::Full => "full",
112 Self::DirtyRows => "dirty",
113 Self::FullRedraw => "redraw",
114 }
115 }
116}
117
118impl From<DiffStrategy> for DiffBucket {
119 fn from(strategy: DiffStrategy) -> Self {
120 match strategy {
121 DiffStrategy::Full => Self::Full,
122 DiffStrategy::DirtyRows => Self::DirtyRows,
123 DiffStrategy::FullRedraw => Self::FullRedraw,
124 }
125 }
126}
127
128impl fmt::Display for BucketKey {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 write!(
131 f,
132 "{}:{}:{}",
133 self.mode.as_str(),
134 self.diff.as_str(),
135 self.size_bucket
136 )
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct ConformalPrediction {
143 pub upper_us: f64,
145 pub risk: bool,
147 pub confidence: f64,
149 pub bucket: BucketKey,
151 pub sample_count: usize,
153 pub quantile: f64,
155 pub fallback_level: u8,
157 pub window_size: usize,
159 pub reset_count: u64,
161 pub y_hat: f64,
163 pub budget_us: f64,
165}
166
167impl ConformalPrediction {
168 #[must_use]
170 pub fn to_jsonl(&self) -> String {
171 format!(
172 r#"{{"schema":"conformal-v1","upper_us":{:.1},"risk":{},"confidence":{:.4},"bucket":"{}","samples":{},"quantile":{:.2},"fallback_level":{},"window":{},"resets":{},"y_hat":{:.1},"budget_us":{:.1}}}"#,
173 self.upper_us,
174 self.risk,
175 self.confidence,
176 self.bucket,
177 self.sample_count,
178 self.quantile,
179 self.fallback_level,
180 self.window_size,
181 self.reset_count,
182 self.y_hat,
183 self.budget_us,
184 )
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct ConformalUpdate {
191 pub residual: f64,
193 pub bucket: BucketKey,
195 pub sample_count: usize,
197}
198
199#[derive(Debug, Default)]
200struct BucketState {
201 residuals: VecDeque<f64>,
202}
203
204impl BucketState {
205 fn push(&mut self, residual: f64, window_size: usize) {
206 self.residuals.push_back(residual);
207 while self.residuals.len() > window_size {
208 self.residuals.pop_front();
209 }
210 }
211}
212
213#[derive(Debug)]
215pub struct ConformalPredictor {
216 config: ConformalConfig,
217 buckets: HashMap<BucketKey, BucketState>,
218 reset_count: u64,
219}
220
221impl ConformalPredictor {
222 pub fn new(config: ConformalConfig) -> Self {
224 Self {
225 config,
226 buckets: HashMap::new(),
227 reset_count: 0,
228 }
229 }
230
231 pub fn config(&self) -> &ConformalConfig {
233 &self.config
234 }
235
236 pub fn bucket_samples(&self, key: BucketKey) -> usize {
238 self.buckets
239 .get(&key)
240 .map(|state| state.residuals.len())
241 .unwrap_or(0)
242 }
243
244 pub fn reset_all(&mut self) {
246 self.buckets.clear();
247 self.reset_count += 1;
248 }
249
250 pub fn reset_bucket(&mut self, key: BucketKey) {
252 if let Some(state) = self.buckets.get_mut(&key) {
253 state.residuals.clear();
254 self.reset_count += 1;
255 }
256 }
257
258 pub fn observe(&mut self, key: BucketKey, y_hat_us: f64, observed_us: f64) -> ConformalUpdate {
260 let residual = observed_us - y_hat_us;
261 if !residual.is_finite() {
262 return ConformalUpdate {
263 residual,
264 bucket: key,
265 sample_count: self.bucket_samples(key),
266 };
267 }
268
269 let window_size = self.config.window_size.max(1);
270 let state = self.buckets.entry(key).or_default();
271 state.push(residual, window_size);
272 ConformalUpdate {
273 residual,
274 bucket: key,
275 sample_count: state.residuals.len(),
276 }
277 }
278
279 pub fn predict(&self, key: BucketKey, y_hat_us: f64, budget_us: f64) -> ConformalPrediction {
281 let QuantileDecision {
282 quantile,
283 sample_count,
284 fallback_level,
285 } = self.quantile_for(key);
286
287 let upper_us = y_hat_us + quantile.max(0.0);
288 let risk = upper_us > budget_us;
289
290 ConformalPrediction {
291 upper_us,
292 risk,
293 confidence: 1.0 - self.config.alpha,
294 bucket: key,
295 sample_count,
296 quantile,
297 fallback_level,
298 window_size: self.config.window_size,
299 reset_count: self.reset_count,
300 y_hat: y_hat_us,
301 budget_us,
302 }
303 }
304
305 fn quantile_for(&self, key: BucketKey) -> QuantileDecision {
306 let min_samples = self.config.min_samples.max(1);
307
308 let exact = self.collect_exact(key);
309 if exact.len() >= min_samples {
310 return QuantileDecision::new(self.config.alpha, exact, 0);
311 }
312
313 let mode_diff = self.collect_mode_diff(key.mode, key.diff);
314 if mode_diff.len() >= min_samples {
315 return QuantileDecision::new(self.config.alpha, mode_diff, 1);
316 }
317
318 let mode_only = self.collect_mode(key.mode);
319 if mode_only.len() >= min_samples {
320 return QuantileDecision::new(self.config.alpha, mode_only, 2);
321 }
322
323 let global = self.collect_all();
324 if !global.is_empty() {
325 return QuantileDecision::new(self.config.alpha, global, 3);
326 }
327
328 QuantileDecision {
329 quantile: self.config.q_default,
330 sample_count: 0,
331 fallback_level: 3,
332 }
333 }
334
335 fn collect_exact(&self, key: BucketKey) -> Vec<f64> {
336 self.buckets
337 .get(&key)
338 .map(|state| state.residuals.iter().copied().collect())
339 .unwrap_or_default()
340 }
341
342 fn collect_mode_diff(&self, mode: ModeBucket, diff: DiffBucket) -> Vec<f64> {
343 let mut values = Vec::new();
344 for (key, state) in &self.buckets {
345 if key.mode == mode && key.diff == diff {
346 values.extend(state.residuals.iter().copied());
347 }
348 }
349 values
350 }
351
352 fn collect_mode(&self, mode: ModeBucket) -> Vec<f64> {
353 let mut values = Vec::new();
354 for (key, state) in &self.buckets {
355 if key.mode == mode {
356 values.extend(state.residuals.iter().copied());
357 }
358 }
359 values
360 }
361
362 fn collect_all(&self) -> Vec<f64> {
363 let mut values = Vec::new();
364 for state in self.buckets.values() {
365 values.extend(state.residuals.iter().copied());
366 }
367 values
368 }
369}
370
371#[derive(Debug)]
372struct QuantileDecision {
373 quantile: f64,
374 sample_count: usize,
375 fallback_level: u8,
376}
377
378impl QuantileDecision {
379 fn new(alpha: f64, mut residuals: Vec<f64>, fallback_level: u8) -> Self {
380 let quantile = conformal_quantile(alpha, &mut residuals);
381 Self {
382 quantile,
383 sample_count: residuals.len(),
384 fallback_level,
385 }
386 }
387}
388
389fn conformal_quantile(alpha: f64, residuals: &mut [f64]) -> f64 {
390 if residuals.is_empty() {
391 return 0.0;
392 }
393 residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
394 let n = residuals.len();
395 let rank = ((n as f64 + 1.0) * (1.0 - alpha)).ceil() as usize;
396 let idx = rank.saturating_sub(1).min(n - 1);
397 residuals[idx]
398}
399
400fn size_bucket(cols: u16, rows: u16) -> u8 {
401 let area = cols as u32 * rows as u32;
402 if area == 0 {
403 return 0;
404 }
405 (31 - area.leading_zeros()) as u8
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 fn test_key(cols: u16, rows: u16) -> BucketKey {
413 BucketKey::from_context(
414 ScreenMode::Inline { ui_height: 4 },
415 DiffStrategy::Full,
416 cols,
417 rows,
418 )
419 }
420
421 #[test]
422 fn quantile_n_plus_1_rule() {
423 let mut predictor = ConformalPredictor::new(ConformalConfig {
424 alpha: 0.2,
425 min_samples: 1,
426 window_size: 10,
427 q_default: 0.0,
428 });
429
430 let key = test_key(80, 24);
431 predictor.observe(key, 0.0, 1.0);
432 predictor.observe(key, 0.0, 2.0);
433 predictor.observe(key, 0.0, 3.0);
434
435 let decision = predictor.predict(key, 0.0, 1_000.0);
436 assert_eq!(decision.quantile, 3.0);
437 }
438
439 #[test]
440 fn fallback_hierarchy_mode_diff() {
441 let mut predictor = ConformalPredictor::new(ConformalConfig {
442 alpha: 0.1,
443 min_samples: 4,
444 window_size: 16,
445 q_default: 0.0,
446 });
447
448 let key_a = test_key(80, 24);
449 for value in [1.0, 2.0, 3.0, 4.0] {
450 predictor.observe(key_a, 0.0, value);
451 }
452
453 let key_b = test_key(120, 40);
454 let decision = predictor.predict(key_b, 0.0, 1_000.0);
455 assert_eq!(decision.fallback_level, 1);
456 assert_eq!(decision.sample_count, 4);
457 }
458
459 #[test]
460 fn fallback_hierarchy_mode_only() {
461 let mut predictor = ConformalPredictor::new(ConformalConfig {
462 alpha: 0.1,
463 min_samples: 3,
464 window_size: 16,
465 q_default: 0.0,
466 });
467
468 let key_dirty = BucketKey::from_context(
469 ScreenMode::Inline { ui_height: 4 },
470 DiffStrategy::DirtyRows,
471 80,
472 24,
473 );
474 for value in [10.0, 20.0, 30.0] {
475 predictor.observe(key_dirty, 0.0, value);
476 }
477
478 let key_full = BucketKey::from_context(
479 ScreenMode::Inline { ui_height: 4 },
480 DiffStrategy::Full,
481 120,
482 40,
483 );
484 let decision = predictor.predict(key_full, 0.0, 1_000.0);
485 assert_eq!(decision.fallback_level, 2);
486 assert_eq!(decision.sample_count, 3);
487 }
488
489 #[test]
490 fn window_enforced() {
491 let mut predictor = ConformalPredictor::new(ConformalConfig {
492 alpha: 0.1,
493 min_samples: 1,
494 window_size: 3,
495 q_default: 0.0,
496 });
497 let key = test_key(80, 24);
498 for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
499 predictor.observe(key, 0.0, value);
500 }
501 assert_eq!(predictor.bucket_samples(key), 3);
502 }
503
504 #[test]
505 fn predict_uses_default_when_empty() {
506 let predictor = ConformalPredictor::new(ConformalConfig {
507 alpha: 0.1,
508 min_samples: 2,
509 window_size: 4,
510 q_default: 42.0,
511 });
512 let key = test_key(120, 40);
513 let prediction = predictor.predict(key, 5.0, 10_000.0);
514 assert_eq!(prediction.quantile, 42.0);
515 assert_eq!(prediction.sample_count, 0);
516 assert_eq!(prediction.fallback_level, 3);
517 }
518
519 #[test]
520 fn bucket_isolation_by_size() {
521 let mut predictor = ConformalPredictor::new(ConformalConfig {
522 alpha: 0.2,
523 min_samples: 2,
524 window_size: 10,
525 q_default: 0.0,
526 });
527
528 let small = test_key(40, 10);
529 predictor.observe(small, 0.0, 1.0);
530 predictor.observe(small, 0.0, 2.0);
531
532 let large = test_key(200, 60);
533 predictor.observe(large, 0.0, 10.0);
534 predictor.observe(large, 0.0, 12.0);
535
536 let prediction = predictor.predict(large, 0.0, 1_000.0);
537 assert_eq!(prediction.fallback_level, 0);
538 assert_eq!(prediction.sample_count, 2);
539 assert_eq!(prediction.quantile, 12.0);
540 }
541
542 #[test]
543 fn reset_clears_bucket_and_raises_reset_count() {
544 let mut predictor = ConformalPredictor::new(ConformalConfig {
545 alpha: 0.1,
546 min_samples: 1,
547 window_size: 8,
548 q_default: 7.0,
549 });
550 let key = test_key(80, 24);
551 predictor.observe(key, 0.0, 3.0);
552 assert_eq!(predictor.bucket_samples(key), 1);
553
554 predictor.reset_bucket(key);
555 assert_eq!(predictor.bucket_samples(key), 0);
556
557 let prediction = predictor.predict(key, 0.0, 1_000.0);
558 assert_eq!(prediction.quantile, 7.0);
559 assert_eq!(prediction.reset_count, 1);
560 }
561
562 #[test]
563 fn reset_all_forces_conservative_fallback() {
564 let mut predictor = ConformalPredictor::new(ConformalConfig {
565 alpha: 0.1,
566 min_samples: 1,
567 window_size: 8,
568 q_default: 9.0,
569 });
570 let key = test_key(80, 24);
571 predictor.observe(key, 0.0, 2.0);
572
573 predictor.reset_all();
574 let prediction = predictor.predict(key, 0.0, 1_000.0);
575 assert_eq!(prediction.quantile, 9.0);
576 assert_eq!(prediction.sample_count, 0);
577 assert_eq!(prediction.fallback_level, 3);
578 assert_eq!(prediction.reset_count, 1);
579 }
580
581 #[test]
582 fn size_bucket_log2_area() {
583 let a = size_bucket(8, 8); let b = size_bucket(8, 16); assert_eq!(a, 6);
586 assert_eq!(b, 7);
587 }
588
589 #[test]
592 fn size_bucket_zero_area() {
593 assert_eq!(size_bucket(0, 0), 0);
594 assert_eq!(size_bucket(0, 24), 0);
595 assert_eq!(size_bucket(80, 0), 0);
596 }
597
598 #[test]
599 fn size_bucket_one_by_one() {
600 assert_eq!(size_bucket(1, 1), 0); }
602
603 #[test]
604 fn size_bucket_typical_terminals() {
605 let b80 = size_bucket(80, 24); let b120 = size_bucket(120, 40); assert_eq!(b80, 10);
608 assert_eq!(b120, 12);
609 }
610
611 #[test]
614 fn conformal_quantile_empty() {
615 let mut data: Vec<f64> = vec![];
616 assert_eq!(conformal_quantile(0.1, &mut data), 0.0);
617 }
618
619 #[test]
620 fn conformal_quantile_single_element() {
621 let mut data = vec![42.0];
622 assert_eq!(conformal_quantile(0.1, &mut data), 42.0);
623 }
624
625 #[test]
626 fn conformal_quantile_sorted_data() {
627 let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
628 let q = conformal_quantile(0.5, &mut data);
629 assert_eq!(q, 3.0);
631 }
632
633 #[test]
634 fn conformal_quantile_alpha_half() {
635 let mut data = vec![10.0, 20.0, 30.0, 40.0];
636 let q = conformal_quantile(0.5, &mut data);
637 assert_eq!(q, 30.0);
639 }
640
641 #[test]
644 fn mode_bucket_as_str_all_variants() {
645 assert_eq!(ModeBucket::Inline.as_str(), "inline");
646 assert_eq!(ModeBucket::InlineAuto.as_str(), "inline_auto");
647 assert_eq!(ModeBucket::AltScreen.as_str(), "altscreen");
648 }
649
650 #[test]
651 fn diff_bucket_as_str_all_variants() {
652 assert_eq!(DiffBucket::Full.as_str(), "full");
653 assert_eq!(DiffBucket::DirtyRows.as_str(), "dirty");
654 assert_eq!(DiffBucket::FullRedraw.as_str(), "redraw");
655 }
656
657 #[test]
658 fn diff_bucket_from_strategy() {
659 assert_eq!(DiffBucket::from(DiffStrategy::Full), DiffBucket::Full);
660 assert_eq!(
661 DiffBucket::from(DiffStrategy::DirtyRows),
662 DiffBucket::DirtyRows
663 );
664 assert_eq!(
665 DiffBucket::from(DiffStrategy::FullRedraw),
666 DiffBucket::FullRedraw
667 );
668 }
669
670 #[test]
673 fn bucket_key_display_format() {
674 let key = BucketKey {
675 mode: ModeBucket::AltScreen,
676 diff: DiffBucket::DirtyRows,
677 size_bucket: 12,
678 };
679 assert_eq!(format!("{key}"), "altscreen:dirty:12");
680 }
681
682 #[test]
685 fn observe_nan_residual_not_stored() {
686 let mut predictor = ConformalPredictor::new(ConformalConfig {
687 alpha: 0.1,
688 min_samples: 1,
689 window_size: 8,
690 q_default: 5.0,
691 });
692 let key = test_key(80, 24);
693 let update = predictor.observe(key, 0.0, f64::NAN);
694 assert!(!update.residual.is_finite());
695 assert_eq!(predictor.bucket_samples(key), 0);
696 }
697
698 #[test]
699 fn observe_infinity_residual_not_stored() {
700 let mut predictor = ConformalPredictor::new(ConformalConfig {
701 alpha: 0.1,
702 min_samples: 1,
703 window_size: 8,
704 q_default: 5.0,
705 });
706 let key = test_key(80, 24);
707 predictor.observe(key, 0.0, f64::INFINITY);
708 assert_eq!(predictor.bucket_samples(key), 0);
709 }
710
711 #[test]
714 fn prediction_risk_flag() {
715 let predictor = ConformalPredictor::new(ConformalConfig {
716 alpha: 0.1,
717 min_samples: 1,
718 window_size: 8,
719 q_default: 50.0,
720 });
721 let key = test_key(80, 24);
722 let p = predictor.predict(key, 0.0, 100.0);
724 assert!(!p.risk); let p2 = predictor.predict(key, 0.0, 30.0);
726 assert!(p2.risk); }
728
729 #[test]
730 fn prediction_confidence() {
731 let predictor = ConformalPredictor::new(ConformalConfig {
732 alpha: 0.05,
733 min_samples: 1,
734 window_size: 8,
735 q_default: 0.0,
736 });
737 let key = test_key(80, 24);
738 let p = predictor.predict(key, 0.0, 100.0);
739 assert!((p.confidence - 0.95).abs() < 1e-10);
740 }
741
742 #[test]
745 fn global_fallback_with_data() {
746 let mut predictor = ConformalPredictor::new(ConformalConfig {
747 alpha: 0.1,
748 min_samples: 100, window_size: 256,
750 q_default: 999.0,
751 });
752 let alt_key = BucketKey::from_context(ScreenMode::AltScreen, DiffStrategy::Full, 80, 24);
754 predictor.observe(alt_key, 0.0, 5.0);
755
756 let inline_key = test_key(80, 24);
757 let p = predictor.predict(inline_key, 0.0, 1000.0);
758 assert_eq!(p.fallback_level, 3);
760 assert_eq!(p.sample_count, 1);
761 assert_eq!(p.quantile, 5.0);
762 }
763
764 #[test]
767 fn mode_bucket_from_screen_modes() {
768 assert_eq!(
769 ModeBucket::from_screen_mode(ScreenMode::Inline { ui_height: 4 }),
770 ModeBucket::Inline
771 );
772 assert_eq!(
773 ModeBucket::from_screen_mode(ScreenMode::InlineAuto {
774 min_height: 4,
775 max_height: 24
776 }),
777 ModeBucket::InlineAuto
778 );
779 assert_eq!(
780 ModeBucket::from_screen_mode(ScreenMode::AltScreen),
781 ModeBucket::AltScreen
782 );
783 }
784
785 #[test]
788 fn config_defaults() {
789 let config = ConformalConfig::default();
790 assert!((config.alpha - 0.05).abs() < 1e-10);
791 assert_eq!(config.min_samples, 20);
792 assert_eq!(config.window_size, 256);
793 assert!((config.q_default - 10_000.0).abs() < 1e-10);
794 }
795
796 #[test]
797 fn predictor_config_accessor() {
798 let config = ConformalConfig {
799 alpha: 0.2,
800 min_samples: 5,
801 window_size: 32,
802 q_default: 100.0,
803 };
804 let predictor = ConformalPredictor::new(config);
805 assert!((predictor.config().alpha - 0.2).abs() < 1e-10);
806 assert_eq!(predictor.config().min_samples, 5);
807 }
808
809 #[test]
812 fn negative_residual_clamped_in_prediction() {
813 let mut predictor = ConformalPredictor::new(ConformalConfig {
814 alpha: 0.1,
815 min_samples: 1,
816 window_size: 8,
817 q_default: 0.0,
818 });
819 let key = test_key(80, 24);
820 predictor.observe(key, 10.0, 5.0);
822 let p = predictor.predict(key, 10.0, 100.0);
823 assert_eq!(p.upper_us, 10.0);
826 }
827
828 #[test]
831 fn observe_returns_correct_update() {
832 let mut predictor = ConformalPredictor::new(ConformalConfig {
833 alpha: 0.1,
834 min_samples: 1,
835 window_size: 8,
836 q_default: 0.0,
837 });
838 let key = test_key(80, 24);
839 let update = predictor.observe(key, 3.0, 10.0);
840 assert!((update.residual - 7.0).abs() < 1e-10);
841 assert_eq!(update.bucket, key);
842 assert_eq!(update.sample_count, 1);
843 }
844
845 #[test]
848 fn prediction_preserves_yhat_and_budget() {
849 let predictor = ConformalPredictor::new(ConformalConfig::default());
850 let key = test_key(80, 24);
851 let p = predictor.predict(key, 42.5, 16666.0);
852 assert!((p.y_hat - 42.5).abs() < 1e-10);
853 assert!((p.budget_us - 16666.0).abs() < 1e-10);
854 }
855}