1use std::collections::VecDeque;
45
46#[derive(Debug, Clone)]
48pub struct PredictorConfig {
49 pub default_height: u16,
51 pub prior_strength: f64,
53 pub prior_mean: f64,
55 pub prior_variance: f64,
57 pub coverage: f64,
59 pub calibration_window: usize,
61}
62
63impl Default for PredictorConfig {
64 fn default() -> Self {
65 Self {
66 default_height: 1,
67 prior_strength: 2.0,
68 prior_mean: 1.0,
69 prior_variance: 4.0,
70 coverage: 0.90,
71 calibration_window: 200,
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
78struct WelfordStats {
79 n: u64,
80 mean: f64,
81 m2: f64, }
83
84impl WelfordStats {
85 fn new() -> Self {
86 Self {
87 n: 0,
88 mean: 0.0,
89 m2: 0.0,
90 }
91 }
92
93 fn update(&mut self, x: f64) {
94 self.n += 1;
95 let delta = x - self.mean;
96 self.mean += delta / self.n as f64;
97 let delta2 = x - self.mean;
98 self.m2 += delta * delta2;
99 }
100
101 fn variance(&self) -> f64 {
102 if self.n < 2 {
103 return f64::MAX;
104 }
105 self.m2 / (self.n - 1) as f64
106 }
107}
108
109#[derive(Debug, Clone)]
111struct CategoryState {
112 welford: WelfordStats,
114 posterior_mean: f64,
116 posterior_kappa: f64,
118 residuals: VecDeque<f64>,
120}
121
122#[derive(Debug, Clone, Copy)]
124pub struct HeightPrediction {
125 pub predicted: u16,
127 pub lower: u16,
129 pub upper: u16,
131 pub observations: u64,
133}
134
135#[derive(Debug, Clone)]
137pub struct HeightPredictor {
138 config: PredictorConfig,
139 categories: Vec<CategoryState>,
141 total_measurements: u64,
143 total_violations: u64,
145}
146
147impl HeightPredictor {
148 pub fn new(config: PredictorConfig) -> Self {
150 let default_cat = CategoryState {
152 welford: WelfordStats::new(),
153 posterior_mean: config.prior_mean,
154 posterior_kappa: config.prior_strength,
155 residuals: VecDeque::new(),
156 };
157 Self {
158 config,
159 categories: vec![default_cat],
160 total_measurements: 0,
161 total_violations: 0,
162 }
163 }
164
165 pub fn register_category(&mut self) -> usize {
167 let id = self.categories.len();
168 self.categories.push(CategoryState {
169 welford: WelfordStats::new(),
170 posterior_mean: self.config.prior_mean,
171 posterior_kappa: self.config.prior_strength,
172 residuals: VecDeque::new(),
173 });
174 id
175 }
176
177 pub fn predict(&self, category: usize) -> HeightPrediction {
179 let cat = match self.categories.get(category) {
180 Some(c) => c,
181 None => return self.cold_prediction(),
182 };
183
184 if cat.welford.n == 0 {
185 return self.cold_prediction();
186 }
187
188 let mu = cat.posterior_mean;
189 let predicted = mu.round().max(1.0) as u16;
190
191 let (lower, upper) = self.conformal_bounds(cat, mu);
193
194 HeightPrediction {
195 predicted,
196 lower,
197 upper,
198 observations: cat.welford.n,
199 }
200 }
201
202 pub fn observe(&mut self, category: usize, actual_height: u16) -> bool {
205 while self.categories.len() <= category {
207 self.register_category();
208 }
209
210 let prediction = self.predict(category);
211 let within_bounds = actual_height >= prediction.lower && actual_height <= prediction.upper;
212
213 self.total_measurements += 1;
214 if !within_bounds && prediction.observations > 0 {
215 self.total_violations += 1;
216 }
217
218 let cat = &mut self.categories[category];
219 let h = actual_height as f64;
220
221 let residual = (cat.posterior_mean - h).abs();
223 cat.residuals.push_back(residual);
224 if cat.residuals.len() > self.config.calibration_window {
225 cat.residuals.pop_front();
226 }
227
228 cat.welford.update(h);
230
231 let n = cat.welford.n as f64;
233 let kappa_0 = self.config.prior_strength;
234 let mu_0 = self.config.prior_mean;
235 cat.posterior_kappa = kappa_0 + n;
236 cat.posterior_mean = (kappa_0 * mu_0 + n * cat.welford.mean) / cat.posterior_kappa;
237
238 within_bounds
239 }
240
241 fn cold_prediction(&self) -> HeightPrediction {
243 let d = self.config.default_height;
244 let margin = (self.config.prior_variance.sqrt() * 2.0).ceil() as u16;
245 HeightPrediction {
246 predicted: d,
247 lower: d.saturating_sub(margin),
248 upper: d.saturating_add(margin),
249 observations: 0,
250 }
251 }
252
253 fn conformal_bounds(&self, cat: &CategoryState, mu: f64) -> (u16, u16) {
255 if cat.residuals.is_empty() {
256 let margin = (self.config.prior_variance.sqrt() * 2.0).ceil() as u16;
258 let predicted = mu.round().max(1.0) as u16;
259 return (
260 predicted.saturating_sub(margin),
261 predicted.saturating_add(margin),
262 );
263 }
264
265 let mut sorted: Vec<f64> = cat.residuals.iter().copied().collect();
267 sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
268
269 let alpha = 1.0 - self.config.coverage;
270 let quantile_idx = ((1.0 - alpha) * sorted.len() as f64).ceil() as usize;
271 let quantile_idx = quantile_idx.min(sorted.len()).saturating_sub(1);
272 let q = sorted[quantile_idx];
273
274 let lower = (mu - q).max(1.0).floor() as u16;
275 let upper = (mu + q).ceil().max(1.0) as u16;
276
277 (lower, upper)
278 }
279
280 pub fn posterior_mean(&self, category: usize) -> f64 {
282 self.categories
283 .get(category)
284 .map(|c| c.posterior_mean)
285 .unwrap_or(self.config.prior_mean)
286 }
287
288 pub fn posterior_variance(&self, category: usize) -> f64 {
290 self.categories
291 .get(category)
292 .map(|c| {
293 let sigma_sq = if c.welford.n < 2 {
294 self.config.prior_variance
295 } else {
296 c.welford.variance()
297 };
298 sigma_sq / c.posterior_kappa
299 })
300 .unwrap_or(self.config.prior_variance)
301 }
302
303 pub fn total_measurements(&self) -> u64 {
305 self.total_measurements
306 }
307
308 pub fn total_violations(&self) -> u64 {
310 self.total_violations
311 }
312
313 pub fn violation_rate(&self) -> f64 {
315 if self.total_measurements == 0 {
316 return 0.0;
317 }
318 self.total_violations as f64 / self.total_measurements as f64
319 }
320
321 pub fn category_count(&self) -> usize {
323 self.categories.len()
324 }
325
326 pub fn category_observations(&self, category: usize) -> u64 {
328 self.categories
329 .get(category)
330 .map(|c| c.welford.n)
331 .unwrap_or(0)
332 }
333}
334
335impl Default for HeightPredictor {
336 fn default() -> Self {
337 Self::new(PredictorConfig::default())
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
348 fn unit_posterior_update() {
349 let config = PredictorConfig {
350 prior_mean: 2.0,
351 prior_strength: 1.0,
352 prior_variance: 4.0,
353 ..Default::default()
354 };
355 let mut pred = HeightPredictor::new(config);
356
357 assert!((pred.posterior_mean(0) - 2.0).abs() < 1e-10);
359
360 pred.observe(0, 4);
362 assert!((pred.posterior_mean(0) - 3.0).abs() < 1e-10);
364
365 pred.observe(0, 4);
367 assert!((pred.posterior_mean(0) - 10.0 / 3.0).abs() < 1e-10);
369 }
370
371 #[test]
372 fn unit_posterior_variance_decreases() {
373 let mut pred = HeightPredictor::new(PredictorConfig {
374 prior_variance: 4.0,
375 ..Default::default()
376 });
377
378 let var_0 = pred.posterior_variance(0);
379 assert!(var_0 > 0.0, "prior variance should be positive");
380
381 for i in 0..10 {
383 pred.observe(0, if i % 2 == 0 { 2 } else { 4 });
384 }
385 let var_10 = pred.posterior_variance(0);
386
387 for i in 0..90 {
388 pred.observe(0, if i % 2 == 0 { 2 } else { 4 });
389 }
390 let var_100 = pred.posterior_variance(0);
391
392 assert!(
394 var_10 < var_0,
395 "variance should decrease: {var_10} >= {var_0}"
396 );
397 assert!(
398 var_100 < var_10,
399 "variance should decrease: {var_100} >= {var_10}"
400 );
401 }
402
403 #[test]
406 fn unit_conformal_bounds() {
407 let config = PredictorConfig {
408 coverage: 0.90,
409 prior_mean: 3.0,
410 prior_strength: 1.0,
411 ..Default::default()
412 };
413 let mut pred = HeightPredictor::new(config);
414
415 for _ in 0..50 {
417 pred.observe(0, 3);
418 }
419
420 let p = pred.predict(0);
421 assert_eq!(p.predicted, 3);
424 assert!(p.lower <= 3);
425 assert!(p.upper >= 3);
426 }
427
428 #[test]
429 fn conformal_bounds_widen_with_noise() {
430 let config = PredictorConfig {
431 coverage: 0.90,
432 prior_mean: 5.0,
433 prior_strength: 1.0,
434 ..Default::default()
435 };
436 let mut pred = HeightPredictor::new(config);
437
438 for _ in 0..50 {
440 pred.observe(0, 5);
441 }
442 let tight = pred.predict(0);
443
444 let mut pred2 = HeightPredictor::new(PredictorConfig {
446 coverage: 0.90,
447 prior_mean: 5.0,
448 prior_strength: 1.0,
449 ..Default::default()
450 });
451 let mut seed: u64 = 0xABCD_1234_5678_9ABC;
452 for _ in 0..50 {
453 seed = seed
454 .wrapping_mul(6364136223846793005)
455 .wrapping_add(1442695040888963407);
456 let h = 3 + (seed >> 62) as u16; pred2.observe(0, h);
458 }
459 let wide = pred2.predict(0);
460
461 assert!(
462 (wide.upper - wide.lower) >= (tight.upper - tight.lower),
463 "noisy data should produce wider bounds"
464 );
465 }
466
467 #[test]
470 fn property_coverage() {
471 let alpha = 0.10;
472 let config = PredictorConfig {
473 coverage: 1.0 - alpha,
474 prior_mean: 3.0,
475 prior_strength: 2.0,
476 prior_variance: 4.0,
477 calibration_window: 100,
478 ..Default::default()
479 };
480 let mut pred = HeightPredictor::new(config);
481
482 let mut seed: u64 = 0xDEAD_BEEF_CAFE_0001;
484 for _ in 0..100 {
485 seed = seed
486 .wrapping_mul(6364136223846793005)
487 .wrapping_add(1442695040888963407);
488 let h = 2 + (seed >> 62) as u16; pred.observe(0, h);
490 }
491
492 let mut violations = 0u32;
494 let test_n = 200;
495 for _ in 0..test_n {
496 seed = seed
497 .wrapping_mul(6364136223846793005)
498 .wrapping_add(1442695040888963407);
499 let h = 2 + (seed >> 62) as u16;
500 let within = pred.observe(0, h);
501 if !within {
502 violations += 1;
503 }
504 }
505
506 let viol_rate = violations as f64 / test_n as f64;
507 assert!(
510 viol_rate <= alpha + 0.15,
511 "violation rate {viol_rate} exceeds α + tolerance ({alpha} + 0.15)"
512 );
513 }
514
515 #[test]
518 fn e2e_scroll_stability() {
519 let mut pred = HeightPredictor::new(PredictorConfig {
520 prior_mean: 1.0,
521 prior_strength: 2.0,
522 default_height: 1,
523 coverage: 0.90,
524 ..Default::default()
525 });
526
527 let mut corrections = 0u32;
529 for _ in 0..500 {
530 let within = pred.observe(0, 1);
531 if !within {
532 corrections += 1;
533 }
534 }
535
536 let p = pred.predict(0);
539 assert_eq!(p.predicted, 1);
540 assert!(corrections < 10, "too many corrections: {corrections}");
541 }
542
543 #[test]
546 fn categories_are_independent() {
547 let mut pred = HeightPredictor::default();
548 let cat_a = 0;
549 let cat_b = pred.register_category();
550
551 for _ in 0..20 {
553 pred.observe(cat_a, 1);
554 pred.observe(cat_b, 5);
555 }
556
557 let pa = pred.predict(cat_a);
558 let pb = pred.predict(cat_b);
559
560 assert_eq!(pa.predicted, 1);
561 assert!(pb.predicted >= 4 && pb.predicted <= 5);
562 }
563
564 #[test]
567 fn cold_prediction_uses_default() {
568 let pred = HeightPredictor::new(PredictorConfig {
569 default_height: 2,
570 prior_variance: 1.0,
571 ..Default::default()
572 });
573 let p = pred.predict(0);
574 assert_eq!(p.predicted, 2);
575 assert_eq!(p.observations, 0);
576 }
577
578 #[test]
581 fn deterministic_under_same_observations() {
582 let run = || {
583 let mut pred = HeightPredictor::default();
584 let observations = [1, 2, 1, 3, 1, 2, 1, 1, 4, 1];
585 for &h in &observations {
586 pred.observe(0, h);
587 }
588 (pred.predict(0).predicted, pred.posterior_mean(0))
589 };
590
591 let (p1, m1) = run();
592 let (p2, m2) = run();
593 assert_eq!(p1, p2);
594 assert!((m1 - m2).abs() < 1e-15);
595 }
596
597 #[test]
600 fn perf_prediction_overhead() {
601 let mut pred = HeightPredictor::default();
602
603 for _ in 0..100 {
605 pred.observe(0, 2);
606 }
607
608 let start = std::time::Instant::now();
609 let mut _sink = 0u16;
610 for _ in 0..100_000 {
611 _sink = _sink.wrapping_add(pred.predict(0).predicted);
612 }
613 let elapsed = start.elapsed();
614 let per_prediction = elapsed / 100_000;
615
616 assert!(
618 per_prediction < std::time::Duration::from_micros(5),
619 "prediction too slow: {per_prediction:?}"
620 );
621 }
622
623 #[test]
626 fn violation_tracking() {
627 let mut pred = HeightPredictor::new(PredictorConfig {
628 prior_mean: 5.0,
629 prior_strength: 100.0, default_height: 5,
631 coverage: 0.95,
632 ..Default::default()
633 });
634
635 for _ in 0..50 {
637 pred.observe(0, 5);
638 }
639
640 let within = pred.observe(0, 20);
642 assert!(!within, "extreme outlier should violate bounds");
643 assert!(pred.total_violations() > 0);
644 }
645}