1const MAX_ORDER: usize = 12;
21
22const MAX_SYMS: usize = 48;
24
25const FNV_OFFSET: u32 = 0x811C_9DC5;
27const FNV_PRIME: u32 = 0x0100_0193;
29
30const NUM_ORDERS: usize = MAX_ORDER + 1;
32
33#[derive(Clone, Copy)]
35struct PpmEntry {
36 checksum: u16,
38 syms: [u8; MAX_SYMS],
40 counts: [u16; MAX_SYMS],
42 len: u8,
44 total: u16,
46}
47
48impl PpmEntry {
49 const EMPTY: Self = PpmEntry {
50 checksum: 0,
51 syms: [0; MAX_SYMS],
52 counts: [0; MAX_SYMS],
53 len: 0,
54 total: 0,
55 };
56
57 #[inline]
58 fn increment(&mut self, symbol: u8) {
59 let n = self.len as usize;
60 for i in 0..n {
61 if self.syms[i] == symbol {
62 self.counts[i] = self.counts[i].saturating_add(1);
63 self.total = self.total.saturating_add(1);
64 return;
65 }
66 }
67 if n < MAX_SYMS {
68 self.syms[n] = symbol;
69 self.counts[n] = 1;
70 self.len += 1;
71 self.total = self.total.saturating_add(1);
72 }
73 }
74
75 fn halve(&mut self) {
76 let mut write = 0usize;
77 let mut new_total: u16 = 0;
78 for read in 0..self.len as usize {
79 let c = self.counts[read] >> 1;
80 if c > 0 {
81 self.syms[write] = self.syms[read];
82 self.counts[write] = c;
83 new_total = new_total.saturating_add(c);
84 write += 1;
85 }
86 }
87 self.len = write as u8;
88 self.total = new_total;
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct PpmConfig {
95 pub sizes: [usize; NUM_ORDERS],
97}
98
99impl PpmConfig {
100 pub fn default_sizes() -> Self {
102 PpmConfig {
103 sizes: [
104 1, 1 << 8, 1 << 16, 1 << 18, 1 << 19, 1 << 19, 1 << 19, 1 << 18, 1 << 18, 1 << 17, 1 << 17, 1 << 16, 1 << 16, ],
118 }
119 }
120
121 pub fn scaled_4x() -> Self {
123 PpmConfig {
124 sizes: [
125 1, 1 << 8, 1 << 16, 1 << 20, 1 << 21, 1 << 21, 1 << 21, 1 << 20, 1 << 20, 1 << 19, 1 << 19, 1 << 18, 1 << 18, ],
139 }
140 }
141}
142
143pub struct PpmModel {
153 tables: Vec<Box<[PpmEntry]>>,
155 masks: [usize; NUM_ORDERS],
157
158 byte_probs: [u32; 256],
160 probs_valid: bool,
162 context: [u8; MAX_ORDER],
164 bytes_seen: usize,
166}
167
168fn make_table(size: usize) -> Box<[PpmEntry]> {
169 vec![PpmEntry::EMPTY; size].into_boxed_slice()
170}
171
172impl PpmModel {
173 pub fn new() -> Self {
175 Self::with_config(PpmConfig::default_sizes())
176 }
177
178 pub fn with_config(config: PpmConfig) -> Self {
180 let mut tables = Vec::with_capacity(NUM_ORDERS);
181 let mut masks = [0usize; NUM_ORDERS];
182 for (i, &size) in config.sizes.iter().enumerate() {
183 tables.push(make_table(size));
184 masks[i] = size - 1;
185 }
186
187 PpmModel {
188 tables,
189 masks,
190 byte_probs: [0u32; 256],
191 probs_valid: false,
192 context: [0u8; MAX_ORDER],
193 bytes_seen: 0,
194 }
195 }
196
197 #[inline]
199 pub fn predict_bit(&mut self, bpos: u8, c0: u32) -> u32 {
200 if !self.probs_valid {
201 self.compute_byte_probs();
202 self.probs_valid = true;
203 }
204
205 let bit_pos = 7 - bpos;
206 let mask = 1u8 << bit_pos;
207
208 let mut sum_one: u64 = 0;
209 let mut sum_zero: u64 = 0;
210
211 if bpos == 0 {
212 for b in 0..256usize {
213 let p = self.byte_probs[b] as u64;
214 if (b as u8) & mask != 0 {
215 sum_one += p;
216 } else {
217 sum_zero += p;
218 }
219 }
220 } else {
221 let partial = (c0 & ((1u32 << bpos) - 1)) as u8;
222 let shift = 8 - bpos;
223 let base = (partial as usize) << shift;
224 let count = 1usize << shift;
225
226 for i in 0..count {
227 let b = base | i;
228 let p = self.byte_probs[b] as u64;
229 if (b as u8) & mask != 0 {
230 sum_one += p;
231 } else {
232 sum_zero += p;
233 }
234 }
235 }
236
237 let total = sum_one + sum_zero;
238 if total == 0 {
239 return 2048;
240 }
241
242 let p = ((sum_one << 12) / total) as u32;
243 p.clamp(1, 4095)
244 }
245
246 #[inline]
248 pub fn update_byte(&mut self, byte: u8) {
249 let max_usable_order = self.bytes_seen.min(MAX_ORDER);
250
251 for order in 0..=max_usable_order {
252 let (hash, chk) = self.context_hash_and_checksum(order);
253 let idx = hash as usize & self.masks[order];
254 let entry = &mut self.tables[order][idx];
255
256 if entry.checksum == 0 || entry.checksum == chk {
257 entry.checksum = chk;
258 entry.increment(byte);
259 if entry.total > 4000 {
260 entry.halve();
261 }
262 } else {
263 if entry.total < 4 {
265 *entry = PpmEntry::EMPTY;
266 entry.checksum = chk;
267 entry.increment(byte);
268 }
269 }
270 }
271
272 for i in (1..MAX_ORDER).rev() {
274 self.context[i] = self.context[i - 1];
275 }
276 self.context[0] = byte;
277 self.bytes_seen += 1;
278 self.probs_valid = false;
279 }
280
281 fn compute_byte_probs(&mut self) {
283 let max_usable_order = self.bytes_seen.min(MAX_ORDER);
284
285 let mut excluded = [false; 256];
286 let mut probs = [0u64; 256];
287 let mut remaining_mass: u64 = 1 << 20;
288
289 for order in (0..=max_usable_order).rev() {
291 let (hash, chk) = self.context_hash_and_checksum(order);
292 let idx = hash as usize & self.masks[order];
293 let entry = &self.tables[order][idx];
294
295 if entry.checksum != chk || entry.total == 0 || entry.len == 0 {
297 continue;
298 }
299
300 let mut effective_total: u32 = 0;
301 let mut effective_distinct: u32 = 0;
302
303 let n = entry.len as usize;
304 for i in 0..n {
305 if !excluded[entry.syms[i] as usize] {
306 effective_total += entry.counts[i] as u32;
307 effective_distinct += 1;
308 }
309 }
310
311 if effective_total == 0 || effective_distinct == 0 {
312 continue;
313 }
314
315 let escape_d = effective_distinct.div_ceil(2);
317 let denominator = effective_total + escape_d;
318
319 let symbol_mass = (remaining_mass * effective_total as u64) / denominator as u64;
320 let escape_frac = remaining_mass - symbol_mass;
321
322 for i in 0..n {
323 let sym = entry.syms[i];
324 if !excluded[sym as usize] {
325 let sym_prob = (symbol_mass * entry.counts[i] as u64) / effective_total as u64;
326 probs[sym as usize] += sym_prob;
327 excluded[sym as usize] = true;
328 }
329 }
330
331 remaining_mass = escape_frac;
332 if remaining_mass == 0 {
333 break;
334 }
335 }
336
337 if remaining_mass > 0 {
339 let mut unseen: u32 = 0;
340 for e in &excluded {
341 if !e {
342 unseen += 1;
343 }
344 }
345 if unseen > 0 {
346 let per_sym = remaining_mass / unseen as u64;
347 let mut leftover = remaining_mass - per_sym * unseen as u64;
348 for i in 0..256 {
349 if !excluded[i] {
350 probs[i] += per_sym;
351 if leftover > 0 {
352 probs[i] += 1;
353 leftover -= 1;
354 }
355 }
356 }
357 }
358 }
359
360 for (i, &p) in probs.iter().enumerate() {
361 self.byte_probs[i] = p as u32;
362 }
363 }
364
365 #[inline]
367 fn context_hash_and_checksum(&self, order: usize) -> (u32, u16) {
368 if order == 0 {
369 return (0, 1);
371 }
372 let mut h = FNV_OFFSET;
373 for i in 0..order {
374 h ^= self.context[i] as u32;
375 h = h.wrapping_mul(FNV_PRIME);
376 }
377 let chk = ((h >> 16) as u16) | 1; (h, chk)
379 }
380}
381
382impl Default for PpmModel {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn initial_prediction_balanced() {
394 let mut model = PpmModel::new();
395 let p = model.predict_bit(0, 1);
396 assert!(
397 (1900..=2100).contains(&p),
398 "initial prediction should be near 2048, got {p}"
399 );
400 }
401
402 #[test]
403 fn prediction_always_in_range() {
404 let mut model = PpmModel::new();
405 let data = b"Hello, World! This is a test of the PPM model for prediction.";
406 for &byte in data {
407 for bpos in 0..8u8 {
408 let c0 = if bpos == 0 {
409 1u32
410 } else {
411 let mut p = 1u32;
412 for prev in 0..bpos {
413 p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
414 }
415 p
416 };
417 let p = model.predict_bit(bpos, c0);
418 assert!(
419 (1..=4095).contains(&p),
420 "prediction out of range at bpos {bpos}: {p}"
421 );
422 }
423 model.update_byte(byte);
424 }
425 }
426
427 #[test]
428 fn adapts_to_repeated_bytes() {
429 let mut model = PpmModel::new();
430 let byte = b'A';
431 for _ in 0..100 {
432 model.update_byte(byte);
433 }
434 let p = model.predict_bit(0, 1);
435 assert!(
437 p < 1500,
438 "after 100 'A' bytes, P(bit7=1) should be low, got {p}"
439 );
440 }
441
442 #[test]
443 fn adapts_to_repeated_pattern() {
444 let mut model = PpmModel::new();
445 let pattern = b"abcdefgh";
446 for _ in 0..200 {
447 for &byte in pattern {
448 model.update_byte(byte);
449 }
450 }
451 for &byte in b"abcdefg" {
452 model.update_byte(byte);
453 }
454 model.compute_byte_probs();
455 let p_h = model.byte_probs[b'h' as usize];
456 assert!(
457 p_h > 100_000,
458 "after 'abcdefg', P('h') should be significant, got {p_h} / 1048576"
459 );
460 }
461
462 #[test]
463 fn byte_probs_sum_correctly() {
464 let mut model = PpmModel::new();
465 let data = b"the quick brown fox jumps over the lazy dog the cat sat on the mat";
466 for &byte in data.iter() {
467 model.update_byte(byte);
468 }
469 model.compute_byte_probs();
470 let total: u64 = model.byte_probs.iter().map(|&p| p as u64).sum();
471 assert!(
472 (1_000_000..=1_100_000).contains(&total),
473 "byte_probs should sum to ~1M, got {total}"
474 );
475 }
476
477 #[test]
478 fn exclusion_works() {
479 let mut model = PpmModel::new();
480 for _ in 0..100 {
481 model.update_byte(b'a');
482 model.update_byte(b'b');
483 }
484 model.update_byte(b'a');
485 model.compute_byte_probs();
486 let p_b = model.byte_probs[b'b' as usize];
487 let p_a = model.byte_probs[b'a' as usize];
488 assert!(
489 p_b > p_a * 2,
490 "after 'a', P('b')={p_b} should be >> P('a')={p_a}"
491 );
492 }
493
494 #[test]
495 fn deterministic() {
496 let data = b"test determinism of ppm model with enough context abcabc";
497 let mut m1 = PpmModel::new();
498 let mut m2 = PpmModel::new();
499
500 for &byte in data.iter() {
501 for bpos in 0..8u8 {
502 let c0 = if bpos == 0 {
503 1u32
504 } else {
505 let mut p = 1u32;
506 for prev in 0..bpos {
507 p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
508 }
509 p
510 };
511 let p1 = m1.predict_bit(bpos, c0);
512 let p2 = m2.predict_bit(bpos, c0);
513 assert_eq!(p1, p2, "models diverged at bpos {bpos}");
514 }
515 m1.update_byte(byte);
516 m2.update_byte(byte);
517 }
518 }
519
520 #[test]
521 fn solo_bpb_alice29_prefix() {
522 let data = include_bytes!("../../../../corpus/alice29.txt");
523 let prefix = &data[..10_000.min(data.len())];
524
525 let mut model = PpmModel::new();
526 let mut total_bits: f64 = 0.0;
527
528 for &byte in prefix {
529 let mut c0 = 1u32;
530 for bpos in 0..8u8 {
531 let p = model.predict_bit(bpos, c0);
532 let bit = (byte >> (7 - bpos)) & 1;
533 let prob_of_bit = if bit == 1 {
534 p as f64 / 4096.0
535 } else {
536 1.0 - p as f64 / 4096.0
537 };
538 total_bits += -prob_of_bit.max(1e-9).log2();
539 c0 = (c0 << 1) | bit as u32;
540 }
541 model.update_byte(byte);
542 }
543
544 let bpb = total_bits / prefix.len() as f64;
545 eprintln!("PPM solo bpb on 10KB alice29 (orders 0-{MAX_ORDER}): {bpb:.3}");
546 assert!(bpb < 6.0, "PPM solo bpb too high: {bpb:.3}");
547 }
548
549 #[test]
550 fn ppm_entry_increment_and_halve() {
551 let mut entry = PpmEntry::EMPTY;
552 entry.checksum = 1;
553 entry.increment(b'a');
554 entry.increment(b'a');
555 entry.increment(b'b');
556 assert_eq!(entry.len, 2);
557 assert_eq!(entry.total, 3);
558
559 entry.halve();
560 assert_eq!(entry.len, 1);
561 assert_eq!(entry.counts[0], 1);
562 assert_eq!(entry.total, 1);
563 }
564}