use crate::common::{
FRAME_DURATION_MS, LEVEL_ESTIMATOR_LEAK_FACTOR, LEVEL_ESTIMATOR_TIME_TO_CONFIDENCE_MS,
SATURATION_PROTECTOR_INITIAL_HEADROOM_DB, VAD_CONFIDENCE_THRESHOLD,
};
fn clamp_level_estimate_dbfs(level_estimate_dbfs: f32) -> f32 {
level_estimate_dbfs.clamp(-90.0, 30.0)
}
#[derive(Debug, Clone, Copy)]
pub struct AdaptiveDigitalConfig {
pub headroom_db: f32,
pub max_gain_db: f32,
pub initial_gain_db: f32,
pub max_gain_change_db_per_second: f32,
pub max_output_noise_level_dbfs: f32,
}
impl Default for AdaptiveDigitalConfig {
fn default() -> Self {
Self {
headroom_db: 5.0,
max_gain_db: 50.0,
initial_gain_db: 15.0,
max_gain_change_db_per_second: 6.0,
max_output_noise_level_dbfs: -50.0,
}
}
}
fn get_initial_speech_level_estimate_dbfs(config: &AdaptiveDigitalConfig) -> f32 {
clamp_level_estimate_dbfs(
-SATURATION_PROTECTOR_INITIAL_HEADROOM_DB - config.initial_gain_db - config.headroom_db,
)
}
#[derive(Clone, Copy, Debug)]
struct LevelEstimatorState {
time_to_confidence_ms: i32,
level_dbfs: Ratio,
}
#[derive(Clone, Copy, Debug)]
struct Ratio {
numerator: f32,
denominator: f32,
}
impl Ratio {
fn get_ratio(&self) -> f32 {
debug_assert!(self.denominator != 0.0);
self.numerator / self.denominator
}
}
#[derive(Debug)]
pub struct SpeechLevelEstimator {
initial_speech_level_dbfs: f32,
adjacent_speech_frames_threshold: i32,
preliminary_state: LevelEstimatorState,
reliable_state: LevelEstimatorState,
level_dbfs: f32,
is_confident: bool,
num_adjacent_speech_frames: i32,
}
impl SpeechLevelEstimator {
pub fn new(config: &AdaptiveDigitalConfig, adjacent_speech_frames_threshold: i32) -> Self {
debug_assert!(adjacent_speech_frames_threshold >= 1);
let initial_speech_level_dbfs = get_initial_speech_level_estimate_dbfs(config);
let mut est = Self {
initial_speech_level_dbfs,
adjacent_speech_frames_threshold,
preliminary_state: LevelEstimatorState {
time_to_confidence_ms: 0,
level_dbfs: Ratio {
numerator: 0.0,
denominator: 1.0,
},
},
reliable_state: LevelEstimatorState {
time_to_confidence_ms: 0,
level_dbfs: Ratio {
numerator: 0.0,
denominator: 1.0,
},
},
level_dbfs: initial_speech_level_dbfs,
is_confident: false,
num_adjacent_speech_frames: 0,
};
est.reset();
est
}
pub fn update(&mut self, rms_dbfs: f32, speech_probability: f32) {
debug_assert!(rms_dbfs > -150.0);
debug_assert!(rms_dbfs < 50.0);
debug_assert!(speech_probability >= 0.0);
debug_assert!(speech_probability <= 1.0);
if speech_probability < VAD_CONFIDENCE_THRESHOLD {
if self.adjacent_speech_frames_threshold > 1 {
if self.num_adjacent_speech_frames >= self.adjacent_speech_frames_threshold {
self.reliable_state = self.preliminary_state;
} else if self.num_adjacent_speech_frames > 0 {
self.preliminary_state = self.reliable_state;
}
}
self.num_adjacent_speech_frames = 0;
} else {
self.num_adjacent_speech_frames += 1;
debug_assert!(self.preliminary_state.time_to_confidence_ms >= 0);
let buffer_is_full = self.preliminary_state.time_to_confidence_ms == 0;
if !buffer_is_full {
self.preliminary_state.time_to_confidence_ms -= FRAME_DURATION_MS;
}
debug_assert!(speech_probability > 0.0);
let leak_factor = if buffer_is_full {
LEVEL_ESTIMATOR_LEAK_FACTOR
} else {
1.0
};
self.preliminary_state.level_dbfs.numerator =
self.preliminary_state.level_dbfs.numerator * leak_factor
+ rms_dbfs * speech_probability;
self.preliminary_state.level_dbfs.denominator =
self.preliminary_state.level_dbfs.denominator * leak_factor + speech_probability;
let level_dbfs = self.preliminary_state.level_dbfs.get_ratio();
if self.num_adjacent_speech_frames >= self.adjacent_speech_frames_threshold {
self.level_dbfs = clamp_level_estimate_dbfs(level_dbfs);
}
}
self.update_is_confident();
}
pub fn level_dbfs(&self) -> f32 {
self.level_dbfs
}
pub fn is_confident(&self) -> bool {
self.is_confident
}
pub fn reset(&mut self) {
self.preliminary_state = self.make_initial_state();
self.reliable_state = self.make_initial_state();
self.level_dbfs = self.initial_speech_level_dbfs;
self.is_confident = false;
self.num_adjacent_speech_frames = 0;
}
fn update_is_confident(&mut self) {
if self.adjacent_speech_frames_threshold == 1 {
self.is_confident = self.preliminary_state.time_to_confidence_ms == 0;
return;
}
self.is_confident = self.reliable_state.time_to_confidence_ms == 0
|| (self.num_adjacent_speech_frames >= self.adjacent_speech_frames_threshold
&& self.preliminary_state.time_to_confidence_ms == 0);
}
fn make_initial_state(&self) -> LevelEstimatorState {
LevelEstimatorState {
time_to_confidence_ms: LEVEL_ESTIMATOR_TIME_TO_CONFIDENCE_MS as i32,
level_dbfs: Ratio {
numerator: self.initial_speech_level_dbfs,
denominator: 1.0,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const NUM_FRAMES_TO_CONFIDENCE: i32 =
LEVEL_ESTIMATOR_TIME_TO_CONFIDENCE_MS as i32 / FRAME_DURATION_MS;
const CONVERGENCE_SPEED_TESTS_LEVEL_TOLERANCE: f32 = 0.5;
const NO_SPEECH_PROBABILITY: f32 = 0.0;
const LOW_SPEECH_PROBABILITY: f32 = VAD_CONFIDENCE_THRESHOLD / 2.0;
const MAX_SPEECH_PROBABILITY: f32 = 1.0;
fn run_on_constant_level(
num_iterations: i32,
rms_dbfs: f32,
speech_probability: f32,
level_estimator: &mut SpeechLevelEstimator,
) {
for _ in 0..num_iterations {
level_estimator.update(rms_dbfs, speech_probability);
}
}
struct TestLevelEstimator {
estimator: SpeechLevelEstimator,
initial_speech_level_dbfs: f32,
level_rms_dbfs: f32,
level_peak_dbfs: f32,
}
impl TestLevelEstimator {
fn new(adjacent_speech_frames_threshold: i32) -> Self {
let config = AdaptiveDigitalConfig::default();
let estimator = SpeechLevelEstimator::new(&config, adjacent_speech_frames_threshold);
let initial_speech_level_dbfs = estimator.level_dbfs();
let level_rms_dbfs = initial_speech_level_dbfs / 2.0;
let level_peak_dbfs = initial_speech_level_dbfs / 3.0;
debug_assert!(level_rms_dbfs < level_peak_dbfs);
debug_assert!(initial_speech_level_dbfs < level_rms_dbfs);
debug_assert!(
level_rms_dbfs - initial_speech_level_dbfs > 5.0,
"Adjust `level_rms_dbfs` so that the difference from the initial \
level is wide enough for the tests"
);
Self {
estimator,
initial_speech_level_dbfs,
level_rms_dbfs,
level_peak_dbfs,
}
}
}
#[test]
fn level_stabilizes() {
let mut t = TestLevelEstimator::new(1);
run_on_constant_level(
NUM_FRAMES_TO_CONFIDENCE,
t.level_rms_dbfs,
MAX_SPEECH_PROBABILITY,
&mut t.estimator,
);
let estimated_level_dbfs = t.estimator.level_dbfs();
run_on_constant_level(
1,
t.level_rms_dbfs,
MAX_SPEECH_PROBABILITY,
&mut t.estimator,
);
assert!(
(t.estimator.level_dbfs() - estimated_level_dbfs).abs() < 0.1,
"level {} should be near {}",
t.estimator.level_dbfs(),
estimated_level_dbfs
);
}
#[test]
fn is_not_confident() {
let mut t = TestLevelEstimator::new(1);
run_on_constant_level(
NUM_FRAMES_TO_CONFIDENCE / 2,
t.level_rms_dbfs,
MAX_SPEECH_PROBABILITY,
&mut t.estimator,
);
assert!(!t.estimator.is_confident());
}
#[test]
fn is_confident() {
let mut t = TestLevelEstimator::new(1);
run_on_constant_level(
NUM_FRAMES_TO_CONFIDENCE,
t.level_rms_dbfs,
MAX_SPEECH_PROBABILITY,
&mut t.estimator,
);
assert!(t.estimator.is_confident());
}
#[test]
fn estimator_ignores_non_speech_frames() {
let mut t = TestLevelEstimator::new(1);
run_on_constant_level(
NUM_FRAMES_TO_CONFIDENCE,
t.level_rms_dbfs,
MAX_SPEECH_PROBABILITY,
&mut t.estimator,
);
let estimated_level_dbfs = t.estimator.level_dbfs();
run_on_constant_level(
NUM_FRAMES_TO_CONFIDENCE,
0.0,
NO_SPEECH_PROBABILITY,
&mut t.estimator,
);
assert_eq!(t.estimator.level_dbfs(), estimated_level_dbfs);
}
#[test]
fn convergence_speed_before_confidence() {
let mut t = TestLevelEstimator::new(1);
run_on_constant_level(
NUM_FRAMES_TO_CONFIDENCE,
t.level_rms_dbfs,
MAX_SPEECH_PROBABILITY,
&mut t.estimator,
);
assert!(
(t.estimator.level_dbfs() - t.level_rms_dbfs).abs()
<= CONVERGENCE_SPEED_TESTS_LEVEL_TOLERANCE,
"level {} should be near {}",
t.estimator.level_dbfs(),
t.level_rms_dbfs
);
}
#[test]
fn convergence_speed_after_confidence() {
let mut t = TestLevelEstimator::new(1);
run_on_constant_level(
NUM_FRAMES_TO_CONFIDENCE,
t.initial_speech_level_dbfs,
MAX_SPEECH_PROBABILITY,
&mut t.estimator,
);
assert_eq!(t.estimator.level_dbfs(), t.initial_speech_level_dbfs);
assert!(t.estimator.is_confident());
let convergence_time_after_confidence_num_frames = 700; assert!(convergence_time_after_confidence_num_frames > NUM_FRAMES_TO_CONFIDENCE);
run_on_constant_level(
convergence_time_after_confidence_num_frames,
t.level_rms_dbfs,
MAX_SPEECH_PROBABILITY,
&mut t.estimator,
);
assert!(
(t.estimator.level_dbfs() - t.level_rms_dbfs).abs()
<= CONVERGENCE_SPEED_TESTS_LEVEL_TOLERANCE,
"level {} should be near {}",
t.estimator.level_dbfs(),
t.level_rms_dbfs
);
}
#[test]
fn do_not_adapt_to_short_speech_segments_threshold_1() {
do_not_adapt_to_short_speech_segments(1);
}
#[test]
fn do_not_adapt_to_short_speech_segments_threshold_9() {
do_not_adapt_to_short_speech_segments(9);
}
#[test]
fn do_not_adapt_to_short_speech_segments_threshold_17() {
do_not_adapt_to_short_speech_segments(17);
}
fn do_not_adapt_to_short_speech_segments(threshold: i32) {
let mut t = TestLevelEstimator::new(threshold);
let initial_level = t.estimator.level_dbfs();
assert!(initial_level < t.level_peak_dbfs);
for _ in 0..threshold - 1 {
t.estimator.update(t.level_rms_dbfs, MAX_SPEECH_PROBABILITY);
assert_eq!(
initial_level,
t.estimator.level_dbfs(),
"level should not change before threshold"
);
}
t.estimator.update(t.level_rms_dbfs, LOW_SPEECH_PROBABILITY);
assert_eq!(
initial_level,
t.estimator.level_dbfs(),
"level should not change after low-probability frame"
);
}
#[test]
fn adapt_to_enough_speech_segments_threshold_1() {
adapt_to_enough_speech_segments(1);
}
#[test]
fn adapt_to_enough_speech_segments_threshold_9() {
adapt_to_enough_speech_segments(9);
}
#[test]
fn adapt_to_enough_speech_segments_threshold_17() {
adapt_to_enough_speech_segments(17);
}
fn adapt_to_enough_speech_segments(threshold: i32) {
let mut t = TestLevelEstimator::new(threshold);
let initial_level = t.estimator.level_dbfs();
assert!(initial_level < t.level_peak_dbfs);
for _ in 0..threshold {
t.estimator.update(t.level_rms_dbfs, MAX_SPEECH_PROBABILITY);
}
assert!(
initial_level < t.estimator.level_dbfs(),
"level should increase after enough speech frames"
);
}
}