Skip to main content

datacortex_core/model/
neural_model.rs

1//! NeuralModel -- bit-level cross-context model.
2//!
3//! Uses context hashes that combine bit-level information from the current
4//! byte being decoded with byte-level context. This captures patterns that
5//! traditional byte-context CM models miss because they don't condition on
6//! the bits already decoded in the current byte.
7//!
8//! Key contexts:
9//!   - c0_full (all decoded bits so far, not just low 8) x c1 nibbles
10//!   - Byte boundary contexts (position in line, word boundary)
11//!   - Bit-pattern contexts (repeated bit runs, alternating patterns)
12//!
13//! Uses the same ContextModel (ContextMap + StateMap) machinery as other models.
14//!
15//! CRITICAL: Encoder and decoder must produce IDENTICAL neural model state.
16
17use crate::model::cm_model::ContextModel;
18
19/// Number of internal context models.
20const N_MODELS: usize = 4;
21
22/// Neural model using bit-level cross-contexts.
23pub struct NeuralModel {
24    /// Internal context models.
25    models: Vec<ContextModel>,
26}
27
28impl NeuralModel {
29    /// Create a new neural model with default sizes.
30    pub fn new() -> Self {
31        Self::with_size(1 << 21) // 2MB per model, 8MB total
32    }
33
34    /// Create with custom size per internal model.
35    pub fn with_size(size: usize) -> Self {
36        let mut models = Vec::with_capacity(N_MODELS);
37        for _ in 0..N_MODELS {
38            models.push(ContextModel::new(size));
39        }
40        NeuralModel { models }
41    }
42
43    /// Predict using bit-level cross-contexts. Returns single 12-bit probability.
44    #[inline]
45    #[allow(clippy::too_many_arguments)]
46    pub fn predict(
47        &mut self,
48        c0: u32,
49        bpos: u8,
50        c1: u8,
51        c2: u8,
52        c3: u8,
53        run_len: u8,
54        match_q: u8,
55    ) -> u32 {
56        // c0 contains the FULL partial byte (1-255 range, MSB first).
57        // At bpos=0, c0=1. At bpos=3, c0 has 1 + 3 decoded bits.
58        // This is a unique signal that no other model uses directly.
59        let c0_full = c0; // up to 9 bits of context
60
61        // Context 0: c0_full (partial byte with all bits) x c1_hi_nibble
62        // This is UNIQUE because other models use c0 & 0xFF (losing the leading 1-bit).
63        // c0_full encodes exactly which bits have been decoded so far.
64        let c1_hi = (c1 >> 4) as u32;
65        let h0 = fhash3(c0_full, c1_hi, 0xA1B2_C3D4, 0xDEAD_1001);
66        let p0 = self.models[0].predict(h0);
67
68        // Context 1: c0_full x byte_class_pair(c1, c2)
69        // Combines partial byte with a transition context
70        let class_pair = byte_class_pair(c1, c2) as u32;
71        let h1 = fhash3(c0_full, class_pair, 0xE5F6_0718, 0xBEEF_2002);
72        let p1 = self.models[1].predict(h1);
73
74        // Context 2: c0_full x c1 x run_q x bpos -- conditioning on whether we're in a run
75        let rq = quantize_run(run_len) as u32;
76        let h2 = fhash4(c0_full, c1 as u32, rq, bpos as u32, 0xCAFE_3003);
77        let p2 = self.models[2].predict(h2);
78
79        // Context 3: c0_full x c2_lo x match_q -- skip-1 with match awareness
80        let c2_lo = (c2 & 0x0F) as u32;
81        let h3 = fhash4(c0_full, c2_lo, match_q as u32, c3 as u32, 0xFACE_4004);
82        let p3 = self.models[3].predict(h3);
83
84        // Average predictions
85        let sum = p0 + p1 + p2 + p3;
86        let avg = sum / N_MODELS as u32;
87        avg.clamp(1, 4095)
88    }
89
90    /// Update all internal models after observing bit.
91    #[inline]
92    pub fn update(&mut self, bit: u8) {
93        for model in &mut self.models {
94            model.update(bit);
95        }
96    }
97}
98
99impl Default for NeuralModel {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105/// Encode byte class pair as a single value (6x6 = 36 values).
106#[inline]
107fn byte_class_pair(c1: u8, c2: u8) -> u8 {
108    byte_class_6(c1) * 6 + byte_class_6(c2)
109}
110
111/// 6-class byte classification.
112#[inline]
113fn byte_class_6(b: u8) -> u8 {
114    match b {
115        b'a'..=b'z' => 0,
116        b'A'..=b'Z' => 1,
117        b'0'..=b'9' => 2,
118        b' ' | b'\t' => 3,
119        b'\n' | b'\r' => 4,
120        _ => 5,
121    }
122}
123
124/// Quantize run length to 0-3.
125#[inline]
126fn quantize_run(run_len: u8) -> u8 {
127    match run_len {
128        0..=1 => 0,
129        2..=3 => 1,
130        4..=8 => 2,
131        _ => 3,
132    }
133}
134
135/// Fast hash of 3 values with seed. FNV-1a style.
136#[inline]
137fn fhash3(a: u32, b: u32, c: u32, seed: u32) -> u32 {
138    let mut h = seed;
139    h ^= a;
140    h = h.wrapping_mul(0x0100_0193);
141    h ^= b;
142    h = h.wrapping_mul(0x0100_0193);
143    h ^= c;
144    h = h.wrapping_mul(0x0100_0193);
145    h
146}
147
148/// Fast hash of 4 values with seed.
149#[inline]
150fn fhash4(a: u32, b: u32, c: u32, d: u32, seed: u32) -> u32 {
151    let mut h = seed;
152    h ^= a;
153    h = h.wrapping_mul(0x0100_0193);
154    h ^= b;
155    h = h.wrapping_mul(0x0100_0193);
156    h ^= c;
157    h = h.wrapping_mul(0x0100_0193);
158    h ^= d;
159    h = h.wrapping_mul(0x0100_0193);
160    h
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn initial_prediction_near_half() {
169        let mut model = NeuralModel::new();
170        let p = model.predict(1, 0, 0, 0, 0, 0, 0);
171        assert!(
172            (1800..=2200).contains(&p),
173            "initial prediction should be near 2048, got {p}"
174        );
175    }
176
177    #[test]
178    fn prediction_always_in_range() {
179        let mut model = NeuralModel::new();
180        for c1 in [0u8, 65, 128, 255] {
181            for bpos in 0..8u8 {
182                let p = model.predict(1, bpos, c1, 0, 0, 0, 0);
183                assert!((1..=4095).contains(&p), "prediction out of range: {p}");
184                model.update(1);
185            }
186        }
187    }
188
189    #[test]
190    fn deterministic() {
191        let mut m1 = NeuralModel::new();
192        let mut m2 = NeuralModel::new();
193
194        let data: &[u8] = b"Hello World";
195        for &byte in data {
196            for bpos in 0..8u8 {
197                let p1 = m1.predict(1, bpos, byte, 0, 0, 0, 0);
198                let p2 = m2.predict(1, bpos, byte, 0, 0, 0, 0);
199                assert_eq!(p1, p2, "neural models diverged");
200                let bit = (byte >> (7 - bpos)) & 1;
201                m1.update(bit);
202                m2.update(bit);
203            }
204        }
205    }
206
207    #[test]
208    fn adapts_to_data() {
209        let mut model = NeuralModel::new();
210        let mut first_p = 0;
211        for i in 0..200 {
212            let p = model.predict(1, 0, b'A', b'B', b'C', 1, 0);
213            if i == 0 {
214                first_p = p;
215            }
216            model.update(1);
217        }
218        let final_p = model.predict(1, 0, b'A', b'B', b'C', 1, 0);
219        assert!(
220            final_p > first_p,
221            "model should adapt: first={first_p}, final={final_p}"
222        );
223    }
224
225    #[test]
226    fn byte_class_categories() {
227        assert_eq!(byte_class_6(b'a'), 0);
228        assert_eq!(byte_class_6(b'Z'), 1);
229        assert_eq!(byte_class_6(b'5'), 2);
230        assert_eq!(byte_class_6(b' '), 3);
231        assert_eq!(byte_class_6(b'\n'), 4);
232        assert_eq!(byte_class_6(b'.'), 5);
233    }
234}