const NUM_BINS: usize = 65;
pub struct APMStage {
table: Vec<[u32; NUM_BINS]>,
num_contexts: usize,
blend: u32,
last_ctx: usize,
last_bin: usize,
last_weight: u32,
}
impl APMStage {
pub fn new(num_contexts: usize, blend_pct: u32) -> Self {
let mut table = vec![[0u32; NUM_BINS]; num_contexts];
for ctx_row in &mut table {
for (i, entry) in ctx_row.iter_mut().enumerate() {
*entry = ((i as u64 * 4095 + (NUM_BINS as u64 - 1) / 2) / (NUM_BINS as u64 - 1))
.clamp(1, 4095) as u32;
}
}
APMStage {
table,
num_contexts,
blend: (blend_pct * 256 / 100).min(256),
last_ctx: 0,
last_bin: 0,
last_weight: 0,
}
}
#[inline(always)]
pub fn predict(&mut self, prob: u32, context: usize) -> u32 {
let ctx = context % self.num_contexts;
self.last_ctx = ctx;
let scaled = prob.min(4095) as u64 * (NUM_BINS as u64 - 1);
let bin = (scaled / 4095) as usize;
let bin = bin.min(NUM_BINS - 2); let weight = (scaled % 4095) as u32;
self.last_bin = bin;
self.last_weight = weight;
let t = &self.table[ctx];
let interp = t[bin] as i64 + (t[bin + 1] as i64 - t[bin] as i64) * weight as i64 / 4095;
let apm_p = interp.clamp(1, 4095) as u32;
let blended =
(apm_p as u64 * self.blend as u64 + prob as u64 * (256 - self.blend) as u64) / 256;
(blended as u32).clamp(1, 4095)
}
#[inline(always)]
pub fn update(&mut self, bit: u8) {
let target = if bit != 0 { 4095u32 } else { 1u32 };
let t = &mut self.table[self.last_ctx];
let rate = 4;
let old = t[self.last_bin];
let delta = (target as i32 - old as i32) >> rate;
t[self.last_bin] = (old as i32 + delta).clamp(1, 4095) as u32;
if self.last_bin + 1 < NUM_BINS {
let old2 = t[self.last_bin + 1];
let delta2 = (target as i32 - old2 as i32) >> (rate + 1);
t[self.last_bin + 1] = (old2 as i32 + delta2).clamp(1, 4095) as u32;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_passthrough() {
let mut apm = APMStage::new(1, 0); let p = apm.predict(2048, 0);
assert_eq!(p, 2048);
}
#[test]
fn initial_50_blend_near_identity() {
let mut apm = APMStage::new(1, 50);
let p = apm.predict(2048, 0);
assert!(
(2000..=2096).contains(&p),
"50% blend of identity should be near input: {p}"
);
}
#[test]
fn prediction_in_range() {
let mut apm = APMStage::new(512, 50);
for prob in [1u32, 100, 1000, 2048, 3000, 4000, 4095] {
for ctx in [0usize, 100, 511] {
let p = apm.predict(prob, ctx);
assert!(
(1..=4095).contains(&p),
"out of range: prob={prob}, ctx={ctx}, got {p}"
);
}
}
}
#[test]
fn update_adapts() {
let mut apm = APMStage::new(1, 100); for _ in 0..100 {
apm.predict(2048, 0);
apm.update(1);
}
let p = apm.predict(2048, 0);
assert!(p > 2048, "after many 1s, APM should predict higher: {p}");
}
#[test]
fn different_contexts_independent() {
let mut apm = APMStage::new(2, 100);
for _ in 0..50 {
apm.predict(2048, 0);
apm.update(1);
}
let p = apm.predict(2048, 1);
assert!(
(2000..=2096).contains(&p),
"untrained context should be near 2048: {p}"
);
}
#[test]
fn extreme_inputs() {
let mut apm = APMStage::new(1, 50);
let p_low = apm.predict(1, 0);
assert!((1..=100).contains(&p_low), "low input: {p_low}");
let p_high = apm.predict(4095, 0);
assert!((3995..=4095).contains(&p_high), "high input: {p_high}");
}
}