use crate::model::cm_model::ContextModel;
const N_MODELS: usize = 4;
pub struct NeuralModel {
models: Vec<ContextModel>,
}
impl NeuralModel {
pub fn new() -> Self {
Self::with_size(1 << 21) }
pub fn with_size(size: usize) -> Self {
let mut models = Vec::with_capacity(N_MODELS);
for _ in 0..N_MODELS {
models.push(ContextModel::new(size));
}
NeuralModel { models }
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub fn predict(
&mut self,
c0: u32,
bpos: u8,
c1: u8,
c2: u8,
c3: u8,
run_len: u8,
match_q: u8,
) -> u32 {
let c0_full = c0;
let c1_hi = (c1 >> 4) as u32;
let h0 = fhash3(c0_full, c1_hi, 0xA1B2_C3D4, 0xDEAD_1001);
let p0 = self.models[0].predict(h0);
let class_pair = byte_class_pair(c1, c2) as u32;
let h1 = fhash3(c0_full, class_pair, 0xE5F6_0718, 0xBEEF_2002);
let p1 = self.models[1].predict(h1);
let rq = quantize_run(run_len) as u32;
let h2 = fhash4(c0_full, c1 as u32, rq, bpos as u32, 0xCAFE_3003);
let p2 = self.models[2].predict(h2);
let c2_lo = (c2 & 0x0F) as u32;
let h3 = fhash4(c0_full, c2_lo, match_q as u32, c3 as u32, 0xFACE_4004);
let p3 = self.models[3].predict(h3);
let sum = p0 + p1 + p2 + p3;
let avg = sum / N_MODELS as u32;
avg.clamp(1, 4095)
}
#[inline]
pub fn update(&mut self, bit: u8) {
for model in &mut self.models {
model.update(bit);
}
}
}
impl Default for NeuralModel {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn byte_class_pair(c1: u8, c2: u8) -> u8 {
byte_class_6(c1) * 6 + byte_class_6(c2)
}
#[inline]
fn byte_class_6(b: u8) -> u8 {
match b {
b'a'..=b'z' => 0,
b'A'..=b'Z' => 1,
b'0'..=b'9' => 2,
b' ' | b'\t' => 3,
b'\n' | b'\r' => 4,
_ => 5,
}
}
#[inline]
fn quantize_run(run_len: u8) -> u8 {
match run_len {
0..=1 => 0,
2..=3 => 1,
4..=8 => 2,
_ => 3,
}
}
#[inline]
fn fhash3(a: u32, b: u32, c: u32, seed: u32) -> u32 {
let mut h = seed;
h ^= a;
h = h.wrapping_mul(0x0100_0193);
h ^= b;
h = h.wrapping_mul(0x0100_0193);
h ^= c;
h = h.wrapping_mul(0x0100_0193);
h
}
#[inline]
fn fhash4(a: u32, b: u32, c: u32, d: u32, seed: u32) -> u32 {
let mut h = seed;
h ^= a;
h = h.wrapping_mul(0x0100_0193);
h ^= b;
h = h.wrapping_mul(0x0100_0193);
h ^= c;
h = h.wrapping_mul(0x0100_0193);
h ^= d;
h = h.wrapping_mul(0x0100_0193);
h
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_prediction_near_half() {
let mut model = NeuralModel::new();
let p = model.predict(1, 0, 0, 0, 0, 0, 0);
assert!(
(1800..=2200).contains(&p),
"initial prediction should be near 2048, got {p}"
);
}
#[test]
fn prediction_always_in_range() {
let mut model = NeuralModel::new();
for c1 in [0u8, 65, 128, 255] {
for bpos in 0..8u8 {
let p = model.predict(1, bpos, c1, 0, 0, 0, 0);
assert!((1..=4095).contains(&p), "prediction out of range: {p}");
model.update(1);
}
}
}
#[test]
fn deterministic() {
let mut m1 = NeuralModel::new();
let mut m2 = NeuralModel::new();
let data: &[u8] = b"Hello World";
for &byte in data {
for bpos in 0..8u8 {
let p1 = m1.predict(1, bpos, byte, 0, 0, 0, 0);
let p2 = m2.predict(1, bpos, byte, 0, 0, 0, 0);
assert_eq!(p1, p2, "neural models diverged");
let bit = (byte >> (7 - bpos)) & 1;
m1.update(bit);
m2.update(bit);
}
}
}
#[test]
fn adapts_to_data() {
let mut model = NeuralModel::new();
let mut first_p = 0;
for i in 0..200 {
let p = model.predict(1, 0, b'A', b'B', b'C', 1, 0);
if i == 0 {
first_p = p;
}
model.update(1);
}
let final_p = model.predict(1, 0, b'A', b'B', b'C', 1, 0);
assert!(
final_p > first_p,
"model should adapt: first={first_p}, final={final_p}"
);
}
#[test]
fn byte_class_categories() {
assert_eq!(byte_class_6(b'a'), 0);
assert_eq!(byte_class_6(b'Z'), 1);
assert_eq!(byte_class_6(b'5'), 2);
assert_eq!(byte_class_6(b' '), 3);
assert_eq!(byte_class_6(b'\n'), 4);
assert_eq!(byte_class_6(b'.'), 5);
}
}