1use crate::model::cm_model::ContextModel;
18
19const N_MODELS: usize = 4;
21
22pub struct NeuralModel {
24 models: Vec<ContextModel>,
26}
27
28impl NeuralModel {
29 pub fn new() -> Self {
31 Self::with_size(1 << 21) }
33
34 pub fn with_size(size: usize) -> Self {
36 let mut models = Vec::with_capacity(N_MODELS);
37 for _ in 0..N_MODELS {
38 models.push(ContextModel::new(size));
39 }
40 NeuralModel { models }
41 }
42
43 #[inline]
45 #[allow(clippy::too_many_arguments)]
46 pub fn predict(
47 &mut self,
48 c0: u32,
49 bpos: u8,
50 c1: u8,
51 c2: u8,
52 c3: u8,
53 run_len: u8,
54 match_q: u8,
55 ) -> u32 {
56 let c0_full = c0; let c1_hi = (c1 >> 4) as u32;
65 let h0 = fhash3(c0_full, c1_hi, 0xA1B2_C3D4, 0xDEAD_1001);
66 let p0 = self.models[0].predict(h0);
67
68 let class_pair = byte_class_pair(c1, c2) as u32;
71 let h1 = fhash3(c0_full, class_pair, 0xE5F6_0718, 0xBEEF_2002);
72 let p1 = self.models[1].predict(h1);
73
74 let rq = quantize_run(run_len) as u32;
76 let h2 = fhash4(c0_full, c1 as u32, rq, bpos as u32, 0xCAFE_3003);
77 let p2 = self.models[2].predict(h2);
78
79 let c2_lo = (c2 & 0x0F) as u32;
81 let h3 = fhash4(c0_full, c2_lo, match_q as u32, c3 as u32, 0xFACE_4004);
82 let p3 = self.models[3].predict(h3);
83
84 let sum = p0 + p1 + p2 + p3;
86 let avg = sum / N_MODELS as u32;
87 avg.clamp(1, 4095)
88 }
89
90 #[inline]
92 pub fn update(&mut self, bit: u8) {
93 for model in &mut self.models {
94 model.update(bit);
95 }
96 }
97}
98
99impl Default for NeuralModel {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105#[inline]
107fn byte_class_pair(c1: u8, c2: u8) -> u8 {
108 byte_class_6(c1) * 6 + byte_class_6(c2)
109}
110
111#[inline]
113fn byte_class_6(b: u8) -> u8 {
114 match b {
115 b'a'..=b'z' => 0,
116 b'A'..=b'Z' => 1,
117 b'0'..=b'9' => 2,
118 b' ' | b'\t' => 3,
119 b'\n' | b'\r' => 4,
120 _ => 5,
121 }
122}
123
124#[inline]
126fn quantize_run(run_len: u8) -> u8 {
127 match run_len {
128 0..=1 => 0,
129 2..=3 => 1,
130 4..=8 => 2,
131 _ => 3,
132 }
133}
134
135#[inline]
137fn fhash3(a: u32, b: u32, c: u32, seed: u32) -> u32 {
138 let mut h = seed;
139 h ^= a;
140 h = h.wrapping_mul(0x0100_0193);
141 h ^= b;
142 h = h.wrapping_mul(0x0100_0193);
143 h ^= c;
144 h = h.wrapping_mul(0x0100_0193);
145 h
146}
147
148#[inline]
150fn fhash4(a: u32, b: u32, c: u32, d: u32, seed: u32) -> u32 {
151 let mut h = seed;
152 h ^= a;
153 h = h.wrapping_mul(0x0100_0193);
154 h ^= b;
155 h = h.wrapping_mul(0x0100_0193);
156 h ^= c;
157 h = h.wrapping_mul(0x0100_0193);
158 h ^= d;
159 h = h.wrapping_mul(0x0100_0193);
160 h
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn initial_prediction_near_half() {
169 let mut model = NeuralModel::new();
170 let p = model.predict(1, 0, 0, 0, 0, 0, 0);
171 assert!(
172 (1800..=2200).contains(&p),
173 "initial prediction should be near 2048, got {p}"
174 );
175 }
176
177 #[test]
178 fn prediction_always_in_range() {
179 let mut model = NeuralModel::new();
180 for c1 in [0u8, 65, 128, 255] {
181 for bpos in 0..8u8 {
182 let p = model.predict(1, bpos, c1, 0, 0, 0, 0);
183 assert!((1..=4095).contains(&p), "prediction out of range: {p}");
184 model.update(1);
185 }
186 }
187 }
188
189 #[test]
190 fn deterministic() {
191 let mut m1 = NeuralModel::new();
192 let mut m2 = NeuralModel::new();
193
194 let data: &[u8] = b"Hello World";
195 for &byte in data {
196 for bpos in 0..8u8 {
197 let p1 = m1.predict(1, bpos, byte, 0, 0, 0, 0);
198 let p2 = m2.predict(1, bpos, byte, 0, 0, 0, 0);
199 assert_eq!(p1, p2, "neural models diverged");
200 let bit = (byte >> (7 - bpos)) & 1;
201 m1.update(bit);
202 m2.update(bit);
203 }
204 }
205 }
206
207 #[test]
208 fn adapts_to_data() {
209 let mut model = NeuralModel::new();
210 let mut first_p = 0;
211 for i in 0..200 {
212 let p = model.predict(1, 0, b'A', b'B', b'C', 1, 0);
213 if i == 0 {
214 first_p = p;
215 }
216 model.update(1);
217 }
218 let final_p = model.predict(1, 0, b'A', b'B', b'C', 1, 0);
219 assert!(
220 final_p > first_p,
221 "model should adapt: first={first_p}, final={final_p}"
222 );
223 }
224
225 #[test]
226 fn byte_class_categories() {
227 assert_eq!(byte_class_6(b'a'), 0);
228 assert_eq!(byte_class_6(b'Z'), 1);
229 assert_eq!(byte_class_6(b'5'), 2);
230 assert_eq!(byte_class_6(b' '), 3);
231 assert_eq!(byte_class_6(b'\n'), 4);
232 assert_eq!(byte_class_6(b'.'), 5);
233 }
234}