pub(crate) mod cmvn;
pub(crate) mod fbank;
use super::onnx;
use crate::error::VadError;
use crate::{ProcessTimings, VadCapabilities, VoiceActivityDetector};
use cmvn::CmvnStats;
use fbank::FbankExtractor;
use ndarray::Array4;
use ort::{inputs, session::Session, value::TensorRef};
use std::time::{Duration, Instant};
const MODEL_BYTES: &[u8] = include_bytes!(concat!(
env!("OUT_DIR"),
"/fireredvad_stream_vad_with_cache.onnx"
));
const CMVN_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/firered_cmvn.ark"));
const SAMPLE_RATE: u32 = 16000;
const FRAME_SHIFT: usize = fbank::FRAME_SHIFT;
const FRAME_LENGTH: usize = 400;
const N_MEL: usize = 80;
const CACHE_LAYERS: usize = 8;
const CACHE_BATCH: usize = 1;
const CACHE_PROJ: usize = 128;
const CACHE_LOOKBACK: usize = 19;
pub struct FireRedVad {
session: Session,
fbank: FbankExtractor,
cmvn: CmvnStats,
caches: Array4<f32>,
sample_buffer: Vec<f32>,
frame_count: usize,
fbank_time: Duration,
cmvn_time: Duration,
onnx_time: Duration,
timing_frames: u64,
}
unsafe impl Send for FireRedVad {}
impl FireRedVad {
pub fn new() -> Result<Self, VadError> {
let cmvn = CmvnStats::from_kaldi_binary(CMVN_BYTES)?;
Self::from_session(onnx::session_from_memory(MODEL_BYTES)?, cmvn)
}
pub fn from_file<P: AsRef<std::path::Path>>(
model_path: P,
cmvn_path: P,
) -> Result<Self, VadError> {
let cmvn_data = std::fs::read(cmvn_path.as_ref()).map_err(|e| {
VadError::BackendError(format!(
"failed to read CMVN file '{}': {e}",
cmvn_path.as_ref().display()
))
})?;
let cmvn = CmvnStats::from_kaldi_binary(&cmvn_data)?;
Self::from_session(onnx::session_from_file(model_path)?, cmvn)
}
pub fn from_memory(model_bytes: &[u8], cmvn_bytes: &[u8]) -> Result<Self, VadError> {
let cmvn = CmvnStats::from_kaldi_binary(cmvn_bytes)?;
Self::from_session(onnx::session_from_memory(model_bytes)?, cmvn)
}
fn from_session(session: Session, cmvn: CmvnStats) -> Result<Self, VadError> {
Ok(Self {
session,
fbank: FbankExtractor::new(),
cmvn,
caches: Array4::<f32>::zeros((CACHE_LAYERS, CACHE_BATCH, CACHE_PROJ, CACHE_LOOKBACK)),
sample_buffer: Vec::with_capacity(FRAME_LENGTH),
frame_count: 0,
fbank_time: Duration::ZERO,
cmvn_time: Duration::ZERO,
onnx_time: Duration::ZERO,
timing_frames: 0,
})
}
fn run_inference(&mut self, features: &[f32; N_MEL]) -> Result<f32, VadError> {
let feat_tensor = TensorRef::from_array_view(([1i64, 1, N_MEL as i64], &features[..]))
.map_err(|e| VadError::BackendError(format!("failed to create feature tensor: {e}")))?;
let cache_tensor = TensorRef::from_array_view(self.caches.view())
.map_err(|e| VadError::BackendError(format!("failed to create cache tensor: {e}")))?;
let outputs = self
.session
.run(inputs![
"feat" => feat_tensor,
"caches_in" => cache_tensor,
])
.map_err(|e| VadError::BackendError(format!("inference failed: {e}")))?;
let probs = outputs
.get("probs")
.ok_or_else(|| VadError::BackendError("missing 'probs' tensor".into()))?;
let (_, probs_data): (_, &[f32]) = probs
.try_extract_tensor()
.map_err(|e| VadError::BackendError(format!("failed to extract probs: {e}")))?;
let probability = *probs_data
.first()
.ok_or_else(|| VadError::BackendError("empty probs tensor".into()))?;
let new_caches = outputs
.get("caches_out")
.ok_or_else(|| VadError::BackendError("missing 'caches_out' tensor".into()))?;
let (_, cache_data): (_, &[f32]) = new_caches
.try_extract_tensor()
.map_err(|e| VadError::BackendError(format!("failed to extract caches: {e}")))?;
let expected_cache_size = CACHE_LAYERS * CACHE_BATCH * CACHE_PROJ * CACHE_LOOKBACK;
if cache_data.len() == expected_cache_size {
self.caches
.as_slice_mut()
.ok_or_else(|| VadError::BackendError("cache buffer not contiguous".into()))?
.copy_from_slice(cache_data);
} else {
return Err(VadError::BackendError(format!(
"unexpected cache size: expected {expected_cache_size}, got {}",
cache_data.len()
)));
}
Ok(probability.clamp(0.0, 1.0))
}
}
impl VoiceActivityDetector for FireRedVad {
fn capabilities(&self) -> VadCapabilities {
VadCapabilities {
sample_rate: SAMPLE_RATE,
frame_size: FRAME_SHIFT,
frame_duration_ms: (FRAME_SHIFT as u32 * 1000) / SAMPLE_RATE,
}
}
fn process(&mut self, samples: &[i16], sample_rate: u32) -> Result<f32, VadError> {
if sample_rate != SAMPLE_RATE {
return Err(VadError::InvalidSampleRate(sample_rate));
}
if samples.len() != FRAME_SHIFT {
return Err(VadError::InvalidFrameSize {
got: samples.len(),
expected: FRAME_SHIFT,
});
}
for &s in samples {
self.sample_buffer.push(s as f32);
}
let needed = if self.frame_count == 0 {
FRAME_LENGTH } else {
FRAME_SHIFT };
if self.sample_buffer.len() < needed {
return Ok(0.0);
}
let t_fbank = Instant::now();
let mut fbank_features = [0.0f32; N_MEL];
if self.frame_count == 0 {
let frame: &[f32; FRAME_LENGTH] = self.sample_buffer[..FRAME_LENGTH]
.try_into()
.map_err(|_| VadError::BackendError("buffer size mismatch".into()))?;
self.fbank.extract_frame_full(frame, &mut fbank_features);
let drain_len = FRAME_SHIFT;
self.sample_buffer.drain(..drain_len);
} else {
self.fbank
.extract_frame(&self.sample_buffer[..FRAME_SHIFT], &mut fbank_features);
self.sample_buffer.drain(..FRAME_SHIFT);
}
self.frame_count += 1;
self.fbank_time += t_fbank.elapsed();
let t_cmvn = Instant::now();
self.cmvn.normalize(&mut fbank_features);
self.cmvn_time += t_cmvn.elapsed();
let t_onnx = Instant::now();
let result = self.run_inference(&fbank_features);
self.onnx_time += t_onnx.elapsed();
self.timing_frames += 1;
result
}
fn reset(&mut self) {
self.fbank.reset();
self.caches.fill(0.0);
self.sample_buffer.clear();
self.frame_count = 0;
}
fn timings(&self) -> ProcessTimings {
ProcessTimings {
stages: vec![
("fbank", self.fbank_time),
("cmvn", self.cmvn_time),
("onnx", self.onnx_time),
],
frames: self.timing_frames,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_succeeds() {
let vad = FireRedVad::new();
assert!(vad.is_ok(), "Failed to create FireRedVad: {:?}", vad.err());
}
#[test]
fn capabilities() {
let vad = FireRedVad::new().unwrap();
let caps = vad.capabilities();
assert_eq!(caps.sample_rate, 16000);
assert_eq!(caps.frame_size, 160);
assert_eq!(caps.frame_duration_ms, 10);
}
#[test]
fn process_silence() {
let mut vad = FireRedVad::new().unwrap();
let silence = vec![0i16; 160];
let _ = vad.process(&silence, 16000).unwrap(); let _ = vad.process(&silence, 16000).unwrap(); let prob = vad.process(&silence, 16000).unwrap();
assert!(
prob >= 0.0 && prob <= 1.0,
"Probability out of range: {prob}"
);
}
#[test]
fn process_wrong_sample_rate() {
let mut vad = FireRedVad::new().unwrap();
let samples = vec![0i16; 160];
let result = vad.process(&samples, 8000);
assert!(matches!(result, Err(VadError::InvalidSampleRate(8000))));
}
#[test]
fn process_wrong_frame_size() {
let mut vad = FireRedVad::new().unwrap();
let samples = vec![0i16; 100];
let result = vad.process(&samples, 16000);
assert!(matches!(
result,
Err(VadError::InvalidFrameSize {
got: 100,
expected: 160
})
));
}
#[test]
fn reset_works() {
let mut vad = FireRedVad::new().unwrap();
let samples: Vec<i16> = (0..160).map(|i| (i * 10) as i16).collect();
let _ = vad.process(&samples, 16000).unwrap();
let _ = vad.process(&samples, 16000).unwrap();
let _ = vad.process(&samples, 16000).unwrap();
vad.reset();
let silence = vec![0i16; 160];
let result = vad.process(&silence, 16000);
assert!(result.is_ok());
}
#[test]
fn multiple_frames() {
let mut vad = FireRedVad::new().unwrap();
let silence = vec![0i16; 160];
for _ in 0..10 {
let result = vad.process(&silence, 16000);
assert!(result.is_ok());
let prob = result.unwrap();
assert!(prob >= 0.0 && prob <= 1.0);
}
}
#[test]
fn from_memory_with_embedded_model() {
let vad = FireRedVad::from_memory(MODEL_BYTES, CMVN_BYTES);
assert!(vad.is_ok(), "from_memory failed: {:?}", vad.err());
}
#[test]
fn from_memory_invalid_bytes() {
let result = FireRedVad::from_memory(b"not a valid onnx model", CMVN_BYTES);
assert!(result.is_err());
assert!(matches!(result, Err(VadError::BackendError(_))));
}
#[test]
fn from_file_nonexistent() {
let result = FireRedVad::from_file("/nonexistent/model.onnx", "/nonexistent/cmvn.ark");
assert!(result.is_err());
assert!(matches!(result, Err(VadError::BackendError(_))));
}
#[test]
fn probabilities_match_python_reference() {
let samples_json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/firered_reference/ref_samples.json"
));
let samples_data: serde_json::Value = serde_json::from_str(samples_json).unwrap();
let samples: Vec<i16> = serde_json::from_value(samples_data["samples"].clone()).unwrap();
let probs_json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/firered_reference/ref_probs.json"
));
let probs_data: serde_json::Value = serde_json::from_str(probs_json).unwrap();
let ref_probs: Vec<f64> = serde_json::from_value(probs_data["probs"].clone()).unwrap();
let cmvn = CmvnStats::from_kaldi_binary(CMVN_BYTES).unwrap();
let mut session = onnx::session_from_memory(MODEL_BYTES).unwrap();
let mut fbank = FbankExtractor::new();
let mut caches =
Array4::<f32>::zeros((CACHE_LAYERS, CACHE_BATCH, CACHE_PROJ, CACHE_LOOKBACK));
let num_frames = (samples.len() - 400) / 160 + 1;
assert_eq!(num_frames, ref_probs.len());
let mut max_diff: f64 = 0.0;
for frame_idx in 0..num_frames {
let start = frame_idx * 160;
let end = start + 400;
let frame_samples: Vec<f32> = samples[start..end].iter().map(|&s| s as f32).collect();
let frame_arr: &[f32; 400] = frame_samples.as_slice().try_into().unwrap();
let mut features = [0.0f32; 80];
fbank.extract_frame_full(frame_arr, &mut features);
cmvn.normalize(&mut features);
let feat_tensor = TensorRef::from_array_view(([1i64, 1, 80], &features[..])).unwrap();
let cache_tensor = TensorRef::from_array_view(caches.view()).unwrap();
let outputs = session
.run(inputs![
"feat" => feat_tensor,
"caches_in" => cache_tensor,
])
.unwrap();
let probs = outputs.get("probs").unwrap();
let (_, probs_data): (_, &[f32]) = probs.try_extract_tensor().unwrap();
let probability = probs_data[0];
let new_caches = outputs.get("caches_out").unwrap();
let (_, cache_data): (_, &[f32]) = new_caches.try_extract_tensor().unwrap();
caches.as_slice_mut().unwrap().copy_from_slice(cache_data);
let diff = (probability as f64 - ref_probs[frame_idx]).abs();
if diff > max_diff {
max_diff = diff;
}
if frame_idx < 5 {
eprintln!(
" frame {frame_idx}: rust={probability:.6}, python={:.6}, diff={diff:.8}",
ref_probs[frame_idx]
);
}
}
eprintln!("Max probability diff vs Python: {max_diff:.8}");
assert!(
max_diff < 0.02,
"Probability max diff vs Python: {max_diff:.8} (tolerance: 0.02)"
);
}
#[test]
fn probabilities_match_upstream_fireredvad() {
let samples_json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/firered_reference/ref_samples.json"
));
let samples_data: serde_json::Value = serde_json::from_str(samples_json).unwrap();
let samples: Vec<i16> = serde_json::from_value(samples_data["samples"].clone()).unwrap();
let upstream_json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/firered_reference/ref_upstream_probs.json"
));
let upstream_data: serde_json::Value = serde_json::from_str(upstream_json).unwrap();
let upstream_probs: Vec<f64> =
serde_json::from_value(upstream_data["probs"].clone()).unwrap();
let cmvn = CmvnStats::from_kaldi_binary(CMVN_BYTES).unwrap();
let mut session = onnx::session_from_memory(MODEL_BYTES).unwrap();
let mut fbank = FbankExtractor::new();
let mut caches =
Array4::<f32>::zeros((CACHE_LAYERS, CACHE_BATCH, CACHE_PROJ, CACHE_LOOKBACK));
let num_frames = (samples.len() - 400) / 160 + 1;
assert_eq!(num_frames, upstream_probs.len());
let mut max_diff: f64 = 0.0;
for frame_idx in 0..num_frames {
let start = frame_idx * 160;
let end = start + 400;
let frame_samples: Vec<f32> = samples[start..end].iter().map(|&s| s as f32).collect();
let frame_arr: &[f32; 400] = frame_samples.as_slice().try_into().unwrap();
let mut features = [0.0f32; 80];
fbank.extract_frame_full(frame_arr, &mut features);
cmvn.normalize(&mut features);
let feat_tensor = TensorRef::from_array_view(([1i64, 1, 80], &features[..])).unwrap();
let cache_tensor = TensorRef::from_array_view(caches.view()).unwrap();
let outputs = session
.run(inputs![
"feat" => feat_tensor,
"caches_in" => cache_tensor,
])
.unwrap();
let probs = outputs.get("probs").unwrap();
let (_, probs_data): (_, &[f32]) = probs.try_extract_tensor().unwrap();
let probability = probs_data[0];
let new_caches = outputs.get("caches_out").unwrap();
let (_, cache_data): (_, &[f32]) = new_caches.try_extract_tensor().unwrap();
caches.as_slice_mut().unwrap().copy_from_slice(cache_data);
let diff = (probability as f64 - upstream_probs[frame_idx]).abs();
if diff > max_diff {
max_diff = diff;
}
if frame_idx < 5 {
eprintln!(
" frame {frame_idx}: rust={probability:.6}, upstream={:.6}, diff={diff:.8}",
upstream_probs[frame_idx]
);
}
}
eprintln!("Max probability diff vs upstream FireRedVAD: {max_diff:.8}");
assert!(
max_diff < 0.02,
"Probability max diff vs upstream: {max_diff:.8} (tolerance: 0.02)"
);
}
}