use crate::mbe_params::{L_MAX, MbeParams};
use crate::ambe_plus2_wire::frame::AMBE_PITCH_TABLE;
pub const RHO_CROSS_RATE: f64 = 0.65;
pub const R_FAST_PATH_THRESHOLD: f64 = 0.01;
const LOG2_FLOOR_INPUT: f64 = 1e-10;
#[derive(Clone, Debug)]
pub struct CrossRatePredictorState {
m_tilde_prev: [f32; L_MAX as usize + 1],
l_prev: u8,
omega_0_prev: f32,
}
impl CrossRatePredictorState {
pub fn new() -> Self {
Self {
m_tilde_prev: [1.0; L_MAX as usize + 1],
l_prev: 30,
omega_0_prev: AMBE_PITCH_TABLE[30].omega_0,
}
}
#[cfg(test)]
pub(crate) fn m_tilde_prev(&self) -> &[f32; L_MAX as usize + 1] {
&self.m_tilde_prev
}
#[cfg(test)]
pub(crate) fn l_prev(&self) -> u8 {
self.l_prev
}
#[cfg(test)]
pub(crate) fn omega_0_prev(&self) -> f32 {
self.omega_0_prev
}
pub fn reset(&mut self) {
*self = Self::new();
}
}
impl Default for CrossRatePredictorState {
fn default() -> Self {
Self::new()
}
}
pub fn blend(
params: &MbeParams,
target_omega_0: f32,
target_l: u8,
state: &mut CrossRatePredictorState,
) -> MbeParams {
debug_assert!(target_l as usize <= L_MAX as usize);
let omega_0_curr = f64::from(target_omega_0);
let omega_0_prev = f64::from(state.omega_0_prev);
let r_prev = omega_0_curr / omega_0_prev.max(1e-20);
let l_curr = target_l;
let amps_curr = params.amplitudes_slice();
let l_source = amps_curr.len().min(l_curr as usize);
let mut blended = [0f32; L_MAX as usize];
blended[..l_source].copy_from_slice(&s_curr[..l_source]);
if (r_prev - 1.0).abs() > R_FAST_PATH_THRESHOLD {
let energy_offset = 0.5 * r_prev.log2();
let mut log_m_prev = [0.0_f64; L_MAX as usize + 1];
for l in 1..=state.l_prev as usize {
let m = f64::from(state.m_tilde_prev[l]).max(LOG2_FLOOR_INPUT);
log_m_prev[l] = m.log2();
}
for l_b in 1..=l_curr as usize {
let src_pos = (l_b as f64) * r_prev;
let prev_log = if src_pos < 1.0 || src_pos > state.l_prev as f64 {
LOG2_FLOOR_INPUT.log2()
} else {
let pos_lo = src_pos.floor() as usize;
let pos_hi = (pos_lo + 1).min(state.l_prev as usize);
let frac = src_pos - pos_lo as f64;
let v_lo = log_m_prev[pos_lo];
let v_hi = log_m_prev[pos_hi];
v_lo + frac * (v_hi - v_lo)
};
let prev_resampled_log = prev_log + energy_offset;
let prior_linear = (2.0_f64).powf(prev_resampled_log) as f32;
let curr = blended[l_b - 1];
blended[l_b - 1] = RHO_CROSS_RATE as f32 * prior_linear
+ (1.0 - RHO_CROSS_RATE as f32) * curr;
}
}
let voiced_src = params.voiced_slice();
let mut voiced_out = [false; L_MAX as usize];
let l_voiced = voiced_src.len().min(l_curr as usize);
voiced_out[..l_voiced].copy_from_slice(&voiced_src[..l_voiced]);
let result = MbeParams::new(
target_omega_0,
l_curr,
&voiced_out[..l_curr as usize],
&blended[..l_curr as usize],
)
.expect("blend result has valid MbeParams constraints");
state.m_tilde_prev.fill(0.0);
state.m_tilde_prev[0] = 1.0;
for l in 1..=l_curr as usize {
state.m_tilde_prev[l] = blended[l - 1];
}
state.l_prev = l_curr;
state.omega_0_prev = target_omega_0;
result
}
#[cfg(test)]
mod tests {
use super::*;
fn make_params(omega_0: f32, l: u8, amps_fill: f32) -> MbeParams {
let voiced = vec![true; l as usize];
let amps = vec![amps_fill; l as usize];
MbeParams::new(omega_0, l, &voiced, &s).expect("test params valid")
}
#[test]
fn cold_start_uses_annex_l_row_30() {
let state = CrossRatePredictorState::new();
assert_eq!(state.l_prev(), 30);
let expected_omega = AMBE_PITCH_TABLE[30].omega_0;
assert_eq!(state.omega_0_prev(), expected_omega);
for l in 0..=L_MAX as usize {
assert_eq!(state.m_tilde_prev()[l], 1.0);
}
}
#[test]
fn fast_path_passes_current_unchanged_when_r_within_one_percent() {
let target_omega = 0.15_f32;
let mut state = CrossRatePredictorState::new();
state.omega_0_prev = target_omega;
state.l_prev = 20;
for l in 1..=20 {
state.m_tilde_prev[l] = 42.0;
}
let params = make_params(target_omega, 20, 100.0);
let out = blend(¶ms, target_omega, 20, &mut state);
for (i, &a) in out.amplitudes_slice().iter().enumerate().take(20) {
assert_eq!(a, 100.0, "harmonic {i}: fast-path should pass current through");
}
assert_eq!(state.l_prev, 20);
assert_eq!(state.omega_0_prev, target_omega);
for l in 1..=20 {
assert_eq!(state.m_tilde_prev[l], 100.0);
}
}
#[test]
fn slow_path_blends_sixty_five_thirty_five_at_matching_grids() {
let mut state = CrossRatePredictorState::new();
state.omega_0_prev = 0.14;
state.l_prev = 20;
for l in 1..=20 {
state.m_tilde_prev[l] = 100.0;
}
let params = make_params(0.15, 20, 400.0);
let out = blend(¶ms, 0.15, 20, &mut state);
let r_prev: f64 = 0.15_f64 / 0.14_f64;
let expected_prior_linear = 100.0_f64 * r_prev.sqrt();
let expected_blend =
RHO_CROSS_RATE * expected_prior_linear + (1.0 - RHO_CROSS_RATE) * 400.0;
for l in 5..=15 {
let observed = out.amplitudes_slice()[l - 1] as f64;
let rel_err = (observed - expected_blend).abs() / expected_blend;
assert!(
rel_err < 0.01,
"l={l}: observed {observed}, expected {expected_blend} (rel err {rel_err})"
);
}
}
#[test]
fn state_advances_every_frame_regardless_of_path() {
let mut state = CrossRatePredictorState::new();
let omega_init = state.omega_0_prev;
let p1 = make_params(omega_init, state.l_prev, 50.0);
let _ = blend(&p1, omega_init, state.l_prev, &mut state);
assert_eq!(state.omega_0_prev, omega_init);
let p2 = make_params(omega_init * 1.2, 25, 80.0);
let _ = blend(&p2, omega_init * 1.2, 25, &mut state);
assert!((state.omega_0_prev - omega_init * 1.2).abs() < 1e-6);
assert_eq!(state.l_prev, 25);
}
#[test]
fn voicing_is_preserved_across_blend() {
let mut state = CrossRatePredictorState::new();
state.omega_0_prev = 0.14;
state.l_prev = 10;
let l: u8 = 10;
let voiced = vec![true, false, true, false, true, false, true, false, true, false];
let amps = vec![100.0_f32; l as usize];
let params = MbeParams::new(0.15, l, &voiced, &s).unwrap();
let out = blend(¶ms, 0.15, l, &mut state);
assert_eq!(&out.voiced_slice()[..l as usize], &voiced[..]);
}
#[test]
fn reset_returns_to_cold_start() {
let mut state = CrossRatePredictorState::new();
state.omega_0_prev = 0.25;
state.l_prev = 40;
state.m_tilde_prev[1] = 999.0;
state.reset();
let fresh = CrossRatePredictorState::new();
assert_eq!(state.l_prev, fresh.l_prev);
assert_eq!(state.omega_0_prev, fresh.omega_0_prev);
assert_eq!(state.m_tilde_prev[1], 1.0);
}
#[test]
fn sustained_input_converges_toward_stable_output() {
let mut state = CrossRatePredictorState::new();
state.omega_0_prev = 0.20;
state.l_prev = 15;
for l in 1..=15 {
state.m_tilde_prev[l] = 0.0; }
let mut last_l1 = 0.0_f64;
for frame in 0..30 {
let p = make_params(0.15, 20, 500.0);
let out = blend(&p, 0.15, 20, &mut state);
let this_l1 = out.amplitudes_slice()[0] as f64;
if frame > 5 {
assert!(
(this_l1 - 500.0).abs() < 1.0,
"frame {frame}: blended l=1 = {this_l1}, expected ~500"
);
}
last_l1 = this_l1;
}
assert!((last_l1 - 500.0).abs() < 1.0);
}
}