use crate::mixer::logistic::{squash, stretch};
pub const NUM_MODELS: usize = 28;
const FINE_SETS: usize = 65536;
const MEDIUM_SETS: usize = 16384;
const COARSE_SETS: usize = 4096;
const W_SCALE: i32 = 4096;
const INITIAL_WEIGHTS: [i32; NUM_MODELS] = [
200, 300, 60, 350, 60, 450, 60, 450, 60, 450, 60, 300, 60, 250, 60, 200, 60, 180, 60, 300, 250, 250, 200, 250, 200, 50, 30, 150, ];
const FINE_LR: i32 = 2;
const MEDIUM_LR: i32 = 3;
const COARSE_LR: i32 = 4;
pub struct DualMixer {
fine_weights: Vec<[i32; NUM_MODELS]>,
medium_weights: Vec<[i32; NUM_MODELS]>,
coarse_weights: Vec<[i32; NUM_MODELS]>,
last_d: [i32; NUM_MODELS],
last_fine_ctx: usize,
last_medium_ctx: usize,
last_coarse_ctx: usize,
last_p: u32,
}
impl DualMixer {
pub fn new() -> Self {
DualMixer {
fine_weights: vec![INITIAL_WEIGHTS; FINE_SETS],
medium_weights: vec![INITIAL_WEIGHTS; MEDIUM_SETS],
coarse_weights: vec![INITIAL_WEIGHTS; COARSE_SETS],
last_d: [0; NUM_MODELS],
last_fine_ctx: 0,
last_medium_ctx: 0,
last_coarse_ctx: 0,
last_p: 2048,
}
}
#[inline(always)]
#[allow(clippy::needless_range_loop)]
#[allow(clippy::too_many_arguments)]
pub fn predict(
&mut self,
predictions: &[u32; NUM_MODELS],
c0: u32,
c1: u8,
bpos: u8,
byte_class: u8,
match_len_q: u8,
run_q: u8,
_xml_state: u8,
) -> u32 {
for i in 0..NUM_MODELS {
self.last_d[i] = stretch(predictions[i]);
}
self.last_fine_ctx = fine_context(c0, c1, bpos, byte_class, match_len_q, run_q);
self.last_medium_ctx = medium_context(c0, c1, bpos, run_q, match_len_q);
self.last_coarse_ctx = coarse_context(c0, bpos);
let fw = &self.fine_weights[self.last_fine_ctx];
let mw = &self.medium_weights[self.last_medium_ctx];
let cw = &self.coarse_weights[self.last_coarse_ctx];
let mut fine_sum: i64 = 0;
let mut medium_sum: i64 = 0;
let mut coarse_sum: i64 = 0;
for i in 0..NUM_MODELS {
let d = self.last_d[i] as i64;
fine_sum += fw[i] as i64 * d;
medium_sum += mw[i] as i64 * d;
coarse_sum += cw[i] as i64 * d;
}
let fine_d = (fine_sum / W_SCALE as i64) as i32;
let medium_d = (medium_sum / W_SCALE as i64) as i32;
let coarse_d = (coarse_sum / W_SCALE as i64) as i32;
let blended_d = (fine_d as i64 * 5 + medium_d as i64 * 3 + coarse_d as i64 * 2) / 10;
let p = squash(blended_d as i32).clamp(1, 4095);
self.last_p = p;
p
}
#[inline(always)]
#[allow(clippy::needless_range_loop)]
pub fn update(&mut self, bit: u8) {
let error = (bit as i32) * 4096 - self.last_p as i32;
let fw = &mut self.fine_weights[self.last_fine_ctx];
for i in 0..NUM_MODELS {
let delta = (FINE_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
fw[i] = (fw[i] as i64 + delta).clamp(-32768, 32767) as i32;
}
let mw = &mut self.medium_weights[self.last_medium_ctx];
for i in 0..NUM_MODELS {
let delta = (MEDIUM_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
mw[i] = (mw[i] as i64 + delta).clamp(-32768, 32767) as i32;
}
let cw = &mut self.coarse_weights[self.last_coarse_ctx];
for i in 0..NUM_MODELS {
let delta = (COARSE_LR as i64 * self.last_d[i] as i64 * error as i64) >> 16;
cw[i] = (cw[i] as i64 + delta).clamp(-32768, 32767) as i32;
}
}
}
impl Default for DualMixer {
fn default() -> Self {
Self::new()
}
}
#[inline]
pub fn byte_class(b: u8) -> u8 {
match b {
0..=31 => 0, b' ' => 1, b'0'..=b'9' => 2, b'A'..=b'Z' => 3, b'a'..=b'z' => 4, b'!'..=b'/' => 5, b':'..=b'@' => 5, b'['..=b'`' => 5, b'{'..=b'~' => 5, 0x80..=0x9F => 6, 0xA0..=0xBF => 7, 0xC0..=0xDF => 8, 0xE0..=0xFE => 9, 0xFF => 10, _ => 11, }
}
#[inline]
fn fine_context(c0: u32, c1: u8, bpos: u8, bclass: u8, match_q: u8, run_q: u8) -> usize {
let mut h: usize = c0 as usize & 0xFF;
h = h.wrapping_mul(97) + (c1 as usize >> 4);
h = h.wrapping_mul(97) + bpos as usize;
h = h.wrapping_mul(97) + (bclass as usize & 0x7);
h = h.wrapping_mul(97) + (match_q as usize & 0x3);
h = h.wrapping_mul(97) + (run_q as usize & 0x3);
h & (FINE_SETS - 1)
}
#[inline]
fn medium_context(c0: u32, c1: u8, bpos: u8, run_q: u8, match_q: u8) -> usize {
let bclass = byte_class(c1);
let mut h: usize = c0 as usize & 0xFF;
h = h.wrapping_mul(67) + (c1 as usize >> 4);
h = h.wrapping_mul(67) + bpos as usize;
h = h.wrapping_mul(67) + bclass as usize;
h = h.wrapping_mul(67) + (run_q as usize & 0x3);
h = h.wrapping_mul(67) + (match_q as usize & 0x3);
h & (MEDIUM_SETS - 1)
}
#[inline]
fn coarse_context(c0: u32, bpos: u8) -> usize {
((c0 as usize & 0xFF) | ((bpos as usize) << 8)) & (COARSE_SETS - 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_prediction_near_balanced() {
let mut mixer = DualMixer::new();
let preds = [2048u32; NUM_MODELS];
let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
assert!(
(1900..=2100).contains(&p),
"initial prediction should be near 2048, got {p}"
);
}
#[test]
fn prediction_in_range() {
let mut mixer = DualMixer::new();
let mut preds = [2048u32; NUM_MODELS];
preds[0] = 100;
preds[1] = 4000;
preds[4] = 3000;
preds[7] = 500;
let p = mixer.predict(&preds, 128, b'a', 3, 4, 1, 0, 0);
assert!((1..=4095).contains(&p), "prediction out of range: {p}");
}
#[test]
fn update_changes_weights() {
let mut mixer = DualMixer::new();
let preds = [2048u32; NUM_MODELS];
mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
let before = mixer.fine_weights[mixer.last_fine_ctx];
mixer.update(1);
let after = mixer.fine_weights[mixer.last_fine_ctx];
let _ = (before, after);
}
#[test]
fn mixer_adapts_to_biased_input() {
let mut mixer = DualMixer::new();
for _ in 0..100 {
let mut preds = [2048u32; NUM_MODELS];
preds[0] = 3500;
let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
let _ = p;
mixer.update(1);
}
let mut preds = [2048u32; NUM_MODELS];
preds[0] = 3500;
let p = mixer.predict(&preds, 1, 0, 0, 0, 0, 0, 0);
assert!(p > 2500, "mixer should have learned to trust model 0: {p}");
}
#[test]
fn byte_class_categories() {
assert_eq!(byte_class(0), 0); assert_eq!(byte_class(b' '), 1); assert_eq!(byte_class(b'5'), 2); assert_eq!(byte_class(b'A'), 3); assert_eq!(byte_class(b'z'), 4); assert_eq!(byte_class(b'.'), 5); assert_eq!(byte_class(0x80), 6); assert_eq!(byte_class(0x90), 6); assert_eq!(byte_class(0xA0), 7); assert_eq!(byte_class(0xC0), 8); assert_eq!(byte_class(0xE0), 9); assert_eq!(byte_class(0xFF), 10); }
#[test]
fn fine_context_in_range() {
for c0 in [1u32, 128, 255] {
for bpos in 0..8u8 {
let ctx = fine_context(c0, 0xFF, bpos, 7, 3, 3);
assert!(ctx < FINE_SETS, "fine context out of range: {ctx}");
}
}
}
#[test]
fn medium_context_in_range() {
for c0 in [1u32, 128, 255] {
for bpos in 0..8u8 {
let ctx = medium_context(c0, 0xFF, bpos, 3, 3);
assert!(ctx < MEDIUM_SETS, "medium context out of range: {ctx}");
}
}
}
#[test]
fn coarse_context_in_range() {
for c0 in [1u32, 128, 255] {
for bpos in 0..8u8 {
let ctx = coarse_context(c0, bpos);
assert!(ctx < COARSE_SETS, "coarse context out of range: {ctx}");
}
}
}
}