Skip to main content

ftui_widgets/
height_predictor.rs

1//! Bayesian height prediction with conformal bounds for virtualized lists.
2//!
3//! Predicts unseen row heights to pre-allocate scroll space and avoid
4//! scroll jumps when actual heights are measured lazily.
5//!
6//! # Mathematical Model
7//!
8//! ## Bayesian Online Estimation
9//!
10//! Maintains a Normal-Normal conjugate model per item category:
11//!
12//! ```text
13//! Prior:     μ ~ N(μ₀, σ₀²/κ₀)
14//! Likelihood: h_i ~ N(μ, σ²)
15//! Posterior:  μ | data ~ N(μ_n, σ²/κ_n)
16//!
17//! where:
18//!   κ_n = κ₀ + n
19//!   μ_n = (κ₀·μ₀ + n·x̄) / κ_n
20//!   σ²  estimated via running variance (Welford's algorithm)
21//! ```
22//!
23//! ## Conformal Prediction Bounds
24//!
25//! Given a calibration set of (predicted, actual) residuals, the conformal
26//! interval is:
27//!
28//! ```text
29//! [μ_n - q_{1-α/2}, μ_n + q_{1-α/2}]
30//! ```
31//!
32//! where `q` is the empirical quantile of |residuals|. This provides
33//! distribution-free coverage: P(h ∈ interval) ≥ 1 - α.
34//!
35//! # Failure Modes
36//!
37//! | Condition | Behavior | Rationale |
38//! |-----------|----------|-----------|
39//! | No measurements | Return default height | Cold start fallback |
40//! | n = 1 | Wide interval (use prior σ) | Insufficient data |
41//! | All same height | σ → 0, interval collapses | Homogeneous data |
42//! | Actual > bound | Adjust + record violation | Expected at rate α |
43
44use std::collections::VecDeque;
45
46/// Configuration for the height predictor.
47#[derive(Debug, Clone)]
48pub struct PredictorConfig {
49    /// Default height when no data is available.
50    pub default_height: u16,
51    /// Prior strength κ₀ (higher = more trust in default). Default: 2.0.
52    pub prior_strength: f64,
53    /// Prior mean μ₀ (usually same as default_height).
54    pub prior_mean: f64,
55    /// Prior variance estimate. Default: 4.0.
56    pub prior_variance: f64,
57    /// Conformal coverage level (1 - α). Default: 0.90.
58    pub coverage: f64,
59    /// Max calibration residuals to keep. Default: 200.
60    pub calibration_window: usize,
61}
62
63impl Default for PredictorConfig {
64    fn default() -> Self {
65        Self {
66            default_height: 1,
67            prior_strength: 2.0,
68            prior_mean: 1.0,
69            prior_variance: 4.0,
70            coverage: 0.90,
71            calibration_window: 200,
72        }
73    }
74}
75
76/// Running statistics using Welford's online algorithm.
77#[derive(Debug, Clone)]
78struct WelfordStats {
79    n: u64,
80    mean: f64,
81    m2: f64, // Sum of squared deviations
82}
83
84impl WelfordStats {
85    fn new() -> Self {
86        Self {
87            n: 0,
88            mean: 0.0,
89            m2: 0.0,
90        }
91    }
92
93    fn update(&mut self, x: f64) {
94        self.n += 1;
95        let delta = x - self.mean;
96        self.mean += delta / self.n as f64;
97        let delta2 = x - self.mean;
98        self.m2 += delta * delta2;
99    }
100
101    fn variance(&self) -> f64 {
102        if self.n < 2 {
103            return f64::MAX;
104        }
105        self.m2 / (self.n - 1) as f64
106    }
107}
108
109/// Per-category prediction state.
110#[derive(Debug, Clone)]
111struct CategoryState {
112    /// Welford running stats for observed heights.
113    welford: WelfordStats,
114    /// Posterior mean μ_n.
115    posterior_mean: f64,
116    /// Posterior κ_n.
117    posterior_kappa: f64,
118    /// Calibration residuals |predicted - actual|.
119    residuals: VecDeque<f64>,
120}
121
122/// A prediction with conformal bounds.
123#[derive(Debug, Clone, Copy)]
124pub struct HeightPrediction {
125    /// Point prediction (posterior mean, rounded).
126    pub predicted: u16,
127    /// Lower conformal bound.
128    pub lower: u16,
129    /// Upper conformal bound.
130    pub upper: u16,
131    /// Number of observations for this category.
132    pub observations: u64,
133}
134
135/// Bayesian height predictor with conformal bounds.
136#[derive(Debug, Clone)]
137pub struct HeightPredictor {
138    config: PredictorConfig,
139    /// Per-category states. Key is category index (0 = default).
140    categories: Vec<CategoryState>,
141    /// Total measurements across all categories.
142    total_measurements: u64,
143    /// Total bound violations.
144    total_violations: u64,
145}
146
147impl HeightPredictor {
148    /// Create a new predictor with default config.
149    pub fn new(config: PredictorConfig) -> Self {
150        // Start with one default category.
151        let default_cat = CategoryState {
152            welford: WelfordStats::new(),
153            posterior_mean: config.prior_mean,
154            posterior_kappa: config.prior_strength,
155            residuals: VecDeque::new(),
156        };
157        Self {
158            config,
159            categories: vec![default_cat],
160            total_measurements: 0,
161            total_violations: 0,
162        }
163    }
164
165    /// Register a new category. Returns the category id.
166    pub fn register_category(&mut self) -> usize {
167        let id = self.categories.len();
168        self.categories.push(CategoryState {
169            welford: WelfordStats::new(),
170            posterior_mean: self.config.prior_mean,
171            posterior_kappa: self.config.prior_strength,
172            residuals: VecDeque::new(),
173        });
174        id
175    }
176
177    /// Predict height for an item in the given category.
178    pub fn predict(&self, category: usize) -> HeightPrediction {
179        let cat = match self.categories.get(category) {
180            Some(c) => c,
181            None => return self.cold_prediction(),
182        };
183
184        if cat.welford.n == 0 {
185            return self.cold_prediction();
186        }
187
188        let mu = cat.posterior_mean;
189        let predicted = mu.round().max(1.0) as u16;
190
191        // Conformal bounds from calibration residuals.
192        let (lower, upper) = self.conformal_bounds(cat, mu);
193
194        HeightPrediction {
195            predicted,
196            lower,
197            upper,
198            observations: cat.welford.n,
199        }
200    }
201
202    /// Record an actual measured height, updating the model.
203    /// Returns whether the measurement was within the predicted bounds.
204    pub fn observe(&mut self, category: usize, actual_height: u16) -> bool {
205        // Ensure category exists.
206        while self.categories.len() <= category {
207            self.register_category();
208        }
209
210        let prediction = self.predict(category);
211        let within_bounds = actual_height >= prediction.lower && actual_height <= prediction.upper;
212
213        self.total_measurements += 1;
214        if !within_bounds && prediction.observations > 0 {
215            self.total_violations += 1;
216        }
217
218        let cat = &mut self.categories[category];
219        let h = actual_height as f64;
220
221        // Record calibration residual.
222        let residual = (cat.posterior_mean - h).abs();
223        cat.residuals.push_back(residual);
224        if cat.residuals.len() > self.config.calibration_window {
225            cat.residuals.pop_front();
226        }
227
228        // Update Welford stats.
229        cat.welford.update(h);
230
231        // Update posterior: μ_n = (κ₀·μ₀ + n·x̄) / κ_n
232        let n = cat.welford.n as f64;
233        let kappa_0 = self.config.prior_strength;
234        let mu_0 = self.config.prior_mean;
235        cat.posterior_kappa = kappa_0 + n;
236        cat.posterior_mean = (kappa_0 * mu_0 + n * cat.welford.mean) / cat.posterior_kappa;
237
238        within_bounds
239    }
240
241    /// Cold-start prediction when no data is available.
242    fn cold_prediction(&self) -> HeightPrediction {
243        let d = self.config.default_height;
244        let margin = (self.config.prior_variance.sqrt() * 2.0).ceil() as u16;
245        HeightPrediction {
246            predicted: d,
247            lower: d.saturating_sub(margin),
248            upper: d.saturating_add(margin),
249            observations: 0,
250        }
251    }
252
253    /// Compute conformal bounds from calibration residuals.
254    fn conformal_bounds(&self, cat: &CategoryState, mu: f64) -> (u16, u16) {
255        if cat.residuals.is_empty() {
256            // Fallback: use prior variance.
257            let margin = (self.config.prior_variance.sqrt() * 2.0).ceil() as u16;
258            let predicted = mu.round().max(1.0) as u16;
259            return (
260                predicted.saturating_sub(margin),
261                predicted.saturating_add(margin),
262            );
263        }
264
265        // Sort residuals to find quantile.
266        let mut sorted: Vec<f64> = cat.residuals.iter().copied().collect();
267        sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
268
269        let alpha = 1.0 - self.config.coverage;
270        let quantile_idx = ((1.0 - alpha) * sorted.len() as f64).ceil() as usize;
271        let quantile_idx = quantile_idx.min(sorted.len()).saturating_sub(1);
272        let q = sorted[quantile_idx];
273
274        let lower = (mu - q).max(1.0).floor() as u16;
275        let upper = (mu + q).ceil().max(1.0) as u16;
276
277        (lower, upper)
278    }
279
280    /// Get the posterior mean for a category.
281    pub fn posterior_mean(&self, category: usize) -> f64 {
282        self.categories
283            .get(category)
284            .map(|c| c.posterior_mean)
285            .unwrap_or(self.config.prior_mean)
286    }
287
288    /// Get the posterior variance for a category.
289    pub fn posterior_variance(&self, category: usize) -> f64 {
290        self.categories
291            .get(category)
292            .map(|c| {
293                let sigma_sq = if c.welford.n < 2 {
294                    self.config.prior_variance
295                } else {
296                    c.welford.variance()
297                };
298                sigma_sq / c.posterior_kappa
299            })
300            .unwrap_or(self.config.prior_variance)
301    }
302
303    /// Total measurements observed.
304    pub fn total_measurements(&self) -> u64 {
305        self.total_measurements
306    }
307
308    /// Total bound violations.
309    pub fn total_violations(&self) -> u64 {
310        self.total_violations
311    }
312
313    /// Empirical violation rate.
314    pub fn violation_rate(&self) -> f64 {
315        if self.total_measurements == 0 {
316            return 0.0;
317        }
318        self.total_violations as f64 / self.total_measurements as f64
319    }
320
321    /// Number of categories.
322    pub fn category_count(&self) -> usize {
323        self.categories.len()
324    }
325
326    /// Number of observations for a category.
327    pub fn category_observations(&self, category: usize) -> u64 {
328        self.categories
329            .get(category)
330            .map(|c| c.welford.n)
331            .unwrap_or(0)
332    }
333}
334
335impl Default for HeightPredictor {
336    fn default() -> Self {
337        Self::new(PredictorConfig::default())
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    // ─── Posterior update tests ────────────────────────────────────
346
347    #[test]
348    fn unit_posterior_update() {
349        let config = PredictorConfig {
350            prior_mean: 2.0,
351            prior_strength: 1.0,
352            prior_variance: 4.0,
353            ..Default::default()
354        };
355        let mut pred = HeightPredictor::new(config);
356
357        // Prior: μ=2.0, κ=1.
358        assert!((pred.posterior_mean(0) - 2.0).abs() < 1e-10);
359
360        // Observe height 4.
361        pred.observe(0, 4);
362        // κ_1 = 1 + 1 = 2, μ_1 = (1*2 + 1*4) / 2 = 3.0
363        assert!((pred.posterior_mean(0) - 3.0).abs() < 1e-10);
364
365        // Observe another height 4.
366        pred.observe(0, 4);
367        // κ_2 = 1 + 2 = 3, x̄ = 4, μ_2 = (1*2 + 2*4) / 3 = 10/3 ≈ 3.333
368        assert!((pred.posterior_mean(0) - 10.0 / 3.0).abs() < 1e-10);
369    }
370
371    #[test]
372    fn unit_posterior_variance_decreases() {
373        let mut pred = HeightPredictor::new(PredictorConfig {
374            prior_variance: 4.0,
375            ..Default::default()
376        });
377
378        let var_0 = pred.posterior_variance(0);
379        assert!(var_0 > 0.0, "prior variance should be positive");
380
381        // Feed noisy data so Welford variance is non-zero.
382        for i in 0..10 {
383            pred.observe(0, if i % 2 == 0 { 2 } else { 4 });
384        }
385        let var_10 = pred.posterior_variance(0);
386
387        for i in 0..90 {
388            pred.observe(0, if i % 2 == 0 { 2 } else { 4 });
389        }
390        let var_100 = pred.posterior_variance(0);
391
392        // With noisy data, posterior variance σ²/κ_n decreases as κ_n grows.
393        assert!(
394            var_10 < var_0,
395            "variance should decrease: {var_10} >= {var_0}"
396        );
397        assert!(
398            var_100 < var_10,
399            "variance should decrease: {var_100} >= {var_10}"
400        );
401    }
402
403    // ─── Conformal bounds tests ───────────────────────────────────
404
405    #[test]
406    fn unit_conformal_bounds() {
407        let config = PredictorConfig {
408            coverage: 0.90,
409            prior_mean: 3.0,
410            prior_strength: 1.0,
411            ..Default::default()
412        };
413        let mut pred = HeightPredictor::new(config);
414
415        // Feed consistent data.
416        for _ in 0..50 {
417            pred.observe(0, 3);
418        }
419
420        let p = pred.predict(0);
421        // With all observations at 3, residuals should be near 0.
422        // Bounds should be tight around 3.
423        assert_eq!(p.predicted, 3);
424        assert!(p.lower <= 3);
425        assert!(p.upper >= 3);
426    }
427
428    #[test]
429    fn conformal_bounds_widen_with_noise() {
430        let config = PredictorConfig {
431            coverage: 0.90,
432            prior_mean: 5.0,
433            prior_strength: 1.0,
434            ..Default::default()
435        };
436        let mut pred = HeightPredictor::new(config);
437
438        // Consistent data → tight bounds.
439        for _ in 0..50 {
440            pred.observe(0, 5);
441        }
442        let tight = pred.predict(0);
443
444        // Reset with noisy data.
445        let mut pred2 = HeightPredictor::new(PredictorConfig {
446            coverage: 0.90,
447            prior_mean: 5.0,
448            prior_strength: 1.0,
449            ..Default::default()
450        });
451        let mut seed: u64 = 0xABCD_1234_5678_9ABC;
452        for _ in 0..50 {
453            seed = seed
454                .wrapping_mul(6364136223846793005)
455                .wrapping_add(1442695040888963407);
456            let h = 3 + (seed >> 62) as u16; // heights 3..6
457            pred2.observe(0, h);
458        }
459        let wide = pred2.predict(0);
460
461        assert!(
462            (wide.upper - wide.lower) >= (tight.upper - tight.lower),
463            "noisy data should produce wider bounds"
464        );
465    }
466
467    // ─── Coverage property test ───────────────────────────────────
468
469    #[test]
470    fn property_coverage() {
471        let alpha = 0.10;
472        let config = PredictorConfig {
473            coverage: 1.0 - alpha,
474            prior_mean: 3.0,
475            prior_strength: 2.0,
476            prior_variance: 4.0,
477            calibration_window: 100,
478            ..Default::default()
479        };
480        let mut pred = HeightPredictor::new(config);
481
482        // Warm up with calibration data.
483        let mut seed: u64 = 0xDEAD_BEEF_CAFE_0001;
484        for _ in 0..100 {
485            seed = seed
486                .wrapping_mul(6364136223846793005)
487                .wrapping_add(1442695040888963407);
488            let h = 2 + (seed >> 62) as u16; // heights 2..5
489            pred.observe(0, h);
490        }
491
492        // Now check coverage on new data.
493        let mut violations = 0u32;
494        let test_n = 200;
495        for _ in 0..test_n {
496            seed = seed
497                .wrapping_mul(6364136223846793005)
498                .wrapping_add(1442695040888963407);
499            let h = 2 + (seed >> 62) as u16;
500            let within = pred.observe(0, h);
501            if !within {
502                violations += 1;
503            }
504        }
505
506        let viol_rate = violations as f64 / test_n as f64;
507        // Empirical violation rate should be approximately ≤ α.
508        // Allow generous tolerance for finite sample + discrete heights.
509        assert!(
510            viol_rate <= alpha + 0.15,
511            "violation rate {viol_rate} exceeds α + tolerance ({alpha} + 0.15)"
512        );
513    }
514
515    // ─── Scroll stability test ────────────────────────────────────
516
517    #[test]
518    fn e2e_scroll_stability() {
519        let mut pred = HeightPredictor::new(PredictorConfig {
520            prior_mean: 1.0,
521            prior_strength: 2.0,
522            default_height: 1,
523            coverage: 0.90,
524            ..Default::default()
525        });
526
527        // All items are height 1 (most common TUI case).
528        let mut corrections = 0u32;
529        for _ in 0..500 {
530            let within = pred.observe(0, 1);
531            if !within {
532                corrections += 1;
533            }
534        }
535
536        // With homogeneous heights, should converge quickly with zero corrections
537        // after warmup.
538        let p = pred.predict(0);
539        assert_eq!(p.predicted, 1);
540        assert!(corrections < 10, "too many corrections: {corrections}");
541    }
542
543    // ─── Multiple categories ──────────────────────────────────────
544
545    #[test]
546    fn categories_are_independent() {
547        let mut pred = HeightPredictor::default();
548        let cat_a = 0;
549        let cat_b = pred.register_category();
550
551        // Feed different data to each.
552        for _ in 0..20 {
553            pred.observe(cat_a, 1);
554            pred.observe(cat_b, 5);
555        }
556
557        let pa = pred.predict(cat_a);
558        let pb = pred.predict(cat_b);
559
560        assert_eq!(pa.predicted, 1);
561        assert!(pb.predicted >= 4 && pb.predicted <= 5);
562    }
563
564    // ─── Cold start ───────────────────────────────────────────────
565
566    #[test]
567    fn cold_prediction_uses_default() {
568        let pred = HeightPredictor::new(PredictorConfig {
569            default_height: 2,
570            prior_variance: 1.0,
571            ..Default::default()
572        });
573        let p = pred.predict(0);
574        assert_eq!(p.predicted, 2);
575        assert_eq!(p.observations, 0);
576    }
577
578    // ─── Determinism ──────────────────────────────────────────────
579
580    #[test]
581    fn deterministic_under_same_observations() {
582        let run = || {
583            let mut pred = HeightPredictor::default();
584            let observations = [1, 2, 1, 3, 1, 2, 1, 1, 4, 1];
585            for &h in &observations {
586                pred.observe(0, h);
587            }
588            (pred.predict(0).predicted, pred.posterior_mean(0))
589        };
590
591        let (p1, m1) = run();
592        let (p2, m2) = run();
593        assert_eq!(p1, p2);
594        assert!((m1 - m2).abs() < 1e-15);
595    }
596
597    // ─── Performance ──────────────────────────────────────────────
598
599    #[test]
600    fn perf_prediction_overhead() {
601        let mut pred = HeightPredictor::default();
602
603        // Warm up.
604        for _ in 0..100 {
605            pred.observe(0, 2);
606        }
607
608        let start = std::time::Instant::now();
609        let mut _sink = 0u16;
610        for _ in 0..100_000 {
611            _sink = _sink.wrapping_add(pred.predict(0).predicted);
612        }
613        let elapsed = start.elapsed();
614        let per_prediction = elapsed / 100_000;
615
616        // Must be < 5μs per prediction (generous for debug builds).
617        assert!(
618            per_prediction < std::time::Duration::from_micros(5),
619            "prediction too slow: {per_prediction:?}"
620        );
621    }
622
623    // ─── Violation tracking ───────────────────────────────────────
624
625    #[test]
626    fn violation_tracking() {
627        let mut pred = HeightPredictor::new(PredictorConfig {
628            prior_mean: 5.0,
629            prior_strength: 100.0, // strong prior
630            default_height: 5,
631            coverage: 0.95,
632            ..Default::default()
633        });
634
635        // Warm up with height=5.
636        for _ in 0..50 {
637            pred.observe(0, 5);
638        }
639
640        // Sudden jump to height=20 should violate bounds.
641        let within = pred.observe(0, 20);
642        assert!(!within, "extreme outlier should violate bounds");
643        assert!(pred.total_violations() > 0);
644    }
645}