use std::path::Path;
use anyhow::{Context, Result};
use ort::session::Session;
use ort::value::TensorRef;
use parking_lot::Mutex;
fn ort_err(e: impl std::fmt::Display) -> anyhow::Error {
anyhow::anyhow!("{e}")
}
pub const VAD_MODEL_FILE: &str = "silero_vad.onnx";
pub const VAD_SAMPLE_RATE: i64 = 16000;
pub const VAD_FRAME_SAMPLES: usize = 512;
const VAD_STATE_LEN: usize = 2 * 128;
#[derive(Debug, Clone, Copy)]
pub struct VadConfig {
pub threshold: f32,
pub min_silence_ms: u32,
pub min_speech_ms: u32,
pub speech_pad_ms: u32,
}
impl Default for VadConfig {
fn default() -> Self {
Self {
threshold: 0.5,
min_silence_ms: 500,
min_speech_ms: 250,
speech_pad_ms: 100,
}
}
}
impl VadConfig {
fn ms_to_samples(ms: u32) -> usize {
(VAD_SAMPLE_RATE as usize * ms as usize) / 1000
}
}
pub struct SileroVad {
session: Mutex<Session>,
}
impl SileroVad {
pub fn load(model_path: &Path) -> Result<Self> {
let session = Session::builder()
.map_err(ort_err)?
.commit_from_file(model_path)
.map_err(ort_err)
.with_context(|| format!("Failed to load VAD model {}", model_path.display()))?;
tracing::info!("VAD model loaded from {}", model_path.display());
Ok(Self {
session: Mutex::new(session),
})
}
fn run_frame(&self, frame: &[f32], state: &mut [f32; VAD_STATE_LEN]) -> Result<f32> {
let mut input = [0.0f32; VAD_FRAME_SAMPLES];
let n = frame.len().min(VAD_FRAME_SAMPLES);
input[..n].copy_from_slice(&frame[..n]);
let input_t = TensorRef::from_array_view(([1_usize, VAD_FRAME_SAMPLES], input.as_slice()))?;
let state_t = TensorRef::from_array_view(([2_usize, 1, 128], state.as_slice()))?;
let sr_t = TensorRef::from_array_view(([1_usize], [VAD_SAMPLE_RATE].as_slice()))?;
let prob = {
let mut session = self.session.lock();
let outputs = session
.run(ort::inputs![
"input" => input_t,
"state" => state_t,
"sr" => sr_t,
])
.context("VAD model inference failed")?;
let (_, new_state) = outputs["stateN"]
.try_extract_tensor::<f32>()
.context("failed to extract VAD state")?;
if new_state.len() != VAD_STATE_LEN {
anyhow::bail!(
"unexpected VAD state length {} (expected {VAD_STATE_LEN})",
new_state.len()
);
}
state.copy_from_slice(new_state);
let (_, prob) = outputs["output"]
.try_extract_tensor::<f32>()
.context("failed to extract VAD probability")?;
prob.first().copied().unwrap_or(0.0)
};
Ok(prob)
}
pub fn frame_probs(&self, samples: &[f32]) -> Result<Vec<f32>> {
let mut state = [0.0f32; VAD_STATE_LEN];
let mut probs = Vec::with_capacity(samples.len() / VAD_FRAME_SAMPLES + 1);
let mut i = 0;
while i < samples.len() {
let end = (i + VAD_FRAME_SAMPLES).min(samples.len());
probs.push(self.run_frame(&samples[i..end], &mut state)?);
i = end;
}
Ok(probs)
}
pub fn speech_regions(&self, samples: &[f32], cfg: &VadConfig) -> Result<Vec<(usize, usize)>> {
let probs = self.frame_probs(samples)?;
Ok(regions_from_probs(
&probs,
VAD_FRAME_SAMPLES,
samples.len(),
cfg,
))
}
}
pub fn regions_from_probs(
probs: &[f32],
frame_samples: usize,
total_samples: usize,
cfg: &VadConfig,
) -> Vec<(usize, usize)> {
if probs.is_empty() || total_samples == 0 {
return Vec::new();
}
let min_silence = VadConfig::ms_to_samples(cfg.min_silence_ms);
let min_speech = VadConfig::ms_to_samples(cfg.min_speech_ms);
let pad = VadConfig::ms_to_samples(cfg.speech_pad_ms);
let mut regions: Vec<(usize, usize)> = Vec::new();
let mut run_start: Option<usize> = None;
for (i, &p) in probs.iter().enumerate() {
let speech = p >= cfg.threshold;
if speech && run_start.is_none() {
run_start = Some(i * frame_samples);
} else if !speech && let Some(s) = run_start.take() {
regions.push((s, i * frame_samples));
}
}
if let Some(s) = run_start.take() {
regions.push((s, total_samples));
}
if regions.is_empty() {
return regions;
}
let mut merged: Vec<(usize, usize)> = Vec::with_capacity(regions.len());
for (s, e) in regions {
match merged.last_mut() {
Some(last) if s.saturating_sub(last.1) < min_silence => last.1 = e,
_ => merged.push((s, e)),
}
}
merged.retain(|(s, e)| e - s >= min_speech);
if merged.is_empty() {
return merged;
}
let mut padded: Vec<(usize, usize)> = Vec::with_capacity(merged.len());
for (s, e) in merged {
let ps = s.saturating_sub(pad);
let pe = (e + pad).min(total_samples);
match padded.last_mut() {
Some(last) if ps <= last.1 => last.1 = last.1.max(pe),
_ => padded.push((ps, pe)),
}
}
padded
}
pub fn remap_compressed_seconds(
t_compressed_s: f64,
regions: &[(usize, usize)],
sample_rate: f64,
) -> f64 {
if regions.is_empty() {
return t_compressed_s;
}
let target = (t_compressed_s * sample_rate).max(0.0);
let mut acc = 0.0f64; for &(s, e) in regions {
let len = (e - s) as f64;
if target <= acc + len {
let into = (target - acc).max(0.0);
return (s as f64 + into) / sample_rate;
}
acc += len;
}
let &(_, end) = regions.last().expect("non-empty checked above");
end as f64 / sample_rate
}
pub struct VadEndpointer {
state: [f32; VAD_STATE_LEN],
leftover: Vec<f32>,
hangover: Hangover,
}
impl VadEndpointer {
pub fn new(cfg: &VadConfig) -> Self {
Self {
state: [0.0f32; VAD_STATE_LEN],
leftover: Vec::with_capacity(VAD_FRAME_SAMPLES),
hangover: Hangover::new(cfg),
}
}
pub fn push(&mut self, vad: &SileroVad, samples: &[f32]) -> Result<bool> {
self.leftover.extend_from_slice(samples);
let mut endpoint = false;
let mut off = 0;
while off + VAD_FRAME_SAMPLES <= self.leftover.len() {
let prob = vad.run_frame(
&self.leftover[off..off + VAD_FRAME_SAMPLES],
&mut self.state,
)?;
off += VAD_FRAME_SAMPLES;
if self.hangover.update(prob, VAD_FRAME_SAMPLES) {
endpoint = true;
}
}
if off > 0 {
self.leftover.drain(..off);
}
Ok(endpoint)
}
}
#[derive(Debug)]
pub struct Hangover {
threshold: f32,
min_silence_samples: usize,
seen_speech: bool,
trailing_silence: usize,
armed: bool,
}
impl Hangover {
fn new(cfg: &VadConfig) -> Self {
Self {
threshold: cfg.threshold,
min_silence_samples: VadConfig::ms_to_samples(cfg.min_silence_ms),
seen_speech: false,
trailing_silence: 0,
armed: false,
}
}
fn update(&mut self, prob: f32, frame_samples: usize) -> bool {
if prob >= self.threshold {
self.seen_speech = true;
self.armed = true;
self.trailing_silence = 0;
return false;
}
if !self.seen_speech {
return false;
}
self.trailing_silence += frame_samples;
if self.armed && self.trailing_silence >= self.min_silence_samples {
self.armed = false; return true;
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg(
threshold: f32,
min_silence_ms: u32,
min_speech_ms: u32,
speech_pad_ms: u32,
) -> VadConfig {
VadConfig {
threshold,
min_silence_ms,
min_speech_ms,
speech_pad_ms,
}
}
#[test]
fn test_ms_to_samples_16khz() {
assert_eq!(VadConfig::ms_to_samples(1000), 16000);
assert_eq!(VadConfig::ms_to_samples(500), 8000);
assert_eq!(VadConfig::ms_to_samples(0), 0);
}
#[test]
fn test_regions_empty_probs_is_empty() {
let c = VadConfig::default();
assert!(regions_from_probs(&[], 512, 0, &c).is_empty());
assert!(regions_from_probs(&[0.9, 0.9], 512, 0, &c).is_empty());
}
#[test]
fn test_regions_all_silence_is_empty() {
let c = cfg(0.5, 0, 0, 0);
let probs = vec![0.1f32; 10];
assert!(regions_from_probs(&probs, 512, 10 * 512, &c).is_empty());
}
#[test]
fn test_regions_single_block_no_pad_no_mins() {
let c = cfg(0.5, 0, 0, 0);
let probs = [0.1, 0.9, 0.9, 0.1];
let r = regions_from_probs(&probs, 100, 400, &c);
assert_eq!(r, vec![(100, 300)]);
}
#[test]
fn test_regions_trailing_speech_clamps_to_total() {
let c = cfg(0.5, 0, 0, 0);
let probs = [0.1, 0.9, 0.9];
let r = regions_from_probs(&probs, 100, 250, &c);
assert_eq!(r, vec![(100, 250)]);
}
#[test]
fn test_regions_min_silence_merges_short_gap() {
let c = cfg(0.5, 100, 0, 0); let probs = [0.9, 0.1, 0.9];
let r = regions_from_probs(&probs, 100, 300, &c);
assert_eq!(r, vec![(0, 300)]);
}
#[test]
fn test_regions_long_gap_keeps_two_regions() {
let c = cfg(0.5, 0, 0, 0);
let probs = [0.9, 0.1, 0.1, 0.9];
let r = regions_from_probs(&probs, 100, 400, &c);
assert_eq!(r, vec![(0, 100), (300, 400)]);
}
#[test]
fn test_regions_min_speech_drops_short_blip() {
let c = cfg(0.5, 0, 100, 0);
let probs = [0.1, 0.9, 0.1];
assert!(regions_from_probs(&probs, 100, 300, &c).is_empty());
}
#[test]
fn test_regions_padding_extends_and_clamps() {
let c = cfg(0.5, 0, 0, 10); let probs = [0.1, 0.9, 0.1];
let r = regions_from_probs(&probs, 100, 1000, &c);
assert_eq!(r, vec![(0, 360)]);
}
#[test]
fn test_regions_padding_merges_overlapping_neighbours() {
let c = cfg(0.5, 0, 0, 50); let probs = [0.9, 0.1, 0.1, 0.9, 0.1];
let r = regions_from_probs(&probs, 100, 2000, &c);
assert_eq!(r, vec![(0, 1200)]);
}
#[test]
fn test_hangover_fires_once_after_min_silence() {
let c = cfg(0.5, 100, 0, 0); let mut h = Hangover::new(&c);
assert!(!h.update(0.9, 512));
assert!(!h.update(0.1, 512)); assert!(!h.update(0.1, 512)); assert!(!h.update(0.1, 512)); assert!(h.update(0.1, 512)); assert!(!h.update(0.1, 512));
}
#[test]
fn test_hangover_no_fire_before_any_speech() {
let c = cfg(0.5, 0, 0, 0);
let mut h = Hangover::new(&c);
for _ in 0..10 {
assert!(!h.update(0.1, 512));
}
}
#[test]
fn test_hangover_rearms_for_next_utterance() {
let c = cfg(0.5, 50, 0, 0); let mut h = Hangover::new(&c);
h.update(0.9, 512); assert!(!h.update(0.1, 512)); assert!(h.update(0.1, 512)); assert!(!h.update(0.9, 512));
assert!(!h.update(0.1, 512)); assert!(h.update(0.1, 512)); }
#[test]
fn test_remap_no_regions_is_identity() {
assert_eq!(remap_compressed_seconds(1.5, &[], 16000.0), 1.5);
}
#[test]
fn test_remap_single_region_offsets_by_start() {
let regions = [(16000usize, 32000usize)];
assert_eq!(remap_compressed_seconds(0.0, ®ions, 16000.0), 1.0);
assert_eq!(remap_compressed_seconds(0.5, ®ions, 16000.0), 1.5);
}
#[test]
fn test_remap_second_region_skips_silence_gap() {
let regions = [(0usize, 16000usize), (48000usize, 64000usize)];
assert_eq!(remap_compressed_seconds(0.5, ®ions, 16000.0), 0.5);
assert_eq!(remap_compressed_seconds(1.5, ®ions, 16000.0), 3.5);
}
#[test]
fn test_remap_past_end_clamps_to_last_region_end() {
let regions = [(0usize, 16000usize), (48000usize, 64000usize)];
assert_eq!(remap_compressed_seconds(10.0, ®ions, 16000.0), 4.0);
}
#[test]
#[ignore = "requires the Silero VAD model at ~/.gigastt/models/vad/silero_vad.onnx"]
fn test_silero_silence_low_prob_and_runs() {
let home = std::env::var("HOME").expect("HOME");
let path = std::path::PathBuf::from(home).join(".gigastt/models/vad/silero_vad.onnx");
let vad = SileroVad::load(&path).expect("load silero");
let silence = vec![0.0f32; 16000];
let probs = vad.frame_probs(&silence).expect("frame_probs");
assert!(!probs.is_empty(), "expected at least one frame");
for p in &probs {
assert!((0.0..=1.0).contains(p), "prob {p} out of range");
}
let max_silence = probs.iter().cloned().fold(0.0f32, f32::max);
assert!(
max_silence < 0.5,
"silence should be below threshold, got {max_silence}"
);
let tone: Vec<f32> = (0..16000)
.map(|i| 0.5 * (2.0 * std::f32::consts::PI * 200.0 * i as f32 / 16000.0).sin())
.collect();
let probs2 = vad.frame_probs(&tone).expect("frame_probs tone");
for p in &probs2 {
assert!((0.0..=1.0).contains(p), "tone prob {p} out of range");
}
assert!(
vad.speech_regions(&silence, &VadConfig::default())
.expect("regions")
.is_empty()
);
}
}