use crate::vad::VadState;
use crate::VoiceConfig;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::Instant;
pub trait VadBackend: Send {
fn process_samples(&mut self, samples: &[f32]);
fn is_speech_active(&self) -> bool;
fn turn_ended(&self) -> bool;
fn reset_speech_state(&mut self);
fn set_threshold_boost(&mut self, db: f32);
fn noise_floor_db(&self) -> f32 {
-96.0
}
fn is_calibrated(&self) -> bool {
true
}
}
pub struct EnergyVadBackend {
inner: VadState,
}
impl EnergyVadBackend {
pub fn from_config(sample_rate: u32, config: &VoiceConfig) -> Self {
Self {
inner: VadState::from_config(sample_rate, config),
}
}
}
impl VadBackend for EnergyVadBackend {
fn process_samples(&mut self, samples: &[f32]) {
self.inner.process_samples(samples);
}
fn is_speech_active(&self) -> bool {
self.inner.is_speech_active()
}
fn turn_ended(&self) -> bool {
self.inner.turn_ended()
}
fn reset_speech_state(&mut self) {
self.inner.reset_speech_state();
}
fn set_threshold_boost(&mut self, db: f32) {
self.inner.set_threshold_boost(db);
}
fn noise_floor_db(&self) -> f32 {
self.inner.noise_floor_db()
}
fn is_calibrated(&self) -> bool {
self.inner.is_calibrated()
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
use mlx_rs::ops;
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
use mlx_rs::ops::indexing::IndexOp;
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
use mlx_rs::Array;
const SILERO_V6_REPO: &str = "mlx-community/silero-vad-v6";
const SILERO_CONTEXT_SIZE: usize = 64;
const SILERO_STFT_PAD_RIGHT: usize = 64;
const SILERO_LSTM_HIDDEN: i32 = 128;
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
const SILERO_V6_KEYS: &[&str] = &[
"vad_16k.stft_conv.weight",
"vad_16k.conv1.weight",
"vad_16k.conv1.bias",
"vad_16k.conv2.weight",
"vad_16k.conv2.bias",
"vad_16k.conv3.weight",
"vad_16k.conv3.bias",
"vad_16k.conv4.weight",
"vad_16k.conv4.bias",
"vad_16k.lstm.Wx",
"vad_16k.lstm.Wh",
"vad_16k.lstm.bias",
"vad_16k.final_conv.weight",
"vad_16k.final_conv.bias",
];
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
pub struct MlxSileroVadBackend {
model: MlxSileroV6,
source_rate: u32,
pending_resample_pos: f64,
window: Vec<f32>,
context: [f32; SILERO_CONTEXT_SIZE],
last_probability: f32,
is_speaking: bool,
silence_onset_at: Option<Instant>,
turn_end_ms: u64,
boost: f32,
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
impl MlxSileroVadBackend {
pub fn new(source_rate: u32, config: &VoiceConfig) -> Result<Self, SileroError> {
let weights_path = find_or_download_mlx_silero_v6()?;
let model = MlxSileroV6::load(&weights_path)?;
Ok(Self {
model,
source_rate,
pending_resample_pos: 0.0,
window: Vec::with_capacity(SILERO_CHUNK_SIZE * 2),
context: [0.0; SILERO_CONTEXT_SIZE],
last_probability: 0.0,
is_speaking: false,
silence_onset_at: None,
turn_end_ms: config.turn_end_ms as u64,
boost: 0.0,
})
}
fn current_threshold(&self) -> f32 {
DEFAULT_THRESHOLD + self.boost
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
impl VadBackend for MlxSileroVadBackend {
fn process_samples(&mut self, samples: &[f32]) {
if samples.is_empty() {
return;
}
let ratio = self.source_rate as f64 / SILERO_SAMPLE_RATE as f64;
let mut src_pos = self.pending_resample_pos;
while src_pos < samples.len() as f64 {
let lo = src_pos.floor() as usize;
let hi = (lo + 1).min(samples.len() - 1);
let t = (src_pos - lo as f64) as f32;
let s = samples[lo] * (1.0 - t) + samples[hi] * t;
self.window.push(s);
src_pos += ratio;
}
self.pending_resample_pos = src_pos - samples.len() as f64;
let was_speaking = self.is_speaking;
let threshold = self.current_threshold();
while self.window.len() >= SILERO_CHUNK_SIZE {
let chunk: Vec<f32> = self.window.drain(..SILERO_CHUNK_SIZE).collect();
match self.model.predict(&self.context, &chunk) {
Ok(probability) => {
self.last_probability = probability;
self.is_speaking = probability >= threshold;
self.context
.copy_from_slice(&chunk[SILERO_CHUNK_SIZE - SILERO_CONTEXT_SIZE..]);
}
Err(e) => {
tracing::warn!("[voice] Silero VAD v6 MLX inference failed: {e}");
self.is_speaking = false;
break;
}
}
}
let now = Instant::now();
if self.is_speaking {
self.silence_onset_at = None;
} else if was_speaking && !self.is_speaking {
self.silence_onset_at = Some(now);
}
}
fn is_speech_active(&self) -> bool {
self.is_speaking
}
fn turn_ended(&self) -> bool {
match self.silence_onset_at {
Some(at) => at.elapsed().as_millis() as u64 >= self.turn_end_ms,
None => false,
}
}
fn reset_speech_state(&mut self) {
self.model.reset_state();
self.is_speaking = false;
self.silence_onset_at = None;
self.window.clear();
self.context = [0.0; SILERO_CONTEXT_SIZE];
self.pending_resample_pos = 0.0;
}
fn set_threshold_boost(&mut self, db: f32) {
self.boost = if db > 0.0 {
BARGE_IN_PROBABILITY_OFFSET
} else {
0.0
};
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
struct MlxSileroV6 {
weights: HashMap<String, Array>,
h: Array,
c: Array,
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
impl MlxSileroV6 {
fn load(path: &Path) -> Result<Self, SileroError> {
let loaded = Array::load_safetensors(path)
.map_err(|e| SileroError::ModelLoad(format!("load {}: {e}", path.display())))?;
let mut weights = HashMap::new();
for key in SILERO_V6_KEYS {
let value = loaded
.get(*key)
.ok_or_else(|| SileroError::ModelLoad(format!("missing tensor {key}")))?
.clone();
weights.insert((*key).to_string(), value);
}
let model = Self {
weights,
h: Array::zeros::<f32>(&[1, SILERO_LSTM_HIDDEN])
.map_err(|e| SileroError::ModelLoad(format!("init h: {e}")))?,
c: Array::zeros::<f32>(&[1, SILERO_LSTM_HIDDEN])
.map_err(|e| SileroError::ModelLoad(format!("init c: {e}")))?,
};
mlx_rs::transforms::eval(model.weights.values().chain([&model.h, &model.c]))
.map_err(|e| SileroError::ModelLoad(format!("eval weights: {e}")))?;
Ok(model)
}
fn reset_state(&mut self) {
if let Ok(h) = Array::zeros::<f32>(&[1, SILERO_LSTM_HIDDEN]) {
self.h = h;
}
if let Ok(c) = Array::zeros::<f32>(&[1, SILERO_LSTM_HIDDEN]) {
self.c = c;
}
}
fn predict(
&mut self,
context: &[f32; SILERO_CONTEXT_SIZE],
chunk: &[f32],
) -> Result<f32, SileroError> {
let mut merged =
Vec::with_capacity(SILERO_CONTEXT_SIZE + SILERO_CHUNK_SIZE + SILERO_STFT_PAD_RIGHT);
merged.extend_from_slice(context);
merged.extend_from_slice(chunk);
let len = merged.len();
for i in 0..SILERO_STFT_PAD_RIGHT {
merged.push(merged[len - 2 - i]);
}
let audio = Array::from_slice(&merged, &[1, merged.len() as i32, 1]);
let (prob, h, c) = self.forward(&audio)?;
mlx_rs::transforms::eval([&prob, &h, &c]).map_err(map_mlx)?;
let probability = prob.as_slice::<f32>()[0];
self.h = h;
self.c = c;
Ok(probability)
}
fn forward(&self, audio_640: &Array) -> Result<(Array, Array, Array), SileroError> {
let w = |key: &str| -> Result<&Array, SileroError> {
self.weights
.get(key)
.ok_or_else(|| SileroError::ModelLoad(format!("missing tensor {key}")))
};
let z = ops::conv1d(
audio_640,
w("vad_16k.stft_conv.weight")?,
128,
0,
1,
None::<i32>,
)
.map_err(map_mlx)?;
let real = z.index((.., .., 0..129));
let imag = z.index((.., .., 129..258));
let x2 = ops::add(
&ops::add(
&ops::multiply(&real, &real).map_err(map_mlx)?,
&ops::multiply(&imag, &imag).map_err(map_mlx)?,
)
.map_err(map_mlx)?,
&Array::from_f32(1e-12),
)
.map_err(map_mlx)?;
let mut x = ops::sqrt(&x2).map_err(map_mlx)?;
for (idx, stride, padding) in [(1, 1, 1), (2, 2, 1), (3, 2, 1), (4, 1, 1)] {
x = ops::conv1d(
&x,
w(&format!("vad_16k.conv{idx}.weight"))?,
stride,
padding,
1,
None::<i32>,
)
.map_err(map_mlx)?;
x = ops::add(&x, w(&format!("vad_16k.conv{idx}.bias"))?).map_err(map_mlx)?;
x = ops::maximum(&x, &Array::from_f32(0.0)).map_err(map_mlx)?;
}
let feat = x.index((.., 0, ..));
let wx_t = ops::transpose_axes(w("vad_16k.lstm.Wx")?, &[1, 0]).map_err(map_mlx)?;
let wh_t = ops::transpose_axes(w("vad_16k.lstm.Wh")?, &[1, 0]).map_err(map_mlx)?;
let gates = ops::add(
&ops::add(
&ops::matmul(&feat, &wx_t).map_err(map_mlx)?,
&ops::matmul(&self.h, &wh_t).map_err(map_mlx)?,
)
.map_err(map_mlx)?,
w("vad_16k.lstm.bias")?,
)
.map_err(map_mlx)?;
let pieces = ops::split(&gates, 4, -1).map_err(map_mlx)?;
let i_g = ops::sigmoid(&pieces[0]).map_err(map_mlx)?;
let f_g = ops::sigmoid(&pieces[1]).map_err(map_mlx)?;
let g_g = ops::tanh(&pieces[2]).map_err(map_mlx)?;
let o_g = ops::sigmoid(&pieces[3]).map_err(map_mlx)?;
let c_new = ops::add(
&ops::multiply(&f_g, &self.c).map_err(map_mlx)?,
&ops::multiply(&i_g, &g_g).map_err(map_mlx)?,
)
.map_err(map_mlx)?;
let h_new = ops::multiply(&o_g, &ops::tanh(&c_new).map_err(map_mlx)?).map_err(map_mlx)?;
let dec_in = ops::reshape(
&ops::maximum(&h_new, &Array::from_f32(0.0)).map_err(map_mlx)?,
&[1, 1, SILERO_LSTM_HIDDEN],
)
.map_err(map_mlx)?;
let mut dec = ops::conv1d(
&dec_in,
w("vad_16k.final_conv.weight")?,
1,
0,
1,
None::<i32>,
)
.map_err(map_mlx)?;
dec = ops::add(&dec, w("vad_16k.final_conv.bias")?).map_err(map_mlx)?;
let prob = ops::sigmoid(&dec).map_err(map_mlx)?;
Ok((prob, h_new, c_new))
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn map_mlx(e: mlx_rs::error::Exception) -> SileroError {
SileroError::Inference(format!("{e}"))
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn find_or_download_mlx_silero_v6() -> Result<PathBuf, SileroError> {
if let Some(path) = latest_huggingface_repo_file(SILERO_V6_REPO, "model.safetensors") {
return Ok(path);
}
let model_dir = std::env::var("HF_HOME")
.map(PathBuf::from)
.map(|hf_home| {
hf_home
.join("hub")
.join(format!("models--{}", SILERO_V6_REPO.replace('/', "--")))
.join("snapshots")
.join("manual")
})
.unwrap_or_else(|_| {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".car")
.join("models")
.join("silero-vad-v6-mlx")
});
let model_path = model_dir.join("model.safetensors");
if model_path.exists() {
return Ok(model_path);
}
std::fs::create_dir_all(&model_dir)
.map_err(|e| SileroError::ModelLoad(format!("create {}: {e}", model_dir.display())))?;
let url = "https://huggingface.co/mlx-community/silero-vad-v6/resolve/main/model.safetensors";
tracing::info!(%url, path = %model_path.display(), "downloading Silero VAD v6 MLX weights");
let bytes = reqwest::blocking::get(url)
.and_then(|response| response.error_for_status())
.map_err(|e| SileroError::ModelLoad(format!("download {url}: {e}")))?
.bytes()
.map_err(|e| SileroError::ModelLoad(format!("read {url}: {e}")))?;
std::fs::write(&model_path, &bytes)
.map_err(|e| SileroError::ModelLoad(format!("write {}: {e}", model_path.display())))?;
Ok(model_path)
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn latest_huggingface_repo_file(repo_id: &str, filename: &str) -> Option<PathBuf> {
let repo_dir = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".cache")
.join("huggingface")
})
.join("hub")
.join(format!("models--{}", repo_id.replace('/', "--")));
let snapshots = repo_dir.join("snapshots");
let mut candidates: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
.ok()?
.filter_map(Result::ok)
.map(|entry| entry.path().join(filename))
.filter(|path| path.exists())
.map(|path| {
let modified = path
.metadata()
.and_then(|metadata| metadata.modified())
.unwrap_or(std::time::SystemTime::UNIX_EPOCH);
(modified, path)
})
.collect();
candidates.sort();
candidates.pop().map(|(_, path)| path)
}
use voice_activity_detector::VoiceActivityDetector;
const SILERO_SAMPLE_RATE: usize = 16_000;
const SILERO_CHUNK_SIZE: usize = 512;
const DEFAULT_THRESHOLD: f32 = 0.50;
const BARGE_IN_PROBABILITY_OFFSET: f32 = 0.30;
pub struct SileroVadBackend {
detector: VoiceActivityDetector,
source_rate: u32,
pending_resample_pos: f64,
window: Vec<f32>,
last_probability: f32,
is_speaking: bool,
silence_onset_at: Option<Instant>,
turn_end_ms: u64,
boost: f32,
}
impl SileroVadBackend {
pub fn new(source_rate: u32, config: &VoiceConfig) -> Result<Self, SileroError> {
let detector = VoiceActivityDetector::builder()
.sample_rate(SILERO_SAMPLE_RATE as i64)
.chunk_size(SILERO_CHUNK_SIZE)
.build()
.map_err(|e| SileroError::ModelLoad(format!("{e}")))?;
Ok(Self {
detector,
source_rate,
pending_resample_pos: 0.0,
window: Vec::with_capacity(SILERO_CHUNK_SIZE * 2),
last_probability: 0.0,
is_speaking: false,
silence_onset_at: None,
turn_end_ms: config.turn_end_ms as u64,
boost: 0.0,
})
}
fn current_threshold(&self) -> f32 {
DEFAULT_THRESHOLD + self.boost
}
}
impl VadBackend for SileroVadBackend {
fn process_samples(&mut self, samples: &[f32]) {
if samples.is_empty() {
return;
}
let ratio = self.source_rate as f64 / SILERO_SAMPLE_RATE as f64;
let mut src_pos = self.pending_resample_pos;
while src_pos < samples.len() as f64 {
let lo = src_pos.floor() as usize;
let hi = (lo + 1).min(samples.len() - 1);
let t = (src_pos - lo as f64) as f32;
let s = samples[lo] * (1.0 - t) + samples[hi] * t;
self.window.push(s);
src_pos += ratio;
}
self.pending_resample_pos = src_pos - samples.len() as f64;
let was_speaking = self.is_speaking;
let threshold = self.current_threshold();
while self.window.len() >= SILERO_CHUNK_SIZE {
let chunk: Vec<f32> = self.window.drain(..SILERO_CHUNK_SIZE).collect();
self.last_probability = self.detector.predict(chunk);
self.is_speaking = self.last_probability >= threshold;
}
let now = Instant::now();
if self.is_speaking {
self.silence_onset_at = None;
} else if was_speaking && !self.is_speaking {
self.silence_onset_at = Some(now);
}
}
fn is_speech_active(&self) -> bool {
self.is_speaking
}
fn turn_ended(&self) -> bool {
match self.silence_onset_at {
Some(at) => at.elapsed().as_millis() as u64 >= self.turn_end_ms,
None => false,
}
}
fn reset_speech_state(&mut self) {
self.is_speaking = false;
self.silence_onset_at = None;
self.window.clear();
self.pending_resample_pos = 0.0;
}
fn set_threshold_boost(&mut self, db: f32) {
self.boost = if db > 0.0 {
BARGE_IN_PROBABILITY_OFFSET
} else {
0.0
};
}
}
#[derive(Debug, thiserror::Error)]
pub enum SileroError {
#[error("failed to load Silero model: {0}")]
ModelLoad(String),
#[error("Silero inference failed: {0}")]
Inference(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn energy_backend_round_trips_through_trait() {
let cfg = VoiceConfig::default();
let mut be = EnergyVadBackend::from_config(16_000, &cfg);
assert!(!be.is_speech_active());
be.set_threshold_boost(18.0);
be.reset_speech_state();
}
#[test]
fn silero_backend_constructs() {
let cfg = VoiceConfig::default();
let result = SileroVadBackend::new(44_100, &cfg);
assert!(result.is_ok(), "silero backend should build");
}
#[test]
fn silero_silent_input_is_not_speech() {
let cfg = VoiceConfig::default();
let mut be = SileroVadBackend::new(16_000, &cfg).unwrap();
let silence = vec![0.0_f32; 16_000]; be.process_samples(&silence);
assert!(!be.is_speech_active());
}
#[test]
fn silero_threshold_boost_clears_to_zero() {
let cfg = VoiceConfig::default();
let mut be = SileroVadBackend::new(16_000, &cfg).unwrap();
be.set_threshold_boost(18.0);
be.set_threshold_boost(0.0);
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
#[test]
#[ignore = "downloads mlx-community/silero-vad-v6 when it is not already cached"]
fn mlx_silero_v6_silent_input_is_not_speech() {
let cfg = VoiceConfig::default();
let mut be = MlxSileroVadBackend::new(16_000, &cfg).unwrap();
let silence = vec![0.0_f32; 16_000];
be.process_samples(&silence);
assert!(!be.is_speech_active());
}
}