use crate::transcribe::streaming_whisper_params;
use whisper_rs::WhisperContext;
const PARTIAL_INTERVAL_SAMPLES: usize = 16000 * 2;
const MIN_TRANSCRIBE_SAMPLES: usize = 16000;
#[derive(Debug, Clone)]
pub struct StreamingResult {
pub text: String,
pub is_final: bool,
pub duration_secs: f64,
}
pub struct StreamingWhisper {
audio_buffer: Vec<f32>,
samples_since_partial: usize,
last_partial: String,
n_threads: i32,
language: Option<String>,
has_created_state: bool,
}
impl StreamingWhisper {
pub fn new(language: Option<String>) -> Self {
Self {
audio_buffer: Vec::with_capacity(16000 * 30), samples_since_partial: 0,
last_partial: String::new(),
n_threads: num_cpus(),
language,
has_created_state: false,
}
}
pub fn feed(&mut self, samples: &[f32], ctx: &WhisperContext) -> Option<StreamingResult> {
self.audio_buffer.extend_from_slice(samples);
self.samples_since_partial += samples.len();
if self.samples_since_partial >= PARTIAL_INTERVAL_SAMPLES
&& self.audio_buffer.len() >= MIN_TRANSCRIBE_SAMPLES
{
self.samples_since_partial = 0;
return self.transcribe(ctx, false);
}
None
}
pub fn finalize(&mut self, ctx: &WhisperContext) -> Option<StreamingResult> {
if self.audio_buffer.len() < MIN_TRANSCRIBE_SAMPLES {
return None;
}
self.transcribe(ctx, true)
}
pub fn reset(&mut self) {
self.audio_buffer.clear();
self.samples_since_partial = 0;
self.last_partial.clear();
}
pub fn duration_secs(&self) -> f64 {
self.audio_buffer.len() as f64 / 16000.0
}
fn transcribe(&mut self, ctx: &WhisperContext, is_final: bool) -> Option<StreamingResult> {
let mut state = if self.has_created_state {
let state = suppress_stderr(|| ctx.create_state().ok());
state?
} else {
self.has_created_state = true;
ctx.create_state().ok()?
};
let mut params = streaming_whisper_params();
params.set_n_threads(self.n_threads);
params.set_language(self.language.as_deref());
let start = std::time::Instant::now();
if let Err(e) = state.full(params, &self.audio_buffer) {
tracing::warn!("streaming whisper failed: {}", e);
return None;
}
let elapsed_ms = start.elapsed().as_millis();
let duration_secs = self.audio_buffer.len() as f64 / 16000.0;
let num_segments = state.full_n_segments();
let mut text = String::new();
for i in 0..num_segments {
if let Some(seg) = state.get_segment(i) {
if let Ok(t) = seg.to_str_lossy() {
let t = t.trim();
if !t.is_empty() {
if !text.is_empty() {
text.push(' ');
}
text.push_str(t);
}
}
}
}
let text = text.trim().to_string();
if text.is_empty() {
return None;
}
if !is_final && text == self.last_partial {
return None;
}
tracing::debug!(
partial = !is_final,
words = text.split_whitespace().count(),
audio_secs = format!("{:.1}", duration_secs),
whisper_ms = elapsed_ms,
"streaming transcription"
);
self.last_partial = text.clone();
Some(StreamingResult {
text,
is_final,
duration_secs,
})
}
}
fn suppress_stderr<T>(f: impl FnOnce() -> T) -> T {
#[cfg(unix)]
{
use std::os::unix::io::AsRawFd;
let stderr_fd = std::io::stderr().as_raw_fd();
let saved = unsafe { libc::dup(stderr_fd) };
if saved >= 0 {
let devnull = std::fs::OpenOptions::new()
.write(true)
.open("/dev/null")
.ok();
if let Some(ref dn) = devnull {
unsafe { libc::dup2(dn.as_raw_fd(), stderr_fd) };
}
let result = f();
unsafe { libc::dup2(saved, stderr_fd) };
unsafe { libc::close(saved) };
return result;
}
}
f()
}
fn num_cpus() -> i32 {
whisper_guard::params::num_cpus()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_streaming_whisper_has_empty_buffer() {
let sw = StreamingWhisper::new(None);
assert_eq!(sw.duration_secs(), 0.0);
assert!(sw.audio_buffer.is_empty());
}
#[test]
fn feed_below_interval_returns_none() {
let mut sw = StreamingWhisper::new(None);
let silence = vec![0.0f32; 16000];
sw.audio_buffer.extend_from_slice(&silence);
sw.samples_since_partial += silence.len();
assert_eq!(sw.duration_secs(), 1.0);
assert_eq!(sw.samples_since_partial, 16000);
}
#[test]
fn reset_clears_state() {
let mut sw = StreamingWhisper::new(Some("en".into()));
sw.audio_buffer.extend_from_slice(&[0.0; 16000]);
sw.samples_since_partial = 16000;
sw.last_partial = "hello".into();
sw.reset();
assert!(sw.audio_buffer.is_empty());
assert_eq!(sw.samples_since_partial, 0);
assert!(sw.last_partial.is_empty());
assert_eq!(sw.duration_secs(), 0.0);
}
}