Skip to main content

whisper_cpp_plus/
stream.rs

1//! Streaming transcription — faithful port of stream.cpp
2//!
3//! Replaces SDL audio capture with a push-based `feed_audio()` API
4//! since we're a library, not a binary.
5
6use crate::context::WhisperContext;
7use crate::error::Result;
8use crate::params::FullParams;
9use crate::state::{Segment, WhisperState};
10use std::collections::VecDeque;
11
12const WHISPER_SAMPLE_RATE: i32 = 16000;
13
14// ---------------------------------------------------------------------------
15// WhisperStreamConfig
16// ---------------------------------------------------------------------------
17
18/// Streaming config — maps to stream.cpp's whisper_params (streaming subset).
19#[derive(Debug, Clone)]
20pub struct WhisperStreamConfig {
21    /// Audio step size in ms. Set <= 0 for VAD mode.
22    pub step_ms: i32,
23    /// Audio length per inference in ms.
24    pub length_ms: i32,
25    /// Audio to keep from previous step in ms.
26    pub keep_ms: i32,
27    /// VAD energy threshold.
28    pub vad_thold: f32,
29    /// High-pass frequency cutoff for VAD.
30    pub freq_thold: f32,
31    /// If true, don't carry prompt tokens across boundaries.
32    pub no_context: bool,
33}
34
35impl Default for WhisperStreamConfig {
36    fn default() -> Self {
37        Self {
38            step_ms: 3000,
39            length_ms: 10000,
40            keep_ms: 200,
41            vad_thold: 0.6,
42            freq_thold: 100.0,
43            no_context: true,
44        }
45    }
46}
47
48// ---------------------------------------------------------------------------
49// WhisperStream
50// ---------------------------------------------------------------------------
51
52/// Streaming transcriber — faithful port of stream.cpp main loop.
53///
54/// Two modes:
55/// - **Fixed-step** (`step_ms > 0`): sliding window with overlap.
56/// - **VAD** (`step_ms <= 0`): transcribe on speech activity.
57pub struct WhisperStream {
58    state: WhisperState,
59    params: FullParams,
60    config: WhisperStreamConfig,
61    use_vad: bool,
62
63    // Pre-computed sample counts
64    n_samples_step: usize,
65    n_samples_len: usize,
66    n_samples_keep: usize,
67    n_new_line: i32,
68
69    // Overlap buffer from previous inference
70    pcmf32_old: Vec<f32>,
71    // Context propagation
72    prompt_tokens: Vec<i32>,
73
74    n_iter: i32,
75
76    // Internal audio buffer (replaces SDL capture)
77    audio_buf: VecDeque<f32>,
78
79    // Total samples consumed from audio_buf
80    total_samples_processed: i64,
81}
82
83impl WhisperStream {
84    /// Create with default config.
85    pub fn new(ctx: &WhisperContext, params: FullParams) -> Result<Self> {
86        Self::with_config(ctx, params, WhisperStreamConfig::default())
87    }
88
89    /// Create with custom config.
90    pub fn with_config(
91        ctx: &WhisperContext,
92        mut params: FullParams,
93        mut config: WhisperStreamConfig,
94    ) -> Result<Self> {
95        let state = WhisperState::new(ctx)?;
96
97        // --- Config normalization (stream.cpp main()) ---
98        config.keep_ms = config.keep_ms.min(config.step_ms);
99        config.length_ms = config.length_ms.max(config.step_ms);
100
101        // Sample counts
102        let n_samples_step = (1e-3 * config.step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
103        let n_samples_len = (1e-3 * config.length_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
104        let n_samples_keep = (1e-3 * config.keep_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
105
106        // Mode detection
107        let use_vad = n_samples_step == 0; // step_ms <= 0 → VAD
108
109        // n_new_line: guard against division by zero when step_ms <= 0
110        let n_new_line = if !use_vad {
111            (config.length_ms / config.step_ms - 1).max(1)
112        } else {
113            1
114        };
115
116        // Auto-set mode-dependent FullParams (stream.cpp lines 141-143)
117        params = params
118            .no_timestamps(!use_vad)
119            .max_tokens(0)
120            .single_segment(!use_vad)
121            .print_progress(false)
122            .print_realtime(false);
123
124        // Force no_context in VAD mode: no_context |= use_vad
125        if use_vad {
126            config.no_context = true;
127            params = params.no_context(true);
128        }
129
130        Ok(Self {
131            state,
132            params,
133            config,
134            use_vad,
135            n_samples_step,
136            n_samples_len,
137            n_samples_keep,
138            n_new_line,
139            pcmf32_old: Vec::new(),
140            prompt_tokens: Vec::new(),
141            n_iter: 0,
142            audio_buf: VecDeque::new(),
143            total_samples_processed: 0,
144        })
145    }
146
147    // --- Audio input ---
148
149    /// Push samples into the internal buffer (replaces SDL capture).
150    pub fn feed_audio(&mut self, samples: &[f32]) {
151        self.audio_buf.extend(samples.iter());
152    }
153
154    // --- Processing ---
155
156    /// Dispatch to fixed-step or VAD mode.
157    pub fn process_step(&mut self) -> Result<Option<Vec<Segment>>> {
158        if !self.use_vad {
159            self.process_step_fixed()
160        } else {
161            self.process_step_vad()
162        }
163    }
164
165    /// Fixed-step (sliding window) mode — port of stream.cpp lines 253-428.
166    fn process_step_fixed(&mut self) -> Result<Option<Vec<Segment>>> {
167        // Need at least n_samples_step new samples
168        if self.audio_buf.len() < self.n_samples_step {
169            return Ok(None);
170        }
171
172        // Pop n_samples_step from front of audio_buf
173        let pcmf32_new: Vec<f32> = self.audio_buf.drain(..self.n_samples_step).collect();
174        self.total_samples_processed += pcmf32_new.len() as i64;
175
176        let n_samples_new = pcmf32_new.len();
177
178        // Exact formula from stream.cpp line 279:
179        // n_samples_take = min(pcmf32_old.size(), max(0, n_samples_keep + n_samples_len - n_samples_new))
180        let n_samples_take = self
181            .pcmf32_old
182            .len()
183            .min((self.n_samples_keep + self.n_samples_len).saturating_sub(n_samples_new));
184
185        // Build pcmf32: tail of pcmf32_old + pcmf32_new
186        let mut pcmf32 = Vec::with_capacity(n_samples_take + n_samples_new);
187        if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
188            let start = self.pcmf32_old.len() - n_samples_take;
189            pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
190        }
191        pcmf32.extend_from_slice(&pcmf32_new);
192
193        // Save for next iteration
194        self.pcmf32_old = pcmf32.clone();
195
196        // Run inference
197        let segments = self.run_inference(&pcmf32)?;
198
199        self.n_iter += 1;
200
201        // At n_new_line boundary (stream.cpp lines 408-425)
202        if self.n_iter % self.n_new_line == 0 {
203            // Keep only last n_samples_keep samples
204            if self.n_samples_keep > 0 && pcmf32.len() >= self.n_samples_keep {
205                self.pcmf32_old = pcmf32[pcmf32.len() - self.n_samples_keep..].to_vec();
206            } else {
207                self.pcmf32_old.clear();
208            }
209
210            // Collect prompt tokens if !no_context
211            if !self.config.no_context {
212                self.collect_prompt_tokens();
213            }
214        }
215
216        Ok(Some(segments))
217    }
218
219    /// VAD mode — port of stream.cpp lines 293-313.
220    fn process_step_vad(&mut self) -> Result<Option<Vec<Segment>>> {
221        // Need at least 2 seconds of audio (stream.cpp: t_diff < 2000 → continue)
222        let n_vad_samples = (WHISPER_SAMPLE_RATE * 2) as usize; // 32000 samples
223        if self.audio_buf.len() < n_vad_samples {
224            return Ok(None);
225        }
226
227        // Pop 2 seconds for VAD probe
228        let pcmf32_vad: Vec<f32> = self.audio_buf.drain(..n_vad_samples).collect();
229        self.total_samples_processed += pcmf32_vad.len() as i64;
230
231        // Check for speech
232        let is_silence = vad_simple(
233            &pcmf32_vad,
234            WHISPER_SAMPLE_RATE,
235            1000,
236            self.config.vad_thold,
237            self.config.freq_thold,
238        );
239
240        if is_silence {
241            return Ok(None);
242        }
243
244        // Speech detected — grab length_ms of audio total (stream.cpp line 305)
245        let n_samples_len = self.n_samples_len;
246        let additional = n_samples_len.saturating_sub(pcmf32_vad.len());
247        let mut pcmf32 = pcmf32_vad;
248
249        if additional > 0 {
250            let available = additional.min(self.audio_buf.len());
251            let extra: Vec<f32> = self.audio_buf.drain(..available).collect();
252            self.total_samples_processed += extra.len() as i64;
253            pcmf32.extend_from_slice(&extra);
254        }
255
256        let segments = self.run_inference(&pcmf32)?;
257        self.n_iter += 1;
258
259        Ok(Some(segments))
260    }
261
262    /// Run whisper inference on audio — port of stream.cpp lines 316-344.
263    fn run_inference(&mut self, audio: &[f32]) -> Result<Vec<Segment>> {
264        if audio.is_empty() {
265            return Ok(Vec::new());
266        }
267
268        // Clone params so we can set prompt_tokens pointer
269        let mut params = self.params.clone();
270
271        // Set prompt tokens on the clone, pointing to self.prompt_tokens.
272        // The prompt_tokens() method stores a raw pointer. self.prompt_tokens
273        // (Vec<i32>) lives on self and outlives the full() call, so this is safe.
274        if !self.config.no_context && !self.prompt_tokens.is_empty() {
275            params = params.prompt_tokens(&self.prompt_tokens);
276        }
277
278        self.state.full(params, audio)?;
279
280        // Extract segments
281        let n_segments = self.state.full_n_segments();
282        let mut segments = Vec::with_capacity(n_segments as usize);
283
284        for i in 0..n_segments {
285            let text = self.state.full_get_segment_text(i)?;
286            let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
287            let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
288
289            segments.push(Segment {
290                start_ms,
291                end_ms,
292                text,
293                speaker_turn_next,
294            });
295        }
296
297        Ok(segments)
298    }
299
300    /// Collect prompt tokens from last inference — port of stream.cpp lines 416-425.
301    fn collect_prompt_tokens(&mut self) {
302        self.prompt_tokens.clear();
303
304        let n_segments = self.state.full_n_segments();
305        for i in 0..n_segments {
306            let token_count = self.state.full_n_tokens(i);
307            for j in 0..token_count {
308                self.prompt_tokens.push(self.state.full_get_token_id(i, j));
309            }
310        }
311    }
312
313    // --- Convenience methods ---
314
315    /// Process all remaining audio in buffer.
316    pub fn flush(&mut self) -> Result<Vec<Segment>> {
317        let mut all_segments = Vec::new();
318
319        loop {
320            match self.process_step()? {
321                Some(segments) => all_segments.extend(segments),
322                None => break,
323            }
324        }
325
326        // If there's leftover audio that's less than a full step, run inference on it
327        if !self.audio_buf.is_empty() {
328            let remaining: Vec<f32> = self.audio_buf.drain(..).collect();
329            self.total_samples_processed += remaining.len() as i64;
330
331            if !self.use_vad {
332                // Build final buffer with overlap
333                let n_samples_take = self.pcmf32_old.len().min(
334                    (self.n_samples_keep + self.n_samples_len).saturating_sub(remaining.len()),
335                );
336                let mut pcmf32 = Vec::with_capacity(n_samples_take + remaining.len());
337                if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
338                    let start = self.pcmf32_old.len() - n_samples_take;
339                    pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
340                }
341                pcmf32.extend_from_slice(&remaining);
342
343                let segments = self.run_inference(&pcmf32)?;
344                all_segments.extend(segments);
345            } else {
346                let segments = self.run_inference(&remaining)?;
347                all_segments.extend(segments);
348            }
349        }
350
351        Ok(all_segments)
352    }
353
354    /// Clear buffers, counters, prompt tokens.
355    pub fn reset(&mut self) {
356        self.audio_buf.clear();
357        self.pcmf32_old.clear();
358        self.prompt_tokens.clear();
359        self.n_iter = 0;
360        self.total_samples_processed = 0;
361    }
362
363    /// Samples currently in the internal buffer.
364    pub fn buffer_size(&self) -> usize {
365        self.audio_buf.len()
366    }
367
368    /// Total samples consumed from the buffer.
369    pub fn processed_samples(&self) -> i64 {
370        self.total_samples_processed
371    }
372}
373
374// ---------------------------------------------------------------------------
375// vad_simple + high_pass_filter — port from common.cpp
376// ---------------------------------------------------------------------------
377
378/// High-pass filter — port of common.cpp::high_pass_filter (lines 597-608).
379fn high_pass_filter(data: &mut [f32], cutoff: f32, sample_rate: f32) {
380    if data.is_empty() {
381        return;
382    }
383    let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff);
384    let dt = 1.0 / sample_rate;
385    let alpha = dt / (rc + dt);
386
387    let mut y = data[0];
388    for i in 1..data.len() {
389        y = alpha * (y + data[i] - data[i - 1]);
390        data[i] = y;
391    }
392}
393
394/// Energy-based VAD — port of common.cpp::vad_simple (lines 610-646).
395///
396/// Returns `true` if **silence** (no speech detected).
397fn vad_simple(
398    pcmf32: &[f32],
399    sample_rate: i32,
400    last_ms: i32,
401    vad_thold: f32,
402    freq_thold: f32,
403) -> bool {
404    let n_samples = pcmf32.len();
405    let n_samples_last = (sample_rate as usize * last_ms.max(0) as usize) / 1000;
406
407    if n_samples_last >= n_samples {
408        // not enough samples — assume no speech (C++ returns false here,
409        // but the sense in C++ is inverted: false = silence. We return true = silence.)
410        return true;
411    }
412
413    // Work on a copy so we can apply the high-pass filter
414    let mut data = pcmf32.to_vec();
415
416    if freq_thold > 0.0 {
417        high_pass_filter(&mut data, freq_thold, sample_rate as f32);
418    }
419
420    let mut energy_all: f32 = 0.0;
421    let mut energy_last: f32 = 0.0;
422
423    for (i, &s) in data.iter().enumerate() {
424        energy_all += s.abs();
425        if i >= n_samples - n_samples_last {
426            energy_last += s.abs();
427        }
428    }
429
430    energy_all /= n_samples as f32;
431    energy_last /= n_samples_last as f32;
432
433    // C++ returns false (speech) when energy_last > thold * energy_all.
434    // We return true for silence.
435    energy_last <= vad_thold * energy_all
436}
437
438// ---------------------------------------------------------------------------
439// Tests
440// ---------------------------------------------------------------------------
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use crate::SamplingStrategy;
446    use std::path::Path;
447
448    #[test]
449    fn test_config_defaults() {
450        let config = WhisperStreamConfig::default();
451        assert_eq!(config.step_ms, 3000);
452        assert_eq!(config.length_ms, 10000);
453        assert_eq!(config.keep_ms, 200);
454        assert!((config.vad_thold - 0.6).abs() < f32::EPSILON);
455        assert!((config.freq_thold - 100.0).abs() < f32::EPSILON);
456        assert!(config.no_context);
457    }
458
459    #[test]
460    fn test_config_normalization() {
461        // keep_ms clamped to step_ms
462        let model_path = "tests/models/ggml-tiny.en.bin";
463        if !Path::new(model_path).exists() {
464            // Can't test normalization without a model for the constructor.
465            // Test the logic directly instead.
466            let mut config = WhisperStreamConfig {
467                step_ms: 2000,
468                length_ms: 5000,
469                keep_ms: 3000, // > step_ms, should be clamped
470                ..Default::default()
471            };
472            config.keep_ms = config.keep_ms.min(config.step_ms);
473            config.length_ms = config.length_ms.max(config.step_ms);
474            assert_eq!(config.keep_ms, 2000);
475            assert_eq!(config.length_ms, 5000);
476
477            // length_ms clamped up to step_ms
478            let mut config2 = WhisperStreamConfig {
479                step_ms: 8000,
480                length_ms: 5000, // < step_ms, should be raised
481                keep_ms: 200,
482                ..Default::default()
483            };
484            config2.keep_ms = config2.keep_ms.min(config2.step_ms);
485            config2.length_ms = config2.length_ms.max(config2.step_ms);
486            assert_eq!(config2.length_ms, 8000);
487            assert_eq!(config2.keep_ms, 200);
488        }
489    }
490
491    #[test]
492    fn test_n_new_line_calculation() {
493        // n_new_line = max(1, length_ms / step_ms - 1) when !use_vad
494        // Defaults: length_ms=10000, step_ms=3000 → 10000/3000 - 1 = 2
495        let n = (10000i32 / 3000 - 1).max(1);
496        assert_eq!(n, 2);
497
498        // step_ms=5000, length_ms=10000 → 10000/5000 - 1 = 1
499        let n = (10000i32 / 5000 - 1).max(1);
500        assert_eq!(n, 1);
501
502        // step_ms=10000, length_ms=10000 → 10000/10000 - 1 = 0 → clamped to 1
503        let n = (10000i32 / 10000 - 1).max(1);
504        assert_eq!(n, 1);
505
506        // step_ms=2000, length_ms=10000 → 10000/2000 - 1 = 4
507        let n = (10000i32 / 2000 - 1).max(1);
508        assert_eq!(n, 4);
509
510        // VAD mode: always 1
511        let n_vad = 1i32;
512        assert_eq!(n_vad, 1);
513    }
514
515    #[test]
516    fn test_vad_mode_detection() {
517        // step_ms <= 0 → use_vad
518        let step_ms_values = [0, -1, -100];
519        for step_ms in step_ms_values {
520            let n_samples_step = (1e-3 * step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
521            assert_eq!(
522                n_samples_step, 0,
523                "step_ms={} should yield 0 samples",
524                step_ms
525            );
526        }
527
528        // step_ms > 0 → fixed step
529        let n = (1e-3 * 3000.0 * WHISPER_SAMPLE_RATE as f64) as usize;
530        assert_eq!(n, 48000);
531    }
532
533    #[test]
534    fn test_feed_and_buffer() {
535        let model_path = "tests/models/ggml-tiny.en.bin";
536        if !Path::new(model_path).exists() {
537            eprintln!("Skipping test_feed_and_buffer: model not found");
538            return;
539        }
540
541        let ctx = WhisperContext::new(model_path).unwrap();
542        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
543        let mut stream = WhisperStream::new(&ctx, params).unwrap();
544
545        assert_eq!(stream.buffer_size(), 0);
546
547        let samples = vec![0.0f32; 16000];
548        stream.feed_audio(&samples);
549        assert_eq!(stream.buffer_size(), 16000);
550
551        stream.feed_audio(&samples);
552        assert_eq!(stream.buffer_size(), 32000);
553    }
554
555    #[test]
556    fn test_vad_simple_silence() {
557        let silence = vec![0.0f32; 16000];
558        assert!(vad_simple(&silence, 16000, 100, 0.6, 100.0));
559    }
560
561    #[test]
562    fn test_vad_simple_too_few_samples() {
563        let short = vec![0.1f32; 100];
564        assert!(vad_simple(&short, 16000, 1000, 0.6, 100.0));
565    }
566
567    #[test]
568    fn test_high_pass_filter_basic() {
569        let mut data = vec![1.0, 0.0, 1.0, 0.0, 1.0];
570        high_pass_filter(&mut data, 100.0, 16000.0);
571        assert_ne!(data[2], 1.0);
572    }
573
574    #[test]
575    fn test_reset() {
576        let model_path = "tests/models/ggml-tiny.en.bin";
577        if !Path::new(model_path).exists() {
578            eprintln!("Skipping test_reset: model not found");
579            return;
580        }
581
582        let ctx = WhisperContext::new(model_path).unwrap();
583        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
584        let mut stream = WhisperStream::new(&ctx, params).unwrap();
585
586        stream.feed_audio(&vec![0.0f32; 16000]);
587        assert_eq!(stream.buffer_size(), 16000);
588
589        stream.reset();
590        assert_eq!(stream.buffer_size(), 0);
591        assert_eq!(stream.processed_samples(), 0);
592    }
593
594    // --- Integration tests (require model) ---
595
596    #[test]
597    fn test_fixed_step_basic() {
598        let model_path = "tests/models/ggml-tiny.en.bin";
599        if !Path::new(model_path).exists() {
600            eprintln!("Skipping test_fixed_step_basic: model not found");
601            return;
602        }
603
604        let ctx = WhisperContext::new(model_path).unwrap();
605        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }).language("en");
606
607        // Use a small step for testing
608        let config = WhisperStreamConfig {
609            step_ms: 3000,
610            length_ms: 10000,
611            keep_ms: 200,
612            ..Default::default()
613        };
614
615        let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
616
617        // Feed enough audio for one step (3 seconds = 48000 samples)
618        let audio = vec![0.0f32; 48000];
619        stream.feed_audio(&audio);
620
621        let result = stream.process_step().unwrap();
622        assert!(
623            result.is_some(),
624            "Should produce segments with enough audio"
625        );
626        assert!(stream.processed_samples() > 0);
627    }
628
629    #[test]
630    fn test_prompt_propagation() {
631        let model_path = "tests/models/ggml-tiny.en.bin";
632        if !Path::new(model_path).exists() {
633            eprintln!("Skipping test_prompt_propagation: model not found");
634            return;
635        }
636
637        let ctx = WhisperContext::new(model_path).unwrap();
638        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }).language("en");
639
640        let config = WhisperStreamConfig {
641            step_ms: 3000,
642            length_ms: 6000,
643            keep_ms: 200,
644            no_context: false, // enable prompt propagation
645            ..Default::default()
646        };
647
648        let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
649
650        // n_new_line = max(1, 6000/3000 - 1) = 1, so every iteration triggers
651        // prompt collection when no_context=false.
652
653        // Feed enough for one step
654        let audio = vec![0.0f32; 48000];
655        stream.feed_audio(&audio);
656
657        let result = stream.process_step().unwrap();
658        assert!(result.is_some());
659
660        // After one iteration at the n_new_line boundary, prompt_tokens should
661        // be populated (assuming whisper produced at least one token).
662        // With silence input, whisper may or may not produce tokens, so we
663        // just verify the mechanism didn't panic.
664        assert!(stream.processed_samples() > 0);
665    }
666}