use super::{Complex64, DFT_SIZE, HarmonicBasis, PitchRefinement, packed_index};
pub const K_HAT_MAX: u8 = 12;
pub const XI_MAX_FLOOR: f64 = 20000.0;
const E_P_POOR_MATCH_THRESHOLD: f64 = 0.5;
const THETA_BASE_VOICED_HISTORY: f64 = 0.5625;
const THETA_BASE_UNVOICED_HISTORY: f64 = 0.45;
const THETA_PITCH_BAND_COEF: f64 = 0.3096;
#[derive(Clone, Debug)]
pub struct VuvState {
xi_max: f64,
vuv_prev: [u8; (K_HAT_MAX + 1) as usize],
k_prev: u8,
}
impl VuvState {
pub const fn cold_start() -> Self {
Self {
xi_max: XI_MAX_FLOOR,
vuv_prev: [0; (K_HAT_MAX + 1) as usize],
k_prev: 0,
}
}
pub fn xi_max(&self) -> f64 {
self.xi_max
}
pub fn vuv_prev(&self, k: u8) -> u8 {
if (1..=K_HAT_MAX).contains(&k) {
self.vuv_prev[k as usize]
} else {
0
}
}
pub fn override_vuv_prev(&mut self, decisions: &[u8], k_hat: u8) {
self.vuv_prev.fill(0);
let n = (k_hat as usize).min(decisions.len()).min(K_HAT_MAX as usize);
for k in 1..=n {
self.vuv_prev[k] = decisions[k];
}
self.k_prev = k_hat;
}
pub fn override_xi_max(&mut self, xi_max: f64) {
self.xi_max = xi_max;
}
}
impl Default for VuvState {
fn default() -> Self {
Self::cold_start()
}
}
#[derive(Clone, Debug)]
pub struct VuvResult {
pub k_hat: u8,
pub vuv: [u8; (K_HAT_MAX + 1) as usize],
pub d_k: [f64; (K_HAT_MAX + 1) as usize],
pub theta_k: [f64; (K_HAT_MAX + 1) as usize],
pub m_xi: f64,
pub xi_0: f64,
pub xi_max_after: f64,
}
#[inline]
pub fn band_count_for(l_hat: u8) -> u8 {
if l_hat <= 36 {
(l_hat + 2) / 3
} else {
K_HAT_MAX
}
}
pub fn determine_vuv(
sw: &[Complex64; DFT_SIZE],
basis: &HarmonicBasis,
refinement: &PitchRefinement,
e_p_hat_i: f64,
state: &mut VuvState,
) -> VuvResult {
let l_hat = refinement.l_hat;
let omega_hat = refinement.omega_hat;
let k_hat = band_count_for(l_hat);
let wr_dc = basis.w_r(0);
let wr_dc_sq = wr_dc * wr_dc;
let inv_wr_dc_sq = if wr_dc_sq > 0.0 { 1.0 / wr_dc_sq } else { 0.0 };
let mut xi_lf = 0.0;
for m in 0..=63 {
xi_lf += sw[m].norm_sqr();
}
xi_lf *= inv_wr_dc_sq;
let mut xi_hf = 0.0;
for m in 64..=128 {
xi_hf += sw[m].norm_sqr();
}
xi_hf *= inv_wr_dc_sq;
let xi_0 = xi_lf + xi_hf;
let xi_max_new = if xi_0 > state.xi_max {
0.5 * state.xi_max + 0.5 * xi_0
} else {
let decay = 0.99 * state.xi_max + 0.01 * xi_0;
decay.max(XI_MAX_FLOOR)
};
let ratio_den = 0.01 * xi_max_new + xi_0;
let ratio = if ratio_den > 0.0 {
(0.0025 * xi_max_new + xi_0) / ratio_den
} else {
0.0
};
let m_xi = if xi_lf >= 5.0 * xi_hf {
ratio
} else if xi_hf > 0.0 {
ratio * (xi_lf / (5.0 * xi_hf)).sqrt()
} else {
ratio
};
let mut vuv = [0u8; (K_HAT_MAX + 1) as usize];
let mut d_k = [0f64; (K_HAT_MAX + 1) as usize];
let mut theta_k = [0f64; (K_HAT_MAX + 1) as usize];
for k in 1..=k_hat {
let l_lo = 3 * u32::from(k) - 2;
let l_hi = if k < k_hat {
3 * u32::from(k)
} else {
u32::from(l_hat)
};
let mut num = 0.0;
let mut den = 0.0;
for l in l_lo..=l_hi {
let a_l = basis.harmonic_amplitude(sw, l, omega_hat);
let (m_lo, m_hi) = HarmonicBasis::bin_endpoints(l, omega_hat);
for m in m_lo..m_hi {
let synth = basis.synthetic_bin(m, l, omega_hat, a_l);
let observed = sw[packed_index(m)];
let dre = observed.re - synth.re;
let dim = observed.im - synth.im;
num += dre * dre + dim * dim;
den += observed.norm_sqr();
}
}
let d = if den > 0.0 { num / den } else { 1.0 };
d_k[k as usize] = d;
let theta = if e_p_hat_i > E_P_POOR_MATCH_THRESHOLD && k >= 2 {
0.0
} else {
let base = if state.vuv_prev[k as usize] == 1 {
THETA_BASE_VOICED_HISTORY
} else {
THETA_BASE_UNVOICED_HISTORY
};
let modulation = 1.0 - THETA_PITCH_BAND_COEF * f64::from(k - 1) * omega_hat;
(base * modulation * m_xi).max(0.0)
};
theta_k[k as usize] = theta;
vuv[k as usize] = if d < theta { 1 } else { 0 };
}
state.xi_max = xi_max_new;
state.vuv_prev.fill(0);
for k in 1..=k_hat {
state.vuv_prev[k as usize] = vuv[k as usize];
}
state.k_prev = k_hat;
VuvResult {
k_hat,
vuv,
d_k,
theta_k,
m_xi,
xi_0,
xi_max_after: xi_max_new,
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::{L_HAT_MAX, L_HAT_MIN, W_R_HALF, refine_pitch, signal_spectrum};
fn broadband_periodic(period: f64, max_h: u32) -> [f64; (2 * W_R_HALF + 1) as usize] {
let mut signal = [0.0f64; (2 * W_R_HALF + 1) as usize];
let omega = 2.0 * core::f64::consts::PI / period;
for (idx, slot) in signal.iter_mut().enumerate() {
let n = idx as i32 - W_R_HALF;
let nf = f64::from(n);
let mut s = 0.0;
for h in 1..=max_h {
s += (f64::from(h) * omega * nf).cos() / f64::from(h);
}
*slot = s;
}
signal
}
#[test]
fn band_count_matches_eq34_table() {
assert_eq!(band_count_for(9), 3);
assert_eq!(band_count_for(10), 4);
assert_eq!(band_count_for(11), 4);
assert_eq!(band_count_for(12), 4);
assert_eq!(band_count_for(36), 12);
assert_eq!(band_count_for(37), 12);
assert_eq!(band_count_for(56), 12);
let mut prev = 0u8;
for l in L_HAT_MIN..=L_HAT_MAX {
let k = band_count_for(l);
assert!(k >= prev, "K̂({l}) = {k}, prev = {prev}");
assert!(k <= K_HAT_MAX);
prev = k;
}
}
#[test]
fn vuv_state_cold_start_matches_spec() {
let s = VuvState::cold_start();
assert_eq!(s.xi_max(), XI_MAX_FLOOR);
for k in 0..=K_HAT_MAX {
assert_eq!(s.vuv_prev(k), 0);
}
assert_eq!(s.vuv_prev(K_HAT_MAX + 1), 0);
}
#[test]
fn vuv_xi_max_attack_blends_50_50() {
let basis = HarmonicBasis::new();
let signal = broadband_periodic(50.0, 20);
let sw = signal_spectrum(&signal);
let refinement = refine_pitch(&sw, &basis, 50.0);
let mut state = VuvState::cold_start();
let prev_xi_max = state.xi_max;
let res = determine_vuv(&sw, &basis, &refinement, 0.1, &mut state);
let wr_dc = basis.w_r(0);
let inv = 1.0 / (wr_dc * wr_dc);
let mut xi_lf = 0.0;
for m in 0..=63 {
xi_lf += sw[m].norm_sqr();
}
xi_lf *= inv;
let mut xi_hf = 0.0;
for m in 64..=128 {
xi_hf += sw[m].norm_sqr();
}
xi_hf *= inv;
let xi_0 = xi_lf + xi_hf;
if xi_0 > prev_xi_max {
let expected = 0.5 * prev_xi_max + 0.5 * xi_0;
assert!(
(res.xi_max_after - expected).abs() < 1e-6 * expected.abs().max(1.0),
"xi_max_after = {}, expected = {}",
res.xi_max_after,
expected
);
}
}
#[test]
fn vuv_xi_max_stays_at_floor_on_silence() {
let basis = HarmonicBasis::new();
let signal = [0.0f64; (2 * W_R_HALF + 1) as usize];
let sw = signal_spectrum(&signal);
let refinement = PitchRefinement {
omega_hat: 2.0 * core::f64::consts::PI / 50.0,
l_hat: 23,
b0: 41,
e_r: 0.0,
};
let mut state = VuvState::cold_start();
let res = determine_vuv(&sw, &basis, &refinement, 0.1, &mut state);
assert_eq!(res.xi_max_after, XI_MAX_FLOOR);
}
#[test]
fn vuv_result_k_hat_matches_band_count() {
let basis = HarmonicBasis::new();
let signal = broadband_periodic(50.0, 20);
let sw = signal_spectrum(&signal);
let refinement = refine_pitch(&sw, &basis, 50.0);
let mut state = VuvState::cold_start();
let res = determine_vuv(&sw, &basis, &refinement, 0.1, &mut state);
assert_eq!(res.k_hat, band_count_for(refinement.l_hat));
assert!(res.k_hat <= K_HAT_MAX);
}
#[test]
fn vuv_silent_input_yields_all_unvoiced() {
let basis = HarmonicBasis::new();
let signal = [0.0f64; (2 * W_R_HALF + 1) as usize];
let sw = signal_spectrum(&signal);
let refinement = PitchRefinement {
omega_hat: 2.0 * core::f64::consts::PI / 50.0,
l_hat: 23,
b0: 41,
e_r: 0.0,
};
let mut state = VuvState::cold_start();
let res = determine_vuv(&sw, &basis, &refinement, 0.1, &mut state);
for k in 1..=res.k_hat {
assert_eq!(res.vuv[k as usize], 0, "band {k} voiced on silence");
assert_eq!(res.d_k[k as usize], 1.0);
}
}
#[test]
fn vuv_poor_pitch_match_forces_bands_2_plus_unvoiced() {
let basis = HarmonicBasis::new();
let signal = broadband_periodic(50.0, 20);
let sw = signal_spectrum(&signal);
let refinement = refine_pitch(&sw, &basis, 50.0);
let mut state = VuvState::cold_start();
let res = determine_vuv(&sw, &basis, &refinement, 0.75, &mut state);
for k in 2..=res.k_hat {
assert_eq!(
res.vuv[k as usize], 0,
"band {k} voiced under poor pitch match"
);
}
}
#[test]
fn vuv_hysteresis_threshold_depends_on_previous_decision() {
let basis = HarmonicBasis::new();
let signal = broadband_periodic(50.0, 20);
let sw = signal_spectrum(&signal);
let refinement = refine_pitch(&sw, &basis, 50.0);
let mut cold = VuvState::cold_start();
let res_cold = determine_vuv(&sw, &basis, &refinement, 0.1, &mut cold);
let mut warm = VuvState::cold_start();
for k in 1..=K_HAT_MAX {
warm.vuv_prev[k as usize] = 1;
}
warm.k_prev = K_HAT_MAX;
let res_warm = determine_vuv(&sw, &basis, &refinement, 0.1, &mut warm);
let count_cold: u32 = (1..=res_cold.k_hat)
.map(|k| u32::from(res_cold.vuv[k as usize]))
.sum();
let count_warm: u32 = (1..=res_warm.k_hat)
.map(|k| u32::from(res_warm.vuv[k as usize]))
.sum();
assert!(
count_warm >= count_cold,
"warm voiced count {count_warm} < cold voiced count {count_cold}"
);
}
#[test]
fn vuv_state_is_committed_after_call() {
let basis = HarmonicBasis::new();
let signal = broadband_periodic(50.0, 20);
let sw = signal_spectrum(&signal);
let refinement = refine_pitch(&sw, &basis, 50.0);
let mut state = VuvState::cold_start();
let res = determine_vuv(&sw, &basis, &refinement, 0.1, &mut state);
assert_eq!(state.xi_max, res.xi_max_after);
for k in 1..=res.k_hat {
assert_eq!(state.vuv_prev[k as usize], res.vuv[k as usize]);
}
for k in (res.k_hat + 1)..=K_HAT_MAX {
assert_eq!(state.vuv_prev[k as usize], 0);
}
}
#[test]
fn vuv_d_k_values_are_bounded_and_finite() {
let basis = HarmonicBasis::new();
let signal = broadband_periodic(50.0, 20);
let sw = signal_spectrum(&signal);
let refinement = refine_pitch(&sw, &basis, 50.0);
let mut state = VuvState::cold_start();
let res = determine_vuv(&sw, &basis, &refinement, 0.1, &mut state);
for k in 1..=res.k_hat {
let d = res.d_k[k as usize];
assert!(d.is_finite() && d >= 0.0, "D_{k} = {d}");
}
}
}