1use std::collections::VecDeque;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
61pub enum RenderStage {
62 Layout,
64 Diff,
66 Present,
68}
69
70impl RenderStage {
71 pub const ALL: [RenderStage; 3] = [Self::Layout, Self::Diff, Self::Present];
73
74 pub fn name(self) -> &'static str {
76 match self {
77 Self::Layout => "layout",
78 Self::Diff => "diff",
79 Self::Present => "present",
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy)]
86pub struct StageObservation {
87 pub layout_us: f64,
89 pub diff_us: f64,
91 pub present_us: f64,
93}
94
95impl StageObservation {
96 pub fn get(&self, stage: RenderStage) -> f64 {
98 match stage {
99 RenderStage::Layout => self.layout_us,
100 RenderStage::Diff => self.diff_us,
101 RenderStage::Present => self.present_us,
102 }
103 }
104
105 pub fn total_us(&self) -> f64 {
107 self.layout_us + self.diff_us + self.present_us
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct StageAlert {
114 pub stage: RenderStage,
115 pub is_alert: bool,
117 pub observed: f64,
119 pub threshold: f64,
121 pub e_value: f64,
123 pub calibration_count: usize,
125}
126
127#[derive(Debug, Clone)]
129pub struct FrameResult {
130 pub stages: [StageAlert; 3],
132}
133
134impl FrameResult {
135 pub fn any_alert(&self) -> bool {
137 self.stages.iter().any(|s| s.is_alert)
138 }
139
140 pub fn alerting_stages(&self) -> Vec<RenderStage> {
142 self.stages
143 .iter()
144 .filter(|s| s.is_alert)
145 .map(|s| s.stage)
146 .collect()
147 }
148
149 pub fn stage(&self, stage: RenderStage) -> &StageAlert {
151 &self.stages[stage as usize]
152 }
153}
154
155#[derive(Debug, Clone)]
157pub struct StagedConfig {
158 pub alpha: f64,
160 pub max_calibration: usize,
162 pub min_calibration: usize,
164 pub lambda: f64,
166}
167
168impl Default for StagedConfig {
169 fn default() -> Self {
170 Self {
171 alpha: 0.05,
172 max_calibration: 500,
173 min_calibration: 10,
174 lambda: 0.5,
175 }
176 }
177}
178
179const E_MIN: f64 = 1e-12;
181const E_MAX: f64 = 1e12;
183
184#[derive(Debug, Clone)]
186struct StageState {
187 calibration: VecDeque<f64>,
189 mean: f64,
191 m2: f64,
193 n: u64,
195 e_value: f64,
197}
198
199impl StageState {
200 fn new() -> Self {
201 Self {
202 calibration: VecDeque::new(),
203 mean: 0.0,
204 m2: 0.0,
205 n: 0,
206 e_value: 1.0,
207 }
208 }
209
210 fn calibrate(&mut self, value: f64, max_samples: usize) {
212 self.n += 1;
213 let delta = value - self.mean;
214 self.mean += delta / self.n as f64;
215 let delta2 = value - self.mean;
216 self.m2 += delta * delta2;
217
218 self.calibration.push_back(value);
219 while self.calibration.len() > max_samples {
220 self.calibration.pop_front();
221 }
222 }
223
224 fn variance(&self) -> f64 {
226 if self.n < 2 {
227 return 1.0;
228 }
229 (self.m2 / (self.n - 1) as f64).max(1e-10)
230 }
231
232 fn conformal_threshold(&self, alpha: f64) -> f64 {
234 if self.calibration.is_empty() {
235 return f64::MAX;
236 }
237 let n = self.calibration.len();
238 let mut sorted: Vec<f64> = self.calibration.iter().copied().collect();
239 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
240
241 let quantile_idx = (((1.0 - alpha) * (n + 1) as f64).ceil() as usize).min(n) - 1;
243 sorted[quantile_idx.min(n - 1)]
244 }
245
246 fn update_e_process(&mut self, value: f64, lambda: f64) {
248 let std = self.variance().sqrt();
249 let z = if std > 1e-10 {
250 (value - self.mean) / std
251 } else {
252 0.0
253 };
254 let log_e = lambda * z - lambda * lambda / 2.0;
255 self.e_value = (self.e_value * log_e.exp()).clamp(E_MIN, E_MAX);
256 }
257}
258
259#[derive(Debug, Clone)]
265pub struct StagedConformalPredictor {
266 config: StagedConfig,
267 states: [StageState; 3],
268}
269
270impl Default for StagedConformalPredictor {
271 fn default() -> Self {
272 Self::new(StagedConfig::default())
273 }
274}
275
276impl StagedConformalPredictor {
277 pub fn new(config: StagedConfig) -> Self {
279 Self {
280 config,
281 states: [StageState::new(), StageState::new(), StageState::new()],
282 }
283 }
284
285 pub fn calibrate(&mut self, stage: RenderStage, value: f64) {
287 self.states[stage as usize].calibrate(value, self.config.max_calibration);
288 }
289
290 pub fn calibrate_frame(&mut self, obs: &StageObservation) {
292 for stage in RenderStage::ALL {
293 self.calibrate(stage, obs.get(stage));
294 }
295 }
296
297 pub fn observe_frame(&mut self, obs: StageObservation) -> FrameResult {
299 let mut alerts = [
300 self.observe_stage(RenderStage::Layout, obs.layout_us),
301 self.observe_stage(RenderStage::Diff, obs.diff_us),
302 self.observe_stage(RenderStage::Present, obs.present_us),
303 ];
304 alerts[0].stage = RenderStage::Layout;
306 alerts[1].stage = RenderStage::Diff;
307 alerts[2].stage = RenderStage::Present;
308 FrameResult { stages: alerts }
309 }
310
311 fn observe_stage(&mut self, stage: RenderStage, value: f64) -> StageAlert {
312 let state = &mut self.states[stage as usize];
313 let threshold = state.conformal_threshold(self.config.alpha);
314 let calibration_count = state.calibration.len();
315
316 state.update_e_process(value, self.config.lambda);
318
319 let is_alert = calibration_count >= self.config.min_calibration
320 && value > threshold
321 && state.e_value > 1.0 / self.config.alpha;
322
323 StageAlert {
324 stage,
325 is_alert,
326 observed: value,
327 threshold,
328 e_value: state.e_value,
329 calibration_count,
330 }
331 }
332
333 pub fn calibration_count(&self, stage: RenderStage) -> usize {
335 self.states[stage as usize].calibration.len()
336 }
337
338 pub fn reset(&mut self) {
340 self.states = [StageState::new(), StageState::new(), StageState::new()];
341 }
342}
343
344#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn default_config() {
354 let cfg = StagedConfig::default();
355 assert_eq!(cfg.alpha, 0.05);
356 assert_eq!(cfg.max_calibration, 500);
357 assert_eq!(cfg.min_calibration, 10);
358 }
359
360 #[test]
361 fn render_stage_names() {
362 assert_eq!(RenderStage::Layout.name(), "layout");
363 assert_eq!(RenderStage::Diff.name(), "diff");
364 assert_eq!(RenderStage::Present.name(), "present");
365 }
366
367 #[test]
368 fn stage_observation_total() {
369 let obs = StageObservation {
370 layout_us: 100.0,
371 diff_us: 50.0,
372 present_us: 200.0,
373 };
374 assert!((obs.total_us() - 350.0).abs() < 1e-10);
375 }
376
377 #[test]
378 fn stage_observation_get() {
379 let obs = StageObservation {
380 layout_us: 100.0,
381 diff_us: 50.0,
382 present_us: 200.0,
383 };
384 assert!((obs.get(RenderStage::Layout) - 100.0).abs() < 1e-10);
385 assert!((obs.get(RenderStage::Diff) - 50.0).abs() < 1e-10);
386 assert!((obs.get(RenderStage::Present) - 200.0).abs() < 1e-10);
387 }
388
389 #[test]
390 fn no_alert_during_calibration() {
391 let mut pred = StagedConformalPredictor::default();
392 for _ in 0..5 {
394 pred.calibrate(RenderStage::Layout, 100.0);
395 }
396 let result = pred.observe_frame(StageObservation {
397 layout_us: 999.0, diff_us: 0.0,
399 present_us: 0.0,
400 });
401 assert!(!result.stage(RenderStage::Layout).is_alert);
403 }
404
405 #[test]
406 fn alert_on_regression() {
407 let mut pred = StagedConformalPredictor::default();
408 for _ in 0..50 {
410 pred.calibrate_frame(&StageObservation {
411 layout_us: 100.0,
412 diff_us: 50.0,
413 present_us: 200.0,
414 });
415 }
416
417 let mut alerted = false;
419 for _ in 0..20 {
420 let result = pred.observe_frame(StageObservation {
421 layout_us: 500.0, diff_us: 50.0,
423 present_us: 200.0,
424 });
425 if result.any_alert() {
426 alerted = true;
427 assert!(result.stage(RenderStage::Layout).is_alert);
429 assert!(!result.stage(RenderStage::Diff).is_alert);
431 assert!(!result.stage(RenderStage::Present).is_alert);
432 break;
433 }
434 }
435 assert!(alerted, "Should have alerted on 5x layout regression");
436 }
437
438 #[test]
439 fn no_alert_on_normal() {
440 let mut pred = StagedConformalPredictor::default();
441 for _ in 0..50 {
443 pred.calibrate_frame(&StageObservation {
444 layout_us: 100.0,
445 diff_us: 50.0,
446 present_us: 200.0,
447 });
448 }
449 for _ in 0..20 {
451 let result = pred.observe_frame(StageObservation {
452 layout_us: 100.0,
453 diff_us: 50.0,
454 present_us: 200.0,
455 });
456 assert!(!result.any_alert(), "Should not alert on normal frames");
457 }
458 }
459
460 #[test]
461 fn independent_stage_tracking() {
462 let mut pred = StagedConformalPredictor::default();
463 for _ in 0..50 {
465 pred.calibrate(RenderStage::Layout, 100.0);
466 }
467 assert_eq!(pred.calibration_count(RenderStage::Layout), 50);
468 assert_eq!(pred.calibration_count(RenderStage::Diff), 0);
469 assert_eq!(pred.calibration_count(RenderStage::Present), 0);
470 }
471
472 #[test]
473 fn reset_clears_state() {
474 let mut pred = StagedConformalPredictor::default();
475 for _ in 0..20 {
476 pred.calibrate(RenderStage::Layout, 100.0);
477 }
478 assert_eq!(pred.calibration_count(RenderStage::Layout), 20);
479 pred.reset();
480 assert_eq!(pred.calibration_count(RenderStage::Layout), 0);
481 }
482
483 #[test]
484 fn alerting_stages_list() {
485 let result = FrameResult {
487 stages: [
488 StageAlert {
489 stage: RenderStage::Layout,
490 is_alert: true,
491 observed: 500.0,
492 threshold: 120.0,
493 e_value: 100.0,
494 calibration_count: 50,
495 },
496 StageAlert {
497 stage: RenderStage::Diff,
498 is_alert: false,
499 observed: 50.0,
500 threshold: 80.0,
501 e_value: 0.5,
502 calibration_count: 50,
503 },
504 StageAlert {
505 stage: RenderStage::Present,
506 is_alert: true,
507 observed: 800.0,
508 threshold: 250.0,
509 e_value: 200.0,
510 calibration_count: 50,
511 },
512 ],
513 };
514 let alerting = result.alerting_stages();
515 assert_eq!(alerting.len(), 2);
516 assert!(alerting.contains(&RenderStage::Layout));
517 assert!(alerting.contains(&RenderStage::Present));
518 }
519
520 #[test]
521 fn calibration_window_bounded() {
522 let cfg = StagedConfig {
523 max_calibration: 20,
524 ..Default::default()
525 };
526 let mut pred = StagedConformalPredictor::new(cfg);
527 for i in 0..100 {
528 pred.calibrate(RenderStage::Layout, i as f64);
529 }
530 assert_eq!(pred.calibration_count(RenderStage::Layout), 20);
531 }
532}