use crate::state::context_map::ContextMap;
use crate::state::state_map::StateMap;
use crate::state::state_table::StateTable;
const PRED_TABLE_SIZE: usize = 1 << 20; const PRED_TABLE_MASK: usize = PRED_TABLE_SIZE - 1;
const INDIRECT_CM_SIZE: usize = 1 << 23;
const FNV_OFFSET: u32 = 0x811C9DC5;
const FNV_PRIME: u32 = 0x01000193;
pub struct IndirectModel {
prediction_table: Vec<u8>,
count_table: Vec<u8>,
context_map: ContextMap,
state_map: StateMap,
ctx_hash: u32,
predicted_byte: u8,
last_cm_hash: u32,
c0: u32,
c1: u8,
c2: u8,
c3: u8,
bpos: u8,
}
impl IndirectModel {
pub fn new() -> Self {
IndirectModel {
prediction_table: vec![0u8; PRED_TABLE_SIZE],
count_table: vec![0u8; PRED_TABLE_SIZE],
context_map: ContextMap::new(INDIRECT_CM_SIZE),
state_map: StateMap::new(),
ctx_hash: FNV_OFFSET,
predicted_byte: 0,
last_cm_hash: 0,
c0: 1,
c1: 0,
c2: 0,
c3: 0,
bpos: 0,
}
}
#[inline]
pub fn predict(&mut self, c0: u32, bpos: u8, c1: u8) -> u32 {
if bpos == 0 {
self.ctx_hash = indirect_hash(c1, self.c2, self.c3);
let idx = self.ctx_hash as usize & PRED_TABLE_MASK;
self.predicted_byte = self.prediction_table[idx];
}
let cm_hash = predicted_context_hash(self.predicted_byte, c0);
self.last_cm_hash = cm_hash;
let state = self.context_map.get(cm_hash);
self.state_map.predict(state)
}
#[inline]
pub fn update(&mut self, bit: u8) {
let state = self.context_map.get(self.last_cm_hash);
self.state_map.update(state, bit);
let new_state = StateTable::next(state, bit);
self.context_map.set(self.last_cm_hash, new_state);
self.c0 = (self.c0 << 1) | bit as u32;
self.bpos += 1;
if self.bpos >= 8 {
let byte = (self.c0 & 0xFF) as u8;
let idx = self.ctx_hash as usize & PRED_TABLE_MASK;
let current_pred = self.prediction_table[idx];
let current_count = self.count_table[idx];
if byte == current_pred {
self.count_table[idx] = current_count.saturating_add(1);
} else if current_count < 2 {
self.prediction_table[idx] = byte;
self.count_table[idx] = 1;
} else {
self.count_table[idx] = current_count.saturating_sub(1);
}
self.c3 = self.c2;
self.c2 = self.c1;
self.c1 = byte;
self.c0 = 1;
self.bpos = 0;
}
}
}
impl Default for IndirectModel {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn indirect_hash(c1: u8, c2: u8, c3: u8) -> u32 {
let mut h = FNV_OFFSET;
h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
fn predicted_context_hash(predicted: u8, c0: u32) -> u32 {
let mut h = 0x9E3779B9u32; h ^= predicted as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0x1FF; h = h.wrapping_mul(FNV_PRIME);
h
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_prediction_in_range() {
let mut model = IndirectModel::new();
let p = model.predict(1, 0, 0);
assert!(
(1..=4095).contains(&p),
"initial prediction should be in valid range, got {p}"
);
}
#[test]
fn predictions_in_range() {
let mut model = IndirectModel::new();
let data = b"Hello, World! The quick brown fox.";
for &byte in data {
for bpos in 0..8u8 {
let bit = (byte >> (7 - bpos)) & 1;
let c0 = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev in 0..bpos {
p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
}
p
};
let p = model.predict(
c0,
bpos,
if bpos == 0 {
byte.wrapping_sub(1)
} else {
byte
},
);
assert!(
(1..=4095).contains(&p),
"prediction out of range at bpos {bpos}: {p}"
);
model.update(bit);
}
}
}
#[test]
fn prediction_table_updates() {
let mut model = IndirectModel::new();
let pattern = b"abcdabcdabcd";
for &byte in pattern {
for bpos in 0..8u8 {
let bit = (byte >> (7 - bpos)) & 1;
let c0 = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev in 0..bpos {
p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
}
p
};
let _ = model.predict(c0, bpos, model.c1);
model.update(bit);
}
}
let idx = indirect_hash(b'c', b'b', b'a') as usize & PRED_TABLE_MASK;
assert_eq!(
model.prediction_table[idx], b'd',
"prediction table should predict 'd' after 'abc'"
);
}
#[test]
fn deterministic() {
let data = b"test determinism of indirect model";
let mut m1 = IndirectModel::new();
let mut m2 = IndirectModel::new();
for &byte in data {
for bpos in 0..8u8 {
let bit = (byte >> (7 - bpos)) & 1;
let c0 = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev in 0..bpos {
p = (p << 1) | ((byte >> (7 - prev)) & 1) as u32;
}
p
};
let p1 = m1.predict(c0, bpos, m1.c1);
let p2 = m2.predict(c0, bpos, m2.c1);
assert_eq!(p1, p2, "models diverged at bpos {bpos}");
m1.update(bit);
m2.update(bit);
}
}
}
}