Skip to main content

datacortex_core/model/
cm_model.rs

1//! ContextModel -- flexible context model using ContextMap + StateMap.
2//!
3//! Phase 4+: Three CM model variants, each now producing 3 predictions:
4//! 1. **StateMap prediction** (p1): existing state->probability mapping
5//! 2. **Run-count prediction** (p2): consecutive same-bit count in this context
6//! 3. **Byte-history prediction** (p3): last byte seen in this context predicts current bit
7//!
8//! The run-count uses the same bit-level hash as the StateMap.
9//! The byte-history uses a separate byte-level hash (independent of partial byte c0)
10//! so that the stored byte is accessible at any bpos within the same byte context.
11
12use crate::state::context_map::{AssociativeContextMap, ChecksumContextMap, ContextMap};
13use crate::state::state_map::StateMap;
14use crate::state::state_table::StateTable;
15
16// --- Run Map: tracks consecutive same-bit runs per context ---
17
18/// Packed run info: bits [7] = last bit, bits [6:0] = run count (0-127).
19/// Stored in a lossy hash table like ContextMap.
20struct RunMap {
21    table: Vec<u8>,
22    mask: usize,
23}
24
25impl RunMap {
26    fn new(size: usize) -> Self {
27        RunMap {
28            table: vec![0u8; size],
29            mask: size - 1,
30        }
31    }
32
33    /// Get run info for context. Returns (run_count, run_bit).
34    /// run_count=0 means no history.
35    #[inline(always)]
36    fn get(&self, hash: u32) -> (u8, u8) {
37        let packed = self.table[hash as usize & self.mask];
38        let run_bit = packed >> 7;
39        let run_count = packed & 0x7F;
40        (run_count, run_bit)
41    }
42
43    /// Update run info after observing `bit`.
44    #[inline(always)]
45    fn update(&mut self, hash: u32, bit: u8) {
46        let idx = hash as usize & self.mask;
47        let packed = self.table[idx];
48        let run_bit = packed >> 7;
49        let run_count = packed & 0x7F;
50
51        let new_packed = if bit == run_bit && run_count > 0 {
52            // Continue run
53            let new_count = run_count.saturating_add(1).min(127);
54            (bit << 7) | new_count
55        } else {
56            // New run starts
57            (bit << 7) | 1
58        };
59        self.table[idx] = new_packed;
60    }
61
62    /// Convert run info to a 12-bit prediction.
63    #[inline(always)]
64    fn predict_p(&self, hash: u32) -> u32 {
65        let (run_count, run_bit) = self.get(hash);
66        if run_count == 0 {
67            return 2048; // no history
68        }
69        // Strength ramps linearly with run length, capped.
70        let strength = (run_count as u32 * 128).min(1800);
71        if run_bit == 1 {
72            (2048 + strength).min(4095)
73        } else {
74            2048u32.saturating_sub(strength).max(1)
75        }
76    }
77}
78
79/// Dual prediction from a context model.
80/// (state_p, run_p) -- both 12-bit probabilities in [1, 4095].
81/// - state_p: StateMap prediction (existing)
82/// - run_p: run-count continuation prediction (new)
83pub type DualPrediction = (u32, u32);
84
85/// A context model backed by a ContextMap (hash->state) + StateMap (state->prob).
86/// Now also produces run-count and byte-history predictions.
87pub struct ContextModel {
88    /// Hash table mapping context hashes to states.
89    cmap: ContextMap,
90    /// Adaptive state -> probability mapper.
91    smap: StateMap,
92    /// Run-count tracker per context.
93    run_map: RunMap,
94    /// Last looked-up state (for update after predict).
95    last_state: u8,
96    /// Last looked-up hash (for update after predict).
97    last_hash: u32,
98}
99
100impl ContextModel {
101    /// Create a new context model with the given ContextMap size.
102    pub fn new(cmap_size: usize) -> Self {
103        let aux_size = (cmap_size / 4).next_power_of_two().max(1024);
104        ContextModel {
105            cmap: ContextMap::new(cmap_size),
106            smap: StateMap::new(),
107            run_map: RunMap::new(aux_size),
108            last_state: 0,
109            last_hash: 0,
110        }
111    }
112
113    /// Predict probability of bit=1 for the given context hash.
114    /// Returns 12-bit probability in [1, 4095].
115    #[inline(always)]
116    pub fn predict(&mut self, hash: u32) -> u32 {
117        let state = self.cmap.get(hash);
118        self.last_state = state;
119        self.last_hash = hash;
120        self.smap.predict(state)
121    }
122
123    /// Predict dual: (state_p, run_p).
124    #[inline(always)]
125    pub fn predict_multi(&mut self, hash: u32) -> DualPrediction {
126        let state = self.cmap.get(hash);
127        self.last_state = state;
128        self.last_hash = hash;
129        let state_p = self.smap.predict(state);
130        let run_p = self.run_map.predict_p(hash);
131        (state_p, run_p)
132    }
133
134    /// Update the model after observing `bit`.
135    /// Must be called after predict() with the same context.
136    #[inline(always)]
137    pub fn update(&mut self, bit: u8) {
138        self.smap.update(self.last_state, bit);
139        let new_state = StateTable::next(self.last_state, bit);
140        self.cmap.set(self.last_hash, new_state);
141        self.run_map.update(self.last_hash, bit);
142    }
143
144    /// Notify model that a byte is complete (no-op).
145    #[inline(always)]
146    pub fn on_byte_complete(&mut self, _byte: u8) {}
147}
148
149/// A context model using ChecksumContextMap for reduced collision damage.
150pub struct ChecksumContextModel {
151    cmap: ChecksumContextMap,
152    smap: StateMap,
153    run_map: RunMap,
154    last_state: u8,
155    last_hash: u32,
156}
157
158impl ChecksumContextModel {
159    pub fn new(byte_size: usize) -> Self {
160        let aux_size = (byte_size / 4).next_power_of_two().max(1024);
161        ChecksumContextModel {
162            cmap: ChecksumContextMap::new(byte_size),
163            smap: StateMap::new(),
164            run_map: RunMap::new(aux_size),
165            last_state: 0,
166            last_hash: 0,
167        }
168    }
169
170    #[inline(always)]
171    pub fn predict(&mut self, hash: u32) -> u32 {
172        let state = self.cmap.get(hash);
173        self.last_state = state;
174        self.last_hash = hash;
175        self.smap.predict(state)
176    }
177
178    #[inline(always)]
179    pub fn predict_multi(&mut self, hash: u32) -> DualPrediction {
180        let state = self.cmap.get(hash);
181        self.last_state = state;
182        self.last_hash = hash;
183        let state_p = self.smap.predict(state);
184        let run_p = self.run_map.predict_p(hash);
185        (state_p, run_p)
186    }
187
188    #[inline(always)]
189    pub fn update(&mut self, bit: u8) {
190        self.smap.update(self.last_state, bit);
191        let new_state = StateTable::next(self.last_state, bit);
192        self.cmap.set(self.last_hash, new_state);
193        self.run_map.update(self.last_hash, bit);
194    }
195
196    #[inline(always)]
197    pub fn on_byte_complete(&mut self, _byte: u8) {}
198}
199
200/// A context model using 2-way set-associative ContextMap.
201/// Best for order-5+ where collision rates are highest.
202pub struct AssociativeContextModel {
203    cmap: AssociativeContextMap,
204    smap: StateMap,
205    run_map: RunMap,
206    last_state: u8,
207    last_hash: u32,
208}
209
210impl AssociativeContextModel {
211    pub fn new(byte_size: usize) -> Self {
212        let aux_size = (byte_size / 4).next_power_of_two().max(1024);
213        AssociativeContextModel {
214            cmap: AssociativeContextMap::new(byte_size),
215            smap: StateMap::new(),
216            run_map: RunMap::new(aux_size),
217            last_state: 0,
218            last_hash: 0,
219        }
220    }
221
222    #[inline(always)]
223    pub fn predict(&mut self, hash: u32) -> u32 {
224        let state = self.cmap.get(hash);
225        self.last_state = state;
226        self.last_hash = hash;
227        self.smap.predict(state)
228    }
229
230    #[inline(always)]
231    pub fn predict_multi(&mut self, hash: u32) -> DualPrediction {
232        let state = self.cmap.get(hash);
233        self.last_state = state;
234        self.last_hash = hash;
235        let state_p = self.smap.predict(state);
236        let run_p = self.run_map.predict_p(hash);
237        (state_p, run_p)
238    }
239
240    #[inline(always)]
241    pub fn update(&mut self, bit: u8) {
242        self.smap.update(self.last_state, bit);
243        let new_state = StateTable::next(self.last_state, bit);
244        self.cmap.set(self.last_hash, new_state);
245        self.run_map.update(self.last_hash, bit);
246    }
247
248    #[inline(always)]
249    pub fn on_byte_complete(&mut self, _byte: u8) {}
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn initial_prediction_balanced() {
258        let mut cm = ContextModel::new(1024);
259        let p = cm.predict(0);
260        assert_eq!(p, 2048);
261    }
262
263    #[test]
264    fn predict_update_changes_probability() {
265        let mut cm = ContextModel::new(1024);
266        let p1 = cm.predict(42);
267        cm.update(1);
268        let p2 = cm.predict(42);
269        assert_ne!(p1, p2, "update should change prediction");
270    }
271
272    #[test]
273    fn different_contexts_diverge() {
274        let mut cm = ContextModel::new(1024);
275        for _ in 0..20 {
276            cm.predict(10);
277            cm.update(1);
278        }
279        for _ in 0..20 {
280            cm.predict(20);
281            cm.update(0);
282        }
283        let p10 = cm.predict(10);
284        let p20 = cm.predict(20);
285        assert!(
286            p10 > p20,
287            "ctx 10 (all 1s) should predict higher than ctx 20 (all 0s): p10={p10}, p20={p20}"
288        );
289    }
290
291    #[test]
292    fn predictions_in_range() {
293        let mut cm = ContextModel::new(1024);
294        for i in 0..100u32 {
295            let p = cm.predict(i);
296            assert!((1..=4095).contains(&p));
297            cm.update((i & 1) as u8);
298        }
299    }
300
301    #[test]
302    fn multi_predict_returns_pair() {
303        let mut cm = ContextModel::new(1024);
304        let (sp, rp) = cm.predict_multi(42);
305        assert_eq!(sp, 2048);
306        assert_eq!(rp, 2048);
307    }
308
309    #[test]
310    fn run_prediction_adapts() {
311        let mut cm = ContextModel::new(1024);
312        for _ in 0..10 {
313            cm.predict_multi(42);
314            cm.update(1);
315        }
316        let (_, rp) = cm.predict_multi(42);
317        assert!(
318            rp > 2048,
319            "run prediction should favor 1 after many 1s: {rp}"
320        );
321    }
322
323    #[test]
324    fn dual_predictions_in_range() {
325        let mut cm = ContextModel::new(1024);
326        let (sp, rp) = cm.predict_multi(42);
327        assert!((1..=4095).contains(&sp));
328        assert!((1..=4095).contains(&rp));
329    }
330
331    // Checksummed variant tests
332
333    #[test]
334    fn checksum_initial_prediction_balanced() {
335        let mut cm = ChecksumContextModel::new(2048);
336        let p = cm.predict(0);
337        assert_eq!(p, 2048);
338    }
339
340    #[test]
341    fn checksum_predict_update() {
342        let mut cm = ChecksumContextModel::new(2048);
343        let p1 = cm.predict(42);
344        cm.update(1);
345        let p2 = cm.predict(42);
346        assert_ne!(p1, p2, "update should change prediction");
347    }
348
349    #[test]
350    fn checksum_predictions_in_range() {
351        let mut cm = ChecksumContextModel::new(2048);
352        for i in 0..100u32 {
353            let p = cm.predict(i);
354            assert!((1..=4095).contains(&p));
355            cm.update((i & 1) as u8);
356        }
357    }
358
359    #[test]
360    fn checksum_multi_predict() {
361        let mut cm = ChecksumContextModel::new(2048);
362        let (sp, rp) = cm.predict_multi(42);
363        assert_eq!(sp, 2048);
364        assert_eq!(rp, 2048);
365    }
366
367    #[test]
368    fn assoc_multi_predict() {
369        let mut cm = AssociativeContextModel::new(4096);
370        let (sp, rp) = cm.predict_multi(42);
371        assert_eq!(sp, 2048);
372        assert_eq!(rp, 2048);
373    }
374}