datacortex_core/state/
state_table.rs1pub const NUM_STATES: usize = 256;
17
18const fn build_table() -> ([u8; 512], [u16; 256]) {
37 let mut next = [0u8; 512];
40 let mut prob = [2048u16; 256];
41
42 let mut n0 = [0u16; 256];
45 let mut n1 = [0u16; 256];
46
47 n0[0] = 0;
49 n1[0] = 0;
50
51 let mut state_idx: usize = 1;
54
55 let mut total: u16 = 1;
57 while total <= 8 {
58 let mut b: u16 = 0;
59 while b <= total {
60 let a = total - b;
61 if state_idx < 256 {
62 n0[state_idx] = a;
63 n1[state_idx] = b;
64 state_idx += 1;
65 }
66 b += 1;
67 }
68 total += 1;
69 }
70
71 total = 9;
73 while total <= 20 {
74 let mut b: u16 = 0;
75 while b <= total {
76 let a = total - b;
77 let ratio_ok = {
79 let min_val = if a < b { a } else { b };
80 let max_val = if a > b { a } else { b };
81 min_val == 0 || max_val <= min_val * 4
82 };
83 if ratio_ok && state_idx < 256 {
84 n0[state_idx] = a;
85 n1[state_idx] = b;
86 state_idx += 1;
87 }
88 b += 1;
89 }
90 total += 1;
91 }
92
93 let mut run: u16 = 21;
96 while run <= 60 && state_idx < 256 {
97 n0[state_idx] = run;
98 n1[state_idx] = 0;
99 state_idx += 1;
100 n0[state_idx] = 0;
101 n1[state_idx] = run;
102 state_idx += 1;
103 run += 3;
104 }
105
106 while state_idx < 256 {
108 n0[state_idx] = 1;
109 n1[state_idx] = 1;
110 state_idx += 1;
111 }
112
113 let mut s: usize = 0;
116 while s < 256 {
117 let t = n0[s] as u32 + n1[s] as u32 + 2;
118 let p = ((n1[s] as u32 + 1) * 4095 + t / 2) / t;
119 let p = if p < 1 {
120 1
121 } else if p > 4095 {
122 4095
123 } else {
124 p
125 };
126 prob[s] = p as u16;
127 s += 1;
128 }
129
130 s = 0;
135 while s < 256 {
136 let s_n0 = n0[s];
137 let s_n1 = n1[s];
138
139 let (target_n0_0, target_n1_0) = scale_counts(s_n0 + 1, s_n1);
141 next[s * 2] = find_closest_state(&n0, &n1, target_n0_0, target_n1_0, s);
142
143 let (target_n0_1, target_n1_1) = scale_counts(s_n0, s_n1 + 1);
145 next[s * 2 + 1] = find_closest_state(&n0, &n1, target_n0_1, target_n1_1, s);
146
147 s += 1;
148 }
149
150 (next, prob)
151}
152
153const fn scale_counts(a: u16, b: u16) -> (u16, u16) {
156 let total = a + b;
157 if total <= 20 {
158 (a, b)
159 } else if total <= 40 {
160 if a >= b {
162 let new_b = b * 3 / 4;
163 (
164 a,
165 if new_b > 0 {
166 new_b
167 } else if b > 0 {
168 1
169 } else {
170 0
171 },
172 )
173 } else {
174 let new_a = a * 3 / 4;
175 (
176 if new_a > 0 {
177 new_a
178 } else if a > 0 {
179 1
180 } else {
181 0
182 },
183 b,
184 )
185 }
186 } else {
187 if a >= b {
189 let new_b = b / 2;
190 let new_a = if a > 60 { 60 } else { a };
191 (new_a, new_b)
192 } else {
193 let new_a = a / 2;
194 let new_b = if b > 60 { 60 } else { b };
195 (new_a, new_b)
196 }
197 }
198}
199
200const fn find_closest_state(
203 n0: &[u16; 256],
204 n1: &[u16; 256],
205 target_n0: u16,
206 target_n1: u16,
207 current: usize,
208) -> u8 {
209 let mut best: usize = 0;
210 let mut best_dist: u32 = u32::MAX;
211
212 let mut i: usize = 0;
213 while i < 256 {
214 if i != current || current == 0 {
216 let d0 = (n0[i] as u32).abs_diff(target_n0 as u32);
217 let d1 = (n1[i] as u32).abs_diff(target_n1 as u32);
218
219 let target_total = target_n0 as u32 + target_n1 as u32;
221 let state_total = n0[i] as u32 + n1[i] as u32;
222 let total_diff = state_total.abs_diff(target_total);
223
224 let dist = d0 * 3 + d1 * 3 + total_diff;
225
226 if dist < best_dist {
227 best_dist = dist;
228 best = i;
229 }
230 }
231 i += 1;
232 }
233
234 best as u8
235}
236
237static TABLE: ([u8; 512], [u16; 256]) = build_table();
238
239pub struct StateTable;
241
242impl StateTable {
243 #[inline(always)]
245 pub fn next(s: u8, bit: u8) -> u8 {
246 TABLE.0[s as usize * 2 + (bit as usize & 1)]
247 }
248
249 #[inline(always)]
252 pub fn prob(s: u8) -> u16 {
253 TABLE.1[s as usize]
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn initial_state_is_balanced() {
263 assert_eq!(StateTable::prob(0), 2048);
264 }
265
266 #[test]
267 fn state_0_transitions_are_distinct() {
268 let next0 = StateTable::next(0, 0);
269 let next1 = StateTable::next(0, 1);
270 assert_ne!(next0, next1, "state 0 transitions should differ");
271 assert!(
273 StateTable::prob(next0) < 2048,
274 "bit=0 should go to low-probability state, got {}",
275 StateTable::prob(next0)
276 );
277 assert!(
279 StateTable::prob(next1) > 2048,
280 "bit=1 should go to high-probability state, got {}",
281 StateTable::prob(next1)
282 );
283 }
284
285 #[test]
286 fn probabilities_are_in_range() {
287 for s in 0..NUM_STATES {
288 let p = StateTable::prob(s as u8);
289 assert!((1..=4095).contains(&p), "state {s}: prob {p} out of range");
290 }
291 }
292
293 #[test]
294 fn transitions_stay_in_range() {
295 for s in 0..NUM_STATES {
296 for bit in 0..2u8 {
297 let next = StateTable::next(s as u8, bit);
298 assert!(
299 (next as usize) < NUM_STATES,
300 "state {s}, bit {bit}: next={next} out of range"
301 );
302 }
303 }
304 }
305
306 #[test]
307 fn repeated_zeros_decrease_probability() {
308 let mut s = 0u8;
309 s = StateTable::next(s, 0); let p1 = StateTable::prob(s);
311 s = StateTable::next(s, 0); let p2 = StateTable::prob(s);
313 assert!(p2 <= p1, "more 0s should decrease P(1): p1={p1}, p2={p2}");
314 }
315
316 #[test]
317 fn repeated_ones_increase_probability() {
318 let mut s = 0u8;
319 s = StateTable::next(s, 1); let p1 = StateTable::prob(s);
321 s = StateTable::next(s, 1); let p2 = StateTable::prob(s);
323 assert!(p2 >= p1, "more 1s should increase P(1): p1={p1}, p2={p2}");
324 }
325
326 #[test]
327 fn no_state_maps_to_zero_from_nonzero() {
328 for s in 1..NUM_STATES {
330 for bit in 0..2u8 {
331 let next = StateTable::next(s as u8, bit);
332 if s > 2 {
334 assert!(
335 next != 0,
336 "state {s}, bit {bit}: should not transition back to state 0"
337 );
338 }
339 }
340 }
341 }
342
343 #[test]
344 fn convergence_all_zeros() {
345 let mut s = 0u8;
346 for _ in 0..50 {
347 s = StateTable::next(s, 0);
348 }
349 let p = StateTable::prob(s);
350 assert!(p < 200, "50 zeros should give very low P(1): {p}");
351 }
352
353 #[test]
354 fn convergence_all_ones() {
355 let mut s = 0u8;
356 for _ in 0..50 {
357 s = StateTable::next(s, 1);
358 }
359 let p = StateTable::prob(s);
360 assert!(p > 3900, "50 ones should give very high P(1): {p}");
361 }
362
363 #[test]
364 fn mixed_sequence_stays_moderate() {
365 let mut s = 0u8;
366 for i in 0..100 {
368 s = StateTable::next(s, (i & 1) as u8);
369 }
370 let p = StateTable::prob(s);
371 assert!(
372 (500..=3500).contains(&p),
373 "alternating bits should give moderate P(1): {p}"
374 );
375 }
376}