use std::hash::Hash;
use super::TransitionCounter;
const DEFAULT_MIN_OBSERVATIONS: u64 = 20;
#[derive(Debug, Clone)]
pub struct DecayConfig {
pub factor: f64,
pub interval: u64,
}
impl Default for DecayConfig {
fn default() -> Self {
Self {
factor: 0.85,
interval: 500,
}
}
}
#[derive(Debug, Clone)]
struct DecayState {
config: Option<DecayConfig>,
transitions_since_last_decay: u64,
}
impl DecayState {
fn disabled() -> Self {
Self {
config: None,
transitions_since_last_decay: 0,
}
}
fn with_config(config: DecayConfig) -> Self {
Self {
config: Some(config),
transitions_since_last_decay: 0,
}
}
fn maybe_decay<S: Eq + Hash + Clone>(&mut self, counter: &mut TransitionCounter<S>) -> bool {
let config = match &self.config {
Some(c) => c,
None => return false,
};
self.transitions_since_last_decay += 1;
if self.transitions_since_last_decay >= config.interval {
counter.decay(config.factor);
self.transitions_since_last_decay = 0;
true
} else {
false
}
}
}
#[derive(Debug, Clone)]
pub struct ScreenPrediction<S> {
pub screen: S,
pub probability: f64,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct MarkovPredictor<S: Eq + Hash + Clone> {
counter: TransitionCounter<S>,
min_observations: u64,
decay_state: DecayState,
}
impl<S: Eq + Hash + Clone> MarkovPredictor<S> {
#[must_use]
pub fn new() -> Self {
Self {
counter: TransitionCounter::new(),
min_observations: DEFAULT_MIN_OBSERVATIONS,
decay_state: DecayState::disabled(),
}
}
#[must_use]
pub fn with_min_observations(n: u64) -> Self {
Self {
counter: TransitionCounter::new(),
min_observations: n.max(1),
decay_state: DecayState::disabled(),
}
}
#[must_use]
pub fn with_counter(counter: TransitionCounter<S>, min_observations: u64) -> Self {
Self {
counter,
min_observations: min_observations.max(1),
decay_state: DecayState::disabled(),
}
}
pub fn enable_auto_decay(&mut self, config: DecayConfig) {
self.decay_state = DecayState::with_config(config);
}
pub fn record_transition(&mut self, from: S, to: S) {
self.counter.record(from, to);
self.decay_state.maybe_decay(&mut self.counter);
}
#[must_use]
pub fn predict(&self, current_screen: &S) -> Vec<ScreenPrediction<S>> {
let confidence = self.confidence(current_screen);
let ranked = self.counter.all_targets_ranked(current_screen);
if ranked.is_empty() {
return Vec::new();
}
let n_targets = ranked.len() as f64;
let uniform_prob = 1.0 / n_targets;
let mut predictions: Vec<ScreenPrediction<S>> = ranked
.into_iter()
.map(|(screen, raw_prob)| {
let effective = confidence * raw_prob + (1.0 - confidence) * uniform_prob;
ScreenPrediction {
screen,
probability: effective,
confidence,
}
})
.collect();
predictions.sort_by(|a, b| {
b.probability
.partial_cmp(&a.probability)
.unwrap_or(std::cmp::Ordering::Equal)
});
predictions
}
#[must_use]
pub fn is_cold_start(&self, screen: &S) -> bool {
(self.counter.total_from(screen) as u64) < self.min_observations
}
#[must_use]
pub fn confidence(&self, screen: &S) -> f64 {
let observations = self.counter.total_from(screen);
(observations / self.min_observations as f64).min(1.0)
}
#[must_use]
pub fn counter(&self) -> &TransitionCounter<S> {
&self.counter
}
pub fn counter_mut(&mut self) -> &mut TransitionCounter<S> {
&mut self.counter
}
#[must_use]
pub fn min_observations(&self) -> u64 {
self.min_observations
}
}
impl<S: Eq + Hash + Clone> Default for MarkovPredictor<S> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cold_start_returns_uniform_distribution() {
let mut mp = MarkovPredictor::with_min_observations(20);
mp.record_transition("a", "b");
mp.record_transition("a", "c");
let preds = mp.predict(&"a");
assert_eq!(preds.len(), 2);
let diff = (preds[0].probability - preds[1].probability).abs();
assert!(
diff < 0.15,
"cold start should be near-uniform, diff={diff}"
);
}
#[test]
fn warm_predictions_match_observed() {
let mut mp = MarkovPredictor::with_min_observations(10);
for _ in 0..20 {
mp.record_transition("a", "b");
}
for _ in 0..10 {
mp.record_transition("a", "c");
}
let preds = mp.predict(&"a");
assert_eq!(preds.len(), 2);
assert!((preds[0].confidence - 1.0).abs() < 1e-10);
assert_eq!(preds[0].screen, "b");
assert!(preds[0].probability > preds[1].probability);
assert!((preds[0].probability - 21.0 / 32.0).abs() < 1e-10);
assert!((preds[1].probability - 11.0 / 32.0).abs() < 1e-10);
}
#[test]
fn confidence_increases_with_observations() {
let mut mp = MarkovPredictor::with_min_observations(10);
assert_eq!(mp.confidence(&"x"), 0.0);
mp.record_transition("x", "y");
assert!((mp.confidence(&"x") - 0.1).abs() < 1e-10);
for _ in 0..4 {
mp.record_transition("x", "y");
}
assert!((mp.confidence(&"x") - 0.5).abs() < 1e-10);
for _ in 0..5 {
mp.record_transition("x", "y");
}
assert!((mp.confidence(&"x") - 1.0).abs() < 1e-10); }
#[test]
fn confidence_caps_at_one() {
let mut mp = MarkovPredictor::with_min_observations(5);
for _ in 0..100 {
mp.record_transition("a", "b");
}
assert!((mp.confidence(&"a") - 1.0).abs() < 1e-10);
}
#[test]
fn is_cold_start_reflects_threshold() {
let mut mp = MarkovPredictor::with_min_observations(5);
assert!(mp.is_cold_start(&"x"));
for _ in 0..4 {
mp.record_transition("x", "y");
}
assert!(mp.is_cold_start(&"x"));
mp.record_transition("x", "y");
assert!(!mp.is_cold_start(&"x")); }
#[test]
fn empty_predictor_returns_no_predictions() {
let mp: MarkovPredictor<&str> = MarkovPredictor::new();
let preds = mp.predict(&"x");
assert!(preds.is_empty());
}
#[test]
fn predictions_sorted_by_probability() {
let mut mp = MarkovPredictor::with_min_observations(5);
for _ in 0..10 {
mp.record_transition("a", "x");
}
for _ in 0..5 {
mp.record_transition("a", "y");
}
for _ in 0..1 {
mp.record_transition("a", "z");
}
let preds = mp.predict(&"a");
assert_eq!(preds.len(), 3);
assert!(preds[0].probability >= preds[1].probability);
assert!(preds[1].probability >= preds[2].probability);
}
#[test]
fn probabilities_sum_to_approximately_one() {
let mut mp = MarkovPredictor::with_min_observations(10);
for _ in 0..15 {
mp.record_transition("a", "b");
}
for _ in 0..8 {
mp.record_transition("a", "c");
}
for _ in 0..3 {
mp.record_transition("a", "d");
}
let preds = mp.predict(&"a");
let sum: f64 = preds.iter().map(|p| p.probability).sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"probabilities should sum to 1.0, got {sum}"
);
}
#[test]
fn counter_access() {
let mut mp = MarkovPredictor::<&str>::new();
mp.record_transition("a", "b");
assert_eq!(mp.counter().total(), 1.0);
assert_eq!(mp.counter().count(&"a", &"b"), 1.0);
}
#[test]
fn counter_mut_access() {
let mut mp = MarkovPredictor::<&str>::new();
mp.record_transition("a", "b");
let mut other = TransitionCounter::new();
other.record("a", "c");
mp.counter_mut().merge(&other);
assert_eq!(mp.counter().total(), 2.0);
}
#[test]
fn with_counter_constructor() {
let mut counter = TransitionCounter::new();
for _ in 0..50 {
counter.record("a", "b");
}
let mp = MarkovPredictor::with_counter(counter, 10);
assert!(!mp.is_cold_start(&"a"));
assert_eq!(mp.min_observations(), 10);
}
#[test]
fn default_impl() {
let mp: MarkovPredictor<String> = MarkovPredictor::default();
assert_eq!(mp.min_observations(), DEFAULT_MIN_OBSERVATIONS);
assert_eq!(mp.counter().total(), 0.0);
}
#[test]
fn predict_returns_all_known_targets() {
let mut mp = MarkovPredictor::with_min_observations(5);
mp.record_transition("a", "b");
mp.record_transition("a", "c");
mp.record_transition("a", "d");
let preds = mp.predict(&"a");
let screens: Vec<_> = preds.iter().map(|p| p.screen).collect();
eprintln!("predicted screens: {screens:?}");
assert_eq!(preds.len(), 3);
assert!(screens.contains(&"b"));
assert!(screens.contains(&"c"));
assert!(screens.contains(&"d"));
}
#[test]
fn predict_zero_outgoing_returns_empty() {
let mut mp = MarkovPredictor::with_min_observations(5);
mp.record_transition("a", "x");
let preds = mp.predict(&"x");
eprintln!("predictions from unseen source: len={}", preds.len());
assert!(preds.is_empty());
}
#[test]
fn record_transition_updates_predictions() {
let mut mp = MarkovPredictor::with_min_observations(5);
mp.record_transition("a", "b");
let preds_before = mp.predict(&"a");
assert_eq!(preds_before.len(), 1);
assert_eq!(preds_before[0].screen, "b");
mp.record_transition("a", "c");
let preds_after = mp.predict(&"a");
eprintln!(
"before: {} predictions, after: {} predictions",
preds_before.len(),
preds_after.len()
);
assert_eq!(preds_after.len(), 2);
let screens: Vec<_> = preds_after.iter().map(|p| p.screen).collect();
assert!(screens.contains(&"b"));
assert!(screens.contains(&"c"));
}
#[test]
fn predictions_change_with_new_transitions() {
let mut mp = MarkovPredictor::with_min_observations(5);
for _ in 0..10 {
mp.record_transition("a", "b");
}
mp.record_transition("a", "c");
let preds1 = mp.predict(&"a");
let prob_b1 = preds1.iter().find(|p| p.screen == "b").unwrap().probability;
let prob_c1 = preds1.iter().find(|p| p.screen == "c").unwrap().probability;
for _ in 0..50 {
mp.record_transition("a", "c");
}
let preds2 = mp.predict(&"a");
let prob_b2 = preds2.iter().find(|p| p.screen == "b").unwrap().probability;
let prob_c2 = preds2.iter().find(|p| p.screen == "c").unwrap().probability;
eprintln!("before: P(b)={prob_b1:.4}, P(c)={prob_c1:.4}");
eprintln!("after: P(b)={prob_b2:.4}, P(c)={prob_c2:.4}");
assert!(
prob_c2 > prob_c1,
"P(c) should increase with more transitions"
);
assert!(prob_b2 < prob_b1, "P(b) should decrease as c dominates");
}
#[test]
fn decay_via_counter_reduces_old_influence() {
let mut mp = MarkovPredictor::with_min_observations(5);
for _ in 0..20 {
mp.record_transition("a", "b");
}
mp.record_transition("a", "c");
let preds_before = mp.predict(&"a");
let prob_b_before = preds_before
.iter()
.find(|p| p.screen == "b")
.unwrap()
.probability;
mp.counter_mut().decay(0.1);
for _ in 0..5 {
mp.record_transition("a", "c");
}
let preds_after = mp.predict(&"a");
let prob_c_after = preds_after
.iter()
.find(|p| p.screen == "c")
.unwrap()
.probability;
eprintln!("before decay: P(b)={prob_b_before:.4}, after fresh c: P(c)={prob_c_after:.4}");
assert!(
prob_c_after > prob_b_before * 0.5,
"fresh transitions after decay should be influential"
);
}
#[test]
fn decay_shifts_predictions_toward_recent() {
let mut mp = MarkovPredictor::with_min_observations(5);
for _ in 0..20 {
mp.record_transition("a", "b");
}
for _ in 0..5 {
mp.record_transition("a", "c");
}
let p1 = mp.predict(&"a");
let p1_b = p1.iter().find(|p| p.screen == "b").unwrap().probability;
mp.counter_mut().decay(0.1);
for _ in 0..20 {
mp.record_transition("a", "c");
}
for _ in 0..5 {
mp.record_transition("a", "b");
}
let p2 = mp.predict(&"a");
let p2_c = p2.iter().find(|p| p.screen == "c").unwrap().probability;
eprintln!("phase1 P(b)={p1_b:.4}, phase2 P(c)={p2_c:.4}");
assert!(
p2_c > 0.5,
"recent pattern should dominate after decay, got P(c)={p2_c}"
);
}
#[test]
fn screen_prediction_fields_are_populated() {
let mut mp = MarkovPredictor::with_min_observations(10);
for _ in 0..5 {
mp.record_transition("a", "b");
}
mp.record_transition("a", "c");
let preds = mp.predict(&"a");
for pred in &preds {
eprintln!(
"screen={}, prob={:.4}, conf={:.4}",
pred.screen, pred.probability, pred.confidence
);
assert!(pred.probability > 0.0, "probability should be > 0");
assert!(pred.probability <= 1.0, "probability should be <= 1.0");
assert!(pred.confidence >= 0.0, "confidence should be >= 0");
assert!(pred.confidence <= 1.0, "confidence should be <= 1.0");
}
}
#[test]
fn confidence_always_in_unit_range() {
let mut mp = MarkovPredictor::with_min_observations(10);
let c0 = mp.confidence(&"x");
assert!((0.0..=1.0).contains(&c0), "confidence={c0}");
for i in 1..=20 {
mp.record_transition("x", "y");
let c = mp.confidence(&"x");
eprintln!("obs={i}, confidence={c:.4}");
assert!((0.0..=1.0).contains(&c), "confidence out of range: {c}");
}
}
#[test]
fn probability_always_positive_with_smoothing() {
let mut mp = MarkovPredictor::with_min_observations(5);
for _ in 0..100 {
mp.record_transition("a", "b");
}
mp.record_transition("a", "c");
let preds = mp.predict(&"a");
for pred in &preds {
eprintln!("screen={}, prob={:.6}", pred.screen, pred.probability);
assert!(
pred.probability > 0.0,
"all probabilities should be > 0 due to smoothing"
);
}
}
#[test]
fn blending_transitions_smoothly() {
let mut mp = MarkovPredictor::with_min_observations(10);
for _ in 0..4 {
mp.record_transition("a", "b");
}
mp.record_transition("a", "c");
let preds = mp.predict(&"a");
assert_eq!(preds.len(), 2);
let conf = mp.confidence(&"a");
assert!((conf - 0.5).abs() < 1e-10);
let expected_b = 0.5 * (5.0 / 7.0) + 0.5 * 0.5;
let expected_c = 0.5 * (2.0 / 7.0) + 0.5 * 0.5;
assert_eq!(preds[0].screen, "b");
assert!(
(preds[0].probability - expected_b).abs() < 1e-10,
"expected {expected_b}, got {}",
preds[0].probability
);
assert!(
(preds[1].probability - expected_c).abs() < 1e-10,
"expected {expected_c}, got {}",
preds[1].probability
);
let sum: f64 = preds.iter().map(|p| p.probability).sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn auto_decay_triggers_at_interval() {
let mut mp = MarkovPredictor::with_min_observations(5);
mp.enable_auto_decay(DecayConfig {
factor: 0.5,
interval: 10,
});
for _ in 0..9 {
mp.record_transition("a", "b");
}
assert_eq!(mp.counter().total(), 9.0);
mp.record_transition("a", "b");
let total = mp.counter().total();
eprintln!("after 10 transitions with decay(0.5): total={total}");
assert!(
(total - 5.0).abs() < 1e-9,
"expected ~5.0 after decay, got {total}"
);
}
#[test]
fn auto_decay_interval_resets_after_each_cycle() {
let mut mp = MarkovPredictor::with_min_observations(5);
mp.enable_auto_decay(DecayConfig {
factor: 0.5,
interval: 5,
});
for _ in 0..5 {
mp.record_transition("a", "b");
}
let after_first = mp.counter().total();
eprintln!("after first decay: {after_first}");
assert!((after_first - 2.5).abs() < 1e-9);
for _ in 0..5 {
mp.record_transition("a", "b");
}
let after_second = mp.counter().total();
eprintln!("after second decay: {after_second}");
assert!(
(after_second - 3.75).abs() < 1e-9,
"expected ~3.75, got {after_second}"
);
}
#[test]
fn auto_decay_disabled_by_default() {
let mut mp = MarkovPredictor::with_min_observations(5);
for _ in 0..100 {
mp.record_transition("a", "b");
}
assert_eq!(mp.counter().total(), 100.0);
}
#[test]
fn auto_decay_recent_transitions_dominate() {
let mut mp = MarkovPredictor::with_min_observations(5);
mp.enable_auto_decay(DecayConfig {
factor: 0.1, interval: 20,
});
for _ in 0..20 {
mp.record_transition("a", "b");
}
let b_after_decay = mp.counter().count(&"a", &"b");
eprintln!("b after first decay: {b_after_decay}");
for _ in 0..15 {
mp.record_transition("a", "c");
}
let b_count = mp.counter().count(&"a", &"b");
let c_count = mp.counter().count(&"a", &"c");
eprintln!("b_count={b_count}, c_count={c_count}");
assert!(
c_count > b_count,
"recent 'c' transitions ({c_count}) should exceed decayed 'b' ({b_count})"
);
}
#[test]
fn auto_decay_counter_consistency() {
let mut mp = MarkovPredictor::with_min_observations(5);
mp.enable_auto_decay(DecayConfig {
factor: 0.8,
interval: 10,
});
for _ in 0..30 {
mp.record_transition("a", "b");
mp.record_transition("a", "c");
mp.record_transition("x", "y");
}
let total = mp.counter().total();
let mut sum = 0.0;
for from in mp.counter().state_ids() {
for (to, _) in mp.counter().all_targets_ranked(&from) {
sum += mp.counter().count(&from, &to);
}
}
eprintln!("total={total}, sum={sum}");
assert!(
(total - sum).abs() < 1e-6,
"total({total}) should match sum of counts({sum})"
);
}
#[test]
fn decay_config_default() {
let config = DecayConfig::default();
assert!((config.factor - 0.85).abs() < 1e-10);
assert_eq!(config.interval, 500);
}
}