datacortex_core/model/
dmc_model.rs1#[derive(Clone, Copy)]
22struct DmcState {
23 counts: [u32; 2],
25 next: [u32; 2],
27}
28
29impl DmcState {
30 const EMPTY: Self = DmcState {
31 counts: [0; 2],
32 next: [0; 2],
33 };
34}
35
36const INITIAL_STATES: usize = 256 * 8;
38
39struct DmcInstance {
41 states: Vec<DmcState>,
42 current_state: u32,
43 num_states: usize,
44 max_states: usize,
45 clone_threshold: u32,
46}
47
48impl DmcInstance {
49 fn new(max_states: usize, clone_threshold: u32) -> Self {
50 let mut inst = DmcInstance {
51 states: vec![DmcState::EMPTY; max_states],
52 current_state: 0,
53 num_states: INITIAL_STATES,
54 max_states,
55 clone_threshold,
56 };
57 inst.init_states();
58 inst
59 }
60
61 fn init_states(&mut self) {
66 for prev_byte in 0..256u32 {
67 for bpos in 0..8u32 {
68 let state_idx = prev_byte * 8 + bpos;
69 let s = &mut self.states[state_idx as usize];
70 s.counts = [1, 1]; if bpos < 7 {
73 s.next[0] = prev_byte * 8 + bpos + 1;
75 s.next[1] = prev_byte * 8 + bpos + 1;
76 } else {
77 let even_byte = prev_byte & 0xFE;
81 let odd_byte = prev_byte | 1;
82 s.next[0] = even_byte * 8; s.next[1] = odd_byte * 8; }
85 }
86 }
87 self.num_states = INITIAL_STATES;
88 self.current_state = 0;
89 }
90
91 #[inline]
93 fn predict(&self) -> u32 {
94 let s = &self.states[self.current_state as usize];
95 let n0 = s.counts[0] as u64;
96 let n1 = s.counts[1] as u64;
97 let total = n0 + n1;
98 if total == 0 {
99 return 2048;
100 }
101 let p = ((n1 << 12) / total) as u32;
102 p.clamp(1, 4095)
103 }
104
105 #[inline]
107 fn update(&mut self, bit: u8) {
108 let b = bit as usize;
109 let cur = self.current_state as usize;
110
111 self.states[cur].counts[b] = self.states[cur].counts[b].saturating_add(1);
113
114 let total = self.states[cur].counts[0] + self.states[cur].counts[1];
116 if total > 8192 {
117 self.states[cur].counts[0] = (self.states[cur].counts[0] >> 1).max(1);
118 self.states[cur].counts[1] = (self.states[cur].counts[1] >> 1).max(1);
119 }
120
121 let next_idx = self.states[cur].next[b] as usize;
123 let cur_count = self.states[cur].counts[b];
124
125 if cur_count >= self.clone_threshold && self.num_states < self.max_states {
126 let target_total = self.states[next_idx].counts[0] + self.states[next_idx].counts[1];
127
128 if target_total > cur_count + self.clone_threshold {
129 let new_idx = self.num_states;
131 self.num_states += 1;
132
133 self.states[new_idx].next = self.states[next_idx].next;
135
136 let t0 = self.states[next_idx].counts[0] as u64;
140 let t1 = self.states[next_idx].counts[1] as u64;
141 let cc = cur_count as u64;
142 let tt = target_total as u64;
143
144 let new_c0 = ((t0 * cc) / tt).max(1) as u32;
145 let new_c1 = ((t1 * cc) / tt).max(1) as u32;
146
147 self.states[new_idx].counts[0] = new_c0;
148 self.states[new_idx].counts[1] = new_c1;
149
150 self.states[next_idx].counts[0] =
152 self.states[next_idx].counts[0].saturating_sub(new_c0.saturating_sub(1));
153 self.states[next_idx].counts[1] =
154 self.states[next_idx].counts[1].saturating_sub(new_c1.saturating_sub(1));
155
156 self.states[cur].next[b] = new_idx as u32;
158
159 self.current_state = new_idx as u32;
161 } else {
162 self.current_state = next_idx as u32;
163 }
164 } else {
165 self.current_state = next_idx as u32;
166 }
167 }
168
169 #[inline]
176 fn on_byte_complete(&mut self, _byte: u8) {
177 }
179
180 fn reset(&mut self) {
182 for s in self.states[..self.max_states].iter_mut() {
184 *s = DmcState::EMPTY;
185 }
186 self.init_states();
187 }
188}
189
190pub struct DmcModel {
194 instances: Vec<DmcInstance>,
195}
196
197impl DmcModel {
198 pub fn new_single() -> Self {
200 DmcModel {
201 instances: vec![DmcInstance::new(4 * 1024 * 1024, 2)],
202 }
203 }
204
205 pub fn new_forest() -> Self {
208 DmcModel {
209 instances: vec![
210 DmcInstance::new(2 * 1024 * 1024, 2), DmcInstance::new(2 * 1024 * 1024, 4), DmcInstance::new(2 * 1024 * 1024, 8), ],
214 }
215 }
216
217 #[inline]
220 pub fn predict(&self) -> u32 {
221 if self.instances.len() == 1 {
222 return self.instances[0].predict();
223 }
224
225 let mut sum: u64 = 0;
226 for inst in &self.instances {
227 sum += inst.predict() as u64;
228 }
229 let p = (sum / self.instances.len() as u64) as u32;
230 p.clamp(1, 4095)
231 }
232
233 #[inline]
235 pub fn update(&mut self, bit: u8) {
236 for inst in &mut self.instances {
237 inst.update(bit);
238
239 if inst.num_states >= inst.max_states {
241 inst.reset();
242 }
243 }
244 }
245
246 #[inline]
248 pub fn on_byte_complete(&mut self, byte: u8) {
249 for inst in &mut self.instances {
250 inst.on_byte_complete(byte);
251 }
252 }
253}
254
255impl Default for DmcModel {
256 fn default() -> Self {
257 Self::new_single()
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn initial_prediction_balanced() {
267 let model = DmcModel::new_single();
268 let p = model.predict();
269 assert!(
270 (1800..=2200).contains(&p),
271 "initial prediction should be near 2048, got {p}"
272 );
273 }
274
275 #[test]
276 fn prediction_always_in_range() {
277 let mut model = DmcModel::new_single();
278 let data = b"Hello, World! This is a test of the DMC model.";
279 for &byte in data {
280 for bpos in 0..8u8 {
281 let p = model.predict();
282 assert!(
283 (1..=4095).contains(&p),
284 "prediction out of range at bpos {bpos}: {p}"
285 );
286 let bit = (byte >> (7 - bpos)) & 1;
287 model.update(bit);
288 }
289 model.on_byte_complete(byte);
290 }
291 }
292
293 #[test]
294 fn adapts_to_repeated_bytes() {
295 let mut model = DmcModel::new_single();
296 let byte = b'A'; for _ in 0..200 {
298 for bpos in 0..8u8 {
299 let bit = (byte >> (7 - bpos)) & 1;
300 let _p = model.predict();
301 model.update(bit);
302 }
303 model.on_byte_complete(byte);
304 }
305 let p = model.predict();
307 assert!(
308 p < 1500,
309 "after 200 'A' bytes, P(bit7=1) should be low, got {p}"
310 );
311 }
312
313 #[test]
314 fn deterministic() {
315 let data = b"test determinism of dmc model";
316 let mut m1 = DmcModel::new_single();
317 let mut m2 = DmcModel::new_single();
318
319 for &byte in data.iter() {
320 for bpos in 0..8u8 {
321 let p1 = m1.predict();
322 let p2 = m2.predict();
323 assert_eq!(p1, p2, "models diverged at bpos {bpos}");
324 let bit = (byte >> (7 - bpos)) & 1;
325 m1.update(bit);
326 m2.update(bit);
327 }
328 m1.on_byte_complete(byte);
329 m2.on_byte_complete(byte);
330 }
331 }
332
333 #[test]
334 fn forest_prediction_balanced() {
335 let model = DmcModel::new_forest();
336 let p = model.predict();
337 assert!(
338 (1800..=2200).contains(&p),
339 "forest initial prediction should be near 2048, got {p}"
340 );
341 }
342
343 #[test]
344 fn forest_deterministic() {
345 let data = b"test forest determinism with some longer context data here";
346 let mut m1 = DmcModel::new_forest();
347 let mut m2 = DmcModel::new_forest();
348
349 for &byte in data.iter() {
350 for bpos in 0..8u8 {
351 let p1 = m1.predict();
352 let p2 = m2.predict();
353 assert_eq!(p1, p2, "forest models diverged at bpos {bpos}");
354 let bit = (byte >> (7 - bpos)) & 1;
355 m1.update(bit);
356 m2.update(bit);
357 }
358 m1.on_byte_complete(byte);
359 m2.on_byte_complete(byte);
360 }
361 }
362
363 #[test]
364 fn solo_bpb_alice29_prefix() {
365 let data = include_bytes!("../../../../corpus/alice29.txt");
369 let prefix = &data[..10_000.min(data.len())];
370
371 let mut model = DmcModel::new_single();
372 let mut total_bits: f64 = 0.0;
373
374 for &byte in prefix {
375 for bpos in 0..8u8 {
376 let p = model.predict();
377 let bit = (byte >> (7 - bpos)) & 1;
378 let prob_of_bit = if bit == 1 {
379 p as f64 / 4096.0
380 } else {
381 1.0 - p as f64 / 4096.0
382 };
383 total_bits += -prob_of_bit.max(1e-9).log2();
384 model.update(bit);
385 }
386 model.on_byte_complete(byte);
387 }
388
389 let bpb = total_bits / prefix.len() as f64;
390 eprintln!("DMC solo bpb on 10KB alice29: {bpb:.3}");
391 assert!(bpb < 9.0, "DMC solo bpb too high: {bpb:.3}");
394 }
395}