1use std::collections::{HashMap, HashSet};
14use std::hash::Hash;
15
16const DEFAULT_PRUNE_THRESHOLD: f64 = 0.5;
18
19const DEFAULT_SMOOTHING_ALPHA: f64 = 1.0;
21
22#[derive(Debug, Clone)]
26pub struct TransitionCounter<S: Eq + Hash + Clone> {
27 counts: HashMap<(S, S), f64>,
29 total_from: HashMap<S, f64>,
31 total_transitions: f64,
33 smoothing_alpha: f64,
35 prune_threshold: f64,
37}
38
39impl<S: Eq + Hash + Clone> TransitionCounter<S> {
40 #[must_use]
42 pub fn new() -> Self {
43 Self {
44 counts: HashMap::new(),
45 total_from: HashMap::new(),
46 total_transitions: 0.0,
47 smoothing_alpha: DEFAULT_SMOOTHING_ALPHA,
48 prune_threshold: DEFAULT_PRUNE_THRESHOLD,
49 }
50 }
51
52 #[must_use]
54 pub fn with_config(smoothing_alpha: f64, prune_threshold: f64) -> Self {
55 Self {
56 counts: HashMap::new(),
57 total_from: HashMap::new(),
58 total_transitions: 0.0,
59 smoothing_alpha: smoothing_alpha.max(0.0),
60 prune_threshold: prune_threshold.max(0.0),
61 }
62 }
63
64 pub fn record(&mut self, from: S, to: S) {
66 self.record_with_count(from, to, 1.0);
67 }
68
69 pub fn record_with_count(&mut self, from: S, to: S, count: f64) {
71 if count <= 0.0 {
72 return;
73 }
74 *self.counts.entry((from.clone(), to)).or_insert(0.0) += count;
75 *self.total_from.entry(from).or_insert(0.0) += count;
76 self.total_transitions += count;
77 }
78
79 #[must_use]
81 pub fn count(&self, from: &S, to: &S) -> f64 {
82 self.counts
83 .get(&(from.clone(), to.clone()))
84 .copied()
85 .unwrap_or(0.0)
86 }
87
88 #[must_use]
90 pub fn total_from(&self, from: &S) -> f64 {
91 self.total_from.get(from).copied().unwrap_or(0.0)
92 }
93
94 #[must_use]
96 pub fn total(&self) -> f64 {
97 self.total_transitions
98 }
99
100 #[must_use]
108 pub fn probability(&self, from: &S, to: &S) -> f64 {
109 let total = self.total_from(from);
110 let raw_count = self.count(from, to);
111
112 let n_targets = self.targets_from(from);
114 let n = if n_targets == 0 { 1 } else { n_targets };
115
116 let alpha = self.smoothing_alpha;
117 let denominator = total + alpha * n as f64;
118
119 if denominator <= 0.0 {
120 if n > 0 { 1.0 / n as f64 } else { 0.0 }
122 } else {
123 (raw_count + alpha) / denominator
124 }
125 }
126
127 #[must_use]
129 pub fn all_targets_ranked(&self, from: &S) -> Vec<(S, f64)> {
130 let mut targets: Vec<(S, f64)> = self
131 .counts
132 .iter()
133 .filter(|((f, _), _)| f == from)
134 .map(|((_, t), _)| {
135 let prob = self.probability(from, t);
136 (t.clone(), prob)
137 })
138 .collect();
139
140 targets.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
141 targets
142 }
143
144 pub fn merge(&mut self, other: &TransitionCounter<S>) {
146 for ((from, to), count) in &other.counts {
147 *self.counts.entry((from.clone(), to.clone())).or_insert(0.0) += count;
148 }
149 self.recompute_totals();
151 }
152
153 pub fn decay(&mut self, factor: f64) {
158 let factor = factor.clamp(0.0, 1.0);
159 let threshold = self.prune_threshold;
160
161 self.counts.retain(|_, count| {
162 *count *= factor;
163 *count >= threshold
164 });
165
166 self.recompute_totals();
167 }
168
169 #[must_use]
171 pub fn state_ids(&self) -> HashSet<S> {
172 let mut ids = HashSet::new();
173 for (from, to) in self.counts.keys() {
174 ids.insert(from.clone());
175 ids.insert(to.clone());
176 }
177 ids
178 }
179
180 fn targets_from(&self, from: &S) -> usize {
182 self.counts.keys().filter(|(f, _)| f == from).count()
183 }
184
185 fn recompute_totals(&mut self) {
187 self.total_from.clear();
188 self.total_transitions = 0.0;
189 for ((from, _), count) in &self.counts {
190 *self.total_from.entry(from.clone()).or_insert(0.0) += count;
191 self.total_transitions += count;
192 }
193 }
194}
195
196impl<S: Eq + Hash + Clone> Default for TransitionCounter<S> {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn record_increments_counts() {
212 let mut tc = TransitionCounter::new();
213 tc.record("a", "b");
214 assert_eq!(tc.count(&"a", &"b"), 1.0);
215
216 tc.record("a", "b");
217 assert_eq!(tc.count(&"a", &"b"), 2.0);
218
219 tc.record("a", "c");
220 assert_eq!(tc.count(&"a", &"c"), 1.0);
221 assert_eq!(tc.total(), 3.0);
222 }
223
224 #[test]
225 fn total_from_tracks_row_sums() {
226 let mut tc = TransitionCounter::new();
227 tc.record("a", "b");
228 tc.record("a", "b");
229 tc.record("a", "c");
230 tc.record("x", "y");
231
232 assert_eq!(tc.total_from(&"a"), 3.0);
233 assert_eq!(tc.total_from(&"x"), 1.0);
234 assert_eq!(tc.total_from(&"z"), 0.0); }
236
237 #[test]
238 fn probability_with_smoothing() {
239 let mut tc = TransitionCounter::new();
240 tc.record("a", "b");
241 tc.record("a", "b");
242 tc.record("a", "c");
243
244 let p_b = tc.probability(&"a", &"b");
248 let p_c = tc.probability(&"a", &"c");
249 assert!((p_b - 0.6).abs() < 1e-10, "p_b = {p_b}");
250 assert!((p_c - 0.4).abs() < 1e-10, "p_c = {p_c}");
251 }
252
253 #[test]
254 fn probability_unseen_target() {
255 let mut tc = TransitionCounter::new();
256 tc.record("a", "b");
257
258 let p = tc.probability(&"a", &"c");
263 assert!((p - 0.5).abs() < 1e-10, "p = {p}");
264 }
265
266 #[test]
267 fn probability_unknown_source() {
268 let tc: TransitionCounter<&str> = TransitionCounter::new();
269 let p = tc.probability(&"x", &"y");
271 assert!((p - 1.0).abs() < 1e-10, "p = {p}"); }
273
274 #[test]
275 fn decay_reduces_counts() {
276 let mut tc = TransitionCounter::new();
277 for _ in 0..10 {
278 tc.record("a", "b");
279 }
280 assert_eq!(tc.total(), 10.0);
281
282 tc.decay(0.5);
283 assert_eq!(tc.total(), 5.0);
284 assert_eq!(tc.count(&"a", &"b"), 5.0);
285 }
286
287 #[test]
288 fn decay_prunes_below_threshold() {
289 let mut tc = TransitionCounter::with_config(1.0, 0.5);
290 tc.record("a", "b"); tc.decay(0.85); assert!(tc.count(&"a", &"b") > 0.0);
294
295 tc.decay(0.85); assert!(tc.count(&"a", &"b") > 0.0);
297
298 tc.decay(0.85); assert!(tc.count(&"a", &"b") > 0.0);
300
301 tc.decay(0.85); assert!(tc.count(&"a", &"b") > 0.0);
304
305 tc.decay(0.85); assert_eq!(tc.count(&"a", &"b"), 0.0);
307 assert_eq!(tc.total(), 0.0);
308 }
309
310 #[test]
311 fn decay_f64_survives_multiple_cycles() {
312 let mut tc = TransitionCounter::with_config(1.0, 0.5);
314 tc.record("a", "b");
315
316 tc.decay(0.85); assert!(tc.count(&"a", &"b") >= 0.5, "should survive cycle 1");
318
319 tc.decay(0.85); assert!(tc.count(&"a", &"b") >= 0.5, "should survive cycle 2");
321
322 tc.decay(0.85); assert!(tc.count(&"a", &"b") >= 0.5, "should survive cycle 3");
324 }
325
326 #[test]
327 fn merge_combines_counters() {
328 let mut tc1 = TransitionCounter::new();
329 tc1.record("a", "b");
330 tc1.record("a", "b");
331
332 let mut tc2 = TransitionCounter::new();
333 tc2.record("a", "b");
334 tc2.record("a", "c");
335
336 tc1.merge(&tc2);
337 assert_eq!(tc1.count(&"a", &"b"), 3.0);
338 assert_eq!(tc1.count(&"a", &"c"), 1.0);
339 assert_eq!(tc1.total(), 4.0);
340 assert_eq!(tc1.total_from(&"a"), 4.0);
341 }
342
343 #[test]
344 fn all_targets_ranked_sorted_desc() {
345 let mut tc = TransitionCounter::new();
346 for _ in 0..10 {
347 tc.record("a", "b");
348 }
349 for _ in 0..3 {
350 tc.record("a", "c");
351 }
352 tc.record("a", "d");
353
354 let ranked = tc.all_targets_ranked(&"a");
355 assert_eq!(ranked.len(), 3);
356 assert_eq!(ranked[0].0, "b"); assert_eq!(ranked[1].0, "c");
358 assert_eq!(ranked[2].0, "d"); assert!(ranked[0].1 >= ranked[1].1);
362 assert!(ranked[1].1 >= ranked[2].1);
363 }
364
365 #[test]
366 fn empty_counter_returns_uniform() {
367 let tc: TransitionCounter<&str> = TransitionCounter::new();
368 let ranked = tc.all_targets_ranked(&"a");
369 assert!(ranked.is_empty());
370 }
371
372 #[test]
373 fn state_ids_collects_all() {
374 let mut tc = TransitionCounter::new();
375 tc.record("a", "b");
376 tc.record("c", "d");
377
378 let ids = tc.state_ids();
379 assert_eq!(ids.len(), 4);
380 assert!(ids.contains(&"a"));
381 assert!(ids.contains(&"b"));
382 assert!(ids.contains(&"c"));
383 assert!(ids.contains(&"d"));
384 }
385
386 #[test]
387 fn default_impl() {
388 let tc: TransitionCounter<String> = TransitionCounter::default();
389 assert_eq!(tc.total(), 0.0);
390 }
391
392 #[test]
393 fn total_from_cache_consistent_through_record_merge_decay() {
394 let mut tc = TransitionCounter::new();
395 tc.record("a", "b");
396 tc.record("a", "c");
397 assert_eq!(tc.total_from(&"a"), 2.0);
398
399 let mut tc2 = TransitionCounter::new();
400 tc2.record("a", "b");
401 tc.merge(&tc2);
402 assert_eq!(tc.total_from(&"a"), 3.0);
403
404 tc.decay(0.5);
405 assert!((tc.total_from(&"a") - 1.5).abs() < 1e-10);
406 assert!((tc.total() - 1.5).abs() < 1e-10);
407 }
408
409 #[test]
410 fn single_transition_high_probability() {
411 let mut tc = TransitionCounter::new();
412 tc.record("a", "b");
413
414 let p = tc.probability(&"a", &"b");
416 assert!((p - 1.0).abs() < 1e-10);
417 }
418
419 #[test]
420 fn record_with_count_adds_exact_amount() {
421 let mut tc = TransitionCounter::new();
422 tc.record_with_count("a", "b", 7.5);
423 assert_eq!(tc.count(&"a", &"b"), 7.5);
424 assert_eq!(tc.total_from(&"a"), 7.5);
425 assert_eq!(tc.total(), 7.5);
426
427 tc.record_with_count("a", "b", 2.5);
428 assert_eq!(tc.count(&"a", &"b"), 10.0);
429 assert_eq!(tc.total(), 10.0);
430 }
431
432 #[test]
433 fn record_with_count_ignores_zero_and_negative() {
434 let mut tc = TransitionCounter::new();
435 tc.record_with_count("a", "b", 0.0);
436 assert_eq!(tc.total(), 0.0);
437
438 tc.record_with_count("a", "b", -5.0);
439 assert_eq!(tc.total(), 0.0);
440 }
441
442 #[test]
447 fn count_unrecorded_pair_returns_zero() {
448 let mut tc = TransitionCounter::new();
449 tc.record("a", "b");
450 let c = tc.count(&"a", &"c");
451 eprintln!("count(a→c) = {c}");
452 assert_eq!(c, 0.0);
453
454 let c2 = tc.count(&"z", &"q");
455 eprintln!("count(z→q) = {c2}");
456 assert_eq!(c2, 0.0);
457 }
458
459 #[test]
460 fn probability_sums_to_one() {
461 let mut tc = TransitionCounter::new();
462 tc.record("a", "b");
463 tc.record("a", "b");
464 tc.record("a", "c");
465 tc.record("a", "d");
466
467 let targets = tc.all_targets_ranked(&"a");
468 let sum: f64 = targets.iter().map(|(_, p)| p).sum();
469 eprintln!("targets: {targets:?}, sum: {sum}");
470 assert!(
471 (sum - 1.0).abs() < 1e-10,
472 "probabilities must sum to 1.0, got {sum}"
473 );
474 }
475
476 #[test]
477 fn decay_factor_one_is_identity() {
478 let mut tc = TransitionCounter::new();
479 tc.record("a", "b");
480 tc.record("a", "b");
481 tc.record("a", "c");
482 let total_before = tc.total();
483 let count_ab_before = tc.count(&"a", &"b");
484 let count_ac_before = tc.count(&"a", &"c");
485
486 tc.decay(1.0);
487
488 eprintln!("before: total={total_before}, ab={count_ab_before}, ac={count_ac_before}");
489 eprintln!(
490 "after: total={}, ab={}, ac={}",
491 tc.total(),
492 tc.count(&"a", &"b"),
493 tc.count(&"a", &"c")
494 );
495 assert_eq!(tc.total(), total_before);
496 assert_eq!(tc.count(&"a", &"b"), count_ab_before);
497 assert_eq!(tc.count(&"a", &"c"), count_ac_before);
498 }
499
500 #[test]
501 fn decay_factor_zero_removes_all() {
502 let mut tc = TransitionCounter::new();
503 tc.record("a", "b");
504 tc.record("a", "c");
505 tc.record("x", "y");
506 let total_before = tc.total();
507 eprintln!("before decay(0): total={total_before}");
508
509 tc.decay(0.0);
510
511 eprintln!("after decay(0): total={}", tc.total());
512 assert_eq!(tc.total(), 0.0);
513 assert_eq!(tc.count(&"a", &"b"), 0.0);
514 assert!(tc.state_ids().is_empty());
515 }
516
517 #[test]
518 fn merge_disjoint_screens_produces_union() {
519 let mut tc1 = TransitionCounter::new();
520 tc1.record("a", "b");
521
522 let mut tc2 = TransitionCounter::new();
523 tc2.record("x", "y");
524
525 tc1.merge(&tc2);
526
527 let ids = tc1.state_ids();
528 eprintln!("merged state_ids: {ids:?}");
529 assert_eq!(ids.len(), 4);
530 assert!(ids.contains(&"a"));
531 assert!(ids.contains(&"b"));
532 assert!(ids.contains(&"x"));
533 assert!(ids.contains(&"y"));
534 assert_eq!(tc1.count(&"a", &"b"), 1.0);
535 assert_eq!(tc1.count(&"x", &"y"), 1.0);
536 assert_eq!(tc1.total(), 2.0);
537 }
538
539 #[test]
540 fn merge_is_commutative() {
541 let mut tc_a = TransitionCounter::new();
542 tc_a.record("a", "b");
543 tc_a.record("a", "b");
544 tc_a.record("a", "c");
545
546 let mut tc_b = TransitionCounter::new();
547 tc_b.record("a", "b");
548 tc_b.record("a", "c");
549 tc_b.record("a", "c");
550
551 let mut ab = tc_a.clone();
553 ab.merge(&tc_b);
554
555 let mut ba = tc_b.clone();
557 ba.merge(&tc_a);
558
559 eprintln!(
560 "A+B: ab={}, ac={}",
561 ab.count(&"a", &"b"),
562 ab.count(&"a", &"c")
563 );
564 eprintln!(
565 "B+A: ab={}, ac={}",
566 ba.count(&"a", &"b"),
567 ba.count(&"a", &"c")
568 );
569 assert_eq!(ab.count(&"a", &"b"), ba.count(&"a", &"b"));
570 assert_eq!(ab.count(&"a", &"c"), ba.count(&"a", &"c"));
571 assert_eq!(ab.total(), ba.total());
572 }
573
574 #[test]
575 fn merge_with_empty_counter_is_identity() {
576 let mut tc = TransitionCounter::new();
577 tc.record("a", "b");
578 tc.record("a", "c");
579 let total_before = tc.total();
580 let count_ab_before = tc.count(&"a", &"b");
581
582 let empty = TransitionCounter::<&str>::new();
583 tc.merge(&empty);
584
585 eprintln!(
586 "after merge(empty): total={}, ab={}",
587 tc.total(),
588 tc.count(&"a", &"b")
589 );
590 assert_eq!(tc.total(), total_before);
591 assert_eq!(tc.count(&"a", &"b"), count_ab_before);
592 }
593
594 #[test]
595 fn self_loop_transition_counted_correctly() {
596 let mut tc = TransitionCounter::new();
597 tc.record("a", "a");
598 tc.record("a", "a");
599 tc.record("a", "b");
600
601 let count_aa = tc.count(&"a", &"a");
602 let count_ab = tc.count(&"a", &"b");
603 eprintln!(
604 "self-loop: a→a={count_aa}, a→b={count_ab}, total_from(a)={}",
605 tc.total_from(&"a")
606 );
607 assert_eq!(count_aa, 2.0);
608 assert_eq!(count_ab, 1.0);
609 assert_eq!(tc.total_from(&"a"), 3.0);
610
611 assert!(tc.state_ids().contains(&"a"));
613 }
614
615 #[test]
616 fn state_ids_empty_counter() {
617 let tc: TransitionCounter<&str> = TransitionCounter::new();
618 let ids = tc.state_ids();
619 eprintln!("empty counter state_ids: {ids:?}");
620 assert!(ids.is_empty());
621 }
622
623 #[test]
624 fn probability_unseen_target_gets_smoothed_value() {
625 let mut tc = TransitionCounter::new();
626 tc.record("a", "b");
627 tc.record("a", "c");
628
629 let p = tc.probability(&"a", &"d");
633 eprintln!("P(a→d) with smoothing = {p}");
634 assert!(
635 p > 0.0,
636 "unseen target should get non-zero probability via smoothing"
637 );
638 assert!((p - 0.25).abs() < 1e-10, "expected 0.25, got {p}");
639 }
640}