const FIND_LPC_COND_FAC: f64 = 1e-5;
fn energy(x: &[f32]) -> f64 {
crate::simd::dot_f64(x, x)
}
fn inner_product(a: &[f32], b: &[f32], n: usize) -> f64 {
crate::simd::dot_f64(&a[..n], b)
}
pub(crate) fn burg_modified(
a: &mut [f32],
x: &[f32],
min_inv_gain: f32,
subfr_length: usize,
nb_subfr: usize,
d: usize,
) -> f32 {
assert!(d > 0 && d <= 24 && a.len() >= d);
assert!(x.len() >= nb_subfr * subfr_length && subfr_length > d);
let min_inv_gain = f64::from(min_inv_gain);
const MAXD: usize = 24;
let c0 = energy(&x[..nb_subfr * subfr_length]);
let mut c_first_row = [0.0f64; MAXD];
for s in 0..nb_subfr {
let xs = &x[s * subfr_length..];
for n in 1..=d {
c_first_row[n - 1] += inner_product(xs, &xs[n..], subfr_length - n);
}
}
let mut c_last_row = c_first_row;
let mut caf = [0.0f64; MAXD + 1];
let mut cab = [0.0f64; MAXD + 1];
caf[0] = c0 + FIND_LPC_COND_FAC * c0 + 1e-9;
cab[0] = caf[0];
let mut af = [0.0f64; MAXD];
let mut inv_gain = 1.0f64;
let mut reached_max_gain = false;
for n in 0..d {
for s in 0..nb_subfr {
let xs = &x[s * subfr_length..];
let mut tmp1 = f64::from(xs[n]);
let mut tmp2 = f64::from(xs[subfr_length - n - 1]);
for k in 0..n {
c_first_row[k] -= f64::from(xs[n]) * f64::from(xs[n - k - 1]);
c_last_row[k] -= f64::from(xs[subfr_length - n - 1]) * f64::from(xs[subfr_length - n + k]);
let atmp = af[k];
tmp1 += f64::from(xs[n - k - 1]) * atmp;
tmp2 += f64::from(xs[subfr_length - n + k]) * atmp;
}
for k in 0..=n {
caf[k] -= tmp1 * f64::from(xs[n - k]);
cab[k] -= tmp2 * f64::from(xs[subfr_length - n + k - 1]);
}
}
let mut tmp1 = c_first_row[n];
let mut tmp2 = c_last_row[n];
for k in 0..n {
let atmp = af[k];
tmp1 += c_last_row[n - k - 1] * atmp;
tmp2 += c_first_row[n - k - 1] * atmp;
}
caf[n + 1] = tmp1;
cab[n + 1] = tmp2;
let mut num = cab[n + 1];
let mut nrg_b = cab[0];
let mut nrg_f = caf[0];
for k in 0..n {
let atmp = af[k];
num += cab[n - k] * atmp;
nrg_b += cab[k + 1] * atmp;
nrg_f += caf[k + 1] * atmp;
}
let mut rc = -2.0 * num / (nrg_f + nrg_b);
let tmp = inv_gain * (1.0 - rc * rc);
if tmp <= min_inv_gain {
rc = (1.0 - min_inv_gain / inv_gain).sqrt();
if num > 0.0 {
rc = -rc;
}
inv_gain = min_inv_gain;
reached_max_gain = true;
} else {
inv_gain = tmp;
}
for k in 0..(n + 1) >> 1 {
let t1 = af[k];
let t2 = af[n - k - 1];
af[k] = t1 + rc * t2;
af[n - k - 1] = t2 + rc * t1;
}
af[n] = rc;
if reached_max_gain {
for af_k in af.iter_mut().take(d).skip(n + 1) {
*af_k = 0.0;
}
break;
}
for k in 0..=n + 1 {
let t1 = caf[k];
caf[k] += rc * cab[n + 1 - k];
cab[n + 1 - k] += rc * t1;
}
}
if reached_max_gain {
for k in 0..d {
a[k] = -af[k] as f32;
}
let mut c0 = c0;
for s in 0..nb_subfr {
c0 -= energy(&x[s * subfr_length..s * subfr_length + d]);
}
(c0 * inv_gain) as f32
} else {
let mut nrg_f = caf[0];
let mut tmp1 = 1.0f64;
for k in 0..d {
let atmp = af[k];
nrg_f += caf[k + 1] * atmp;
tmp1 += atmp * atmp;
a[k] = -atmp as f32;
}
nrg_f -= FIND_LPC_COND_FAC * c0 * tmp1;
nrg_f as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
#[test]
fn recovers_ar2_coefficients() {
let (a1, a2) = (1.3f32, -0.6f32);
let n = 480usize;
let mut x = vec![0.0f32; n];
let mut seed = 12345u32;
let mut prev = (0.0f32, 0.0f32);
for v in &mut x {
seed = seed.wrapping_mul(1_103_515_245).wrapping_add(12_345);
let e = ((seed >> 16) as f32 / 32_768.0) - 1.0;
let s = a1 * prev.0 + a2 * prev.1 + 0.1 * e;
*v = s;
prev = (s, prev.0);
}
let mut a = [0.0f32; 16];
let order = 2usize;
let resid = burg_modified(&mut a, &x, 1.0 / 1e4, n, 1, order);
assert!((a[0] - a1).abs() < 0.1, "a0={} expected {a1}", a[0]);
assert!((a[1] - a2).abs() < 0.1, "a1={} expected {a2}", a[1]);
let input_energy = energy(&x) as f32;
assert!(
resid > 0.0 && resid < 0.3 * input_energy,
"resid {resid} vs energy {input_energy}"
);
}
#[test]
fn enforces_stability_on_a_pure_tone() {
let n = 320usize;
let x: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3).sin()).collect();
let mut a = [0.0f32; 16];
let resid = burg_modified(&mut a, &x, 1.0 / 1e4, n, 1, 10);
assert!(resid.is_finite() && resid >= 0.0, "resid {resid}");
assert!(a.iter().all(|v| v.is_finite() && v.abs() < 8.0), "coefs {a:?}");
}
}