Skip to main content

datacortex_core/model/
order0.rs

1//! Order-0 Context Model — predicts next bit from partial byte context.
2//!
3//! Context = the partial byte being decoded. `c` starts at 1, and after each
4//! bit, c = (c << 1) | bit. After 8 bits, c holds the full byte value + 256
5//! and gets reset to 1 for the next byte.
6//!
7//! Each context (c in range 1..=255) maps to a state in the StateTable,
8//! which is then mapped to a probability via StateMap.
9
10use crate::state::state_map::StateMap;
11use crate::state::state_table::StateTable;
12
13/// Order-0 bit prediction model.
14///
15/// Uses 256 contexts (indexed by partial byte value c, range 1-255).
16/// Index 0 is unused (c starts at 1).
17pub struct Order0Model {
18    /// State per context (256 entries, index 0 unused).
19    states: [u8; 256],
20    /// Adaptive state → probability map.
21    state_map: StateMap,
22}
23
24impl Order0Model {
25    /// Create a new Order-0 model with all states initialized to 0.
26    pub fn new() -> Self {
27        Order0Model {
28            states: [0u8; 256],
29            state_map: StateMap::new(),
30        }
31    }
32
33    /// Predict the probability of bit=1 given the current partial byte context.
34    ///
35    /// `context`: partial byte value (c, range 1-255).
36    /// Returns: 12-bit probability in [1, 4095].
37    #[inline]
38    pub fn predict(&self, context: usize) -> u32 {
39        let state = self.states[context & 0xFF];
40        self.state_map.predict(state)
41    }
42
43    /// Update the model after observing `bit` in the given context.
44    ///
45    /// `context`: partial byte value (c, range 1-255).
46    /// `bit`: the observed bit (0 or 1).
47    #[inline]
48    pub fn update(&mut self, context: usize, bit: u8) {
49        let ctx = context & 0xFF;
50        let state = self.states[ctx];
51
52        // Update the state map (adaptive probability).
53        self.state_map.update(state, bit);
54
55        // Transition to next state.
56        self.states[ctx] = StateTable::next(state, bit);
57    }
58}
59
60impl Default for Order0Model {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn initial_prediction_is_balanced() {
72        let model = Order0Model::new();
73        // Context 1 (start of byte), state 0 → ~2048
74        let p = model.predict(1);
75        assert_eq!(p, 2048, "initial prediction should be 2048");
76    }
77
78    #[test]
79    fn prediction_in_range() {
80        let model = Order0Model::new();
81        for c in 1..=255 {
82            let p = model.predict(c);
83            assert!(
84                (1..=4095).contains(&p),
85                "context {c}: pred {p} out of range"
86            );
87        }
88    }
89
90    #[test]
91    fn update_adapts_prediction() {
92        let mut model = Order0Model::new();
93        let before = model.predict(1);
94        model.update(1, 1);
95        let after = model.predict(1);
96        // After seeing bit=1, the state transitions, and the state map
97        // probability for the OLD state gets updated. The prediction changes
98        // because the state itself changes.
99        assert_ne!(before, after, "prediction should change after update");
100    }
101
102    #[test]
103    fn different_contexts_have_separate_states() {
104        let mut model = Order0Model::new();
105        // Update context 10 with many 1s.
106        for _ in 0..50 {
107            model.update(10, 1);
108        }
109        // Update context 5 with many 0s.
110        for _ in 0..50 {
111            model.update(5, 0);
112        }
113        let p5 = model.predict(5);
114        let p10 = model.predict(10);
115        // Context 5 (trained on 0s) should predict lower than context 10 (trained on 1s).
116        assert!(
117            p10 > p5,
118            "context 10 (all 1s) should predict higher than context 5 (all 0s): p10={p10}, p5={p5}"
119        );
120    }
121
122    #[test]
123    fn simulate_byte_encoding() {
124        let mut model = Order0Model::new();
125        let byte: u8 = 0x42; // 01000010
126
127        let mut c: usize = 1;
128        for bpos in 0..8 {
129            let bit = (byte >> (7 - bpos)) & 1;
130            let _p = model.predict(c);
131            model.update(c, bit);
132            c = (c << 1) | bit as usize;
133        }
134        // After 8 bits, c should be byte + 256.
135        assert_eq!(c, 0x42 + 256);
136    }
137
138    #[test]
139    fn repeated_pattern_adapts() {
140        let mut model = Order0Model::new();
141        // Encode 'A' (0x41) many times — model should adapt.
142        let byte: u8 = 0x41;
143        let mut total_surprise: f64 = 0.0;
144        let mut first_byte_surprise: f64 = 0.0;
145
146        for iteration in 0..20 {
147            let mut c: usize = 1;
148            let mut byte_surprise: f64 = 0.0;
149            for bpos in 0..8 {
150                let bit = (byte >> (7 - bpos)) & 1;
151                let p = model.predict(c);
152                // Surprise = -log2(P(bit))
153                let prob_of_bit = if bit == 1 {
154                    p as f64 / 4096.0
155                } else {
156                    1.0 - p as f64 / 4096.0
157                };
158                byte_surprise += -prob_of_bit.log2();
159                model.update(c, bit);
160                c = (c << 1) | bit as usize;
161            }
162            if iteration == 0 {
163                first_byte_surprise = byte_surprise;
164            }
165            total_surprise += byte_surprise;
166        }
167
168        let last_avg = total_surprise / 20.0;
169        assert!(
170            last_avg < first_byte_surprise,
171            "model should improve: first byte = {first_byte_surprise:.2} bits, avg = {last_avg:.2} bits"
172        );
173    }
174}