use std::time::Duration;
use tracing::{debug, warn};
use voice_activity_detector::VoiceActivityDetector as SileroVoiceActivityDetector;
use crate::{VoiceInputResult, SAMPLE_RATE};
pub const SILERO_VAD_CHUNK_SIZE: usize = 512;
pub const SILERO_VAD_VOICE_THRESHOLD: f32 = 0.5;
#[derive(Debug, Clone)]
pub struct VoiceDetectionConfig {
pub silence_duration_threshold: u64,
pub pre_speech_padding: u64,
pub post_speech_padding: u64,
pub voice_threshold: f32,
pub max_speech_duration: u64,
}
impl Default for VoiceDetectionConfig {
fn default() -> Self {
Self {
silence_duration_threshold: 800, pre_speech_padding: 150, post_speech_padding: 200, voice_threshold: SILERO_VAD_VOICE_THRESHOLD,
max_speech_duration: 30_000, }
}
}
pub struct VoiceDetection {
active_speech: Vec<f32>,
padding_buffer: Vec<f32>,
silence_frames: usize,
silero_vad: SileroVoiceActivityDetector,
config: VoiceDetectionConfig,
sample_rate: usize,
frames_per_chunk: usize,
silence_frames_threshold: usize,
max_speech_frames: usize,
padding_frames: usize,
}
impl Default for VoiceDetection {
fn default() -> Self {
Self::new(SAMPLE_RATE, SILERO_VAD_CHUNK_SIZE, None).expect("Valid default")
}
}
impl VoiceDetection {
pub fn new(
sample_rate: usize,
chunk_size: usize,
config: Option<VoiceDetectionConfig>,
) -> VoiceInputResult<Self> {
match sample_rate {
8000 => assert!(chunk_size >= 256, "8kHz requires chunk size >= 256"),
16000 => assert!(chunk_size >= 512, "16kHz requires chunk size >= 512"),
_ => panic!("Sample rate must be 8000 or 16000 Hz"),
}
let config = config.unwrap_or_default();
let silero_vad = SileroVoiceActivityDetector::builder()
.sample_rate(sample_rate as i64)
.chunk_size(chunk_size)
.build()
.expect("Valid Silero VAD configuration");
let frames_per_second = sample_rate as f32 / chunk_size as f32;
let silence_frames_threshold =
((config.silence_duration_threshold as f32 / 1000.0) * frames_per_second) as usize;
let max_speech_frames =
((config.max_speech_duration as f32 / 1000.0) * frames_per_second) as usize;
let padding_frames = ((config.pre_speech_padding.max(config.post_speech_padding) as f32
/ 1000.0)
* sample_rate as f32) as usize;
Ok(Self {
active_speech: Vec::with_capacity(sample_rate * 2), padding_buffer: Vec::with_capacity(padding_frames),
silence_frames: 0,
silero_vad,
config,
sample_rate,
frames_per_chunk: chunk_size,
silence_frames_threshold,
max_speech_frames,
padding_frames,
})
}
pub fn add_samples(&mut self, samples: &[f32]) -> Option<Vec<f32>> {
let is_voice = self.silero_vad.predict(samples.to_vec()) > self.config.voice_threshold;
debug!(
"process_samples[{}] confidence={:.3}, silence_frames={}, buffer_size={}",
if is_voice { "VOICE+" } else { "VOICE-" },
self.silero_vad.predict(samples.to_vec()),
self.silence_frames,
self.active_speech.len()
);
if is_voice {
self.handle_voice_detected(samples)
} else {
self.handle_silence_detected(samples)
}
}
fn handle_voice_detected(&mut self, samples: &[f32]) -> Option<Vec<f32>> {
self.silence_frames = 0;
if self.active_speech.is_empty() && !self.padding_buffer.is_empty() {
let padding_samples = self.padding_frames.min(self.padding_buffer.len());
self.active_speech
.extend(&self.padding_buffer[..padding_samples]);
}
self.active_speech.extend(samples);
self.padding_buffer = samples.to_vec();
if self.active_speech.len() >= self.max_speech_frames * self.frames_per_chunk {
warn!("Maximum speech duration exceeded, forcing segment break");
let speech = std::mem::take(&mut self.active_speech);
return Some(speech);
}
None
}
fn handle_silence_detected(&mut self, samples: &[f32]) -> Option<Vec<f32>> {
self.silence_frames += 1;
if !self.active_speech.is_empty() {
if self.silence_frames >= self.silence_frames_threshold {
let mut speech = std::mem::take(&mut self.active_speech);
speech.extend(&self.padding_buffer);
self.padding_buffer = samples.to_vec();
self.silence_frames = 0;
return Some(speech);
}
self.active_speech.extend(samples);
}
self.padding_buffer = samples.to_vec();
None
}
pub fn config(&self) -> &VoiceDetectionConfig {
&self.config
}
pub fn update_config(&mut self, config: VoiceDetectionConfig) {
let frames_per_second = self.sample_rate as f32 / self.frames_per_chunk as f32;
self.silence_frames_threshold =
((config.silence_duration_threshold as f32 / 1000.0) * frames_per_second) as usize;
self.max_speech_frames =
((config.max_speech_duration as f32 / 1000.0) * frames_per_second) as usize;
self.config = config;
}
pub fn active_speech_duration(&self) -> Duration {
Duration::from_secs_f64(self.active_speech.len() as f64 / self.sample_rate as f64)
}
pub fn silero_vad_prediction(&mut self, input: Vec<f32>) -> f32 {
self.silero_vad.predict(input)
}
pub fn silero_vad_is_voice(&mut self, input: Vec<f32>) -> bool {
self.silero_vad_prediction(input) > self.config.voice_threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_voice_detection_config() {
let config = VoiceDetectionConfig {
silence_duration_threshold: 1000,
pre_speech_padding: 200,
post_speech_padding: 300,
voice_threshold: 0.6,
max_speech_duration: 20_000,
};
let detector = VoiceDetection::new(16000, 512, Some(config)).unwrap();
assert_eq!(detector.config().silence_duration_threshold, 1000);
assert_eq!(detector.config().voice_threshold, 0.6);
}
#[test]
fn test_speech_duration_tracking() {
let mut detector = VoiceDetection::new(16000, 512, None).unwrap();
let samples = vec![0.1; 16000];
detector.active_speech.extend(samples);
let duration = detector.active_speech_duration();
assert_eq!(duration.as_secs(), 1);
}
}