Skip to main content

ftui_runtime/tick_strategy/
transition_counter.rs

1//! In-memory frequency matrix for screen transition counting.
2//!
3//! [`TransitionCounter<S>`] tracks how many times each `(from, to)` screen
4//! transition has occurred. It provides probability queries with Laplace
5//! smoothing, temporal decay, and merge support for persistence.
6//!
7//! Counts are stored as `f64` rather than `u64` because temporal decay
8//! multiplies all counts by a factor like 0.85. With integer counts,
9//! entries with count=1 would truncate to 0 on the first decay cycle,
10//! causing premature pruning. With f64, count 1.0 survives ~3 decay
11//! cycles at factor=0.85 before reaching the prune threshold.
12
13use std::collections::{HashMap, HashSet};
14use std::hash::Hash;
15
16/// Prune threshold: entries below this after decay are removed.
17const DEFAULT_PRUNE_THRESHOLD: f64 = 0.5;
18
19/// Default Laplace smoothing alpha.
20const DEFAULT_SMOOTHING_ALPHA: f64 = 1.0;
21
22/// In-memory frequency matrix for `(from, to)` transition counting.
23///
24/// Generic over state type `S` (typically `String` for screen IDs).
25#[derive(Debug, Clone)]
26pub struct TransitionCounter<S: Eq + Hash + Clone> {
27    /// Raw transition counts: `(from, to) → count`.
28    counts: HashMap<(S, S), f64>,
29    /// Row totals for fast probability computation: `from → total`.
30    total_from: HashMap<S, f64>,
31    /// Sum of all counts.
32    total_transitions: f64,
33    /// Smoothing parameter for probability queries.
34    smoothing_alpha: f64,
35    /// Entries below this value after decay are pruned.
36    prune_threshold: f64,
37}
38
39impl<S: Eq + Hash + Clone> TransitionCounter<S> {
40    /// Create a new empty counter with default smoothing.
41    #[must_use]
42    pub fn new() -> Self {
43        Self {
44            counts: HashMap::new(),
45            total_from: HashMap::new(),
46            total_transitions: 0.0,
47            smoothing_alpha: DEFAULT_SMOOTHING_ALPHA,
48            prune_threshold: DEFAULT_PRUNE_THRESHOLD,
49        }
50    }
51
52    /// Create a counter with custom smoothing alpha and prune threshold.
53    #[must_use]
54    pub fn with_config(smoothing_alpha: f64, prune_threshold: f64) -> Self {
55        Self {
56            counts: HashMap::new(),
57            total_from: HashMap::new(),
58            total_transitions: 0.0,
59            smoothing_alpha: smoothing_alpha.max(0.0),
60            prune_threshold: prune_threshold.max(0.0),
61        }
62    }
63
64    /// Record a single transition from `from` to `to` (increments by 1.0).
65    pub fn record(&mut self, from: S, to: S) {
66        self.record_with_count(from, to, 1.0);
67    }
68
69    /// Record a transition with an explicit count (used by persistence layer).
70    pub fn record_with_count(&mut self, from: S, to: S, count: f64) {
71        if count <= 0.0 {
72            return;
73        }
74        *self.counts.entry((from.clone(), to)).or_insert(0.0) += count;
75        *self.total_from.entry(from).or_insert(0.0) += count;
76        self.total_transitions += count;
77    }
78
79    /// Get the raw count for a specific transition.
80    #[must_use]
81    pub fn count(&self, from: &S, to: &S) -> f64 {
82        self.counts
83            .get(&(from.clone(), to.clone()))
84            .copied()
85            .unwrap_or(0.0)
86    }
87
88    /// Get the total transitions originating from `from`.
89    #[must_use]
90    pub fn total_from(&self, from: &S) -> f64 {
91        self.total_from.get(from).copied().unwrap_or(0.0)
92    }
93
94    /// Get the total number of recorded transitions.
95    #[must_use]
96    pub fn total(&self) -> f64 {
97        self.total_transitions
98    }
99
100    /// Compute the probability of transitioning from `from` to `to`.
101    ///
102    /// Uses Laplace (additive) smoothing:
103    /// `P(to|from) = (count(from,to) + alpha) / (total_from(from) + alpha * N)`
104    /// where `N` is the number of known target states from `from`.
105    ///
106    /// Returns a uniform estimate if `from` has no recorded transitions.
107    #[must_use]
108    pub fn probability(&self, from: &S, to: &S) -> f64 {
109        let total = self.total_from(from);
110        let raw_count = self.count(from, to);
111
112        // Count distinct targets from `from`
113        let n_targets = self.targets_from(from);
114        let n = if n_targets == 0 { 1 } else { n_targets };
115
116        let alpha = self.smoothing_alpha;
117        let denominator = total + alpha * n as f64;
118
119        if denominator <= 0.0 {
120            // No data at all — return uniform over known targets
121            if n > 0 { 1.0 / n as f64 } else { 0.0 }
122        } else {
123            (raw_count + alpha) / denominator
124        }
125    }
126
127    /// Return all known targets from `from`, ranked by probability (descending).
128    #[must_use]
129    pub fn all_targets_ranked(&self, from: &S) -> Vec<(S, f64)> {
130        let mut targets: Vec<(S, f64)> = self
131            .counts
132            .iter()
133            .filter(|((f, _), _)| f == from)
134            .map(|((_, t), _)| {
135                let prob = self.probability(from, t);
136                (t.clone(), prob)
137            })
138            .collect();
139
140        targets.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
141        targets
142    }
143
144    /// Merge another counter into this one (additive).
145    pub fn merge(&mut self, other: &TransitionCounter<S>) {
146        for ((from, to), count) in &other.counts {
147            *self.counts.entry((from.clone(), to.clone())).or_insert(0.0) += count;
148        }
149        // Recompute total_from from scratch for correctness
150        self.recompute_totals();
151    }
152
153    /// Apply temporal decay: multiply all counts by `factor`.
154    ///
155    /// `factor` should be in (0.0, 1.0). Entries that fall below the prune
156    /// threshold are removed to prevent unbounded map growth.
157    pub fn decay(&mut self, factor: f64) {
158        let factor = factor.clamp(0.0, 1.0);
159        let threshold = self.prune_threshold;
160
161        self.counts.retain(|_, count| {
162            *count *= factor;
163            *count >= threshold
164        });
165
166        self.recompute_totals();
167    }
168
169    /// Return the set of all known state IDs (both sources and targets).
170    #[must_use]
171    pub fn state_ids(&self) -> HashSet<S> {
172        let mut ids = HashSet::new();
173        for (from, to) in self.counts.keys() {
174            ids.insert(from.clone());
175            ids.insert(to.clone());
176        }
177        ids
178    }
179
180    /// Return the number of distinct targets reachable from `from`.
181    fn targets_from(&self, from: &S) -> usize {
182        self.counts.keys().filter(|(f, _)| f == from).count()
183    }
184
185    /// Recompute `total_from` and `total_transitions` from the counts map.
186    fn recompute_totals(&mut self) {
187        self.total_from.clear();
188        self.total_transitions = 0.0;
189        for ((from, _), count) in &self.counts {
190            *self.total_from.entry(from.clone()).or_insert(0.0) += count;
191            self.total_transitions += count;
192        }
193    }
194}
195
196impl<S: Eq + Hash + Clone> Default for TransitionCounter<S> {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202// =============================================================================
203// Tests
204// =============================================================================
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn record_increments_counts() {
212        let mut tc = TransitionCounter::new();
213        tc.record("a", "b");
214        assert_eq!(tc.count(&"a", &"b"), 1.0);
215
216        tc.record("a", "b");
217        assert_eq!(tc.count(&"a", &"b"), 2.0);
218
219        tc.record("a", "c");
220        assert_eq!(tc.count(&"a", &"c"), 1.0);
221        assert_eq!(tc.total(), 3.0);
222    }
223
224    #[test]
225    fn total_from_tracks_row_sums() {
226        let mut tc = TransitionCounter::new();
227        tc.record("a", "b");
228        tc.record("a", "b");
229        tc.record("a", "c");
230        tc.record("x", "y");
231
232        assert_eq!(tc.total_from(&"a"), 3.0);
233        assert_eq!(tc.total_from(&"x"), 1.0);
234        assert_eq!(tc.total_from(&"z"), 0.0); // unknown
235    }
236
237    #[test]
238    fn probability_with_smoothing() {
239        let mut tc = TransitionCounter::new();
240        tc.record("a", "b");
241        tc.record("a", "b");
242        tc.record("a", "c");
243
244        // With alpha=1.0, 2 targets from "a":
245        // P(b|a) = (2 + 1) / (3 + 1*2) = 3/5 = 0.6
246        // P(c|a) = (1 + 1) / (3 + 1*2) = 2/5 = 0.4
247        let p_b = tc.probability(&"a", &"b");
248        let p_c = tc.probability(&"a", &"c");
249        assert!((p_b - 0.6).abs() < 1e-10, "p_b = {p_b}");
250        assert!((p_c - 0.4).abs() < 1e-10, "p_c = {p_c}");
251    }
252
253    #[test]
254    fn probability_unseen_target() {
255        let mut tc = TransitionCounter::new();
256        tc.record("a", "b");
257
258        // "a" → "c" never recorded, but "b" is known target
259        // With smoothing: P(c|a) = (0 + 1) / (1 + 1*1) = 1/2
260        // But wait, "c" is not a known target from "a", so n_targets = 1 (only "b")
261        // P(c|a) = (0 + 1) / (1 + 1*1) = 0.5
262        let p = tc.probability(&"a", &"c");
263        assert!((p - 0.5).abs() < 1e-10, "p = {p}");
264    }
265
266    #[test]
267    fn probability_unknown_source() {
268        let tc: TransitionCounter<&str> = TransitionCounter::new();
269        // No data at all
270        let p = tc.probability(&"x", &"y");
271        assert!((p - 1.0).abs() < 1e-10, "p = {p}"); // uniform: 1/1
272    }
273
274    #[test]
275    fn decay_reduces_counts() {
276        let mut tc = TransitionCounter::new();
277        for _ in 0..10 {
278            tc.record("a", "b");
279        }
280        assert_eq!(tc.total(), 10.0);
281
282        tc.decay(0.5);
283        assert_eq!(tc.total(), 5.0);
284        assert_eq!(tc.count(&"a", &"b"), 5.0);
285    }
286
287    #[test]
288    fn decay_prunes_below_threshold() {
289        let mut tc = TransitionCounter::with_config(1.0, 0.5);
290        tc.record("a", "b"); // count = 1.0
291
292        tc.decay(0.85); // → 0.85
293        assert!(tc.count(&"a", &"b") > 0.0);
294
295        tc.decay(0.85); // → 0.7225
296        assert!(tc.count(&"a", &"b") > 0.0);
297
298        tc.decay(0.85); // → 0.614
299        assert!(tc.count(&"a", &"b") > 0.0);
300
301        // Keep decaying until below threshold
302        tc.decay(0.85); // → 0.522
303        assert!(tc.count(&"a", &"b") > 0.0);
304
305        tc.decay(0.85); // → 0.443 — below 0.5, should be pruned
306        assert_eq!(tc.count(&"a", &"b"), 0.0);
307        assert_eq!(tc.total(), 0.0);
308    }
309
310    #[test]
311    fn decay_f64_survives_multiple_cycles() {
312        // Verify the bead requirement: count=1.0 survives ~3 cycles at 0.85
313        let mut tc = TransitionCounter::with_config(1.0, 0.5);
314        tc.record("a", "b");
315
316        tc.decay(0.85); // 0.85
317        assert!(tc.count(&"a", &"b") >= 0.5, "should survive cycle 1");
318
319        tc.decay(0.85); // 0.7225
320        assert!(tc.count(&"a", &"b") >= 0.5, "should survive cycle 2");
321
322        tc.decay(0.85); // 0.614
323        assert!(tc.count(&"a", &"b") >= 0.5, "should survive cycle 3");
324    }
325
326    #[test]
327    fn merge_combines_counters() {
328        let mut tc1 = TransitionCounter::new();
329        tc1.record("a", "b");
330        tc1.record("a", "b");
331
332        let mut tc2 = TransitionCounter::new();
333        tc2.record("a", "b");
334        tc2.record("a", "c");
335
336        tc1.merge(&tc2);
337        assert_eq!(tc1.count(&"a", &"b"), 3.0);
338        assert_eq!(tc1.count(&"a", &"c"), 1.0);
339        assert_eq!(tc1.total(), 4.0);
340        assert_eq!(tc1.total_from(&"a"), 4.0);
341    }
342
343    #[test]
344    fn all_targets_ranked_sorted_desc() {
345        let mut tc = TransitionCounter::new();
346        for _ in 0..10 {
347            tc.record("a", "b");
348        }
349        for _ in 0..3 {
350            tc.record("a", "c");
351        }
352        tc.record("a", "d");
353
354        let ranked = tc.all_targets_ranked(&"a");
355        assert_eq!(ranked.len(), 3);
356        assert_eq!(ranked[0].0, "b"); // highest probability
357        assert_eq!(ranked[1].0, "c");
358        assert_eq!(ranked[2].0, "d"); // lowest probability
359
360        // Probabilities should be descending
361        assert!(ranked[0].1 >= ranked[1].1);
362        assert!(ranked[1].1 >= ranked[2].1);
363    }
364
365    #[test]
366    fn empty_counter_returns_uniform() {
367        let tc: TransitionCounter<&str> = TransitionCounter::new();
368        let ranked = tc.all_targets_ranked(&"a");
369        assert!(ranked.is_empty());
370    }
371
372    #[test]
373    fn state_ids_collects_all() {
374        let mut tc = TransitionCounter::new();
375        tc.record("a", "b");
376        tc.record("c", "d");
377
378        let ids = tc.state_ids();
379        assert_eq!(ids.len(), 4);
380        assert!(ids.contains(&"a"));
381        assert!(ids.contains(&"b"));
382        assert!(ids.contains(&"c"));
383        assert!(ids.contains(&"d"));
384    }
385
386    #[test]
387    fn default_impl() {
388        let tc: TransitionCounter<String> = TransitionCounter::default();
389        assert_eq!(tc.total(), 0.0);
390    }
391
392    #[test]
393    fn total_from_cache_consistent_through_record_merge_decay() {
394        let mut tc = TransitionCounter::new();
395        tc.record("a", "b");
396        tc.record("a", "c");
397        assert_eq!(tc.total_from(&"a"), 2.0);
398
399        let mut tc2 = TransitionCounter::new();
400        tc2.record("a", "b");
401        tc.merge(&tc2);
402        assert_eq!(tc.total_from(&"a"), 3.0);
403
404        tc.decay(0.5);
405        assert!((tc.total_from(&"a") - 1.5).abs() < 1e-10);
406        assert!((tc.total() - 1.5).abs() < 1e-10);
407    }
408
409    #[test]
410    fn single_transition_high_probability() {
411        let mut tc = TransitionCounter::new();
412        tc.record("a", "b");
413
414        // P(b|a) with 1 target, alpha=1: (1+1)/(1+1) = 1.0
415        let p = tc.probability(&"a", &"b");
416        assert!((p - 1.0).abs() < 1e-10);
417    }
418
419    #[test]
420    fn record_with_count_adds_exact_amount() {
421        let mut tc = TransitionCounter::new();
422        tc.record_with_count("a", "b", 7.5);
423        assert_eq!(tc.count(&"a", &"b"), 7.5);
424        assert_eq!(tc.total_from(&"a"), 7.5);
425        assert_eq!(tc.total(), 7.5);
426
427        tc.record_with_count("a", "b", 2.5);
428        assert_eq!(tc.count(&"a", &"b"), 10.0);
429        assert_eq!(tc.total(), 10.0);
430    }
431
432    #[test]
433    fn record_with_count_ignores_zero_and_negative() {
434        let mut tc = TransitionCounter::new();
435        tc.record_with_count("a", "b", 0.0);
436        assert_eq!(tc.total(), 0.0);
437
438        tc.record_with_count("a", "b", -5.0);
439        assert_eq!(tc.total(), 0.0);
440    }
441
442    // ========================================================================
443    // Additional tests (I.1 coverage)
444    // ========================================================================
445
446    #[test]
447    fn count_unrecorded_pair_returns_zero() {
448        let mut tc = TransitionCounter::new();
449        tc.record("a", "b");
450        let c = tc.count(&"a", &"c");
451        eprintln!("count(a→c) = {c}");
452        assert_eq!(c, 0.0);
453
454        let c2 = tc.count(&"z", &"q");
455        eprintln!("count(z→q) = {c2}");
456        assert_eq!(c2, 0.0);
457    }
458
459    #[test]
460    fn probability_sums_to_one() {
461        let mut tc = TransitionCounter::new();
462        tc.record("a", "b");
463        tc.record("a", "b");
464        tc.record("a", "c");
465        tc.record("a", "d");
466
467        let targets = tc.all_targets_ranked(&"a");
468        let sum: f64 = targets.iter().map(|(_, p)| p).sum();
469        eprintln!("targets: {targets:?}, sum: {sum}");
470        assert!(
471            (sum - 1.0).abs() < 1e-10,
472            "probabilities must sum to 1.0, got {sum}"
473        );
474    }
475
476    #[test]
477    fn decay_factor_one_is_identity() {
478        let mut tc = TransitionCounter::new();
479        tc.record("a", "b");
480        tc.record("a", "b");
481        tc.record("a", "c");
482        let total_before = tc.total();
483        let count_ab_before = tc.count(&"a", &"b");
484        let count_ac_before = tc.count(&"a", &"c");
485
486        tc.decay(1.0);
487
488        eprintln!("before: total={total_before}, ab={count_ab_before}, ac={count_ac_before}");
489        eprintln!(
490            "after:  total={}, ab={}, ac={}",
491            tc.total(),
492            tc.count(&"a", &"b"),
493            tc.count(&"a", &"c")
494        );
495        assert_eq!(tc.total(), total_before);
496        assert_eq!(tc.count(&"a", &"b"), count_ab_before);
497        assert_eq!(tc.count(&"a", &"c"), count_ac_before);
498    }
499
500    #[test]
501    fn decay_factor_zero_removes_all() {
502        let mut tc = TransitionCounter::new();
503        tc.record("a", "b");
504        tc.record("a", "c");
505        tc.record("x", "y");
506        let total_before = tc.total();
507        eprintln!("before decay(0): total={total_before}");
508
509        tc.decay(0.0);
510
511        eprintln!("after decay(0): total={}", tc.total());
512        assert_eq!(tc.total(), 0.0);
513        assert_eq!(tc.count(&"a", &"b"), 0.0);
514        assert!(tc.state_ids().is_empty());
515    }
516
517    #[test]
518    fn merge_disjoint_screens_produces_union() {
519        let mut tc1 = TransitionCounter::new();
520        tc1.record("a", "b");
521
522        let mut tc2 = TransitionCounter::new();
523        tc2.record("x", "y");
524
525        tc1.merge(&tc2);
526
527        let ids = tc1.state_ids();
528        eprintln!("merged state_ids: {ids:?}");
529        assert_eq!(ids.len(), 4);
530        assert!(ids.contains(&"a"));
531        assert!(ids.contains(&"b"));
532        assert!(ids.contains(&"x"));
533        assert!(ids.contains(&"y"));
534        assert_eq!(tc1.count(&"a", &"b"), 1.0);
535        assert_eq!(tc1.count(&"x", &"y"), 1.0);
536        assert_eq!(tc1.total(), 2.0);
537    }
538
539    #[test]
540    fn merge_is_commutative() {
541        let mut tc_a = TransitionCounter::new();
542        tc_a.record("a", "b");
543        tc_a.record("a", "b");
544        tc_a.record("a", "c");
545
546        let mut tc_b = TransitionCounter::new();
547        tc_b.record("a", "b");
548        tc_b.record("a", "c");
549        tc_b.record("a", "c");
550
551        // Merge A+B
552        let mut ab = tc_a.clone();
553        ab.merge(&tc_b);
554
555        // Merge B+A
556        let mut ba = tc_b.clone();
557        ba.merge(&tc_a);
558
559        eprintln!(
560            "A+B: ab={}, ac={}",
561            ab.count(&"a", &"b"),
562            ab.count(&"a", &"c")
563        );
564        eprintln!(
565            "B+A: ab={}, ac={}",
566            ba.count(&"a", &"b"),
567            ba.count(&"a", &"c")
568        );
569        assert_eq!(ab.count(&"a", &"b"), ba.count(&"a", &"b"));
570        assert_eq!(ab.count(&"a", &"c"), ba.count(&"a", &"c"));
571        assert_eq!(ab.total(), ba.total());
572    }
573
574    #[test]
575    fn merge_with_empty_counter_is_identity() {
576        let mut tc = TransitionCounter::new();
577        tc.record("a", "b");
578        tc.record("a", "c");
579        let total_before = tc.total();
580        let count_ab_before = tc.count(&"a", &"b");
581
582        let empty = TransitionCounter::<&str>::new();
583        tc.merge(&empty);
584
585        eprintln!(
586            "after merge(empty): total={}, ab={}",
587            tc.total(),
588            tc.count(&"a", &"b")
589        );
590        assert_eq!(tc.total(), total_before);
591        assert_eq!(tc.count(&"a", &"b"), count_ab_before);
592    }
593
594    #[test]
595    fn self_loop_transition_counted_correctly() {
596        let mut tc = TransitionCounter::new();
597        tc.record("a", "a");
598        tc.record("a", "a");
599        tc.record("a", "b");
600
601        let count_aa = tc.count(&"a", &"a");
602        let count_ab = tc.count(&"a", &"b");
603        eprintln!(
604            "self-loop: a→a={count_aa}, a→b={count_ab}, total_from(a)={}",
605            tc.total_from(&"a")
606        );
607        assert_eq!(count_aa, 2.0);
608        assert_eq!(count_ab, 1.0);
609        assert_eq!(tc.total_from(&"a"), 3.0);
610
611        // Self-loop appears in state_ids
612        assert!(tc.state_ids().contains(&"a"));
613    }
614
615    #[test]
616    fn state_ids_empty_counter() {
617        let tc: TransitionCounter<&str> = TransitionCounter::new();
618        let ids = tc.state_ids();
619        eprintln!("empty counter state_ids: {ids:?}");
620        assert!(ids.is_empty());
621    }
622
623    #[test]
624    fn probability_unseen_target_gets_smoothed_value() {
625        let mut tc = TransitionCounter::new();
626        tc.record("a", "b");
627        tc.record("a", "c");
628
629        // "d" was never a target from "a", but smoothing gives it a non-zero prob
630        // n_targets from "a" = 2 (b, c)
631        // P(d|a) = (0 + 1) / (2 + 1*2) = 1/4 = 0.25
632        let p = tc.probability(&"a", &"d");
633        eprintln!("P(a→d) with smoothing = {p}");
634        assert!(
635            p > 0.0,
636            "unseen target should get non-zero probability via smoothing"
637        );
638        assert!((p - 0.25).abs() < 1e-10, "expected 0.25, got {p}");
639    }
640}