use std::sync::Arc;
use tokio::task::JoinHandle;
use crate::error::{AudioError, AudioResult};
use crate::traits::VadProcessor;
use super::capture::AudioStream;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VadMode {
HandsFree,
PushToTalk,
}
#[derive(Debug, Clone)]
pub struct VadConfig {
pub mode: VadMode,
pub silence_threshold_ms: u32,
pub speech_threshold_ms: u32,
}
impl VadConfig {
pub fn validate(&self) -> AudioResult<()> {
if self.silence_threshold_ms == 0 {
return Err(AudioError::Vad(
"invalid silence threshold: 0 ms. Threshold must be a positive integer.".into(),
));
}
if self.speech_threshold_ms == 0 {
return Err(AudioError::Vad(
"invalid speech threshold: 0 ms. Threshold must be a positive integer.".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VoiceActivityEvent {
SpeechStarted,
SpeechEnded {
duration_ms: u32,
},
}
pub struct VadTurnManager {
vad: Arc<dyn VadProcessor>,
config: VadConfig,
task_handle: Option<JoinHandle<()>>,
}
impl VadTurnManager {
pub fn new(vad: Arc<dyn VadProcessor>, config: VadConfig) -> AudioResult<Self> {
config.validate()?;
Ok(Self { vad, config, task_handle: None })
}
pub fn start(
&mut self,
stream: AudioStream,
callback: impl Fn(VoiceActivityEvent) + Send + Sync + 'static,
) {
let vad = Arc::clone(&self.vad);
let config = self.config.clone();
let callback = Arc::new(callback);
let handle = tokio::spawn(async move {
Self::run_loop(stream, vad, config, callback).await;
});
self.task_handle = Some(handle);
}
pub fn on_activity(
&mut self,
stream: AudioStream,
callback: impl Fn(VoiceActivityEvent) + Send + Sync + 'static,
) {
self.start(stream, callback);
}
pub fn stop(&mut self) {
if let Some(handle) = self.task_handle.take() {
handle.abort();
}
}
async fn run_loop(
mut stream: AudioStream,
vad: Arc<dyn VadProcessor>,
config: VadConfig,
callback: Arc<dyn Fn(VoiceActivityEvent) + Send + Sync>,
) {
match config.mode {
VadMode::HandsFree => {
Self::run_hands_free(
&mut stream,
&vad,
config.speech_threshold_ms,
config.silence_threshold_ms,
&callback,
)
.await;
}
VadMode::PushToTalk => {
while stream.recv().await.is_some() {
}
}
}
}
async fn run_hands_free(
stream: &mut AudioStream,
vad: &Arc<dyn VadProcessor>,
speech_threshold_ms: u32,
silence_threshold_ms: u32,
callback: &Arc<dyn Fn(VoiceActivityEvent) + Send + Sync>,
) {
let mut is_speaking = false;
let mut consecutive_speech_ms: u32 = 0;
let mut consecutive_silence_ms: u32 = 0;
let mut speech_start_ms: u32 = 0;
while let Some(frame) = stream.recv().await {
let speech = vad.is_speech(&frame);
let frame_duration = frame.duration_ms;
if speech {
consecutive_silence_ms = 0;
consecutive_speech_ms = consecutive_speech_ms.saturating_add(frame_duration);
if !is_speaking && consecutive_speech_ms >= speech_threshold_ms {
is_speaking = true;
speech_start_ms = 0;
let cb = Arc::clone(callback);
tokio::spawn(async move {
cb(VoiceActivityEvent::SpeechStarted);
});
}
if is_speaking {
speech_start_ms = speech_start_ms.saturating_add(frame_duration);
}
} else {
consecutive_speech_ms = 0;
if is_speaking {
consecutive_silence_ms = consecutive_silence_ms.saturating_add(frame_duration);
if consecutive_silence_ms >= silence_threshold_ms {
is_speaking = false;
let duration_ms = speech_start_ms;
let cb = Arc::clone(callback);
tokio::spawn(async move {
cb(VoiceActivityEvent::SpeechEnded { duration_ms });
});
consecutive_silence_ms = 0;
speech_start_ms = 0;
}
} else {
consecutive_silence_ms = 0;
speech_start_ms = 0;
}
}
}
}
}
#[allow(dead_code)]
const _: fn() = || {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<VadTurnManager>();
};
#[cfg(test)]
mod tests {
use super::*;
use crate::frame::AudioFrame;
use crate::traits::SpeechSegment;
struct AlwaysSpeechVad;
impl VadProcessor for AlwaysSpeechVad {
fn is_speech(&self, _frame: &AudioFrame) -> bool {
true
}
fn segment(&self, _frame: &AudioFrame) -> Vec<SpeechSegment> {
vec![]
}
}
#[test]
fn test_vad_config_validate_ok() {
let config = VadConfig {
mode: VadMode::HandsFree,
silence_threshold_ms: 500,
speech_threshold_ms: 200,
};
assert!(config.validate().is_ok());
}
#[test]
fn test_vad_config_validate_zero_silence() {
let config = VadConfig {
mode: VadMode::HandsFree,
silence_threshold_ms: 0,
speech_threshold_ms: 200,
};
let err = config.validate().unwrap_err();
assert!(matches!(err, AudioError::Vad(msg) if msg.contains("silence threshold")));
}
#[test]
fn test_vad_config_validate_zero_speech() {
let config = VadConfig {
mode: VadMode::HandsFree,
silence_threshold_ms: 500,
speech_threshold_ms: 0,
};
let err = config.validate().unwrap_err();
assert!(matches!(err, AudioError::Vad(msg) if msg.contains("speech threshold")));
}
#[test]
fn test_vad_config_validate_both_zero() {
let config = VadConfig {
mode: VadMode::PushToTalk,
silence_threshold_ms: 0,
speech_threshold_ms: 0,
};
let err = config.validate().unwrap_err();
assert!(matches!(err, AudioError::Vad(_)));
}
#[test]
fn test_vad_turn_manager_new_valid() {
let vad: Arc<dyn VadProcessor> = Arc::new(AlwaysSpeechVad);
let config = VadConfig {
mode: VadMode::HandsFree,
silence_threshold_ms: 500,
speech_threshold_ms: 200,
};
let manager = VadTurnManager::new(vad, config);
assert!(manager.is_ok());
}
#[test]
fn test_vad_turn_manager_new_invalid() {
let vad: Arc<dyn VadProcessor> = Arc::new(AlwaysSpeechVad);
let config = VadConfig {
mode: VadMode::HandsFree,
silence_threshold_ms: 0,
speech_threshold_ms: 200,
};
let manager = VadTurnManager::new(vad, config);
assert!(manager.is_err());
}
#[test]
fn test_vad_turn_manager_stop_idempotent() {
let vad: Arc<dyn VadProcessor> = Arc::new(AlwaysSpeechVad);
let config = VadConfig {
mode: VadMode::HandsFree,
silence_threshold_ms: 500,
speech_threshold_ms: 200,
};
let mut manager = VadTurnManager::new(vad, config).unwrap();
manager.stop();
assert!(manager.task_handle.is_none());
manager.stop();
assert!(manager.task_handle.is_none());
}
#[test]
fn test_vad_turn_manager_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<VadTurnManager>();
assert_sync::<VadTurnManager>();
}
#[test]
fn test_vad_mode_clone_copy() {
let mode = VadMode::HandsFree;
let cloned = mode;
assert_eq!(mode, cloned);
let mode2 = VadMode::PushToTalk;
let cloned2 = mode2;
assert_eq!(mode2, cloned2);
}
#[test]
fn test_voice_activity_event_clone_eq() {
let event1 = VoiceActivityEvent::SpeechStarted;
let event2 = event1.clone();
assert_eq!(event1, event2);
let event3 = VoiceActivityEvent::SpeechEnded { duration_ms: 1500 };
let event4 = event3.clone();
assert_eq!(event3, event4);
assert_ne!(event1, event3);
}
}