pub const NUM_REGULAR_CONTEXTS: usize = 365;
pub const NUM_RUN_TERMINATION_CONTEXTS: usize = 2;
pub const NUM_TOTAL_CONTEXTS: usize = NUM_REGULAR_CONTEXTS + NUM_RUN_TERMINATION_CONTEXTS;
#[derive(Clone, Debug)]
pub struct ContextState {
pub cx: i32,
pub k: i32,
pub b: i32,
pub n: i32,
}
impl Default for ContextState {
fn default() -> Self {
Self {
cx: 0,
k: 0,
b: 0,
n: 1,
}
}
}
#[inline]
pub fn context_index(q1: i8, q2: i8, q3: i8) -> (usize, i32) {
let (q1n, q2n, q3n, sign) = if q1 < 0 || (q1 == 0 && q2 < 0) || (q1 == 0 && q2 == 0 && q3 < 0) {
(-q1, -q2, -q3, -1i32)
} else {
(q1, q2, q3, 1i32)
};
let idx = (q1n as usize) * 81 + ((q2n + 4) as usize) * 9 + (q3n + 4) as usize;
(idx.min(NUM_REGULAR_CONTEXTS - 1), sign)
}
pub fn update_context(state: &mut ContextState, err: i32, near: i32, reset: i32, _max_val: i32) {
state.b += err - near;
state.n += 1;
while state.n << state.k < reset {
state.k += 1;
}
if state.b.abs() > reset {
state.b = (state.b + if state.b < 0 { -1 } else { 1 }) / 2;
state.n = (state.n + 1) / 2;
}
if state.b <= -state.n {
state.cx -= 1;
state.b += state.n;
if state.b <= -state.n {
state.b = -state.n + 1;
}
} else if state.b > 0 {
state.cx += 1;
state.b -= state.n;
if state.b > 0 {
state.b = 0;
}
}
state.cx = state.cx.clamp(-128, 127);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_zero_triple_maps_to_centre() {
let (idx, sign) = context_index(0, 0, 0);
assert_eq!(idx, 40);
assert_eq!(sign, 1);
}
#[test]
fn negative_triple_is_sign_normalised() {
let (idx_pos, _) = context_index(1, 2, 3);
let (idx_neg, sign_neg) = context_index(-1, -2, -3);
assert_eq!(idx_pos, idx_neg);
assert_eq!(sign_neg, -1);
}
#[test]
fn index_within_bounds() {
for q1 in -4i8..=4 {
for q2 in -4i8..=4 {
for q3 in -4i8..=4 {
let (idx, _) = context_index(q1, q2, q3);
assert!(
idx < NUM_REGULAR_CONTEXTS,
"idx={idx} out of bounds for ({q1},{q2},{q3})"
);
}
}
}
}
#[test]
fn context_update_does_not_overflow() {
let mut state = ContextState::default();
for err in [-5i32, 3, 0, -1, 7, -10] {
update_context(&mut state, err, 0, 64, 255);
assert!(state.cx >= -128 && state.cx <= 127);
}
}
#[test]
fn total_contexts_is_regular_plus_run_termination() {
assert_eq!(NUM_TOTAL_CONTEXTS, NUM_REGULAR_CONTEXTS + 2);
assert_eq!(NUM_TOTAL_CONTEXTS, 367);
assert_eq!(NUM_RUN_TERMINATION_CONTEXTS, 2);
}
}