use ort::session::{builder::GraphOptimizationLevel, Session};
use ort::value::Value;
use std::path::Path;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum VadError {
#[error("Model load error: {0}")]
ModelLoadError(String),
#[error("Inference error: {0}")]
InferenceError(String),
#[error("Invalid configuration: {0}")]
ConfigError(String),
#[error("Invalid audio: {0}")]
AudioError(String),
}
pub type VadResult<T> = Result<T, VadError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum VadSampleRate {
Rate8k,
#[default]
Rate16k,
}
impl VadSampleRate {
pub fn as_hz(&self) -> i64 {
match self {
VadSampleRate::Rate8k => 8000,
VadSampleRate::Rate16k => 16000,
}
}
pub fn frame_size(&self) -> usize {
match self {
VadSampleRate::Rate8k => 256,
VadSampleRate::Rate16k => 512,
}
}
pub fn context_size(&self) -> usize {
match self {
VadSampleRate::Rate8k => 32,
VadSampleRate::Rate16k => 64,
}
}
pub fn frame_duration_ms(&self) -> f32 {
(self.frame_size() as f32 / self.as_hz() as f32) * 1000.0
}
}
#[derive(Debug, Clone)]
pub struct VadConfig {
pub sample_rate: VadSampleRate,
pub threshold: f32,
pub min_speech_frames: usize,
pub min_silence_frames: usize,
pub padding_frames: usize,
}
impl Default for VadConfig {
fn default() -> Self {
Self {
sample_rate: VadSampleRate::Rate16k,
threshold: 0.5,
min_speech_frames: 1,
min_silence_frames: 8,
padding_frames: 2,
}
}
}
impl VadConfig {
pub fn streaming() -> Self {
Self {
sample_rate: VadSampleRate::Rate16k,
threshold: 0.5,
min_speech_frames: 1,
min_silence_frames: 4, padding_frames: 1,
}
}
pub fn batch() -> Self {
Self {
sample_rate: VadSampleRate::Rate16k,
threshold: 0.5,
min_speech_frames: 2,
min_silence_frames: 16, padding_frames: 4,
}
}
}
struct VadState {
state: Vec<f32>,
context: Vec<f32>,
}
impl VadState {
fn new(context_size: usize) -> Self {
Self {
state: vec![0.0; 2 * 128], context: vec![0.0; context_size],
}
}
fn reset(&mut self) {
self.state.fill(0.0);
self.context.fill(0.0);
}
}
#[derive(Debug, Clone)]
pub struct VadFrame {
pub probability: f32,
pub is_speech: bool,
pub frame_index: u64,
}
#[derive(Debug, Clone)]
pub struct SpeechSegment {
pub start_ms: f32,
pub end_ms: f32,
pub start_frame: u64,
pub end_frame: u64,
pub avg_probability: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VadSessionState {
Silence,
SpeechPending,
Speech,
SilencePending,
}
pub struct VadSession {
session: Session,
config: VadConfig,
state: VadState,
session_state: VadSessionState,
frame_count: u64,
speech_frames: usize,
silence_frames: usize,
segment_start: Option<u64>,
segment_probs: Vec<f32>,
}
impl VadSession {
pub fn new<P: AsRef<Path>>(model_dir: P, config: VadConfig) -> VadResult<Self> {
let model_path = model_dir.as_ref().join("model.onnx");
if !model_path.exists() {
return Err(VadError::ModelLoadError(format!(
"Model file not found: {:?}",
model_path
)));
}
let session = Session::builder()
.map_err(|e| {
VadError::ModelLoadError(format!("Failed to create session builder: {}", e))
})?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| {
VadError::ModelLoadError(format!("Failed to set optimization level: {}", e))
})?
.with_intra_threads(1)
.map_err(|e| VadError::ModelLoadError(format!("Failed to set threads: {}", e)))?
.commit_from_file(&model_path)
.map_err(|e| VadError::ModelLoadError(format!("Failed to load model: {}", e)))?;
let state = VadState::new(config.sample_rate.context_size());
Ok(Self {
session,
config,
state,
session_state: VadSessionState::Silence,
frame_count: 0,
speech_frames: 0,
silence_frames: 0,
segment_start: None,
segment_probs: Vec::new(),
})
}
pub fn config(&self) -> &VadConfig {
&self.config
}
pub fn session_state(&self) -> VadSessionState {
self.session_state
}
pub fn frame_count(&self) -> u64 {
self.frame_count
}
pub fn is_speech(&self) -> bool {
matches!(
self.session_state,
VadSessionState::Speech | VadSessionState::SilencePending
)
}
pub fn reset(&mut self) {
self.state.reset();
self.session_state = VadSessionState::Silence;
self.frame_count = 0;
self.speech_frames = 0;
self.silence_frames = 0;
self.segment_start = None;
self.segment_probs.clear();
}
pub fn process_frame(&mut self, samples: &[f32]) -> VadResult<VadFrame> {
let frame_size = self.config.sample_rate.frame_size();
let context_size = self.config.sample_rate.context_size();
if samples.len() != frame_size {
return Err(VadError::AudioError(format!(
"Expected {} samples, got {}",
frame_size,
samples.len()
)));
}
let mut input = Vec::with_capacity(context_size + frame_size);
input.extend_from_slice(&self.state.context);
input.extend_from_slice(samples);
self.state
.context
.copy_from_slice(&samples[frame_size - context_size..]);
let probability = self.run_inference(&input)?;
let is_speech = probability > self.config.threshold;
self.update_state_machine(is_speech, probability);
let frame = VadFrame {
probability,
is_speech,
frame_index: self.frame_count,
};
self.frame_count += 1;
Ok(frame)
}
pub fn process_audio(&mut self, samples: &[f32]) -> VadResult<Vec<SpeechSegment>> {
let frame_size = self.config.sample_rate.frame_size();
let mut segments = Vec::new();
let num_frames = samples.len() / frame_size;
for i in 0..num_frames {
let start = i * frame_size;
let end = start + frame_size;
let frame = &samples[start..end];
let _result = self.process_frame(frame)?;
if let Some(segment) = self.check_segment_complete() {
segments.push(segment);
}
}
Ok(segments)
}
pub fn flush(&mut self) -> Option<SpeechSegment> {
if self.is_speech() && self.segment_start.is_some() {
let segment = self.create_segment(self.frame_count);
self.segment_start = None;
self.segment_probs.clear();
self.session_state = VadSessionState::Silence;
return Some(segment);
}
None
}
fn run_inference(&mut self, input: &[f32]) -> VadResult<f32> {
let frame_size = self.config.sample_rate.frame_size();
let context_size = self.config.sample_rate.context_size();
let input_size = frame_size + context_size;
let input_tensor = Value::from_array(
ndarray::Array2::from_shape_vec((1, input_size), input.to_vec()).map_err(|e| {
VadError::InferenceError(format!("Failed to create input array: {}", e))
})?,
)
.map_err(|e| VadError::InferenceError(format!("Failed to create input tensor: {}", e)))?;
let sr_tensor = Value::from_array(ndarray::Array::from_elem(
(),
self.config.sample_rate.as_hz(),
))
.map_err(|e| VadError::InferenceError(format!("Failed to create sr tensor: {}", e)))?;
let state_tensor = Value::from_array(
ndarray::Array3::from_shape_vec((2, 1, 128), self.state.state.clone()).map_err(
|e| VadError::InferenceError(format!("Failed to create state array: {}", e)),
)?,
)
.map_err(|e| VadError::InferenceError(format!("Failed to create state tensor: {}", e)))?;
let outputs = self
.session
.run(ort::inputs![
"input" => input_tensor,
"sr" => sr_tensor,
"state" => state_tensor,
])
.map_err(|e| VadError::InferenceError(format!("Inference failed: {}", e)))?;
let output = outputs
.get("output")
.ok_or_else(|| VadError::InferenceError("Missing 'output' in results".into()))?;
let (_, output_data) = output
.try_extract_tensor::<f32>()
.map_err(|e| VadError::InferenceError(format!("Failed to extract output: {}", e)))?;
let probability = output_data.first().copied().unwrap_or(0.0);
let state_out = outputs
.get("stateN")
.ok_or_else(|| VadError::InferenceError("Missing 'stateN' in results".into()))?;
let (_, state_data) = state_out
.try_extract_tensor::<f32>()
.map_err(|e| VadError::InferenceError(format!("Failed to extract state: {}", e)))?;
self.state.state = state_data.to_vec();
Ok(probability)
}
fn update_state_machine(&mut self, is_speech: bool, probability: f32) {
match self.session_state {
VadSessionState::Silence => {
if is_speech {
self.speech_frames = 1;
if self.config.min_speech_frames <= 1 {
self.session_state = VadSessionState::Speech;
self.segment_start = Some(
self.frame_count
.saturating_sub(self.config.padding_frames as u64),
);
self.segment_probs.push(probability);
} else {
self.session_state = VadSessionState::SpeechPending;
}
}
}
VadSessionState::SpeechPending => {
if is_speech {
self.speech_frames += 1;
if self.speech_frames >= self.config.min_speech_frames {
self.session_state = VadSessionState::Speech;
self.segment_start = Some(self.frame_count.saturating_sub(
(self.speech_frames + self.config.padding_frames) as u64,
));
self.segment_probs.push(probability);
}
} else {
self.session_state = VadSessionState::Silence;
self.speech_frames = 0;
}
}
VadSessionState::Speech => {
self.segment_probs.push(probability);
if !is_speech {
self.silence_frames = 1;
if self.config.min_silence_frames <= 1 {
self.session_state = VadSessionState::Silence;
} else {
self.session_state = VadSessionState::SilencePending;
}
}
}
VadSessionState::SilencePending => {
self.segment_probs.push(probability);
if is_speech {
self.session_state = VadSessionState::Speech;
self.silence_frames = 0;
} else {
self.silence_frames += 1;
if self.silence_frames >= self.config.min_silence_frames {
self.session_state = VadSessionState::Silence;
self.silence_frames = 0;
}
}
}
}
}
fn check_segment_complete(&mut self) -> Option<SpeechSegment> {
if self.session_state == VadSessionState::Silence
&& self.segment_start.is_some()
&& !self.segment_probs.is_empty()
{
let segment = self.create_segment(
self.frame_count - self.config.min_silence_frames as u64
+ self.config.padding_frames as u64,
);
self.segment_start = None;
self.segment_probs.clear();
return Some(segment);
}
None
}
fn create_segment(&self, end_frame: u64) -> SpeechSegment {
let start_frame = self.segment_start.unwrap_or(0);
let frame_ms = self.config.sample_rate.frame_duration_ms();
let avg_prob = if self.segment_probs.is_empty() {
0.0
} else {
self.segment_probs.iter().sum::<f32>() / self.segment_probs.len() as f32
};
SpeechSegment {
start_ms: start_frame as f32 * frame_ms,
end_ms: end_frame as f32 * frame_ms,
start_frame,
end_frame,
avg_probability: avg_prob,
}
}
}
pub struct SimpleVad {
threshold: f32,
smoothing: f32,
current_energy: f32,
}
impl SimpleVad {
pub fn new(threshold: f32) -> Self {
Self {
threshold,
smoothing: 0.1,
current_energy: 0.0,
}
}
pub fn is_speech(&mut self, samples: &[f32]) -> bool {
let rms = (samples.iter().map(|s| s * s).sum::<f32>() / samples.len() as f32).sqrt();
self.current_energy = self.current_energy * (1.0 - self.smoothing) + rms * self.smoothing;
self.current_energy > self.threshold
}
pub fn energy(&self) -> f32 {
self.current_energy
}
pub fn reset(&mut self) {
self.current_energy = 0.0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vad_sample_rate() {
assert_eq!(VadSampleRate::Rate16k.frame_size(), 512);
assert_eq!(VadSampleRate::Rate16k.context_size(), 64);
assert_eq!(VadSampleRate::Rate8k.frame_size(), 256);
assert_eq!(VadSampleRate::Rate8k.context_size(), 32);
}
#[test]
fn test_vad_config_defaults() {
let config = VadConfig::default();
assert_eq!(config.sample_rate, VadSampleRate::Rate16k);
assert_eq!(config.threshold, 0.5);
}
#[test]
fn test_simple_vad() {
let mut vad = SimpleVad::new(0.01);
let silence = vec![0.0f32; 512];
assert!(!vad.is_speech(&silence));
let loud = vec![0.5f32; 512];
assert!(vad.is_speech(&loud));
}
#[test]
fn test_vad_state() {
let state = VadState::new(64);
assert_eq!(state.state.len(), 2 * 128);
assert_eq!(state.context.len(), 64);
}
}