1use crate::context::WhisperContext;
7use crate::error::Result;
8use crate::params::FullParams;
9use crate::state::{Segment, WhisperState};
10use std::collections::VecDeque;
11
12const WHISPER_SAMPLE_RATE: i32 = 16000;
13
14#[derive(Debug, Clone)]
20pub struct WhisperStreamConfig {
21 pub step_ms: i32,
23 pub length_ms: i32,
25 pub keep_ms: i32,
27 pub vad_thold: f32,
29 pub freq_thold: f32,
31 pub no_context: bool,
33}
34
35impl Default for WhisperStreamConfig {
36 fn default() -> Self {
37 Self {
38 step_ms: 3000,
39 length_ms: 10000,
40 keep_ms: 200,
41 vad_thold: 0.6,
42 freq_thold: 100.0,
43 no_context: true,
44 }
45 }
46}
47
48pub struct WhisperStream {
58 state: WhisperState,
59 params: FullParams,
60 config: WhisperStreamConfig,
61 use_vad: bool,
62
63 n_samples_step: usize,
65 n_samples_len: usize,
66 n_samples_keep: usize,
67 n_new_line: i32,
68
69 pcmf32_old: Vec<f32>,
71 prompt_tokens: Vec<i32>,
73
74 n_iter: i32,
75
76 audio_buf: VecDeque<f32>,
78
79 total_samples_processed: i64,
81}
82
83impl WhisperStream {
84 pub fn new(ctx: &WhisperContext, params: FullParams) -> Result<Self> {
86 Self::with_config(ctx, params, WhisperStreamConfig::default())
87 }
88
89 pub fn with_config(
91 ctx: &WhisperContext,
92 mut params: FullParams,
93 mut config: WhisperStreamConfig,
94 ) -> Result<Self> {
95 let state = WhisperState::new(ctx)?;
96
97 config.keep_ms = config.keep_ms.min(config.step_ms);
99 config.length_ms = config.length_ms.max(config.step_ms);
100
101 let n_samples_step = (1e-3 * config.step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
103 let n_samples_len = (1e-3 * config.length_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
104 let n_samples_keep = (1e-3 * config.keep_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
105
106 let use_vad = n_samples_step == 0; let n_new_line = if !use_vad {
111 (config.length_ms / config.step_ms - 1).max(1)
112 } else {
113 1
114 };
115
116 params = params
118 .no_timestamps(!use_vad)
119 .max_tokens(0)
120 .single_segment(!use_vad)
121 .print_progress(false)
122 .print_realtime(false);
123
124 if use_vad {
126 config.no_context = true;
127 params = params.no_context(true);
128 }
129
130 Ok(Self {
131 state,
132 params,
133 config,
134 use_vad,
135 n_samples_step,
136 n_samples_len,
137 n_samples_keep,
138 n_new_line,
139 pcmf32_old: Vec::new(),
140 prompt_tokens: Vec::new(),
141 n_iter: 0,
142 audio_buf: VecDeque::new(),
143 total_samples_processed: 0,
144 })
145 }
146
147 pub fn feed_audio(&mut self, samples: &[f32]) {
151 self.audio_buf.extend(samples.iter());
152 }
153
154 pub fn process_step(&mut self) -> Result<Option<Vec<Segment>>> {
158 if !self.use_vad {
159 self.process_step_fixed()
160 } else {
161 self.process_step_vad()
162 }
163 }
164
165 fn process_step_fixed(&mut self) -> Result<Option<Vec<Segment>>> {
167 if self.audio_buf.len() < self.n_samples_step {
169 return Ok(None);
170 }
171
172 let pcmf32_new: Vec<f32> = self.audio_buf.drain(..self.n_samples_step).collect();
174 self.total_samples_processed += pcmf32_new.len() as i64;
175
176 let n_samples_new = pcmf32_new.len();
177
178 let n_samples_take = self
181 .pcmf32_old
182 .len()
183 .min((self.n_samples_keep + self.n_samples_len).saturating_sub(n_samples_new));
184
185 let mut pcmf32 = Vec::with_capacity(n_samples_take + n_samples_new);
187 if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
188 let start = self.pcmf32_old.len() - n_samples_take;
189 pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
190 }
191 pcmf32.extend_from_slice(&pcmf32_new);
192
193 self.pcmf32_old = pcmf32.clone();
195
196 let segments = self.run_inference(&pcmf32)?;
198
199 self.n_iter += 1;
200
201 if self.n_iter % self.n_new_line == 0 {
203 if self.n_samples_keep > 0 && pcmf32.len() >= self.n_samples_keep {
205 self.pcmf32_old = pcmf32[pcmf32.len() - self.n_samples_keep..].to_vec();
206 } else {
207 self.pcmf32_old.clear();
208 }
209
210 if !self.config.no_context {
212 self.collect_prompt_tokens();
213 }
214 }
215
216 Ok(Some(segments))
217 }
218
219 fn process_step_vad(&mut self) -> Result<Option<Vec<Segment>>> {
221 let n_vad_samples = (WHISPER_SAMPLE_RATE * 2) as usize; if self.audio_buf.len() < n_vad_samples {
224 return Ok(None);
225 }
226
227 let pcmf32_vad: Vec<f32> = self.audio_buf.drain(..n_vad_samples).collect();
229 self.total_samples_processed += pcmf32_vad.len() as i64;
230
231 let is_silence = vad_simple(
233 &pcmf32_vad,
234 WHISPER_SAMPLE_RATE,
235 1000,
236 self.config.vad_thold,
237 self.config.freq_thold,
238 );
239
240 if is_silence {
241 return Ok(None);
242 }
243
244 let n_samples_len = self.n_samples_len;
246 let additional = n_samples_len.saturating_sub(pcmf32_vad.len());
247 let mut pcmf32 = pcmf32_vad;
248
249 if additional > 0 {
250 let available = additional.min(self.audio_buf.len());
251 let extra: Vec<f32> = self.audio_buf.drain(..available).collect();
252 self.total_samples_processed += extra.len() as i64;
253 pcmf32.extend_from_slice(&extra);
254 }
255
256 let segments = self.run_inference(&pcmf32)?;
257 self.n_iter += 1;
258
259 Ok(Some(segments))
260 }
261
262 fn run_inference(&mut self, audio: &[f32]) -> Result<Vec<Segment>> {
264 if audio.is_empty() {
265 return Ok(Vec::new());
266 }
267
268 let mut params = self.params.clone();
270
271 if !self.config.no_context && !self.prompt_tokens.is_empty() {
275 params = params.prompt_tokens(&self.prompt_tokens);
276 }
277
278 self.state.full(params, audio)?;
279
280 let n_segments = self.state.full_n_segments();
282 let mut segments = Vec::with_capacity(n_segments as usize);
283
284 for i in 0..n_segments {
285 let text = self.state.full_get_segment_text(i)?;
286 let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
287 let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
288
289 segments.push(Segment {
290 start_ms,
291 end_ms,
292 text,
293 speaker_turn_next,
294 });
295 }
296
297 Ok(segments)
298 }
299
300 fn collect_prompt_tokens(&mut self) {
302 self.prompt_tokens.clear();
303
304 let n_segments = self.state.full_n_segments();
305 for i in 0..n_segments {
306 let token_count = self.state.full_n_tokens(i);
307 for j in 0..token_count {
308 self.prompt_tokens.push(self.state.full_get_token_id(i, j));
309 }
310 }
311 }
312
313 pub fn flush(&mut self) -> Result<Vec<Segment>> {
317 let mut all_segments = Vec::new();
318
319 loop {
320 match self.process_step()? {
321 Some(segments) => all_segments.extend(segments),
322 None => break,
323 }
324 }
325
326 if !self.audio_buf.is_empty() {
328 let remaining: Vec<f32> = self.audio_buf.drain(..).collect();
329 self.total_samples_processed += remaining.len() as i64;
330
331 if !self.use_vad {
332 let n_samples_take = self.pcmf32_old.len().min(
334 (self.n_samples_keep + self.n_samples_len).saturating_sub(remaining.len()),
335 );
336 let mut pcmf32 = Vec::with_capacity(n_samples_take + remaining.len());
337 if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
338 let start = self.pcmf32_old.len() - n_samples_take;
339 pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
340 }
341 pcmf32.extend_from_slice(&remaining);
342
343 let segments = self.run_inference(&pcmf32)?;
344 all_segments.extend(segments);
345 } else {
346 let segments = self.run_inference(&remaining)?;
347 all_segments.extend(segments);
348 }
349 }
350
351 Ok(all_segments)
352 }
353
354 pub fn reset(&mut self) {
356 self.audio_buf.clear();
357 self.pcmf32_old.clear();
358 self.prompt_tokens.clear();
359 self.n_iter = 0;
360 self.total_samples_processed = 0;
361 }
362
363 pub fn buffer_size(&self) -> usize {
365 self.audio_buf.len()
366 }
367
368 pub fn processed_samples(&self) -> i64 {
370 self.total_samples_processed
371 }
372}
373
374fn high_pass_filter(data: &mut [f32], cutoff: f32, sample_rate: f32) {
380 if data.is_empty() {
381 return;
382 }
383 let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff);
384 let dt = 1.0 / sample_rate;
385 let alpha = dt / (rc + dt);
386
387 let mut y = data[0];
388 for i in 1..data.len() {
389 y = alpha * (y + data[i] - data[i - 1]);
390 data[i] = y;
391 }
392}
393
394fn vad_simple(
398 pcmf32: &[f32],
399 sample_rate: i32,
400 last_ms: i32,
401 vad_thold: f32,
402 freq_thold: f32,
403) -> bool {
404 let n_samples = pcmf32.len();
405 let n_samples_last = (sample_rate as usize * last_ms.max(0) as usize) / 1000;
406
407 if n_samples_last >= n_samples {
408 return true;
411 }
412
413 let mut data = pcmf32.to_vec();
415
416 if freq_thold > 0.0 {
417 high_pass_filter(&mut data, freq_thold, sample_rate as f32);
418 }
419
420 let mut energy_all: f32 = 0.0;
421 let mut energy_last: f32 = 0.0;
422
423 for (i, &s) in data.iter().enumerate() {
424 energy_all += s.abs();
425 if i >= n_samples - n_samples_last {
426 energy_last += s.abs();
427 }
428 }
429
430 energy_all /= n_samples as f32;
431 energy_last /= n_samples_last as f32;
432
433 energy_last <= vad_thold * energy_all
436}
437
438#[cfg(test)]
443mod tests {
444 use super::*;
445 use crate::SamplingStrategy;
446 use std::path::Path;
447
448 #[test]
449 fn test_config_defaults() {
450 let config = WhisperStreamConfig::default();
451 assert_eq!(config.step_ms, 3000);
452 assert_eq!(config.length_ms, 10000);
453 assert_eq!(config.keep_ms, 200);
454 assert!((config.vad_thold - 0.6).abs() < f32::EPSILON);
455 assert!((config.freq_thold - 100.0).abs() < f32::EPSILON);
456 assert!(config.no_context);
457 }
458
459 #[test]
460 fn test_config_normalization() {
461 let model_path = "tests/models/ggml-tiny.en.bin";
463 if !Path::new(model_path).exists() {
464 let mut config = WhisperStreamConfig {
467 step_ms: 2000,
468 length_ms: 5000,
469 keep_ms: 3000, ..Default::default()
471 };
472 config.keep_ms = config.keep_ms.min(config.step_ms);
473 config.length_ms = config.length_ms.max(config.step_ms);
474 assert_eq!(config.keep_ms, 2000);
475 assert_eq!(config.length_ms, 5000);
476
477 let mut config2 = WhisperStreamConfig {
479 step_ms: 8000,
480 length_ms: 5000, keep_ms: 200,
482 ..Default::default()
483 };
484 config2.keep_ms = config2.keep_ms.min(config2.step_ms);
485 config2.length_ms = config2.length_ms.max(config2.step_ms);
486 assert_eq!(config2.length_ms, 8000);
487 assert_eq!(config2.keep_ms, 200);
488 }
489 }
490
491 #[test]
492 fn test_n_new_line_calculation() {
493 let n = (10000i32 / 3000 - 1).max(1);
496 assert_eq!(n, 2);
497
498 let n = (10000i32 / 5000 - 1).max(1);
500 assert_eq!(n, 1);
501
502 let n = (10000i32 / 10000 - 1).max(1);
504 assert_eq!(n, 1);
505
506 let n = (10000i32 / 2000 - 1).max(1);
508 assert_eq!(n, 4);
509
510 let n_vad = 1i32;
512 assert_eq!(n_vad, 1);
513 }
514
515 #[test]
516 fn test_vad_mode_detection() {
517 let step_ms_values = [0, -1, -100];
519 for step_ms in step_ms_values {
520 let n_samples_step = (1e-3 * step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
521 assert_eq!(
522 n_samples_step, 0,
523 "step_ms={} should yield 0 samples",
524 step_ms
525 );
526 }
527
528 let n = (1e-3 * 3000.0 * WHISPER_SAMPLE_RATE as f64) as usize;
530 assert_eq!(n, 48000);
531 }
532
533 #[test]
534 fn test_feed_and_buffer() {
535 let model_path = "tests/models/ggml-tiny.en.bin";
536 if !Path::new(model_path).exists() {
537 eprintln!("Skipping test_feed_and_buffer: model not found");
538 return;
539 }
540
541 let ctx = WhisperContext::new(model_path).unwrap();
542 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
543 let mut stream = WhisperStream::new(&ctx, params).unwrap();
544
545 assert_eq!(stream.buffer_size(), 0);
546
547 let samples = vec![0.0f32; 16000];
548 stream.feed_audio(&samples);
549 assert_eq!(stream.buffer_size(), 16000);
550
551 stream.feed_audio(&samples);
552 assert_eq!(stream.buffer_size(), 32000);
553 }
554
555 #[test]
556 fn test_vad_simple_silence() {
557 let silence = vec![0.0f32; 16000];
558 assert!(vad_simple(&silence, 16000, 100, 0.6, 100.0));
559 }
560
561 #[test]
562 fn test_vad_simple_too_few_samples() {
563 let short = vec![0.1f32; 100];
564 assert!(vad_simple(&short, 16000, 1000, 0.6, 100.0));
565 }
566
567 #[test]
568 fn test_high_pass_filter_basic() {
569 let mut data = vec![1.0, 0.0, 1.0, 0.0, 1.0];
570 high_pass_filter(&mut data, 100.0, 16000.0);
571 assert_ne!(data[2], 1.0);
572 }
573
574 #[test]
575 fn test_reset() {
576 let model_path = "tests/models/ggml-tiny.en.bin";
577 if !Path::new(model_path).exists() {
578 eprintln!("Skipping test_reset: model not found");
579 return;
580 }
581
582 let ctx = WhisperContext::new(model_path).unwrap();
583 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
584 let mut stream = WhisperStream::new(&ctx, params).unwrap();
585
586 stream.feed_audio(&vec![0.0f32; 16000]);
587 assert_eq!(stream.buffer_size(), 16000);
588
589 stream.reset();
590 assert_eq!(stream.buffer_size(), 0);
591 assert_eq!(stream.processed_samples(), 0);
592 }
593
594 #[test]
597 fn test_fixed_step_basic() {
598 let model_path = "tests/models/ggml-tiny.en.bin";
599 if !Path::new(model_path).exists() {
600 eprintln!("Skipping test_fixed_step_basic: model not found");
601 return;
602 }
603
604 let ctx = WhisperContext::new(model_path).unwrap();
605 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }).language("en");
606
607 let config = WhisperStreamConfig {
609 step_ms: 3000,
610 length_ms: 10000,
611 keep_ms: 200,
612 ..Default::default()
613 };
614
615 let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
616
617 let audio = vec![0.0f32; 48000];
619 stream.feed_audio(&audio);
620
621 let result = stream.process_step().unwrap();
622 assert!(
623 result.is_some(),
624 "Should produce segments with enough audio"
625 );
626 assert!(stream.processed_samples() > 0);
627 }
628
629 #[test]
630 fn test_prompt_propagation() {
631 let model_path = "tests/models/ggml-tiny.en.bin";
632 if !Path::new(model_path).exists() {
633 eprintln!("Skipping test_prompt_propagation: model not found");
634 return;
635 }
636
637 let ctx = WhisperContext::new(model_path).unwrap();
638 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }).language("en");
639
640 let config = WhisperStreamConfig {
641 step_ms: 3000,
642 length_ms: 6000,
643 keep_ms: 200,
644 no_context: false, ..Default::default()
646 };
647
648 let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
649
650 let audio = vec![0.0f32; 48000];
655 stream.feed_audio(&audio);
656
657 let result = stream.process_step().unwrap();
658 assert!(result.is_some());
659
660 assert!(stream.processed_samples() > 0);
665 }
666}