use crate::common::{
FFT_LENGTH_BY_2, FFT_LENGTH_BY_2_MINUS_1, FFT_LENGTH_BY_2_PLUS_1, X2_BAND_ENERGY_THRESHOLD,
};
const MIN_ERL: f32 = 0.01;
const MAX_ERL: f32 = 1000.0;
#[derive(Debug)]
pub(crate) struct ErlEstimator {
startup_phase_length_blocks: usize,
erl: [f32; FFT_LENGTH_BY_2_PLUS_1],
hold_counters: [i32; FFT_LENGTH_BY_2_MINUS_1],
erl_time_domain: f32,
hold_counter_time_domain: i32,
blocks_since_reset: usize,
}
impl ErlEstimator {
pub(crate) fn new(startup_phase_length_blocks: usize) -> Self {
Self {
startup_phase_length_blocks,
erl: [MAX_ERL; FFT_LENGTH_BY_2_PLUS_1],
hold_counters: [0; FFT_LENGTH_BY_2_MINUS_1],
erl_time_domain: MAX_ERL,
hold_counter_time_domain: 0,
blocks_since_reset: 0,
}
}
pub(crate) fn reset(&mut self) {
self.blocks_since_reset = 0;
}
pub(crate) fn update(
&mut self,
converged_filters: &[bool],
render_spectra: &[[f32; FFT_LENGTH_BY_2_PLUS_1]],
capture_spectra: &[[f32; FFT_LENGTH_BY_2_PLUS_1]],
) {
let num_capture_channels = converged_filters.len();
debug_assert_eq!(capture_spectra.len(), num_capture_channels);
let first_converged = converged_filters.iter().position(|&c| c);
let any_filter_converged = first_converged.is_some();
self.blocks_since_reset += 1;
if self.blocks_since_reset < self.startup_phase_length_blocks || !any_filter_converged {
return;
}
let mut max_capture_spectrum = [0.0f32; FFT_LENGTH_BY_2_PLUS_1];
if num_capture_channels == 1 {
max_capture_spectrum.copy_from_slice(&capture_spectra[0]);
} else {
let first = first_converged.unwrap();
max_capture_spectrum.copy_from_slice(&capture_spectra[first]);
for ch in (first + 1)..num_capture_channels {
if !converged_filters[ch] {
continue;
}
for (max_k, &cap_k) in max_capture_spectrum
.iter_mut()
.zip(capture_spectra[ch].iter())
{
*max_k = (*max_k).max(cap_k);
}
}
}
let num_render_channels = render_spectra.len();
let mut max_render_spectrum = [0.0f32; FFT_LENGTH_BY_2_PLUS_1];
max_render_spectrum.copy_from_slice(&render_spectra[0]);
for rend_ch in &render_spectra[1..num_render_channels] {
for (max_k, &rend_k) in max_render_spectrum.iter_mut().zip(rend_ch.iter()) {
*max_k = (*max_k).max(rend_k);
}
}
let x2 = &max_render_spectrum;
let y2 = &max_capture_spectrum;
for k in 1..FFT_LENGTH_BY_2 {
if x2[k] > X2_BAND_ENERGY_THRESHOLD {
let new_erl = y2[k] / x2[k];
if new_erl < self.erl[k] {
self.hold_counters[k - 1] = 1000;
self.erl[k] += 0.1 * (new_erl - self.erl[k]);
self.erl[k] = self.erl[k].max(MIN_ERL);
}
}
}
for counter in &mut self.hold_counters {
*counter -= 1;
}
for k in 1..FFT_LENGTH_BY_2 {
if self.hold_counters[k - 1] <= 0 {
self.erl[k] = MAX_ERL.min(2.0 * self.erl[k]);
}
}
self.erl[0] = self.erl[1];
self.erl[FFT_LENGTH_BY_2] = self.erl[FFT_LENGTH_BY_2 - 1];
let x2_sum: f32 = x2.iter().sum();
if x2_sum > X2_BAND_ENERGY_THRESHOLD * FFT_LENGTH_BY_2_PLUS_1 as f32 {
let y2_sum: f32 = y2.iter().sum();
let new_erl = y2_sum / x2_sum;
if new_erl < self.erl_time_domain {
self.hold_counter_time_domain = 1000;
self.erl_time_domain += 0.1 * (new_erl - self.erl_time_domain);
self.erl_time_domain = self.erl_time_domain.max(MIN_ERL);
}
}
self.hold_counter_time_domain -= 1;
if self.hold_counter_time_domain <= 0 {
self.erl_time_domain = MAX_ERL.min(2.0 * self.erl_time_domain);
}
}
pub(crate) fn erl_time_domain(&self) -> f32 {
self.erl_time_domain
}
}
#[cfg(test)]
mod tests {
use super::*;
fn verify_erl(erl: &[f32; FFT_LENGTH_BY_2_PLUS_1], erl_time_domain: f32, reference: f32) {
for &v in erl.iter() {
assert!(
(v - reference).abs() < 0.001,
"ERL bin {v} != reference {reference}"
);
}
assert!(
(erl_time_domain - reference).abs() < 0.001,
"ERL time domain {erl_time_domain} != reference {reference}"
);
}
#[test]
fn estimates() {
for &num_render_channels in &[1usize, 2, 8] {
for &num_capture_channels in &[1usize, 2, 8] {
let mut x2 = vec![[0.0f32; FFT_LENGTH_BY_2_PLUS_1]; num_render_channels];
let mut y2 = vec![[0.0f32; FFT_LENGTH_BY_2_PLUS_1]; num_capture_channels];
let mut converged_filters = vec![false; num_capture_channels];
let converged_idx = num_capture_channels - 1;
converged_filters[converged_idx] = true;
let mut estimator = ErlEstimator::new(0);
for x2_ch in &mut x2 {
x2_ch.fill(500.0 * 1000.0 * 1000.0);
}
y2[converged_idx].fill(10.0 * x2[0][0]);
for _ in 0..200 {
estimator.update(&converged_filters, &x2, &y2);
}
verify_erl(&estimator.erl, estimator.erl_time_domain(), 10.0);
y2[converged_idx].fill(10000.0 * x2[0][0]);
for _ in 0..998 {
estimator.update(&converged_filters, &x2, &y2);
}
verify_erl(&estimator.erl, estimator.erl_time_domain(), 10.0);
estimator.update(&converged_filters, &x2, &y2);
verify_erl(&estimator.erl, estimator.erl_time_domain(), 20.0);
for _ in 0..1000 {
estimator.update(&converged_filters, &x2, &y2);
}
verify_erl(&estimator.erl, estimator.erl_time_domain(), 1000.0);
for x2_ch in &mut x2 {
x2_ch.fill(1000.0 * 1000.0);
}
y2[converged_idx].fill(10.0 * x2[0][0]);
for _ in 0..200 {
estimator.update(&converged_filters, &x2, &y2);
}
verify_erl(&estimator.erl, estimator.erl_time_domain(), 1000.0);
}
}
}
}