use crate::context::WhisperContext;
use crate::error::Result;
use crate::params::FullParams;
use crate::state::{Segment, WhisperState};
use std::collections::VecDeque;
const WHISPER_SAMPLE_RATE: i32 = 16000;
#[derive(Debug, Clone)]
pub struct WhisperStreamConfig {
pub step_ms: i32,
pub length_ms: i32,
pub keep_ms: i32,
pub vad_thold: f32,
pub freq_thold: f32,
pub no_context: bool,
}
impl Default for WhisperStreamConfig {
fn default() -> Self {
Self {
step_ms: 3000,
length_ms: 10000,
keep_ms: 200,
vad_thold: 0.6,
freq_thold: 100.0,
no_context: true,
}
}
}
pub struct WhisperStream {
state: WhisperState,
params: FullParams,
config: WhisperStreamConfig,
use_vad: bool,
n_samples_step: usize,
n_samples_len: usize,
n_samples_keep: usize,
n_new_line: i32,
pcmf32_old: Vec<f32>,
prompt_tokens: Vec<i32>,
n_iter: i32,
audio_buf: VecDeque<f32>,
total_samples_processed: i64,
}
impl WhisperStream {
pub fn new(ctx: &WhisperContext, params: FullParams) -> Result<Self> {
Self::with_config(ctx, params, WhisperStreamConfig::default())
}
pub fn with_config(
ctx: &WhisperContext,
mut params: FullParams,
mut config: WhisperStreamConfig,
) -> Result<Self> {
let state = WhisperState::new(ctx)?;
config.keep_ms = config.keep_ms.min(config.step_ms);
config.length_ms = config.length_ms.max(config.step_ms);
let n_samples_step =
(1e-3 * config.step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
let n_samples_len =
(1e-3 * config.length_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
let n_samples_keep =
(1e-3 * config.keep_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
let use_vad = n_samples_step == 0;
let n_new_line = if !use_vad {
(config.length_ms / config.step_ms - 1).max(1)
} else {
1
};
params = params
.no_timestamps(!use_vad)
.max_tokens(0)
.single_segment(!use_vad)
.print_progress(false)
.print_realtime(false);
if use_vad {
config.no_context = true;
params = params.no_context(true);
}
Ok(Self {
state,
params,
config,
use_vad,
n_samples_step,
n_samples_len,
n_samples_keep,
n_new_line,
pcmf32_old: Vec::new(),
prompt_tokens: Vec::new(),
n_iter: 0,
audio_buf: VecDeque::new(),
total_samples_processed: 0,
})
}
pub fn feed_audio(&mut self, samples: &[f32]) {
self.audio_buf.extend(samples.iter());
}
pub fn process_step(&mut self) -> Result<Option<Vec<Segment>>> {
if !self.use_vad {
self.process_step_fixed()
} else {
self.process_step_vad()
}
}
fn process_step_fixed(&mut self) -> Result<Option<Vec<Segment>>> {
if self.audio_buf.len() < self.n_samples_step {
return Ok(None);
}
let pcmf32_new: Vec<f32> = self.audio_buf.drain(..self.n_samples_step).collect();
self.total_samples_processed += pcmf32_new.len() as i64;
let n_samples_new = pcmf32_new.len();
let n_samples_take = self.pcmf32_old.len().min(
(self.n_samples_keep + self.n_samples_len).saturating_sub(n_samples_new),
);
let mut pcmf32 = Vec::with_capacity(n_samples_take + n_samples_new);
if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
let start = self.pcmf32_old.len() - n_samples_take;
pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
}
pcmf32.extend_from_slice(&pcmf32_new);
self.pcmf32_old = pcmf32.clone();
let segments = self.run_inference(&pcmf32)?;
self.n_iter += 1;
if self.n_iter % self.n_new_line == 0 {
if self.n_samples_keep > 0 && pcmf32.len() >= self.n_samples_keep {
self.pcmf32_old =
pcmf32[pcmf32.len() - self.n_samples_keep..].to_vec();
} else {
self.pcmf32_old.clear();
}
if !self.config.no_context {
self.collect_prompt_tokens();
}
}
Ok(Some(segments))
}
fn process_step_vad(&mut self) -> Result<Option<Vec<Segment>>> {
let n_vad_samples = (WHISPER_SAMPLE_RATE * 2) as usize; if self.audio_buf.len() < n_vad_samples {
return Ok(None);
}
let pcmf32_vad: Vec<f32> = self.audio_buf.drain(..n_vad_samples).collect();
self.total_samples_processed += pcmf32_vad.len() as i64;
let is_silence = vad_simple(
&pcmf32_vad,
WHISPER_SAMPLE_RATE,
1000,
self.config.vad_thold,
self.config.freq_thold,
);
if is_silence {
return Ok(None);
}
let n_samples_len = self.n_samples_len;
let additional = n_samples_len.saturating_sub(pcmf32_vad.len());
let mut pcmf32 = pcmf32_vad;
if additional > 0 {
let available = additional.min(self.audio_buf.len());
let extra: Vec<f32> = self.audio_buf.drain(..available).collect();
self.total_samples_processed += extra.len() as i64;
pcmf32.extend_from_slice(&extra);
}
let segments = self.run_inference(&pcmf32)?;
self.n_iter += 1;
Ok(Some(segments))
}
fn run_inference(&mut self, audio: &[f32]) -> Result<Vec<Segment>> {
if audio.is_empty() {
return Ok(Vec::new());
}
let mut params = self.params.clone();
if !self.config.no_context && !self.prompt_tokens.is_empty() {
params = params.prompt_tokens(&self.prompt_tokens);
}
self.state.full(params, audio)?;
let n_segments = self.state.full_n_segments();
let mut segments = Vec::with_capacity(n_segments as usize);
for i in 0..n_segments {
let text = self.state.full_get_segment_text(i)?;
let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
segments.push(Segment {
start_ms,
end_ms,
text,
speaker_turn_next,
});
}
Ok(segments)
}
fn collect_prompt_tokens(&mut self) {
self.prompt_tokens.clear();
let n_segments = self.state.full_n_segments();
for i in 0..n_segments {
let token_count = self.state.full_n_tokens(i);
for j in 0..token_count {
self.prompt_tokens
.push(self.state.full_get_token_id(i, j));
}
}
}
pub fn flush(&mut self) -> Result<Vec<Segment>> {
let mut all_segments = Vec::new();
loop {
match self.process_step()? {
Some(segments) => all_segments.extend(segments),
None => break,
}
}
if !self.audio_buf.is_empty() {
let remaining: Vec<f32> = self.audio_buf.drain(..).collect();
self.total_samples_processed += remaining.len() as i64;
if !self.use_vad {
let n_samples_take = self.pcmf32_old.len().min(
(self.n_samples_keep + self.n_samples_len)
.saturating_sub(remaining.len()),
);
let mut pcmf32 = Vec::with_capacity(n_samples_take + remaining.len());
if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
let start = self.pcmf32_old.len() - n_samples_take;
pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
}
pcmf32.extend_from_slice(&remaining);
let segments = self.run_inference(&pcmf32)?;
all_segments.extend(segments);
} else {
let segments = self.run_inference(&remaining)?;
all_segments.extend(segments);
}
}
Ok(all_segments)
}
pub fn reset(&mut self) {
self.audio_buf.clear();
self.pcmf32_old.clear();
self.prompt_tokens.clear();
self.n_iter = 0;
self.total_samples_processed = 0;
}
pub fn buffer_size(&self) -> usize {
self.audio_buf.len()
}
pub fn processed_samples(&self) -> i64 {
self.total_samples_processed
}
}
fn high_pass_filter(data: &mut [f32], cutoff: f32, sample_rate: f32) {
if data.is_empty() {
return;
}
let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff);
let dt = 1.0 / sample_rate;
let alpha = dt / (rc + dt);
let mut y = data[0];
for i in 1..data.len() {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
fn vad_simple(
pcmf32: &[f32],
sample_rate: i32,
last_ms: i32,
vad_thold: f32,
freq_thold: f32,
) -> bool {
let n_samples = pcmf32.len();
let n_samples_last = (sample_rate as usize * last_ms.max(0) as usize) / 1000;
if n_samples_last >= n_samples {
return true;
}
let mut data = pcmf32.to_vec();
if freq_thold > 0.0 {
high_pass_filter(&mut data, freq_thold, sample_rate as f32);
}
let mut energy_all: f32 = 0.0;
let mut energy_last: f32 = 0.0;
for (i, &s) in data.iter().enumerate() {
energy_all += s.abs();
if i >= n_samples - n_samples_last {
energy_last += s.abs();
}
}
energy_all /= n_samples as f32;
energy_last /= n_samples_last as f32;
energy_last <= vad_thold * energy_all
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SamplingStrategy;
use std::path::Path;
#[test]
fn test_config_defaults() {
let config = WhisperStreamConfig::default();
assert_eq!(config.step_ms, 3000);
assert_eq!(config.length_ms, 10000);
assert_eq!(config.keep_ms, 200);
assert!((config.vad_thold - 0.6).abs() < f32::EPSILON);
assert!((config.freq_thold - 100.0).abs() < f32::EPSILON);
assert!(config.no_context);
}
#[test]
fn test_config_normalization() {
let model_path = "tests/models/ggml-tiny.en.bin";
if !Path::new(model_path).exists() {
let mut config = WhisperStreamConfig {
step_ms: 2000,
length_ms: 5000,
keep_ms: 3000, ..Default::default()
};
config.keep_ms = config.keep_ms.min(config.step_ms);
config.length_ms = config.length_ms.max(config.step_ms);
assert_eq!(config.keep_ms, 2000);
assert_eq!(config.length_ms, 5000);
let mut config2 = WhisperStreamConfig {
step_ms: 8000,
length_ms: 5000, keep_ms: 200,
..Default::default()
};
config2.keep_ms = config2.keep_ms.min(config2.step_ms);
config2.length_ms = config2.length_ms.max(config2.step_ms);
assert_eq!(config2.length_ms, 8000);
assert_eq!(config2.keep_ms, 200);
}
}
#[test]
fn test_n_new_line_calculation() {
let n = (10000i32 / 3000 - 1).max(1);
assert_eq!(n, 2);
let n = (10000i32 / 5000 - 1).max(1);
assert_eq!(n, 1);
let n = (10000i32 / 10000 - 1).max(1);
assert_eq!(n, 1);
let n = (10000i32 / 2000 - 1).max(1);
assert_eq!(n, 4);
let n_vad = 1i32;
assert_eq!(n_vad, 1);
}
#[test]
fn test_vad_mode_detection() {
let step_ms_values = [0, -1, -100];
for step_ms in step_ms_values {
let n_samples_step =
(1e-3 * step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
assert_eq!(n_samples_step, 0, "step_ms={} should yield 0 samples", step_ms);
}
let n = (1e-3 * 3000.0 * WHISPER_SAMPLE_RATE as f64) as usize;
assert_eq!(n, 48000);
}
#[test]
fn test_feed_and_buffer() {
let model_path = "tests/models/ggml-tiny.en.bin";
if !Path::new(model_path).exists() {
eprintln!("Skipping test_feed_and_buffer: model not found");
return;
}
let ctx = WhisperContext::new(model_path).unwrap();
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
let mut stream = WhisperStream::new(&ctx, params).unwrap();
assert_eq!(stream.buffer_size(), 0);
let samples = vec![0.0f32; 16000];
stream.feed_audio(&samples);
assert_eq!(stream.buffer_size(), 16000);
stream.feed_audio(&samples);
assert_eq!(stream.buffer_size(), 32000);
}
#[test]
fn test_vad_simple_silence() {
let silence = vec![0.0f32; 16000];
assert!(vad_simple(&silence, 16000, 100, 0.6, 100.0));
}
#[test]
fn test_vad_simple_too_few_samples() {
let short = vec![0.1f32; 100];
assert!(vad_simple(&short, 16000, 1000, 0.6, 100.0));
}
#[test]
fn test_high_pass_filter_basic() {
let mut data = vec![1.0, 0.0, 1.0, 0.0, 1.0];
high_pass_filter(&mut data, 100.0, 16000.0);
assert_ne!(data[2], 1.0);
}
#[test]
fn test_reset() {
let model_path = "tests/models/ggml-tiny.en.bin";
if !Path::new(model_path).exists() {
eprintln!("Skipping test_reset: model not found");
return;
}
let ctx = WhisperContext::new(model_path).unwrap();
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
let mut stream = WhisperStream::new(&ctx, params).unwrap();
stream.feed_audio(&vec![0.0f32; 16000]);
assert_eq!(stream.buffer_size(), 16000);
stream.reset();
assert_eq!(stream.buffer_size(), 0);
assert_eq!(stream.processed_samples(), 0);
}
#[test]
fn test_fixed_step_basic() {
let model_path = "tests/models/ggml-tiny.en.bin";
if !Path::new(model_path).exists() {
eprintln!("Skipping test_fixed_step_basic: model not found");
return;
}
let ctx = WhisperContext::new(model_path).unwrap();
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
.language("en");
let config = WhisperStreamConfig {
step_ms: 3000,
length_ms: 10000,
keep_ms: 200,
..Default::default()
};
let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
let audio = vec![0.0f32; 48000];
stream.feed_audio(&audio);
let result = stream.process_step().unwrap();
assert!(result.is_some(), "Should produce segments with enough audio");
assert!(stream.processed_samples() > 0);
}
#[test]
fn test_prompt_propagation() {
let model_path = "tests/models/ggml-tiny.en.bin";
if !Path::new(model_path).exists() {
eprintln!("Skipping test_prompt_propagation: model not found");
return;
}
let ctx = WhisperContext::new(model_path).unwrap();
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
.language("en");
let config = WhisperStreamConfig {
step_ms: 3000,
length_ms: 6000,
keep_ms: 200,
no_context: false, ..Default::default()
};
let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
let audio = vec![0.0f32; 48000];
stream.feed_audio(&audio);
let result = stream.process_step().unwrap();
assert!(result.is_some());
assert!(stream.processed_samples() > 0);
}
}