use core::f64::consts::{LN_2, PI as PI64, SQRT_2};
use crate::imbe_wire::dequantize::DecodeError;
use crate::ambe_plus2_wire::frame::{
AMBE_BLOCK_LENGTHS, AMBE_GAIN_LEVELS, AMBE_HOC_B5, AMBE_HOC_B6, AMBE_HOC_B7,
AMBE_HOC_B8, AMBE_PITCH_TABLE, AMBE_PRBA24, AMBE_PRBA58, AMBE_VUV_CODEBOOK,
ANNEX_T, PitchEntry, ToneParams,
};
use crate::ambe_plus2_wire::priority::{deprioritize, prioritize};
use crate::mbe_params::{L_MAX, MbeParams};
pub const MAX_BLOCK_SIZE: usize = 17;
fn dct_cos(j_i: usize) -> &'static [f64] {
use std::sync::OnceLock;
static TABLES: OnceLock<[Vec<f64>; MAX_BLOCK_SIZE + 1]> = OnceLock::new();
let tables = TABLES.get_or_init(|| {
core::array::from_fn(|j| {
if j == 0 {
Vec::new()
} else {
let mut t = vec![0f64; j * j];
for k_0 in 0..j {
for j_0 in 0..j {
t[k_0 * j + j_0] =
(PI64 * (k_0 as f64) * (j_0 as f64 + 0.5) / j as f64).cos();
}
}
t
}
})
});
&tables[j_i]
}
pub const INIT_PREV_L: u8 = 15;
pub const PITCH_INDEX_MAX: u8 = 119;
pub const HALFRATE_TONE_FIRST: u8 = 120;
#[derive(Clone, Debug)]
pub struct DecoderState {
prev_lambda: [f64; L_MAX as usize + 2],
prev_l: u8,
prev_gamma: f64,
}
impl DecoderState {
pub fn new() -> Self {
Self {
prev_lambda: [1.0; L_MAX as usize + 2],
prev_l: INIT_PREV_L,
prev_gamma: 0.0,
}
}
fn prev_lambda_at(&self, l: u8) -> f64 {
if l == 0 {
return self.prev_lambda[1];
}
let idx = (l as usize).min(self.prev_l as usize);
self.prev_lambda[idx]
}
pub fn previous_l(&self) -> u8 {
self.prev_l
}
pub fn previous_gamma(&self) -> f64 {
self.prev_gamma
}
pub fn lambda_tilde_snapshot(&self) -> Vec<f64> {
let n = self.prev_l as usize;
let mut out = Vec::with_capacity(n + 1);
out.push(0.0);
for l in 1..=n {
out.push(self.prev_lambda[l]);
}
out
}
pub fn from_lambda_state(
lambda_tilde_prev: &[f64],
l_tilde_prev: u8,
gamma_tilde_prev: f64,
) -> Self {
debug_assert!(l_tilde_prev as usize <= L_MAX as usize);
debug_assert!(lambda_tilde_prev.len() > l_tilde_prev as usize);
let mut state = Self::new();
for l in 1..=l_tilde_prev as usize {
state.prev_lambda[l] = lambda_tilde_prev[l];
}
let tail = if l_tilde_prev == 0 {
1.0
} else {
lambda_tilde_prev[l_tilde_prev as usize]
};
for l in (l_tilde_prev as usize + 1)..(L_MAX as usize + 2) {
state.prev_lambda[l] = tail;
}
state.prev_l = l_tilde_prev;
state.prev_gamma = gamma_tilde_prev;
state
}
}
impl Default for DecoderState {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct PitchInfo {
pub omega_0: f32,
pub l: u8,
}
pub fn decode_pitch(b0: u8) -> Option<PitchInfo> {
if b0 > PITCH_INDEX_MAX {
return None;
}
let PitchEntry { l, omega_0 } = AMBE_PITCH_TABLE[b0 as usize];
Some(PitchInfo { omega_0, l })
}
pub fn expand_vuv(b1: u8, omega_0: f32, l: u8) -> [bool; L_MAX as usize] {
debug_assert!(b1 < 32, "b̂₁ is a 5-bit index");
debug_assert!(l <= L_MAX);
let codebook = &AMBE_VUV_CODEBOOK[b1 as usize];
let omega_0 = f64::from(omega_0);
let mut out = [false; L_MAX as usize];
for l_h in 1..=l {
let j = (f64::from(l_h) * 16.0 * omega_0 / (2.0 * PI64)).floor() as i32;
let j = j.clamp(0, 7) as usize;
out[(l_h - 1) as usize] = codebook[j];
}
out
}
pub fn decode_gain(b2: u8, prev_gamma: f64) -> f64 {
debug_assert!(b2 < 32);
f64::from(AMBE_GAIN_LEVELS[b2 as usize]) + 0.5 * prev_gamma
}
pub fn decode_prba_vector(b3: u16, b4: u8) -> [f64; 8] {
debug_assert!(b3 < 512 && b4 < 128);
let p = AMBE_PRBA24[b3 as usize];
let q = AMBE_PRBA58[b4 as usize];
[
0.0,
f64::from(p[0]),
f64::from(p[1]),
f64::from(p[2]),
f64::from(q[0]),
f64::from(q[1]),
f64::from(q[2]),
f64::from(q[3]),
]
}
pub fn prba_to_residuals(g: &[f64; 8]) -> [f64; 8] {
let mut r = [0f64; 8];
for i_0 in 0..8 {
let i_half = i_0 as f64 + 0.5;
let mut acc = 0f64;
for m_0 in 0..8 {
let alpha = if m_0 == 0 { 1.0 } else { 2.0 };
let arg = PI64 * (m_0 as f64) * i_half / 8.0;
acc += alpha * g[m_0] * arg.cos();
}
r[i_0] = acc;
}
r
}
pub fn pair_split(r: &[f64; 8]) -> [(f64, f64); 4] {
let w = SQRT_2 / 4.0;
[
((r[0] + r[1]) / 2.0, w * (r[0] - r[1])),
((r[2] + r[3]) / 2.0, w * (r[2] - r[3])),
((r[4] + r[5]) / 2.0, w * (r[4] - r[5])),
((r[6] + r[7]) / 2.0, w * (r[6] - r[7])),
]
}
pub fn assemble_hoc_matrix(
pair: &[(f64, f64); 4],
b5: u8,
b6: u8,
b7: u8,
b8: u8,
blocks: &[u8; 4],
) -> [[f64; MAX_BLOCK_SIZE]; 4] {
debug_assert!(b5 < 32 && b6 < 16 && b7 < 16 && b8 < 8);
let hoc: [[f32; 4]; 4] = [
AMBE_HOC_B5[b5 as usize],
AMBE_HOC_B6[b6 as usize],
AMBE_HOC_B7[b7 as usize],
AMBE_HOC_B8[b8 as usize],
];
let mut c = [[0f64; MAX_BLOCK_SIZE]; 4];
for i in 0..4 {
c[i][0] = pair[i].0; c[i][1] = pair[i].1; let j_i = blocks[i] as usize;
let k_max = j_i.min(6);
for k in 3..=k_max {
c[i][k - 1] = f64::from(hoc[i][k - 3]);
}
}
c
}
pub fn inverse_block_dct(
c: &[[f64; MAX_BLOCK_SIZE]; 4],
blocks: &[u8; 4],
) -> [f64; L_MAX as usize] {
let mut t = [0f64; L_MAX as usize];
let mut l_offset = 0usize;
for i in 0..4 {
let j_i = blocks[i] as usize;
if j_i == 0 {
continue;
}
let cos_tab = dct_cos(j_i);
for j_0 in 0..j_i {
let mut acc = 0f64;
for k_0 in 0..j_i {
let alpha = if k_0 == 0 { 1.0 } else { 2.0 };
acc += alpha * c[i][k_0] * cos_tab[k_0 * j_i + j_0];
}
t[l_offset + j_0] = acc;
}
l_offset += j_i;
}
t
}
pub fn apply_log_prediction(
t: &[f64; L_MAX as usize],
l: u8,
gamma: f64,
state: &DecoderState,
) -> [f64; L_MAX as usize + 2] {
let mut lambda = [0f64; L_MAX as usize + 2];
let l_curr = f64::from(l);
let l_prev = f64::from(state.prev_l);
let t_sum: f64 = t[..l as usize].iter().sum();
let gamma_intercept = gamma - 0.5 * l_curr.log2() - t_sum / l_curr;
let mut mean_sum = 0f64;
for lambda_idx in 1..=l {
let k_l = l_prev * f64::from(lambda_idx) / l_curr;
let k_floor = k_l.floor();
let delta = k_l - k_floor;
let log_lo = state.prev_lambda_at(k_floor as u8);
let log_hi = state.prev_lambda_at(k_floor as u8 + 1);
mean_sum += (1.0 - delta) * log_lo + delta * log_hi;
}
let mean = mean_sum / l_curr;
for l_h in 1..=l {
let k_l = l_prev * f64::from(l_h) / l_curr;
let k_floor = k_l.floor();
let delta = k_l - k_floor;
let log_lo = state.prev_lambda_at(k_floor as u8);
let log_hi = state.prev_lambda_at(k_floor as u8 + 1);
lambda[l_h as usize] = t[(l_h - 1) as usize]
+ 0.65 * (1.0 - delta) * log_lo
+ 0.65 * delta * log_hi
- 0.65 * mean
+ gamma_intercept;
}
lambda
}
pub fn compute_m_tilde(
lambda: &[f64; L_MAX as usize + 2],
voiced: &[bool],
omega_0: f32,
) -> [f32; L_MAX as usize] {
let l = voiced.len();
debug_assert!(l <= L_MAX as usize);
let mut m = [0f32; L_MAX as usize];
let uv_scale = 0.2046 / f64::from(omega_0).sqrt();
for l_h in 1..=l {
let linear = (LN_2 * lambda[l_h]).exp();
m[l_h - 1] = if voiced[l_h - 1] {
linear as f32
} else {
(uv_scale * linear) as f32
};
}
m
}
pub fn dequantize(
u: &[u16; 4],
state: &mut DecoderState,
) -> Result<MbeParams, DecodeError> {
let b = deprioritize(u);
let b0 = b[0] as u8;
let pitch = decode_pitch(b0).ok_or(DecodeError::BadPitch)?;
let l = pitch.l;
let b1 = b[1] as u8;
let voiced = expand_vuv(b1, pitch.omega_0, l);
let b2 = b[2] as u8;
let gamma = decode_gain(b2, state.prev_gamma);
let g = decode_prba_vector(b[3], b[4] as u8);
let r = prba_to_residuals(&g);
let pair = pair_split(&r);
let blocks = AMBE_BLOCK_LENGTHS[(l - 9) as usize];
let c = assemble_hoc_matrix(
&pair,
b[5] as u8,
b[6] as u8,
b[7] as u8,
b[8] as u8,
&blocks,
);
let t = inverse_block_dct(&c, &blocks);
let lambda = apply_log_prediction(&t, l, gamma, state);
let m_tilde = compute_m_tilde(&lambda, &voiced[..l as usize], pitch.omega_0);
let params = MbeParams::new(
pitch.omega_0,
l,
&voiced[..l as usize],
&m_tilde[..l as usize],
)
.map_err(DecodeError::InvalidParams)?;
for l_h in 1..=l as usize {
state.prev_lambda[l_h] = lambda[l_h];
}
for l_h in (l as usize + 1)..=L_MAX as usize + 1 {
state.prev_lambda[l_h] = lambda[l as usize];
}
state.prev_l = l;
state.prev_gamma = gamma;
Ok(params)
}
pub fn encode_pitch(omega_0: f32) -> Option<u8> {
if !(omega_0.is_finite()) || omega_0 <= 0.0 {
return None;
}
let target = f64::from(omega_0);
let (first, last) = (
f64::from(AMBE_PITCH_TABLE[0].omega_0),
f64::from(AMBE_PITCH_TABLE[AMBE_PITCH_TABLE.len() - 1].omega_0),
);
if target > first * 1.000001 || target < last * 0.999999 {
return None;
}
let mut lo = 0usize;
let mut hi = AMBE_PITCH_TABLE.len() - 1;
while lo + 1 < hi {
let mid = (lo + hi) / 2;
if f64::from(AMBE_PITCH_TABLE[mid].omega_0) >= target {
lo = mid;
} else {
hi = mid;
}
}
let d_lo = (target - f64::from(AMBE_PITCH_TABLE[lo].omega_0)).abs();
let d_hi = (target - f64::from(AMBE_PITCH_TABLE[hi].omega_0)).abs();
Some(if d_hi < d_lo { hi as u8 } else { lo as u8 })
}
pub fn encode_vuv(voiced: &[bool], omega_0: f32) -> u8 {
let omega_0 = f64::from(omega_0);
let mut voted = [0i32; 8]; let mut active = [false; 8];
for (i, &v) in voiced.iter().enumerate() {
let l_h = (i + 1) as f64;
let j = (l_h * 16.0 * omega_0 / (2.0 * PI64)).floor() as i32;
let j = j.clamp(0, 7) as usize;
voted[j] += if v { 1 } else { -1 };
active[j] = true;
}
let mut best_idx = 0u8;
let mut best_d = u32::MAX;
for (idx, row) in AMBE_VUV_CODEBOOK.iter().enumerate() {
let mut d = 0u32;
for k in 0..8 {
if active[k] {
let target_k = voted[k] >= 0;
if row[k] != target_k {
d += 1;
}
}
}
if d < best_d {
best_d = d;
best_idx = idx as u8;
}
}
best_idx
}
pub fn encode_gain(delta_gamma: f64) -> u8 {
if delta_gamma <= f64::from(AMBE_GAIN_LEVELS[0]) {
return 0;
}
let last = AMBE_GAIN_LEVELS.len() - 1;
if delta_gamma >= f64::from(AMBE_GAIN_LEVELS[last]) {
return last as u8;
}
let mut lo = 0usize;
let mut hi = last;
while lo + 1 < hi {
let mid = (lo + hi) / 2;
if f64::from(AMBE_GAIN_LEVELS[mid]) <= delta_gamma {
lo = mid;
} else {
hi = mid;
}
}
let d_lo = (delta_gamma - f64::from(AMBE_GAIN_LEVELS[lo])).abs();
let d_hi = (f64::from(AMBE_GAIN_LEVELS[hi]) - delta_gamma).abs();
if d_hi < d_lo { hi as u8 } else { lo as u8 }
}
pub fn pair_join(pair: &[(f64, f64); 4]) -> [f64; 8] {
let mut r = [0f64; 8];
for i in 0..4 {
let (mean, k2) = pair[i];
r[2 * i] = mean + SQRT_2 * k2;
r[2 * i + 1] = mean - SQRT_2 * k2;
}
r
}
pub fn residuals_to_prba(r: &[f64; 8]) -> [f64; 8] {
let mut g = [0f64; 8];
for m_0 in 0..8 {
let mut acc = 0f64;
for i_0 in 0..8 {
let i_half = i_0 as f64 + 0.5;
let arg = PI64 * (m_0 as f64) * i_half / 8.0;
acc += r[i_0] * arg.cos();
}
g[m_0] = acc / 8.0;
}
g
}
fn vq_nearest<const K: usize>(target: &[f64; K], book: &[[f32; K]]) -> usize {
let mut best_idx = 0usize;
let mut best_d = f64::INFINITY;
for (idx, row) in book.iter().enumerate() {
let mut d = 0f64;
for k in 0..K {
let e = target[k] - f64::from(row[k]);
d += e * e;
}
if d < best_d {
best_d = d;
best_idx = idx;
}
}
best_idx
}
pub fn quantize_prba(g: &[f64; 8]) -> (u16, u8) {
let g24: [f64; 3] = [g[1], g[2], g[3]];
let g58: [f64; 4] = [g[4], g[5], g[6], g[7]];
let b3 = vq_nearest(&g24, &AMBE_PRBA24) as u16;
let b4 = vq_nearest(&g58, &AMBE_PRBA58) as u8;
(b3, b4)
}
fn quantize_hoc_block(c_block: &[f64; MAX_BLOCK_SIZE], j_i: u8, book: &[[f32; 4]]) -> u8 {
if j_i < 3 {
return 0;
}
let k_max = j_i.min(6) as usize;
let mut target = [0f64; 4];
for k in 3..=k_max {
target[k - 3] = c_block[k - 1];
}
vq_nearest(&target, book) as u8
}
pub fn quantize_hoc_all(
c: &[[f64; MAX_BLOCK_SIZE]; 4],
blocks: &[u8; 4],
) -> (u8, u8, u8, u8) {
(
quantize_hoc_block(&c[0], blocks[0], &AMBE_HOC_B5),
quantize_hoc_block(&c[1], blocks[1], &AMBE_HOC_B6),
quantize_hoc_block(&c[2], blocks[2], &AMBE_HOC_B7),
quantize_hoc_block(&c[3], blocks[3], &AMBE_HOC_B8),
)
}
pub fn forward_block_dct(
t: &[f64; L_MAX as usize],
blocks: &[u8; 4],
) -> [[f64; MAX_BLOCK_SIZE]; 4] {
let mut c = [[0f64; MAX_BLOCK_SIZE]; 4];
let mut l_offset = 0usize;
for i in 0..4 {
let j_i = blocks[i] as usize;
if j_i == 0 {
continue;
}
let cos_tab = dct_cos(j_i);
let inv_j = 1.0 / j_i as f64;
for k_0 in 0..j_i {
let mut acc = 0f64;
for j_0 in 0..j_i {
acc += t[l_offset + j_0] * cos_tab[k_0 * j_i + j_0];
}
c[i][k_0] = acc * inv_j;
}
l_offset += j_i;
}
c
}
pub fn forward_log_prediction(
lambda: &[f64; L_MAX as usize + 2],
l: u8,
state: &DecoderState,
) -> ([f64; L_MAX as usize], f64) {
let l_curr = f64::from(l);
let l_prev = f64::from(state.prev_l);
let mut pred = [0f64; L_MAX as usize];
let mut mean_pred = 0f64;
for l_h in 1..=l {
let k_l = l_prev * f64::from(l_h) / l_curr;
let k_floor = k_l.floor();
let delta = k_l - k_floor;
let log_lo = state.prev_lambda_at(k_floor as u8);
let log_hi = state.prev_lambda_at(k_floor as u8 + 1);
let p = (1.0 - delta) * log_lo + delta * log_hi;
pred[(l_h - 1) as usize] = p;
mean_pred += p;
}
mean_pred /= l_curr;
let mut mean_lambda = 0f64;
for l_h in 1..=l {
mean_lambda += lambda[l_h as usize];
}
mean_lambda /= l_curr;
let gamma = mean_lambda + 0.5 * l_curr.log2();
let gamma_intercept = mean_lambda;
let mut t = [0f64; L_MAX as usize];
for l_h in 1..=l {
let idx = (l_h - 1) as usize;
t[idx] = lambda[l_h as usize]
- 0.65 * pred[idx]
+ 0.65 * mean_pred
- gamma_intercept;
}
(t, gamma)
}
pub fn m_tilde_to_lambda(
m_tilde: &[f32],
voiced: &[bool],
omega_0: f32,
) -> [f64; L_MAX as usize + 2] {
let l = m_tilde.len();
debug_assert_eq!(voiced.len(), l);
let mut lambda = [0f64; L_MAX as usize + 2];
let uv_scale = 0.2046 / f64::from(omega_0).sqrt();
for i in 0..l {
let m = f64::from(m_tilde[i]);
if m <= 0.0 {
lambda[i + 1] = -1000.0;
} else {
let linear = if voiced[i] { m } else { m / uv_scale };
lambda[i + 1] = linear.log2();
}
}
lambda
}
pub fn quantize(
params: &MbeParams,
state: &mut DecoderState,
) -> Result<[u16; 4], DecodeError> {
let l = params.harmonic_count();
let omega_0 = params.omega_0();
let b0 = encode_pitch(omega_0).ok_or(DecodeError::BadPitch)? as u16;
let voiced_slice = params.voiced_slice();
let b1 = u16::from(encode_vuv(voiced_slice, omega_0));
let lambda = m_tilde_to_lambda(
params.amplitudes_slice(),
voiced_slice,
omega_0,
);
let (t, gamma) = forward_log_prediction(&lambda, l, state);
let delta_gamma = gamma - 0.5 * state.prev_gamma;
let b2 = u16::from(encode_gain(delta_gamma));
let blocks = AMBE_BLOCK_LENGTHS[(l - 9) as usize];
let c = forward_block_dct(&t, &blocks);
let pair: [(f64, f64); 4] = core::array::from_fn(|i| (c[i][0], c[i][1]));
let r = pair_join(&pair);
let g = residuals_to_prba(&r);
let (b3, b4) = quantize_prba(&g);
let (b5, b6, b7, b8) = quantize_hoc_all(&c, &blocks);
let mut b = [0u16; 9];
b[0] = b0;
b[1] = b1;
b[2] = b2;
b[3] = b3;
b[4] = u16::from(b4);
b[5] = u16::from(b5);
b[6] = u16::from(b6);
b[7] = u16::from(b7);
b[8] = u16::from(b8);
let u = prioritize(&b);
let delta_gamma_q = f64::from(AMBE_GAIN_LEVELS[b2 as usize]);
let gamma_q = delta_gamma_q + 0.5 * state.prev_gamma;
for l_h in 1..=l as usize {
state.prev_lambda[l_h] = lambda[l_h];
}
for l_h in (l as usize + 1)..=L_MAX as usize + 1 {
state.prev_lambda[l_h] = lambda[l as usize];
}
state.prev_l = l;
state.prev_gamma = gamma_q;
Ok(u)
}
pub const TONE_B0_FIRST: u8 = 126;
pub const TONE_AMPLITUDE_PEAK: f64 = 16384.0;
pub const TONE_AMPLITUDE_EXPONENT_STEP: f64 = 0.03555;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum FrameKind {
Voice,
Tone,
Erasure,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct ToneFrameFields {
pub id: u8,
pub amplitude: u8,
}
#[derive(Clone, Debug)]
pub enum Decoded {
Voice(MbeParams),
Tone {
fields: ToneFrameFields,
params: MbeParams,
},
Erasure,
}
pub fn classify_ambe_plus2_frame(u: &[u16; 4]) -> FrameKind {
if ((u[0] >> 6) & 0x3F) == 0x3F && (u[3] & 0x0F) == 0 {
return FrameKind::Tone;
}
let b = deprioritize(u);
let b0 = b[0] as u8;
if b0 <= PITCH_INDEX_MAX {
FrameKind::Voice
} else {
FrameKind::Erasure
}
}
pub fn parse_tone_frame(u: &[u16; 4]) -> Option<ToneFrameFields> {
if (u[0] >> 6) & 0x3F != 0x3F {
return None;
}
if u[3] & 0x0F != 0 {
return None;
}
let id = ((u[3] >> 5) & 0xFF) as u8;
let ad_hi = (u[0] & 0x3F) as u8; let ad_lo = ((u[3] >> 4) & 1) as u8; let amplitude = (ad_hi << 1) | ad_lo;
Some(ToneFrameFields { id, amplitude })
}
pub fn encode_tone_frame_info(id: u8, amplitude: u8) -> [u16; 4] {
let amplitude = amplitude & 0x7F; let mut u = [0u16; 4];
let ad_hi = (amplitude >> 1) & 0x3F;
let ad_lo = amplitude & 1;
u[0] = (0x3F << 6) | u16::from(ad_hi);
u[1] = (u16::from(id) << 4) | u16::from(id >> 4);
let id_lo_nibble = u16::from(id & 0x0F);
u[2] = (id_lo_nibble << 7) | u16::from(id >> 1);
u[3] = (u16::from(id & 1) << 13)
| (u16::from(id) << 5)
| (u16::from(ad_lo) << 4);
u
}
pub fn tone_to_mbe_params(id: u8, amplitude: u8) -> Option<MbeParams> {
let tone = ANNEX_T[id as usize]?;
let ToneParams { f0, l1, l2 } = tone;
let f0_f64 = f64::from(f0);
if f0_f64 <= 0.0 {
return None;
}
let omega_0 = (2.0 * PI64 / 8000.0) * f0_f64;
let l_tilde = (3812.5 / f0_f64).floor() as u8;
let l = l_tilde.clamp(crate::mbe_params::L_MIN, L_MAX);
let mut voiced = [false; L_MAX as usize];
let mut amps = [0f32; L_MAX as usize];
let tone_magnitude = if id == 255 {
0.0 } else {
TONE_AMPLITUDE_PEAK
* 10f64.powf(TONE_AMPLITUDE_EXPONENT_STEP * (f64::from(amplitude) - 127.0))
};
for &l_tone in &[l1, l2] {
if l_tone >= 1 && l_tone <= l {
voiced[(l_tone - 1) as usize] = true;
amps[(l_tone - 1) as usize] = tone_magnitude as f32;
}
}
MbeParams::new(omega_0 as f32, l, &voiced[..l as usize], &s[..l as usize]).ok()
}
pub fn decode_to_params(
u: &[u16; 4],
state: &mut DecoderState,
) -> Result<Decoded, DecodeError> {
match classify_ambe_plus2_frame(u) {
FrameKind::Voice => {
let params = dequantize(u, state)?;
Ok(Decoded::Voice(params))
}
FrameKind::Tone => {
let fields = parse_tone_frame(u).ok_or(DecodeError::BadPitch)?;
let params = tone_to_mbe_params(fields.id, fields.amplitude)
.ok_or(DecodeError::BadPitch)?;
Ok(Decoded::Tone { fields, params })
}
FrameKind::Erasure => Ok(Decoded::Erasure),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pitch_decode_endpoints_match_annex_l() {
use core::f32::consts::PI;
let p0 = decode_pitch(0).unwrap();
assert_eq!(p0.l, 9);
assert!((p0.omega_0 - 0.049971 * 2.0 * PI).abs() < 1e-5);
let p_last = decode_pitch(119).unwrap();
assert_eq!(p_last.l, 56);
assert!((p_last.omega_0 - 0.008125 * 2.0 * PI).abs() < 1e-5);
}
#[test]
fn pitch_decode_rejects_tone_and_reserved() {
for b0 in 120u8..=255 {
assert!(decode_pitch(b0).is_none(), "b̂₀ = {b0}");
}
}
#[test]
fn expand_vuv_all_voiced_codebook_0_gives_voiced() {
let v = expand_vuv(0, 0.2, 9);
for l_h in 0..9 {
assert!(v[l_h]);
}
}
#[test]
fn expand_vuv_all_unvoiced_codebook_16_gives_unvoiced() {
let v = expand_vuv(16, 0.2, 9);
for l_h in 0..9 {
assert!(!v[l_h]);
}
}
#[test]
fn expand_vuv_j_l_clamps_to_seven() {
let omega = 2.0 * core::f32::consts::PI / 19.875;
let _ = expand_vuv(0, omega, 9);
}
#[test]
fn decode_gain_first_frame_equals_table_level() {
assert!((decode_gain(0, 0.0) - (-2.0)).abs() < 1e-6);
assert!((decode_gain(31, 0.0) - 6.874496).abs() < 1e-6);
}
#[test]
fn decode_gain_differential_recurrence() {
let g0 = decode_gain(16, 10.0);
let d16 = AMBE_GAIN_LEVELS[16] as f64;
assert!((g0 - (d16 + 5.0)).abs() < 1e-6);
}
#[test]
fn prba_vector_sources_from_correct_codebooks() {
let g = decode_prba_vector(0, 0);
assert_eq!(g[0], 0.0);
assert!((g[1] - 0.526055).abs() < 1e-6);
assert!((g[2] - (-0.328567)).abs() < 1e-6);
assert!((g[3] - (-0.304727)).abs() < 1e-6);
assert!((g[4] - (-0.103660)).abs() < 1e-6);
assert!((g[5] - 0.094597).abs() < 1e-6);
assert!((g[6] - (-0.013149)).abs() < 1e-6);
assert!((g[7] - 0.081501).abs() < 1e-6);
}
#[test]
fn prba_dc_only_propagates_to_all_residuals() {
let g = [3.0f64, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let r = prba_to_residuals(&g);
for i in 0..8 {
assert!((r[i] - 3.0).abs() < 1e-9);
}
}
#[test]
fn pair_split_matches_encoder_inverse() {
let c_in: [(f64, f64); 4] = [
(1.5, 0.3),
(2.0, -0.5),
(-1.0, 0.7),
(0.5, -0.2),
];
let mut r = [0f64; 8];
for i in 0..4 {
r[2 * i] = c_in[i].0 + core::f64::consts::SQRT_2 * c_in[i].1;
r[2 * i + 1] = c_in[i].0 - core::f64::consts::SQRT_2 * c_in[i].1;
}
let c_out = pair_split(&r);
for i in 0..4 {
assert!((c_out[i].0 - c_in[i].0).abs() < 1e-12, "block {i} mean");
assert!(
(c_out[i].1 - c_in[i].1).abs() < 1e-12,
"block {i} k=2"
);
}
}
#[test]
fn hoc_matrix_respects_block_length_clamp() {
let pair = [(1.0, 0.1), (2.0, 0.2), (3.0, 0.3), (4.0, 0.4)];
let blocks = [2u8, 2, 2, 3];
let c = assemble_hoc_matrix(&pair, 0, 0, 0, 0, &blocks);
assert_eq!(c[3][0], 4.0);
assert_eq!(c[3][1], 0.4);
assert!((c[3][2] - f64::from(AMBE_HOC_B8[0][0])).abs() < 1e-6);
for i in 0..3 {
for k in 2..MAX_BLOCK_SIZE {
assert_eq!(c[i][k], 0.0, "block {i}, k_idx {k}");
}
}
}
#[test]
fn hoc_matrix_fills_up_to_k_six_when_block_is_large() {
let pair = [(0.0, 0.0); 4];
let blocks = [17u8, 17, 11, 11]; let c = assemble_hoc_matrix(&pair, 0, 0, 0, 0, &blocks);
for k in 3..=6 {
let expected = f64::from(AMBE_HOC_B5[0][k - 3]);
assert!(
(c[0][k - 1] - expected).abs() < 1e-6,
"block 1, k={k}"
);
}
for k in 7..MAX_BLOCK_SIZE {
assert_eq!(c[0][k - 1], 0.0, "block 1, k_idx {k}");
}
}
#[test]
fn inverse_block_dct_dc_only_propagates() {
let mut c = [[0f64; MAX_BLOCK_SIZE]; 4];
c[0][0] = 1.0;
c[1][0] = 2.0;
c[2][0] = 3.0;
c[3][0] = 4.0;
let blocks = AMBE_BLOCK_LENGTHS[0]; let t = inverse_block_dct(&c, &blocks);
let mut offset = 0;
for (i, &j) in blocks.iter().enumerate() {
for k in 0..j as usize {
assert!(
(t[offset + k] - f64::from(i as u32 + 1)).abs() < 1e-9,
"block {i}, pos {k}"
);
}
offset += j as usize;
}
}
#[test]
fn log_mag_prediction_with_unit_prev_and_zero_gamma() {
let mut t = [0f64; L_MAX as usize];
for i in 0..9 {
t[i] = (i as f64) * 0.1;
}
let state = DecoderState::new();
let lambda = apply_log_prediction(&t, 9, 0.0, &state);
let l_curr = 9f64;
let t_sum: f64 = t[..9].iter().sum();
let gamma_intercept = 0.0 - 0.5 * l_curr.log2() - t_sum / l_curr;
for l_h in 1..=9 {
let expected = t[l_h - 1] + gamma_intercept;
assert!(
(lambda[l_h] - expected).abs() < 1e-9,
"l={l_h}: expected {expected}, got {}",
lambda[l_h]
);
}
}
#[test]
fn dequantize_rejects_tone_frame() {
use crate::ambe_plus2_wire::priority::prioritize;
let mut b = [0u16; 9];
b[0] = 120; let u = prioritize(&b);
let mut state = DecoderState::new();
assert_eq!(
dequantize(&u, &mut state),
Err(DecodeError::BadPitch)
);
}
#[test]
fn dequantize_produces_finite_amplitudes_for_zero_b() {
use crate::ambe_plus2_wire::priority::prioritize;
let b = [0u16; 9];
let u = prioritize(&b);
let mut state = DecoderState::new();
let p = dequantize(&u, &mut state).expect("decode");
assert_eq!(p.harmonic_count(), 9);
assert!((p.omega_0() - 0.049971 * 2.0 * core::f32::consts::PI).abs() < 1e-5);
for l_h in 1..=9 {
let a = p.amplitude(l_h);
assert!(a.is_finite() && a >= 0.0, "l={l_h}: M̃ = {a}");
}
assert_eq!(state.previous_l(), 9);
}
fn build_tone_frame_u(id: u8, amplitude: u8) -> [u16; 4] {
let mut u = [0u16; 4];
let ad_hi = (amplitude >> 1) & 0x3F;
let ad_lo = amplitude & 1;
u[0] = (0x3F << 6) | u16::from(ad_hi);
u[1] = (u16::from(id) << 4) | u16::from(id >> 4);
let id_lo_nibble = u16::from(id & 0x0F);
u[2] = (id_lo_nibble << 7) | u16::from(id >> 1);
u[3] = (u16::from(id & 1) << 13)
| (u16::from(id) << 5)
| (u16::from(ad_lo) << 4);
u
}
#[test]
fn classify_voice_frame_when_b0_valid() {
use crate::ambe_plus2_wire::priority::prioritize;
let b = [0u16; 9];
let u = prioritize(&b);
assert_eq!(classify_ambe_plus2_frame(&u), FrameKind::Voice);
}
#[test]
fn classify_erasure_for_b0_range_120_125() {
use crate::ambe_plus2_wire::priority::prioritize;
for b0 in 120..=125u16 {
let mut b = [0u16; 9];
b[0] = b0;
let u = prioritize(&b);
assert_eq!(
classify_ambe_plus2_frame(&u),
FrameKind::Erasure,
"b̂₀ = {b0}"
);
}
}
#[test]
fn classify_tone_for_b0_126_and_127_with_signature() {
let u = build_tone_frame_u(128, 0); assert_eq!(classify_ambe_plus2_frame(&u), FrameKind::Tone);
}
#[test]
fn classify_erasure_when_tone_range_but_signature_missing() {
use crate::ambe_plus2_wire::priority::prioritize;
let mut b = [0u16; 9];
b[0] = 126; let u = prioritize(&b);
assert_eq!(classify_ambe_plus2_frame(&u), FrameKind::Erasure);
}
#[test]
fn parse_tone_frame_recovers_id_and_amplitude() {
for id in [5u8, 64, 128, 144, 162, 255] {
for amp in [0u8, 1, 63, 64, 126, 127] {
let u = build_tone_frame_u(id, amp);
let fields = parse_tone_frame(&u).unwrap_or_else(|| {
panic!("parse failed for id={id}, amp={amp}")
});
assert_eq!(fields.id, id, "id={id} amp={amp}");
assert_eq!(fields.amplitude, amp, "id={id} amp={amp}");
}
}
}
#[test]
fn parse_tone_frame_rejects_bad_signature() {
let mut u = build_tone_frame_u(128, 64);
u[0] &= !(0x3F << 6); assert_eq!(parse_tone_frame(&u), None);
}
#[test]
fn parse_tone_frame_rejects_bad_trailer() {
let mut u = build_tone_frame_u(128, 64);
u[3] |= 0x01; assert_eq!(parse_tone_frame(&u), None);
}
#[test]
fn tone_to_mbe_single_freq_tone() {
let params = tone_to_mbe_params(5, 127).unwrap();
let expected_omega = (2.0 * core::f64::consts::PI / 8000.0) * 156.25;
assert!((f64::from(params.omega_0()) - expected_omega).abs() < 1e-6);
assert_eq!(params.harmonic_count(), 24);
assert!(params.voiced(1));
assert!((params.amplitude(1) - 16384.0).abs() < 1.0);
for l_h in 2..=24 {
assert!(!params.voiced(l_h));
assert_eq!(params.amplitude(l_h), 0.0);
}
}
#[test]
fn tone_to_mbe_silence_id_255_has_zero_amplitude() {
let params = tone_to_mbe_params(255, 127).unwrap();
for l_h in 1..=params.harmonic_count() {
assert_eq!(params.amplitude(l_h), 0.0, "l={l_h}");
}
}
#[test]
fn tone_to_mbe_reserved_id_returns_none() {
for id in [0u8, 1, 4, 123, 164, 200, 254] {
assert!(tone_to_mbe_params(id, 64).is_none(), "id={id}");
}
}
#[test]
fn tone_to_mbe_two_freq_dtmf() {
let params = tone_to_mbe_params(128, 100).unwrap();
let l = params.harmonic_count();
let voiced_count: usize = (1..=l).filter(|&i| params.voiced(i)).count();
assert!(
voiced_count <= 2,
"DTMF tone should activate at most 2 harmonics, got {voiced_count}"
);
}
#[test]
fn tone_amplitude_scales_logarithmically() {
let step = 10f64.powf(-0.03555);
let p127 = tone_to_mbe_params(5, 127).unwrap();
let p126 = tone_to_mbe_params(5, 126).unwrap();
let a127 = f64::from(p127.amplitude(1));
let a126 = f64::from(p126.amplitude(1));
assert!((a126 / a127 - step).abs() < 1e-4, "a126/a127 = {}", a126 / a127);
}
#[test]
fn decode_ambe_plus2_dispatches_voice() {
use crate::ambe_plus2_wire::priority::prioritize;
let b = [0u16; 9];
let u = prioritize(&b);
let mut state = DecoderState::new();
match decode_to_params(&u, &mut state).unwrap() {
Decoded::Voice(_) => {}
other => panic!("expected Voice, got {other:?}"),
}
}
#[test]
fn decode_ambe_plus2_dispatches_tone() {
let u = build_tone_frame_u(5, 100); let mut state = DecoderState::new();
match decode_to_params(&u, &mut state).unwrap() {
Decoded::Tone { fields, params: _ } => {
assert_eq!(fields.id, 5);
assert_eq!(fields.amplitude, 100);
}
other => panic!("expected Tone, got {other:?}"),
}
assert_eq!(state.previous_l(), INIT_PREV_L);
assert_eq!(state.previous_gamma(), 0.0);
}
#[test]
fn decode_ambe_plus2_dispatches_erasure() {
use crate::ambe_plus2_wire::priority::prioritize;
let mut b = [0u16; 9];
b[0] = 120; let u = prioritize(&b);
let mut state = DecoderState::new();
match decode_to_params(&u, &mut state).unwrap() {
Decoded::Erasure => {}
other => panic!("expected Erasure, got {other:?}"),
}
}
#[test]
fn pitch_encode_roundtrips_every_table_entry() {
for b0 in 0..=119u8 {
let w = AMBE_PITCH_TABLE[b0 as usize].omega_0;
assert_eq!(encode_pitch(w), Some(b0), "b0 = {b0}");
}
}
#[test]
fn pitch_encode_rejects_out_of_range() {
assert_eq!(encode_pitch(0.0), None);
assert_eq!(encode_pitch(-0.1), None);
assert_eq!(encode_pitch(1.0), None); }
#[test]
fn vuv_encode_all_voiced_selects_codebook_row_0() {
let voiced = vec![true; 9];
assert_eq!(encode_vuv(&voiced, 0.2), 0);
}
#[test]
fn vuv_encode_all_unvoiced_picks_row_with_unvoiced_at_active_slots() {
let voiced = vec![false; 9];
let omega_0 = 0.2f32;
let idx = encode_vuv(&voiced, omega_0);
let expanded = expand_vuv(idx, omega_0, 9);
for l_h in 0..9 {
assert!(!expanded[l_h], "l={l_h} should decode unvoiced for encoded b̂₁={idx}");
}
}
#[test]
fn gain_encode_roundtrips_every_annex_o_level() {
for b2 in 0..32u8 {
let target = f64::from(AMBE_GAIN_LEVELS[b2 as usize]);
assert_eq!(encode_gain(target), b2, "b2 = {b2}");
}
}
#[test]
fn gain_encode_clamps_out_of_range() {
assert_eq!(encode_gain(-1000.0), 0);
assert_eq!(encode_gain(1000.0), 31);
}
#[test]
fn pair_join_inverts_pair_split() {
let r_in: [f64; 8] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let pair = pair_split(&r_in);
let r_out = pair_join(&pair);
for i in 0..8 {
assert!((r_out[i] - r_in[i]).abs() < 1e-9, "i={i}");
}
}
#[test]
fn residuals_to_prba_inverts_prba_to_residuals() {
let r_in: [f64; 8] = [0.5, 1.5, -0.3, 0.7, -1.2, 0.8, 0.0, 2.1];
let g = residuals_to_prba(&r_in);
let r_out = prba_to_residuals(&g);
for i in 0..8 {
assert!(
(r_out[i] - r_in[i]).abs() < 1e-9,
"i={i}: {} vs {}",
r_out[i],
r_in[i]
);
}
}
#[test]
fn forward_block_dct_inverts_inverse_block_dct() {
let blocks = AMBE_BLOCK_LENGTHS[(30 - 9) as usize]; let l = 30u8;
let mut t_in = [0f64; L_MAX as usize];
for i in 0..l as usize {
t_in[i] = (i as f64 * 0.3).sin() * 2.0 - 0.5;
}
let c = forward_block_dct(&t_in, &blocks);
let t_out = inverse_block_dct(&c, &blocks);
for i in 0..l as usize {
assert!(
(t_out[i] - t_in[i]).abs() < 1e-6,
"i={i}: {} vs {}",
t_out[i],
t_in[i]
);
}
}
#[test]
fn quantize_prba_produces_valid_indices() {
let g: [f64; 8] = [0.0, 0.5, -0.3, 0.2, -0.1, 0.4, 0.0, -0.2];
let (b3, b4) = quantize_prba(&g);
assert!(b3 < 512);
assert!(b4 < 128);
}
#[test]
fn forward_log_prediction_inverts_apply_under_quantized_gamma() {
let state = DecoderState::new();
let l = 20u8;
let mut lambda_in = [0f64; L_MAX as usize + 2];
for i in 1..=l as usize {
lambda_in[i] = (i as f64 * 0.15).sin() * 0.4 - 0.1;
}
let (t, gamma) = forward_log_prediction(&lambda_in, l, &state);
let lambda_out = apply_log_prediction(&t, l, gamma, &state);
for i in 1..=l as usize {
assert!(
(lambda_out[i] - lambda_in[i]).abs() < 1e-6,
"i={i}: {} vs {}",
lambda_out[i],
lambda_in[i]
);
}
}
#[test]
fn quantize_pitch_and_vuv_roundtrip_on_boundary_cases() {
use crate::ambe_plus2_wire::priority::{deprioritize, prioritize};
for &(b0_seed, b1_seed) in &[(50u16, 0u16), (0, 0), (119, 0), (60, 0)] {
let mut b = [0u16; 9];
b[0] = b0_seed;
b[1] = b1_seed;
let u = prioritize(&b);
let mut state_dec = DecoderState::new();
let params = dequantize(&u, &mut state_dec).unwrap();
let mut state_enc = DecoderState::new();
let u_back = quantize(¶ms, &mut state_enc).unwrap();
let b_back = deprioritize(&u_back);
assert_eq!(
b_back[0], b[0],
"b̂₀ pitch (seed b0={b0_seed}, b1={b1_seed})"
);
assert_eq!(
b_back[1], b[1],
"b̂₁ V/UV (seed b0={b0_seed}, b1={b1_seed})"
);
}
}
#[test]
fn quantize_gain_stabilizes_under_iterated_roundtrip() {
use crate::ambe_plus2_wire::priority::{deprioritize, prioritize};
let mut b = [0u16; 9];
b[0] = 0;
b[1] = 0;
b[2] = 10;
let u0 = prioritize(&b);
let mut state_a = DecoderState::new();
let mut state_b = DecoderState::new();
let params = dequantize(&u0, &mut state_a).unwrap();
let u1 = quantize(¶ms, &mut state_b).unwrap();
let b1 = deprioritize(&u1);
assert_eq!(b1[0], b[0]);
assert_eq!(b1[1], b[1]);
let mut state_c = DecoderState::new();
let params2 = dequantize(&u1, &mut state_c).unwrap();
let mut state_d = DecoderState::new();
let u2 = quantize(¶ms2, &mut state_d).unwrap();
let b2 = deprioritize(&u2);
assert_eq!(b2[2], b1[2], "gain should be stable on iteration 2");
}
#[test]
fn dequantize_advances_state_between_frames() {
use crate::ambe_plus2_wire::priority::prioritize;
let b = [0u16; 9];
let u = prioritize(&b);
let mut state = DecoderState::new();
let g0 = state.previous_gamma();
let _ = dequantize(&u, &mut state).unwrap();
let g1 = state.previous_gamma();
assert_ne!(g0, g1);
let _ = dequantize(&u, &mut state).unwrap();
let g2 = state.previous_gamma();
assert!((g2 - (-3.0)).abs() < 1e-6);
}
}