datacortex_core/model/
indirect_model.rs1use crate::state::context_map::ContextMap;
16use crate::state::state_map::StateMap;
17use crate::state::state_table::StateTable;
18
19const PRED_TABLE_SIZE: usize = 1 << 20; const PRED_TABLE_MASK: usize = PRED_TABLE_SIZE - 1;
23
24const INDIRECT_CM_SIZE: usize = 1 << 23; const FNV_OFFSET: u32 = 0x811C9DC5;
30const FNV_PRIME: u32 = 0x01000193;
32
33pub struct IndirectModel {
35 prediction_table: Vec<u8>,
38 count_table: Vec<u8>,
40 context_map: ContextMap,
42 state_map: StateMap,
44 ctx_hash: u32,
46 predicted_byte: u8,
48 last_cm_hash: u32,
50 c0: u32,
52 c1: u8,
54 c2: u8,
56 c3: u8,
58 bpos: u8,
60}
61
62impl IndirectModel {
63 pub fn new() -> Self {
65 IndirectModel {
66 prediction_table: vec![0u8; PRED_TABLE_SIZE],
67 count_table: vec![0u8; PRED_TABLE_SIZE],
68 context_map: ContextMap::new(INDIRECT_CM_SIZE),
69 state_map: StateMap::new(),
70 ctx_hash: FNV_OFFSET,
71 predicted_byte: 0,
72 last_cm_hash: 0,
73 c0: 1,
74 c1: 0,
75 c2: 0,
76 c3: 0,
77 bpos: 0,
78 }
79 }
80
81 #[inline]
84 pub fn predict(&mut self, c0: u32, bpos: u8, c1: u8) -> u32 {
85 if bpos == 0 {
86 self.ctx_hash = indirect_hash(c1, self.c2, self.c3);
88 let idx = self.ctx_hash as usize & PRED_TABLE_MASK;
89 self.predicted_byte = self.prediction_table[idx];
90 }
91
92 let cm_hash = predicted_context_hash(self.predicted_byte, c0);
95 self.last_cm_hash = cm_hash;
96
97 let state = self.context_map.get(cm_hash);
99 self.state_map.predict(state)
100 }
101
102 #[inline]
104 pub fn update(&mut self, bit: u8) {
105 let state = self.context_map.get(self.last_cm_hash);
107 self.state_map.update(state, bit);
108 let new_state = StateTable::next(state, bit);
109 self.context_map.set(self.last_cm_hash, new_state);
110
111 self.c0 = (self.c0 << 1) | bit as u32;
113 self.bpos += 1;
114
115 if self.bpos >= 8 {
116 let byte = (self.c0 & 0xFF) as u8;
117
118 let idx = self.ctx_hash as usize & PRED_TABLE_MASK;
120 let current_pred = self.prediction_table[idx];
121 let current_count = self.count_table[idx];
122
123 if byte == current_pred {
124 self.count_table[idx] = current_count.saturating_add(1);
125 } else if current_count < 2 {
126 self.prediction_table[idx] = byte;
127 self.count_table[idx] = 1;
128 } else {
129 self.count_table[idx] = current_count.saturating_sub(1);
130 }
131
132 self.c3 = self.c2;
133 self.c2 = self.c1;
134 self.c1 = byte;
135 self.c0 = 1;
136 self.bpos = 0;
137 }
138 }
139}
140
141impl Default for IndirectModel {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147#[inline]
149fn indirect_hash(c1: u8, c2: u8, c3: u8) -> u32 {
150 let mut h = FNV_OFFSET;
151 h ^= c3 as u32;
152 h = h.wrapping_mul(FNV_PRIME);
153 h ^= c2 as u32;
154 h = h.wrapping_mul(FNV_PRIME);
155 h ^= c1 as u32;
156 h = h.wrapping_mul(FNV_PRIME);
157 h
158}
159
160#[inline]
163fn predicted_context_hash(predicted: u8, c0: u32) -> u32 {
164 let mut h = 0x9E3779B9u32; h ^= predicted as u32;
166 h = h.wrapping_mul(FNV_PRIME);
167 h ^= c0 & 0x1FF; h = h.wrapping_mul(FNV_PRIME);
169 h
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn initial_prediction_in_range() {
178 let mut model = IndirectModel::new();
179 let p = model.predict(1, 0, 0);
180 assert!(
181 (1..=4095).contains(&p),
182 "initial prediction should be in valid range, got {p}"
183 );
184 }
185
186 #[test]
187 fn predictions_in_range() {
188 let mut model = IndirectModel::new();
189 let data = b"Hello, World! The quick brown fox.";
190 for &byte in data {
191 for bpos in 0..8u8 {
192 let bit = (byte >> (7 - bpos)) & 1;
193 let c0 = if bpos == 0 {
194 1u32
195 } else {
196 let mut p = 1u32;
197 for prev in 0..bpos {
198 p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
199 }
200 p
201 };
202 let p = model.predict(
203 c0,
204 bpos,
205 if bpos == 0 {
206 byte.wrapping_sub(1)
207 } else {
208 byte
209 },
210 );
211 assert!(
212 (1..=4095).contains(&p),
213 "prediction out of range at bpos {bpos}: {p}"
214 );
215 model.update(bit);
216 }
217 }
218 }
219
220 #[test]
221 fn prediction_table_updates() {
222 let mut model = IndirectModel::new();
223 let pattern = b"abcdabcdabcd";
224 for &byte in pattern {
225 for bpos in 0..8u8 {
226 let bit = (byte >> (7 - bpos)) & 1;
227 let c0 = if bpos == 0 {
228 1u32
229 } else {
230 let mut p = 1u32;
231 for prev in 0..bpos {
232 p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
233 }
234 p
235 };
236 let _ = model.predict(c0, bpos, model.c1);
237 model.update(bit);
238 }
239 }
240 let idx = indirect_hash(b'c', b'b', b'a') as usize & PRED_TABLE_MASK;
241 assert_eq!(
242 model.prediction_table[idx], b'd',
243 "prediction table should predict 'd' after 'abc'"
244 );
245 }
246
247 #[test]
248 fn deterministic() {
249 let data = b"test determinism of indirect model";
250 let mut m1 = IndirectModel::new();
251 let mut m2 = IndirectModel::new();
252
253 for &byte in data {
254 for bpos in 0..8u8 {
255 let bit = (byte >> (7 - bpos)) & 1;
256 let c0 = if bpos == 0 {
257 1u32
258 } else {
259 let mut p = 1u32;
260 for prev in 0..bpos {
261 p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
262 }
263 p
264 };
265 let p1 = m1.predict(c0, bpos, m1.c1);
266 let p2 = m2.predict(c0, bpos, m2.c1);
267 assert_eq!(p1, p2, "models diverged at bpos {bpos}");
268 m1.update(bit);
269 m2.update(bit);
270 }
271 }
272 }
273}