const N_BINS: usize = 64;
const TABLE_SIZE: usize = N_BINS * N_BINS;
pub struct MetaMixer {
table: Vec<u32>,
gru_weight: u32,
last_idx: usize,
last_p: u32,
lr_shift: u32,
}
impl MetaMixer {
pub fn new(gru_weight_pct: u32) -> Self {
let mut table = vec![0u32; TABLE_SIZE];
for cm_bin in 0..N_BINS {
for gru_bin in 0..N_BINS {
let cm_center =
(cm_bin as u32 * 4095 + (N_BINS as u32 - 1) / 2) / (N_BINS as u32 - 1);
let gru_center =
(gru_bin as u32 * 4095 + (N_BINS as u32 - 1) / 2) / (N_BINS as u32 - 1);
let w = (gru_weight_pct * 256 / 100).min(256);
let avg =
(cm_center as u64 * (256 - w) as u64 + gru_center as u64 * w as u64) / 256;
table[cm_bin * N_BINS + gru_bin] = (avg as u32).clamp(1, 4095);
}
}
MetaMixer {
table,
gru_weight: (gru_weight_pct * 256 / 100).min(256),
last_idx: 0,
last_p: 2048,
lr_shift: 5,
}
}
#[inline(always)]
pub fn blend(&mut self, p_cm: u32, p_gru: u32) -> u32 {
let cm_bin = ((p_cm.min(4095) as u64 * (N_BINS as u64 - 1)) / 4095) as usize;
let gru_bin = ((p_gru.min(4095) as u64 * (N_BINS as u64 - 1)) / 4095) as usize;
let idx = cm_bin.min(N_BINS - 1) * N_BINS + gru_bin.min(N_BINS - 1);
self.last_idx = idx;
let table_p = self.table[idx];
let direct = (p_cm as u64 * (256 - self.gru_weight) as u64
+ p_gru as u64 * self.gru_weight as u64)
/ 256;
let blended = u64::midpoint(table_p as u64, direct);
self.last_p = (blended as u32).clamp(1, 4095);
self.last_p
}
#[inline(always)]
pub fn update(&mut self, bit: u8) {
let target = if bit != 0 { 4095u32 } else { 1u32 };
let old = self.table[self.last_idx];
let delta = (target as i32 - old as i32) >> self.lr_shift;
self.table[self.last_idx] = (old as i32 + delta).clamp(1, 4095) as u32;
}
pub fn last_prediction(&self) -> u32 {
self.last_p
}
}
impl Default for MetaMixer {
fn default() -> Self {
Self::new(5) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_blend_biased_to_cm() {
let mut mixer = MetaMixer::new(5);
let p = mixer.blend(2048, 2048);
assert!(
(1900..=2200).contains(&p),
"equal inputs should give ~2048, got {p}"
);
}
#[test]
fn blend_always_in_range() {
let mut mixer = MetaMixer::new(5);
for cm in [1u32, 100, 1000, 2048, 3000, 4000, 4095] {
for gru in [1u32, 100, 1000, 2048, 3000, 4000, 4095] {
let p = mixer.blend(cm, gru);
assert!(
(1..=4095).contains(&p),
"out of range: cm={cm}, gru={gru}, got {p}"
);
}
}
}
#[test]
fn cm_dominates_at_low_weight() {
let mut mixer = MetaMixer::new(5); let p = mixer.blend(3500, 500);
assert!(p > 2500, "5% GRU should let CM dominate: got {p}");
}
#[test]
fn update_adapts() {
let mut mixer = MetaMixer::new(5);
for _ in 0..200 {
mixer.blend(2048, 2048);
mixer.update(1);
}
let p = mixer.blend(2048, 2048);
assert!(p > 2048, "after many 1s, should predict higher: {p}");
}
}