1#![forbid(unsafe_code)]
2
3use std::collections::{HashMap, VecDeque};
13use std::fmt;
14
15use ftui_render::diff_strategy::DiffStrategy;
16
17use crate::terminal_writer::ScreenMode;
18
19#[derive(Debug, Clone)]
21pub struct ConformalConfig {
22 pub alpha: f64,
25
26 pub min_samples: usize,
29
30 pub window_size: usize,
33
34 pub q_default: f64,
37}
38
39impl Default for ConformalConfig {
40 fn default() -> Self {
41 Self {
42 alpha: 0.05,
43 min_samples: 20,
44 window_size: 256,
45 q_default: 10_000.0,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct BucketKey {
53 pub mode: ModeBucket,
54 pub diff: DiffBucket,
55 pub size_bucket: u8,
56}
57
58impl BucketKey {
59 pub fn from_context(
61 screen_mode: ScreenMode,
62 diff_strategy: DiffStrategy,
63 cols: u16,
64 rows: u16,
65 ) -> Self {
66 Self {
67 mode: ModeBucket::from_screen_mode(screen_mode),
68 diff: DiffBucket::from(diff_strategy),
69 size_bucket: size_bucket(cols, rows),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum ModeBucket {
77 Inline,
78 InlineAuto,
79 AltScreen,
80}
81
82impl ModeBucket {
83 pub fn as_str(self) -> &'static str {
84 match self {
85 Self::Inline => "inline",
86 Self::InlineAuto => "inline_auto",
87 Self::AltScreen => "altscreen",
88 }
89 }
90
91 pub fn from_screen_mode(mode: ScreenMode) -> Self {
92 match mode {
93 ScreenMode::Inline { .. } => Self::Inline,
94 ScreenMode::InlineAuto { .. } => Self::InlineAuto,
95 ScreenMode::AltScreen => Self::AltScreen,
96 }
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub enum DiffBucket {
103 Full,
104 DirtyRows,
105 FullRedraw,
106}
107
108impl DiffBucket {
109 pub fn as_str(self) -> &'static str {
110 match self {
111 Self::Full => "full",
112 Self::DirtyRows => "dirty",
113 Self::FullRedraw => "redraw",
114 }
115 }
116}
117
118impl From<DiffStrategy> for DiffBucket {
119 fn from(strategy: DiffStrategy) -> Self {
120 match strategy {
121 DiffStrategy::Full => Self::Full,
122 DiffStrategy::DirtyRows => Self::DirtyRows,
123 DiffStrategy::FullRedraw => Self::FullRedraw,
124 }
125 }
126}
127
128impl fmt::Display for BucketKey {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 write!(
131 f,
132 "{}:{}:{}",
133 self.mode.as_str(),
134 self.diff.as_str(),
135 self.size_bucket
136 )
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct ConformalPrediction {
143 pub upper_us: f64,
145 pub risk: bool,
147 pub confidence: f64,
149 pub bucket: BucketKey,
151 pub sample_count: usize,
153 pub quantile: f64,
155 pub fallback_level: u8,
157 pub window_size: usize,
159 pub reset_count: u64,
161 pub y_hat: f64,
163 pub budget_us: f64,
165}
166
167#[derive(Debug, Clone)]
169pub struct ConformalUpdate {
170 pub residual: f64,
172 pub bucket: BucketKey,
174 pub sample_count: usize,
176}
177
178#[derive(Debug, Default)]
179struct BucketState {
180 residuals: VecDeque<f64>,
181}
182
183impl BucketState {
184 fn push(&mut self, residual: f64, window_size: usize) {
185 self.residuals.push_back(residual);
186 while self.residuals.len() > window_size {
187 self.residuals.pop_front();
188 }
189 }
190}
191
192#[derive(Debug)]
194pub struct ConformalPredictor {
195 config: ConformalConfig,
196 buckets: HashMap<BucketKey, BucketState>,
197 reset_count: u64,
198}
199
200impl ConformalPredictor {
201 pub fn new(config: ConformalConfig) -> Self {
203 Self {
204 config,
205 buckets: HashMap::new(),
206 reset_count: 0,
207 }
208 }
209
210 pub fn config(&self) -> &ConformalConfig {
212 &self.config
213 }
214
215 pub fn bucket_samples(&self, key: BucketKey) -> usize {
217 self.buckets
218 .get(&key)
219 .map(|state| state.residuals.len())
220 .unwrap_or(0)
221 }
222
223 pub fn reset_all(&mut self) {
225 self.buckets.clear();
226 self.reset_count += 1;
227 }
228
229 pub fn reset_bucket(&mut self, key: BucketKey) {
231 if let Some(state) = self.buckets.get_mut(&key) {
232 state.residuals.clear();
233 self.reset_count += 1;
234 }
235 }
236
237 pub fn observe(&mut self, key: BucketKey, y_hat_us: f64, observed_us: f64) -> ConformalUpdate {
239 let residual = observed_us - y_hat_us;
240 if !residual.is_finite() {
241 return ConformalUpdate {
242 residual,
243 bucket: key,
244 sample_count: self.bucket_samples(key),
245 };
246 }
247
248 let window_size = self.config.window_size.max(1);
249 let state = self.buckets.entry(key).or_default();
250 state.push(residual, window_size);
251 ConformalUpdate {
252 residual,
253 bucket: key,
254 sample_count: state.residuals.len(),
255 }
256 }
257
258 pub fn predict(&self, key: BucketKey, y_hat_us: f64, budget_us: f64) -> ConformalPrediction {
260 let QuantileDecision {
261 quantile,
262 sample_count,
263 fallback_level,
264 } = self.quantile_for(key);
265
266 let upper_us = y_hat_us + quantile.max(0.0);
267 let risk = upper_us > budget_us;
268
269 ConformalPrediction {
270 upper_us,
271 risk,
272 confidence: 1.0 - self.config.alpha,
273 bucket: key,
274 sample_count,
275 quantile,
276 fallback_level,
277 window_size: self.config.window_size,
278 reset_count: self.reset_count,
279 y_hat: y_hat_us,
280 budget_us,
281 }
282 }
283
284 fn quantile_for(&self, key: BucketKey) -> QuantileDecision {
285 let min_samples = self.config.min_samples.max(1);
286
287 let exact = self.collect_exact(key);
288 if exact.len() >= min_samples {
289 return QuantileDecision::new(self.config.alpha, exact, 0);
290 }
291
292 let mode_diff = self.collect_mode_diff(key.mode, key.diff);
293 if mode_diff.len() >= min_samples {
294 return QuantileDecision::new(self.config.alpha, mode_diff, 1);
295 }
296
297 let mode_only = self.collect_mode(key.mode);
298 if mode_only.len() >= min_samples {
299 return QuantileDecision::new(self.config.alpha, mode_only, 2);
300 }
301
302 let global = self.collect_all();
303 if !global.is_empty() {
304 return QuantileDecision::new(self.config.alpha, global, 3);
305 }
306
307 QuantileDecision {
308 quantile: self.config.q_default,
309 sample_count: 0,
310 fallback_level: 3,
311 }
312 }
313
314 fn collect_exact(&self, key: BucketKey) -> Vec<f64> {
315 self.buckets
316 .get(&key)
317 .map(|state| state.residuals.iter().copied().collect())
318 .unwrap_or_default()
319 }
320
321 fn collect_mode_diff(&self, mode: ModeBucket, diff: DiffBucket) -> Vec<f64> {
322 let mut values = Vec::new();
323 for (key, state) in &self.buckets {
324 if key.mode == mode && key.diff == diff {
325 values.extend(state.residuals.iter().copied());
326 }
327 }
328 values
329 }
330
331 fn collect_mode(&self, mode: ModeBucket) -> Vec<f64> {
332 let mut values = Vec::new();
333 for (key, state) in &self.buckets {
334 if key.mode == mode {
335 values.extend(state.residuals.iter().copied());
336 }
337 }
338 values
339 }
340
341 fn collect_all(&self) -> Vec<f64> {
342 let mut values = Vec::new();
343 for state in self.buckets.values() {
344 values.extend(state.residuals.iter().copied());
345 }
346 values
347 }
348}
349
350#[derive(Debug)]
351struct QuantileDecision {
352 quantile: f64,
353 sample_count: usize,
354 fallback_level: u8,
355}
356
357impl QuantileDecision {
358 fn new(alpha: f64, mut residuals: Vec<f64>, fallback_level: u8) -> Self {
359 let quantile = conformal_quantile(alpha, &mut residuals);
360 Self {
361 quantile,
362 sample_count: residuals.len(),
363 fallback_level,
364 }
365 }
366}
367
368fn conformal_quantile(alpha: f64, residuals: &mut [f64]) -> f64 {
369 if residuals.is_empty() {
370 return 0.0;
371 }
372 residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
373 let n = residuals.len();
374 let rank = ((n as f64 + 1.0) * (1.0 - alpha)).ceil() as usize;
375 let idx = rank.saturating_sub(1).min(n - 1);
376 residuals[idx]
377}
378
379fn size_bucket(cols: u16, rows: u16) -> u8 {
380 let area = cols as u32 * rows as u32;
381 if area == 0 {
382 return 0;
383 }
384 (31 - area.leading_zeros()) as u8
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 fn test_key(cols: u16, rows: u16) -> BucketKey {
392 BucketKey::from_context(
393 ScreenMode::Inline { ui_height: 4 },
394 DiffStrategy::Full,
395 cols,
396 rows,
397 )
398 }
399
400 #[test]
401 fn quantile_n_plus_1_rule() {
402 let mut predictor = ConformalPredictor::new(ConformalConfig {
403 alpha: 0.2,
404 min_samples: 1,
405 window_size: 10,
406 q_default: 0.0,
407 });
408
409 let key = test_key(80, 24);
410 predictor.observe(key, 0.0, 1.0);
411 predictor.observe(key, 0.0, 2.0);
412 predictor.observe(key, 0.0, 3.0);
413
414 let decision = predictor.predict(key, 0.0, 1_000.0);
415 assert_eq!(decision.quantile, 3.0);
416 }
417
418 #[test]
419 fn fallback_hierarchy_mode_diff() {
420 let mut predictor = ConformalPredictor::new(ConformalConfig {
421 alpha: 0.1,
422 min_samples: 4,
423 window_size: 16,
424 q_default: 0.0,
425 });
426
427 let key_a = test_key(80, 24);
428 for value in [1.0, 2.0, 3.0, 4.0] {
429 predictor.observe(key_a, 0.0, value);
430 }
431
432 let key_b = test_key(120, 40);
433 let decision = predictor.predict(key_b, 0.0, 1_000.0);
434 assert_eq!(decision.fallback_level, 1);
435 assert_eq!(decision.sample_count, 4);
436 }
437
438 #[test]
439 fn fallback_hierarchy_mode_only() {
440 let mut predictor = ConformalPredictor::new(ConformalConfig {
441 alpha: 0.1,
442 min_samples: 3,
443 window_size: 16,
444 q_default: 0.0,
445 });
446
447 let key_dirty = BucketKey::from_context(
448 ScreenMode::Inline { ui_height: 4 },
449 DiffStrategy::DirtyRows,
450 80,
451 24,
452 );
453 for value in [10.0, 20.0, 30.0] {
454 predictor.observe(key_dirty, 0.0, value);
455 }
456
457 let key_full = BucketKey::from_context(
458 ScreenMode::Inline { ui_height: 4 },
459 DiffStrategy::Full,
460 120,
461 40,
462 );
463 let decision = predictor.predict(key_full, 0.0, 1_000.0);
464 assert_eq!(decision.fallback_level, 2);
465 assert_eq!(decision.sample_count, 3);
466 }
467
468 #[test]
469 fn window_enforced() {
470 let mut predictor = ConformalPredictor::new(ConformalConfig {
471 alpha: 0.1,
472 min_samples: 1,
473 window_size: 3,
474 q_default: 0.0,
475 });
476 let key = test_key(80, 24);
477 for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
478 predictor.observe(key, 0.0, value);
479 }
480 assert_eq!(predictor.bucket_samples(key), 3);
481 }
482
483 #[test]
484 fn predict_uses_default_when_empty() {
485 let predictor = ConformalPredictor::new(ConformalConfig {
486 alpha: 0.1,
487 min_samples: 2,
488 window_size: 4,
489 q_default: 42.0,
490 });
491 let key = test_key(120, 40);
492 let prediction = predictor.predict(key, 5.0, 10_000.0);
493 assert_eq!(prediction.quantile, 42.0);
494 assert_eq!(prediction.sample_count, 0);
495 assert_eq!(prediction.fallback_level, 3);
496 }
497
498 #[test]
499 fn bucket_isolation_by_size() {
500 let mut predictor = ConformalPredictor::new(ConformalConfig {
501 alpha: 0.2,
502 min_samples: 2,
503 window_size: 10,
504 q_default: 0.0,
505 });
506
507 let small = test_key(40, 10);
508 predictor.observe(small, 0.0, 1.0);
509 predictor.observe(small, 0.0, 2.0);
510
511 let large = test_key(200, 60);
512 predictor.observe(large, 0.0, 10.0);
513 predictor.observe(large, 0.0, 12.0);
514
515 let prediction = predictor.predict(large, 0.0, 1_000.0);
516 assert_eq!(prediction.fallback_level, 0);
517 assert_eq!(prediction.sample_count, 2);
518 assert_eq!(prediction.quantile, 12.0);
519 }
520
521 #[test]
522 fn reset_clears_bucket_and_raises_reset_count() {
523 let mut predictor = ConformalPredictor::new(ConformalConfig {
524 alpha: 0.1,
525 min_samples: 1,
526 window_size: 8,
527 q_default: 7.0,
528 });
529 let key = test_key(80, 24);
530 predictor.observe(key, 0.0, 3.0);
531 assert_eq!(predictor.bucket_samples(key), 1);
532
533 predictor.reset_bucket(key);
534 assert_eq!(predictor.bucket_samples(key), 0);
535
536 let prediction = predictor.predict(key, 0.0, 1_000.0);
537 assert_eq!(prediction.quantile, 7.0);
538 assert_eq!(prediction.reset_count, 1);
539 }
540
541 #[test]
542 fn reset_all_forces_conservative_fallback() {
543 let mut predictor = ConformalPredictor::new(ConformalConfig {
544 alpha: 0.1,
545 min_samples: 1,
546 window_size: 8,
547 q_default: 9.0,
548 });
549 let key = test_key(80, 24);
550 predictor.observe(key, 0.0, 2.0);
551
552 predictor.reset_all();
553 let prediction = predictor.predict(key, 0.0, 1_000.0);
554 assert_eq!(prediction.quantile, 9.0);
555 assert_eq!(prediction.sample_count, 0);
556 assert_eq!(prediction.fallback_level, 3);
557 assert_eq!(prediction.reset_count, 1);
558 }
559
560 #[test]
561 fn size_bucket_log2_area() {
562 let a = size_bucket(8, 8); let b = size_bucket(8, 16); assert_eq!(a, 6);
565 assert_eq!(b, 7);
566 }
567}