use std::collections::VecDeque;
use anyhow::{Context, Result};
use voice_activity_detector::VoiceActivityDetector;
use crate::resample::WHISPER_SAMPLE_RATE;
pub const VAD_CHUNK_SIZE: usize = 512;
const DEFAULT_PREFILL_FRAMES: usize = 15;
const DEFAULT_ONSET_FRAMES: usize = 2;
const DEFAULT_HANGOVER_FRAMES: usize = 15;
pub struct VadProcessor {
detector: VoiceActivityDetector,
threshold: f32,
is_enabled: bool,
buffer: Vec<f32>,
is_speaking: bool,
frame_buffer: VecDeque<Vec<f32>>,
prefill_frames: usize,
onset_frames: usize,
hangover_frames: usize,
onset_counter: usize,
hangover_counter: usize,
}
impl VadProcessor {
pub fn new(enabled: bool, threshold: f32) -> Result<Self> {
let detector = VoiceActivityDetector::builder()
.sample_rate(WHISPER_SAMPLE_RATE as i64)
.chunk_size(VAD_CHUNK_SIZE)
.build()
.context("Failed to create VAD detector")?;
Ok(Self {
detector,
threshold: threshold.clamp(0.0, 1.0),
is_enabled: enabled,
buffer: Vec::with_capacity(VAD_CHUNK_SIZE * 2),
is_speaking: false,
frame_buffer: VecDeque::with_capacity(DEFAULT_PREFILL_FRAMES + 1),
prefill_frames: DEFAULT_PREFILL_FRAMES,
onset_frames: DEFAULT_ONSET_FRAMES,
hangover_frames: DEFAULT_HANGOVER_FRAMES,
onset_counter: 0,
hangover_counter: 0,
})
}
pub fn disabled() -> Result<Self> {
Self::new(false, 0.5)
}
pub fn is_enabled(&self) -> bool {
self.is_enabled
}
pub fn process(&mut self, samples: &[f32]) -> Vec<f32> {
if !self.is_enabled {
return samples.to_vec();
}
let mut output = Vec::new();
self.buffer.extend_from_slice(samples);
while self.buffer.len() >= VAD_CHUNK_SIZE {
let chunk: Vec<f32> = self.buffer.drain(..VAD_CHUNK_SIZE).collect();
self.frame_buffer.push_back(chunk.clone());
while self.frame_buffer.len() > self.prefill_frames + 1 {
self.frame_buffer.pop_front();
}
let probability = self.detector.predict(chunk.iter().copied());
let is_voice = probability >= self.threshold;
match (self.is_speaking, is_voice) {
(false, true) => {
self.onset_counter += 1;
if self.onset_counter >= self.onset_frames {
self.is_speaking = true;
self.hangover_counter = self.hangover_frames;
self.onset_counter = 0;
for frame in &self.frame_buffer {
output.extend_from_slice(frame);
}
}
}
(true, true) => {
self.hangover_counter = self.hangover_frames;
output.extend_from_slice(&chunk);
}
(true, false) => {
if self.hangover_counter > 0 {
self.hangover_counter -= 1;
output.extend_from_slice(&chunk);
} else {
self.is_speaking = false;
}
}
(false, false) => {
self.onset_counter = 0;
}
}
}
output
}
pub fn reset(&mut self) {
self.frame_buffer.clear();
self.onset_counter = 0;
self.hangover_counter = 0;
self.is_speaking = false;
self.buffer.clear();
}
pub fn flush(&mut self) -> Vec<f32> {
if !self.is_enabled {
return std::mem::take(&mut self.buffer);
}
let mut output = Vec::new();
if self.is_speaking {
output.extend(std::mem::take(&mut self.buffer));
}
self.reset();
output
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vad_disabled_passthrough() {
let mut vad = VadProcessor::disabled().unwrap();
let samples = vec![0.1, 0.2, 0.3];
let output = vad.process(&samples);
assert_eq!(output, samples);
}
}