use crate::common::{FRAME_DURATION_MS, VAD_RESET_PERIOD_MS};
use crate::rnn_vad::common::{FRAME_SIZE_10MS_24K_HZ, FeatureVector, SAMPLE_RATE_24K_HZ};
use crate::rnn_vad::features_extraction::FeaturesExtractor;
use crate::rnn_vad::rnn::RnnVad;
use sonora_common_audio::push_resampler::PushResampler;
use sonora_simd::SimdBackend;
const NUM_FRAMES_PER_SECOND: i32 = 100;
pub trait MonoVad: Send + Sync {
fn sample_rate_hz(&self) -> i32;
fn reset(&mut self);
fn analyze(&mut self, frame: &[f32]) -> f32;
}
struct RnnMonoVad {
features_extractor: FeaturesExtractor,
rnn_vad: RnnVad,
feature_vector: FeatureVector,
}
impl RnnMonoVad {
fn new(backend: SimdBackend) -> Self {
Self {
features_extractor: FeaturesExtractor::new(backend),
rnn_vad: RnnVad::new(backend),
feature_vector: bytemuck::Zeroable::zeroed(),
}
}
}
impl MonoVad for RnnMonoVad {
fn sample_rate_hz(&self) -> i32 {
SAMPLE_RATE_24K_HZ
}
fn reset(&mut self) {
self.rnn_vad.reset();
}
fn analyze(&mut self, frame: &[f32]) -> f32 {
debug_assert_eq!(frame.len(), FRAME_SIZE_10MS_24K_HZ);
let is_silence = self
.features_extractor
.check_silence_compute_features(frame, &mut self.feature_vector);
self.rnn_vad
.compute_vad_probability(&self.feature_vector, is_silence)
}
}
#[derive(derive_more::Debug)]
pub struct VoiceActivityDetectorWrapper {
vad_reset_period_frames: i32,
frame_size: usize,
time_to_vad_reset: i32,
#[debug(skip)]
vad: Box<dyn MonoVad>,
#[debug(skip)]
resampled_buffer: Vec<f32>,
#[debug(skip)]
resampler: PushResampler<f32>,
}
impl VoiceActivityDetectorWrapper {
pub fn new(backend: SimdBackend, sample_rate_hz: i32) -> Self {
Self::with_reset_period(VAD_RESET_PERIOD_MS, backend, sample_rate_hz)
}
pub fn with_reset_period(
vad_reset_period_ms: i32,
backend: SimdBackend,
sample_rate_hz: i32,
) -> Self {
Self::with_vad(
vad_reset_period_ms,
Box::new(RnnMonoVad::new(backend)),
sample_rate_hz,
)
}
pub fn with_vad(
vad_reset_period_ms: i32,
mut vad: Box<dyn MonoVad>,
sample_rate_hz: i32,
) -> Self {
let vad_reset_period_frames = vad_reset_period_ms / FRAME_DURATION_MS;
debug_assert!(vad_reset_period_frames > 1);
let frame_size = (sample_rate_hz / NUM_FRAMES_PER_SECOND) as usize;
let resampled_size = (vad.sample_rate_hz() / NUM_FRAMES_PER_SECOND) as usize;
let resampler = PushResampler::new(frame_size, resampled_size, 1);
vad.reset();
Self {
vad_reset_period_frames,
frame_size,
time_to_vad_reset: vad_reset_period_frames,
vad,
resampled_buffer: vec![0.0; resampled_size],
resampler,
}
}
pub fn analyze(&mut self, frame: &[f32]) -> f32 {
self.time_to_vad_reset -= 1;
if self.time_to_vad_reset <= 0 {
self.vad.reset();
self.time_to_vad_reset = self.vad_reset_period_frames;
}
debug_assert!(frame.len() >= self.frame_size);
let src = &frame[..self.frame_size];
self.resampler
.resample_mono(src, &mut self.resampled_buffer);
self.vad.analyze(&self.resampled_buffer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::Mutex;
const SAMPLE_RATE_8K_HZ: i32 = 8000;
const NO_VAD_PERIODIC_RESET: i32 = FRAME_DURATION_MS * (i32::MAX / FRAME_DURATION_MS);
struct MockVad {
sample_rate: i32,
probabilities: Vec<f32>,
prob_index: usize,
state: Arc<Mutex<MockVadState>>,
}
#[derive(Default)]
struct MockVadState {
reset_count: usize,
analyze_frames: Vec<usize>,
}
impl MockVad {
fn new(sample_rate: i32, probabilities: Vec<f32>, state: Arc<Mutex<MockVadState>>) -> Self {
Self {
sample_rate,
probabilities,
prob_index: 0,
state,
}
}
}
impl MonoVad for MockVad {
fn sample_rate_hz(&self) -> i32 {
self.sample_rate
}
fn reset(&mut self) {
self.state.lock().unwrap().reset_count += 1;
}
fn analyze(&mut self, frame: &[f32]) -> f32 {
self.state.lock().unwrap().analyze_frames.push(frame.len());
let p = self.probabilities[self.prob_index % self.probabilities.len()];
self.prob_index += 1;
p
}
}
fn create_mock_vad_wrapper(
vad_reset_period_ms: i32,
sample_rate_hz: i32,
probabilities: Vec<f32>,
state: Arc<Mutex<MockVadState>>,
) -> VoiceActivityDetectorWrapper {
let vad = MockVad::new(sample_rate_hz, probabilities, state);
VoiceActivityDetectorWrapper::with_vad(vad_reset_period_ms, Box::new(vad), sample_rate_hz)
}
fn make_frame(sample_rate_hz: i32) -> Vec<f32> {
vec![0.0_f32; (sample_rate_hz / NUM_FRAMES_PER_SECOND) as usize]
}
#[test]
fn check_speech_probabilities() {
let probabilities = vec![
0.709, 0.484, 0.882, 0.167, 0.44, 0.525, 0.858, 0.314, 0.653, 0.965, 0.413, 0.0,
];
let state = Arc::new(Mutex::new(MockVadState::default()));
let mut wrapper = create_mock_vad_wrapper(
NO_VAD_PERIODIC_RESET,
SAMPLE_RATE_8K_HZ,
probabilities.clone(),
Arc::clone(&state),
);
let frame = make_frame(SAMPLE_RATE_8K_HZ);
for (i, &expected) in probabilities.iter().enumerate() {
let actual = wrapper.analyze(&frame);
assert_eq!(
expected, actual,
"mismatch at frame {i}: expected {expected}, got {actual}"
);
}
}
#[test]
fn vad_no_periodic_reset() {
let num_frames = 19;
let state = Arc::new(Mutex::new(MockVadState::default()));
let mut wrapper = create_mock_vad_wrapper(
NO_VAD_PERIODIC_RESET,
SAMPLE_RATE_8K_HZ,
vec![1.0],
Arc::clone(&state),
);
let frame = make_frame(SAMPLE_RATE_8K_HZ);
for _ in 0..num_frames {
wrapper.analyze(&frame);
}
assert_eq!(state.lock().unwrap().reset_count, 1);
}
#[test]
fn vad_periodic_reset() {
let test_cases: Vec<(i32, i32)> = vec![
(1, 2),
(1, 5),
(1, 20),
(1, 53),
(19, 2),
(19, 5),
(19, 20),
(19, 53),
(123, 2),
(123, 5),
(123, 20),
(123, 53),
];
for (num_frames, vad_reset_period_frames) in test_cases {
let vad_reset_period_ms = vad_reset_period_frames * FRAME_DURATION_MS;
let state = Arc::new(Mutex::new(MockVadState::default()));
let mut wrapper = create_mock_vad_wrapper(
vad_reset_period_ms,
SAMPLE_RATE_8K_HZ,
vec![1.0],
Arc::clone(&state),
);
let frame = make_frame(SAMPLE_RATE_8K_HZ);
for _ in 0..num_frames {
wrapper.analyze(&frame);
}
let expected_resets = 1 + num_frames / vad_reset_period_frames;
assert_eq!(
state.lock().unwrap().reset_count,
expected_resets as usize,
"num_frames={num_frames}, period={vad_reset_period_frames}"
);
}
}
#[test]
fn check_resampled_frame_size() {
let input_rates = [8000, 16000, 44100, 48000];
let vad_rates = [6000, 8000, 12000, 16000, 24000];
for &input_rate in &input_rates {
for &vad_rate in &vad_rates {
let state = Arc::new(Mutex::new(MockVadState::default()));
let vad = MockVad::new(vad_rate, vec![1.0], Arc::clone(&state));
let mut wrapper = VoiceActivityDetectorWrapper::with_vad(
NO_VAD_PERIODIC_RESET,
Box::new(vad),
input_rate,
);
let frame = make_frame(input_rate);
wrapper.analyze(&frame);
let expected_frame_size = (vad_rate / NUM_FRAMES_PER_SECOND) as usize;
assert_eq!(
state.lock().unwrap().analyze_frames[0],
expected_frame_size,
"input_rate={input_rate}, vad_rate={vad_rate}"
);
}
}
}
}