Skip to main content

datacortex_core/state/
state_table.rs

1//! StateTable — PAQ8-style 256-state bit history machine.
2//!
3//! Each state encodes a compact (n0, n1) bit count pair with recency bias.
4//! State 0 is the initial state (no history observed).
5//!
6//! The table maps (state, bit) -> next_state and provides an approximate
7//! probability of bit=1 for each state based on the encoded counts.
8//!
9//! Design based on PAQ8 state table:
10//! - States encode (n0, n1) count pairs compactly
11//! - Low counts have fine granularity, high counts are coarsened
12//! - Transitions favor recent observations (recency bias)
13//! - State 0 = initial (equal probability)
14
15/// Number of states in the table.
16pub const NUM_STATES: usize = 256;
17
18// State-to-(n0,n1) mapping. Generated to cover the useful range of
19// (n0, n1) combinations with finer granularity near (0,0) and coarser
20// at higher counts.
21
22// The core state table: 256 entries with (n0, n1, next0, next1).
23// Each state represents a compact (n0, n1) bit count history.
24// Transitions:
25//   On bit=0: move to state with (n0+1, n1) or scaled equivalent
26//   On bit=1: move to state with (n0, n1+1) or scaled equivalent
27// For high counts, we scale down the less-recent count to maintain
28// recency bias (the table "forgets" old observations gradually).
29//
30// We generate the table programmatically for optimal coverage.
31// Strategy: enumerate (n0, n1) pairs in a priority order that covers
32// the most useful probability ranges, then create transition links.
33
34// === COMPILE-TIME TABLE GENERATION ===
35
36const fn build_table() -> ([u8; 512], [u16; 256]) {
37    // NEXT[s*2 + bit] = next_state
38    // PROB[s] = 12-bit probability of bit=1
39    let mut next = [0u8; 512];
40    let mut prob = [2048u16; 256];
41
42    // (n0, n1) pairs for each state.
43    // We assign states systematically.
44    let mut n0 = [0u16; 256];
45    let mut n1 = [0u16; 256];
46
47    // State 0: initial (0, 0)
48    n0[0] = 0;
49    n1[0] = 0;
50
51    // Fill states with (n0, n1) pairs in a spiral pattern.
52    // Priority: small total counts first, balanced within each total.
53    let mut state_idx: usize = 1;
54
55    // Fine granularity: all (a, b) with a+b = 1..8 (excluding (0,0) already used)
56    let mut total: u16 = 1;
57    while total <= 8 {
58        let mut b: u16 = 0;
59        while b <= total {
60            let a = total - b;
61            if state_idx < 256 {
62                n0[state_idx] = a;
63                n1[state_idx] = b;
64                state_idx += 1;
65            }
66            b += 1;
67        }
68        total += 1;
69    }
70
71    // Medium: selected pairs with total 9..20, stride 1
72    total = 9;
73    while total <= 20 {
74        let mut b: u16 = 0;
75        while b <= total {
76            let a = total - b;
77            // Skip some to save states - only include if balanced-ish or extreme
78            let ratio_ok = {
79                let min_val = if a < b { a } else { b };
80                let max_val = if a > b { a } else { b };
81                min_val == 0 || max_val <= min_val * 4
82            };
83            if ratio_ok && state_idx < 256 {
84                n0[state_idx] = a;
85                n1[state_idx] = b;
86                state_idx += 1;
87            }
88            b += 1;
89        }
90        total += 1;
91    }
92
93    // Coarse: high-count runs and extreme ratios
94    // Long run states for consecutive 0s
95    let mut run: u16 = 21;
96    while run <= 60 && state_idx < 256 {
97        n0[state_idx] = run;
98        n1[state_idx] = 0;
99        state_idx += 1;
100        n0[state_idx] = 0;
101        n1[state_idx] = run;
102        state_idx += 1;
103        run += 3;
104    }
105
106    // Fill remaining with high-ratio states
107    while state_idx < 256 {
108        n0[state_idx] = 1;
109        n1[state_idx] = 1;
110        state_idx += 1;
111    }
112
113    // === Compute probabilities ===
114    // P(1) = (n1 + 1) / (n0 + n1 + 2) mapped to [1, 4095]
115    let mut s: usize = 0;
116    while s < 256 {
117        let t = n0[s] as u32 + n1[s] as u32 + 2;
118        let p = ((n1[s] as u32 + 1) * 4095 + t / 2) / t;
119        let p = if p < 1 {
120            1
121        } else if p > 4095 {
122            4095
123        } else {
124            p
125        };
126        prob[s] = p as u16;
127        s += 1;
128    }
129
130    // === Build lookup table for finding states by (n0, n1) ===
131    // For transitions, we need to find the state with the closest (n0, n1).
132    // Build transitions: for each state s, find best state for (n0+1, n1) and (n0, n1+1).
133
134    s = 0;
135    while s < 256 {
136        let s_n0 = n0[s];
137        let s_n1 = n1[s];
138
139        // Transition on bit=0: want (s_n0 + 1, s_n1), possibly scaled
140        let (target_n0_0, target_n1_0) = scale_counts(s_n0 + 1, s_n1);
141        next[s * 2] = find_closest_state(&n0, &n1, target_n0_0, target_n1_0, s);
142
143        // Transition on bit=1: want (s_n0, s_n1 + 1), possibly scaled
144        let (target_n0_1, target_n1_1) = scale_counts(s_n0, s_n1 + 1);
145        next[s * 2 + 1] = find_closest_state(&n0, &n1, target_n0_1, target_n1_1, s);
146
147        s += 1;
148    }
149
150    (next, prob)
151}
152
153/// Scale down counts to maintain bounded state space with recency bias.
154/// When total gets too large, scale the smaller count down.
155const fn scale_counts(a: u16, b: u16) -> (u16, u16) {
156    let total = a + b;
157    if total <= 20 {
158        (a, b)
159    } else if total <= 40 {
160        // Light scaling: reduce minority count by ~25%
161        if a >= b {
162            let new_b = b * 3 / 4;
163            (
164                a,
165                if new_b > 0 {
166                    new_b
167                } else if b > 0 {
168                    1
169                } else {
170                    0
171                },
172            )
173        } else {
174            let new_a = a * 3 / 4;
175            (
176                if new_a > 0 {
177                    new_a
178                } else if a > 0 {
179                    1
180                } else {
181                    0
182                },
183                b,
184            )
185        }
186    } else {
187        // Heavy scaling: halve the minority count
188        if a >= b {
189            let new_b = b / 2;
190            let new_a = if a > 60 { 60 } else { a };
191            (new_a, new_b)
192        } else {
193            let new_a = a / 2;
194            let new_b = if b > 60 { 60 } else { b };
195            (new_a, new_b)
196        }
197    }
198}
199
200/// Find the state with (n0, n1) closest to (target_n0, target_n1).
201/// Uses L1 distance weighted by total count difference.
202const fn find_closest_state(
203    n0: &[u16; 256],
204    n1: &[u16; 256],
205    target_n0: u16,
206    target_n1: u16,
207    current: usize,
208) -> u8 {
209    let mut best: usize = 0;
210    let mut best_dist: u32 = u32::MAX;
211
212    let mut i: usize = 0;
213    while i < 256 {
214        // Don't map back to the same state (except for state 0 which can self-loop)
215        if i != current || current == 0 {
216            let d0 = (n0[i] as u32).abs_diff(target_n0 as u32);
217            let d1 = (n1[i] as u32).abs_diff(target_n1 as u32);
218
219            // Weight ratio preservation: penalize flipping the majority
220            let target_total = target_n0 as u32 + target_n1 as u32;
221            let state_total = n0[i] as u32 + n1[i] as u32;
222            let total_diff = state_total.abs_diff(target_total);
223
224            let dist = d0 * 3 + d1 * 3 + total_diff;
225
226            if dist < best_dist {
227                best_dist = dist;
228                best = i;
229            }
230        }
231        i += 1;
232    }
233
234    best as u8
235}
236
237static TABLE: ([u8; 512], [u16; 256]) = build_table();
238
239/// The state table -- provides transitions and initial probabilities.
240pub struct StateTable;
241
242impl StateTable {
243    /// Get the next state after observing `bit` in state `s`.
244    #[inline(always)]
245    pub fn next(s: u8, bit: u8) -> u8 {
246        TABLE.0[s as usize * 2 + (bit as usize & 1)]
247    }
248
249    /// Get the initial (static) probability of bit=1 for state `s`.
250    /// Returns a 12-bit value in [1, 4095].
251    #[inline(always)]
252    pub fn prob(s: u8) -> u16 {
253        TABLE.1[s as usize]
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn initial_state_is_balanced() {
263        assert_eq!(StateTable::prob(0), 2048);
264    }
265
266    #[test]
267    fn state_0_transitions_are_distinct() {
268        let next0 = StateTable::next(0, 0);
269        let next1 = StateTable::next(0, 1);
270        assert_ne!(next0, next1, "state 0 transitions should differ");
271        // After seeing 0, probability should decrease
272        assert!(
273            StateTable::prob(next0) < 2048,
274            "bit=0 should go to low-probability state, got {}",
275            StateTable::prob(next0)
276        );
277        // After seeing 1, probability should increase
278        assert!(
279            StateTable::prob(next1) > 2048,
280            "bit=1 should go to high-probability state, got {}",
281            StateTable::prob(next1)
282        );
283    }
284
285    #[test]
286    fn probabilities_are_in_range() {
287        for s in 0..NUM_STATES {
288            let p = StateTable::prob(s as u8);
289            assert!((1..=4095).contains(&p), "state {s}: prob {p} out of range");
290        }
291    }
292
293    #[test]
294    fn transitions_stay_in_range() {
295        for s in 0..NUM_STATES {
296            for bit in 0..2u8 {
297                let next = StateTable::next(s as u8, bit);
298                assert!(
299                    (next as usize) < NUM_STATES,
300                    "state {s}, bit {bit}: next={next} out of range"
301                );
302            }
303        }
304    }
305
306    #[test]
307    fn repeated_zeros_decrease_probability() {
308        let mut s = 0u8;
309        s = StateTable::next(s, 0); // first 0
310        let p1 = StateTable::prob(s);
311        s = StateTable::next(s, 0); // second 0
312        let p2 = StateTable::prob(s);
313        assert!(p2 <= p1, "more 0s should decrease P(1): p1={p1}, p2={p2}");
314    }
315
316    #[test]
317    fn repeated_ones_increase_probability() {
318        let mut s = 0u8;
319        s = StateTable::next(s, 1); // first 1
320        let p1 = StateTable::prob(s);
321        s = StateTable::next(s, 1); // second 1
322        let p2 = StateTable::prob(s);
323        assert!(p2 >= p1, "more 1s should increase P(1): p1={p1}, p2={p2}");
324    }
325
326    #[test]
327    fn no_state_maps_to_zero_from_nonzero() {
328        // After first observation, state should not return to 0
329        for s in 1..NUM_STATES {
330            for bit in 0..2u8 {
331                let next = StateTable::next(s as u8, bit);
332                // State 0 is okay only if current state is also near-initial
333                if s > 2 {
334                    assert!(
335                        next != 0,
336                        "state {s}, bit {bit}: should not transition back to state 0"
337                    );
338                }
339            }
340        }
341    }
342
343    #[test]
344    fn convergence_all_zeros() {
345        let mut s = 0u8;
346        for _ in 0..50 {
347            s = StateTable::next(s, 0);
348        }
349        let p = StateTable::prob(s);
350        assert!(p < 200, "50 zeros should give very low P(1): {p}");
351    }
352
353    #[test]
354    fn convergence_all_ones() {
355        let mut s = 0u8;
356        for _ in 0..50 {
357            s = StateTable::next(s, 1);
358        }
359        let p = StateTable::prob(s);
360        assert!(p > 3900, "50 ones should give very high P(1): {p}");
361    }
362
363    #[test]
364    fn mixed_sequence_stays_moderate() {
365        let mut s = 0u8;
366        // Alternating bits should keep probability moderate
367        for i in 0..100 {
368            s = StateTable::next(s, (i & 1) as u8);
369        }
370        let p = StateTable::prob(s);
371        assert!(
372            (500..=3500).contains(&p),
373            "alternating bits should give moderate P(1): {p}"
374        );
375    }
376}