#[derive(Clone, Copy)]
struct DmcState {
counts: [u32; 2],
next: [u32; 2],
}
impl DmcState {
const EMPTY: Self = DmcState {
counts: [0; 2],
next: [0; 2],
};
}
const INITIAL_STATES: usize = 256 * 8;
struct DmcInstance {
states: Vec<DmcState>,
current_state: u32,
num_states: usize,
max_states: usize,
clone_threshold: u32,
}
impl DmcInstance {
fn new(max_states: usize, clone_threshold: u32) -> Self {
let mut inst = DmcInstance {
states: vec![DmcState::EMPTY; max_states],
current_state: 0,
num_states: INITIAL_STATES,
max_states,
clone_threshold,
};
inst.init_states();
inst
}
fn init_states(&mut self) {
for prev_byte in 0..256u32 {
for bpos in 0..8u32 {
let state_idx = prev_byte * 8 + bpos;
let s = &mut self.states[state_idx as usize];
s.counts = [1, 1];
if bpos < 7 {
s.next[0] = prev_byte * 8 + bpos + 1;
s.next[1] = prev_byte * 8 + bpos + 1;
} else {
let even_byte = prev_byte & 0xFE;
let odd_byte = prev_byte | 1;
s.next[0] = even_byte * 8; s.next[1] = odd_byte * 8; }
}
}
self.num_states = INITIAL_STATES;
self.current_state = 0;
}
#[inline]
fn predict(&self) -> u32 {
let s = &self.states[self.current_state as usize];
let n0 = s.counts[0] as u64;
let n1 = s.counts[1] as u64;
let total = n0 + n1;
if total == 0 {
return 2048;
}
let p = ((n1 << 12) / total) as u32;
p.clamp(1, 4095)
}
#[inline]
fn update(&mut self, bit: u8) {
let b = bit as usize;
let cur = self.current_state as usize;
self.states[cur].counts[b] = self.states[cur].counts[b].saturating_add(1);
let total = self.states[cur].counts[0] + self.states[cur].counts[1];
if total > 8192 {
self.states[cur].counts[0] = (self.states[cur].counts[0] >> 1).max(1);
self.states[cur].counts[1] = (self.states[cur].counts[1] >> 1).max(1);
}
let next_idx = self.states[cur].next[b] as usize;
let cur_count = self.states[cur].counts[b];
if cur_count >= self.clone_threshold && self.num_states < self.max_states {
let target_total = self.states[next_idx].counts[0] + self.states[next_idx].counts[1];
if target_total > cur_count + self.clone_threshold {
let new_idx = self.num_states;
self.num_states += 1;
self.states[new_idx].next = self.states[next_idx].next;
let t0 = self.states[next_idx].counts[0] as u64;
let t1 = self.states[next_idx].counts[1] as u64;
let cc = cur_count as u64;
let tt = target_total as u64;
let new_c0 = ((t0 * cc) / tt).max(1) as u32;
let new_c1 = ((t1 * cc) / tt).max(1) as u32;
self.states[new_idx].counts[0] = new_c0;
self.states[new_idx].counts[1] = new_c1;
self.states[next_idx].counts[0] =
self.states[next_idx].counts[0].saturating_sub(new_c0.saturating_sub(1));
self.states[next_idx].counts[1] =
self.states[next_idx].counts[1].saturating_sub(new_c1.saturating_sub(1));
self.states[cur].next[b] = new_idx as u32;
self.current_state = new_idx as u32;
} else {
self.current_state = next_idx as u32;
}
} else {
self.current_state = next_idx as u32;
}
}
#[inline]
fn on_byte_complete(&mut self, _byte: u8) {
}
fn reset(&mut self) {
for s in self.states[..self.max_states].iter_mut() {
*s = DmcState::EMPTY;
}
self.init_states();
}
}
pub struct DmcModel {
instances: Vec<DmcInstance>,
}
impl DmcModel {
pub fn new_single() -> Self {
DmcModel {
instances: vec![DmcInstance::new(4 * 1024 * 1024, 2)],
}
}
pub fn new_forest() -> Self {
DmcModel {
instances: vec![
DmcInstance::new(2 * 1024 * 1024, 2), DmcInstance::new(2 * 1024 * 1024, 4), DmcInstance::new(2 * 1024 * 1024, 8), ],
}
}
#[inline]
pub fn predict(&self) -> u32 {
if self.instances.len() == 1 {
return self.instances[0].predict();
}
let mut sum: u64 = 0;
for inst in &self.instances {
sum += inst.predict() as u64;
}
let p = (sum / self.instances.len() as u64) as u32;
p.clamp(1, 4095)
}
#[inline]
pub fn update(&mut self, bit: u8) {
for inst in &mut self.instances {
inst.update(bit);
if inst.num_states >= inst.max_states {
inst.reset();
}
}
}
#[inline]
pub fn on_byte_complete(&mut self, byte: u8) {
for inst in &mut self.instances {
inst.on_byte_complete(byte);
}
}
}
impl Default for DmcModel {
fn default() -> Self {
Self::new_single()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_prediction_balanced() {
let model = DmcModel::new_single();
let p = model.predict();
assert!(
(1800..=2200).contains(&p),
"initial prediction should be near 2048, got {p}"
);
}
#[test]
fn prediction_always_in_range() {
let mut model = DmcModel::new_single();
let data = b"Hello, World! This is a test of the DMC model.";
for &byte in data {
for bpos in 0..8u8 {
let p = model.predict();
assert!(
(1..=4095).contains(&p),
"prediction out of range at bpos {bpos}: {p}"
);
let bit = (byte >> (7 - bpos)) & 1;
model.update(bit);
}
model.on_byte_complete(byte);
}
}
#[test]
fn adapts_to_repeated_bytes() {
let mut model = DmcModel::new_single();
let byte = b'A'; for _ in 0..200 {
for bpos in 0..8u8 {
let bit = (byte >> (7 - bpos)) & 1;
let _p = model.predict();
model.update(bit);
}
model.on_byte_complete(byte);
}
let p = model.predict();
assert!(
p < 1500,
"after 200 'A' bytes, P(bit7=1) should be low, got {p}"
);
}
#[test]
fn deterministic() {
let data = b"test determinism of dmc model";
let mut m1 = DmcModel::new_single();
let mut m2 = DmcModel::new_single();
for &byte in data.iter() {
for bpos in 0..8u8 {
let p1 = m1.predict();
let p2 = m2.predict();
assert_eq!(p1, p2, "models diverged at bpos {bpos}");
let bit = (byte >> (7 - bpos)) & 1;
m1.update(bit);
m2.update(bit);
}
m1.on_byte_complete(byte);
m2.on_byte_complete(byte);
}
}
#[test]
fn forest_prediction_balanced() {
let model = DmcModel::new_forest();
let p = model.predict();
assert!(
(1800..=2200).contains(&p),
"forest initial prediction should be near 2048, got {p}"
);
}
#[test]
fn forest_deterministic() {
let data = b"test forest determinism with some longer context data here";
let mut m1 = DmcModel::new_forest();
let mut m2 = DmcModel::new_forest();
for &byte in data.iter() {
for bpos in 0..8u8 {
let p1 = m1.predict();
let p2 = m2.predict();
assert_eq!(p1, p2, "forest models diverged at bpos {bpos}");
let bit = (byte >> (7 - bpos)) & 1;
m1.update(bit);
m2.update(bit);
}
m1.on_byte_complete(byte);
m2.on_byte_complete(byte);
}
}
#[test]
fn solo_bpb_alice29_prefix() {
let data = include_bytes!("../../../../corpus/alice29.txt");
let prefix = &data[..10_000.min(data.len())];
let mut model = DmcModel::new_single();
let mut total_bits: f64 = 0.0;
for &byte in prefix {
for bpos in 0..8u8 {
let p = model.predict();
let bit = (byte >> (7 - bpos)) & 1;
let prob_of_bit = if bit == 1 {
p as f64 / 4096.0
} else {
1.0 - p as f64 / 4096.0
};
total_bits += -prob_of_bit.max(1e-9).log2();
model.update(bit);
}
model.on_byte_complete(byte);
}
let bpb = total_bits / prefix.len() as f64;
eprintln!("DMC solo bpb on 10KB alice29: {bpb:.3}");
assert!(bpb < 9.0, "DMC solo bpb too high: {bpb:.3}");
}
}