use crate::state::context_map::ContextMap;
use crate::state::state_map::StateMap;
use crate::state::state_table::StateTable;
pub struct WordModel {
cmap: ContextMap,
smap: StateMap,
word_hash: u32,
prev_word_hash: u32,
prev2_word_hash: u32,
in_word: bool,
word_len: u32,
last_state: u8,
last_hash: u32,
}
impl WordModel {
pub fn new() -> Self {
Self::with_size(1 << 24) }
pub fn with_size(cmap_size: usize) -> Self {
WordModel {
cmap: ContextMap::new(cmap_size),
smap: StateMap::new(),
word_hash: 0,
prev_word_hash: 0,
prev2_word_hash: 0,
in_word: false,
word_len: 0,
last_state: 0,
last_hash: 0,
}
}
#[inline]
pub fn predict(&mut self, c0: u32, bpos: u8, c1: u8) -> u32 {
if bpos == 0 {
self.update_word_state(c1);
}
let mut h = self.word_hash;
h = h.wrapping_mul(0x01000193) ^ (c0 & 0xFF);
h = h.wrapping_mul(0x01000193) ^ self.prev_word_hash;
let len_q = self.word_len.min(7);
h = h.wrapping_mul(0x01000193) ^ len_q;
let state = self.cmap.get(h);
self.last_state = state;
self.last_hash = h;
self.smap.predict(state)
}
#[inline]
pub fn update(&mut self, bit: u8) {
self.smap.update(self.last_state, bit);
let new_state = StateTable::next(self.last_state, bit);
self.cmap.set(self.last_hash, new_state);
}
fn update_word_state(&mut self, c1: u8) {
let is_word_char = c1.is_ascii_alphanumeric() || c1 == b'_';
if is_word_char {
if !self.in_word {
self.prev2_word_hash = self.prev_word_hash;
self.prev_word_hash = self.word_hash;
self.word_hash = 0;
self.word_len = 0;
self.in_word = true;
}
let ch = if c1.is_ascii_uppercase() {
c1 + 32 } else {
c1
};
self.word_hash = self.word_hash.wrapping_mul(0x01000193) ^ ch as u32;
self.word_len += 1;
} else {
if self.in_word {
self.in_word = false;
}
self.word_hash = c1 as u32;
self.word_len = 0;
}
}
#[inline]
pub fn in_word(&self) -> bool {
self.in_word
}
#[inline]
pub fn word_len_quantized(&self) -> u8 {
match self.word_len {
0 => 0,
1..=2 => 1,
3..=6 => 2,
_ => 3,
}
}
}
impl Default for WordModel {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_prediction_balanced() {
let mut wm = WordModel::new();
let p = wm.predict(1, 0, 0);
assert_eq!(p, 2048);
}
#[test]
fn predictions_in_range() {
let mut wm = WordModel::new();
for c in 0..=255u8 {
let p = wm.predict(1, 0, c);
assert!((1..=4095).contains(&p));
wm.update(0);
}
}
#[test]
fn word_context_changes() {
let mut wm = WordModel::new();
for &ch in b"hello" {
for bpos in 0..8u8 {
let bit = (ch >> (7 - bpos)) & 1;
wm.predict(1, bpos, if bpos == 0 { ch } else { 0 });
wm.update(bit);
}
}
let p1 = wm.predict(1, 0, b'o');
let mut wm2 = WordModel::new();
for &ch in b"world" {
for bpos in 0..8u8 {
let bit = (ch >> (7 - bpos)) & 1;
wm2.predict(1, bpos, if bpos == 0 { ch } else { 0 });
wm2.update(bit);
}
}
let p2 = wm2.predict(1, 0, b'd');
assert!((1..=4095).contains(&p1));
assert!((1..=4095).contains(&p2));
}
}