Skip to main content

datacortex_core/state/
state_map.rs

1//! StateMap -- adaptive state-to-probability mapper.
2//!
3//! Maps each of 256 states to a 12-bit probability (1-4095) that adapts
4//! based on observed bits. Uses variable learning rate:
5//! - Fast learning for low-count states (early observations matter more)
6//! - Slow learning for high-count states (well-established statistics)
7//!
8//! Initialization uses the StateTable's static probability as the prior.
9
10use super::state_table::StateTable;
11
12/// Maximum learning count. Higher = better convergence on stable statistics.
13/// 192 is a good balance: fast enough to adapt, slow enough to converge well.
14const MAX_COUNT: u16 = 192;
15
16/// Minimum learning denominator. Prevents divide-by-zero and initial overshoot.
17const MIN_DENOM: i32 = 4;
18
19/// A single entry in the state map.
20#[derive(Clone, Copy)]
21struct Entry {
22    /// 12-bit probability of bit=1, range [1, 4095].
23    prob: u16,
24    /// Number of updates (capped at MAX_COUNT).
25    count: u16,
26}
27
28/// Maps 256 states to adaptive 12-bit probabilities.
29pub struct StateMap {
30    entries: [Entry; 256],
31}
32
33impl StateMap {
34    /// Create a new StateMap initialized from the StateTable's static probabilities.
35    pub fn new() -> Self {
36        let mut entries = [Entry {
37            prob: 2048,
38            count: 0,
39        }; 256];
40
41        for (i, entry) in entries.iter_mut().enumerate() {
42            let base_prob = StateTable::prob(i as u8);
43            entry.prob = base_prob;
44            // Give initial confidence proportional to how extreme the state is.
45            // States near 50/50 get low initial count (easy to update).
46            // Extreme states get higher initial count (harder to shift).
47            let dist_from_center = base_prob.abs_diff(2048);
48            // Scale: 0 at center, up to ~4 at extremes
49            entry.count = ((dist_from_center as u32) / 600).min(4) as u16;
50        }
51
52        StateMap { entries }
53    }
54
55    /// Get the predicted probability of bit=1 for the given state.
56    /// Returns a 12-bit value in [1, 4095].
57    #[inline(always)]
58    pub fn predict(&self, state: u8) -> u32 {
59        self.entries[state as usize].prob as u32
60    }
61
62    /// Update the probability for the given state after observing `bit`.
63    /// Uses adaptive learning: p += (target - p) / (count + MIN_DENOM).
64    /// Learning rate decreases as count increases (1/n convergence).
65    #[inline(always)]
66    pub fn update(&mut self, state: u8, bit: u8) {
67        let e = &mut self.entries[state as usize];
68        let target = if bit != 0 { 4095i32 } else { 0i32 };
69        let p = e.prob as i32;
70        let count = e.count as i32 + MIN_DENOM;
71        let delta = (target - p) / count;
72        let new_p = (p + delta).clamp(1, 4095);
73        e.prob = new_p as u16;
74        if e.count < MAX_COUNT {
75            e.count += 1;
76        }
77    }
78}
79
80impl Default for StateMap {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn initial_state_0_is_balanced() {
92        let sm = StateMap::new();
93        assert_eq!(sm.predict(0), 2048);
94    }
95
96    #[test]
97    fn predictions_in_range() {
98        let sm = StateMap::new();
99        for s in 0..=255u8 {
100            let p = sm.predict(s);
101            assert!((1..=4095).contains(&p), "state {s}: pred {p} out of range");
102        }
103    }
104
105    #[test]
106    fn update_toward_one() {
107        let mut sm = StateMap::new();
108        let before = sm.predict(0);
109        sm.update(0, 1);
110        let after = sm.predict(0);
111        assert!(
112            after >= before,
113            "after seeing 1, prob should increase: {before} -> {after}"
114        );
115    }
116
117    #[test]
118    fn update_toward_zero() {
119        let mut sm = StateMap::new();
120        let before = sm.predict(0);
121        sm.update(0, 0);
122        let after = sm.predict(0);
123        assert!(
124            after <= before,
125            "after seeing 0, prob should decrease: {before} -> {after}"
126        );
127    }
128
129    #[test]
130    fn many_ones_converge_high() {
131        let mut sm = StateMap::new();
132        for _ in 0..200 {
133            sm.update(0, 1);
134        }
135        let p = sm.predict(0);
136        assert!(p > 3500, "many 1s should push probability high: {p}");
137    }
138
139    #[test]
140    fn many_zeros_converge_low() {
141        let mut sm = StateMap::new();
142        for _ in 0..200 {
143            sm.update(0, 0);
144        }
145        let p = sm.predict(0);
146        assert!(p < 500, "many 0s should push probability low: {p}");
147    }
148
149    #[test]
150    fn probability_stays_in_bounds() {
151        let mut sm = StateMap::new();
152        for _ in 0..1000 {
153            sm.update(128, 1);
154        }
155        assert!(sm.predict(128) >= 1 && sm.predict(128) <= 4095);
156
157        for _ in 0..1000 {
158            sm.update(128, 0);
159        }
160        assert!(sm.predict(128) >= 1 && sm.predict(128) <= 4095);
161    }
162
163    #[test]
164    fn different_states_independent() {
165        let mut sm = StateMap::new();
166        let before_10 = sm.predict(10);
167        sm.update(20, 1);
168        sm.update(20, 1);
169        sm.update(20, 1);
170        let after_10 = sm.predict(10);
171        assert_eq!(
172            before_10, after_10,
173            "updating state 20 should not affect state 10"
174        );
175    }
176}