use crate::config::{FFT_SIZE_BY_2_PLUS_1, LONG_STARTUP_PHASE_BLOCKS};
use crate::fast_math::exp_approximation_sign_flip;
use crate::signal_model_estimator::SignalModelEstimator;
pub(crate) struct SignalAnalysis<'a> {
pub num_analyzed_frames: i32,
pub prior_snr: &'a [f32; FFT_SIZE_BY_2_PLUS_1],
pub post_snr: &'a [f32; FFT_SIZE_BY_2_PLUS_1],
pub conservative_noise_spectrum: &'a [f32; FFT_SIZE_BY_2_PLUS_1],
pub signal_spectrum: &'a [f32; FFT_SIZE_BY_2_PLUS_1],
pub signal_spectral_sum: f32,
pub signal_energy: f32,
}
#[derive(Debug)]
pub(crate) struct SpeechProbabilityEstimator {
signal_model_estimator: SignalModelEstimator,
prior_speech_prob: f32,
speech_probability: [f32; FFT_SIZE_BY_2_PLUS_1],
}
impl Default for SpeechProbabilityEstimator {
fn default() -> Self {
Self {
signal_model_estimator: SignalModelEstimator::default(),
prior_speech_prob: 0.5,
speech_probability: [0.0; FFT_SIZE_BY_2_PLUS_1],
}
}
}
impl SpeechProbabilityEstimator {
pub(crate) fn update(&mut self, analysis: &SignalAnalysis<'_>) {
if analysis.num_analyzed_frames < LONG_STARTUP_PHASE_BLOCKS {
self.signal_model_estimator
.adjust_normalization(analysis.num_analyzed_frames, analysis.signal_energy);
}
self.signal_model_estimator.update(
analysis.prior_snr,
analysis.post_snr,
analysis.conservative_noise_spectrum,
analysis.signal_spectrum,
analysis.signal_spectral_sum,
analysis.signal_energy,
);
let model = self.signal_model_estimator.model();
let prior_model = self.signal_model_estimator.prior_model();
const WIDTH_PRIOR_0: f32 = 4.0;
const WIDTH_PRIOR_1: f32 = 2.0 * WIDTH_PRIOR_0;
let width_prior = if model.lrt < prior_model.lrt {
WIDTH_PRIOR_1
} else {
WIDTH_PRIOR_0
};
let indicator0 = 0.5 * ((width_prior * (model.lrt - prior_model.lrt)).tanh() + 1.0);
let width_prior = if model.spectral_flatness > prior_model.flatness_threshold {
WIDTH_PRIOR_1
} else {
WIDTH_PRIOR_0
};
let indicator1 = 0.5
* ((width_prior * (prior_model.flatness_threshold - model.spectral_flatness)).tanh()
+ 1.0);
let width_prior = if model.spectral_diff < prior_model.template_diff_threshold {
WIDTH_PRIOR_1
} else {
WIDTH_PRIOR_0
};
let indicator2 = 0.5
* ((width_prior * (model.spectral_diff - prior_model.template_diff_threshold)).tanh()
+ 1.0);
let ind_prior = prior_model.lrt_weighting * indicator0
+ prior_model.flatness_weighting * indicator1
+ prior_model.difference_weighting * indicator2;
self.prior_speech_prob += 0.1 * (ind_prior - self.prior_speech_prob);
self.prior_speech_prob = self.prior_speech_prob.clamp(0.01, 1.0);
let gain_prior = (1.0 - self.prior_speech_prob) / (self.prior_speech_prob + 0.0001);
let mut inv_lrt = [0.0f32; FFT_SIZE_BY_2_PLUS_1];
exp_approximation_sign_flip(&model.avg_log_lrt, &mut inv_lrt);
for (sp, &il) in self.speech_probability.iter_mut().zip(inv_lrt.iter()) {
*sp = 1.0 / (1.0 + gain_prior * il);
}
}
pub(crate) fn prior_probability(&self) -> f32 {
self.prior_speech_prob
}
pub(crate) fn probability(&self) -> &[f32; FFT_SIZE_BY_2_PLUS_1] {
&self.speech_probability
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_state() {
let est = SpeechProbabilityEstimator::default();
assert_eq!(est.prior_probability(), 0.5);
assert_eq!(est.probability(), &[0.0; FFT_SIZE_BY_2_PLUS_1]);
}
#[test]
fn update_produces_valid_probabilities() {
let mut est = SpeechProbabilityEstimator::default();
let prior_snr = [1.0f32; FFT_SIZE_BY_2_PLUS_1];
let post_snr = [1.0f32; FFT_SIZE_BY_2_PLUS_1];
let cons_noise = [1.0f32; FFT_SIZE_BY_2_PLUS_1];
let signal = [10.0f32; FFT_SIZE_BY_2_PLUS_1];
let sum: f32 = signal.iter().sum();
est.update(&SignalAnalysis {
num_analyzed_frames: 0,
prior_snr: &prior_snr,
post_snr: &post_snr,
conservative_noise_spectrum: &cons_noise,
signal_spectrum: &signal,
signal_spectral_sum: sum,
signal_energy: sum,
});
for &p in est.probability() {
assert!(
(0.0..=1.0).contains(&p),
"probability {p} out of range [0, 1]"
);
}
assert!(est.prior_probability() >= 0.01);
assert!(est.prior_probability() <= 1.0);
}
#[test]
fn high_snr_gives_high_speech_probability() {
let mut est = SpeechProbabilityEstimator::default();
let signal = [100.0f32; FFT_SIZE_BY_2_PLUS_1];
let noise = [1.0f32; FFT_SIZE_BY_2_PLUS_1];
let sum: f32 = signal.iter().sum();
let prior_snr = [10.0f32; FFT_SIZE_BY_2_PLUS_1];
let post_snr = [10.0f32; FFT_SIZE_BY_2_PLUS_1];
for frame in 0..100 {
est.update(&SignalAnalysis {
num_analyzed_frames: frame,
prior_snr: &prior_snr,
post_snr: &post_snr,
conservative_noise_spectrum: &noise,
signal_spectrum: &signal,
signal_spectral_sum: sum,
signal_energy: sum,
});
}
let avg_prob: f32 = est.probability().iter().sum::<f32>() / FFT_SIZE_BY_2_PLUS_1 as f32;
assert!(
avg_prob > 0.5,
"avg speech prob {avg_prob} should be > 0.5 with high SNR"
);
}
#[test]
fn low_snr_gives_low_speech_probability() {
let mut est = SpeechProbabilityEstimator::default();
let signal = [1.0f32; FFT_SIZE_BY_2_PLUS_1];
let noise = [1.0f32; FFT_SIZE_BY_2_PLUS_1];
let sum: f32 = signal.iter().sum();
let prior_snr = [0.01f32; FFT_SIZE_BY_2_PLUS_1];
let post_snr = [0.01f32; FFT_SIZE_BY_2_PLUS_1];
for frame in 0..100 {
est.update(&SignalAnalysis {
num_analyzed_frames: frame,
prior_snr: &prior_snr,
post_snr: &post_snr,
conservative_noise_spectrum: &noise,
signal_spectrum: &signal,
signal_spectral_sum: sum,
signal_energy: sum,
});
}
let avg_prob: f32 = est.probability().iter().sum::<f32>() / FFT_SIZE_BY_2_PLUS_1 as f32;
assert!(
avg_prob < 0.5,
"avg speech prob {avg_prob} should be < 0.5 with low SNR"
);
}
}