Skip to main content

ftui_widgets/
height_predictor.rs

1//! Bayesian height prediction with conformal bounds for virtualized lists.
2//!
3//! Predicts unseen row heights to pre-allocate scroll space and avoid
4//! scroll jumps when actual heights are measured lazily.
5//!
6//! # Mathematical Model
7//!
8//! ## Bayesian Online Estimation
9//!
10//! Maintains a Normal-Normal conjugate model per item category:
11//!
12//! ```text
13//! Prior:     μ ~ N(μ₀, σ₀²/κ₀)
14//! Likelihood: h_i ~ N(μ, σ²)
15//! Posterior:  μ | data ~ N(μ_n, σ²/κ_n)
16//!
17//! where:
18//!   κ_n = κ₀ + n
19//!   μ_n = (κ₀·μ₀ + n·x̄) / κ_n
20//!   σ²  estimated via running variance (Welford's algorithm)
21//! ```
22//!
23//! ## Conformal Prediction Bounds
24//!
25//! Given a calibration set of (predicted, actual) residuals, the conformal
26//! interval is:
27//!
28//! ```text
29//! [μ_n - q_{1-α/2}, μ_n + q_{1-α/2}]
30//! ```
31//!
32//! where `q` is the empirical quantile of |residuals|. This provides
33//! distribution-free coverage: P(h ∈ interval) ≥ 1 - α.
34//!
35//! # Failure Modes
36//!
37//! | Condition | Behavior | Rationale |
38//! |-----------|----------|-----------|
39//! | No measurements | Return default height | Cold start fallback |
40//! | n = 1 | Wide interval (use prior σ) | Insufficient data |
41//! | All same height | σ → 0, interval collapses | Homogeneous data |
42//! | Actual > bound | Adjust + record violation | Expected at rate α |
43
44use std::collections::VecDeque;
45
46/// Configuration for the height predictor.
47#[derive(Debug, Clone)]
48pub struct PredictorConfig {
49    /// Default height when no data is available.
50    pub default_height: u16,
51    /// Prior strength κ₀ (higher = more trust in default). Default: 2.0.
52    pub prior_strength: f64,
53    /// Prior mean μ₀ (usually same as default_height).
54    pub prior_mean: f64,
55    /// Prior variance estimate. Default: 4.0.
56    pub prior_variance: f64,
57    /// Conformal coverage level (1 - α). Default: 0.90.
58    pub coverage: f64,
59    /// Max calibration residuals to keep. Default: 200.
60    pub calibration_window: usize,
61}
62
63impl Default for PredictorConfig {
64    fn default() -> Self {
65        Self {
66            default_height: 1,
67            prior_strength: 2.0,
68            prior_mean: 1.0,
69            prior_variance: 4.0,
70            coverage: 0.90,
71            calibration_window: 200,
72        }
73    }
74}
75
76/// Running statistics using Welford's online algorithm.
77#[derive(Debug, Clone)]
78struct WelfordStats {
79    n: u64,
80    mean: f64,
81    m2: f64, // Sum of squared deviations
82}
83
84impl WelfordStats {
85    fn new() -> Self {
86        Self {
87            n: 0,
88            mean: 0.0,
89            m2: 0.0,
90        }
91    }
92
93    fn update(&mut self, x: f64) {
94        self.n += 1;
95        let delta = x - self.mean;
96        self.mean += delta / self.n as f64;
97        let delta2 = x - self.mean;
98        self.m2 += delta * delta2;
99    }
100
101    fn variance(&self) -> f64 {
102        if self.n < 2 {
103            return f64::MAX;
104        }
105        self.m2 / (self.n - 1) as f64
106    }
107}
108
109/// Per-category prediction state.
110#[derive(Debug, Clone)]
111struct CategoryState {
112    /// Welford running stats for observed heights.
113    welford: WelfordStats,
114    /// Posterior mean μ_n.
115    posterior_mean: f64,
116    /// Posterior κ_n.
117    posterior_kappa: f64,
118    /// Calibration residuals |predicted - actual|.
119    residuals: VecDeque<f64>,
120}
121
122/// A prediction with conformal bounds.
123#[derive(Debug, Clone, Copy)]
124pub struct HeightPrediction {
125    /// Point prediction (posterior mean, rounded).
126    pub predicted: u16,
127    /// Lower conformal bound.
128    pub lower: u16,
129    /// Upper conformal bound.
130    pub upper: u16,
131    /// Number of observations for this category.
132    pub observations: u64,
133}
134
135/// Bayesian height predictor with conformal bounds.
136#[derive(Debug, Clone)]
137pub struct HeightPredictor {
138    config: PredictorConfig,
139    /// Per-category states. Key is category index (0 = default).
140    categories: Vec<CategoryState>,
141    /// Total measurements across all categories.
142    total_measurements: u64,
143    /// Total bound violations.
144    total_violations: u64,
145}
146
147impl HeightPredictor {
148    /// Create a new predictor with default config.
149    pub fn new(config: PredictorConfig) -> Self {
150        // Start with one default category.
151        let default_cat = CategoryState {
152            welford: WelfordStats::new(),
153            posterior_mean: config.prior_mean,
154            posterior_kappa: config.prior_strength,
155            residuals: VecDeque::new(),
156        };
157        Self {
158            config,
159            categories: vec![default_cat],
160            total_measurements: 0,
161            total_violations: 0,
162        }
163    }
164
165    /// Register a new category. Returns the category id.
166    pub fn register_category(&mut self) -> usize {
167        let id = self.categories.len();
168        self.categories.push(CategoryState {
169            welford: WelfordStats::new(),
170            posterior_mean: self.config.prior_mean,
171            posterior_kappa: self.config.prior_strength,
172            residuals: VecDeque::new(),
173        });
174        id
175    }
176
177    /// Predict height for an item in the given category.
178    pub fn predict(&self, category: usize) -> HeightPrediction {
179        let cat = match self.categories.get(category) {
180            Some(c) => c,
181            None => return self.cold_prediction(),
182        };
183
184        if cat.welford.n == 0 {
185            return self.cold_prediction();
186        }
187
188        let mu = cat.posterior_mean;
189        let predicted = mu.round().max(1.0) as u16;
190
191        // Conformal bounds from calibration residuals.
192        let (lower, upper) = self.conformal_bounds(cat, mu);
193
194        HeightPrediction {
195            predicted,
196            lower,
197            upper,
198            observations: cat.welford.n,
199        }
200    }
201
202    /// Record an actual measured height, updating the model.
203    /// Returns whether the measurement was within the predicted bounds.
204    pub fn observe(&mut self, category: usize, actual_height: u16) -> bool {
205        // Ensure category exists.
206        while self.categories.len() <= category {
207            self.register_category();
208        }
209
210        let prediction = self.predict(category);
211        let within_bounds = actual_height >= prediction.lower && actual_height <= prediction.upper;
212
213        self.total_measurements += 1;
214        if !within_bounds && prediction.observations > 0 {
215            self.total_violations += 1;
216        }
217
218        let cat = &mut self.categories[category];
219        let h = actual_height as f64;
220
221        // Record calibration residual.
222        let residual = (cat.posterior_mean - h).abs();
223        cat.residuals.push_back(residual);
224        if cat.residuals.len() > self.config.calibration_window {
225            cat.residuals.pop_front();
226        }
227
228        // Update Welford stats.
229        cat.welford.update(h);
230
231        // Update posterior: μ_n = (κ₀·μ₀ + n·x̄) / κ_n
232        let n = cat.welford.n as f64;
233        let kappa_0 = self.config.prior_strength;
234        let mu_0 = self.config.prior_mean;
235        cat.posterior_kappa = kappa_0 + n;
236        cat.posterior_mean = (kappa_0 * mu_0 + n * cat.welford.mean) / cat.posterior_kappa;
237
238        within_bounds
239    }
240
241    /// Cold-start prediction when no data is available.
242    fn cold_prediction(&self) -> HeightPrediction {
243        let d = self.config.default_height;
244        let margin = (self.config.prior_variance.sqrt() * 2.0).ceil() as u16;
245        HeightPrediction {
246            predicted: d,
247            lower: d.saturating_sub(margin),
248            upper: d.saturating_add(margin),
249            observations: 0,
250        }
251    }
252
253    /// Compute conformal bounds from calibration residuals.
254    fn conformal_bounds(&self, cat: &CategoryState, mu: f64) -> (u16, u16) {
255        if cat.residuals.is_empty() {
256            // Fallback: use prior variance.
257            let margin = (self.config.prior_variance.sqrt() * 2.0).ceil() as u16;
258            let predicted = mu.round().max(1.0) as u16;
259            return (
260                predicted.saturating_sub(margin),
261                predicted.saturating_add(margin),
262            );
263        }
264
265        // Sort residuals to find quantile.
266        let mut sorted: Vec<f64> = cat.residuals.iter().copied().collect();
267        sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
268
269        let alpha = 1.0 - self.config.coverage;
270        let quantile_idx = ((1.0 - alpha) * sorted.len() as f64).ceil() as usize;
271        let quantile_idx = quantile_idx.min(sorted.len()).saturating_sub(1);
272        let q = sorted[quantile_idx];
273
274        let lower = (mu - q).max(1.0).floor() as u16;
275        let upper = (mu + q).ceil().max(1.0) as u16;
276
277        (lower, upper)
278    }
279
280    /// Get the posterior mean for a category.
281    pub fn posterior_mean(&self, category: usize) -> f64 {
282        self.categories
283            .get(category)
284            .map(|c| c.posterior_mean)
285            .unwrap_or(self.config.prior_mean)
286    }
287
288    /// Get the posterior variance for a category.
289    pub fn posterior_variance(&self, category: usize) -> f64 {
290        self.categories
291            .get(category)
292            .map(|c| {
293                let sigma_sq = if c.welford.n < 2 {
294                    self.config.prior_variance
295                } else {
296                    c.welford.variance()
297                };
298                sigma_sq / c.posterior_kappa
299            })
300            .unwrap_or(self.config.prior_variance)
301    }
302
303    /// Total measurements observed.
304    pub fn total_measurements(&self) -> u64 {
305        self.total_measurements
306    }
307
308    /// Total bound violations.
309    pub fn total_violations(&self) -> u64 {
310        self.total_violations
311    }
312
313    /// Empirical violation rate.
314    pub fn violation_rate(&self) -> f64 {
315        if self.total_measurements == 0 {
316            return 0.0;
317        }
318        self.total_violations as f64 / self.total_measurements as f64
319    }
320
321    /// Number of categories.
322    pub fn category_count(&self) -> usize {
323        self.categories.len()
324    }
325
326    /// Number of observations for a category.
327    pub fn category_observations(&self, category: usize) -> u64 {
328        self.categories
329            .get(category)
330            .map(|c| c.welford.n)
331            .unwrap_or(0)
332    }
333}
334
335impl Default for HeightPredictor {
336    fn default() -> Self {
337        Self::new(PredictorConfig::default())
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    // ─── Posterior update tests ────────────────────────────────────
346
347    #[test]
348    fn unit_posterior_update() {
349        let config = PredictorConfig {
350            prior_mean: 2.0,
351            prior_strength: 1.0,
352            prior_variance: 4.0,
353            ..Default::default()
354        };
355        let mut pred = HeightPredictor::new(config);
356
357        // Prior: μ=2.0, κ=1.
358        assert!((pred.posterior_mean(0) - 2.0).abs() < 1e-10);
359
360        // Observe height 4.
361        pred.observe(0, 4);
362        // κ_1 = 1 + 1 = 2, μ_1 = (1*2 + 1*4) / 2 = 3.0
363        assert!((pred.posterior_mean(0) - 3.0).abs() < 1e-10);
364
365        // Observe another height 4.
366        pred.observe(0, 4);
367        // κ_2 = 1 + 2 = 3, x̄ = 4, μ_2 = (1*2 + 2*4) / 3 = 10/3 ≈ 3.333
368        assert!((pred.posterior_mean(0) - 10.0 / 3.0).abs() < 1e-10);
369    }
370
371    #[test]
372    fn unit_posterior_variance_decreases() {
373        let mut pred = HeightPredictor::new(PredictorConfig {
374            prior_variance: 4.0,
375            ..Default::default()
376        });
377
378        let var_0 = pred.posterior_variance(0);
379        assert!(var_0 > 0.0, "prior variance should be positive");
380
381        // Feed noisy data so Welford variance is non-zero.
382        for i in 0..10 {
383            pred.observe(0, if i % 2 == 0 { 2 } else { 4 });
384        }
385        let var_10 = pred.posterior_variance(0);
386
387        for i in 0..90 {
388            pred.observe(0, if i % 2 == 0 { 2 } else { 4 });
389        }
390        let var_100 = pred.posterior_variance(0);
391
392        // With noisy data, posterior variance σ²/κ_n decreases as κ_n grows.
393        assert!(
394            var_10 < var_0,
395            "variance should decrease: {var_10} >= {var_0}"
396        );
397        assert!(
398            var_100 < var_10,
399            "variance should decrease: {var_100} >= {var_10}"
400        );
401    }
402
403    // ─── Conformal bounds tests ───────────────────────────────────
404
405    #[test]
406    fn unit_conformal_bounds() {
407        let config = PredictorConfig {
408            coverage: 0.90,
409            prior_mean: 3.0,
410            prior_strength: 1.0,
411            ..Default::default()
412        };
413        let mut pred = HeightPredictor::new(config);
414
415        // Feed consistent data.
416        for _ in 0..50 {
417            pred.observe(0, 3);
418        }
419
420        let p = pred.predict(0);
421        // With all observations at 3, residuals should be near 0.
422        // Bounds should be tight around 3.
423        assert_eq!(p.predicted, 3);
424        assert!(p.lower <= 3);
425        assert!(p.upper >= 3);
426    }
427
428    #[test]
429    fn conformal_bounds_widen_with_noise() {
430        let config = PredictorConfig {
431            coverage: 0.90,
432            prior_mean: 5.0,
433            prior_strength: 1.0,
434            ..Default::default()
435        };
436        let mut pred = HeightPredictor::new(config);
437
438        // Consistent data → tight bounds.
439        for _ in 0..50 {
440            pred.observe(0, 5);
441        }
442        let tight = pred.predict(0);
443
444        // Reset with noisy data.
445        let mut pred2 = HeightPredictor::new(PredictorConfig {
446            coverage: 0.90,
447            prior_mean: 5.0,
448            prior_strength: 1.0,
449            ..Default::default()
450        });
451        let mut seed: u64 = 0xABCD_1234_5678_9ABC;
452        for _ in 0..50 {
453            seed = seed
454                .wrapping_mul(6364136223846793005)
455                .wrapping_add(1442695040888963407);
456            let h = 3 + (seed >> 62) as u16; // heights 3..6
457            pred2.observe(0, h);
458        }
459        let wide = pred2.predict(0);
460
461        assert!(
462            (wide.upper - wide.lower) >= (tight.upper - tight.lower),
463            "noisy data should produce wider bounds"
464        );
465    }
466
467    // ─── Coverage property test ───────────────────────────────────
468
469    #[test]
470    fn property_coverage() {
471        let alpha = 0.10;
472        let config = PredictorConfig {
473            coverage: 1.0 - alpha,
474            prior_mean: 3.0,
475            prior_strength: 2.0,
476            prior_variance: 4.0,
477            calibration_window: 100,
478            ..Default::default()
479        };
480        let mut pred = HeightPredictor::new(config);
481
482        // Warm up with calibration data.
483        let mut seed: u64 = 0xDEAD_BEEF_CAFE_0001;
484        for _ in 0..100 {
485            seed = seed
486                .wrapping_mul(6364136223846793005)
487                .wrapping_add(1442695040888963407);
488            let h = 2 + (seed >> 62) as u16; // heights 2..5
489            pred.observe(0, h);
490        }
491
492        // Now check coverage on new data.
493        let mut violations = 0u32;
494        let test_n = 200;
495        for _ in 0..test_n {
496            seed = seed
497                .wrapping_mul(6364136223846793005)
498                .wrapping_add(1442695040888963407);
499            let h = 2 + (seed >> 62) as u16;
500            let within = pred.observe(0, h);
501            if !within {
502                violations += 1;
503            }
504        }
505
506        let viol_rate = violations as f64 / test_n as f64;
507        // Empirical violation rate should be approximately ≤ α.
508        // Allow generous tolerance for finite sample + discrete heights.
509        assert!(
510            viol_rate <= alpha + 0.15,
511            "violation rate {viol_rate} exceeds α + tolerance ({alpha} + 0.15)"
512        );
513    }
514
515    // ─── Scroll stability test ────────────────────────────────────
516
517    #[test]
518    fn e2e_scroll_stability() {
519        let mut pred = HeightPredictor::new(PredictorConfig {
520            prior_mean: 1.0,
521            prior_strength: 2.0,
522            default_height: 1,
523            coverage: 0.90,
524            ..Default::default()
525        });
526
527        // All items are height 1 (most common TUI case).
528        let mut corrections = 0u32;
529        for _ in 0..500 {
530            let within = pred.observe(0, 1);
531            if !within {
532                corrections += 1;
533            }
534        }
535
536        // With homogeneous heights, should converge quickly with zero corrections
537        // after warmup.
538        let p = pred.predict(0);
539        assert_eq!(p.predicted, 1);
540        assert!(corrections < 10, "too many corrections: {corrections}");
541    }
542
543    // ─── Multiple categories ──────────────────────────────────────
544
545    #[test]
546    fn categories_are_independent() {
547        let mut pred = HeightPredictor::default();
548        let cat_a = 0;
549        let cat_b = pred.register_category();
550
551        // Feed different data to each.
552        for _ in 0..20 {
553            pred.observe(cat_a, 1);
554            pred.observe(cat_b, 5);
555        }
556
557        let pa = pred.predict(cat_a);
558        let pb = pred.predict(cat_b);
559
560        assert_eq!(pa.predicted, 1);
561        assert!(pb.predicted >= 4 && pb.predicted <= 5);
562    }
563
564    // ─── Cold start ───────────────────────────────────────────────
565
566    #[test]
567    fn cold_prediction_uses_default() {
568        let pred = HeightPredictor::new(PredictorConfig {
569            default_height: 2,
570            prior_variance: 1.0,
571            ..Default::default()
572        });
573        let p = pred.predict(0);
574        assert_eq!(p.predicted, 2);
575        assert_eq!(p.observations, 0);
576    }
577
578    // ─── Determinism ──────────────────────────────────────────────
579
580    #[test]
581    fn deterministic_under_same_observations() {
582        let run = || {
583            let mut pred = HeightPredictor::default();
584            let observations = [1, 2, 1, 3, 1, 2, 1, 1, 4, 1];
585            for &h in &observations {
586                pred.observe(0, h);
587            }
588            (pred.predict(0).predicted, pred.posterior_mean(0))
589        };
590
591        let (p1, m1) = run();
592        let (p2, m2) = run();
593        assert_eq!(p1, p2);
594        assert!((m1 - m2).abs() < 1e-15);
595    }
596
597    // ─── Performance ──────────────────────────────────────────────
598
599    #[test]
600    fn perf_prediction_overhead() {
601        let mut pred = HeightPredictor::default();
602
603        // Warm up.
604        for _ in 0..100 {
605            pred.observe(0, 2);
606        }
607
608        let start = std::time::Instant::now();
609        let mut _sink = 0u16;
610        for _ in 0..100_000 {
611            _sink = _sink.wrapping_add(pred.predict(0).predicted);
612        }
613        let elapsed = start.elapsed();
614        let per_prediction = elapsed / 100_000;
615
616        // Must be < 5μs per prediction (generous for debug builds).
617        assert!(
618            per_prediction < std::time::Duration::from_micros(5),
619            "prediction too slow: {per_prediction:?}"
620        );
621    }
622
623    // ─── Violation tracking ───────────────────────────────────────
624
625    #[test]
626    fn violation_tracking() {
627        let mut pred = HeightPredictor::new(PredictorConfig {
628            prior_mean: 5.0,
629            prior_strength: 100.0, // strong prior
630            default_height: 5,
631            coverage: 0.95,
632            ..Default::default()
633        });
634
635        // Warm up with height=5.
636        for _ in 0..50 {
637            pred.observe(0, 5);
638        }
639
640        // Sudden jump to height=20 should violate bounds.
641        let within = pred.observe(0, 20);
642        assert!(!within, "extreme outlier should violate bounds");
643        assert!(pred.total_violations() > 0);
644    }
645
646    // ── PredictorConfig defaults ─────────────────────────────────
647
648    #[test]
649    fn config_default_values() {
650        let config = PredictorConfig::default();
651        assert_eq!(config.default_height, 1);
652        assert!((config.prior_strength - 2.0).abs() < f64::EPSILON);
653        assert!((config.prior_mean - 1.0).abs() < f64::EPSILON);
654        assert!((config.prior_variance - 4.0).abs() < f64::EPSILON);
655        assert!((config.coverage - 0.90).abs() < f64::EPSILON);
656        assert_eq!(config.calibration_window, 200);
657    }
658
659    // ── HeightPredictor::default ─────────────────────────────────
660
661    #[test]
662    fn default_predictor_has_one_category() {
663        let pred = HeightPredictor::default();
664        assert_eq!(pred.category_count(), 1);
665        assert_eq!(pred.total_measurements(), 0);
666        assert_eq!(pred.total_violations(), 0);
667        assert!((pred.violation_rate() - 0.0).abs() < f64::EPSILON);
668    }
669
670    // ── Predict unknown category ─────────────────────────────────
671
672    #[test]
673    fn predict_unknown_category_returns_cold() {
674        let pred = HeightPredictor::default();
675        let p = pred.predict(999);
676        assert_eq!(p.predicted, pred.config.default_height);
677        assert_eq!(p.observations, 0);
678    }
679
680    // ── Observe auto-creates categories ──────────────────────────
681
682    #[test]
683    fn observe_auto_creates_categories() {
684        let mut pred = HeightPredictor::default();
685        assert_eq!(pred.category_count(), 1);
686        pred.observe(3, 5);
687        // Should auto-create categories 1, 2, 3
688        assert_eq!(pred.category_count(), 4);
689        assert_eq!(pred.category_observations(3), 1);
690    }
691
692    // ── Violation rate ───────────────────────────────────────────
693
694    #[test]
695    fn violation_rate_empty() {
696        let pred = HeightPredictor::default();
697        assert!((pred.violation_rate() - 0.0).abs() < f64::EPSILON);
698    }
699
700    #[test]
701    fn violation_rate_computation() {
702        let mut pred = HeightPredictor::new(PredictorConfig {
703            prior_mean: 5.0,
704            prior_strength: 100.0,
705            default_height: 5,
706            coverage: 0.95,
707            ..Default::default()
708        });
709        // Warm up so bounds are tight
710        for _ in 0..50 {
711            pred.observe(0, 5);
712        }
713        // 10 normal observations
714        for _ in 0..10 {
715            pred.observe(0, 5);
716        }
717        let before_violations = pred.total_violations();
718        // 1 extreme outlier
719        pred.observe(0, 100);
720        let after_violations = pred.total_violations();
721        assert!(after_violations > before_violations);
722        assert!(pred.violation_rate() > 0.0);
723    }
724
725    // ── Category accessors ───────────────────────────────────────
726
727    #[test]
728    fn category_observations_returns_zero_for_unknown() {
729        let pred = HeightPredictor::default();
730        assert_eq!(pred.category_observations(999), 0);
731    }
732
733    #[test]
734    fn category_observations_tracks_counts() {
735        let mut pred = HeightPredictor::default();
736        pred.observe(0, 3);
737        pred.observe(0, 4);
738        pred.observe(0, 5);
739        assert_eq!(pred.category_observations(0), 3);
740    }
741
742    // ── Posterior accessors with unknown category ─────────────────
743
744    #[test]
745    fn posterior_mean_unknown_returns_prior() {
746        let pred = HeightPredictor::default();
747        assert!((pred.posterior_mean(999) - pred.config.prior_mean).abs() < f64::EPSILON);
748    }
749
750    #[test]
751    fn posterior_variance_unknown_returns_prior() {
752        let pred = HeightPredictor::default();
753        assert!((pred.posterior_variance(999) - pred.config.prior_variance).abs() < f64::EPSILON);
754    }
755
756    // ── Register category ────────────────────────────────────────
757
758    #[test]
759    fn register_category_returns_sequential_ids() {
760        let mut pred = HeightPredictor::default();
761        let id1 = pred.register_category();
762        let id2 = pred.register_category();
763        assert_eq!(id1, 1);
764        assert_eq!(id2, 2);
765        assert_eq!(pred.category_count(), 3);
766    }
767
768    // ── Observe returns within_bounds ─────────────────────────────
769
770    #[test]
771    fn observe_returns_true_for_consistent_data() {
772        let mut pred = HeightPredictor::new(PredictorConfig {
773            prior_mean: 3.0,
774            prior_strength: 1.0,
775            ..Default::default()
776        });
777        // Warm up
778        for _ in 0..20 {
779            pred.observe(0, 3);
780        }
781        // Same value should be within bounds
782        assert!(pred.observe(0, 3));
783    }
784
785    // ── Total measurements ───────────────────────────────────────
786
787    #[test]
788    fn total_measurements_increments() {
789        let mut pred = HeightPredictor::default();
790        for i in 0..7 {
791            pred.observe(0, (i + 1) as u16);
792        }
793        assert_eq!(pred.total_measurements(), 7);
794    }
795
796    // ── HeightPrediction bounds ordering ─────────────────────────
797
798    #[test]
799    fn prediction_lower_le_predicted_le_upper() {
800        let mut pred = HeightPredictor::default();
801        for _ in 0..30 {
802            pred.observe(0, 3);
803        }
804        let p = pred.predict(0);
805        assert!(p.lower <= p.predicted);
806        assert!(p.predicted <= p.upper);
807    }
808
809    // ── Edge-case tests (bd-l9r1a) ──────────────────────────
810
811    #[test]
812    fn observe_height_zero() {
813        let mut pred = HeightPredictor::default();
814        pred.observe(0, 0);
815        let p = pred.predict(0);
816        // predicted is max(mu.round(), 1.0) so at least 1
817        assert!(p.predicted >= 1);
818    }
819
820    #[test]
821    fn observe_height_max_u16() {
822        let mut pred = HeightPredictor::default();
823        pred.observe(0, u16::MAX);
824        let p = pred.predict(0);
825        assert!(p.predicted > 0);
826        assert!(p.observations == 1);
827    }
828
829    #[test]
830    fn cold_prediction_zero_variance() {
831        let pred = HeightPredictor::new(PredictorConfig {
832            default_height: 5,
833            prior_variance: 0.0,
834            ..Default::default()
835        });
836        let p = pred.predict(0);
837        assert_eq!(p.predicted, 5);
838        // margin = ceil(sqrt(0.0) * 2.0) = 0
839        assert_eq!(p.lower, 5);
840        assert_eq!(p.upper, 5);
841    }
842
843    #[test]
844    fn cold_prediction_large_variance() {
845        let pred = HeightPredictor::new(PredictorConfig {
846            default_height: 1,
847            prior_variance: 10000.0,
848            ..Default::default()
849        });
850        let p = pred.predict(0);
851        assert_eq!(p.predicted, 1);
852        // margin = ceil(sqrt(10000) * 2) = ceil(200) = 200
853        assert_eq!(p.lower, 0); // 1.saturating_sub(200) = 0
854    }
855
856    #[test]
857    fn coverage_zero() {
858        let mut pred = HeightPredictor::new(PredictorConfig {
859            coverage: 0.0,
860            prior_mean: 3.0,
861            prior_strength: 1.0,
862            ..Default::default()
863        });
864        for _ in 0..20 {
865            pred.observe(0, 3);
866        }
867        // alpha = 1.0, quantile_idx → 0
868        let p = pred.predict(0);
869        assert!(p.predicted > 0);
870    }
871
872    #[test]
873    fn coverage_one() {
874        let mut pred = HeightPredictor::new(PredictorConfig {
875            coverage: 1.0,
876            prior_mean: 3.0,
877            prior_strength: 1.0,
878            ..Default::default()
879        });
880        for _ in 0..20 {
881            pred.observe(0, 3);
882        }
883        for _ in 0..5 {
884            pred.observe(0, 10);
885        }
886        // alpha = 0.0, quantile_idx → max residual
887        let p = pred.predict(0);
888        assert!(p.lower <= p.predicted);
889        assert!(p.predicted <= p.upper);
890    }
891
892    #[test]
893    fn calibration_window_one() {
894        let mut pred = HeightPredictor::new(PredictorConfig {
895            calibration_window: 1,
896            prior_mean: 3.0,
897            prior_strength: 1.0,
898            ..Default::default()
899        });
900        for _ in 0..10 {
901            pred.observe(0, 3);
902        }
903        let p = pred.predict(0);
904        assert!(p.predicted > 0);
905        assert!(p.lower <= p.predicted);
906    }
907
908    #[test]
909    fn single_observation_uses_wide_bounds() {
910        let mut pred = HeightPredictor::new(PredictorConfig {
911            prior_mean: 5.0,
912            prior_strength: 1.0,
913            prior_variance: 4.0,
914            ..Default::default()
915        });
916        pred.observe(0, 5);
917        let p = pred.predict(0);
918        assert_eq!(p.observations, 1);
919        // With only 1 residual, bounds come from that single residual
920        assert!(p.lower <= p.predicted);
921        assert!(p.predicted <= p.upper);
922    }
923
924    #[test]
925    fn predictor_config_clone_and_debug() {
926        let config = PredictorConfig::default();
927        let cloned = config.clone();
928        assert_eq!(cloned.default_height, config.default_height);
929        let dbg = format!("{:?}", config);
930        assert!(dbg.contains("PredictorConfig"));
931    }
932
933    #[test]
934    fn height_prediction_copy_and_debug() {
935        let p = HeightPrediction {
936            predicted: 3,
937            lower: 1,
938            upper: 5,
939            observations: 10,
940        };
941        let p2 = p; // Copy
942        assert_eq!(p.predicted, p2.predicted);
943        assert_eq!(p.lower, p2.lower);
944        assert_eq!(p.upper, p2.upper);
945        assert_eq!(p.observations, p2.observations);
946        let dbg = format!("{:?}", p);
947        assert!(dbg.contains("HeightPrediction"));
948    }
949
950    #[test]
951    fn height_prediction_clone() {
952        fn assert_clone<T: Clone>() {}
953        assert_clone::<HeightPrediction>();
954        let p = HeightPrediction {
955            predicted: 2,
956            lower: 1,
957            upper: 4,
958            observations: 5,
959        };
960        let cloned = p; // Copy implies Clone; clippy forbids clone_on_copy
961        assert_eq!(cloned.predicted, 2);
962    }
963
964    #[test]
965    fn predictor_clone_independence() {
966        let mut pred = HeightPredictor::default();
967        pred.observe(0, 5);
968        pred.observe(0, 5);
969        let mut cloned = pred.clone();
970        cloned.observe(0, 100);
971        // Original should be unaffected
972        assert_eq!(pred.total_measurements(), 2);
973        assert_eq!(cloned.total_measurements(), 3);
974    }
975
976    #[test]
977    fn predictor_debug() {
978        let pred = HeightPredictor::default();
979        let dbg = format!("{:?}", pred);
980        assert!(dbg.contains("HeightPredictor"));
981    }
982
983    #[test]
984    fn posterior_variance_with_two_identical_observations() {
985        let mut pred = HeightPredictor::new(PredictorConfig {
986            prior_variance: 4.0,
987            prior_strength: 1.0,
988            ..Default::default()
989        });
990        pred.observe(0, 3);
991        pred.observe(0, 3);
992        // Welford variance with identical values = 0, κ_n = 3
993        // posterior_variance = 0 / 3 = 0
994        let var = pred.posterior_variance(0);
995        assert!(var.abs() < 1e-10, "identical obs should give ~0 variance");
996    }
997
998    #[test]
999    fn posterior_variance_with_one_observation_uses_prior() {
1000        let mut pred = HeightPredictor::new(PredictorConfig {
1001            prior_variance: 4.0,
1002            prior_strength: 2.0,
1003            ..Default::default()
1004        });
1005        pred.observe(0, 3);
1006        // n=1, so welford.variance() returns f64::MAX → uses prior_variance
1007        // But wait: code checks n < 2, uses prior_variance = 4.0
1008        // posterior_variance = 4.0 / (2.0 + 1) = 4/3
1009        let var = pred.posterior_variance(0);
1010        assert!((var - 4.0 / 3.0).abs() < 1e-10);
1011    }
1012
1013    #[test]
1014    fn observe_returns_false_for_first_cold_outlier() {
1015        let mut pred = HeightPredictor::new(PredictorConfig {
1016            default_height: 1,
1017            prior_mean: 1.0,
1018            prior_strength: 2.0,
1019            prior_variance: 0.25,
1020            ..Default::default()
1021        });
1022        // Cold prediction: predicted=1, margin=ceil(sqrt(0.25)*2)=ceil(1.0)=1
1023        // bounds: [0, 2]
1024        // First observation is cold (observations=0), so violation not counted
1025        let within = pred.observe(0, 100);
1026        // Cold start: prediction.observations == 0, so violation is NOT counted
1027        assert!(within || pred.total_violations() == 0);
1028    }
1029
1030    #[test]
1031    fn all_same_height_converges_exactly() {
1032        let mut pred = HeightPredictor::new(PredictorConfig {
1033            prior_mean: 3.0,
1034            prior_strength: 1.0,
1035            ..Default::default()
1036        });
1037        for _ in 0..100 {
1038            pred.observe(0, 3);
1039        }
1040        let p = pred.predict(0);
1041        assert_eq!(p.predicted, 3);
1042        // With all identical observations, bounds should collapse
1043        assert_eq!(p.lower, 3);
1044        assert_eq!(p.upper, 3);
1045    }
1046
1047    #[test]
1048    fn many_categories_auto_created() {
1049        let mut pred = HeightPredictor::default();
1050        pred.observe(10, 5);
1051        // Categories 0..=10 should exist now
1052        assert_eq!(pred.category_count(), 11);
1053        // Intermediate categories have no observations
1054        assert_eq!(pred.category_observations(5), 0);
1055        assert_eq!(pred.category_observations(10), 1);
1056    }
1057
1058    #[test]
1059    fn prediction_bounds_ordering_after_mixed_data() {
1060        let mut pred = HeightPredictor::default();
1061        for h in [1, 2, 5, 10, 1, 3, 7, 2, 4, 6] {
1062            pred.observe(0, h);
1063        }
1064        let p = pred.predict(0);
1065        assert!(
1066            p.lower <= p.predicted,
1067            "lower={} > predicted={}",
1068            p.lower,
1069            p.predicted
1070        );
1071        assert!(
1072            p.predicted <= p.upper,
1073            "predicted={} > upper={}",
1074            p.predicted,
1075            p.upper
1076        );
1077    }
1078}