Skip to main content

datacortex_core/model/
indirect_model.rs

1//! IndirectModel — second-order context prediction.
2//!
3//! Instead of "what follows this context?", predicts "what follows the byte
4//! that USUALLY follows this context?"
5//!
6//! Maintains two tables:
7//! 1. prediction_table: context_hash -> most-likely-next-byte (updated byte-by-byte)
8//! 2. ContextMap: uses (predicted_byte, c0_partial, bpos) as context for bit prediction
9//!
10//! This captures second-order sequential patterns:
11//! "after 'th' comes 'e', and after 'e' in this position comes ' '"
12//!
13//! Proven effective in PAQ8PX's IndirectModel: typically -0.03 to -0.08 bpb.
14
15use crate::state::context_map::ContextMap;
16use crate::state::state_map::StateMap;
17use crate::state::state_table::StateTable;
18
19/// Size of the prediction table (must be power of 2).
20/// 1M entries = 2MB (1 byte prediction + 1 byte count per entry).
21const PRED_TABLE_SIZE: usize = 1 << 20; // 1M entries
22const PRED_TABLE_MASK: usize = PRED_TABLE_SIZE - 1;
23
24/// Size of the indirect ContextMap (must be power of 2).
25/// Larger = fewer collisions for the (predicted_byte, partial) contexts.
26const INDIRECT_CM_SIZE: usize = 1 << 23; // 8MB
27
28/// FNV offset basis for hashing.
29const FNV_OFFSET: u32 = 0x811C9DC5;
30/// FNV prime for hashing.
31const FNV_PRIME: u32 = 0x01000193;
32
33/// Indirect context model (model #16 in the mixer).
34pub struct IndirectModel {
35    /// Table 1: context_hash -> predicted next byte.
36    /// Each entry stores the most commonly seen next byte for that context.
37    prediction_table: Vec<u8>,
38    /// Count table: tracks confidence in the prediction.
39    count_table: Vec<u8>,
40    /// Table 2: maps (predicted_byte, c0) context hash -> state byte.
41    context_map: ContextMap,
42    /// StateMap: converts state byte to 12-bit probability.
43    state_map: StateMap,
44    /// Context hash used for prediction table lookup.
45    ctx_hash: u32,
46    /// Last predicted byte from the table.
47    predicted_byte: u8,
48    /// Last ContextMap hash used (for update).
49    last_cm_hash: u32,
50    /// Current partial byte (1-255).
51    c0: u32,
52    /// Last full byte.
53    c1: u8,
54    /// Second-to-last byte.
55    c2: u8,
56    /// Third-to-last byte.
57    c3: u8,
58    /// Bit position (0-7).
59    bpos: u8,
60}
61
62impl IndirectModel {
63    /// Create a new indirect model.
64    pub fn new() -> Self {
65        IndirectModel {
66            prediction_table: vec![0u8; PRED_TABLE_SIZE],
67            count_table: vec![0u8; PRED_TABLE_SIZE],
68            context_map: ContextMap::new(INDIRECT_CM_SIZE),
69            state_map: StateMap::new(),
70            ctx_hash: FNV_OFFSET,
71            predicted_byte: 0,
72            last_cm_hash: 0,
73            c0: 1,
74            c1: 0,
75            c2: 0,
76            c3: 0,
77            bpos: 0,
78        }
79    }
80
81    /// Predict the probability of the next bit being 1.
82    /// Returns 12-bit probability in [1, 4095].
83    #[inline]
84    pub fn predict(&mut self, c0: u32, bpos: u8, c1: u8) -> u32 {
85        if bpos == 0 {
86            // At byte boundary: look up prediction for this context.
87            self.ctx_hash = indirect_hash(c1, self.c2, self.c3);
88            let idx = self.ctx_hash as usize & PRED_TABLE_MASK;
89            self.predicted_byte = self.prediction_table[idx];
90        }
91
92        // Build context hash from predicted_byte + c0 partial byte.
93        // This gives the ContextMap a unique slot per (predicted_byte, partial_bit_pattern).
94        let cm_hash = predicted_context_hash(self.predicted_byte, c0);
95        self.last_cm_hash = cm_hash;
96
97        // Look up state in ContextMap, convert to probability via StateMap.
98        let state = self.context_map.get(cm_hash);
99        self.state_map.predict(state)
100    }
101
102    /// Update after observing `bit`.
103    #[inline]
104    pub fn update(&mut self, bit: u8) {
105        // Update ContextMap state for the context we just predicted with.
106        let state = self.context_map.get(self.last_cm_hash);
107        self.state_map.update(state, bit);
108        let new_state = StateTable::next(state, bit);
109        self.context_map.set(self.last_cm_hash, new_state);
110
111        // Track partial byte.
112        self.c0 = (self.c0 << 1) | bit as u32;
113        self.bpos += 1;
114
115        if self.bpos >= 8 {
116            let byte = (self.c0 & 0xFF) as u8;
117
118            // Update prediction table: for the PREVIOUS context, record what byte actually came.
119            let idx = self.ctx_hash as usize & PRED_TABLE_MASK;
120            let current_pred = self.prediction_table[idx];
121            let current_count = self.count_table[idx];
122
123            if byte == current_pred {
124                self.count_table[idx] = current_count.saturating_add(1);
125            } else if current_count < 2 {
126                self.prediction_table[idx] = byte;
127                self.count_table[idx] = 1;
128            } else {
129                self.count_table[idx] = current_count.saturating_sub(1);
130            }
131
132            self.c3 = self.c2;
133            self.c2 = self.c1;
134            self.c1 = byte;
135            self.c0 = 1;
136            self.bpos = 0;
137        }
138    }
139}
140
141impl Default for IndirectModel {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147/// Hash function for indirect model context (3 bytes).
148#[inline]
149fn indirect_hash(c1: u8, c2: u8, c3: u8) -> u32 {
150    let mut h = FNV_OFFSET;
151    h ^= c3 as u32;
152    h = h.wrapping_mul(FNV_PRIME);
153    h ^= c2 as u32;
154    h = h.wrapping_mul(FNV_PRIME);
155    h ^= c1 as u32;
156    h = h.wrapping_mul(FNV_PRIME);
157    h
158}
159
160/// Hash combining predicted byte with partial byte context (c0).
161/// This gives each (predicted_byte, bit_pattern) pair a distinct context slot.
162#[inline]
163fn predicted_context_hash(predicted: u8, c0: u32) -> u32 {
164    let mut h = 0x9E3779B9u32; // golden ratio seed for different hash space
165    h ^= predicted as u32;
166    h = h.wrapping_mul(FNV_PRIME);
167    h ^= c0 & 0x1FF; // c0 is 1-511 during a byte (9 bits)
168    h = h.wrapping_mul(FNV_PRIME);
169    h
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn initial_prediction_in_range() {
178        let mut model = IndirectModel::new();
179        let p = model.predict(1, 0, 0);
180        assert!(
181            (1..=4095).contains(&p),
182            "initial prediction should be in valid range, got {p}"
183        );
184    }
185
186    #[test]
187    fn predictions_in_range() {
188        let mut model = IndirectModel::new();
189        let data = b"Hello, World! The quick brown fox.";
190        for &byte in data {
191            for bpos in 0..8u8 {
192                let bit = (byte >> (7 - bpos)) & 1;
193                let c0 = if bpos == 0 {
194                    1u32
195                } else {
196                    let mut p = 1u32;
197                    for prev in 0..bpos {
198                        p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
199                    }
200                    p
201                };
202                let p = model.predict(
203                    c0,
204                    bpos,
205                    if bpos == 0 {
206                        byte.wrapping_sub(1)
207                    } else {
208                        byte
209                    },
210                );
211                assert!(
212                    (1..=4095).contains(&p),
213                    "prediction out of range at bpos {bpos}: {p}"
214                );
215                model.update(bit);
216            }
217        }
218    }
219
220    #[test]
221    fn prediction_table_updates() {
222        let mut model = IndirectModel::new();
223        let pattern = b"abcdabcdabcd";
224        for &byte in pattern {
225            for bpos in 0..8u8 {
226                let bit = (byte >> (7 - bpos)) & 1;
227                let c0 = if bpos == 0 {
228                    1u32
229                } else {
230                    let mut p = 1u32;
231                    for prev in 0..bpos {
232                        p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
233                    }
234                    p
235                };
236                let _ = model.predict(c0, bpos, model.c1);
237                model.update(bit);
238            }
239        }
240        let idx = indirect_hash(b'c', b'b', b'a') as usize & PRED_TABLE_MASK;
241        assert_eq!(
242            model.prediction_table[idx], b'd',
243            "prediction table should predict 'd' after 'abc'"
244        );
245    }
246
247    #[test]
248    fn deterministic() {
249        let data = b"test determinism of indirect model";
250        let mut m1 = IndirectModel::new();
251        let mut m2 = IndirectModel::new();
252
253        for &byte in data {
254            for bpos in 0..8u8 {
255                let bit = (byte >> (7 - bpos)) & 1;
256                let c0 = if bpos == 0 {
257                    1u32
258                } else {
259                    let mut p = 1u32;
260                    for prev in 0..bpos {
261                        p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
262                    }
263                    p
264                };
265                let p1 = m1.predict(c0, bpos, m1.c1);
266                let p2 = m2.predict(c0, bpos, m2.c1);
267                assert_eq!(p1, p2, "models diverged at bpos {bpos}");
268                m1.update(bit);
269                m2.update(bit);
270            }
271        }
272    }
273}