use serde::Serialize;
use std::sync::Arc;
use crate::error::PhosttError;
use super::engine::Engine;
use super::features;
use super::{CONTEXT_SIZE, N_MELS, SessionTriplet, TARGET_SAMPLE_RATE};
use kaldi_native_fbank::fbank::FbankComputer;
use kaldi_native_fbank::online::{FeatureComputer, OnlineFeature};
pub fn now_timestamp() -> f64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub window_frames: usize,
pub overlap_frames: usize,
pub fuzzy_match_threshold: f32,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
window_frames: 400,
overlap_frames: 100,
fuzzy_match_threshold: 1.0, }
}
}
impl StreamingConfig {
pub fn shift_frames(&self) -> usize {
self.window_frames.saturating_sub(self.overlap_frames)
}
pub fn shift_encoder_frames(&self) -> usize {
self.shift_frames() / 4
}
pub fn validate(&self) -> Result<(), String> {
if self.window_frames == 0 {
return Err("streaming window must be > 0 frames".into());
}
if self.overlap_frames >= self.window_frames {
return Err("streaming overlap must be smaller than window".into());
}
if !self.window_frames.is_multiple_of(4) {
return Err("streaming window must be a multiple of 4 (encoder subsampling)".into());
}
if !self.overlap_frames.is_multiple_of(4) {
return Err("streaming overlap must be a multiple of 4 (encoder subsampling)".into());
}
if !(0.0..=1.0).contains(&self.fuzzy_match_threshold) {
return Err("fuzzy_match_threshold must be in [0.0, 1.0]".into());
}
Ok(())
}
}
#[non_exhaustive]
pub struct DecoderState {
pub tokens: Vec<i64>,
pub blank_id: usize,
pub consecutive_blanks: usize,
}
impl StreamingState {
pub fn reset_utterance_state(&mut self) {
self.decoder = DecoderState::new(self.blank_id);
self.accumulated_text = Arc::new(String::new());
self.accumulated_words = Arc::new(Vec::new());
self.feature_window.clear();
self.prev_window_words.clear();
self.total_frames = 0;
}
}
impl DecoderState {
pub fn new(blank_id: usize) -> Self {
Self {
tokens: vec![blank_id as i64; CONTEXT_SIZE],
blank_id,
consecutive_blanks: 0,
}
}
pub fn push_token(&mut self, token: i64) {
self.tokens.rotate_left(1);
let last = self.tokens.last_mut().expect("CONTEXT_SIZE > 0");
*last = token;
}
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
pub struct WordInfo {
pub word: String,
pub start: f64,
pub end: f64,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
}
#[cfg(feature = "diarization")]
pub struct DiarizationStreamState {
pub diarizer: polyvoice::OnlineDiarizer,
}
#[non_exhaustive]
pub struct StreamingState {
pub decoder: DecoderState,
pub online: OnlineFeature,
pub frames_seen: usize,
pub accumulated_text: Arc<String>,
pub accumulated_words: Arc<Vec<WordInfo>>,
pub total_frames: usize,
pub feature_window: Vec<f32>,
pub prev_window_words: Vec<WordInfo>,
pub config: StreamingConfig,
pub blank_id: usize,
pub vad_session: Option<silero::Session>,
pub vad_stream_state: Option<silero::StreamState>,
pub vad_segmenter: Option<silero::SpeechSegmenter>,
pub vad_audio_buffer: Vec<f32>,
pub vad_sample_offset: u64,
pub vad_pending_asr: Vec<Vec<f32>>,
#[cfg(feature = "diarization")]
pub diarization_state: Option<DiarizationStreamState>,
}
impl Engine {
pub fn create_state(&self, diarization_enabled: bool) -> Result<StreamingState, PhosttError> {
#[cfg(feature = "diarization")]
let diarization_state = if diarization_enabled && self.speaker_encoder.is_some() {
Some(DiarizationStreamState {
diarizer: polyvoice::OnlineDiarizer::new(polyvoice::DiarizationConfig {
window_secs: 1.5,
hop_secs: 1.5, threshold: 0.5,
max_speakers: 64,
min_speech_secs: 0.25,
max_gap_secs: 0.5,
sample_rate: polyvoice::SampleRate::new(16000).expect("valid sample rate"),
}),
})
} else {
None
};
#[cfg(not(feature = "diarization"))]
if diarization_enabled {
tracing::warn!(
"diarization_enabled=true ignored: build lacks the `diarization` feature"
);
}
let computer = FbankComputer::new(features::phostt_fbank_options())
.map_err(|e| PhosttError::Inference(format!("FBANK init failed: {e}")))?;
let online = OnlineFeature::new(FeatureComputer::Fbank(computer));
let (vad_session, vad_stream_state, vad_segmenter) = if self.vad_enabled {
let session = silero::Session::bundled()
.map_err(|e| PhosttError::ModelLoad(format!("silero VAD load failed: {e}")))?;
let stream = silero::StreamState::new(silero::SampleRate::Rate16k);
let segmenter = silero::SpeechSegmenter::new(silero::SpeechOptions::default());
(Some(session), Some(stream), Some(segmenter))
} else {
(None, None, None)
};
let blank_id = self.tokenizer.blank_id();
Ok(StreamingState {
decoder: DecoderState::new(blank_id),
online,
frames_seen: 0,
accumulated_text: Arc::new(String::new()),
accumulated_words: Arc::new(Vec::new()),
total_frames: 0,
feature_window: Vec::new(),
prev_window_words: Vec::new(),
config: self.streaming_config.clone(),
blank_id,
vad_session,
vad_stream_state,
vad_segmenter,
vad_audio_buffer: Vec::new(),
vad_sample_offset: 0,
vad_pending_asr: Vec::new(),
#[cfg(feature = "diarization")]
diarization_state,
})
}
pub fn process_chunk(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, PhosttError> {
if samples.is_empty() {
return Ok(vec![]);
}
if state.vad_session.is_some() {
return self.process_chunk_vad(samples, state, triplet);
}
self.process_chunk_overlap(samples, state, triplet)
}
fn process_chunk_overlap(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, PhosttError> {
state
.online
.accept_waveform(TARGET_SAMPLE_RATE as f32, samples);
let ready = state.online.num_frames_ready();
let new_frames = ready.saturating_sub(state.frames_seen);
if new_frames == 0 {
return Ok(vec![]);
}
let new_features =
features::extract_online_frames(&state.online, state.frames_seen, new_frames);
state.frames_seen = ready;
state.feature_window.extend_from_slice(&new_features);
let mut emitted_words: Vec<WordInfo> = Vec::new();
let mut endpoint = false;
while state.feature_window.len() / N_MELS >= state.config.window_frames {
let num_frames = state.config.window_frames;
let features = &state.feature_window[..num_frames * N_MELS];
let frame_offset = state.total_frames;
let (window_words, window_endpoint, _enc_len) = self
.run_inference(
triplet,
features,
num_frames,
&mut state.decoder,
frame_offset,
)
.map_err(|e| PhosttError::Inference(format!("{e:#}")))?;
let delta = super::delta_words(
&window_words,
&state.prev_window_words,
state.config.fuzzy_match_threshold,
);
emitted_words.extend(delta);
state.prev_window_words = window_words;
let shift = state.config.shift_frames() * N_MELS;
state.feature_window.drain(..shift);
state.total_frames += state.config.shift_encoder_frames();
if window_endpoint {
endpoint = true;
break;
}
}
#[cfg(feature = "diarization")]
if let (Some(dia), Some(enc)) = (
state.diarization_state.as_mut(),
self.speaker_encoder.as_ref(),
) {
if let Err(e) = dia.diarizer.feed(samples, enc) {
tracing::warn!("Diarizer feed failed: {e:#}");
}
let speaker_id = dia.diarizer.current_speaker().map(|s| s.0);
for w in &mut emitted_words {
w.speaker = speaker_id;
}
}
if emitted_words.is_empty() && !endpoint {
return Ok(vec![]);
}
let acc_text = Arc::make_mut(&mut state.accumulated_text);
let acc_words = Arc::make_mut(&mut state.accumulated_words);
for w in &emitted_words {
if !acc_text.is_empty() {
acc_text.push(' ');
}
acc_text.push_str(&w.word);
}
acc_words.extend(emitted_words);
let text = Arc::clone(&state.accumulated_text);
let words = Arc::clone(&state.accumulated_words);
let ts = now_timestamp();
if endpoint {
state.accumulated_text = Arc::new(String::new());
state.accumulated_words = Arc::new(Vec::new());
state.decoder.consecutive_blanks = 0;
state.prev_window_words.clear();
Ok(vec![TranscriptSegment {
text,
words,
is_final: true,
timestamp: ts,
}])
} else {
Ok(vec![TranscriptSegment {
text,
words,
is_final: false,
timestamp: ts,
}])
}
}
fn process_chunk_vad(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, PhosttError> {
state.vad_audio_buffer.extend_from_slice(samples);
let (speech_segments, is_active) = {
let session = state.vad_session.as_mut().unwrap();
let stream = state.vad_stream_state.as_mut().unwrap();
let segmenter = state.vad_segmenter.as_mut().unwrap();
let mut segments: Vec<silero::SpeechSegment> = Vec::new();
session
.process_stream(stream, samples, |probability| {
if let Some(segment) = segmenter.push_probability(probability) {
segments.push(segment);
}
})
.map_err(|e| PhosttError::Inference(format!("VAD inference failed: {e}")))?;
let active = segmenter.is_active();
(segments, active)
};
let mut emitted_segments: Vec<TranscriptSegment> = Vec::new();
let buffer_start = state.vad_sample_offset;
for segment in &speech_segments {
let buf_start = segment.start_sample().saturating_sub(buffer_start) as usize;
let buf_end = segment.end_sample().saturating_sub(buffer_start) as usize;
if buf_end > state.vad_audio_buffer.len() {
tracing::warn!("VAD segment extends beyond audio buffer, skipping");
continue;
}
let speech_samples = &state.vad_audio_buffer[buf_start..buf_end];
if speech_samples.is_empty() {
continue;
}
state.vad_pending_asr.push(speech_samples.to_vec());
state.reset_utterance_state();
}
if let Some(last_seg) = speech_segments.last() {
let remove_up_to = (last_seg.end_sample().saturating_sub(buffer_start)) as usize;
if remove_up_to <= state.vad_audio_buffer.len() {
state.vad_audio_buffer.drain(..remove_up_to);
state.vad_sample_offset += remove_up_to as u64;
}
}
if is_active {
let partials = self.process_chunk_overlap(samples, state, triplet)?;
emitted_segments.extend(partials);
}
Ok(emitted_segments)
}
fn flush_state_vad(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Option<TranscriptSegment> {
let session = state.vad_session.as_mut()?;
let stream = state.vad_stream_state.as_mut()?;
let segmenter = state.vad_segmenter.as_mut()?;
if let Ok(Some(probability)) = session.flush_stream(stream)
&& let Some(segment) = segmenter.push_probability(probability)
{
let buffer_start = state.vad_sample_offset;
let buf_start = segment.start_sample().saturating_sub(buffer_start) as usize;
let buf_end = (segment.end_sample().saturating_sub(buffer_start) as usize)
.min(state.vad_audio_buffer.len());
if buf_start < buf_end {
let speech_samples = &state.vad_audio_buffer[buf_start..buf_end];
if let Ok(result) = self.transcribe_samples(speech_samples, triplet)
&& !result.text.is_empty()
{
state.reset_utterance_state();
return Some(TranscriptSegment {
text: Arc::new(result.text),
words: Arc::new(result.words),
is_final: true,
timestamp: now_timestamp(),
});
}
}
}
if let Some(segment) = segmenter.finish() {
let buffer_start = state.vad_sample_offset;
let buf_start = segment.start_sample().saturating_sub(buffer_start) as usize;
let buf_end = (segment.end_sample().saturating_sub(buffer_start) as usize)
.min(state.vad_audio_buffer.len());
if buf_start < buf_end {
let speech_samples = &state.vad_audio_buffer[buf_start..buf_end];
if let Ok(result) = self.transcribe_samples(speech_samples, triplet)
&& !result.text.is_empty()
{
state.reset_utterance_state();
return Some(TranscriptSegment {
text: Arc::new(result.text),
words: Arc::new(result.words),
is_final: true,
timestamp: now_timestamp(),
});
}
}
}
None
}
pub fn flush_state(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Option<TranscriptSegment> {
if state.vad_session.is_some() {
return self.flush_state_vad(state, triplet);
}
state.online.input_finished();
let ready = state.online.num_frames_ready();
let new_frames = ready.saturating_sub(state.frames_seen);
if new_frames > 0 {
let new_features =
features::extract_online_frames(&state.online, state.frames_seen, new_frames);
state.feature_window.extend_from_slice(&new_features);
state.frames_seen = ready;
}
if !state.feature_window.is_empty() {
let num_frames = state.feature_window.len() / N_MELS;
let features = &state.feature_window[..];
let frame_offset = state.total_frames;
let (window_words, _endpoint, _enc_len) = match self.run_inference(
triplet,
features,
num_frames,
&mut state.decoder,
frame_offset,
) {
Ok(r) => r,
Err(e) => {
tracing::error!("flush_state inference failed: {e:#}");
return None;
}
};
let delta = super::delta_words(
&window_words,
&state.prev_window_words,
state.config.fuzzy_match_threshold,
);
let acc_text = Arc::make_mut(&mut state.accumulated_text);
let acc_words = Arc::make_mut(&mut state.accumulated_words);
for w in &delta {
if !acc_text.is_empty() {
acc_text.push(' ');
}
acc_text.push_str(&w.word);
}
acc_words.extend(delta);
state.prev_window_words = window_words;
state.feature_window.clear();
state.total_frames += num_frames / 4;
}
if state.accumulated_text.is_empty() {
return None;
}
let seg = TranscriptSegment {
text: Arc::clone(&state.accumulated_text),
words: Arc::clone(&state.accumulated_words),
is_final: true,
timestamp: now_timestamp(),
};
Some(seg)
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TranscribeResult {
pub text: String,
pub words: Vec<WordInfo>,
pub duration_s: f64,
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
pub struct TranscriptSegment {
#[cfg_attr(feature = "openapi", schema(value_type = String))]
pub text: Arc<String>,
#[cfg_attr(feature = "openapi", schema(value_type = Vec<WordInfo>))]
pub words: Arc<Vec<WordInfo>>,
pub is_final: bool,
pub timestamp: f64,
}