use crate::error::VadError;
use crate::frame::{frame_samples, validate_sample_rate};
use crate::{ProcessTimings, VadCapabilities, VoiceActivityDetector};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WebRtcVadMode {
Quality,
LowBitrate,
Aggressive,
VeryAggressive,
}
impl From<WebRtcVadMode> for webrtc_vad::VadMode {
fn from(mode: WebRtcVadMode) -> Self {
match mode {
WebRtcVadMode::Quality => webrtc_vad::VadMode::Quality,
WebRtcVadMode::LowBitrate => webrtc_vad::VadMode::LowBitrate,
WebRtcVadMode::Aggressive => webrtc_vad::VadMode::Aggressive,
WebRtcVadMode::VeryAggressive => webrtc_vad::VadMode::VeryAggressive,
}
}
}
const DEFAULT_FRAME_DURATION_MS: u32 = 30;
pub struct WebRtcVad {
vad: webrtc_vad::Vad,
sample_rate: u32,
mode: WebRtcVadMode,
frame_duration_ms: u32,
inference_time: Duration,
timing_frames: u64,
}
unsafe impl Send for WebRtcVad {}
impl WebRtcVad {
pub fn new(sample_rate: u32, mode: WebRtcVadMode) -> Result<Self, VadError> {
Self::with_frame_duration(sample_rate, mode, DEFAULT_FRAME_DURATION_MS)
}
pub fn with_frame_duration(
sample_rate: u32,
mode: WebRtcVadMode,
frame_duration_ms: u32,
) -> Result<Self, VadError> {
validate_sample_rate(sample_rate)?;
if !matches!(frame_duration_ms, 10 | 20 | 30) {
return Err(VadError::InvalidFrameSize {
got: frame_samples(sample_rate, frame_duration_ms),
expected: frame_samples(sample_rate, 30),
});
}
let mut vad = webrtc_vad::Vad::new_with_rate(to_sample_rate(sample_rate));
vad.set_mode(mode.into());
Ok(Self {
vad,
sample_rate,
mode,
frame_duration_ms,
inference_time: Duration::ZERO,
timing_frames: 0,
})
}
}
impl VoiceActivityDetector for WebRtcVad {
fn capabilities(&self) -> VadCapabilities {
VadCapabilities {
sample_rate: self.sample_rate,
frame_size: frame_samples(self.sample_rate, self.frame_duration_ms),
frame_duration_ms: self.frame_duration_ms,
}
}
fn process(&mut self, samples: &[i16], sample_rate: u32) -> Result<f32, VadError> {
if sample_rate != self.sample_rate {
return Err(VadError::InvalidSampleRate(sample_rate));
}
let valid_frame_sizes = [
frame_samples(sample_rate, 10),
frame_samples(sample_rate, 20),
frame_samples(sample_rate, 30),
];
if !valid_frame_sizes.contains(&samples.len()) {
return Err(VadError::InvalidFrameSize {
got: samples.len(),
expected: valid_frame_sizes[0], });
}
let start = Instant::now();
let is_voice = self
.vad
.is_voice_segment(samples)
.map_err(|()| VadError::BackendError("webrtc-vad processing error".into()))?;
self.inference_time += start.elapsed();
self.timing_frames += 1;
Ok(if is_voice { 1.0 } else { 0.0 })
}
fn reset(&mut self) {
let mut vad = webrtc_vad::Vad::new_with_rate(to_sample_rate(self.sample_rate));
vad.set_mode(self.mode.into());
self.vad = vad;
}
fn timings(&self) -> ProcessTimings {
ProcessTimings {
stages: vec![("inference", self.inference_time)],
frames: self.timing_frames,
}
}
}
fn to_sample_rate(rate: u32) -> webrtc_vad::SampleRate {
match rate {
8000 => webrtc_vad::SampleRate::Rate8kHz,
16000 => webrtc_vad::SampleRate::Rate16kHz,
32000 => webrtc_vad::SampleRate::Rate32kHz,
48000 => webrtc_vad::SampleRate::Rate48kHz,
_ => unreachable!("sample rate validated before calling to_sample_rate"),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_with_valid_rates() {
for &rate in &[8000, 16000, 32000, 48000] {
let vad = WebRtcVad::new(rate, WebRtcVadMode::Quality);
assert!(vad.is_ok());
}
}
#[test]
fn create_with_invalid_rate() {
let vad = WebRtcVad::new(44100, WebRtcVadMode::Quality);
assert!(vad.is_err());
}
#[test]
fn process_silence() {
let mut vad = WebRtcVad::new(16000, WebRtcVadMode::Quality).unwrap();
let silence = vec![0i16; 160];
let result = vad.process(&silence, 16000).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn process_wrong_sample_rate() {
let mut vad = WebRtcVad::new(16000, WebRtcVadMode::Quality).unwrap();
let samples = vec![0i16; 160];
let result = vad.process(&samples, 8000);
assert!(result.is_err());
}
#[test]
fn process_invalid_frame_size() {
let mut vad = WebRtcVad::new(16000, WebRtcVadMode::Quality).unwrap();
let samples = vec![0i16; 100]; let result = vad.process(&samples, 16000);
assert!(result.is_err());
}
#[test]
fn reset_works() {
let mut vad = WebRtcVad::new(16000, WebRtcVadMode::Aggressive).unwrap();
let silence = vec![0i16; 160];
let _ = vad.process(&silence, 16000).unwrap();
vad.reset();
let result = vad.process(&silence, 16000).unwrap();
assert_eq!(result, 0.0);
}
}