use crate::error::{Result, WhisperError};
use std::path::Path;
use whisper_cpp_plus_sys as ffi;
#[derive(Debug, Clone)]
pub struct VadParams {
pub threshold: f32,
pub min_speech_duration_ms: i32,
pub min_silence_duration_ms: i32,
pub max_speech_duration_s: f32,
pub speech_pad_ms: i32,
pub samples_overlap: f32,
}
impl Default for VadParams {
fn default() -> Self {
let default_params = unsafe { ffi::whisper_vad_default_params() };
Self {
threshold: default_params.threshold,
min_speech_duration_ms: default_params.min_speech_duration_ms,
min_silence_duration_ms: default_params.min_silence_duration_ms,
max_speech_duration_s: default_params.max_speech_duration_s,
speech_pad_ms: default_params.speech_pad_ms,
samples_overlap: default_params.samples_overlap,
}
}
}
impl VadParams {
fn to_ffi(&self) -> ffi::whisper_vad_params {
ffi::whisper_vad_params {
threshold: self.threshold,
min_speech_duration_ms: self.min_speech_duration_ms,
min_silence_duration_ms: self.min_silence_duration_ms,
max_speech_duration_s: self.max_speech_duration_s,
speech_pad_ms: self.speech_pad_ms,
samples_overlap: self.samples_overlap,
}
}
}
#[derive(Debug, Clone)]
pub struct VadContextParams {
pub n_threads: i32,
pub use_gpu: bool,
pub gpu_device: i32,
}
impl Default for VadContextParams {
fn default() -> Self {
let default_params = unsafe { ffi::whisper_vad_default_context_params() };
Self {
n_threads: default_params.n_threads,
use_gpu: default_params.use_gpu,
gpu_device: default_params.gpu_device,
}
}
}
impl VadContextParams {
fn to_ffi(&self) -> ffi::whisper_vad_context_params {
ffi::whisper_vad_context_params {
n_threads: self.n_threads,
use_gpu: self.use_gpu,
gpu_device: self.gpu_device,
}
}
}
pub struct WhisperVadProcessor {
ctx: *mut ffi::whisper_vad_context,
}
unsafe impl Send for WhisperVadProcessor {}
unsafe impl Sync for WhisperVadProcessor {}
impl Drop for WhisperVadProcessor {
fn drop(&mut self) {
unsafe {
if !self.ctx.is_null() {
ffi::whisper_vad_free(self.ctx);
}
}
}
}
impl WhisperVadProcessor {
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
Self::new_with_params(model_path, VadContextParams::default())
}
pub fn new_with_params<P: AsRef<Path>>(
model_path: P,
params: VadContextParams,
) -> Result<Self> {
let path_str = model_path
.as_ref()
.to_str()
.ok_or_else(|| WhisperError::ModelLoadError("Invalid path".into()))?;
let c_path = std::ffi::CString::new(path_str)?;
let ctx = unsafe {
ffi::whisper_vad_init_from_file_with_params(c_path.as_ptr(), params.to_ffi())
};
if ctx.is_null() {
return Err(WhisperError::ModelLoadError(
"Failed to load VAD model".into(),
));
}
Ok(Self { ctx })
}
pub fn detect_speech(&mut self, samples: &[f32]) -> bool {
if samples.is_empty() {
return false;
}
unsafe {
ffi::whisper_vad_detect_speech(
self.ctx,
samples.as_ptr(),
samples.len() as i32,
)
}
}
pub fn n_probs(&self) -> i32 {
unsafe { ffi::whisper_vad_n_probs(self.ctx) }
}
pub fn get_probs(&self) -> Vec<f32> {
let n = self.n_probs();
if n == 0 {
return Vec::new();
}
let probs_ptr = unsafe { ffi::whisper_vad_probs(self.ctx) };
if probs_ptr.is_null() {
return Vec::new();
}
let slice = unsafe { std::slice::from_raw_parts(probs_ptr, n as usize) };
slice.to_vec()
}
pub fn segments_from_probs(&mut self, params: &VadParams) -> Result<VadSegments> {
let segments_ptr = unsafe {
ffi::whisper_vad_segments_from_probs(self.ctx, params.to_ffi())
};
if segments_ptr.is_null() {
return Err(WhisperError::InvalidContext);
}
Ok(VadSegments {
ptr: segments_ptr,
})
}
pub fn segments_from_samples(
&mut self,
samples: &[f32],
params: &VadParams,
) -> Result<VadSegments> {
if samples.is_empty() {
return Err(WhisperError::InvalidAudioFormat);
}
let segments_ptr = unsafe {
ffi::whisper_vad_segments_from_samples(
self.ctx,
params.to_ffi(),
samples.as_ptr(),
samples.len() as i32,
)
};
if segments_ptr.is_null() {
return Err(WhisperError::InvalidContext);
}
Ok(VadSegments {
ptr: segments_ptr,
})
}
}
pub struct VadSegments {
ptr: *mut ffi::whisper_vad_segments,
}
impl Drop for VadSegments {
fn drop(&mut self) {
unsafe {
if !self.ptr.is_null() {
ffi::whisper_vad_free_segments(self.ptr);
}
}
}
}
impl VadSegments {
pub fn n_segments(&self) -> i32 {
unsafe { ffi::whisper_vad_segments_n_segments(self.ptr) }
}
pub fn get_segment_t0(&self, i_segment: i32) -> f32 {
unsafe { ffi::whisper_vad_segments_get_segment_t0(self.ptr, i_segment) / 100.0 }
}
pub fn get_segment_t1(&self, i_segment: i32) -> f32 {
unsafe { ffi::whisper_vad_segments_get_segment_t1(self.ptr, i_segment) / 100.0 }
}
pub fn get_all_segments(&self) -> Vec<(f32, f32)> {
let n = self.n_segments();
let mut segments = Vec::with_capacity(n as usize);
for i in 0..n {
segments.push((self.get_segment_t0(i), self.get_segment_t1(i)));
}
segments
}
pub fn extract_audio_segments(&self, audio: &[f32], sample_rate: f32) -> Vec<Vec<f32>> {
let segments = self.get_all_segments();
let mut audio_segments = Vec::with_capacity(segments.len());
for (start, end) in segments {
let start_sample = (start * sample_rate) as usize;
let end_sample = (end * sample_rate) as usize;
if start_sample < audio.len() && end_sample <= audio.len() {
audio_segments.push(audio[start_sample..end_sample].to_vec());
}
}
audio_segments
}
}
pub struct VadParamsBuilder {
params: VadParams,
}
impl VadParamsBuilder {
pub fn new() -> Self {
Self {
params: VadParams::default(),
}
}
pub fn threshold(mut self, threshold: f32) -> Self {
self.params.threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn min_speech_duration_ms(mut self, ms: i32) -> Self {
self.params.min_speech_duration_ms = ms.max(0);
self
}
pub fn min_silence_duration_ms(mut self, ms: i32) -> Self {
self.params.min_silence_duration_ms = ms.max(0);
self
}
pub fn max_speech_duration_s(mut self, seconds: f32) -> Self {
self.params.max_speech_duration_s = seconds.max(0.0);
self
}
pub fn speech_pad_ms(mut self, ms: i32) -> Self {
self.params.speech_pad_ms = ms.max(0);
self
}
pub fn samples_overlap(mut self, overlap: f32) -> Self {
self.params.samples_overlap = overlap.max(0.0);
self
}
pub fn build(self) -> VadParams {
self.params
}
}
impl Default for VadParamsBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vad_params_default() {
let params = VadParams::default();
assert!(params.threshold > 0.0 && params.threshold < 1.0);
assert!(params.min_speech_duration_ms >= 0);
assert!(params.max_speech_duration_s > 0.0);
}
#[test]
fn test_vad_params_builder() {
let params = VadParamsBuilder::new()
.threshold(0.6)
.min_speech_duration_ms(250)
.min_silence_duration_ms(100)
.max_speech_duration_s(30.0)
.speech_pad_ms(100)
.build();
assert_eq!(params.threshold, 0.6);
assert_eq!(params.min_speech_duration_ms, 250);
assert_eq!(params.min_silence_duration_ms, 100);
assert_eq!(params.max_speech_duration_s, 30.0);
assert_eq!(params.speech_pad_ms, 100);
}
#[test]
fn test_vad_params_builder_clamps() {
let params = VadParamsBuilder::new()
.threshold(1.5) .min_speech_duration_ms(-100) .build();
assert_eq!(params.threshold, 1.0);
assert_eq!(params.min_speech_duration_ms, 0);
}
#[test]
fn test_vad_processor_creation() {
let model_path = "tests/models/ggml-silero-vad.bin";
if Path::new(model_path).exists() {
let processor = WhisperVadProcessor::new(model_path);
assert!(processor.is_ok());
} else {
eprintln!("Skipping VAD processor creation test: model not found");
}
}
#[test]
fn test_vad_context_params() {
let params = VadContextParams::default();
assert!(params.n_threads > 0);
let custom_params = VadContextParams {
n_threads: 4,
use_gpu: true,
gpu_device: 0,
};
assert_eq!(custom_params.n_threads, 4);
assert!(custom_params.use_gpu);
assert_eq!(custom_params.gpu_device, 0);
}
}