use anyhow::{Context, Result};
use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
use ort::{session::Session, value::TensorRef};
use std::path::{Path, PathBuf};
const STATE_ELEMS: usize = 2 * 128; const SILERO_REPO: &str = "onnx-community/silero-vad";
const SILERO_HF_FILE: &str = "onnx/model.onnx";
const SILERO_LOCAL_FILE: &str = "silero_vad.onnx";
pub struct SileroVad {
session: Session,
state: Vec<f32>, sample_rate: i64,
}
impl SileroVad {
pub fn open(model_path: &Path) -> Result<Self> {
let session = Session::builder()
.context("Failed to create ONNX session builder")?
.with_intra_threads(1)
.context("Failed to set intra-op threads")?
.commit_from_file(model_path)
.with_context(|| {
format!(
"Failed to load Silero VAD model from {}",
model_path.display()
)
})?;
Ok(Self {
session,
state: vec![0.0f32; STATE_ELEMS],
sample_rate: 16000,
})
}
pub fn reset_state(&mut self) {
self.state.fill(0.0);
}
pub fn probability(&mut self, frame: &[f32; 512]) -> Result<f32> {
let sr_buf = [self.sample_rate];
let input_t = TensorRef::from_array_view(([1_usize, 512], frame.as_slice()))
.context("input tensor")?;
let state_t = TensorRef::from_array_view(([2_usize, 1, 128], self.state.as_slice()))
.context("state tensor")?;
let sr_t =
TensorRef::from_array_view(([1_usize], sr_buf.as_slice())).context("sr tensor")?;
let (prob, new_state) = {
let outputs = self
.session
.run(ort::inputs![
"input" => input_t,
"state" => state_t,
"sr" => sr_t
])
.context("VAD inference")?;
let (_, prob_data) = outputs["output"]
.try_extract_tensor::<f32>()
.context("extract output")?;
let prob = prob_data[0];
let (_, state_data) = outputs["stateN"]
.try_extract_tensor::<f32>()
.context("extract stateN")?;
let new_state = state_data.to_vec();
(prob, new_state)
};
self.state = new_state;
Ok(prob)
}
}
pub fn ensure_silero_model(models_dir: &Path) -> Result<PathBuf> {
let target = models_dir.join(SILERO_LOCAL_FILE);
if target.exists() {
return Ok(target);
}
std::fs::create_dir_all(models_dir)
.with_context(|| format!("create models dir {}", models_dir.display()))?;
let api = ApiBuilder::from_env()
.build()
.context("HuggingFace API init")?;
let cached = api
.repo(Repo::new(SILERO_REPO.to_string(), RepoType::Model))
.get(SILERO_HF_FILE)
.with_context(|| format!("download {SILERO_HF_FILE} from {SILERO_REPO}"))?;
std::fs::copy(&cached, &target).with_context(|| format!("copy to {}", target.display()))?;
Ok(target)
}
#[cfg(test)]
mod tests {
use super::*;
use std::f32::consts::TAU;
fn models_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.expect("workspace root")
.join("models")
}
fn open_vad() -> Result<SileroVad> {
let model_path = ensure_silero_model(&models_dir())?;
SileroVad::open(&model_path)
}
fn mean_prob(vad: &mut SileroVad, samples: &[f32]) -> Result<f32> {
let frames = samples.len() / 512;
let mut sum = 0.0f32;
for i in 0..frames {
let frame: &[f32; 512] = samples[i * 512..(i + 1) * 512]
.try_into()
.expect("slice length");
sum += vad.probability(frame)?;
}
Ok(if frames > 0 { sum / frames as f32 } else { 0.0 })
}
#[test]
#[ignore = "requires silero_vad.onnx in ./models — download via ensure_silero_model or wforge convert"]
fn test_vad_silence() -> Result<()> {
let mut vad = open_vad()?;
let silence = vec![0.0f32; 16000];
let p = mean_prob(&mut vad, &silence)?;
assert!(p < 0.1, "silence should score < 0.1, got {p:.4}");
Ok(())
}
#[test]
#[ignore = "requires silero_vad.onnx in ./models — download via ensure_silero_model or wforge convert"]
fn test_vad_sine_not_speech() -> Result<()> {
let mut vad = open_vad()?;
let sine: Vec<f32> = (0..16000)
.map(|i| (TAU * 1000.0 * i as f32 / 16000.0).sin() * 0.5)
.collect();
let p = mean_prob(&mut vad, &sine)?;
assert!(p < 0.3, "1 kHz sine should score < 0.3, got {p:.4}");
Ok(())
}
fn read_f32_wav(path: &std::path::Path) -> Result<Vec<f32>> {
let bytes = std::fs::read(path).with_context(|| format!("read {}", path.display()))?;
let mut pos = 12usize;
let mut data_start = None;
let mut data_len = 0usize;
while pos + 8 <= bytes.len() {
let id = &bytes[pos..pos + 4];
let size = u32::from_le_bytes(bytes[pos + 4..pos + 8].try_into().unwrap()) as usize;
if id == b"data" {
data_start = Some(pos + 8);
data_len = size;
break;
}
pos += 8 + size + (size & 1); }
let start = data_start.context("no 'data' chunk in WAV")?;
let end = (start + data_len).min(bytes.len());
let n = (end - start) / 4;
let mut samples = Vec::with_capacity(n);
for i in 0..n {
let b: [u8; 4] = bytes[start + i * 4..start + i * 4 + 4].try_into().unwrap();
samples.push(f32::from_le_bytes(b));
}
Ok(samples)
}
#[test]
#[ignore = "requires silero_vad.onnx in ./models AND test_data/LJ001-0001_16k.wav at repo root"]
fn test_vad_speech() -> Result<()> {
let repo_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.expect("workspace root")
.to_path_buf();
let wav = repo_root.join("test_data").join("LJ001-0001_16k.wav");
assert!(wav.exists(), "speech WAV not found: {}", wav.display());
let samples = read_f32_wav(&wav)?;
assert!(!samples.is_empty(), "no samples in {}", wav.display());
let mut vad = open_vad()?;
let p = mean_prob(&mut vad, &samples)?;
assert!(p > 0.7, "speech should score > 0.7, got {p:.4}");
Ok(())
}
}