use super::{L_HAT_MAX, SpectralAmplitudes, VuvResult, band_for_harmonic};
pub const L_TILDE_COLD_START: u8 = 30;
pub const L_TILDE_COLD_START_HALFRATE: u8 = 15;
pub const RHO_HALFRATE: f64 = 0.65;
pub const LAMBDA_HAT_UNVOICED_BIAS: f64 = 2.289;
pub type PredictionResidual = [f64; L_HAT_MAX as usize + 1];
#[derive(Clone, Debug)]
pub struct PredictorState {
m_tilde_prev: [f64; L_HAT_MAX as usize + 1],
l_tilde_prev: u8,
}
impl PredictorState {
pub const fn cold_start() -> Self {
Self {
m_tilde_prev: [1.0; L_HAT_MAX as usize + 1],
l_tilde_prev: L_TILDE_COLD_START,
}
}
pub(super) fn read(&self, l: u32) -> f64 {
if l == 0 {
return 1.0;
}
let clamped = (l as usize).min(self.l_tilde_prev as usize);
self.m_tilde_prev[clamped]
}
pub fn m_tilde_prev_slice(&self) -> Vec<f32> {
let n = self.l_tilde_prev as usize;
(1..=n).map(|i| self.m_tilde_prev[i] as f32).collect()
}
pub fn l_tilde_prev(&self) -> u8 {
self.l_tilde_prev
}
pub fn commit(&mut self, m_tilde_curr: &[f64], l_tilde_curr: u8) {
assert!(l_tilde_curr as usize <= L_HAT_MAX as usize);
assert!(m_tilde_curr.len() > l_tilde_curr as usize);
self.m_tilde_prev[0] = 1.0;
for l in 1..=l_tilde_curr as usize {
self.m_tilde_prev[l] = m_tilde_curr[l];
}
self.l_tilde_prev = l_tilde_curr;
}
}
impl Default for PredictorState {
fn default() -> Self {
Self::cold_start()
}
}
#[inline]
pub fn imbe_rho_f64(l_hat: u8) -> f64 {
f64::from(crate::imbe_wire::dequantize::imbe_rho(l_hat))
}
pub fn compute_prediction_residual(
m_hat: &SpectralAmplitudes,
l_hat: u8,
state: &PredictorState,
) -> PredictionResidual {
let mut t_hat = [0.0f64; L_HAT_MAX as usize + 1];
if l_hat == 0 {
return t_hat;
}
let l_hat_u32 = u32::from(l_hat);
let l_prev_u32 = u32::from(state.l_tilde_prev.max(1));
let ratio = f64::from(state.l_tilde_prev) / f64::from(l_hat);
let rho = imbe_rho_f64(l_hat);
let mut p = [0.0f64; L_HAT_MAX as usize + 1];
let mut p_sum = 0.0;
for l in 1..=l_hat_u32 {
let k_hat = ratio * f64::from(l);
let k_floor = k_hat.floor();
let delta = k_hat - k_floor;
let kf = k_floor as i64;
let lo = if kf < 0 {
1.0 } else if (kf as u32) <= l_prev_u32 {
state.read(kf as u32)
} else {
state.read(l_prev_u32) };
let hi = if kf < 0 {
state.read(0)
} else if (kf as u32 + 1) <= l_prev_u32 {
state.read(kf as u32 + 1)
} else {
state.read(l_prev_u32) };
let log_lo = if lo > 0.0 { lo.log2() } else { 0.0 };
let log_hi = if hi > 0.0 { hi.log2() } else { 0.0 };
let p_l = (1.0 - delta) * log_lo + delta * log_hi;
p[l as usize] = p_l;
p_sum += p_l;
}
let p_mean = p_sum / f64::from(l_hat);
for l in 1..=l_hat_u32 {
let m_l = m_hat[l as usize];
let log_m = if m_l > 0.0 { m_l.log2() } else { 0.0 };
t_hat[l as usize] = log_m - rho * (p[l as usize] - p_mean);
}
t_hat
}
#[cfg(test)]
mod imbe_tests {
use super::*;
#[test]
fn predictor_state_cold_start_matches_spec() {
let st = PredictorState::cold_start();
assert_eq!(st.l_tilde_prev(), L_TILDE_COLD_START);
for l in 0..=L_HAT_MAX as u32 {
assert_eq!(st.read(l), 1.0, "M̃_{l}(−1) = {}", st.read(l));
}
}
#[test]
fn predictor_state_read_applies_eq56_eq57_boundaries() {
let mut st = PredictorState::cold_start();
let mut m_tilde = [0.0f64; L_HAT_MAX as usize + 1];
m_tilde[1] = 2.0;
m_tilde[2] = 4.0;
m_tilde[3] = 8.0;
st.commit(&m_tilde, 3);
assert_eq!(st.read(0), 1.0);
assert_eq!(st.read(1), 2.0);
assert_eq!(st.read(3), 8.0);
assert_eq!(st.read(4), 8.0);
assert_eq!(st.read(30), 8.0);
}
#[test]
fn compute_prediction_residual_cold_start_equals_log2_amplitude() {
let st = PredictorState::cold_start();
let mut m_hat = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=12u32 {
m_hat[l as usize] = f64::from(l) * 0.5 + 1.0;
}
let t_hat = compute_prediction_residual(&m_hat, 12, &st);
for l in 1..=12u32 {
let expected = m_hat[l as usize].log2();
assert!(
(t_hat[l as usize] - expected).abs() < 1e-12,
"T̂_{l} = {}, expected log₂({}) = {}",
t_hat[l as usize],
m_hat[l as usize],
expected
);
}
}
#[test]
fn compute_prediction_residual_preserves_mean_of_log2_amplitude() {
let mut st = PredictorState::cold_start();
let mut past = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=20u32 {
past[l as usize] = 1.0 + f64::from(l) * 0.1; }
st.commit(&past, 20);
let mut m_hat = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=18u32 {
m_hat[l as usize] = 0.5 + f64::from(l) * 0.2;
}
let l_hat = 18u8;
let t_hat = compute_prediction_residual(&m_hat, l_hat, &st);
let t_mean: f64 = (1..=u32::from(l_hat))
.map(|l| t_hat[l as usize])
.sum::<f64>()
/ f64::from(l_hat);
let log_m_mean: f64 = (1..=u32::from(l_hat))
.map(|l| m_hat[l as usize].log2())
.sum::<f64>()
/ f64::from(l_hat);
assert!(
(t_mean - log_m_mean).abs() < 1e-12,
"mean(T̂) = {t_mean}, mean(log₂ M̂) = {log_m_mean}"
);
}
#[test]
fn compute_prediction_residual_partial_blend_when_prediction_perfect() {
let mut st = PredictorState::cold_start();
let mut m_same = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=16u32 {
m_same[l as usize] = 1.0 + f64::from(l) * 0.3;
}
st.commit(&m_same, 16);
let t_hat = compute_prediction_residual(&m_same, 16, &st);
let log_mean: f64 = (1..=16u32)
.map(|l| m_same[l as usize].log2())
.sum::<f64>()
/ 16.0;
let rho = imbe_rho_f64(16);
for l in 1..=16u32 {
let log_m = m_same[l as usize].log2();
let expected = (1.0 - rho) * log_m + rho * log_mean;
assert!(
(t_hat[l as usize] - expected).abs() < 1e-9,
"T̂_{l} = {}, expected (1-ρ)·log + ρ·mean = {}",
t_hat[l as usize],
expected
);
}
}
#[test]
fn compute_prediction_residual_handles_l_hat_mismatch() {
let mut st = PredictorState::cold_start();
let mut past = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=10u32 {
past[l as usize] = 2.0_f64.powf(f64::from(l) * 0.1);
}
st.commit(&past, 10);
let mut m_hat = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=25u32 {
m_hat[l as usize] = 2.0_f64.powf(f64::from(l) * 0.05);
}
let t_hat = compute_prediction_residual(&m_hat, 25, &st);
for l in 1..=25u32 {
assert!(t_hat[l as usize].is_finite(), "T̂_{l} = {}", t_hat[l as usize]);
}
let t_hat_shrink = compute_prediction_residual(&m_hat, 5, &st);
for l in 1..=5u32 {
assert!(
t_hat_shrink[l as usize].is_finite(),
"T̂_{l} (shrink) = {}",
t_hat_shrink[l as usize]
);
}
}
#[test]
fn imbe_rho_matches_eq55_schedule_points() {
let tol = 1e-5;
assert!((imbe_rho_f64(9) - 0.40).abs() < tol);
assert!((imbe_rho_f64(15) - 0.40).abs() < tol);
assert!((imbe_rho_f64(16) - 0.43).abs() < tol);
assert!((imbe_rho_f64(20) - 0.55).abs() < tol);
assert!((imbe_rho_f64(24) - 0.67).abs() < tol);
assert!((imbe_rho_f64(25) - 0.70).abs() < tol);
assert!((imbe_rho_f64(56) - 0.70).abs() < tol);
}
}
#[derive(Clone, Debug)]
pub struct HalfratePredictorState {
lambda_tilde_prev: [f64; L_HAT_MAX as usize + 1],
l_tilde_prev: u8,
gamma_tilde_prev: f64,
}
impl HalfratePredictorState {
pub const fn cold_start() -> Self {
Self {
lambda_tilde_prev: [1.0; L_HAT_MAX as usize + 1],
l_tilde_prev: L_TILDE_COLD_START_HALFRATE,
gamma_tilde_prev: 0.0,
}
}
pub(super) fn read(&self, l: u32) -> f64 {
let l_prev = self.l_tilde_prev as usize;
if l == 0 {
return self.lambda_tilde_prev[1.min(l_prev.max(1))];
}
let clamped = (l as usize).min(l_prev.max(1));
self.lambda_tilde_prev[clamped]
}
pub fn l_tilde_prev(&self) -> u8 {
self.l_tilde_prev
}
pub fn gamma_tilde_prev(&self) -> f64 {
self.gamma_tilde_prev
}
pub fn lambda_tilde_prev(&self) -> &[f64; L_HAT_MAX as usize + 1] {
&self.lambda_tilde_prev
}
pub fn commit(
&mut self,
lambda_tilde_curr: &[f64],
l_tilde_curr: u8,
gamma_tilde_curr: f64,
) {
assert!(l_tilde_curr as usize <= L_HAT_MAX as usize);
assert!(lambda_tilde_curr.len() > l_tilde_curr as usize);
for l in 1..=l_tilde_curr as usize {
self.lambda_tilde_prev[l] = lambda_tilde_curr[l];
}
self.l_tilde_prev = l_tilde_curr;
self.gamma_tilde_prev = gamma_tilde_curr;
}
}
impl Default for HalfratePredictorState {
fn default() -> Self {
Self::cold_start()
}
}
pub fn lambda_hat_from_m_hat(
m_hat: &SpectralAmplitudes,
vuv: &VuvResult,
omega_hat_0: f64,
l_hat: u8,
) -> [f64; L_HAT_MAX as usize + 1] {
let mut lambda_hat = [0.0f64; L_HAT_MAX as usize + 1];
if l_hat == 0 {
return lambda_hat;
}
let l_hat_f = f64::from(l_hat);
let log2_l_hat = l_hat_f.log2();
let log2_wl = (omega_hat_0 * l_hat_f).log2();
let k_hat = vuv.k_hat;
for l in 1..=u32::from(l_hat) {
let m_l = m_hat[l as usize];
let log_m = if m_l > 0.0 { m_l.log2() } else { 0.0 };
let k = band_for_harmonic(l, k_hat);
let is_voiced = k >= 1 && vuv.vuv[k as usize] == 1;
lambda_hat[l as usize] = if is_voiced {
log_m + 0.5 * log2_l_hat
} else {
log_m + 0.5 * log2_wl + LAMBDA_HAT_UNVOICED_BIAS
};
}
lambda_hat
}
pub fn compute_prediction_residual_ambe_plus2(
lambda_hat: &[f64; L_HAT_MAX as usize + 1],
l_hat: u8,
state: &HalfratePredictorState,
) -> PredictionResidual {
let mut t_hat = [0.0f64; L_HAT_MAX as usize + 1];
if l_hat == 0 {
return t_hat;
}
let l_hat_u32 = u32::from(l_hat);
let l_prev_u32 = u32::from(state.l_tilde_prev.max(1));
let ratio = f64::from(state.l_tilde_prev) / f64::from(l_hat);
let mut p = [0.0f64; L_HAT_MAX as usize + 1];
let mut p_sum = 0.0;
for l in 1..=l_hat_u32 {
let k_hat_f = ratio * f64::from(l);
let k_floor = k_hat_f.floor();
let delta = k_hat_f - k_floor;
let kf = k_floor as i64;
let lo = if kf < 0 {
state.read(0)
} else if (kf as u32) <= l_prev_u32 {
state.read(kf as u32)
} else {
state.read(l_prev_u32) };
let hi = if kf < 0 {
state.read(0)
} else if (kf as u32 + 1) <= l_prev_u32 {
state.read(kf as u32 + 1)
} else {
state.read(l_prev_u32)
};
let p_l = (1.0 - delta) * lo + delta * hi;
p[l as usize] = p_l;
p_sum += p_l;
}
let p_mean = p_sum / f64::from(l_hat);
for l in 1..=l_hat_u32 {
t_hat[l as usize] =
lambda_hat[l as usize] - RHO_HALFRATE * (p[l as usize] - p_mean);
}
t_hat
}
#[cfg(test)]
mod ambe_plus2_tests {
use super::*;
fn zero_vuv(k_hat: u8) -> VuvResult {
VuvResult {
k_hat,
vuv: [0; super::super::K_HAT_MAX as usize + 1],
d_k: [0.0; super::super::K_HAT_MAX as usize + 1],
theta_k: [0.0; super::super::K_HAT_MAX as usize + 1],
m_xi: 0.0,
xi_0: 0.0,
xi_max_after: 0.0,
}
}
fn all_voiced(k_hat: u8) -> VuvResult {
let mut v = [0u8; super::super::K_HAT_MAX as usize + 1];
for k in 1..=k_hat as usize {
v[k] = 1;
}
VuvResult {
k_hat,
vuv: v,
d_k: [0.0; super::super::K_HAT_MAX as usize + 1],
theta_k: [0.0; super::super::K_HAT_MAX as usize + 1],
m_xi: 0.0,
xi_0: 0.0,
xi_max_after: 0.0,
}
}
#[test]
fn ambe_plus2_predictor_cold_start_values() {
let st = HalfratePredictorState::cold_start();
assert_eq!(st.l_tilde_prev(), 15);
assert_eq!(st.gamma_tilde_prev(), 0.0);
for l in 0..=L_HAT_MAX as usize {
assert_eq!(st.lambda_tilde_prev[l], 1.0);
}
}
#[test]
fn ambe_plus2_predictor_read_applies_eq156_eq157() {
let mut st = HalfratePredictorState::cold_start();
st.lambda_tilde_prev[1] = 0.1;
st.lambda_tilde_prev[2] = 0.2;
st.lambda_tilde_prev[3] = 0.3;
st.lambda_tilde_prev[4] = 0.4;
st.lambda_tilde_prev[5] = 0.5;
st.l_tilde_prev = 5;
assert_eq!(st.read(0), 0.1);
assert_eq!(st.read(1), 0.1);
assert_eq!(st.read(5), 0.5);
assert_eq!(st.read(6), 0.5);
assert_eq!(st.read(999), 0.5);
}
#[test]
fn lambda_hat_voiced_adjustment_matches_eq150() {
let l_hat = 10u8;
let mut m_hat = [0.0f64; L_HAT_MAX as usize + 1];
m_hat[1] = 16.0;
let vuv = all_voiced(4); let lambda_hat = lambda_hat_from_m_hat(&m_hat, &vuv, 0.3, l_hat);
let expected = (16.0_f64).log2() + 0.5 * (10.0_f64).log2();
assert!(
(lambda_hat[1] - expected).abs() < 1e-12,
"lambda_hat[1] = {} vs expected {}",
lambda_hat[1],
expected
);
}
#[test]
fn lambda_hat_unvoiced_adds_2289_plus_omega_l_term() {
let l_hat = 10u8;
let mut m_hat = [0.0f64; L_HAT_MAX as usize + 1];
m_hat[1] = 16.0;
let vuv = zero_vuv(4);
let lambda_hat = lambda_hat_from_m_hat(&m_hat, &vuv, 0.3, l_hat);
let expected =
(16.0_f64).log2() + 0.5 * (0.3_f64 * 10.0).log2() + LAMBDA_HAT_UNVOICED_BIAS;
assert!(
(lambda_hat[1] - expected).abs() < 1e-12,
"lambda_hat[1] = {} vs expected {}",
lambda_hat[1],
expected
);
}
#[test]
fn lambda_hat_zero_magnitude_uses_zero_log_not_neg_infinity() {
let l_hat = 5u8;
let m_hat = [0.0f64; L_HAT_MAX as usize + 1];
let vuv = all_voiced(2);
let lambda_hat = lambda_hat_from_m_hat(&m_hat, &vuv, 0.2, l_hat);
for l in 1..=l_hat as usize {
let expected = 0.5 * (5.0_f64).log2();
assert!(
(lambda_hat[l] - expected).abs() < 1e-12,
"l={l}: {} vs {}",
lambda_hat[l],
expected
);
}
}
#[test]
fn prediction_residual_ambe_plus2_cold_start_equals_lambda_hat_minus_rho_residual() {
let l_hat = 12u8;
let mut lambda_hat = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=l_hat as usize {
lambda_hat[l] = 3.0;
}
let st = HalfratePredictorState::cold_start();
let t_hat = compute_prediction_residual_ambe_plus2(&lambda_hat, l_hat, &st);
for l in 1..=l_hat as usize {
assert!(
(t_hat[l] - 3.0).abs() < 1e-12,
"l={l}: {} vs 3.0",
t_hat[l]
);
}
}
#[test]
fn prediction_residual_ambe_plus2_perfect_prediction_gives_zero_mean() {
let l_hat = 20u8;
let mut lambda_hat = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=l_hat as usize {
lambda_hat[l] = (l as f64) * 0.1;
}
let mut st = HalfratePredictorState::cold_start();
st.l_tilde_prev = 20;
for l in 1..=20 {
st.lambda_tilde_prev[l] = (l as f64) * 0.1;
}
let t_hat = compute_prediction_residual_ambe_plus2(&lambda_hat, l_hat, &st);
let sum_t: f64 = t_hat[1..=l_hat as usize].iter().sum();
let sum_lambda: f64 = lambda_hat[1..=l_hat as usize].iter().sum();
assert!(
(sum_t - sum_lambda).abs() < 1e-10,
"Σ T̂ = {} vs Σ Λ̂ = {} (mean-removal should cancel on perfect prediction)",
sum_t,
sum_lambda,
);
}
#[test]
fn ambe_plus2_predictor_commit_advances_state() {
let mut st = HalfratePredictorState::cold_start();
let mut lambda_curr = [0.0f64; L_HAT_MAX as usize + 1];
for l in 1..=12 {
lambda_curr[l] = l as f64 * 0.25;
}
st.commit(&lambda_curr, 12, 5.0);
assert_eq!(st.l_tilde_prev(), 12);
assert_eq!(st.gamma_tilde_prev(), 5.0);
for l in 1..=12 {
assert_eq!(st.lambda_tilde_prev[l], l as f64 * 0.25);
}
}
}