use std::time::Instant;
use super::{
encoder::{StreamingEncoder, StreamingEncoderBackend},
mel_spectrogram::IncrementalMelSpectrogram,
retry_state::SessionRetryState,
types::{StreamingConfig, StreamingStats, TranscriptionEvent},
};
use crate::{
Array,
error::{LayerKeyedPayload, Result},
};
pub trait StreamingDecoderBackend {
fn decode_all_tokens(
&self,
audio_features: &Array,
confirmed_token_ids: &[u32],
config: &StreamingConfig,
max_tokens: usize,
) -> Result<Vec<u32>>;
}
pub trait StreamingTokenizer {
fn decode_ids(&self, ids: &[u32]) -> String;
}
#[derive(Debug, Default)]
struct SessionSharedState {
completed_text: String,
confirmed_token_ids: Vec<u32>,
provisional_token_ids: Vec<u32>,
provisional_first_seen: Vec<Instant>,
provisional_agreement_counts: Vec<usize>,
confirmed_text: String,
}
struct DecodePassParams<'a> {
audio_features: &'a Array,
confirmed_token_ids: Vec<u32>,
display_prefix: String,
prev_provisional: Vec<u32>,
prev_first_seen: Vec<Instant>,
prev_agreement_counts: Vec<usize>,
min_agreement_passes: usize,
}
pub struct StreamingInferenceSession<B, D, T> {
decoder: D,
tokenizer: T,
config: StreamingConfig,
mel_processor: IncrementalMelSpectrogram,
encoder: StreamingEncoder<B>,
shared: SessionSharedState,
is_active: bool,
total_samples_fed: usize,
last_decode_time: Option<Instant>,
boundary_fast_decode_until: Option<Instant>,
has_new_encoder_content: bool,
frozen_window_count: usize,
retry_state: SessionRetryState,
}
impl<B, D, T> StreamingInferenceSession<B, D, T>
where
B: StreamingEncoderBackend,
D: StreamingDecoderBackend,
T: StreamingTokenizer,
{
pub fn new(
decoder: D,
tokenizer: T,
config: StreamingConfig,
encoder_backend: B,
sample_rate: u32,
n_mels: usize,
overlap_frames: usize,
) -> Result<Self> {
let mel_processor = IncrementalMelSpectrogram::new(sample_rate, 400, 160, n_mels)?;
let max_cached_windows = config.max_cached_windows();
let encoder = StreamingEncoder::new(encoder_backend, max_cached_windows, overlap_frames);
Ok(Self {
decoder,
tokenizer,
config,
mel_processor,
encoder,
shared: SessionSharedState::default(),
is_active: true,
total_samples_fed: 0,
last_decode_time: None,
boundary_fast_decode_until: None,
has_new_encoder_content: false,
frozen_window_count: 0,
retry_state: SessionRetryState::new(),
})
}
#[inline(always)]
pub fn config(&self) -> &StreamingConfig {
&self.config
}
#[inline(always)]
pub fn total_samples_fed(&self) -> usize {
self.total_samples_fed
}
#[inline(always)]
pub fn encoded_window_count(&self) -> usize {
self.encoder.encoded_window_count()
}
#[inline(always)]
pub fn is_active(&self) -> bool {
self.is_active
}
#[cfg(test)]
pub(super) fn retry_state(&self) -> &SessionRetryState {
&self.retry_state
}
#[cfg(test)]
pub(super) fn retry_state_mut(&mut self) -> &mut SessionRetryState {
&mut self.retry_state
}
pub fn feed_audio(&mut self, samples: &[f32]) -> Result<Vec<TranscriptionEvent>> {
if !self.is_active {
return Ok(Vec::new());
}
self.total_samples_fed = self.total_samples_fed.saturating_add(samples.len());
let mut events: Vec<TranscriptionEvent> = Vec::new();
let (discharge_events, discharge_ran_decode) = self.discharge_retry_obligation()?;
events.extend(discharge_events);
if self.retry_state.has_obligation() {
return Ok(events);
}
let mel_opt = self.mel_processor.process(samples)?;
let new_windows = if let Some(mel_frames) = mel_opt.as_ref() {
self.encoder.feed(mel_frames)?
} else {
0
};
if new_windows > 0 || self.encoder.has_pending_frames() {
self.has_new_encoder_content = true;
}
let now = Instant::now();
if new_windows > 0 {
let boost = self.config.boundary_boost_seconds().max(0.0);
if boost > 0.0 {
self.boundary_fast_decode_until = Some(now + std::time::Duration::from_secs_f64(boost));
} else {
self.boundary_fast_decode_until = None;
}
}
let effective_decode_interval_seconds = if let Some(until) = self.boundary_fast_decode_until
&& now < until
{
let fast = self.config.boundary_decode_interval_seconds().max(0.05);
let normal = self.config.decode_interval_seconds().max(0.05);
fast.min(normal)
} else {
self.boundary_fast_decode_until = None;
self.config.decode_interval_seconds().max(0.05)
};
let has_pending_retries =
self.config.finalize_completed_windows() && !self.retry_state.finalize_queue().is_empty();
let should_decode =
if (self.config.finalize_completed_windows() && new_windows > 0) || has_pending_retries {
true
} else if let Some(last) = self.last_decode_time {
now.duration_since(last).as_secs_f64() >= effective_decode_interval_seconds
} else {
self.has_new_encoder_content
};
let skip_normal_decode = discharge_ran_decode && new_windows == 0;
if should_decode && (self.has_new_encoder_content || has_pending_retries) && !skip_normal_decode
{
self.has_new_encoder_content = false;
let is_boundary_finalize_pass = self.config.finalize_completed_windows() && new_windows > 0;
if !is_boundary_finalize_pass {
self.last_decode_time = Some(now);
}
if new_windows > 0 || has_pending_retries {
self.retry_state.arm_decode_owed();
}
let decode_events = self.run_decode_pass()?;
events.extend(decode_events);
self.retry_state.clear_decode_owed();
}
Ok(events)
}
pub fn stop(&mut self) -> Result<Vec<TranscriptionEvent>> {
if !self.is_active && !self.retry_state.has_obligation() {
return Ok(Vec::new());
}
let mut events: Vec<TranscriptionEvent> = Vec::new();
if self.retry_state.has_pending_stop_partial_decode()
&& !self.retry_state.has_pending_stop_encoder_feed()
&& !self.retry_state.has_decode_owed()
&& self.retry_state.finalize_queue().is_empty()
{
let audio_features = self
.retry_state
.take_stop_partial_decode_features()
.expect("guard above asserted has_pending_stop_partial_decode");
let reinstate = match clone_partial_decode_payload(audio_features.as_ref()) {
Ok(p) => p,
Err(e) => {
self.retry_state.arm_stop_partial_decode(audio_features);
return Err(e);
}
};
self.retry_state.arm_stop_partial_decode(reinstate);
self.finalize_partial_window_and_emit_ended(audio_features, &mut events)?;
let _ = self.retry_state.take_stop_partial_decode_features();
self.is_active = false;
self.encoder.reset();
self.mel_processor.reset();
self.boundary_fast_decode_until = None;
self.retry_state.clear_all();
return Ok(events);
}
let (discharge_events, _ran_decode) = self.discharge_retry_obligation()?;
events.extend(discharge_events);
if self.retry_state.has_obligation() {
return Ok(events);
}
self.retry_state.stage_stop_mel_flush();
let mel_opt = self.mel_processor.flush()?;
self.retry_state.clear_stop_mel_flush();
if let Some(mel_frames) = mel_opt {
self.retry_state.stage_stop_encoder_feed(mel_frames);
let _drain_window_count = self
.retry_state
.discharge_stop_encoder_feed(&mut self.encoder)?;
self.retry_state.clear_decode_owed();
}
if self.config.finalize_completed_windows() {
let drained = self.encoder.drain_newly_encoded_windows();
for window in drained {
self.retry_state.enqueue_finalize(window);
}
if !self.retry_state.finalize_queue().is_empty() {
let finalize_events = self.finalize_completed_windows()?;
events.extend(finalize_events);
}
} else {
self.freeze_completed_windows();
}
let audio_features = self.encoder.encode_pending()?;
let reinstate = clone_partial_decode_payload(audio_features.as_ref())?;
self.retry_state.arm_stop_partial_decode(reinstate);
self.finalize_partial_window_and_emit_ended(audio_features, &mut events)?;
let _ = self.retry_state.take_stop_partial_decode_features();
self.is_active = false;
self.encoder.reset();
self.mel_processor.reset();
self.boundary_fast_decode_until = None;
self.retry_state.clear_all();
Ok(events)
}
pub fn cancel(&mut self) {
self.is_active = false;
self.encoder.reset();
self.mel_processor.reset();
self.boundary_fast_decode_until = None;
self.shared = SessionSharedState::default();
self.retry_state.clear_all();
}
pub fn reset(&mut self) {
self.is_active = true;
self.total_samples_fed = 0;
self.last_decode_time = None;
self.boundary_fast_decode_until = None;
self.has_new_encoder_content = false;
self.frozen_window_count = 0;
self.encoder.reset();
self.mel_processor.reset();
self.shared = SessionSharedState::default();
self.retry_state.clear_all();
}
fn discharge_retry_obligation(&mut self) -> Result<(Vec<TranscriptionEvent>, bool)> {
let mut events: Vec<TranscriptionEvent> = Vec::new();
let mut ran_decode = false;
if self.retry_state.has_pending_stop_mel_flush() {
let _mel_opt = self
.retry_state
.discharge_stop_mel_flush(&mut self.mel_processor)?;
}
if self.retry_state.has_pending_stop_encoder_feed() {
let drain_window_count = self
.retry_state
.discharge_stop_encoder_feed(&mut self.encoder)?;
if drain_window_count > 0 || self.encoder.has_pending_frames() {
self.has_new_encoder_content = true;
}
}
if self.retry_state.has_decode_owed() {
self.has_new_encoder_content = false;
let decode_events = self.run_decode_pass()?;
self.retry_state.clear_decode_owed();
events.extend(decode_events);
ran_decode = true;
self.last_decode_time = Some(Instant::now());
}
if !self.retry_state.finalize_queue().is_empty() && self.retry_state.resume_at().is_none() {
if self.config.finalize_completed_windows() {
let finalize_events = self.finalize_completed_windows()?;
events.extend(finalize_events);
ran_decode = true;
self.last_decode_time = Some(Instant::now());
}
}
Ok((events, ran_decode))
}
fn finalize_partial_window_and_emit_ended(
&mut self,
audio_features: Option<Array>,
events: &mut Vec<TranscriptionEvent>,
) -> Result<()> {
if let Some(audio_features) = audio_features {
if audio_features.shape().first().copied().unwrap_or(0) > 0 {
let display_prefix = concat_text(&self.shared.completed_text, &self.shared.confirmed_text);
let confirmed_count = self.shared.confirmed_token_ids.len();
let estimated_tokens = self
.config
.max_tokens_per_pass()
.min(confirmed_count.saturating_add(24).max(24));
let token_ids = self.decoder.decode_all_tokens(
&audio_features,
&self.shared.confirmed_token_ids,
&self.config,
estimated_tokens,
)?;
self.shared.confirmed_token_ids = token_ids;
self.shared.provisional_token_ids.clear();
self.shared.provisional_first_seen.clear();
self.shared.provisional_agreement_counts.clear();
self.shared.confirmed_text = self.tokenizer.decode_ids(&self.shared.confirmed_token_ids);
let _ = display_prefix; }
} else {
if !self.shared.provisional_token_ids.is_empty() {
let promoted = std::mem::take(&mut self.shared.provisional_token_ids);
self.shared.confirmed_token_ids.extend(promoted);
self.shared.provisional_first_seen.clear();
self.shared.provisional_agreement_counts.clear();
}
if !self.shared.confirmed_token_ids.is_empty() {
self.shared.confirmed_text = self.tokenizer.decode_ids(&self.shared.confirmed_token_ids);
}
}
let final_text = concat_text(&self.shared.completed_text, &self.shared.confirmed_text);
events.push(TranscriptionEvent::ended(final_text));
Ok(())
}
fn run_decode_pass(&mut self) -> Result<Vec<TranscriptionEvent>> {
if self.config.finalize_completed_windows() {
let drained = self.encoder.drain_newly_encoded_windows();
for window in drained {
self.retry_state.enqueue_finalize(window);
}
if !self.retry_state.finalize_queue().is_empty() {
return self.finalize_completed_windows();
}
} else {
self.freeze_completed_windows();
}
let Some(audio_features) = self.encoder.encode_pending()? else {
return Ok(Vec::new());
};
let num_audio_tokens = audio_features.shape().first().copied().unwrap_or(0);
if num_audio_tokens == 0 {
return Ok(Vec::new());
}
let confirmed_count = self.shared.confirmed_token_ids.len();
let windowed_seconds = num_audio_tokens as f64 / 13.0;
let estimated_total_tokens = ((windowed_seconds * 10.0).ceil() as usize).max(24);
let max_tokens = self
.config
.max_tokens_per_pass()
.min(estimated_total_tokens.max(confirmed_count.saturating_add(24)));
let display_prefix = concat_text(&self.shared.completed_text, &self.shared.confirmed_text);
let min_agreement_passes = if let Some(until) = self.boundary_fast_decode_until
&& Instant::now() < until
{
self
.config
.min_agreement_passes()
.max(self.config.boundary_min_agreement_passes())
.max(1)
} else {
self.config.min_agreement_passes().max(1)
};
let params = DecodePassParams {
audio_features: &audio_features,
confirmed_token_ids: self.shared.confirmed_token_ids.clone(),
display_prefix,
prev_provisional: self.shared.provisional_token_ids.clone(),
prev_first_seen: self.shared.provisional_first_seen.clone(),
prev_agreement_counts: self.shared.provisional_agreement_counts.clone(),
min_agreement_passes,
};
let start = Instant::now();
let all_token_ids = self.decoder.decode_all_tokens(
params.audio_features,
¶ms.confirmed_token_ids,
&self.config,
max_tokens,
)?;
let decode_time = start.elapsed().as_secs_f64();
Ok(self.promote_tokens(&all_token_ids, ¶ms, decode_time))
}
fn promote_tokens(
&mut self,
all_token_ids: &[u32],
params: &DecodePassParams<'_>,
decode_time: f64,
) -> Vec<TranscriptionEvent> {
let confirmed_count = params.confirmed_token_ids.len();
let new_provisional: Vec<u32> = all_token_ids
.iter()
.skip(confirmed_count)
.copied()
.collect();
let gen_token_count = all_token_ids.len();
let now = Instant::now();
let delay = std::time::Duration::from_millis(u64::from(self.config.delay_preset().delay_ms()));
let mut match_len = 0;
let compare_len = params.prev_provisional.len().min(new_provisional.len());
for (i, new_id) in new_provisional.iter().enumerate().take(compare_len) {
if params.prev_provisional[i] == *new_id {
match_len = i + 1;
} else {
break;
}
}
let mut next_first_seen: Vec<Instant> = Vec::with_capacity(new_provisional.len());
let mut next_agreement_counts: Vec<usize> = Vec::with_capacity(new_provisional.len());
for i in 0..new_provisional.len() {
if i < match_len {
let seen = params.prev_first_seen.get(i).copied().unwrap_or(now);
let prev_agreement = params.prev_agreement_counts.get(i).copied().unwrap_or(1);
next_first_seen.push(seen);
next_agreement_counts.push(prev_agreement.saturating_add(1).max(1));
} else {
next_first_seen.push(now);
next_agreement_counts.push(1);
}
}
let required_agreement_passes = params.min_agreement_passes.max(1);
let mut promotion_count = 0;
for i in 0..new_provisional.len() {
let has_delay = next_first_seen
.get(i)
.map(|t| now.duration_since(*t) >= delay)
.unwrap_or(false);
let has_agreement = next_agreement_counts
.get(i)
.map(|c| *c >= required_agreement_passes)
.unwrap_or(false);
if has_delay && has_agreement {
promotion_count = i + 1;
} else {
break;
}
}
let final_provisional: Vec<u32> = new_provisional
.iter()
.skip(promotion_count)
.copied()
.collect();
let final_first_seen: Vec<Instant> = next_first_seen
.iter()
.skip(promotion_count)
.copied()
.collect();
let final_agreement_counts: Vec<usize> = next_agreement_counts
.iter()
.skip(promotion_count)
.copied()
.collect();
let mut events: Vec<TranscriptionEvent> = Vec::new();
if promotion_count > 0 {
let promoted: Vec<u32> = new_provisional[..promotion_count].to_vec();
self.shared.confirmed_token_ids.extend(promoted);
self.shared.confirmed_text = self.tokenizer.decode_ids(&self.shared.confirmed_token_ids);
events.push(TranscriptionEvent::confirmed(concat_text(
&self.shared.completed_text,
&self.shared.confirmed_text,
)));
}
self.shared.provisional_token_ids = final_provisional.clone();
self.shared.provisional_first_seen = final_first_seen;
self.shared.provisional_agreement_counts = final_agreement_counts;
let final_prov_text = self.tokenizer.decode_ids(&final_provisional);
let display_prefix = concat_text(&self.shared.completed_text, &self.shared.confirmed_text);
events.push(TranscriptionEvent::display_update(
display_prefix,
final_prov_text,
));
let _ = params.display_prefix;
let total_audio_seconds = self.total_samples_fed as f64 / 16_000.0;
let tps = if decode_time > 0.0 {
gen_token_count as f64 / decode_time
} else {
0.0
};
events.push(TranscriptionEvent::Stats(StreamingStats {
encoded_window_count: self.encoder.encoded_window_count(),
total_audio_seconds,
tokens_per_second: tps,
real_time_factor: 0.0,
peak_memory_gb: peak_memory_gb_or_zero(),
}));
events
}
fn finalize_completed_windows(&mut self) -> Result<Vec<TranscriptionEvent>> {
if self.retry_state.finalize_queue().is_empty() {
return Ok(Vec::new());
}
let mut total_decode_time: f64 = 0.0;
let mut total_generated_tokens: usize = 0;
let mut events: Vec<TranscriptionEvent> = Vec::new();
while let Some(pending) = self.retry_state.finalize_queue_mut().front_mut() {
let candidate_fallback = if !pending.fallback_consumed {
pending.fallback_consumed = true;
let mut stream_tokens: Vec<u32> = self.shared.confirmed_token_ids.clone();
stream_tokens.extend(self.shared.provisional_token_ids.iter().copied());
if stream_tokens.is_empty() {
None
} else {
Some(self.tokenizer.decode_ids(&stream_tokens))
}
} else {
None
};
let num_audio_tokens = pending.encoder_output.shape().first().copied().unwrap_or(0);
let selected_window_text = if num_audio_tokens == 0 {
candidate_fallback.unwrap_or_default()
} else {
let start = Instant::now();
let token_ids = self.decoder.decode_all_tokens(
&pending.encoder_output,
&[],
&self.config,
self.config.max_tokens_per_pass(),
)?;
let decode_time = start.elapsed().as_secs_f64();
total_decode_time += decode_time;
total_generated_tokens = total_generated_tokens.saturating_add(token_ids.len());
let full_text = self.tokenizer.decode_ids(&token_ids);
if full_text.trim().is_empty()
&& let Some(fallback) = candidate_fallback
{
fallback
} else {
full_text
}
};
if !selected_window_text.trim().is_empty() {
append_text(&selected_window_text, &mut self.shared.completed_text);
}
self.shared.confirmed_token_ids.clear();
self.shared.provisional_token_ids.clear();
self.shared.provisional_first_seen.clear();
self.shared.provisional_agreement_counts.clear();
self.shared.confirmed_text.clear();
self.retry_state.finalize_queue_mut().pop_front();
self.frozen_window_count = self.frozen_window_count.saturating_add(1);
}
let total_audio_seconds = self.total_samples_fed as f64 / 16_000.0;
let tps = if total_decode_time > 0.0 {
total_generated_tokens as f64 / total_decode_time
} else {
0.0
};
events.push(TranscriptionEvent::Stats(StreamingStats {
encoded_window_count: self.encoder.encoded_window_count(),
total_audio_seconds,
tokens_per_second: tps,
real_time_factor: 0.0,
peak_memory_gb: peak_memory_gb_or_zero(),
}));
Ok(events)
}
fn freeze_completed_windows(&mut self) {
let current = self.encoder.encoded_window_count();
if current <= self.frozen_window_count {
return;
}
let mut all_tokens: Vec<u32> = self.shared.confirmed_token_ids.clone();
all_tokens.extend(self.shared.provisional_token_ids.iter().copied());
if !all_tokens.is_empty() {
let window_text = self.tokenizer.decode_ids(&all_tokens);
append_text(&window_text, &mut self.shared.completed_text);
}
self.shared.confirmed_token_ids.clear();
self.shared.provisional_token_ids.clear();
self.shared.provisional_first_seen.clear();
self.shared.provisional_agreement_counts.clear();
self.shared.confirmed_text.clear();
self.frozen_window_count = current;
}
}
fn append_text(segment: &str, base: &mut String) {
let trimmed = segment.trim();
if trimmed.is_empty() {
return;
}
if base.is_empty() {
base.push_str(trimmed);
return;
}
let base_last_is_ws = base.chars().last().is_some_and(char::is_whitespace);
let seg_first_is_ws = trimmed.chars().next().is_some_and(char::is_whitespace);
if base_last_is_ws || seg_first_is_ws {
base.push_str(trimmed);
} else {
base.push(' ');
base.push_str(trimmed);
}
}
fn concat_text(a: &str, b: &str) -> String {
let mut out = String::with_capacity(a.len() + b.len() + 1);
out.push_str(a);
append_text(b, &mut out);
out
}
fn clone_partial_decode_payload(features: Option<&Array>) -> Result<Option<Array>> {
match features {
None => Ok(None),
Some(a) => a.try_clone().map(Some).map_err(|e| {
crate::Error::LayerKeyed(LayerKeyedPayload::new(
"StopPartialDecode: failed to clone audio_features for retry",
e,
))
}),
}
}
fn peak_memory_gb_or_zero() -> f64 {
crate::memory::peak_memory()
.map(|bytes| bytes as f64 / 1e9)
.unwrap_or(0.0)
}
#[cfg(test)]
mod tests;