datacortex_core/state/
state_map.rs1use super::state_table::StateTable;
11
12const MAX_COUNT: u16 = 192;
15
16const MIN_DENOM: i32 = 4;
18
19#[derive(Clone, Copy)]
21struct Entry {
22 prob: u16,
24 count: u16,
26}
27
28pub struct StateMap {
30 entries: [Entry; 256],
31}
32
33impl StateMap {
34 pub fn new() -> Self {
36 let mut entries = [Entry {
37 prob: 2048,
38 count: 0,
39 }; 256];
40
41 for (i, entry) in entries.iter_mut().enumerate() {
42 let base_prob = StateTable::prob(i as u8);
43 entry.prob = base_prob;
44 let dist_from_center = base_prob.abs_diff(2048);
48 entry.count = ((dist_from_center as u32) / 600).min(4) as u16;
50 }
51
52 StateMap { entries }
53 }
54
55 #[inline(always)]
58 pub fn predict(&self, state: u8) -> u32 {
59 self.entries[state as usize].prob as u32
60 }
61
62 #[inline(always)]
66 pub fn update(&mut self, state: u8, bit: u8) {
67 let e = &mut self.entries[state as usize];
68 let target = if bit != 0 { 4095i32 } else { 0i32 };
69 let p = e.prob as i32;
70 let count = e.count as i32 + MIN_DENOM;
71 let delta = (target - p) / count;
72 let new_p = (p + delta).clamp(1, 4095);
73 e.prob = new_p as u16;
74 if e.count < MAX_COUNT {
75 e.count += 1;
76 }
77 }
78}
79
80impl Default for StateMap {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn initial_state_0_is_balanced() {
92 let sm = StateMap::new();
93 assert_eq!(sm.predict(0), 2048);
94 }
95
96 #[test]
97 fn predictions_in_range() {
98 let sm = StateMap::new();
99 for s in 0..=255u8 {
100 let p = sm.predict(s);
101 assert!((1..=4095).contains(&p), "state {s}: pred {p} out of range");
102 }
103 }
104
105 #[test]
106 fn update_toward_one() {
107 let mut sm = StateMap::new();
108 let before = sm.predict(0);
109 sm.update(0, 1);
110 let after = sm.predict(0);
111 assert!(
112 after >= before,
113 "after seeing 1, prob should increase: {before} -> {after}"
114 );
115 }
116
117 #[test]
118 fn update_toward_zero() {
119 let mut sm = StateMap::new();
120 let before = sm.predict(0);
121 sm.update(0, 0);
122 let after = sm.predict(0);
123 assert!(
124 after <= before,
125 "after seeing 0, prob should decrease: {before} -> {after}"
126 );
127 }
128
129 #[test]
130 fn many_ones_converge_high() {
131 let mut sm = StateMap::new();
132 for _ in 0..200 {
133 sm.update(0, 1);
134 }
135 let p = sm.predict(0);
136 assert!(p > 3500, "many 1s should push probability high: {p}");
137 }
138
139 #[test]
140 fn many_zeros_converge_low() {
141 let mut sm = StateMap::new();
142 for _ in 0..200 {
143 sm.update(0, 0);
144 }
145 let p = sm.predict(0);
146 assert!(p < 500, "many 0s should push probability low: {p}");
147 }
148
149 #[test]
150 fn probability_stays_in_bounds() {
151 let mut sm = StateMap::new();
152 for _ in 0..1000 {
153 sm.update(128, 1);
154 }
155 assert!(sm.predict(128) >= 1 && sm.predict(128) <= 4095);
156
157 for _ in 0..1000 {
158 sm.update(128, 0);
159 }
160 assert!(sm.predict(128) >= 1 && sm.predict(128) <= 4095);
161 }
162
163 #[test]
164 fn different_states_independent() {
165 let mut sm = StateMap::new();
166 let before_10 = sm.predict(10);
167 sm.update(20, 1);
168 sm.update(20, 1);
169 sm.update(20, 1);
170 let after_10 = sm.predict(10);
171 assert_eq!(
172 before_10, after_10,
173 "updating state 20 should not affect state 10"
174 );
175 }
176}