use super::{analysis_window, lpf_tap};
use crate::mbe_params::SAMPLES_PER_FRAME;
pub const W_I_HALF: i32 = 150;
pub const H_LPF_HALF: i32 = 10;
pub const LOOKAHEAD_LEN: usize = 3 * SAMPLES_PER_FRAME as usize;
const SHARED_LPF_TAIL: usize = (W_I_HALF as usize) - (SAMPLES_PER_FRAME as usize / 2);
pub const SHARED_LPF_LEN: usize =
2 * SAMPLES_PER_FRAME as usize + 2 * W_I_HALF as usize + 1;
pub const PITCH_INPUT_HALF: i32 = W_I_HALF + H_LPF_HALF;
pub const PITCH_INPUT_LEN: usize = (2 * PITCH_INPUT_HALF + 1) as usize;
pub const PITCH_GRID_MIN: f64 = 21.0;
pub const PITCH_GRID_MAX: f64 = 122.0;
pub const PITCH_GRID_STEP: f64 = 0.5;
pub const PITCH_GRID_LEN: usize = 203;
const PITCH_CONTINUITY_LO: f64 = 0.8;
const PITCH_CONTINUITY_HI: f64 = 1.2;
const ENERGY_FLOOR: f64 = 1e-12;
pub const PITCH_COLD_START: f64 = 100.0;
pub fn compute_s_lpf(s_input: &[f64; PITCH_INPUT_LEN]) -> [f64; (2 * W_I_HALF + 1) as usize] {
const TAPS: usize = (2 * H_LPF_HALF + 1) as usize;
let mut h = [0.0f64; TAPS];
for j in -H_LPF_HALF..=H_LPF_HALF {
h[(j + H_LPF_HALF) as usize] = f64::from(lpf_tap(j));
}
let mut out = [0.0f64; (2 * W_I_HALF + 1) as usize];
let two_h = 2 * H_LPF_HALF as usize;
for (i, slot) in out.iter_mut().enumerate() {
let src_base = i + two_h;
let mut acc = 0.0;
for k in 0..TAPS {
acc += s_input[src_base - k] * h[k];
}
*slot = acc;
}
out
}
pub fn compute_s_lpf_shared(lookahead: &[f64; LOOKAHEAD_LEN]) -> [f64; SHARED_LPF_LEN] {
const TAPS: usize = (2 * H_LPF_HALF + 1) as usize;
let mut h = [0.0f64; TAPS];
for j in -H_LPF_HALF..=H_LPF_HALF {
h[(j + H_LPF_HALF) as usize] = f64::from(lpf_tap(j));
}
const PAD_LEFT: usize = SHARED_LPF_TAIL + H_LPF_HALF as usize;
const PAD_LEN: usize = SHARED_LPF_LEN + 2 * H_LPF_HALF as usize;
let mut padded = [0.0f64; PAD_LEN];
padded[PAD_LEFT..PAD_LEFT + LOOKAHEAD_LEN].copy_from_slice(lookahead);
let two_h = 2 * H_LPF_HALF as usize;
let mut out = [0.0f64; SHARED_LPF_LEN];
for (i, slot) in out.iter_mut().enumerate() {
let src_base = i + two_h;
let mut acc = 0.0;
for k in 0..TAPS {
acc += padded[src_base - k] * h[k];
}
*slot = acc;
}
out
}
#[inline]
pub fn slot_s_lpf(shared: &[f64; SHARED_LPF_LEN], slot: usize) -> &[f64] {
let start = slot * SAMPLES_PER_FRAME as usize;
&shared[start..start + (2 * W_I_HALF + 1) as usize]
}
#[derive(Clone, Debug)]
pub struct PitchSearch {
energy: f64,
w_i_fourth_moment: f64,
r: [f64; R_MAX_LAG + 1],
e_grid: [f64; PITCH_GRID_LEN],
sparse: Option<Box<SparseRmq>>,
}
const SPARSE_LEVELS: usize = 8;
type SparseRow = [(f64, u16); PITCH_GRID_LEN];
type SparseRmq = [SparseRow; SPARSE_LEVELS];
#[inline]
fn lex_min(a: (f64, u16), b: (f64, u16)) -> (f64, u16) {
match a.0.partial_cmp(&b.0) {
Some(core::cmp::Ordering::Less) => a,
Some(core::cmp::Ordering::Greater) => b,
_ => {
if a.1 <= b.1 {
a
} else {
b
}
}
}
}
const R_MAX_LAG: usize = W_I_HALF as usize;
const AUTOCORR_FFT_LEN: usize = 512;
fn autocorr_via_fft(sw2: &[f64; (2 * W_I_HALF + 1) as usize]) -> [f64; R_MAX_LAG + 1] {
use num_complex::Complex;
use rustfft::{Fft, FftPlanner};
use std::sync::OnceLock;
static FFT_FWD: OnceLock<std::sync::Arc<dyn Fft<f64>>> = OnceLock::new();
static FFT_INV: OnceLock<std::sync::Arc<dyn Fft<f64>>> = OnceLock::new();
let fft_fwd = FFT_FWD.get_or_init(|| {
let mut planner = FftPlanner::<f64>::new();
planner.plan_fft_forward(AUTOCORR_FFT_LEN)
});
let fft_inv = FFT_INV.get_or_init(|| {
let mut planner = FftPlanner::<f64>::new();
planner.plan_fft_inverse(AUTOCORR_FFT_LEN)
});
let mut buf = [Complex::<f64>::new(0.0, 0.0); AUTOCORR_FFT_LEN];
for (i, &x) in sw2.iter().enumerate() {
buf[i].re = x;
}
fft_fwd.process(&mut buf);
for c in buf.iter_mut() {
c.re = c.re * c.re + c.im * c.im;
c.im = 0.0;
}
fft_inv.process(&mut buf);
let scale = 1.0 / AUTOCORR_FFT_LEN as f64;
let mut r = [0.0f64; R_MAX_LAG + 1];
for (t, slot) in r.iter_mut().enumerate() {
*slot = buf[t].re * scale;
}
r
}
impl PitchSearch {
pub fn new(s_input: &[f64; PITCH_INPUT_LEN]) -> Self {
let s_lpf = compute_s_lpf(s_input);
Self::from_lpf(&s_lpf)
}
pub fn from_lpf_slice(s_lpf: &[f64]) -> Self {
let arr: &[f64; (2 * W_I_HALF + 1) as usize] = s_lpf
.try_into()
.expect("from_lpf_slice expects exactly 2·W_I_HALF + 1 entries");
Self::from_lpf(arr)
}
pub fn from_lpf(s_lpf: &[f64; (2 * W_I_HALF + 1) as usize]) -> Self {
let mut sw2 = [0.0f64; (2 * W_I_HALF + 1) as usize];
let mut energy = 0.0;
let mut w4 = 0.0;
for j in -W_I_HALF..=W_I_HALF {
let idx = (j + W_I_HALF) as usize;
let s = s_lpf[idx];
let w = f64::from(analysis_window(j));
let w2 = w * w;
sw2[idx] = s * w2;
energy += s * s * w2;
w4 += w2 * w2;
}
let r = autocorr_via_fft(&sw2);
let mut e_grid = [0.0f64; PITCH_GRID_LEN];
let mut tmp = Self {
energy,
w_i_fourth_moment: w4,
r,
e_grid: [0.0; PITCH_GRID_LEN],
sparse: None,
};
for (i, slot) in e_grid.iter_mut().enumerate() {
let p = PITCH_GRID_MIN + (i as f64) * PITCH_GRID_STEP;
*slot = tmp.e_of_p(p);
}
tmp.e_grid = e_grid;
tmp
}
pub fn enable_argmin_fast(&mut self) {
if self.sparse.is_some() {
return;
}
let mut sparse: Box<SparseRmq> = Box::new(core::array::from_fn(|_| {
[(f64::INFINITY, 0u16); PITCH_GRID_LEN]
}));
for i in 0..PITCH_GRID_LEN {
sparse[0][i] = (self.e_grid[i], i as u16);
}
for k in 1..SPARSE_LEVELS {
let half = 1usize << (k - 1);
for i in 0..PITCH_GRID_LEN {
let left = sparse[k - 1][i];
let right = if i + half < PITCH_GRID_LEN {
sparse[k - 1][i + half]
} else {
left
};
sparse[k][i] = lex_min(left, right);
}
}
self.sparse = Some(sparse);
}
fn r_at(&self, t: f64) -> f64 {
let t_floor = t.floor();
let frac = t - t_floor;
let ti = t_floor as usize;
if ti + 1 >= self.r.len() {
return *self.r.last().unwrap_or(&0.0);
}
(1.0 - frac) * self.r[ti] + frac * self.r[ti + 1]
}
pub fn e_of_p(&self, p: f64) -> f64 {
if self.energy < ENERGY_FLOOR {
return 1.0;
}
let n_max = (f64::from(W_I_HALF) / p).floor() as i32;
let mut inner = 0.0;
for n in -n_max..=n_max {
inner += self.r_at((f64::from(n) * p).abs());
}
let num = self.energy - p * inner;
let denom = self.energy * (1.0 - p * self.w_i_fourth_moment);
if denom <= 0.0 {
return 1.0;
}
num / denom
}
#[inline]
pub fn e_of_p_grid(&self, i: usize) -> f64 {
self.e_grid.get(i).copied().unwrap_or(1.0)
}
pub fn argmin_in_range(&self, p_lo: f64, p_hi: f64) -> (f64, f64) {
let p_lo = p_lo.max(PITCH_GRID_MIN);
let p_hi = p_hi.min(PITCH_GRID_MAX);
let i_lo = ((p_lo - PITCH_GRID_MIN) / PITCH_GRID_STEP).ceil() as i32;
let i_hi = ((p_hi - PITCH_GRID_MIN) / PITCH_GRID_STEP).floor() as i32;
if i_lo > i_hi {
let p = PITCH_GRID_MIN + f64::from(i_lo.max(0)) * PITCH_GRID_STEP;
return (p, 1.0);
}
let i_lo_u = i_lo as usize;
let i_hi_u = (i_hi as usize).min(PITCH_GRID_LEN - 1);
if let Some(sparse) = self.sparse.as_deref() {
let len = i_hi_u - i_lo_u + 1;
let k = (len as u32).ilog2() as usize;
let window = 1usize << k;
let left = sparse[k][i_lo_u];
let right = sparse[k][i_hi_u + 1 - window];
let (best_e, best_i_u16) = lex_min(left, right);
let best_p = PITCH_GRID_MIN + f64::from(best_i_u16) * PITCH_GRID_STEP;
return (best_p, best_e);
}
let slice = &self.e_grid[i_lo_u..=i_hi_u];
let (rel_idx, &best_e) = slice
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal))
.expect("non-empty slice");
let best_i = i_lo_u + rel_idx;
let best_p = PITCH_GRID_MIN + (best_i as f64) * PITCH_GRID_STEP;
(best_p, best_e)
}
}
#[inline]
pub fn snap_to_pitch_grid(p: f64) -> Option<f64> {
let snapped = (p / PITCH_GRID_STEP).round() * PITCH_GRID_STEP;
if (PITCH_GRID_MIN..=PITCH_GRID_MAX).contains(&snapped) {
Some(snapped)
} else {
None
}
}
#[derive(Clone, Copy, Debug)]
pub struct LookBackContext {
pub prev_pitch: f64,
pub prev_err_1: f64,
pub prev_err_2: f64,
}
impl LookBackContext {
pub const fn cold_start() -> Self {
Self {
prev_pitch: PITCH_COLD_START,
prev_err_1: 0.0,
prev_err_2: 0.0,
}
}
}
pub fn look_back(current: &PitchSearch, ctx: LookBackContext) -> (f64, f64) {
let p_lo = PITCH_CONTINUITY_LO * ctx.prev_pitch;
let p_hi = PITCH_CONTINUITY_HI * ctx.prev_pitch;
let (p_b, e_b) = current.argmin_in_range(p_lo, p_hi);
let ce_b = e_b + ctx.prev_err_1 + ctx.prev_err_2;
(p_b, ce_b)
}
pub fn look_ahead(
current: &PitchSearch,
next1: &PitchSearch,
next2: &PitchSearch,
) -> (f64, f64) {
let mut best_p0 = PITCH_GRID_MIN;
let mut best_ce = f64::INFINITY;
for i in 0..PITCH_GRID_LEN {
let p0 = PITCH_GRID_MIN + (i as f64) * PITCH_GRID_STEP;
let e0 = current.e_of_p_grid(i);
let (p1, e1) = next1.argmin_in_range(
PITCH_CONTINUITY_LO * p0,
PITCH_CONTINUITY_HI * p0,
);
let (_p2, e2) = next2.argmin_in_range(
PITCH_CONTINUITY_LO * p1,
PITCH_CONTINUITY_HI * p1,
);
let ce = e0 + e1 + e2;
if ce < best_ce {
best_ce = ce;
best_p0 = p0;
}
}
let ce_p0 = best_ce;
let n_max = (best_p0 / PITCH_GRID_MIN).floor() as u32;
if n_max >= 2 {
for n in (2..=n_max).rev() {
let raw = best_p0 / f64::from(n);
if raw < PITCH_GRID_MIN {
continue;
}
let Some(sub) = snap_to_pitch_grid(raw) else {
continue;
};
let e0s = current.e_of_p(sub);
let (p1s, e1s) = next1.argmin_in_range(
PITCH_CONTINUITY_LO * sub,
PITCH_CONTINUITY_HI * sub,
);
let (_p2s, e2s) = next2.argmin_in_range(
PITCH_CONTINUITY_LO * p1s,
PITCH_CONTINUITY_HI * p1s,
);
let ce_sub = e0s + e1s + e2s;
let ratio = if ce_p0 > 0.0 {
ce_sub / ce_p0
} else {
f64::INFINITY
};
let eq18 = ce_sub <= 0.85 && ratio <= 1.7;
let eq19 = ce_sub <= 0.40 && ratio <= 3.5;
let eq20 = ce_sub <= 0.05;
if eq18 || eq19 || eq20 {
return (sub, ce_sub);
}
}
}
(best_p0, ce_p0)
}
#[inline]
pub fn decide_initial_pitch(p_b: f64, ce_b: f64, p_f: f64, ce_f: f64) -> f64 {
if ce_b <= 0.48 {
p_b
} else if ce_b <= ce_f {
p_b
} else {
p_f
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn snap_to_pitch_grid_rounds_to_half_sample() {
assert_eq!(snap_to_pitch_grid(50.0), Some(50.0));
assert_eq!(snap_to_pitch_grid(50.1), Some(50.0));
assert_eq!(snap_to_pitch_grid(50.25), Some(50.5)); assert_eq!(snap_to_pitch_grid(50.3), Some(50.5));
assert_eq!(snap_to_pitch_grid(50.7), Some(50.5));
assert_eq!(snap_to_pitch_grid(50.8), Some(51.0));
}
#[test]
fn snap_to_pitch_grid_rejects_out_of_range() {
assert_eq!(snap_to_pitch_grid(20.9), Some(21.0));
assert_eq!(snap_to_pitch_grid(20.6), None);
assert_eq!(snap_to_pitch_grid(10.0), None);
assert_eq!(snap_to_pitch_grid(200.0), None);
assert_eq!(snap_to_pitch_grid(21.0), Some(21.0));
assert_eq!(snap_to_pitch_grid(122.0), Some(122.0));
}
#[test]
fn e_of_p_is_one_on_silence() {
let s = [0.0f64; PITCH_INPUT_LEN];
let search = PitchSearch::new(&s);
for p in [21.0, 50.0, 100.0, 122.0] {
assert_eq!(search.e_of_p(p), 1.0, "P = {p}");
}
}
#[test]
fn argmin_in_range_clamps_and_returns_grid_value() {
let s = [0.1f64; PITCH_INPUT_LEN]; let search = PitchSearch::new(&s);
let (p, _) = search.argmin_in_range(40.0, 60.0);
assert!((40.0..=60.0).contains(&p), "P = {p} out of [40, 60]");
let doubled = (p * 2.0).round();
assert!((doubled / 2.0 - p).abs() < 1e-9, "P = {p} not on grid");
}
fn periodic_input(period: f64) -> [f64; PITCH_INPUT_LEN] {
let mut s = [0.0f64; PITCH_INPUT_LEN];
let omega = 2.0 * core::f64::consts::PI / period;
for (idx, slot) in s.iter_mut().enumerate() {
let n = idx as i32 - PITCH_INPUT_HALF;
let nf = f64::from(n);
*slot = (omega * nf).cos()
+ 0.5 * (2.0 * omega * nf).cos()
+ 0.25 * (3.0 * omega * nf).cos();
}
s
}
#[test]
fn argmin_in_range_matches_brute_force_sweep() {
let s = periodic_input(50.0);
let search = PitchSearch::new(&s);
for (p_lo, p_hi) in [(21.0, 122.0), (40.0, 60.0), (80.0, 100.0), (45.0, 55.0)] {
let (p_hat, e_hat) = search.argmin_in_range(p_lo, p_hi);
let mut brute_min = f64::INFINITY;
let mut brute_p = p_hat;
let mut p = PITCH_GRID_MIN;
while p <= PITCH_GRID_MAX {
if p >= p_lo && p <= p_hi {
let e = search.e_of_p(p);
if e < brute_min {
brute_min = e;
brute_p = p;
}
}
p += PITCH_GRID_STEP;
}
assert!(
(e_hat - brute_min).abs() < 1e-12 && (p_hat - brute_p).abs() < 1e-9,
"[{p_lo}, {p_hi}]: argmin_in_range returned ({p_hat}, {e_hat}), \
brute = ({brute_p}, {brute_min})"
);
}
}
#[test]
#[ignore]
fn probe_dump_e_of_p() {
for true_p in [50.0, 65.0, 80.0, 100.0] {
let s = periodic_input(true_p);
let search = PitchSearch::new(&s);
let (p_hat, e_hat) = search.argmin_in_range(PITCH_GRID_MIN, PITCH_GRID_MAX);
println!("\n=== true_p = {true_p} ===");
println!("argmin: P̂ = {p_hat}, E = {e_hat:.4}");
println!("energy = {:.4e}", search.energy);
println!("w_i_fourth_moment = {:.4e}", search.w_i_fourth_moment);
let mut p = PITCH_GRID_MIN;
while p <= PITCH_GRID_MAX {
let e = search.e_of_p(p);
if p == true_p || p == p_hat || (p - PITCH_GRID_MIN) % 10.0 == 0.0 {
println!(" E({p:6.1}) = {e:+.4}");
}
p += PITCH_GRID_STEP;
}
}
}
#[test]
fn look_back_cold_start_has_zero_history_contribution() {
let s = periodic_input(50.0);
let search = PitchSearch::new(&s);
let ctx = LookBackContext::cold_start();
let (p_b, ce_b) = look_back(&search, ctx);
assert!(
(80.0..=120.0).contains(&p_b),
"P̂_B = {p_b} out of [80, 120]"
);
let e_b = search.e_of_p(p_b);
assert!((ce_b - e_b).abs() < 1e-12, "CE_B = {ce_b}, E(P̂_B) = {e_b}");
}
#[test]
fn decide_initial_pitch_eq21_wins_when_ce_b_is_small() {
assert_eq!(decide_initial_pitch(50.0, 0.3, 70.0, 0.2), 50.0);
assert_eq!(decide_initial_pitch(50.0, 0.48, 70.0, 0.1), 50.0);
}
#[test]
fn decide_initial_pitch_eq22_wins_when_ce_b_leq_ce_f() {
assert_eq!(decide_initial_pitch(50.0, 0.6, 70.0, 0.65), 50.0);
assert_eq!(decide_initial_pitch(50.0, 0.6, 70.0, 0.6), 50.0);
}
#[test]
fn decide_initial_pitch_eq23_picks_forward_when_ce_b_is_larger() {
assert_eq!(decide_initial_pitch(50.0, 0.7, 70.0, 0.5), 70.0);
}
#[test]
fn look_ahead_returns_grid_value_and_finite_ce() {
let s = periodic_input(50.0);
let cur = PitchSearch::new(&s);
let n1 = PitchSearch::new(&s);
let n2 = PitchSearch::new(&s);
let (p_f, ce_f) = look_ahead(&cur, &n1, &n2);
assert!(
(PITCH_GRID_MIN..=PITCH_GRID_MAX).contains(&p_f),
"P̂_F = {p_f} out of grid"
);
assert!(ce_f.is_finite(), "CE_F = {ce_f} not finite");
let doubled = (p_f / PITCH_GRID_STEP).round();
assert!(
(doubled * PITCH_GRID_STEP - p_f).abs() < 1e-9,
"P̂_F = {p_f} not on grid"
);
}
#[test]
fn compute_s_lpf_of_zero_is_zero() {
let s = [0.0f64; PITCH_INPUT_LEN];
let s_lpf = compute_s_lpf(&s);
for (i, &v) in s_lpf.iter().enumerate() {
assert!(v.abs() < 1e-12, "s_LPF[{i}] = {v}");
}
}
#[test]
fn compute_s_lpf_interior_matches_filter_dc_gain() {
let s = [1.0f64; PITCH_INPUT_LEN];
let s_lpf = compute_s_lpf(&s);
let dc_gain: f64 = (-H_LPF_HALF..=H_LPF_HALF)
.map(|j| f64::from(lpf_tap(j)))
.sum();
let center = s_lpf[W_I_HALF as usize];
assert!(
(center - dc_gain).abs() < 1e-9,
"s_LPF(0) = {center}, expected DC gain {dc_gain}"
);
}
}