Skip to main content

datacortex_core/model/
dmc_model.rs

1//! DMC (Dynamic Markov Compression) — bit-level automaton predictor.
2//!
3//! A prediction model based on state cloning: starts with a small initial automaton
4//! and adaptively clones states when a transition is used frequently enough, creating
5//! context-specific states that capture sub-byte and cross-byte patterns.
6//!
7//! Key properties:
8//! - Bit-level: predict() returns 12-bit probability, update(bit) transitions automaton
9//! - State cloning: when a transition count exceeds clone_threshold and the target state
10//!   has accumulated significantly more counts, clone the target into a new state specific
11//!   to this transition path
12//! - Deterministic: uses integer arithmetic only for count splitting (no floats)
13//! - Self-resetting: when max_states reached, reinitialize to starting automaton
14//!
15//! Memory: ~64MB with 4M states at 16 bytes/state.
16//!
17//! Reference: Cormack & Horspool, "Data Compression using Dynamic Markov Modelling" (1987).
18//! PAQ8PX uses a DmcForest with multiple clone thresholds.
19
20/// A single DMC state: counts and transitions for bit 0 and bit 1.
21#[derive(Clone, Copy)]
22struct DmcState {
23    /// Counts of observed bits [count_0, count_1].
24    counts: [u32; 2],
25    /// Next state indices [next_state_if_0, next_state_if_1].
26    next: [u32; 2],
27}
28
29impl DmcState {
30    const EMPTY: Self = DmcState {
31        counts: [0; 2],
32        next: [0; 2],
33    };
34}
35
36/// Number of initial states: 256 (one per previous byte) × 8 (bit position) = 2048.
37const INITIAL_STATES: usize = 256 * 8;
38
39/// Single DMC automaton instance.
40struct DmcInstance {
41    states: Vec<DmcState>,
42    current_state: u32,
43    num_states: usize,
44    max_states: usize,
45    clone_threshold: u32,
46}
47
48impl DmcInstance {
49    fn new(max_states: usize, clone_threshold: u32) -> Self {
50        let mut inst = DmcInstance {
51            states: vec![DmcState::EMPTY; max_states],
52            current_state: 0,
53            num_states: INITIAL_STATES,
54            max_states,
55            clone_threshold,
56        };
57        inst.init_states();
58        inst
59    }
60
61    /// Initialize the automaton: create INITIAL_STATES states.
62    /// State index = prev_byte * 8 + bit_position.
63    /// At bpos 7, transitions go to the completed byte's bpos=0 state based on
64    /// the LSB (even/odd byte). Cloning will refine cross-byte paths over time.
65    fn init_states(&mut self) {
66        for prev_byte in 0..256u32 {
67            for bpos in 0..8u32 {
68                let state_idx = prev_byte * 8 + bpos;
69                let s = &mut self.states[state_idx as usize];
70                s.counts = [1, 1]; // Laplace prior
71
72                if bpos < 7 {
73                    // Next bit position in same byte context.
74                    s.next[0] = prev_byte * 8 + bpos + 1;
75                    s.next[1] = prev_byte * 8 + bpos + 1;
76                } else {
77                    // bpos == 7: byte complete. The last bit determines LSB of new byte.
78                    // Approximate: transition to the byte context based on prev_byte's
79                    // parity. Cloning creates refined cross-byte paths over time.
80                    let even_byte = prev_byte & 0xFE;
81                    let odd_byte = prev_byte | 1;
82                    s.next[0] = even_byte * 8; // bpos 0 of even byte context
83                    s.next[1] = odd_byte * 8; // bpos 0 of odd byte context
84                }
85            }
86        }
87        self.num_states = INITIAL_STATES;
88        self.current_state = 0;
89    }
90
91    /// Predict probability of bit=1. Returns 12-bit probability [1, 4095].
92    #[inline]
93    fn predict(&self) -> u32 {
94        let s = &self.states[self.current_state as usize];
95        let n0 = s.counts[0] as u64;
96        let n1 = s.counts[1] as u64;
97        let total = n0 + n1;
98        if total == 0 {
99            return 2048;
100        }
101        let p = ((n1 << 12) / total) as u32;
102        p.clamp(1, 4095)
103    }
104
105    /// Update after observing bit b. Transitions automaton and optionally clones.
106    #[inline]
107    fn update(&mut self, bit: u8) {
108        let b = bit as usize;
109        let cur = self.current_state as usize;
110
111        // Increment count for observed bit.
112        self.states[cur].counts[b] = self.states[cur].counts[b].saturating_add(1);
113
114        // Periodic count halving to prevent overflow and keep adaptation.
115        let total = self.states[cur].counts[0] + self.states[cur].counts[1];
116        if total > 8192 {
117            self.states[cur].counts[0] = (self.states[cur].counts[0] >> 1).max(1);
118            self.states[cur].counts[1] = (self.states[cur].counts[1] >> 1).max(1);
119        }
120
121        // State cloning logic.
122        let next_idx = self.states[cur].next[b] as usize;
123        let cur_count = self.states[cur].counts[b];
124
125        if cur_count >= self.clone_threshold && self.num_states < self.max_states {
126            let target_total = self.states[next_idx].counts[0] + self.states[next_idx].counts[1];
127
128            if target_total > cur_count + self.clone_threshold {
129                // Clone target into new state.
130                let new_idx = self.num_states;
131                self.num_states += 1;
132
133                // Copy target's transitions.
134                self.states[new_idx].next = self.states[next_idx].next;
135
136                // Split counts using integer arithmetic (deterministic).
137                // new_state gets cur_count / target_total proportion of each count.
138                // Use u64 to avoid overflow.
139                let t0 = self.states[next_idx].counts[0] as u64;
140                let t1 = self.states[next_idx].counts[1] as u64;
141                let cc = cur_count as u64;
142                let tt = target_total as u64;
143
144                let new_c0 = ((t0 * cc) / tt).max(1) as u32;
145                let new_c1 = ((t1 * cc) / tt).max(1) as u32;
146
147                self.states[new_idx].counts[0] = new_c0;
148                self.states[new_idx].counts[1] = new_c1;
149
150                // Reduce original target counts.
151                self.states[next_idx].counts[0] =
152                    self.states[next_idx].counts[0].saturating_sub(new_c0.saturating_sub(1));
153                self.states[next_idx].counts[1] =
154                    self.states[next_idx].counts[1].saturating_sub(new_c1.saturating_sub(1));
155
156                // Redirect this transition to the clone.
157                self.states[cur].next[b] = new_idx as u32;
158
159                // Transition to clone.
160                self.current_state = new_idx as u32;
161            } else {
162                self.current_state = next_idx as u32;
163            }
164        } else {
165            self.current_state = next_idx as u32;
166        }
167    }
168
169    /// Notify that a full byte has been completed.
170    /// In the mixer context, NOT resetting gives better results because:
171    /// - Reset predictions at byte start are noisy (back to initial state)
172    /// - Natural transitions let cloned states capture cross-byte patterns
173    ///
174    /// The solo test is worse without reset, but mixer integration is better.
175    #[inline]
176    fn on_byte_complete(&mut self, _byte: u8) {
177        // No-op: let the automaton flow naturally across byte boundaries.
178    }
179
180    /// Full reset when max_states is reached.
181    fn reset(&mut self) {
182        // Clear all states.
183        for s in self.states[..self.max_states].iter_mut() {
184            *s = DmcState::EMPTY;
185        }
186        self.init_states();
187    }
188}
189
190/// DmcForest: multiple DMC instances with different clone thresholds.
191/// Each instance captures patterns at different granularities.
192/// Predictions are averaged in probability space.
193pub struct DmcModel {
194    instances: Vec<DmcInstance>,
195}
196
197impl DmcModel {
198    /// Create a DmcModel with a single instance (threshold=2, 4M states = ~64MB).
199    pub fn new_single() -> Self {
200        DmcModel {
201            instances: vec![DmcInstance::new(4 * 1024 * 1024, 2)],
202        }
203    }
204
205    /// Create a DmcForest with 3 instances at different thresholds.
206    /// Total memory: ~48MB (3 × 1M states) or ~96MB (3 × 2M states).
207    pub fn new_forest() -> Self {
208        DmcModel {
209            instances: vec![
210                DmcInstance::new(2 * 1024 * 1024, 2), // aggressive cloning (~32MB)
211                DmcInstance::new(2 * 1024 * 1024, 4), // moderate cloning (~32MB)
212                DmcInstance::new(2 * 1024 * 1024, 8), // conservative cloning (~32MB)
213            ],
214        }
215    }
216
217    /// Predict probability of bit=1. Returns 12-bit probability [1, 4095].
218    /// Averages predictions from all instances.
219    #[inline]
220    pub fn predict(&self) -> u32 {
221        if self.instances.len() == 1 {
222            return self.instances[0].predict();
223        }
224
225        let mut sum: u64 = 0;
226        for inst in &self.instances {
227            sum += inst.predict() as u64;
228        }
229        let p = (sum / self.instances.len() as u64) as u32;
230        p.clamp(1, 4095)
231    }
232
233    /// Update all instances after observing bit.
234    #[inline]
235    pub fn update(&mut self, bit: u8) {
236        for inst in &mut self.instances {
237            inst.update(bit);
238
239            // Check if we need to reset (max_states reached).
240            if inst.num_states >= inst.max_states {
241                inst.reset();
242            }
243        }
244    }
245
246    /// Notify all instances that a full byte has been completed.
247    #[inline]
248    pub fn on_byte_complete(&mut self, byte: u8) {
249        for inst in &mut self.instances {
250            inst.on_byte_complete(byte);
251        }
252    }
253}
254
255impl Default for DmcModel {
256    fn default() -> Self {
257        Self::new_single()
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn initial_prediction_balanced() {
267        let model = DmcModel::new_single();
268        let p = model.predict();
269        assert!(
270            (1800..=2200).contains(&p),
271            "initial prediction should be near 2048, got {p}"
272        );
273    }
274
275    #[test]
276    fn prediction_always_in_range() {
277        let mut model = DmcModel::new_single();
278        let data = b"Hello, World! This is a test of the DMC model.";
279        for &byte in data {
280            for bpos in 0..8u8 {
281                let p = model.predict();
282                assert!(
283                    (1..=4095).contains(&p),
284                    "prediction out of range at bpos {bpos}: {p}"
285                );
286                let bit = (byte >> (7 - bpos)) & 1;
287                model.update(bit);
288            }
289            model.on_byte_complete(byte);
290        }
291    }
292
293    #[test]
294    fn adapts_to_repeated_bytes() {
295        let mut model = DmcModel::new_single();
296        let byte = b'A'; // 0x41 = 01000001
297        for _ in 0..200 {
298            for bpos in 0..8u8 {
299                let bit = (byte >> (7 - bpos)) & 1;
300                let _p = model.predict();
301                model.update(bit);
302            }
303            model.on_byte_complete(byte);
304        }
305        // After many 'A' bytes, bit 7 (MSB) of 'A' is 0, so P(1) should be low.
306        let p = model.predict();
307        assert!(
308            p < 1500,
309            "after 200 'A' bytes, P(bit7=1) should be low, got {p}"
310        );
311    }
312
313    #[test]
314    fn deterministic() {
315        let data = b"test determinism of dmc model";
316        let mut m1 = DmcModel::new_single();
317        let mut m2 = DmcModel::new_single();
318
319        for &byte in data.iter() {
320            for bpos in 0..8u8 {
321                let p1 = m1.predict();
322                let p2 = m2.predict();
323                assert_eq!(p1, p2, "models diverged at bpos {bpos}");
324                let bit = (byte >> (7 - bpos)) & 1;
325                m1.update(bit);
326                m2.update(bit);
327            }
328            m1.on_byte_complete(byte);
329            m2.on_byte_complete(byte);
330        }
331    }
332
333    #[test]
334    fn forest_prediction_balanced() {
335        let model = DmcModel::new_forest();
336        let p = model.predict();
337        assert!(
338            (1800..=2200).contains(&p),
339            "forest initial prediction should be near 2048, got {p}"
340        );
341    }
342
343    #[test]
344    fn forest_deterministic() {
345        let data = b"test forest determinism with some longer context data here";
346        let mut m1 = DmcModel::new_forest();
347        let mut m2 = DmcModel::new_forest();
348
349        for &byte in data.iter() {
350            for bpos in 0..8u8 {
351                let p1 = m1.predict();
352                let p2 = m2.predict();
353                assert_eq!(p1, p2, "forest models diverged at bpos {bpos}");
354                let bit = (byte >> (7 - bpos)) & 1;
355                m1.update(bit);
356                m2.update(bit);
357            }
358            m1.on_byte_complete(byte);
359            m2.on_byte_complete(byte);
360        }
361    }
362
363    #[test]
364    fn solo_bpb_alice29_prefix() {
365        // DMC solo without byte-boundary reset has higher bpb (~8) but
366        // performs better in the mixer context. With reset, solo is ~3.9 bpb
367        // but mixer integration is worse due to noisy reset predictions.
368        let data = include_bytes!("../../../../corpus/alice29.txt");
369        let prefix = &data[..10_000.min(data.len())];
370
371        let mut model = DmcModel::new_single();
372        let mut total_bits: f64 = 0.0;
373
374        for &byte in prefix {
375            for bpos in 0..8u8 {
376                let p = model.predict();
377                let bit = (byte >> (7 - bpos)) & 1;
378                let prob_of_bit = if bit == 1 {
379                    p as f64 / 4096.0
380                } else {
381                    1.0 - p as f64 / 4096.0
382                };
383                total_bits += -prob_of_bit.max(1e-9).log2();
384                model.update(bit);
385            }
386            model.on_byte_complete(byte);
387        }
388
389        let bpb = total_bits / prefix.len() as f64;
390        eprintln!("DMC solo bpb on 10KB alice29: {bpb:.3}");
391        // Threshold is lenient: DMC without byte-boundary reset has high solo bpb
392        // but contributes useful diversity to the mixer on large files.
393        assert!(bpb < 9.0, "DMC solo bpb too high: {bpb:.3}");
394    }
395}