Skip to main content

ftui_widgets/
height_predictor.rs

1//! Bayesian height prediction with conformal bounds for virtualized lists.
2//!
3//! Predicts unseen row heights to pre-allocate scroll space and avoid
4//! scroll jumps when actual heights are measured lazily.
5//!
6//! # Mathematical Model
7//!
8//! ## Bayesian Online Estimation
9//!
10//! Maintains a Normal-Normal conjugate model per item category:
11//!
12//! ```text
13//! Prior:     μ ~ N(μ₀, σ₀²/κ₀)
14//! Likelihood: h_i ~ N(μ, σ²)
15//! Posterior:  μ | data ~ N(μ_n, σ²/κ_n)
16//!
17//! where:
18//!   κ_n = κ₀ + n
19//!   μ_n = (κ₀·μ₀ + n·x̄) / κ_n
20//!   σ²  estimated via running variance (Welford's algorithm)
21//! ```
22//!
23//! ## Conformal Prediction Bounds
24//!
25//! Given a calibration set of (predicted, actual) residuals, the conformal
26//! interval is:
27//!
28//! ```text
29//! [μ_n - q_{1-α/2}, μ_n + q_{1-α/2}]
30//! ```
31//!
32//! where `q` is the empirical quantile of |residuals|. This provides
33//! distribution-free coverage: P(h ∈ interval) ≥ 1 - α.
34//!
35//! # Failure Modes
36//!
37//! | Condition | Behavior | Rationale |
38//! |-----------|----------|-----------|
39//! | No measurements | Return default height | Cold start fallback |
40//! | n = 1 | Wide interval (use prior σ) | Insufficient data |
41//! | All same height | σ → 0, interval collapses | Homogeneous data |
42//! | Actual > bound | Adjust + record violation | Expected at rate α |
43
44use std::collections::VecDeque;
45
46/// Configuration for the height predictor.
47#[derive(Debug, Clone)]
48pub struct PredictorConfig {
49    /// Default height when no data is available.
50    pub default_height: u16,
51    /// Prior strength κ₀ (higher = more trust in default). Default: 2.0.
52    pub prior_strength: f64,
53    /// Prior mean μ₀ (usually same as default_height).
54    pub prior_mean: f64,
55    /// Prior variance estimate. Default: 4.0.
56    pub prior_variance: f64,
57    /// Conformal coverage level (1 - α). Default: 0.90.
58    pub coverage: f64,
59    /// Max calibration residuals to keep. Default: 200.
60    pub calibration_window: usize,
61}
62
63impl Default for PredictorConfig {
64    fn default() -> Self {
65        Self {
66            default_height: 1,
67            prior_strength: 2.0,
68            prior_mean: 1.0,
69            prior_variance: 4.0,
70            coverage: 0.90,
71            calibration_window: 200,
72        }
73    }
74}
75
76/// Running statistics using Welford's online algorithm.
77#[derive(Debug, Clone)]
78struct WelfordStats {
79    n: u64,
80    mean: f64,
81    m2: f64, // Sum of squared deviations
82}
83
84impl WelfordStats {
85    fn new() -> Self {
86        Self {
87            n: 0,
88            mean: 0.0,
89            m2: 0.0,
90        }
91    }
92
93    fn update(&mut self, x: f64) {
94        self.n += 1;
95        let delta = x - self.mean;
96        self.mean += delta / self.n as f64;
97        let delta2 = x - self.mean;
98        self.m2 += delta * delta2;
99    }
100
101    fn variance(&self) -> f64 {
102        if self.n < 2 {
103            return f64::MAX;
104        }
105        self.m2 / (self.n - 1) as f64
106    }
107}
108
109/// Per-category prediction state.
110#[derive(Debug, Clone)]
111struct CategoryState {
112    /// Welford running stats for observed heights.
113    welford: WelfordStats,
114    /// Posterior mean μ_n.
115    posterior_mean: f64,
116    /// Posterior κ_n.
117    posterior_kappa: f64,
118    /// Calibration residuals |predicted - actual|.
119    residuals: VecDeque<f64>,
120}
121
122/// A prediction with conformal bounds.
123#[derive(Debug, Clone, Copy)]
124pub struct HeightPrediction {
125    /// Point prediction (posterior mean, rounded).
126    pub predicted: u16,
127    /// Lower conformal bound.
128    pub lower: u16,
129    /// Upper conformal bound.
130    pub upper: u16,
131    /// Number of observations for this category.
132    pub observations: u64,
133}
134
135/// Bayesian height predictor with conformal bounds.
136#[derive(Debug, Clone)]
137pub struct HeightPredictor {
138    config: PredictorConfig,
139    /// Per-category states. Key is category index (0 = default).
140    categories: Vec<CategoryState>,
141    /// Total measurements across all categories.
142    total_measurements: u64,
143    /// Total bound violations.
144    total_violations: u64,
145}
146
147impl HeightPredictor {
148    /// Create a new predictor with default config.
149    pub fn new(config: PredictorConfig) -> Self {
150        // Start with one default category.
151        let default_cat = CategoryState {
152            welford: WelfordStats::new(),
153            posterior_mean: config.prior_mean,
154            posterior_kappa: config.prior_strength,
155            residuals: VecDeque::new(),
156        };
157        Self {
158            config,
159            categories: vec![default_cat],
160            total_measurements: 0,
161            total_violations: 0,
162        }
163    }
164
165    /// Register a new category. Returns the category id.
166    pub fn register_category(&mut self) -> usize {
167        let id = self.categories.len();
168        self.categories.push(CategoryState {
169            welford: WelfordStats::new(),
170            posterior_mean: self.config.prior_mean,
171            posterior_kappa: self.config.prior_strength,
172            residuals: VecDeque::new(),
173        });
174        id
175    }
176
177    /// Predict height for an item in the given category.
178    pub fn predict(&self, category: usize) -> HeightPrediction {
179        let cat = match self.categories.get(category) {
180            Some(c) => c,
181            None => return self.cold_prediction(),
182        };
183
184        if cat.welford.n == 0 {
185            return self.cold_prediction();
186        }
187
188        let mu = cat.posterior_mean;
189        let predicted = mu.round().max(1.0) as u16;
190
191        // Conformal bounds from calibration residuals.
192        let (lower, upper) = self.conformal_bounds(cat, mu);
193
194        HeightPrediction {
195            predicted,
196            lower,
197            upper,
198            observations: cat.welford.n,
199        }
200    }
201
202    /// Record an actual measured height, updating the model.
203    /// Returns whether the measurement was within the predicted bounds.
204    pub fn observe(&mut self, category: usize, actual_height: u16) -> bool {
205        // Ensure category exists.
206        while self.categories.len() <= category {
207            self.register_category();
208        }
209
210        let prediction = self.predict(category);
211        let within_bounds = actual_height >= prediction.lower && actual_height <= prediction.upper;
212
213        self.total_measurements += 1;
214        if !within_bounds && prediction.observations > 0 {
215            self.total_violations += 1;
216        }
217
218        let cat = &mut self.categories[category];
219        let h = actual_height as f64;
220
221        // Record calibration residual.
222        let residual = (cat.posterior_mean - h).abs();
223        cat.residuals.push_back(residual);
224        if cat.residuals.len() > self.config.calibration_window {
225            cat.residuals.pop_front();
226        }
227
228        // Update Welford stats.
229        cat.welford.update(h);
230
231        // Update posterior: μ_n = (κ₀·μ₀ + n·x̄) / κ_n
232        let n = cat.welford.n as f64;
233        let kappa_0 = self.config.prior_strength;
234        let mu_0 = self.config.prior_mean;
235        cat.posterior_kappa = kappa_0 + n;
236        cat.posterior_mean = (kappa_0 * mu_0 + n * cat.welford.mean) / cat.posterior_kappa;
237
238        within_bounds
239    }
240
241    /// Cold-start prediction when no data is available.
242    fn cold_prediction(&self) -> HeightPrediction {
243        let d = self.config.default_height;
244        let margin = (self.config.prior_variance.sqrt() * 2.0).ceil() as u16;
245        HeightPrediction {
246            predicted: d,
247            lower: d.saturating_sub(margin),
248            upper: d.saturating_add(margin),
249            observations: 0,
250        }
251    }
252
253    /// Compute conformal bounds from calibration residuals.
254    fn conformal_bounds(&self, cat: &CategoryState, mu: f64) -> (u16, u16) {
255        if cat.residuals.is_empty() {
256            // Fallback: use prior variance.
257            let margin = (self.config.prior_variance.sqrt() * 2.0).ceil() as u16;
258            let predicted = mu.round().max(1.0) as u16;
259            return (
260                predicted.saturating_sub(margin),
261                predicted.saturating_add(margin),
262            );
263        }
264
265        // Sort residuals to find quantile.
266        let mut sorted: Vec<f64> = cat.residuals.iter().copied().collect();
267        sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
268
269        let alpha = 1.0 - self.config.coverage;
270        let n = sorted.len() as f64;
271        let quantile_idx = ((1.0 - alpha) * (n + 1.0)).ceil() as usize;
272        let quantile_idx = quantile_idx.min(sorted.len()).saturating_sub(1);
273        let q = sorted[quantile_idx];
274
275        let lower = (mu - q).max(1.0).floor() as u16;
276        let upper = (mu + q).ceil().max(1.0) as u16;
277
278        (lower, upper)
279    }
280
281    /// Get the posterior mean for a category.
282    pub fn posterior_mean(&self, category: usize) -> f64 {
283        self.categories
284            .get(category)
285            .map(|c| c.posterior_mean)
286            .unwrap_or(self.config.prior_mean)
287    }
288
289    /// Get the posterior variance for a category.
290    pub fn posterior_variance(&self, category: usize) -> f64 {
291        self.categories
292            .get(category)
293            .map(|c| {
294                let sigma_sq = if c.welford.n < 2 {
295                    self.config.prior_variance
296                } else {
297                    c.welford.variance()
298                };
299                sigma_sq / c.posterior_kappa
300            })
301            .unwrap_or(self.config.prior_variance)
302    }
303
304    /// Total measurements observed.
305    pub fn total_measurements(&self) -> u64 {
306        self.total_measurements
307    }
308
309    /// Total bound violations.
310    pub fn total_violations(&self) -> u64 {
311        self.total_violations
312    }
313
314    /// Empirical violation rate.
315    pub fn violation_rate(&self) -> f64 {
316        if self.total_measurements == 0 {
317            return 0.0;
318        }
319        self.total_violations as f64 / self.total_measurements as f64
320    }
321
322    /// Number of categories.
323    pub fn category_count(&self) -> usize {
324        self.categories.len()
325    }
326
327    /// Number of observations for a category.
328    pub fn category_observations(&self, category: usize) -> u64 {
329        self.categories
330            .get(category)
331            .map(|c| c.welford.n)
332            .unwrap_or(0)
333    }
334}
335
336impl Default for HeightPredictor {
337    fn default() -> Self {
338        Self::new(PredictorConfig::default())
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    // ─── Posterior update tests ────────────────────────────────────
347
348    #[test]
349    fn unit_posterior_update() {
350        let config = PredictorConfig {
351            prior_mean: 2.0,
352            prior_strength: 1.0,
353            prior_variance: 4.0,
354            ..Default::default()
355        };
356        let mut pred = HeightPredictor::new(config);
357
358        // Prior: μ=2.0, κ=1.
359        assert!((pred.posterior_mean(0) - 2.0).abs() < 1e-10);
360
361        // Observe height 4.
362        pred.observe(0, 4);
363        // κ_1 = 1 + 1 = 2, μ_1 = (1*2 + 1*4) / 2 = 3.0
364        assert!((pred.posterior_mean(0) - 3.0).abs() < 1e-10);
365
366        // Observe another height 4.
367        pred.observe(0, 4);
368        // κ_2 = 1 + 2 = 3, x̄ = 4, μ_2 = (1*2 + 2*4) / 3 = 10/3 ≈ 3.333
369        assert!((pred.posterior_mean(0) - 10.0 / 3.0).abs() < 1e-10);
370    }
371
372    #[test]
373    fn unit_posterior_variance_decreases() {
374        let mut pred = HeightPredictor::new(PredictorConfig {
375            prior_variance: 4.0,
376            ..Default::default()
377        });
378
379        let var_0 = pred.posterior_variance(0);
380        assert!(var_0 > 0.0, "prior variance should be positive");
381
382        // Feed noisy data so Welford variance is non-zero.
383        for i in 0..10 {
384            pred.observe(0, if i % 2 == 0 { 2 } else { 4 });
385        }
386        let var_10 = pred.posterior_variance(0);
387
388        for i in 0..90 {
389            pred.observe(0, if i % 2 == 0 { 2 } else { 4 });
390        }
391        let var_100 = pred.posterior_variance(0);
392
393        // With noisy data, posterior variance σ²/κ_n decreases as κ_n grows.
394        assert!(
395            var_10 < var_0,
396            "variance should decrease: {var_10} >= {var_0}"
397        );
398        assert!(
399            var_100 < var_10,
400            "variance should decrease: {var_100} >= {var_10}"
401        );
402    }
403
404    // ─── Conformal bounds tests ───────────────────────────────────
405
406    #[test]
407    fn unit_conformal_bounds() {
408        let config = PredictorConfig {
409            coverage: 0.90,
410            prior_mean: 3.0,
411            prior_strength: 1.0,
412            ..Default::default()
413        };
414        let mut pred = HeightPredictor::new(config);
415
416        // Feed consistent data.
417        for _ in 0..50 {
418            pred.observe(0, 3);
419        }
420
421        let p = pred.predict(0);
422        // With all observations at 3, residuals should be near 0.
423        // Bounds should be tight around 3.
424        assert_eq!(p.predicted, 3);
425        assert!(p.lower <= 3);
426        assert!(p.upper >= 3);
427    }
428
429    #[test]
430    fn conformal_bounds_widen_with_noise() {
431        let config = PredictorConfig {
432            coverage: 0.90,
433            prior_mean: 5.0,
434            prior_strength: 1.0,
435            ..Default::default()
436        };
437        let mut pred = HeightPredictor::new(config);
438
439        // Consistent data → tight bounds.
440        for _ in 0..50 {
441            pred.observe(0, 5);
442        }
443        let tight = pred.predict(0);
444
445        // Reset with noisy data.
446        let mut pred2 = HeightPredictor::new(PredictorConfig {
447            coverage: 0.90,
448            prior_mean: 5.0,
449            prior_strength: 1.0,
450            ..Default::default()
451        });
452        let mut seed: u64 = 0xABCD_1234_5678_9ABC;
453        for _ in 0..50 {
454            seed = seed
455                .wrapping_mul(6364136223846793005)
456                .wrapping_add(1442695040888963407);
457            let h = 3 + (seed >> 62) as u16; // heights 3..6
458            pred2.observe(0, h);
459        }
460        let wide = pred2.predict(0);
461
462        assert!(
463            (wide.upper - wide.lower) >= (tight.upper - tight.lower),
464            "noisy data should produce wider bounds"
465        );
466    }
467
468    // ─── Coverage property test ───────────────────────────────────
469
470    #[test]
471    fn property_coverage() {
472        let alpha = 0.10;
473        let config = PredictorConfig {
474            coverage: 1.0 - alpha,
475            prior_mean: 3.0,
476            prior_strength: 2.0,
477            prior_variance: 4.0,
478            calibration_window: 100,
479            ..Default::default()
480        };
481        let mut pred = HeightPredictor::new(config);
482
483        // Warm up with calibration data.
484        let mut seed: u64 = 0xDEAD_BEEF_CAFE_0001;
485        for _ in 0..100 {
486            seed = seed
487                .wrapping_mul(6364136223846793005)
488                .wrapping_add(1442695040888963407);
489            let h = 2 + (seed >> 62) as u16; // heights 2..5
490            pred.observe(0, h);
491        }
492
493        // Now check coverage on new data.
494        let mut violations = 0u32;
495        let test_n = 200;
496        for _ in 0..test_n {
497            seed = seed
498                .wrapping_mul(6364136223846793005)
499                .wrapping_add(1442695040888963407);
500            let h = 2 + (seed >> 62) as u16;
501            let within = pred.observe(0, h);
502            if !within {
503                violations += 1;
504            }
505        }
506
507        let viol_rate = violations as f64 / test_n as f64;
508        // Empirical violation rate should be approximately ≤ α.
509        // Allow generous tolerance for finite sample + discrete heights.
510        assert!(
511            viol_rate <= alpha + 0.15,
512            "violation rate {viol_rate} exceeds α + tolerance ({alpha} + 0.15)"
513        );
514    }
515
516    // ─── Scroll stability test ────────────────────────────────────
517
518    #[test]
519    fn e2e_scroll_stability() {
520        let mut pred = HeightPredictor::new(PredictorConfig {
521            prior_mean: 1.0,
522            prior_strength: 2.0,
523            default_height: 1,
524            coverage: 0.90,
525            ..Default::default()
526        });
527
528        // All items are height 1 (most common TUI case).
529        let mut corrections = 0u32;
530        for _ in 0..500 {
531            let within = pred.observe(0, 1);
532            if !within {
533                corrections += 1;
534            }
535        }
536
537        // With homogeneous heights, should converge quickly with zero corrections
538        // after warmup.
539        let p = pred.predict(0);
540        assert_eq!(p.predicted, 1);
541        assert!(corrections < 10, "too many corrections: {corrections}");
542    }
543
544    // ─── Multiple categories ──────────────────────────────────────
545
546    #[test]
547    fn categories_are_independent() {
548        let mut pred = HeightPredictor::default();
549        let cat_a = 0;
550        let cat_b = pred.register_category();
551
552        // Feed different data to each.
553        for _ in 0..20 {
554            pred.observe(cat_a, 1);
555            pred.observe(cat_b, 5);
556        }
557
558        let pa = pred.predict(cat_a);
559        let pb = pred.predict(cat_b);
560
561        assert_eq!(pa.predicted, 1);
562        assert!(pb.predicted >= 4 && pb.predicted <= 5);
563    }
564
565    // ─── Cold start ───────────────────────────────────────────────
566
567    #[test]
568    fn cold_prediction_uses_default() {
569        let pred = HeightPredictor::new(PredictorConfig {
570            default_height: 2,
571            prior_variance: 1.0,
572            ..Default::default()
573        });
574        let p = pred.predict(0);
575        assert_eq!(p.predicted, 2);
576        assert_eq!(p.observations, 0);
577    }
578
579    // ─── Determinism ──────────────────────────────────────────────
580
581    #[test]
582    fn deterministic_under_same_observations() {
583        let run = || {
584            let mut pred = HeightPredictor::default();
585            let observations = [1, 2, 1, 3, 1, 2, 1, 1, 4, 1];
586            for &h in &observations {
587                pred.observe(0, h);
588            }
589            (pred.predict(0).predicted, pred.posterior_mean(0))
590        };
591
592        let (p1, m1) = run();
593        let (p2, m2) = run();
594        assert_eq!(p1, p2);
595        assert!((m1 - m2).abs() < 1e-15);
596    }
597
598    // ─── Performance ──────────────────────────────────────────────
599
600    #[test]
601    fn perf_prediction_overhead() {
602        let mut pred = HeightPredictor::default();
603
604        // Warm up.
605        for _ in 0..100 {
606            pred.observe(0, 2);
607        }
608
609        let start = std::time::Instant::now();
610        let mut _sink = 0u16;
611        for _ in 0..100_000 {
612            _sink = _sink.wrapping_add(pred.predict(0).predicted);
613        }
614        let elapsed = start.elapsed();
615        let per_prediction = elapsed / 100_000;
616
617        // Must be < 5μs per prediction (generous for debug builds).
618        assert!(
619            per_prediction < std::time::Duration::from_micros(5),
620            "prediction too slow: {per_prediction:?}"
621        );
622    }
623
624    // ─── Violation tracking ───────────────────────────────────────
625
626    #[test]
627    fn violation_tracking() {
628        let mut pred = HeightPredictor::new(PredictorConfig {
629            prior_mean: 5.0,
630            prior_strength: 100.0, // strong prior
631            default_height: 5,
632            coverage: 0.95,
633            ..Default::default()
634        });
635
636        // Warm up with height=5.
637        for _ in 0..50 {
638            pred.observe(0, 5);
639        }
640
641        // Sudden jump to height=20 should violate bounds.
642        let within = pred.observe(0, 20);
643        assert!(!within, "extreme outlier should violate bounds");
644        assert!(pred.total_violations() > 0);
645    }
646
647    // ── PredictorConfig defaults ─────────────────────────────────
648
649    #[test]
650    fn config_default_values() {
651        let config = PredictorConfig::default();
652        assert_eq!(config.default_height, 1);
653        assert!((config.prior_strength - 2.0).abs() < f64::EPSILON);
654        assert!((config.prior_mean - 1.0).abs() < f64::EPSILON);
655        assert!((config.prior_variance - 4.0).abs() < f64::EPSILON);
656        assert!((config.coverage - 0.90).abs() < f64::EPSILON);
657        assert_eq!(config.calibration_window, 200);
658    }
659
660    // ── HeightPredictor::default ─────────────────────────────────
661
662    #[test]
663    fn default_predictor_has_one_category() {
664        let pred = HeightPredictor::default();
665        assert_eq!(pred.category_count(), 1);
666        assert_eq!(pred.total_measurements(), 0);
667        assert_eq!(pred.total_violations(), 0);
668        assert!((pred.violation_rate() - 0.0).abs() < f64::EPSILON);
669    }
670
671    // ── Predict unknown category ─────────────────────────────────
672
673    #[test]
674    fn predict_unknown_category_returns_cold() {
675        let pred = HeightPredictor::default();
676        let p = pred.predict(999);
677        assert_eq!(p.predicted, pred.config.default_height);
678        assert_eq!(p.observations, 0);
679    }
680
681    // ── Observe auto-creates categories ──────────────────────────
682
683    #[test]
684    fn observe_auto_creates_categories() {
685        let mut pred = HeightPredictor::default();
686        assert_eq!(pred.category_count(), 1);
687        pred.observe(3, 5);
688        // Should auto-create categories 1, 2, 3
689        assert_eq!(pred.category_count(), 4);
690        assert_eq!(pred.category_observations(3), 1);
691    }
692
693    // ── Violation rate ───────────────────────────────────────────
694
695    #[test]
696    fn violation_rate_empty() {
697        let pred = HeightPredictor::default();
698        assert!((pred.violation_rate() - 0.0).abs() < f64::EPSILON);
699    }
700
701    #[test]
702    fn violation_rate_computation() {
703        let mut pred = HeightPredictor::new(PredictorConfig {
704            prior_mean: 5.0,
705            prior_strength: 100.0,
706            default_height: 5,
707            coverage: 0.95,
708            ..Default::default()
709        });
710        // Warm up so bounds are tight
711        for _ in 0..50 {
712            pred.observe(0, 5);
713        }
714        // 10 normal observations
715        for _ in 0..10 {
716            pred.observe(0, 5);
717        }
718        let before_violations = pred.total_violations();
719        // 1 extreme outlier
720        pred.observe(0, 100);
721        let after_violations = pred.total_violations();
722        assert!(after_violations > before_violations);
723        assert!(pred.violation_rate() > 0.0);
724    }
725
726    // ── Category accessors ───────────────────────────────────────
727
728    #[test]
729    fn category_observations_returns_zero_for_unknown() {
730        let pred = HeightPredictor::default();
731        assert_eq!(pred.category_observations(999), 0);
732    }
733
734    #[test]
735    fn category_observations_tracks_counts() {
736        let mut pred = HeightPredictor::default();
737        pred.observe(0, 3);
738        pred.observe(0, 4);
739        pred.observe(0, 5);
740        assert_eq!(pred.category_observations(0), 3);
741    }
742
743    // ── Posterior accessors with unknown category ─────────────────
744
745    #[test]
746    fn posterior_mean_unknown_returns_prior() {
747        let pred = HeightPredictor::default();
748        assert!((pred.posterior_mean(999) - pred.config.prior_mean).abs() < f64::EPSILON);
749    }
750
751    #[test]
752    fn posterior_variance_unknown_returns_prior() {
753        let pred = HeightPredictor::default();
754        assert!((pred.posterior_variance(999) - pred.config.prior_variance).abs() < f64::EPSILON);
755    }
756
757    // ── Register category ────────────────────────────────────────
758
759    #[test]
760    fn register_category_returns_sequential_ids() {
761        let mut pred = HeightPredictor::default();
762        let id1 = pred.register_category();
763        let id2 = pred.register_category();
764        assert_eq!(id1, 1);
765        assert_eq!(id2, 2);
766        assert_eq!(pred.category_count(), 3);
767    }
768
769    // ── Observe returns within_bounds ─────────────────────────────
770
771    #[test]
772    fn observe_returns_true_for_consistent_data() {
773        let mut pred = HeightPredictor::new(PredictorConfig {
774            prior_mean: 3.0,
775            prior_strength: 1.0,
776            ..Default::default()
777        });
778        // Warm up
779        for _ in 0..20 {
780            pred.observe(0, 3);
781        }
782        // Same value should be within bounds
783        assert!(pred.observe(0, 3));
784    }
785
786    // ── Total measurements ───────────────────────────────────────
787
788    #[test]
789    fn total_measurements_increments() {
790        let mut pred = HeightPredictor::default();
791        for i in 0..7 {
792            pred.observe(0, (i + 1) as u16);
793        }
794        assert_eq!(pred.total_measurements(), 7);
795    }
796
797    // ── HeightPrediction bounds ordering ─────────────────────────
798
799    #[test]
800    fn prediction_lower_le_predicted_le_upper() {
801        let mut pred = HeightPredictor::default();
802        for _ in 0..30 {
803            pred.observe(0, 3);
804        }
805        let p = pred.predict(0);
806        assert!(p.lower <= p.predicted);
807        assert!(p.predicted <= p.upper);
808    }
809
810    // ── Edge-case tests (bd-l9r1a) ──────────────────────────
811
812    #[test]
813    fn observe_height_zero() {
814        let mut pred = HeightPredictor::default();
815        pred.observe(0, 0);
816        let p = pred.predict(0);
817        // predicted is max(mu.round(), 1.0) so at least 1
818        assert!(p.predicted >= 1);
819    }
820
821    #[test]
822    fn observe_height_max_u16() {
823        let mut pred = HeightPredictor::default();
824        pred.observe(0, u16::MAX);
825        let p = pred.predict(0);
826        assert!(p.predicted > 0);
827        assert!(p.observations == 1);
828    }
829
830    #[test]
831    fn cold_prediction_zero_variance() {
832        let pred = HeightPredictor::new(PredictorConfig {
833            default_height: 5,
834            prior_variance: 0.0,
835            ..Default::default()
836        });
837        let p = pred.predict(0);
838        assert_eq!(p.predicted, 5);
839        // margin = ceil(sqrt(0.0) * 2.0) = 0
840        assert_eq!(p.lower, 5);
841        assert_eq!(p.upper, 5);
842    }
843
844    #[test]
845    fn cold_prediction_large_variance() {
846        let pred = HeightPredictor::new(PredictorConfig {
847            default_height: 1,
848            prior_variance: 10000.0,
849            ..Default::default()
850        });
851        let p = pred.predict(0);
852        assert_eq!(p.predicted, 1);
853        // margin = ceil(sqrt(10000) * 2) = ceil(200) = 200
854        assert_eq!(p.lower, 0); // 1.saturating_sub(200) = 0
855    }
856
857    #[test]
858    fn coverage_zero() {
859        let mut pred = HeightPredictor::new(PredictorConfig {
860            coverage: 0.0,
861            prior_mean: 3.0,
862            prior_strength: 1.0,
863            ..Default::default()
864        });
865        for _ in 0..20 {
866            pred.observe(0, 3);
867        }
868        // alpha = 1.0, quantile_idx → 0
869        let p = pred.predict(0);
870        assert!(p.predicted > 0);
871    }
872
873    #[test]
874    fn coverage_one() {
875        let mut pred = HeightPredictor::new(PredictorConfig {
876            coverage: 1.0,
877            prior_mean: 3.0,
878            prior_strength: 1.0,
879            ..Default::default()
880        });
881        for _ in 0..20 {
882            pred.observe(0, 3);
883        }
884        for _ in 0..5 {
885            pred.observe(0, 10);
886        }
887        // alpha = 0.0, quantile_idx → max residual
888        let p = pred.predict(0);
889        assert!(p.lower <= p.predicted);
890        assert!(p.predicted <= p.upper);
891    }
892
893    #[test]
894    fn calibration_window_one() {
895        let mut pred = HeightPredictor::new(PredictorConfig {
896            calibration_window: 1,
897            prior_mean: 3.0,
898            prior_strength: 1.0,
899            ..Default::default()
900        });
901        for _ in 0..10 {
902            pred.observe(0, 3);
903        }
904        let p = pred.predict(0);
905        assert!(p.predicted > 0);
906        assert!(p.lower <= p.predicted);
907    }
908
909    #[test]
910    fn single_observation_uses_wide_bounds() {
911        let mut pred = HeightPredictor::new(PredictorConfig {
912            prior_mean: 5.0,
913            prior_strength: 1.0,
914            prior_variance: 4.0,
915            ..Default::default()
916        });
917        pred.observe(0, 5);
918        let p = pred.predict(0);
919        assert_eq!(p.observations, 1);
920        // With only 1 residual, bounds come from that single residual
921        assert!(p.lower <= p.predicted);
922        assert!(p.predicted <= p.upper);
923    }
924
925    #[test]
926    fn predictor_config_clone_and_debug() {
927        let config = PredictorConfig::default();
928        let cloned = config.clone();
929        assert_eq!(cloned.default_height, config.default_height);
930        let dbg = format!("{:?}", config);
931        assert!(dbg.contains("PredictorConfig"));
932    }
933
934    #[test]
935    fn height_prediction_copy_and_debug() {
936        let p = HeightPrediction {
937            predicted: 3,
938            lower: 1,
939            upper: 5,
940            observations: 10,
941        };
942        let p2 = p; // Copy
943        assert_eq!(p.predicted, p2.predicted);
944        assert_eq!(p.lower, p2.lower);
945        assert_eq!(p.upper, p2.upper);
946        assert_eq!(p.observations, p2.observations);
947        let dbg = format!("{:?}", p);
948        assert!(dbg.contains("HeightPrediction"));
949    }
950
951    #[test]
952    fn height_prediction_clone() {
953        fn assert_clone<T: Clone>() {}
954        assert_clone::<HeightPrediction>();
955        let p = HeightPrediction {
956            predicted: 2,
957            lower: 1,
958            upper: 4,
959            observations: 5,
960        };
961        let cloned = p; // Copy implies Clone; clippy forbids clone_on_copy
962        assert_eq!(cloned.predicted, 2);
963    }
964
965    #[test]
966    fn predictor_clone_independence() {
967        let mut pred = HeightPredictor::default();
968        pred.observe(0, 5);
969        pred.observe(0, 5);
970        let mut cloned = pred.clone();
971        cloned.observe(0, 100);
972        // Original should be unaffected
973        assert_eq!(pred.total_measurements(), 2);
974        assert_eq!(cloned.total_measurements(), 3);
975    }
976
977    #[test]
978    fn predictor_debug() {
979        let pred = HeightPredictor::default();
980        let dbg = format!("{:?}", pred);
981        assert!(dbg.contains("HeightPredictor"));
982    }
983
984    #[test]
985    fn posterior_variance_with_two_identical_observations() {
986        let mut pred = HeightPredictor::new(PredictorConfig {
987            prior_variance: 4.0,
988            prior_strength: 1.0,
989            ..Default::default()
990        });
991        pred.observe(0, 3);
992        pred.observe(0, 3);
993        // Welford variance with identical values = 0, κ_n = 3
994        // posterior_variance = 0 / 3 = 0
995        let var = pred.posterior_variance(0);
996        assert!(var.abs() < 1e-10, "identical obs should give ~0 variance");
997    }
998
999    #[test]
1000    fn posterior_variance_with_one_observation_uses_prior() {
1001        let mut pred = HeightPredictor::new(PredictorConfig {
1002            prior_variance: 4.0,
1003            prior_strength: 2.0,
1004            ..Default::default()
1005        });
1006        pred.observe(0, 3);
1007        // n=1, so welford.variance() returns f64::MAX → uses prior_variance
1008        // But wait: code checks n < 2, uses prior_variance = 4.0
1009        // posterior_variance = 4.0 / (2.0 + 1) = 4/3
1010        let var = pred.posterior_variance(0);
1011        assert!((var - 4.0 / 3.0).abs() < 1e-10);
1012    }
1013
1014    #[test]
1015    fn observe_returns_false_for_first_cold_outlier() {
1016        let mut pred = HeightPredictor::new(PredictorConfig {
1017            default_height: 1,
1018            prior_mean: 1.0,
1019            prior_strength: 2.0,
1020            prior_variance: 0.25,
1021            ..Default::default()
1022        });
1023        // Cold prediction: predicted=1, margin=ceil(sqrt(0.25)*2)=ceil(1.0)=1
1024        // bounds: [0, 2]
1025        // First observation is cold (observations=0), so violation not counted
1026        let within = pred.observe(0, 100);
1027        // Cold start: prediction.observations == 0, so violation is NOT counted
1028        assert!(within || pred.total_violations() == 0);
1029    }
1030
1031    #[test]
1032    fn all_same_height_converges_exactly() {
1033        let mut pred = HeightPredictor::new(PredictorConfig {
1034            prior_mean: 3.0,
1035            prior_strength: 1.0,
1036            ..Default::default()
1037        });
1038        for _ in 0..100 {
1039            pred.observe(0, 3);
1040        }
1041        let p = pred.predict(0);
1042        assert_eq!(p.predicted, 3);
1043        // With all identical observations, bounds should collapse
1044        assert_eq!(p.lower, 3);
1045        assert_eq!(p.upper, 3);
1046    }
1047
1048    #[test]
1049    fn many_categories_auto_created() {
1050        let mut pred = HeightPredictor::default();
1051        pred.observe(10, 5);
1052        // Categories 0..=10 should exist now
1053        assert_eq!(pred.category_count(), 11);
1054        // Intermediate categories have no observations
1055        assert_eq!(pred.category_observations(5), 0);
1056        assert_eq!(pred.category_observations(10), 1);
1057    }
1058
1059    #[test]
1060    fn prediction_bounds_ordering_after_mixed_data() {
1061        let mut pred = HeightPredictor::default();
1062        for h in [1, 2, 5, 10, 1, 3, 7, 2, 4, 6] {
1063            pred.observe(0, h);
1064        }
1065        let p = pred.predict(0);
1066        assert!(
1067            p.lower <= p.predicted,
1068            "lower={} > predicted={}",
1069            p.lower,
1070            p.predicted
1071        );
1072        assert!(
1073            p.predicted <= p.upper,
1074            "predicted={} > upper={}",
1075            p.predicted,
1076            p.upper
1077        );
1078    }
1079}