use crate::state::context_map::ContextMap;
use crate::state::state_map::StateMap;
use crate::state::state_table::StateTable;
const FNV_PRIME: u32 = 0x01000193;
const FNV_OFFSET: u32 = 0x811C9DC5;
pub struct SparseModel {
cmap_gap2: ContextMap,
smap_gap2: StateMap,
last_state_gap2: u8,
last_hash_gap2: u32,
cmap_gap3: ContextMap,
smap_gap3: StateMap,
last_state_gap3: u8,
last_hash_gap3: u32,
}
impl SparseModel {
pub fn new() -> Self {
Self::with_size(1 << 23) }
pub fn with_size(cmap_size: usize) -> Self {
SparseModel {
cmap_gap2: ContextMap::new(cmap_size),
smap_gap2: StateMap::new(),
last_state_gap2: 0,
last_hash_gap2: 0,
cmap_gap3: ContextMap::new(cmap_size),
smap_gap3: StateMap::new(),
last_state_gap3: 0,
last_hash_gap3: 0,
}
}
#[inline]
pub fn predict(&mut self, c0: u32, c1: u8, c2: u8, c3: u8) -> u32 {
let h2 = gap2_hash(c2, c0);
let state2 = self.cmap_gap2.get(h2);
self.last_state_gap2 = state2;
self.last_hash_gap2 = h2;
let p2 = self.smap_gap2.predict(state2);
let h3 = gap3_hash(c3, c1, c0);
let state3 = self.cmap_gap3.get(h3);
self.last_state_gap3 = state3;
self.last_hash_gap3 = h3;
let p3 = self.smap_gap3.predict(state3);
((p2 + p3) / 2).clamp(1, 4095)
}
#[inline]
pub fn update(&mut self, bit: u8) {
self.smap_gap2.update(self.last_state_gap2, bit);
let new2 = StateTable::next(self.last_state_gap2, bit);
self.cmap_gap2.set(self.last_hash_gap2, new2);
self.smap_gap3.update(self.last_state_gap3, bit);
let new3 = StateTable::next(self.last_state_gap3, bit);
self.cmap_gap3.set(self.last_hash_gap3, new3);
}
}
impl Default for SparseModel {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn gap2_hash(c2: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET ^ 0xDEAD; h ^= c2 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[inline]
fn gap3_hash(c3: u8, c1: u8, c0: u32) -> u32 {
let mut h = FNV_OFFSET ^ 0xBEEF; h ^= c3 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c1 as u32;
h = h.wrapping_mul(FNV_PRIME);
h ^= c0 & 0xFF;
h = h.wrapping_mul(FNV_PRIME);
h
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_prediction_balanced() {
let mut sm = SparseModel::new();
let p = sm.predict(1, 0, 0, 0);
assert_eq!(p, 2048);
}
#[test]
fn predictions_in_range() {
let mut sm = SparseModel::new();
for i in 0..50u32 {
let p = sm.predict(1, i as u8, (i + 1) as u8, (i + 2) as u8);
assert!((1..=4095).contains(&p));
sm.update((i & 1) as u8);
}
}
}