1use crate::state::context_map::{AssociativeContextMap, ChecksumContextMap, ContextMap};
13use crate::state::state_map::StateMap;
14use crate::state::state_table::StateTable;
15
16struct RunMap {
21 table: Vec<u8>,
22 mask: usize,
23}
24
25impl RunMap {
26 fn new(size: usize) -> Self {
27 RunMap {
28 table: vec![0u8; size],
29 mask: size - 1,
30 }
31 }
32
33 #[inline(always)]
36 fn get(&self, hash: u32) -> (u8, u8) {
37 let packed = self.table[hash as usize & self.mask];
38 let run_bit = packed >> 7;
39 let run_count = packed & 0x7F;
40 (run_count, run_bit)
41 }
42
43 #[inline(always)]
45 fn update(&mut self, hash: u32, bit: u8) {
46 let idx = hash as usize & self.mask;
47 let packed = self.table[idx];
48 let run_bit = packed >> 7;
49 let run_count = packed & 0x7F;
50
51 let new_packed = if bit == run_bit && run_count > 0 {
52 let new_count = run_count.saturating_add(1).min(127);
54 (bit << 7) | new_count
55 } else {
56 (bit << 7) | 1
58 };
59 self.table[idx] = new_packed;
60 }
61
62 #[inline(always)]
64 fn predict_p(&self, hash: u32) -> u32 {
65 let (run_count, run_bit) = self.get(hash);
66 if run_count == 0 {
67 return 2048; }
69 let strength = (run_count as u32 * 128).min(1800);
71 if run_bit == 1 {
72 (2048 + strength).min(4095)
73 } else {
74 2048u32.saturating_sub(strength).max(1)
75 }
76 }
77}
78
79pub type DualPrediction = (u32, u32);
84
85pub struct ContextModel {
88 cmap: ContextMap,
90 smap: StateMap,
92 run_map: RunMap,
94 last_state: u8,
96 last_hash: u32,
98}
99
100impl ContextModel {
101 pub fn new(cmap_size: usize) -> Self {
103 let aux_size = (cmap_size / 4).next_power_of_two().max(1024);
104 ContextModel {
105 cmap: ContextMap::new(cmap_size),
106 smap: StateMap::new(),
107 run_map: RunMap::new(aux_size),
108 last_state: 0,
109 last_hash: 0,
110 }
111 }
112
113 #[inline(always)]
116 pub fn predict(&mut self, hash: u32) -> u32 {
117 let state = self.cmap.get(hash);
118 self.last_state = state;
119 self.last_hash = hash;
120 self.smap.predict(state)
121 }
122
123 #[inline(always)]
125 pub fn predict_multi(&mut self, hash: u32) -> DualPrediction {
126 let state = self.cmap.get(hash);
127 self.last_state = state;
128 self.last_hash = hash;
129 let state_p = self.smap.predict(state);
130 let run_p = self.run_map.predict_p(hash);
131 (state_p, run_p)
132 }
133
134 #[inline(always)]
137 pub fn update(&mut self, bit: u8) {
138 self.smap.update(self.last_state, bit);
139 let new_state = StateTable::next(self.last_state, bit);
140 self.cmap.set(self.last_hash, new_state);
141 self.run_map.update(self.last_hash, bit);
142 }
143
144 #[inline(always)]
146 pub fn on_byte_complete(&mut self, _byte: u8) {}
147}
148
149pub struct ChecksumContextModel {
151 cmap: ChecksumContextMap,
152 smap: StateMap,
153 run_map: RunMap,
154 last_state: u8,
155 last_hash: u32,
156}
157
158impl ChecksumContextModel {
159 pub fn new(byte_size: usize) -> Self {
160 let aux_size = (byte_size / 4).next_power_of_two().max(1024);
161 ChecksumContextModel {
162 cmap: ChecksumContextMap::new(byte_size),
163 smap: StateMap::new(),
164 run_map: RunMap::new(aux_size),
165 last_state: 0,
166 last_hash: 0,
167 }
168 }
169
170 #[inline(always)]
171 pub fn predict(&mut self, hash: u32) -> u32 {
172 let state = self.cmap.get(hash);
173 self.last_state = state;
174 self.last_hash = hash;
175 self.smap.predict(state)
176 }
177
178 #[inline(always)]
179 pub fn predict_multi(&mut self, hash: u32) -> DualPrediction {
180 let state = self.cmap.get(hash);
181 self.last_state = state;
182 self.last_hash = hash;
183 let state_p = self.smap.predict(state);
184 let run_p = self.run_map.predict_p(hash);
185 (state_p, run_p)
186 }
187
188 #[inline(always)]
189 pub fn update(&mut self, bit: u8) {
190 self.smap.update(self.last_state, bit);
191 let new_state = StateTable::next(self.last_state, bit);
192 self.cmap.set(self.last_hash, new_state);
193 self.run_map.update(self.last_hash, bit);
194 }
195
196 #[inline(always)]
197 pub fn on_byte_complete(&mut self, _byte: u8) {}
198}
199
200pub struct AssociativeContextModel {
203 cmap: AssociativeContextMap,
204 smap: StateMap,
205 run_map: RunMap,
206 last_state: u8,
207 last_hash: u32,
208}
209
210impl AssociativeContextModel {
211 pub fn new(byte_size: usize) -> Self {
212 let aux_size = (byte_size / 4).next_power_of_two().max(1024);
213 AssociativeContextModel {
214 cmap: AssociativeContextMap::new(byte_size),
215 smap: StateMap::new(),
216 run_map: RunMap::new(aux_size),
217 last_state: 0,
218 last_hash: 0,
219 }
220 }
221
222 #[inline(always)]
223 pub fn predict(&mut self, hash: u32) -> u32 {
224 let state = self.cmap.get(hash);
225 self.last_state = state;
226 self.last_hash = hash;
227 self.smap.predict(state)
228 }
229
230 #[inline(always)]
231 pub fn predict_multi(&mut self, hash: u32) -> DualPrediction {
232 let state = self.cmap.get(hash);
233 self.last_state = state;
234 self.last_hash = hash;
235 let state_p = self.smap.predict(state);
236 let run_p = self.run_map.predict_p(hash);
237 (state_p, run_p)
238 }
239
240 #[inline(always)]
241 pub fn update(&mut self, bit: u8) {
242 self.smap.update(self.last_state, bit);
243 let new_state = StateTable::next(self.last_state, bit);
244 self.cmap.set(self.last_hash, new_state);
245 self.run_map.update(self.last_hash, bit);
246 }
247
248 #[inline(always)]
249 pub fn on_byte_complete(&mut self, _byte: u8) {}
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn initial_prediction_balanced() {
258 let mut cm = ContextModel::new(1024);
259 let p = cm.predict(0);
260 assert_eq!(p, 2048);
261 }
262
263 #[test]
264 fn predict_update_changes_probability() {
265 let mut cm = ContextModel::new(1024);
266 let p1 = cm.predict(42);
267 cm.update(1);
268 let p2 = cm.predict(42);
269 assert_ne!(p1, p2, "update should change prediction");
270 }
271
272 #[test]
273 fn different_contexts_diverge() {
274 let mut cm = ContextModel::new(1024);
275 for _ in 0..20 {
276 cm.predict(10);
277 cm.update(1);
278 }
279 for _ in 0..20 {
280 cm.predict(20);
281 cm.update(0);
282 }
283 let p10 = cm.predict(10);
284 let p20 = cm.predict(20);
285 assert!(
286 p10 > p20,
287 "ctx 10 (all 1s) should predict higher than ctx 20 (all 0s): p10={p10}, p20={p20}"
288 );
289 }
290
291 #[test]
292 fn predictions_in_range() {
293 let mut cm = ContextModel::new(1024);
294 for i in 0..100u32 {
295 let p = cm.predict(i);
296 assert!((1..=4095).contains(&p));
297 cm.update((i & 1) as u8);
298 }
299 }
300
301 #[test]
302 fn multi_predict_returns_pair() {
303 let mut cm = ContextModel::new(1024);
304 let (sp, rp) = cm.predict_multi(42);
305 assert_eq!(sp, 2048);
306 assert_eq!(rp, 2048);
307 }
308
309 #[test]
310 fn run_prediction_adapts() {
311 let mut cm = ContextModel::new(1024);
312 for _ in 0..10 {
313 cm.predict_multi(42);
314 cm.update(1);
315 }
316 let (_, rp) = cm.predict_multi(42);
317 assert!(
318 rp > 2048,
319 "run prediction should favor 1 after many 1s: {rp}"
320 );
321 }
322
323 #[test]
324 fn dual_predictions_in_range() {
325 let mut cm = ContextModel::new(1024);
326 let (sp, rp) = cm.predict_multi(42);
327 assert!((1..=4095).contains(&sp));
328 assert!((1..=4095).contains(&rp));
329 }
330
331 #[test]
334 fn checksum_initial_prediction_balanced() {
335 let mut cm = ChecksumContextModel::new(2048);
336 let p = cm.predict(0);
337 assert_eq!(p, 2048);
338 }
339
340 #[test]
341 fn checksum_predict_update() {
342 let mut cm = ChecksumContextModel::new(2048);
343 let p1 = cm.predict(42);
344 cm.update(1);
345 let p2 = cm.predict(42);
346 assert_ne!(p1, p2, "update should change prediction");
347 }
348
349 #[test]
350 fn checksum_predictions_in_range() {
351 let mut cm = ChecksumContextModel::new(2048);
352 for i in 0..100u32 {
353 let p = cm.predict(i);
354 assert!((1..=4095).contains(&p));
355 cm.update((i & 1) as u8);
356 }
357 }
358
359 #[test]
360 fn checksum_multi_predict() {
361 let mut cm = ChecksumContextModel::new(2048);
362 let (sp, rp) = cm.predict_multi(42);
363 assert_eq!(sp, 2048);
364 assert_eq!(rp, 2048);
365 }
366
367 #[test]
368 fn assoc_multi_predict() {
369 let mut cm = AssociativeContextModel::new(4096);
370 let (sp, rp) = cm.predict_multi(42);
371 assert_eq!(sp, 2048);
372 assert_eq!(rp, 2048);
373 }
374}