use crate::state::state_map::StateMap;
use crate::state::state_table::StateTable;
pub struct Order0Model {
states: [u8; 256],
state_map: StateMap,
}
impl Order0Model {
pub fn new() -> Self {
Order0Model {
states: [0u8; 256],
state_map: StateMap::new(),
}
}
#[inline]
pub fn predict(&self, context: usize) -> u32 {
let state = self.states[context & 0xFF];
self.state_map.predict(state)
}
#[inline]
pub fn update(&mut self, context: usize, bit: u8) {
let ctx = context & 0xFF;
let state = self.states[ctx];
self.state_map.update(state, bit);
self.states[ctx] = StateTable::next(state, bit);
}
}
impl Default for Order0Model {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_prediction_is_balanced() {
let model = Order0Model::new();
let p = model.predict(1);
assert_eq!(p, 2048, "initial prediction should be 2048");
}
#[test]
fn prediction_in_range() {
let model = Order0Model::new();
for c in 1..=255 {
let p = model.predict(c);
assert!(
(1..=4095).contains(&p),
"context {c}: pred {p} out of range"
);
}
}
#[test]
fn update_adapts_prediction() {
let mut model = Order0Model::new();
let before = model.predict(1);
model.update(1, 1);
let after = model.predict(1);
assert_ne!(before, after, "prediction should change after update");
}
#[test]
fn different_contexts_have_separate_states() {
let mut model = Order0Model::new();
for _ in 0..50 {
model.update(10, 1);
}
for _ in 0..50 {
model.update(5, 0);
}
let p5 = model.predict(5);
let p10 = model.predict(10);
assert!(
p10 > p5,
"context 10 (all 1s) should predict higher than context 5 (all 0s): p10={p10}, p5={p5}"
);
}
#[test]
fn simulate_byte_encoding() {
let mut model = Order0Model::new();
let byte: u8 = 0x42;
let mut c: usize = 1;
for bpos in 0..8 {
let bit = (byte >> (7 - bpos)) & 1;
let _p = model.predict(c);
model.update(c, bit);
c = (c << 1) | bit as usize;
}
assert_eq!(c, 0x42 + 256);
}
#[test]
fn repeated_pattern_adapts() {
let mut model = Order0Model::new();
let byte: u8 = 0x41;
let mut total_surprise: f64 = 0.0;
let mut first_byte_surprise: f64 = 0.0;
for iteration in 0..20 {
let mut c: usize = 1;
let mut byte_surprise: f64 = 0.0;
for bpos in 0..8 {
let bit = (byte >> (7 - bpos)) & 1;
let p = model.predict(c);
let prob_of_bit = if bit == 1 {
p as f64 / 4096.0
} else {
1.0 - p as f64 / 4096.0
};
byte_surprise += -prob_of_bit.log2();
model.update(c, bit);
c = (c << 1) | bit as usize;
}
if iteration == 0 {
first_byte_surprise = byte_surprise;
}
total_surprise += byte_surprise;
}
let last_avg = total_surprise / 20.0;
assert!(
last_avg < first_byte_surprise,
"model should improve: first byte = {first_byte_surprise:.2} bits, avg = {last_avg:.2} bits"
);
}
}