use super::onnx;
use crate::error::VadError;
use crate::{ProcessTimings, VadCapabilities, VoiceActivityDetector};
use ndarray::{Array1, Array2, Array3};
use ort::{inputs, session::Session, value::Tensor};
use std::time::{Duration, Instant};
const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/silero_vad.onnx"));
const CONTEXT_SIZE: usize = 64;
const STATE_DIM: usize = 128;
pub struct SileroVad {
session: Session,
sample_rate: u32,
chunk_size: usize,
state: Array3<f32>,
context: Vec<f32>,
normalize_time: Duration,
onnx_time: Duration,
timing_frames: u64,
}
unsafe impl Send for SileroVad {}
impl SileroVad {
pub fn new(sample_rate: u32) -> Result<Self, VadError> {
Self::from_memory(MODEL_BYTES, sample_rate)
}
pub fn from_file<P: AsRef<std::path::Path>>(
path: P,
sample_rate: u32,
) -> Result<Self, VadError> {
Self::validate_sample_rate(sample_rate)?;
let chunk_size = Self::chunk_size_for_rate(sample_rate);
let session = onnx::session_from_file(path)?;
let state = Array3::<f32>::zeros((2, 1, STATE_DIM));
let context = vec![0.0f32; CONTEXT_SIZE];
Ok(Self {
session,
sample_rate,
chunk_size,
state,
context,
normalize_time: Duration::ZERO,
onnx_time: Duration::ZERO,
timing_frames: 0,
})
}
pub fn from_memory(model_bytes: &[u8], sample_rate: u32) -> Result<Self, VadError> {
Self::validate_sample_rate(sample_rate)?;
let chunk_size = Self::chunk_size_for_rate(sample_rate);
let session = onnx::session_from_memory(model_bytes)?;
let state = Array3::<f32>::zeros((2, 1, STATE_DIM));
let context = vec![0.0f32; CONTEXT_SIZE];
Ok(Self {
session,
sample_rate,
chunk_size,
state,
context,
normalize_time: Duration::ZERO,
onnx_time: Duration::ZERO,
timing_frames: 0,
})
}
fn validate_sample_rate(sample_rate: u32) -> Result<(), VadError> {
match sample_rate {
8000 | 16000 => Ok(()),
_ => Err(VadError::InvalidSampleRate(sample_rate)),
}
}
fn chunk_size_for_rate(sample_rate: u32) -> usize {
match sample_rate {
8000 => 256,
16000 => 512,
_ => unreachable!("sample rate validated before calling chunk_size_for_rate"),
}
}
}
impl VoiceActivityDetector for SileroVad {
fn capabilities(&self) -> VadCapabilities {
VadCapabilities {
sample_rate: self.sample_rate,
frame_size: self.chunk_size,
frame_duration_ms: (self.chunk_size as u32 * 1000) / self.sample_rate,
}
}
fn process(&mut self, samples: &[i16], sample_rate: u32) -> Result<f32, VadError> {
if sample_rate != self.sample_rate {
return Err(VadError::InvalidSampleRate(sample_rate));
}
if samples.len() != self.chunk_size {
return Err(VadError::InvalidFrameSize {
got: samples.len(),
expected: self.chunk_size,
});
}
let t_preprocess = Instant::now();
let samples_f32: Vec<f32> = samples.iter().map(|&s| s as f32 / 32768.0).collect();
let input_size = CONTEXT_SIZE + self.chunk_size;
let mut input_data = Vec::with_capacity(input_size);
input_data.extend_from_slice(&self.context);
input_data.extend_from_slice(&samples_f32);
self.normalize_time += t_preprocess.elapsed();
let t_inference = Instant::now();
let input_array = Array2::from_shape_vec((1, input_size), input_data)
.map_err(|e| VadError::BackendError(format!("failed to create input array: {e}")))?;
let input_tensor = Tensor::from_array(input_array)
.map_err(|e| VadError::BackendError(format!("failed to create input tensor: {e}")))?;
let state_tensor = Tensor::from_array(self.state.clone())
.map_err(|e| VadError::BackendError(format!("failed to create state tensor: {e}")))?;
let sr_array = Array1::from_vec(vec![self.sample_rate as i64]);
let sr_tensor = Tensor::from_array(sr_array)
.map_err(|e| VadError::BackendError(format!("failed to create sr tensor: {e}")))?;
let outputs = self
.session
.run(inputs![
"input" => input_tensor,
"state" => state_tensor,
"sr" => sr_tensor,
])
.map_err(|e| VadError::BackendError(format!("inference failed: {e}")))?;
let output = outputs
.get("output")
.ok_or_else(|| VadError::BackendError("missing 'output' tensor".into()))?;
let (_, output_data): (_, &[f32]) = output
.try_extract_tensor()
.map_err(|e| VadError::BackendError(format!("failed to extract output: {e}")))?;
let probability = *output_data
.first()
.ok_or_else(|| VadError::BackendError("empty output tensor".into()))?;
let new_state = outputs
.get("stateN")
.ok_or_else(|| VadError::BackendError("missing 'stateN' tensor".into()))?;
let (_, new_state_data): (_, &[f32]) = new_state
.try_extract_tensor()
.map_err(|e| VadError::BackendError(format!("failed to extract state: {e}")))?;
if new_state_data.len() == 2 * STATE_DIM {
self.state
.as_slice_mut()
.ok_or_else(|| VadError::BackendError("state buffer not contiguous".into()))?
.copy_from_slice(new_state_data);
} else {
return Err(VadError::BackendError(format!(
"unexpected state size: expected {expected}, got {got}",
expected = 2 * STATE_DIM,
got = new_state_data.len()
)));
}
let start = samples_f32.len().saturating_sub(CONTEXT_SIZE);
self.context.copy_from_slice(&samples_f32[start..]);
self.onnx_time += t_inference.elapsed();
self.timing_frames += 1;
Ok(probability.clamp(0.0, 1.0))
}
fn reset(&mut self) {
self.state.fill(0.0);
self.context.fill(0.0);
}
fn timings(&self) -> ProcessTimings {
ProcessTimings {
stages: vec![("normalize", self.normalize_time), ("onnx", self.onnx_time)],
frames: self.timing_frames,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_with_valid_rates() {
let vad_16k = SileroVad::new(16000);
assert!(vad_16k.is_ok());
let vad_8k = SileroVad::new(8000);
assert!(vad_8k.is_ok());
}
#[test]
fn create_with_invalid_rate() {
let vad = SileroVad::new(44100);
assert!(matches!(vad, Err(VadError::InvalidSampleRate(44100))));
let vad = SileroVad::new(32000);
assert!(matches!(vad, Err(VadError::InvalidSampleRate(32000))));
let vad = SileroVad::new(48000);
assert!(matches!(vad, Err(VadError::InvalidSampleRate(48000))));
}
#[test]
fn process_silence_16k() {
let mut vad = SileroVad::new(16000).unwrap();
let silence = vec![0i16; 512]; let result = vad.process(&silence, 16000).unwrap();
assert!(
result < 0.5,
"Expected low probability for silence, got {result}"
);
}
#[test]
fn process_silence_8k() {
let mut vad = SileroVad::new(8000).unwrap();
let silence = vec![0i16; 256]; let result = vad.process(&silence, 8000).unwrap();
assert!(
result < 0.5,
"Expected low probability for silence, got {result}"
);
}
#[test]
fn process_wrong_sample_rate() {
let mut vad = SileroVad::new(16000).unwrap();
let samples = vec![0i16; 512];
let result = vad.process(&samples, 8000);
assert!(matches!(result, Err(VadError::InvalidSampleRate(8000))));
}
#[test]
fn process_invalid_frame_size() {
let mut vad = SileroVad::new(16000).unwrap();
let samples = vec![0i16; 100]; let result = vad.process(&samples, 16000);
assert!(matches!(
result,
Err(VadError::InvalidFrameSize {
got: 100,
expected: 512
})
));
}
#[test]
fn process_returns_continuous_probability() {
let mut vad = SileroVad::new(16000).unwrap();
let samples: Vec<i16> = (0..512).map(|i| (i % 100) as i16 * 50).collect();
let result = vad.process(&samples, 16000).unwrap();
assert!(result >= 0.0 && result <= 1.0);
}
#[test]
fn reset_clears_state() {
let mut vad = SileroVad::new(16000).unwrap();
let samples: Vec<i16> = (0..512).map(|i| i as i16 * 10).collect();
let _ = vad.process(&samples, 16000).unwrap();
vad.reset();
let silence = vec![0i16; 512];
let result = vad.process(&silence, 16000).unwrap();
assert!(result < 0.5);
}
#[test]
fn state_persists_between_calls() {
let mut vad = SileroVad::new(16000).unwrap();
let silence = vec![0i16; 512];
let prob1 = vad.process(&silence, 16000).unwrap();
let prob2 = vad.process(&silence, 16000).unwrap();
let prob3 = vad.process(&silence, 16000).unwrap();
assert!(prob1 < 0.5);
assert!(prob2 < 0.5);
assert!(prob3 < 0.5);
}
#[test]
fn from_memory_with_embedded_model() {
let vad = SileroVad::from_memory(MODEL_BYTES, 16000);
assert!(vad.is_ok(), "from_memory failed: {:?}", vad.err());
}
#[test]
fn from_memory_invalid_bytes() {
let result = SileroVad::from_memory(b"not a valid onnx model", 16000);
assert!(result.is_err());
assert!(matches!(result, Err(VadError::BackendError(_))));
}
#[test]
fn from_memory_invalid_sample_rate() {
let result = SileroVad::from_memory(MODEL_BYTES, 44100);
assert!(matches!(result, Err(VadError::InvalidSampleRate(44100))));
}
#[test]
fn from_file_nonexistent() {
let result = SileroVad::from_file("/nonexistent/model.onnx", 16000);
assert!(result.is_err());
assert!(matches!(result, Err(VadError::BackendError(_))));
}
#[test]
fn from_file_with_temp_model() {
let dir = std::env::temp_dir().join("wavekat_vad_test");
std::fs::create_dir_all(&dir).unwrap();
let model_path = dir.join("silero-vad-test.onnx");
std::fs::write(&model_path, MODEL_BYTES).unwrap();
let result = SileroVad::from_file(&model_path, 16000);
assert!(result.is_ok(), "from_file failed: {:?}", result.err());
let mut vad = result.unwrap();
let silence = vec![0i16; 512];
let prob = vad.process(&silence, 16000).unwrap();
assert!(prob >= 0.0 && prob <= 1.0);
let _ = std::fs::remove_file(&model_path);
}
}