use crate::error::{Error, Result};
use crate::time::{AudioDuration, AudioTimestamp};
const NANOS_PER_SECOND: u128 = 1_000_000_000;
#[derive(Debug, Clone, Copy)]
pub struct VadConfig {
pub sample_rate: u32,
pub frame_duration: AudioDuration,
pub frame_overlap: f32,
pub energy_smoothing: f32,
pub flux_smoothing: f32,
pub energy_floor: f32,
pub flux_floor: f32,
pub threshold_smoothing: f32,
pub activation_margin: f32,
pub release_margin: f32,
pub base_threshold: f32,
pub energy_weight: f32,
pub flux_weight: f32,
pub hangover_frames: usize,
pub min_speech_frames: usize,
pub stream_start_time: AudioTimestamp,
pub pre_emphasis: Option<f32>,
}
impl Default for VadConfig {
fn default() -> Self {
Self {
sample_rate: 16_000,
frame_duration: AudioDuration::from_millis(20),
frame_overlap: 0.5,
energy_smoothing: 0.85,
flux_smoothing: 0.8,
energy_floor: 1e-4,
flux_floor: 1e-4,
threshold_smoothing: 0.9,
activation_margin: 1.1,
release_margin: 0.9,
base_threshold: 0.4,
energy_weight: 0.6,
flux_weight: 0.4,
hangover_frames: 3,
min_speech_frames: 3,
stream_start_time: AudioTimestamp::EPOCH,
pre_emphasis: Some(0.97),
}
}
}
impl VadConfig {
pub fn validate(&self) -> Result<()> {
const EPSILON: f32 = 1e-6;
if self.sample_rate == 0 {
return Err(invalid_input("sample_rate must be greater than zero"));
}
if self.frame_duration.as_nanos() as u64 == 0 {
return Err(invalid_input("frame_duration must be non-zero"));
}
if !(0.0..1.0).contains(&self.frame_overlap) {
return Err(invalid_input("frame_overlap must be within [0.0, 1.0)"));
}
if !(0.0..1.0).contains(&self.energy_smoothing) {
return Err(invalid_input("energy_smoothing must be within [0.0, 1.0)"));
}
if !(0.0..1.0).contains(&self.flux_smoothing) {
return Err(invalid_input("flux_smoothing must be within [0.0, 1.0)"));
}
if !(0.0..1.0).contains(&self.threshold_smoothing) {
return Err(invalid_input(
"threshold_smoothing must be within [0.0, 1.0)",
));
}
if self.activation_margin < 1.0 {
return Err(invalid_input("activation_margin must be >= 1.0"));
}
if self.release_margin <= 0.0 {
return Err(invalid_input("release_margin must be positive"));
}
if self.release_margin > self.activation_margin {
return Err(invalid_input("release_margin must be <= activation_margin"));
}
if self.base_threshold <= 0.0 {
return Err(invalid_input("base_threshold must be positive"));
}
if self.energy_weight < 0.0 || self.flux_weight < 0.0 {
return Err(invalid_input("metric weights must be non-negative"));
}
let weight_sum = self.energy_weight + self.flux_weight;
if weight_sum.abs() < EPSILON {
return Err(invalid_input("metric weights must not both be zero"));
}
if self.min_speech_frames == 0 {
return Err(invalid_input("min_speech_frames must be greater than zero"));
}
if let Some(coeff) = self.pre_emphasis {
if !(0.0..1.0).contains(&coeff) {
return Err(invalid_input(
"pre_emphasis coefficient must be in [0.0, 1.0)",
));
}
}
Ok(())
}
pub fn frame_length_samples(&self) -> Result<usize> {
let sr = u128::from(self.sample_rate);
let nanos = self.frame_duration.as_nanos();
let numerator = nanos
.saturating_mul(sr)
.saturating_add(NANOS_PER_SECOND / 2);
let samples = usize::try_from(numerator / NANOS_PER_SECOND)
.map_err(|_| invalid_input("frame duration too large for platform"))?;
Ok(samples.max(1))
}
pub fn hop_length_samples(&self) -> Result<usize> {
let frame_length = self.frame_length_samples()?;
let hop = (frame_length as f32 * (1.0 - self.frame_overlap)).round() as usize;
Ok(hop.max(1))
}
pub fn fft_size(&self) -> Result<usize> {
Ok(self.frame_length_samples()?.next_power_of_two())
}
}
fn invalid_input(message: impl Into<String>) -> Error {
Error::InvalidInput(message.into())
}