1use crate::error::{Result, WhisperError};
7use std::path::Path;
8use whisper_cpp_plus_sys as ffi;
9
10#[derive(Debug, Clone)]
12pub struct VadParams {
13 pub threshold: f32,
15 pub min_speech_duration_ms: i32,
17 pub min_silence_duration_ms: i32,
19 pub max_speech_duration_s: f32,
21 pub speech_pad_ms: i32,
23 pub samples_overlap: f32,
25}
26
27impl Default for VadParams {
28 fn default() -> Self {
29 let default_params = unsafe { ffi::whisper_vad_default_params() };
31
32 Self {
33 threshold: default_params.threshold,
34 min_speech_duration_ms: default_params.min_speech_duration_ms,
35 min_silence_duration_ms: default_params.min_silence_duration_ms,
36 max_speech_duration_s: default_params.max_speech_duration_s,
37 speech_pad_ms: default_params.speech_pad_ms,
38 samples_overlap: default_params.samples_overlap,
39 }
40 }
41}
42
43impl VadParams {
44 fn to_ffi(&self) -> ffi::whisper_vad_params {
46 ffi::whisper_vad_params {
47 threshold: self.threshold,
48 min_speech_duration_ms: self.min_speech_duration_ms,
49 min_silence_duration_ms: self.min_silence_duration_ms,
50 max_speech_duration_s: self.max_speech_duration_s,
51 speech_pad_ms: self.speech_pad_ms,
52 samples_overlap: self.samples_overlap,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct VadContextParams {
60 pub n_threads: i32,
62 pub use_gpu: bool,
64 pub gpu_device: i32,
66}
67
68impl Default for VadContextParams {
69 fn default() -> Self {
70 let default_params = unsafe { ffi::whisper_vad_default_context_params() };
71
72 Self {
73 n_threads: default_params.n_threads,
74 use_gpu: default_params.use_gpu,
75 gpu_device: default_params.gpu_device,
76 }
77 }
78}
79
80impl VadContextParams {
81 fn to_ffi(&self) -> ffi::whisper_vad_context_params {
83 ffi::whisper_vad_context_params {
84 n_threads: self.n_threads,
85 use_gpu: self.use_gpu,
86 gpu_device: self.gpu_device,
87 }
88 }
89}
90
91pub struct WhisperVadProcessor {
93 ctx: *mut ffi::whisper_vad_context,
94}
95
96unsafe impl Send for WhisperVadProcessor {}
97unsafe impl Sync for WhisperVadProcessor {}
98
99impl Drop for WhisperVadProcessor {
100 fn drop(&mut self) {
101 unsafe {
102 if !self.ctx.is_null() {
103 ffi::whisper_vad_free(self.ctx);
104 }
105 }
106 }
107}
108
109impl WhisperVadProcessor {
110 pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
112 Self::new_with_params(model_path, VadContextParams::default())
113 }
114
115 pub fn new_with_params<P: AsRef<Path>>(
117 model_path: P,
118 params: VadContextParams,
119 ) -> Result<Self> {
120 let path_str = model_path
121 .as_ref()
122 .to_str()
123 .ok_or_else(|| WhisperError::ModelLoadError("Invalid path".into()))?;
124
125 let c_path = std::ffi::CString::new(path_str)?;
126
127 let ctx = unsafe {
128 ffi::whisper_vad_init_from_file_with_params(c_path.as_ptr(), params.to_ffi())
129 };
130
131 if ctx.is_null() {
132 return Err(WhisperError::ModelLoadError(
133 "Failed to load VAD model".into(),
134 ));
135 }
136
137 Ok(Self { ctx })
138 }
139
140 pub fn detect_speech(&mut self, samples: &[f32]) -> bool {
142 if samples.is_empty() {
143 return false;
144 }
145
146 unsafe { ffi::whisper_vad_detect_speech(self.ctx, samples.as_ptr(), samples.len() as i32) }
147 }
148
149 pub fn n_probs(&self) -> i32 {
151 unsafe { ffi::whisper_vad_n_probs(self.ctx) }
152 }
153
154 pub fn get_probs(&self) -> Vec<f32> {
156 let n = self.n_probs();
157 if n == 0 {
158 return Vec::new();
159 }
160
161 let probs_ptr = unsafe { ffi::whisper_vad_probs(self.ctx) };
162 if probs_ptr.is_null() {
163 return Vec::new();
164 }
165
166 let slice = unsafe { std::slice::from_raw_parts(probs_ptr, n as usize) };
167 slice.to_vec()
168 }
169
170 pub fn segments_from_probs(&mut self, params: &VadParams) -> Result<VadSegments> {
172 let segments_ptr =
173 unsafe { ffi::whisper_vad_segments_from_probs(self.ctx, params.to_ffi()) };
174
175 if segments_ptr.is_null() {
176 return Err(WhisperError::InvalidContext);
177 }
178
179 Ok(VadSegments { ptr: segments_ptr })
180 }
181
182 pub fn segments_from_samples(
184 &mut self,
185 samples: &[f32],
186 params: &VadParams,
187 ) -> Result<VadSegments> {
188 if samples.is_empty() {
189 return Err(WhisperError::InvalidAudioFormat);
190 }
191
192 let segments_ptr = unsafe {
193 ffi::whisper_vad_segments_from_samples(
194 self.ctx,
195 params.to_ffi(),
196 samples.as_ptr(),
197 samples.len() as i32,
198 )
199 };
200
201 if segments_ptr.is_null() {
202 return Err(WhisperError::InvalidContext);
203 }
204
205 Ok(VadSegments { ptr: segments_ptr })
206 }
207}
208
209pub struct VadSegments {
211 ptr: *mut ffi::whisper_vad_segments,
212}
213
214impl Drop for VadSegments {
215 fn drop(&mut self) {
216 unsafe {
217 if !self.ptr.is_null() {
218 ffi::whisper_vad_free_segments(self.ptr);
219 }
220 }
221 }
222}
223
224impl VadSegments {
225 pub fn n_segments(&self) -> i32 {
227 unsafe { ffi::whisper_vad_segments_n_segments(self.ptr) }
228 }
229
230 pub fn get_segment_t0(&self, i_segment: i32) -> f32 {
232 unsafe { ffi::whisper_vad_segments_get_segment_t0(self.ptr, i_segment) / 100.0 }
234 }
235
236 pub fn get_segment_t1(&self, i_segment: i32) -> f32 {
238 unsafe { ffi::whisper_vad_segments_get_segment_t1(self.ptr, i_segment) / 100.0 }
240 }
241
242 pub fn get_all_segments(&self) -> Vec<(f32, f32)> {
244 let n = self.n_segments();
245 let mut segments = Vec::with_capacity(n as usize);
246
247 for i in 0..n {
248 segments.push((self.get_segment_t0(i), self.get_segment_t1(i)));
249 }
250
251 segments
252 }
253
254 pub fn extract_audio_segments(&self, audio: &[f32], sample_rate: f32) -> Vec<Vec<f32>> {
256 let segments = self.get_all_segments();
257 let mut audio_segments = Vec::with_capacity(segments.len());
258
259 for (start, end) in segments {
260 let start_sample = (start * sample_rate) as usize;
261 let end_sample = (end * sample_rate) as usize;
262
263 if start_sample < audio.len() && end_sample <= audio.len() {
264 audio_segments.push(audio[start_sample..end_sample].to_vec());
265 }
266 }
267
268 audio_segments
269 }
270}
271
272pub struct VadParamsBuilder {
274 params: VadParams,
275}
276
277impl VadParamsBuilder {
278 pub fn new() -> Self {
280 Self {
281 params: VadParams::default(),
282 }
283 }
284
285 pub fn threshold(mut self, threshold: f32) -> Self {
287 self.params.threshold = threshold.clamp(0.0, 1.0);
288 self
289 }
290
291 pub fn min_speech_duration_ms(mut self, ms: i32) -> Self {
293 self.params.min_speech_duration_ms = ms.max(0);
294 self
295 }
296
297 pub fn min_silence_duration_ms(mut self, ms: i32) -> Self {
299 self.params.min_silence_duration_ms = ms.max(0);
300 self
301 }
302
303 pub fn max_speech_duration_s(mut self, seconds: f32) -> Self {
305 self.params.max_speech_duration_s = seconds.max(0.0);
306 self
307 }
308
309 pub fn speech_pad_ms(mut self, ms: i32) -> Self {
311 self.params.speech_pad_ms = ms.max(0);
312 self
313 }
314
315 pub fn samples_overlap(mut self, overlap: f32) -> Self {
317 self.params.samples_overlap = overlap.max(0.0);
318 self
319 }
320
321 pub fn build(self) -> VadParams {
323 self.params
324 }
325}
326
327impl Default for VadParamsBuilder {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_vad_params_default() {
339 let params = VadParams::default();
340 assert!(params.threshold > 0.0 && params.threshold < 1.0);
341 assert!(params.min_speech_duration_ms >= 0);
342 assert!(params.max_speech_duration_s > 0.0);
343 }
344
345 #[test]
346 fn test_vad_params_builder() {
347 let params = VadParamsBuilder::new()
348 .threshold(0.6)
349 .min_speech_duration_ms(250)
350 .min_silence_duration_ms(100)
351 .max_speech_duration_s(30.0)
352 .speech_pad_ms(100)
353 .build();
354
355 assert_eq!(params.threshold, 0.6);
356 assert_eq!(params.min_speech_duration_ms, 250);
357 assert_eq!(params.min_silence_duration_ms, 100);
358 assert_eq!(params.max_speech_duration_s, 30.0);
359 assert_eq!(params.speech_pad_ms, 100);
360 }
361
362 #[test]
363 fn test_vad_params_builder_clamps() {
364 let params = VadParamsBuilder::new()
365 .threshold(1.5) .min_speech_duration_ms(-100) .build();
368
369 assert_eq!(params.threshold, 1.0);
370 assert_eq!(params.min_speech_duration_ms, 0);
371 }
372
373 #[test]
374 fn test_vad_processor_creation() {
375 let model_path = "tests/models/ggml-silero-vad.bin";
377 if Path::new(model_path).exists() {
378 let processor = WhisperVadProcessor::new(model_path);
379 assert!(processor.is_ok());
380 } else {
381 eprintln!("Skipping VAD processor creation test: model not found");
382 }
383 }
384
385 #[test]
386 fn test_vad_context_params() {
387 let params = VadContextParams::default();
388 assert!(params.n_threads > 0);
389
390 let custom_params = VadContextParams {
391 n_threads: 4,
392 use_gpu: true,
393 gpu_device: 0,
394 };
395 assert_eq!(custom_params.n_threads, 4);
396 assert!(custom_params.use_gpu);
397 assert_eq!(custom_params.gpu_device, 0);
398 }
399}