whisper-cpp-plus 0.1.4

Safe Rust bindings for whisper.cpp with real-time PCM streaming and VAD support
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
//! Voice Activity Detection (VAD) support
//!
//! This module provides VAD capabilities for detecting speech segments
//! in audio before transcription, improving performance and accuracy.

use crate::error::{Result, WhisperError};
use std::path::Path;
use whisper_cpp_plus_sys as ffi;

/// VAD parameters for speech detection
#[derive(Debug, Clone)]
pub struct VadParams {
    /// Probability threshold to consider as speech (0.0 - 1.0)
    pub threshold: f32,
    /// Minimum duration for a valid speech segment (in milliseconds)
    pub min_speech_duration_ms: i32,
    /// Minimum duration of silence to split segments (in milliseconds)
    pub min_silence_duration_ms: i32,
    /// Maximum speech duration before forcing a segment break (in seconds)
    pub max_speech_duration_s: f32,
    /// Padding added before and after speech segments (in milliseconds)
    pub speech_pad_ms: i32,
    /// Overlap in seconds when copying audio samples from speech segment
    pub samples_overlap: f32,
}

impl Default for VadParams {
    fn default() -> Self {
        // Use whisper.cpp's default VAD parameters
        let default_params = unsafe { ffi::whisper_vad_default_params() };

        Self {
            threshold: default_params.threshold,
            min_speech_duration_ms: default_params.min_speech_duration_ms,
            min_silence_duration_ms: default_params.min_silence_duration_ms,
            max_speech_duration_s: default_params.max_speech_duration_s,
            speech_pad_ms: default_params.speech_pad_ms,
            samples_overlap: default_params.samples_overlap,
        }
    }
}

impl VadParams {
    /// Convert to FFI params
    fn to_ffi(&self) -> ffi::whisper_vad_params {
        ffi::whisper_vad_params {
            threshold: self.threshold,
            min_speech_duration_ms: self.min_speech_duration_ms,
            min_silence_duration_ms: self.min_silence_duration_ms,
            max_speech_duration_s: self.max_speech_duration_s,
            speech_pad_ms: self.speech_pad_ms,
            samples_overlap: self.samples_overlap,
        }
    }
}

/// VAD context parameters
#[derive(Debug, Clone)]
pub struct VadContextParams {
    /// Number of threads to use for processing
    pub n_threads: i32,
    /// Whether to use GPU acceleration
    pub use_gpu: bool,
    /// GPU device ID to use
    pub gpu_device: i32,
}

impl Default for VadContextParams {
    fn default() -> Self {
        let default_params = unsafe { ffi::whisper_vad_default_context_params() };

        Self {
            n_threads: default_params.n_threads,
            use_gpu: default_params.use_gpu,
            gpu_device: default_params.gpu_device,
        }
    }
}

impl VadContextParams {
    /// Convert to FFI params
    fn to_ffi(&self) -> ffi::whisper_vad_context_params {
        ffi::whisper_vad_context_params {
            n_threads: self.n_threads,
            use_gpu: self.use_gpu,
            gpu_device: self.gpu_device,
        }
    }
}

/// Voice Activity Detector
pub struct WhisperVadProcessor {
    ctx: *mut ffi::whisper_vad_context,
}

unsafe impl Send for WhisperVadProcessor {}
unsafe impl Sync for WhisperVadProcessor {}

impl Drop for WhisperVadProcessor {
    fn drop(&mut self) {
        unsafe {
            if !self.ctx.is_null() {
                ffi::whisper_vad_free(self.ctx);
            }
        }
    }
}

impl WhisperVadProcessor {
    /// Create a new VAD processor from a model file
    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
        Self::new_with_params(model_path, VadContextParams::default())
    }

    /// Create a new VAD processor with custom parameters
    pub fn new_with_params<P: AsRef<Path>>(
        model_path: P,
        params: VadContextParams,
    ) -> Result<Self> {
        let path_str = model_path
            .as_ref()
            .to_str()
            .ok_or_else(|| WhisperError::ModelLoadError("Invalid path".into()))?;

        let c_path = std::ffi::CString::new(path_str)?;

        let ctx = unsafe {
            ffi::whisper_vad_init_from_file_with_params(c_path.as_ptr(), params.to_ffi())
        };

        if ctx.is_null() {
            return Err(WhisperError::ModelLoadError(
                "Failed to load VAD model".into(),
            ));
        }

        Ok(Self { ctx })
    }

    /// Detect speech in audio samples
    pub fn detect_speech(&mut self, samples: &[f32]) -> bool {
        if samples.is_empty() {
            return false;
        }

        unsafe {
            ffi::whisper_vad_detect_speech(
                self.ctx,
                samples.as_ptr(),
                samples.len() as i32,
            )
        }
    }

    /// Get the number of probability values
    pub fn n_probs(&self) -> i32 {
        unsafe { ffi::whisper_vad_n_probs(self.ctx) }
    }

    /// Get probability values
    pub fn get_probs(&self) -> Vec<f32> {
        let n = self.n_probs();
        if n == 0 {
            return Vec::new();
        }

        let probs_ptr = unsafe { ffi::whisper_vad_probs(self.ctx) };
        if probs_ptr.is_null() {
            return Vec::new();
        }

        let slice = unsafe { std::slice::from_raw_parts(probs_ptr, n as usize) };
        slice.to_vec()
    }

    /// Get speech segments from probability values
    pub fn segments_from_probs(&mut self, params: &VadParams) -> Result<VadSegments> {
        let segments_ptr = unsafe {
            ffi::whisper_vad_segments_from_probs(self.ctx, params.to_ffi())
        };

        if segments_ptr.is_null() {
            return Err(WhisperError::InvalidContext);
        }

        Ok(VadSegments {
            ptr: segments_ptr,
        })
    }

    /// Get speech segments directly from audio samples
    pub fn segments_from_samples(
        &mut self,
        samples: &[f32],
        params: &VadParams,
    ) -> Result<VadSegments> {
        if samples.is_empty() {
            return Err(WhisperError::InvalidAudioFormat);
        }

        let segments_ptr = unsafe {
            ffi::whisper_vad_segments_from_samples(
                self.ctx,
                params.to_ffi(),
                samples.as_ptr(),
                samples.len() as i32,
            )
        };

        if segments_ptr.is_null() {
            return Err(WhisperError::InvalidContext);
        }

        Ok(VadSegments {
            ptr: segments_ptr,
        })
    }
}

/// Speech segments detected by VAD
pub struct VadSegments {
    ptr: *mut ffi::whisper_vad_segments,
}

impl Drop for VadSegments {
    fn drop(&mut self) {
        unsafe {
            if !self.ptr.is_null() {
                ffi::whisper_vad_free_segments(self.ptr);
            }
        }
    }
}

impl VadSegments {
    /// Get the number of segments
    pub fn n_segments(&self) -> i32 {
        unsafe { ffi::whisper_vad_segments_n_segments(self.ptr) }
    }

    /// Get segment start time in seconds
    pub fn get_segment_t0(&self, i_segment: i32) -> f32 {
        // The FFI returns time in centiseconds, convert to seconds
        unsafe { ffi::whisper_vad_segments_get_segment_t0(self.ptr, i_segment) / 100.0 }
    }

    /// Get segment end time in seconds
    pub fn get_segment_t1(&self, i_segment: i32) -> f32 {
        // The FFI returns time in centiseconds, convert to seconds
        unsafe { ffi::whisper_vad_segments_get_segment_t1(self.ptr, i_segment) / 100.0 }
    }

    /// Get all segments as tuples of (start, end) times in seconds
    pub fn get_all_segments(&self) -> Vec<(f32, f32)> {
        let n = self.n_segments();
        let mut segments = Vec::with_capacity(n as usize);

        for i in 0..n {
            segments.push((self.get_segment_t0(i), self.get_segment_t1(i)));
        }

        segments
    }

    /// Extract audio segments from the original audio based on VAD segments
    pub fn extract_audio_segments(&self, audio: &[f32], sample_rate: f32) -> Vec<Vec<f32>> {
        let segments = self.get_all_segments();
        let mut audio_segments = Vec::with_capacity(segments.len());

        for (start, end) in segments {
            let start_sample = (start * sample_rate) as usize;
            let end_sample = (end * sample_rate) as usize;

            if start_sample < audio.len() && end_sample <= audio.len() {
                audio_segments.push(audio[start_sample..end_sample].to_vec());
            }
        }

        audio_segments
    }
}

/// Builder for VadParams
pub struct VadParamsBuilder {
    params: VadParams,
}

impl VadParamsBuilder {
    /// Create a new builder with default values
    pub fn new() -> Self {
        Self {
            params: VadParams::default(),
        }
    }

    /// Set the probability threshold (0.0 - 1.0)
    pub fn threshold(mut self, threshold: f32) -> Self {
        self.params.threshold = threshold.clamp(0.0, 1.0);
        self
    }

    /// Set minimum speech duration in milliseconds
    pub fn min_speech_duration_ms(mut self, ms: i32) -> Self {
        self.params.min_speech_duration_ms = ms.max(0);
        self
    }

    /// Set minimum silence duration in milliseconds
    pub fn min_silence_duration_ms(mut self, ms: i32) -> Self {
        self.params.min_silence_duration_ms = ms.max(0);
        self
    }

    /// Set maximum speech duration in seconds
    pub fn max_speech_duration_s(mut self, seconds: f32) -> Self {
        self.params.max_speech_duration_s = seconds.max(0.0);
        self
    }

    /// Set speech padding in milliseconds
    pub fn speech_pad_ms(mut self, ms: i32) -> Self {
        self.params.speech_pad_ms = ms.max(0);
        self
    }

    /// Set samples overlap
    pub fn samples_overlap(mut self, overlap: f32) -> Self {
        self.params.samples_overlap = overlap.max(0.0);
        self
    }

    /// Build the parameters
    pub fn build(self) -> VadParams {
        self.params
    }
}

impl Default for VadParamsBuilder {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_vad_params_default() {
        let params = VadParams::default();
        assert!(params.threshold > 0.0 && params.threshold < 1.0);
        assert!(params.min_speech_duration_ms >= 0);
        assert!(params.max_speech_duration_s > 0.0);
    }

    #[test]
    fn test_vad_params_builder() {
        let params = VadParamsBuilder::new()
            .threshold(0.6)
            .min_speech_duration_ms(250)
            .min_silence_duration_ms(100)
            .max_speech_duration_s(30.0)
            .speech_pad_ms(100)
            .build();

        assert_eq!(params.threshold, 0.6);
        assert_eq!(params.min_speech_duration_ms, 250);
        assert_eq!(params.min_silence_duration_ms, 100);
        assert_eq!(params.max_speech_duration_s, 30.0);
        assert_eq!(params.speech_pad_ms, 100);
    }

    #[test]
    fn test_vad_params_builder_clamps() {
        let params = VadParamsBuilder::new()
            .threshold(1.5) // Should be clamped to 1.0
            .min_speech_duration_ms(-100) // Should be clamped to 0
            .build();

        assert_eq!(params.threshold, 1.0);
        assert_eq!(params.min_speech_duration_ms, 0);
    }

    #[test]
    fn test_vad_processor_creation() {
        // This test will only run if a VAD model is available
        let model_path = "tests/models/ggml-silero-vad.bin";
        if Path::new(model_path).exists() {
            let processor = WhisperVadProcessor::new(model_path);
            assert!(processor.is_ok());
        } else {
            eprintln!("Skipping VAD processor creation test: model not found");
        }
    }

    #[test]
    fn test_vad_context_params() {
        let params = VadContextParams::default();
        assert!(params.n_threads > 0);

        let custom_params = VadContextParams {
            n_threads: 4,
            use_gpu: true,
            gpu_device: 0,
        };
        assert_eq!(custom_params.n_threads, 4);
        assert!(custom_params.use_gpu);
        assert_eq!(custom_params.gpu_device, 0);
    }
}