use anyhow::{Context, Result, ensure};
use whisper_rs::WhisperContext;
use crate::audio_pipeline::TARGET_SAMPLE_RATE;
use crate::decoder::SamplesSink;
use crate::opts::Opts;
use crate::segment_encoder::SegmentEncoder;
use super::segments::{run_whisper_full, to_segment};
const DEFAULT_MAX_BUFFER_SECONDS: usize = 30;
const MAX_BACKOFF_SHIFT: u32 = 4;
pub(crate) struct BufferedSegmentTranscriber<'a> {
ctx: &'a WhisperContext,
opts: &'a Opts,
encoder: &'a mut dyn SegmentEncoder,
min_window_samples: usize,
max_window_samples: usize,
next_infer_at_samples: usize,
no_progress_runs: u32,
samples: Vec<f32>,
head: usize,
advanced_samples: usize,
}
impl<'a> BufferedSegmentTranscriber<'a> {
pub(crate) fn new(
ctx: &'a WhisperContext,
opts: &'a Opts,
encoder: &'a mut dyn SegmentEncoder,
) -> Self {
let min_window_seconds = opts.incremental_min_window_seconds.max(1);
let min_window_samples = TARGET_SAMPLE_RATE as usize * min_window_seconds;
let max_window_samples = TARGET_SAMPLE_RATE as usize * DEFAULT_MAX_BUFFER_SECONDS;
Self {
ctx,
opts,
encoder,
min_window_samples,
max_window_samples,
next_infer_at_samples: min_window_samples,
no_progress_runs: 0,
samples: Vec::new(),
head: 0,
advanced_samples: 0,
}
}
pub(crate) fn finish(&mut self) -> Result<()> {
let _ = self.process_available(true)?;
self.samples.clear();
self.head = 0;
Ok(())
}
fn window(&self) -> &[f32] {
&self.samples[self.head..]
}
fn window_len(&self) -> usize {
self.samples.len().saturating_sub(self.head)
}
fn maybe_compact(&mut self) {
if self.head == 0 {
return;
}
let should_compact =
self.head >= TARGET_SAMPLE_RATE as usize || self.head >= self.samples.len() / 2;
if should_compact {
self.samples.drain(..self.head);
self.head = 0;
}
}
fn process_available(&mut self, end_of_stream: bool) -> Result<Progress> {
let win_len = self.window_len();
if win_len == 0 {
return Ok(Progress::NoOp);
}
if !end_of_stream && win_len < self.min_window_samples {
return Ok(Progress::NoOp);
}
let force_flush = end_of_stream || win_len >= self.max_window_samples;
if !force_flush && win_len < self.next_infer_at_samples {
return Ok(Progress::NoOp);
}
let state = run_whisper_full(self.ctx, self.opts, self.window())?;
let n_segments_i32 = state.full_n_segments();
if n_segments_i32 <= 0 {
if !force_flush {
self.no_progress_runs = self.no_progress_runs.saturating_add(1);
self.next_infer_at_samples = next_infer_threshold(
win_len,
self.min_window_samples,
self.max_window_samples,
self.no_progress_runs,
);
}
return Ok(Progress::NoOp);
}
let n_segments: usize = n_segments_i32
.try_into()
.context("whisper returned a negative segment count")?;
let emit_count = if force_flush {
n_segments
} else if n_segments >= 2 {
n_segments - 1
} else {
0
};
if emit_count == 0 {
self.no_progress_runs = self.no_progress_runs.saturating_add(1);
self.next_infer_at_samples = next_infer_threshold(
win_len,
self.min_window_samples,
self.max_window_samples,
self.no_progress_runs,
);
return Ok(Progress::NoOp);
}
let offset_seconds = self.advanced_samples as f32 / TARGET_SAMPLE_RATE as f32;
for segment_idx in 0..emit_count {
let whisper_segment = state
.get_segment(segment_idx as i32)
.with_context(|| format!("whisper segment {segment_idx} was missing"))?;
let mut segment = to_segment(whisper_segment)?;
apply_time_offset(&mut segment, offset_seconds);
self.encoder
.write_segment(&segment)
.map_err(anyhow::Error::new)?;
}
let last_emitted_idx = emit_count - 1;
let last_emitted = state
.get_segment(last_emitted_idx as i32)
.with_context(|| format!("whisper segment {last_emitted_idx} was missing"))?;
let end_samples = segment_end_samples(last_emitted.end_timestamp(), win_len)?;
if end_of_stream {
self.head = self.samples.len();
self.advanced_samples += win_len;
} else {
self.head += end_samples;
self.advanced_samples += end_samples;
self.maybe_compact();
}
self.no_progress_runs = 0;
if !force_flush {
self.next_infer_at_samples = self.window_len() + self.min_window_samples;
} else {
self.next_infer_at_samples = self.min_window_samples;
}
Ok(Progress::Advanced)
}
}
impl SamplesSink for BufferedSegmentTranscriber<'_> {
fn on_samples(&mut self, samples_16k_mono: &[f32]) -> Result<bool> {
self.samples.extend_from_slice(samples_16k_mono);
let _ = self.process_available(false)?;
Ok(true)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Progress {
NoOp,
Advanced,
}
fn next_infer_threshold(
current_len: usize,
min_window_samples: usize,
max_window_samples: usize,
no_progress_runs: u32,
) -> usize {
let shift = no_progress_runs.saturating_sub(1).min(MAX_BACKOFF_SHIFT);
let step = min_window_samples.saturating_mul(1usize << shift);
let proposed = current_len.saturating_add(step);
proposed.min(max_window_samples)
}
fn segment_end_samples(end_timestamp_cs: i64, available_samples: usize) -> Result<usize> {
ensure!(
end_timestamp_cs >= 0,
"whisper returned negative end timestamp: {end_timestamp_cs}"
);
let end_timestamp_cs: usize = end_timestamp_cs
.try_into()
.context("whisper end timestamp did not fit in usize")?;
let mut end_samples = end_timestamp_cs.saturating_mul(TARGET_SAMPLE_RATE as usize) / 100;
if end_samples == 0 {
end_samples = 1;
}
if end_samples > available_samples {
end_samples = available_samples;
}
Ok(end_samples)
}
fn apply_time_offset(segment: &mut crate::segments::Segment, offset_seconds: f32) {
segment.start_seconds += offset_seconds;
segment.end_seconds += offset_seconds;
for token in &mut segment.tokens {
token.start_seconds += offset_seconds;
token.end_seconds += offset_seconds;
}
}