1use std::hash::Hash;
20
21use super::TransitionCounter;
22
23const DEFAULT_MIN_OBSERVATIONS: u64 = 20;
25
26#[derive(Debug, Clone)]
32pub struct DecayConfig {
33 pub factor: f64,
35 pub interval: u64,
37}
38
39impl Default for DecayConfig {
40 fn default() -> Self {
41 Self {
42 factor: 0.85,
43 interval: 500,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50struct DecayState {
51 config: Option<DecayConfig>,
52 transitions_since_last_decay: u64,
53}
54
55impl DecayState {
56 fn disabled() -> Self {
57 Self {
58 config: None,
59 transitions_since_last_decay: 0,
60 }
61 }
62
63 fn with_config(config: DecayConfig) -> Self {
64 Self {
65 config: Some(config),
66 transitions_since_last_decay: 0,
67 }
68 }
69
70 fn maybe_decay<S: Eq + Hash + Clone>(&mut self, counter: &mut TransitionCounter<S>) -> bool {
73 let config = match &self.config {
74 Some(c) => c,
75 None => return false,
76 };
77
78 self.transitions_since_last_decay += 1;
79 if self.transitions_since_last_decay >= config.interval {
80 counter.decay(config.factor);
81 self.transitions_since_last_decay = 0;
82 true
83 } else {
84 false
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct ScreenPrediction<S> {
92 pub screen: S,
94 pub probability: f64,
96 pub confidence: f64,
98}
99
100#[derive(Debug, Clone)]
105pub struct MarkovPredictor<S: Eq + Hash + Clone> {
106 counter: TransitionCounter<S>,
107 min_observations: u64,
108 decay_state: DecayState,
109}
110
111impl<S: Eq + Hash + Clone> MarkovPredictor<S> {
112 #[must_use]
114 pub fn new() -> Self {
115 Self {
116 counter: TransitionCounter::new(),
117 min_observations: DEFAULT_MIN_OBSERVATIONS,
118 decay_state: DecayState::disabled(),
119 }
120 }
121
122 #[must_use]
124 pub fn with_min_observations(n: u64) -> Self {
125 Self {
126 counter: TransitionCounter::new(),
127 min_observations: n.max(1),
128 decay_state: DecayState::disabled(),
129 }
130 }
131
132 #[must_use]
134 pub fn with_counter(counter: TransitionCounter<S>, min_observations: u64) -> Self {
135 Self {
136 counter,
137 min_observations: min_observations.max(1),
138 decay_state: DecayState::disabled(),
139 }
140 }
141
142 pub fn enable_auto_decay(&mut self, config: DecayConfig) {
147 self.decay_state = DecayState::with_config(config);
148 }
149
150 pub fn record_transition(&mut self, from: S, to: S) {
155 self.counter.record(from, to);
156 self.decay_state.maybe_decay(&mut self.counter);
157 }
158
159 #[must_use]
165 pub fn predict(&self, current_screen: &S) -> Vec<ScreenPrediction<S>> {
166 let confidence = self.confidence(current_screen);
167 let ranked = self.counter.all_targets_ranked(current_screen);
168
169 if ranked.is_empty() {
170 return Vec::new();
171 }
172
173 let n_targets = ranked.len() as f64;
174 let uniform_prob = 1.0 / n_targets;
175
176 let mut predictions: Vec<ScreenPrediction<S>> = ranked
177 .into_iter()
178 .map(|(screen, raw_prob)| {
179 let effective = confidence * raw_prob + (1.0 - confidence) * uniform_prob;
180 ScreenPrediction {
181 screen,
182 probability: effective,
183 confidence,
184 }
185 })
186 .collect();
187
188 predictions.sort_by(|a, b| {
190 b.probability
191 .partial_cmp(&a.probability)
192 .unwrap_or(std::cmp::Ordering::Equal)
193 });
194
195 predictions
196 }
197
198 #[must_use]
203 pub fn is_cold_start(&self, screen: &S) -> bool {
204 (self.counter.total_from(screen) as u64) < self.min_observations
205 }
206
207 #[must_use]
211 pub fn confidence(&self, screen: &S) -> f64 {
212 let observations = self.counter.total_from(screen);
213 (observations / self.min_observations as f64).min(1.0)
214 }
215
216 #[must_use]
218 pub fn counter(&self) -> &TransitionCounter<S> {
219 &self.counter
220 }
221
222 pub fn counter_mut(&mut self) -> &mut TransitionCounter<S> {
224 &mut self.counter
225 }
226
227 #[must_use]
229 pub fn min_observations(&self) -> u64 {
230 self.min_observations
231 }
232}
233
234impl<S: Eq + Hash + Clone> Default for MarkovPredictor<S> {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn cold_start_returns_uniform_distribution() {
250 let mut mp = MarkovPredictor::with_min_observations(20);
251 mp.record_transition("a", "b");
253 mp.record_transition("a", "c");
254
255 let preds = mp.predict(&"a");
256 assert_eq!(preds.len(), 2);
257
258 let diff = (preds[0].probability - preds[1].probability).abs();
261 assert!(
262 diff < 0.15,
263 "cold start should be near-uniform, diff={diff}"
264 );
265 }
266
267 #[test]
268 fn warm_predictions_match_observed() {
269 let mut mp = MarkovPredictor::with_min_observations(10);
270
271 for _ in 0..20 {
273 mp.record_transition("a", "b");
274 }
275 for _ in 0..10 {
276 mp.record_transition("a", "c");
277 }
278
279 let preds = mp.predict(&"a");
280 assert_eq!(preds.len(), 2);
281
282 assert!((preds[0].confidence - 1.0).abs() < 1e-10);
284
285 assert_eq!(preds[0].screen, "b");
287 assert!(preds[0].probability > preds[1].probability);
288
289 assert!((preds[0].probability - 21.0 / 32.0).abs() < 1e-10);
293 assert!((preds[1].probability - 11.0 / 32.0).abs() < 1e-10);
294 }
295
296 #[test]
297 fn confidence_increases_with_observations() {
298 let mut mp = MarkovPredictor::with_min_observations(10);
299
300 assert_eq!(mp.confidence(&"x"), 0.0); mp.record_transition("x", "y");
303 assert!((mp.confidence(&"x") - 0.1).abs() < 1e-10); for _ in 0..4 {
306 mp.record_transition("x", "y");
307 }
308 assert!((mp.confidence(&"x") - 0.5).abs() < 1e-10); for _ in 0..5 {
311 mp.record_transition("x", "y");
312 }
313 assert!((mp.confidence(&"x") - 1.0).abs() < 1e-10); }
315
316 #[test]
317 fn confidence_caps_at_one() {
318 let mut mp = MarkovPredictor::with_min_observations(5);
319 for _ in 0..100 {
320 mp.record_transition("a", "b");
321 }
322 assert!((mp.confidence(&"a") - 1.0).abs() < 1e-10);
323 }
324
325 #[test]
326 fn is_cold_start_reflects_threshold() {
327 let mut mp = MarkovPredictor::with_min_observations(5);
328 assert!(mp.is_cold_start(&"x"));
329
330 for _ in 0..4 {
331 mp.record_transition("x", "y");
332 }
333 assert!(mp.is_cold_start(&"x")); mp.record_transition("x", "y");
336 assert!(!mp.is_cold_start(&"x")); }
338
339 #[test]
340 fn empty_predictor_returns_no_predictions() {
341 let mp: MarkovPredictor<&str> = MarkovPredictor::new();
342 let preds = mp.predict(&"x");
343 assert!(preds.is_empty());
344 }
345
346 #[test]
347 fn predictions_sorted_by_probability() {
348 let mut mp = MarkovPredictor::with_min_observations(5);
349 for _ in 0..10 {
350 mp.record_transition("a", "x");
351 }
352 for _ in 0..5 {
353 mp.record_transition("a", "y");
354 }
355 for _ in 0..1 {
356 mp.record_transition("a", "z");
357 }
358
359 let preds = mp.predict(&"a");
360 assert_eq!(preds.len(), 3);
361 assert!(preds[0].probability >= preds[1].probability);
362 assert!(preds[1].probability >= preds[2].probability);
363 }
364
365 #[test]
366 fn probabilities_sum_to_approximately_one() {
367 let mut mp = MarkovPredictor::with_min_observations(10);
368 for _ in 0..15 {
369 mp.record_transition("a", "b");
370 }
371 for _ in 0..8 {
372 mp.record_transition("a", "c");
373 }
374 for _ in 0..3 {
375 mp.record_transition("a", "d");
376 }
377
378 let preds = mp.predict(&"a");
379 let sum: f64 = preds.iter().map(|p| p.probability).sum();
380 assert!(
381 (sum - 1.0).abs() < 1e-10,
382 "probabilities should sum to 1.0, got {sum}"
383 );
384 }
385
386 #[test]
387 fn counter_access() {
388 let mut mp = MarkovPredictor::<&str>::new();
389 mp.record_transition("a", "b");
390
391 assert_eq!(mp.counter().total(), 1.0);
392 assert_eq!(mp.counter().count(&"a", &"b"), 1.0);
393 }
394
395 #[test]
396 fn counter_mut_access() {
397 let mut mp = MarkovPredictor::<&str>::new();
398 mp.record_transition("a", "b");
399
400 let mut other = TransitionCounter::new();
402 other.record("a", "c");
403 mp.counter_mut().merge(&other);
404
405 assert_eq!(mp.counter().total(), 2.0);
406 }
407
408 #[test]
409 fn with_counter_constructor() {
410 let mut counter = TransitionCounter::new();
411 for _ in 0..50 {
412 counter.record("a", "b");
413 }
414
415 let mp = MarkovPredictor::with_counter(counter, 10);
416 assert!(!mp.is_cold_start(&"a"));
417 assert_eq!(mp.min_observations(), 10);
418 }
419
420 #[test]
421 fn default_impl() {
422 let mp: MarkovPredictor<String> = MarkovPredictor::default();
423 assert_eq!(mp.min_observations(), DEFAULT_MIN_OBSERVATIONS);
424 assert_eq!(mp.counter().total(), 0.0);
425 }
426
427 #[test]
432 fn predict_returns_all_known_targets() {
433 let mut mp = MarkovPredictor::with_min_observations(5);
434 mp.record_transition("a", "b");
435 mp.record_transition("a", "c");
436 mp.record_transition("a", "d");
437
438 let preds = mp.predict(&"a");
439 let screens: Vec<_> = preds.iter().map(|p| p.screen).collect();
440 eprintln!("predicted screens: {screens:?}");
441 assert_eq!(preds.len(), 3);
442 assert!(screens.contains(&"b"));
443 assert!(screens.contains(&"c"));
444 assert!(screens.contains(&"d"));
445 }
446
447 #[test]
448 fn predict_zero_outgoing_returns_empty() {
449 let mut mp = MarkovPredictor::with_min_observations(5);
450 mp.record_transition("a", "x");
452
453 let preds = mp.predict(&"x");
454 eprintln!("predictions from unseen source: len={}", preds.len());
455 assert!(preds.is_empty());
456 }
457
458 #[test]
459 fn record_transition_updates_predictions() {
460 let mut mp = MarkovPredictor::with_min_observations(5);
461 mp.record_transition("a", "b");
462
463 let preds_before = mp.predict(&"a");
464 assert_eq!(preds_before.len(), 1);
465 assert_eq!(preds_before[0].screen, "b");
466
467 mp.record_transition("a", "c");
469 let preds_after = mp.predict(&"a");
470 eprintln!(
471 "before: {} predictions, after: {} predictions",
472 preds_before.len(),
473 preds_after.len()
474 );
475 assert_eq!(preds_after.len(), 2);
476 let screens: Vec<_> = preds_after.iter().map(|p| p.screen).collect();
477 assert!(screens.contains(&"b"));
478 assert!(screens.contains(&"c"));
479 }
480
481 #[test]
482 fn predictions_change_with_new_transitions() {
483 let mut mp = MarkovPredictor::with_min_observations(5);
484 for _ in 0..10 {
485 mp.record_transition("a", "b");
486 }
487 mp.record_transition("a", "c");
488
489 let preds1 = mp.predict(&"a");
490 let prob_b1 = preds1.iter().find(|p| p.screen == "b").unwrap().probability;
491 let prob_c1 = preds1.iter().find(|p| p.screen == "c").unwrap().probability;
492
493 for _ in 0..50 {
495 mp.record_transition("a", "c");
496 }
497
498 let preds2 = mp.predict(&"a");
499 let prob_b2 = preds2.iter().find(|p| p.screen == "b").unwrap().probability;
500 let prob_c2 = preds2.iter().find(|p| p.screen == "c").unwrap().probability;
501
502 eprintln!("before: P(b)={prob_b1:.4}, P(c)={prob_c1:.4}");
503 eprintln!("after: P(b)={prob_b2:.4}, P(c)={prob_c2:.4}");
504
505 assert!(
507 prob_c2 > prob_c1,
508 "P(c) should increase with more transitions"
509 );
510 assert!(prob_b2 < prob_b1, "P(b) should decrease as c dominates");
512 }
513
514 #[test]
515 fn decay_via_counter_reduces_old_influence() {
516 let mut mp = MarkovPredictor::with_min_observations(5);
517 for _ in 0..20 {
519 mp.record_transition("a", "b");
520 }
521 mp.record_transition("a", "c");
522
523 let preds_before = mp.predict(&"a");
524 let prob_b_before = preds_before
525 .iter()
526 .find(|p| p.screen == "b")
527 .unwrap()
528 .probability;
529
530 mp.counter_mut().decay(0.1);
532
533 for _ in 0..5 {
535 mp.record_transition("a", "c");
536 }
537
538 let preds_after = mp.predict(&"a");
539 let prob_c_after = preds_after
540 .iter()
541 .find(|p| p.screen == "c")
542 .unwrap()
543 .probability;
544
545 eprintln!("before decay: P(b)={prob_b_before:.4}, after fresh c: P(c)={prob_c_after:.4}");
546 assert!(
548 prob_c_after > prob_b_before * 0.5,
549 "fresh transitions after decay should be influential"
550 );
551 }
552
553 #[test]
554 fn decay_shifts_predictions_toward_recent() {
555 let mut mp = MarkovPredictor::with_min_observations(5);
556
557 for _ in 0..20 {
559 mp.record_transition("a", "b");
560 }
561 for _ in 0..5 {
562 mp.record_transition("a", "c");
563 }
564
565 let p1 = mp.predict(&"a");
566 let p1_b = p1.iter().find(|p| p.screen == "b").unwrap().probability;
567
568 mp.counter_mut().decay(0.1);
570
571 for _ in 0..20 {
573 mp.record_transition("a", "c");
574 }
575 for _ in 0..5 {
576 mp.record_transition("a", "b");
577 }
578
579 let p2 = mp.predict(&"a");
580 let p2_c = p2.iter().find(|p| p.screen == "c").unwrap().probability;
581
582 eprintln!("phase1 P(b)={p1_b:.4}, phase2 P(c)={p2_c:.4}");
583 assert!(
585 p2_c > 0.5,
586 "recent pattern should dominate after decay, got P(c)={p2_c}"
587 );
588 }
589
590 #[test]
591 fn screen_prediction_fields_are_populated() {
592 let mut mp = MarkovPredictor::with_min_observations(10);
593 for _ in 0..5 {
594 mp.record_transition("a", "b");
595 }
596 mp.record_transition("a", "c");
597
598 let preds = mp.predict(&"a");
599 for pred in &preds {
600 eprintln!(
601 "screen={}, prob={:.4}, conf={:.4}",
602 pred.screen, pred.probability, pred.confidence
603 );
604 assert!(pred.probability > 0.0, "probability should be > 0");
606 assert!(pred.probability <= 1.0, "probability should be <= 1.0");
607 assert!(pred.confidence >= 0.0, "confidence should be >= 0");
609 assert!(pred.confidence <= 1.0, "confidence should be <= 1.0");
610 }
611 }
612
613 #[test]
614 fn confidence_always_in_unit_range() {
615 let mut mp = MarkovPredictor::with_min_observations(10);
616
617 let c0 = mp.confidence(&"x");
619 assert!((0.0..=1.0).contains(&c0), "confidence={c0}");
620
621 for i in 1..=20 {
623 mp.record_transition("x", "y");
624 let c = mp.confidence(&"x");
625 eprintln!("obs={i}, confidence={c:.4}");
626 assert!((0.0..=1.0).contains(&c), "confidence out of range: {c}");
627 }
628 }
629
630 #[test]
631 fn probability_always_positive_with_smoothing() {
632 let mut mp = MarkovPredictor::with_min_observations(5);
633 for _ in 0..100 {
634 mp.record_transition("a", "b");
635 }
636 mp.record_transition("a", "c");
637
638 let preds = mp.predict(&"a");
639 for pred in &preds {
640 eprintln!("screen={}, prob={:.6}", pred.screen, pred.probability);
641 assert!(
642 pred.probability > 0.0,
643 "all probabilities should be > 0 due to smoothing"
644 );
645 }
646 }
647
648 #[test]
649 fn blending_transitions_smoothly() {
650 let mut mp = MarkovPredictor::with_min_observations(10);
653
654 for _ in 0..4 {
656 mp.record_transition("a", "b");
657 }
658 mp.record_transition("a", "c");
659
660 let preds = mp.predict(&"a");
661 assert_eq!(preds.len(), 2);
662
663 let conf = mp.confidence(&"a");
665 assert!((conf - 0.5).abs() < 1e-10);
666
667 let expected_b = 0.5 * (5.0 / 7.0) + 0.5 * 0.5;
673 let expected_c = 0.5 * (2.0 / 7.0) + 0.5 * 0.5;
674
675 assert_eq!(preds[0].screen, "b");
676 assert!(
677 (preds[0].probability - expected_b).abs() < 1e-10,
678 "expected {expected_b}, got {}",
679 preds[0].probability
680 );
681 assert!(
682 (preds[1].probability - expected_c).abs() < 1e-10,
683 "expected {expected_c}, got {}",
684 preds[1].probability
685 );
686
687 let sum: f64 = preds.iter().map(|p| p.probability).sum();
689 assert!((sum - 1.0).abs() < 1e-10);
690 }
691
692 #[test]
697 fn auto_decay_triggers_at_interval() {
698 let mut mp = MarkovPredictor::with_min_observations(5);
699 mp.enable_auto_decay(DecayConfig {
700 factor: 0.5,
701 interval: 10,
702 });
703
704 for _ in 0..9 {
706 mp.record_transition("a", "b");
707 }
708 assert_eq!(mp.counter().total(), 9.0);
710
711 mp.record_transition("a", "b");
713 let total = mp.counter().total();
714 eprintln!("after 10 transitions with decay(0.5): total={total}");
715 assert!(
716 (total - 5.0).abs() < 1e-9,
717 "expected ~5.0 after decay, got {total}"
718 );
719 }
720
721 #[test]
722 fn auto_decay_interval_resets_after_each_cycle() {
723 let mut mp = MarkovPredictor::with_min_observations(5);
724 mp.enable_auto_decay(DecayConfig {
725 factor: 0.5,
726 interval: 5,
727 });
728
729 for _ in 0..5 {
731 mp.record_transition("a", "b");
732 }
733 let after_first = mp.counter().total();
734 eprintln!("after first decay: {after_first}");
735 assert!((after_first - 2.5).abs() < 1e-9);
736
737 for _ in 0..5 {
739 mp.record_transition("a", "b");
740 }
741 let after_second = mp.counter().total();
742 eprintln!("after second decay: {after_second}");
743 assert!(
745 (after_second - 3.75).abs() < 1e-9,
746 "expected ~3.75, got {after_second}"
747 );
748 }
749
750 #[test]
751 fn auto_decay_disabled_by_default() {
752 let mut mp = MarkovPredictor::with_min_observations(5);
753
754 for _ in 0..100 {
756 mp.record_transition("a", "b");
757 }
758 assert_eq!(mp.counter().total(), 100.0);
759 }
760
761 #[test]
762 fn auto_decay_recent_transitions_dominate() {
763 let mut mp = MarkovPredictor::with_min_observations(5);
764 mp.enable_auto_decay(DecayConfig {
765 factor: 0.1, interval: 20,
767 });
768
769 for _ in 0..20 {
771 mp.record_transition("a", "b");
772 }
773
774 let b_after_decay = mp.counter().count(&"a", &"b");
776 eprintln!("b after first decay: {b_after_decay}");
777
778 for _ in 0..15 {
780 mp.record_transition("a", "c");
781 }
782
783 let b_count = mp.counter().count(&"a", &"b");
785 let c_count = mp.counter().count(&"a", &"c");
786 eprintln!("b_count={b_count}, c_count={c_count}");
787 assert!(
788 c_count > b_count,
789 "recent 'c' transitions ({c_count}) should exceed decayed 'b' ({b_count})"
790 );
791 }
792
793 #[test]
794 fn auto_decay_counter_consistency() {
795 let mut mp = MarkovPredictor::with_min_observations(5);
796 mp.enable_auto_decay(DecayConfig {
797 factor: 0.8,
798 interval: 10,
799 });
800
801 for _ in 0..30 {
803 mp.record_transition("a", "b");
804 mp.record_transition("a", "c");
805 mp.record_transition("x", "y");
806 }
807
808 let total = mp.counter().total();
810 let mut sum = 0.0;
811 for from in mp.counter().state_ids() {
812 for (to, _) in mp.counter().all_targets_ranked(&from) {
813 sum += mp.counter().count(&from, &to);
814 }
815 }
816 eprintln!("total={total}, sum={sum}");
817 assert!(
818 (total - sum).abs() < 1e-6,
819 "total({total}) should match sum of counts({sum})"
820 );
821 }
822
823 #[test]
824 fn decay_config_default() {
825 let config = DecayConfig::default();
826 assert!((config.factor - 0.85).abs() < 1e-10);
827 assert_eq!(config.interval, 500);
828 }
829}