Skip to main content

datacortex_core/model/
ppm_model.rs

1//! PPM (Prediction by Partial Matching) byte-level predictor — PPMd variant.
2//!
3//! A fundamentally DIFFERENT prediction paradigm from CM:
4//! - CM: hash-based, lossy collisions, bit-level, fixed context orders
5//! - PPM: byte-level, adaptive order with escape/exclusion, checksum-validated
6//!
7//! Full-range PPM with orders 0-12. Lower orders (0-3) provide a good fallback
8//! on small files. Higher orders (10-12) go BEYOND what any CM model covers,
9//! providing unique prediction signal on large files. All orders use checksum
10//! validation to avoid hash collisions.
11//!
12//! The PPMd Method D escape estimation (Shkarin formula) and full exclusion
13//! mechanism provide a different error profile from CM's bit-level hash models.
14//!
15//! CRITICAL: PPM updates at BYTE level, not bit level. Only update after all
16//! 8 bits decoded. Byte probabilities are cached and converted to bit
17//! predictions on each bit.
18
19/// Maximum context order (up to 12 preceding bytes).
20const MAX_ORDER: usize = 12;
21
22/// Maximum symbols stored per context entry.
23const MAX_SYMS: usize = 48;
24
25/// FNV offset basis.
26const FNV_OFFSET: u32 = 0x811C_9DC5;
27/// FNV prime.
28const FNV_PRIME: u32 = 0x0100_0193;
29
30/// Number of orders: 0..=MAX_ORDER = 13.
31const NUM_ORDERS: usize = MAX_ORDER + 1;
32
33/// A flat PPM entry with checksum validation.
34#[derive(Clone, Copy)]
35struct PpmEntry {
36    /// Context checksum (upper 16 bits of hash). 0 = empty slot.
37    checksum: u16,
38    /// Symbols observed in this context.
39    syms: [u8; MAX_SYMS],
40    /// Counts for each symbol.
41    counts: [u16; MAX_SYMS],
42    /// Number of distinct symbols stored.
43    len: u8,
44    /// Sum of all counts.
45    total: u16,
46}
47
48impl PpmEntry {
49    const EMPTY: Self = PpmEntry {
50        checksum: 0,
51        syms: [0; MAX_SYMS],
52        counts: [0; MAX_SYMS],
53        len: 0,
54        total: 0,
55    };
56
57    #[inline]
58    fn increment(&mut self, symbol: u8) {
59        let n = self.len as usize;
60        for i in 0..n {
61            if self.syms[i] == symbol {
62                self.counts[i] = self.counts[i].saturating_add(1);
63                self.total = self.total.saturating_add(1);
64                return;
65            }
66        }
67        if n < MAX_SYMS {
68            self.syms[n] = symbol;
69            self.counts[n] = 1;
70            self.len += 1;
71            self.total = self.total.saturating_add(1);
72        }
73    }
74
75    fn halve(&mut self) {
76        let mut write = 0usize;
77        let mut new_total: u16 = 0;
78        for read in 0..self.len as usize {
79            let c = self.counts[read] >> 1;
80            if c > 0 {
81                self.syms[write] = self.syms[read];
82                self.counts[write] = c;
83                new_total = new_total.saturating_add(c);
84                write += 1;
85            }
86        }
87        self.len = write as u8;
88        self.total = new_total;
89    }
90}
91
92/// PPM table sizes configuration.
93#[derive(Debug, Clone)]
94pub struct PpmConfig {
95    /// Table sizes per order (0..=MAX_ORDER). Each must be a power of 2.
96    pub sizes: [usize; NUM_ORDERS],
97}
98
99impl PpmConfig {
100    /// Default (~90MB): original sizes.
101    pub fn default_sizes() -> Self {
102        PpmConfig {
103            sizes: [
104                1,       // order 0:  1 entry (unigram)
105                1 << 8,  // order 1:  256 entries
106                1 << 16, // order 2:  64K entries
107                1 << 18, // order 3:  256K entries
108                1 << 19, // order 4:  512K entries
109                1 << 19, // order 5:  512K entries
110                1 << 19, // order 6:  512K entries
111                1 << 18, // order 7:  256K entries
112                1 << 18, // order 8:  256K entries
113                1 << 17, // order 9:  128K entries
114                1 << 17, // order 10: 128K entries
115                1 << 16, // order 11: 64K entries
116                1 << 16, // order 12: 64K entries
117            ],
118        }
119    }
120
121    /// Scaled 4x (~360MB): 4x entries at orders 3-12 for fewer collisions.
122    pub fn scaled_4x() -> Self {
123        PpmConfig {
124            sizes: [
125                1,       // order 0:  1 entry (unigram)
126                1 << 8,  // order 1:  256 entries
127                1 << 16, // order 2:  64K entries
128                1 << 20, // order 3:  1M entries (was 256K)
129                1 << 21, // order 4:  2M entries (was 512K)
130                1 << 21, // order 5:  2M entries (was 512K)
131                1 << 21, // order 6:  2M entries (was 512K)
132                1 << 20, // order 7:  1M entries (was 256K)
133                1 << 20, // order 8:  1M entries (was 256K)
134                1 << 19, // order 9:  512K entries (was 128K)
135                1 << 19, // order 10: 512K entries (was 128K)
136                1 << 18, // order 11: 256K entries (was 64K)
137                1 << 18, // order 12: 256K entries (was 64K)
138            ],
139        }
140    }
141}
142
143/// PPM model with checksum-validated hash tables at orders 0-12.
144///
145/// Memory budget depends on config:
146/// - Default: ~90MB total
147/// - Scaled 4x: ~360MB total
148/// - Order 0: 1 entry (global unigram)
149/// - Order 1: 256 entries
150/// - Order 2: 64K entries
151/// - Orders 3+: configurable
152pub struct PpmModel {
153    /// Hash tables for orders 0..=MAX_ORDER.
154    tables: Vec<Box<[PpmEntry]>>,
155    /// Table masks (size - 1) per order.
156    masks: [usize; NUM_ORDERS],
157
158    /// Cached byte probability distribution (256 entries, scaled to sum ~2^20).
159    byte_probs: [u32; 256],
160    /// Whether byte_probs has been computed.
161    probs_valid: bool,
162    /// Context bytes: last MAX_ORDER bytes. [0] = most recent.
163    context: [u8; MAX_ORDER],
164    /// Number of bytes seen so far.
165    bytes_seen: usize,
166}
167
168fn make_table(size: usize) -> Box<[PpmEntry]> {
169    vec![PpmEntry::EMPTY; size].into_boxed_slice()
170}
171
172impl PpmModel {
173    /// Create a new PPM model with default sizes (~90MB).
174    pub fn new() -> Self {
175        Self::with_config(PpmConfig::default_sizes())
176    }
177
178    /// Create a PPM model with the given configuration.
179    pub fn with_config(config: PpmConfig) -> Self {
180        let mut tables = Vec::with_capacity(NUM_ORDERS);
181        let mut masks = [0usize; NUM_ORDERS];
182        for (i, &size) in config.sizes.iter().enumerate() {
183            tables.push(make_table(size));
184            masks[i] = size - 1;
185        }
186
187        PpmModel {
188            tables,
189            masks,
190            byte_probs: [0u32; 256],
191            probs_valid: false,
192            context: [0u8; MAX_ORDER],
193            bytes_seen: 0,
194        }
195    }
196
197    /// Predict bit probability. Returns 12-bit probability [1, 4095].
198    #[inline]
199    pub fn predict_bit(&mut self, bpos: u8, c0: u32) -> u32 {
200        if !self.probs_valid {
201            self.compute_byte_probs();
202            self.probs_valid = true;
203        }
204
205        let bit_pos = 7 - bpos;
206        let mask = 1u8 << bit_pos;
207
208        let mut sum_one: u64 = 0;
209        let mut sum_zero: u64 = 0;
210
211        if bpos == 0 {
212            for b in 0..256usize {
213                let p = self.byte_probs[b] as u64;
214                if (b as u8) & mask != 0 {
215                    sum_one += p;
216                } else {
217                    sum_zero += p;
218                }
219            }
220        } else {
221            let partial = (c0 & ((1u32 << bpos) - 1)) as u8;
222            let shift = 8 - bpos;
223            let base = (partial as usize) << shift;
224            let count = 1usize << shift;
225
226            for i in 0..count {
227                let b = base | i;
228                let p = self.byte_probs[b] as u64;
229                if (b as u8) & mask != 0 {
230                    sum_one += p;
231                } else {
232                    sum_zero += p;
233                }
234            }
235        }
236
237        let total = sum_one + sum_zero;
238        if total == 0 {
239            return 2048;
240        }
241
242        let p = ((sum_one << 12) / total) as u32;
243        p.clamp(1, 4095)
244    }
245
246    /// Update PPM model after a full byte has been decoded.
247    #[inline]
248    pub fn update_byte(&mut self, byte: u8) {
249        let max_usable_order = self.bytes_seen.min(MAX_ORDER);
250
251        for order in 0..=max_usable_order {
252            let (hash, chk) = self.context_hash_and_checksum(order);
253            let idx = hash as usize & self.masks[order];
254            let entry = &mut self.tables[order][idx];
255
256            if entry.checksum == 0 || entry.checksum == chk {
257                entry.checksum = chk;
258                entry.increment(byte);
259                if entry.total > 4000 {
260                    entry.halve();
261                }
262            } else {
263                // Hash collision. Replace weak entries.
264                if entry.total < 4 {
265                    *entry = PpmEntry::EMPTY;
266                    entry.checksum = chk;
267                    entry.increment(byte);
268                }
269            }
270        }
271
272        // Shift context ring.
273        for i in (1..MAX_ORDER).rev() {
274            self.context[i] = self.context[i - 1];
275        }
276        self.context[0] = byte;
277        self.bytes_seen += 1;
278        self.probs_valid = false;
279    }
280
281    /// Compute byte probability distribution using PPMd Method D with exclusion.
282    fn compute_byte_probs(&mut self) {
283        let max_usable_order = self.bytes_seen.min(MAX_ORDER);
284
285        let mut excluded = [false; 256];
286        let mut probs = [0u64; 256];
287        let mut remaining_mass: u64 = 1 << 20;
288
289        // Scan from highest order down to 0.
290        for order in (0..=max_usable_order).rev() {
291            let (hash, chk) = self.context_hash_and_checksum(order);
292            let idx = hash as usize & self.masks[order];
293            let entry = &self.tables[order][idx];
294
295            // Skip if empty or checksum mismatch (hash collision).
296            if entry.checksum != chk || entry.total == 0 || entry.len == 0 {
297                continue;
298            }
299
300            let mut effective_total: u32 = 0;
301            let mut effective_distinct: u32 = 0;
302
303            let n = entry.len as usize;
304            for i in 0..n {
305                if !excluded[entry.syms[i] as usize] {
306                    effective_total += entry.counts[i] as u32;
307                    effective_distinct += 1;
308                }
309            }
310
311            if effective_total == 0 || effective_distinct == 0 {
312                continue;
313            }
314
315            // PPMd Method D escape.
316            let escape_d = effective_distinct.div_ceil(2);
317            let denominator = effective_total + escape_d;
318
319            let symbol_mass = (remaining_mass * effective_total as u64) / denominator as u64;
320            let escape_frac = remaining_mass - symbol_mass;
321
322            for i in 0..n {
323                let sym = entry.syms[i];
324                if !excluded[sym as usize] {
325                    let sym_prob = (symbol_mass * entry.counts[i] as u64) / effective_total as u64;
326                    probs[sym as usize] += sym_prob;
327                    excluded[sym as usize] = true;
328                }
329            }
330
331            remaining_mass = escape_frac;
332            if remaining_mass == 0 {
333                break;
334            }
335        }
336
337        // Order -1: uniform for remaining unseen symbols.
338        if remaining_mass > 0 {
339            let mut unseen: u32 = 0;
340            for e in &excluded {
341                if !e {
342                    unseen += 1;
343                }
344            }
345            if unseen > 0 {
346                let per_sym = remaining_mass / unseen as u64;
347                let mut leftover = remaining_mass - per_sym * unseen as u64;
348                for i in 0..256 {
349                    if !excluded[i] {
350                        probs[i] += per_sym;
351                        if leftover > 0 {
352                            probs[i] += 1;
353                            leftover -= 1;
354                        }
355                    }
356                }
357            }
358        }
359
360        for (i, &p) in probs.iter().enumerate() {
361            self.byte_probs[i] = p as u32;
362        }
363    }
364
365    /// Compute context hash and 16-bit checksum for a given order.
366    #[inline]
367    fn context_hash_and_checksum(&self, order: usize) -> (u32, u16) {
368        if order == 0 {
369            // Order 0: single context, fixed checksum.
370            return (0, 1);
371        }
372        let mut h = FNV_OFFSET;
373        for i in 0..order {
374            h ^= self.context[i] as u32;
375            h = h.wrapping_mul(FNV_PRIME);
376        }
377        let chk = ((h >> 16) as u16) | 1; // ensure non-zero
378        (h, chk)
379    }
380}
381
382impl Default for PpmModel {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn initial_prediction_balanced() {
394        let mut model = PpmModel::new();
395        let p = model.predict_bit(0, 1);
396        assert!(
397            (1900..=2100).contains(&p),
398            "initial prediction should be near 2048, got {p}"
399        );
400    }
401
402    #[test]
403    fn prediction_always_in_range() {
404        let mut model = PpmModel::new();
405        let data = b"Hello, World! This is a test of the PPM model for prediction.";
406        for &byte in data {
407            for bpos in 0..8u8 {
408                let c0 = if bpos == 0 {
409                    1u32
410                } else {
411                    let mut p = 1u32;
412                    for prev in 0..bpos {
413                        p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
414                    }
415                    p
416                };
417                let p = model.predict_bit(bpos, c0);
418                assert!(
419                    (1..=4095).contains(&p),
420                    "prediction out of range at bpos {bpos}: {p}"
421                );
422            }
423            model.update_byte(byte);
424        }
425    }
426
427    #[test]
428    fn adapts_to_repeated_bytes() {
429        let mut model = PpmModel::new();
430        let byte = b'A';
431        for _ in 0..100 {
432            model.update_byte(byte);
433        }
434        let p = model.predict_bit(0, 1);
435        // Bit 7 of 'A' (0x41) is 0, so P(bit=1) should be low.
436        assert!(
437            p < 1500,
438            "after 100 'A' bytes, P(bit7=1) should be low, got {p}"
439        );
440    }
441
442    #[test]
443    fn adapts_to_repeated_pattern() {
444        let mut model = PpmModel::new();
445        let pattern = b"abcdefgh";
446        for _ in 0..200 {
447            for &byte in pattern {
448                model.update_byte(byte);
449            }
450        }
451        for &byte in b"abcdefg" {
452            model.update_byte(byte);
453        }
454        model.compute_byte_probs();
455        let p_h = model.byte_probs[b'h' as usize];
456        assert!(
457            p_h > 100_000,
458            "after 'abcdefg', P('h') should be significant, got {p_h} / 1048576"
459        );
460    }
461
462    #[test]
463    fn byte_probs_sum_correctly() {
464        let mut model = PpmModel::new();
465        let data = b"the quick brown fox jumps over the lazy dog the cat sat on the mat";
466        for &byte in data.iter() {
467            model.update_byte(byte);
468        }
469        model.compute_byte_probs();
470        let total: u64 = model.byte_probs.iter().map(|&p| p as u64).sum();
471        assert!(
472            (1_000_000..=1_100_000).contains(&total),
473            "byte_probs should sum to ~1M, got {total}"
474        );
475    }
476
477    #[test]
478    fn exclusion_works() {
479        let mut model = PpmModel::new();
480        for _ in 0..100 {
481            model.update_byte(b'a');
482            model.update_byte(b'b');
483        }
484        model.update_byte(b'a');
485        model.compute_byte_probs();
486        let p_b = model.byte_probs[b'b' as usize];
487        let p_a = model.byte_probs[b'a' as usize];
488        assert!(
489            p_b > p_a * 2,
490            "after 'a', P('b')={p_b} should be >> P('a')={p_a}"
491        );
492    }
493
494    #[test]
495    fn deterministic() {
496        let data = b"test determinism of ppm model with enough context abcabc";
497        let mut m1 = PpmModel::new();
498        let mut m2 = PpmModel::new();
499
500        for &byte in data.iter() {
501            for bpos in 0..8u8 {
502                let c0 = if bpos == 0 {
503                    1u32
504                } else {
505                    let mut p = 1u32;
506                    for prev in 0..bpos {
507                        p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
508                    }
509                    p
510                };
511                let p1 = m1.predict_bit(bpos, c0);
512                let p2 = m2.predict_bit(bpos, c0);
513                assert_eq!(p1, p2, "models diverged at bpos {bpos}");
514            }
515            m1.update_byte(byte);
516            m2.update_byte(byte);
517        }
518    }
519
520    #[test]
521    fn solo_bpb_alice29_prefix() {
522        let data = include_bytes!("../../../../corpus/alice29.txt");
523        let prefix = &data[..10_000.min(data.len())];
524
525        let mut model = PpmModel::new();
526        let mut total_bits: f64 = 0.0;
527
528        for &byte in prefix {
529            let mut c0 = 1u32;
530            for bpos in 0..8u8 {
531                let p = model.predict_bit(bpos, c0);
532                let bit = (byte >> (7 - bpos)) & 1;
533                let prob_of_bit = if bit == 1 {
534                    p as f64 / 4096.0
535                } else {
536                    1.0 - p as f64 / 4096.0
537                };
538                total_bits += -prob_of_bit.max(1e-9).log2();
539                c0 = (c0 << 1) | bit as u32;
540            }
541            model.update_byte(byte);
542        }
543
544        let bpb = total_bits / prefix.len() as f64;
545        eprintln!("PPM solo bpb on 10KB alice29 (orders 0-{MAX_ORDER}): {bpb:.3}");
546        assert!(bpb < 6.0, "PPM solo bpb too high: {bpb:.3}");
547    }
548
549    #[test]
550    fn ppm_entry_increment_and_halve() {
551        let mut entry = PpmEntry::EMPTY;
552        entry.checksum = 1;
553        entry.increment(b'a');
554        entry.increment(b'a');
555        entry.increment(b'b');
556        assert_eq!(entry.len, 2);
557        assert_eq!(entry.total, 3);
558
559        entry.halve();
560        assert_eq!(entry.len, 1);
561        assert_eq!(entry.counts[0], 1);
562        assert_eq!(entry.total, 1);
563    }
564}