pub mod audio;
mod decode;
mod features;
mod tokenizer;
#[cfg(feature = "diarization")]
pub mod diarization;
#[cfg(all(feature = "coreml", feature = "cuda"))]
compile_error!("Features `coreml` and `cuda` are mutually exclusive. Choose one.");
use anyhow::Context;
#[cfg(any(feature = "coreml", feature = "cuda"))]
use ort::ep;
use ort::session::Session;
use ort::value::TensorRef;
use serde::Serialize;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use std::sync::Arc;
use crate::error::GigasttError;
use features::MelSpectrogram;
use kaldi_native_fbank::fbank::FbankComputer;
use kaldi_native_fbank::online::{FeatureComputer, OnlineFeature};
use tokenizer::Tokenizer;
pub const N_MELS: usize = 80;
pub const N_FFT: usize = 400;
pub const HOP_LENGTH: usize = 160;
pub const ENCODER_OUT_DIM: usize = 512;
pub const DECODER_OUT_DIM: usize = 512;
pub const CONTEXT_SIZE: usize = 2;
fn ort_err(e: impl std::fmt::Display) -> anyhow::Error {
anyhow::anyhow!("{e}")
}
pub(crate) fn now_timestamp() -> f64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
const SECONDS_PER_FRAME: f64 = (HOP_LENGTH as f64) * 4.0 / 16000.0;
const DEFAULT_POOL_SIZE: usize = 4;
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub window_frames: usize,
pub overlap_frames: usize,
pub fuzzy_match_threshold: f32,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
window_frames: 400,
overlap_frames: 100,
fuzzy_match_threshold: 1.0, }
}
}
impl StreamingConfig {
pub fn shift_frames(&self) -> usize {
self.window_frames.saturating_sub(self.overlap_frames)
}
pub fn shift_encoder_frames(&self) -> usize {
self.shift_frames() / 4
}
pub fn validate(&self) -> Result<(), String> {
if self.window_frames == 0 {
return Err("streaming window must be > 0 frames".into());
}
if self.overlap_frames >= self.window_frames {
return Err("streaming overlap must be smaller than window".into());
}
if !self.window_frames.is_multiple_of(4) {
return Err("streaming window must be a multiple of 4 (encoder subsampling)".into());
}
if !self.overlap_frames.is_multiple_of(4) {
return Err("streaming overlap must be a multiple of 4 (encoder subsampling)".into());
}
if !(0.0..=1.0).contains(&self.fuzzy_match_threshold) {
return Err("fuzzy_match_threshold must be in [0.0, 1.0]".into());
}
Ok(())
}
}
pub struct SessionTriplet {
pub(crate) encoder: Session,
pub(crate) decoder: Session,
pub(crate) joiner: Session,
}
#[derive(Debug)]
pub enum PoolError {
Closed,
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolError::Closed => write!(f, "session pool is closed"),
}
}
}
impl std::error::Error for PoolError {}
pub struct Pool<T> {
sender: async_channel::Sender<T>,
receiver: async_channel::Receiver<T>,
total: usize,
}
pub type SessionPool = Pool<SessionTriplet>;
impl<T> Pool<T> {
pub fn new(items: Vec<T>) -> Self {
let total = items.len();
let (sender, receiver) = async_channel::bounded(total.max(1));
for item in items {
sender
.try_send(item)
.expect("channel capacity matches item count");
}
Self {
sender,
receiver,
total,
}
}
pub async fn checkout(&self) -> Result<PoolGuard<'_, T>, PoolError> {
match self.receiver.recv().await {
Ok(item) => Ok(PoolGuard {
pool: self,
item: Some(item),
}),
Err(_) => Err(PoolError::Closed),
}
}
pub fn checkout_blocking(&self) -> Result<PoolGuard<'_, T>, PoolError> {
match self.receiver.recv_blocking() {
Ok(item) => Ok(PoolGuard {
pool: self,
item: Some(item),
}),
Err(_) => Err(PoolError::Closed),
}
}
pub fn close(&self) {
self.sender.close();
self.receiver.close();
}
pub fn total(&self) -> usize {
self.total
}
pub fn available(&self) -> usize {
self.receiver.len()
}
}
pub struct PoolGuard<'a, T> {
pool: &'a Pool<T>,
item: Option<T>,
}
impl<T> PoolGuard<'_, T> {
pub fn into_owned(mut self) -> (T, OwnedReservation<T>) {
let item = self
.item
.take()
.expect("PoolGuard::into_owned called after drop");
let reservation = OwnedReservation {
sender: self.pool.sender.clone(),
};
(item, reservation)
}
}
impl<T> Deref for PoolGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.item
.as_ref()
.expect("PoolGuard accessed after item taken")
}
}
impl<T> DerefMut for PoolGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.item
.as_mut()
.expect("PoolGuard accessed after item taken")
}
}
impl<T> Drop for PoolGuard<'_, T> {
fn drop(&mut self) {
if let Some(item) = self.item.take() {
let _ = self.pool.sender.try_send(item);
}
}
}
pub struct OwnedReservation<T> {
sender: async_channel::Sender<T>,
}
impl<T> OwnedReservation<T> {
pub fn checkin(self, item: T) {
let _ = self.sender.try_send(item);
}
pub fn guard(self, item: T) -> PoolItemGuard<T> {
PoolItemGuard {
reservation: self,
item: Some(item),
}
}
}
pub struct PoolItemGuard<T> {
reservation: OwnedReservation<T>,
item: Option<T>,
}
impl<T> PoolItemGuard<T> {
pub fn item_mut(&mut self) -> &mut T {
self.item
.as_mut()
.expect("PoolItemGuard item already taken")
}
pub fn item(&self) -> &T {
self.item
.as_ref()
.expect("PoolItemGuard item already taken")
}
pub fn into_inner(mut self) -> T {
self.item.take().expect("PoolItemGuard item already taken")
}
}
impl<T> Deref for PoolItemGuard<T> {
type Target = T;
fn deref(&self) -> &T {
self.item()
}
}
impl<T> DerefMut for PoolItemGuard<T> {
fn deref_mut(&mut self) -> &mut T {
self.item_mut()
}
}
impl<T> Drop for PoolItemGuard<T> {
fn drop(&mut self) {
if let Some(item) = self.item.take() {
let _ = self.reservation.sender.try_send(item);
}
}
}
#[non_exhaustive]
pub struct DecoderState {
pub tokens: Vec<i64>,
pub blank_id: usize,
pub consecutive_blanks: usize,
}
impl StreamingState {
pub fn reset_utterance_state(&mut self) {
self.decoder = DecoderState::new(self.blank_id);
self.accumulated_text = Arc::new(String::new());
self.accumulated_words = Arc::new(Vec::new());
self.feature_window.clear();
self.prev_window_words.clear();
self.total_frames = 0;
}
}
impl DecoderState {
pub fn new(blank_id: usize) -> Self {
Self {
tokens: vec![blank_id as i64; CONTEXT_SIZE],
blank_id,
consecutive_blanks: 0,
}
}
pub fn push_token(&mut self, token: i64) {
self.tokens.rotate_left(1);
let last = self.tokens.last_mut().expect("CONTEXT_SIZE > 0");
*last = token;
}
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct WordInfo {
pub word: String,
pub start: f64,
pub end: f64,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
}
#[cfg(feature = "diarization")]
pub struct DiarizationStreamState {
pub audio_buffer: Vec<f32>,
pub cluster: diarization::SpeakerCluster,
pub current_speaker: Option<u32>,
}
#[non_exhaustive]
pub struct StreamingState {
pub decoder: DecoderState,
pub online: OnlineFeature,
pub frames_seen: usize,
pub accumulated_text: Arc<String>,
pub accumulated_words: Arc<Vec<WordInfo>>,
pub total_frames: usize,
pub feature_window: Vec<f32>,
pub prev_window_words: Vec<WordInfo>,
pub config: StreamingConfig,
pub blank_id: usize,
pub vad_session: Option<silero::Session>,
pub vad_stream_state: Option<silero::StreamState>,
pub vad_segmenter: Option<silero::SpeechSegmenter>,
pub vad_audio_buffer: Vec<f32>,
pub vad_sample_offset: u64,
pub vad_pending_asr: Vec<Vec<f32>>,
#[cfg(feature = "diarization")]
pub diarization_state: Option<DiarizationStreamState>,
}
pub struct Engine {
pub pool: SessionPool,
tokenizer: Tokenizer,
mel: MelSpectrogram,
streaming_config: StreamingConfig,
vad_enabled: bool,
#[cfg(feature = "diarization")]
pub speaker_encoder: Option<diarization::SpeakerEncoder>,
}
impl Engine {
pub fn vocab_size(&self) -> usize {
self.tokenizer.vocab_size()
}
#[cfg(test)]
pub fn test_stub() -> Self {
Self {
pool: Pool::new(vec![]),
tokenizer: Tokenizer::from_tokens(vec!["<blk>".into(), "a".into()], 0),
mel: MelSpectrogram::new(),
streaming_config: StreamingConfig::default(),
vad_enabled: false,
#[cfg(feature = "diarization")]
speaker_encoder: None,
}
}
pub fn load(model_dir: &str) -> Result<Self, GigasttError> {
Self::load_with_pool_size(model_dir, DEFAULT_POOL_SIZE)
}
pub fn load_with_pool_size(model_dir: &str, pool_size: usize) -> Result<Self, GigasttError> {
Self::load_with_pool_size_and_config(model_dir, pool_size, StreamingConfig::default())
}
pub fn load_with_pool_size_and_config(
model_dir: &str,
pool_size: usize,
streaming_config: StreamingConfig,
) -> Result<Self, GigasttError> {
Self::load_with_pool_size_and_config_and_vad(model_dir, pool_size, streaming_config, false)
}
pub fn load_with_pool_size_and_config_and_vad(
model_dir: &str,
pool_size: usize,
streaming_config: StreamingConfig,
vad_enabled: bool,
) -> Result<Self, GigasttError> {
streaming_config
.validate()
.map_err(|e| GigasttError::ModelLoad(format!("invalid streaming config: {e}")))?;
let dir = Path::new(model_dir);
if !dir.join("encoder.int8.onnx").exists() {
return Err(GigasttError::ModelLoad(format!(
"encoder.int8.onnx not found in {model_dir}"
)));
}
Self::load_inner(dir, model_dir, pool_size, streaming_config, vad_enabled)
.map_err(|e| GigasttError::ModelLoad(format!("{e:#}")))
}
fn load_sessions(
dir: &Path,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
let encoder_path = dir.join("encoder.int8.onnx");
#[cfg(feature = "coreml")]
let (encoder, decoder, joiner) = {
let cache_dir = dir.join("coreml_cache");
let coreml_ep = ep::CoreML::default()
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
.with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
.with_model_cache_dir(cache_dir.to_string_lossy())
.build();
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([coreml_ep.clone()])
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([coreml_ep.clone()])
.map_err(ort_err)?
.commit_from_file(dir.join("decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([coreml_ep])
.map_err(ort_err)?
.commit_from_file(dir.join("joiner.int8.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
#[cfg(feature = "cuda")]
let (encoder, decoder, joiner) = {
let cuda_ep = ep::CUDA::default().build();
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([cuda_ep.clone()])
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([cuda_ep.clone()])
.map_err(ort_err)?
.commit_from_file(dir.join("decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([cuda_ep])
.map_err(ort_err)?
.commit_from_file(dir.join("joiner.int8.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
let (encoder, decoder, joiner) = {
let cache_dir = dir.join("optimized_cache");
std::fs::create_dir_all(&cache_dir).ok();
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_optimized_model_path(cache_dir.join("encoder_optimized.onnx"))
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.commit_from_file(dir.join("decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.commit_from_file(dir.join("joiner.int8.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
Ok((encoder, decoder, joiner))
}
fn load_inner(
dir: &Path,
model_dir: &str,
pool_size: usize,
streaming_config: StreamingConfig,
vad_enabled: bool,
) -> anyhow::Result<Self> {
tracing::info!(
"Loading Zipformer-vi INT8 ONNX models from {model_dir} (pool_size={pool_size})..."
);
#[cfg(feature = "coreml")]
tracing::info!("Using CoreML execution provider (Neural Engine + CPU)");
#[cfg(feature = "cuda")]
tracing::info!("Using CUDA execution provider (falls back to CPU if unavailable)");
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
tracing::info!("Using CPU execution provider");
let prepacked = ort::session::builder::PrepackedWeights::new();
let triplets: Vec<SessionTriplet> = std::thread::scope(|s| {
let handles: Vec<_> = (0..pool_size)
.map(|i| {
let pp = &prepacked;
s.spawn(move || {
tracing::info!(
"Loading session triplet {}/{pool_size} (shared weights)",
i + 1
);
let (encoder, decoder, joiner) = Self::load_sessions(dir, pp)?;
Ok(SessionTriplet {
encoder,
decoder,
joiner,
})
})
})
.collect();
handles
.into_iter()
.map(|h| {
h.join()
.map_err(|_| anyhow::anyhow!("Thread panicked during model loading"))?
})
.collect::<anyhow::Result<Vec<_>>>()
})?;
let tokenizer = Tokenizer::load(&dir.join("tokens.txt"))?;
let mel = MelSpectrogram::new();
tracing::info!(
"Models loaded (vocab_size={}, pool_size={pool_size})",
tokenizer.vocab_size()
);
#[cfg(feature = "diarization")]
let speaker_encoder = match diarization::SpeakerEncoder::load(dir) {
Ok(enc) => {
tracing::info!("Speaker encoder loaded (diarization available)");
Some(enc)
}
Err(e) => {
tracing::warn!("Speaker encoder not loaded, diarization unavailable: {e:#}");
None
}
};
Ok(Self {
pool: SessionPool::new(triplets),
tokenizer,
mel,
streaming_config,
vad_enabled,
#[cfg(feature = "diarization")]
speaker_encoder,
})
}
#[cfg(feature = "diarization")]
pub fn has_speaker_encoder(&self) -> bool {
self.speaker_encoder.is_some()
}
pub fn create_state(&self, diarization_enabled: bool) -> Result<StreamingState, GigasttError> {
#[cfg(feature = "diarization")]
let diarization_state = if diarization_enabled && self.speaker_encoder.is_some() {
Some(DiarizationStreamState {
audio_buffer: Vec::new(),
cluster: diarization::SpeakerCluster::new(),
current_speaker: None,
})
} else {
None
};
#[cfg(not(feature = "diarization"))]
if diarization_enabled {
tracing::warn!(
"diarization_enabled=true ignored: build lacks the `diarization` feature"
);
}
let computer = FbankComputer::new(features::phostt_fbank_options())
.map_err(|e| GigasttError::Inference(format!("FBANK init failed: {e}")))?;
let online = OnlineFeature::new(FeatureComputer::Fbank(computer));
let (vad_session, vad_stream_state, vad_segmenter) = if self.vad_enabled {
let session = silero::Session::bundled()
.map_err(|e| GigasttError::ModelLoad(format!("silero VAD load failed: {e}")))?;
let stream = silero::StreamState::new(silero::SampleRate::Rate16k);
let segmenter = silero::SpeechSegmenter::new(silero::SpeechOptions::default());
(Some(session), Some(stream), Some(segmenter))
} else {
(None, None, None)
};
let blank_id = self.tokenizer.blank_id();
Ok(StreamingState {
decoder: DecoderState::new(blank_id),
online,
frames_seen: 0,
accumulated_text: Arc::new(String::new()),
accumulated_words: Arc::new(Vec::new()),
total_frames: 0,
feature_window: Vec::new(),
prev_window_words: Vec::new(),
config: self.streaming_config.clone(),
blank_id,
vad_session,
vad_stream_state,
vad_segmenter,
vad_audio_buffer: Vec::new(),
vad_sample_offset: 0,
vad_pending_asr: Vec::new(),
#[cfg(feature = "diarization")]
diarization_state,
})
}
pub fn process_chunk(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, GigasttError> {
if samples.is_empty() {
return Ok(vec![]);
}
if state.vad_session.is_some() {
return self.process_chunk_vad(samples, state, triplet);
}
self.process_chunk_overlap(samples, state, triplet)
}
fn process_chunk_overlap(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, GigasttError> {
#[cfg(feature = "diarization")]
let samples_16k_copy = if state.diarization_state.is_some() {
Some(samples.to_vec())
} else {
None
};
state.online.accept_waveform(16000.0, samples);
let ready = state.online.num_frames_ready();
let new_frames = ready.saturating_sub(state.frames_seen);
if new_frames == 0 {
return Ok(vec![]);
}
let new_features =
features::extract_online_frames(&state.online, state.frames_seen, new_frames);
state.frames_seen = ready;
state.feature_window.extend_from_slice(&new_features);
let mut emitted_words: Vec<WordInfo> = Vec::new();
let mut endpoint = false;
while state.feature_window.len() / N_MELS >= state.config.window_frames {
let num_frames = state.config.window_frames;
let features = &state.feature_window[..num_frames * N_MELS];
let frame_offset = state.total_frames;
let (window_words, window_endpoint, _enc_len) = self
.run_inference(
triplet,
features,
num_frames,
&mut state.decoder,
frame_offset,
)
.map_err(|e| GigasttError::Inference(format!("{e:#}")))?;
let delta = Self::delta_words(
&window_words,
&state.prev_window_words,
state.config.fuzzy_match_threshold,
);
emitted_words.extend(delta);
state.prev_window_words = window_words;
let shift = state.config.shift_frames() * N_MELS;
state.feature_window.drain(..shift);
state.total_frames += state.config.shift_encoder_frames();
if window_endpoint {
endpoint = true;
break;
}
}
#[cfg(feature = "diarization")]
if let (Some(dia), Some(copy), Some(enc)) = (
state.diarization_state.as_mut(),
samples_16k_copy.as_ref(),
self.speaker_encoder.as_ref(),
) {
dia.audio_buffer.extend_from_slice(copy);
while dia.audio_buffer.len() >= diarization::SEGMENT_SAMPLES {
let segment: Vec<f32> = dia
.audio_buffer
.drain(..diarization::SEGMENT_SAMPLES)
.collect();
match enc.extract_embedding(&segment) {
Ok(embedding) => {
let speaker = dia.cluster.assign(&embedding);
dia.current_speaker = Some(speaker);
}
Err(e) => {
tracing::warn!("Embedding extraction failed: {e:#}");
}
}
}
if let Some(speaker_id) = dia.current_speaker {
for w in &mut emitted_words {
w.speaker = Some(speaker_id);
}
}
}
if emitted_words.is_empty() && !endpoint {
return Ok(vec![]);
}
let acc_text = Arc::make_mut(&mut state.accumulated_text);
let acc_words = Arc::make_mut(&mut state.accumulated_words);
for w in &emitted_words {
if !acc_text.is_empty() {
acc_text.push(' ');
}
acc_text.push_str(&w.word);
}
acc_words.extend(emitted_words);
let text = Arc::clone(&state.accumulated_text);
let words = Arc::clone(&state.accumulated_words);
let ts = now_timestamp();
if endpoint {
state.accumulated_text = Arc::new(String::new());
state.accumulated_words = Arc::new(Vec::new());
state.decoder.consecutive_blanks = 0;
state.prev_window_words.clear();
Ok(vec![TranscriptSegment {
text,
words,
is_final: true,
timestamp: ts,
}])
} else {
Ok(vec![TranscriptSegment {
text,
words,
is_final: false,
timestamp: ts,
}])
}
}
fn process_chunk_vad(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, GigasttError> {
state.vad_audio_buffer.extend_from_slice(samples);
let (speech_segments, is_active) = {
let session = state.vad_session.as_mut().unwrap();
let stream = state.vad_stream_state.as_mut().unwrap();
let segmenter = state.vad_segmenter.as_mut().unwrap();
let mut segments: Vec<silero::SpeechSegment> = Vec::new();
session
.process_stream(stream, samples, |probability| {
if let Some(segment) = segmenter.push_probability(probability) {
segments.push(segment);
}
})
.map_err(|e| GigasttError::Inference(format!("VAD inference failed: {e}")))?;
let active = segmenter.is_active();
(segments, active)
};
let mut emitted_segments: Vec<TranscriptSegment> = Vec::new();
let buffer_start = state.vad_sample_offset;
for segment in &speech_segments {
let buf_start = segment.start_sample().saturating_sub(buffer_start) as usize;
let buf_end = segment.end_sample().saturating_sub(buffer_start) as usize;
if buf_end > state.vad_audio_buffer.len() {
tracing::warn!("VAD segment extends beyond audio buffer, skipping");
continue;
}
let speech_samples = &state.vad_audio_buffer[buf_start..buf_end];
if speech_samples.is_empty() {
continue;
}
state.vad_pending_asr.push(speech_samples.to_vec());
state.reset_utterance_state();
}
if let Some(last_seg) = speech_segments.last() {
let remove_up_to = (last_seg.end_sample().saturating_sub(buffer_start)) as usize;
if remove_up_to <= state.vad_audio_buffer.len() {
state.vad_audio_buffer.drain(..remove_up_to);
state.vad_sample_offset += remove_up_to as u64;
}
}
if is_active {
let partials = self.process_chunk_overlap(samples, state, triplet)?;
emitted_segments.extend(partials);
}
Ok(emitted_segments)
}
fn flush_state_vad(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Option<TranscriptSegment> {
let session = state.vad_session.as_mut()?;
let stream = state.vad_stream_state.as_mut()?;
let segmenter = state.vad_segmenter.as_mut()?;
if let Ok(Some(probability)) = session.flush_stream(stream)
&& let Some(segment) = segmenter.push_probability(probability)
{
let buffer_start = state.vad_sample_offset;
let buf_start = segment.start_sample().saturating_sub(buffer_start) as usize;
let buf_end = (segment.end_sample().saturating_sub(buffer_start) as usize)
.min(state.vad_audio_buffer.len());
if buf_start < buf_end {
let speech_samples = &state.vad_audio_buffer[buf_start..buf_end];
if let Ok(result) = self.transcribe_samples(speech_samples, triplet)
&& !result.text.is_empty()
{
state.reset_utterance_state();
return Some(TranscriptSegment {
text: Arc::new(result.text),
words: Arc::new(result.words),
is_final: true,
timestamp: now_timestamp(),
});
}
}
}
if let Some(segment) = segmenter.finish() {
let buffer_start = state.vad_sample_offset;
let buf_start = segment.start_sample().saturating_sub(buffer_start) as usize;
let buf_end = (segment.end_sample().saturating_sub(buffer_start) as usize)
.min(state.vad_audio_buffer.len());
if buf_start < buf_end {
let speech_samples = &state.vad_audio_buffer[buf_start..buf_end];
if let Ok(result) = self.transcribe_samples(speech_samples, triplet)
&& !result.text.is_empty()
{
state.reset_utterance_state();
return Some(TranscriptSegment {
text: Arc::new(result.text),
words: Arc::new(result.words),
is_final: true,
timestamp: now_timestamp(),
});
}
}
}
None
}
pub fn flush_state(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Option<TranscriptSegment> {
if state.vad_session.is_some() {
return self.flush_state_vad(state, triplet);
}
state.online.input_finished();
let ready = state.online.num_frames_ready();
let new_frames = ready.saturating_sub(state.frames_seen);
if new_frames > 0 {
let new_features =
features::extract_online_frames(&state.online, state.frames_seen, new_frames);
state.feature_window.extend_from_slice(&new_features);
state.frames_seen = ready;
}
if !state.feature_window.is_empty() {
let num_frames = state.feature_window.len() / N_MELS;
let features = &state.feature_window[..];
let frame_offset = state.total_frames;
let (window_words, _endpoint, _enc_len) = self
.run_inference(
triplet,
features,
num_frames,
&mut state.decoder,
frame_offset,
)
.ok()?;
let delta = Self::delta_words(
&window_words,
&state.prev_window_words,
state.config.fuzzy_match_threshold,
);
let acc_text = Arc::make_mut(&mut state.accumulated_text);
let acc_words = Arc::make_mut(&mut state.accumulated_words);
for w in &delta {
if !acc_text.is_empty() {
acc_text.push(' ');
}
acc_text.push_str(&w.word);
}
acc_words.extend(delta);
state.prev_window_words = window_words;
state.feature_window.clear();
state.total_frames += num_frames / 4;
}
if state.accumulated_text.is_empty() {
return None;
}
let seg = TranscriptSegment {
text: Arc::clone(&state.accumulated_text),
words: Arc::clone(&state.accumulated_words),
is_final: true,
timestamp: now_timestamp(),
};
Some(seg)
}
pub fn transcribe_file(
&self,
path: &str,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples = audio::decode_audio_file(path)
.map_err(|e| GigasttError::InvalidAudio(format!("{e:#}")))?;
self.transcribe_samples(&float_samples, triplet)
}
pub fn transcribe_bytes(
&self,
data: &[u8],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
self.transcribe_bytes_shared(bytes::Bytes::copy_from_slice(data), triplet)
}
pub fn transcribe_bytes_shared(
&self,
data: bytes::Bytes,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples = audio::decode_audio_bytes_shared(data)
.map_err(|e| GigasttError::InvalidAudio(format!("{e:#}")))?;
self.transcribe_samples(&float_samples, triplet)
}
pub fn transcribe_samples(
&self,
float_samples: &[f32],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let duration_s = float_samples.len() as f64 / 16000.0;
let (features, num_frames) = self.mel.compute(float_samples);
tracing::info!("Extracted {} mel frames", num_frames);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
let (words, _endpoint, _enc_len) = self
.run_inference(triplet, &features, num_frames, &mut decoder_state, 0)
.map_err(|e| GigasttError::Inference(format!("{e:#}")))?;
let text: String = words
.iter()
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
Ok(TranscribeResult {
text,
words,
duration_s,
})
}
fn run_inference(
&self,
triplet: &mut SessionTriplet,
features: &[f32],
num_frames: usize,
decoder_state: &mut DecoderState,
frame_offset: usize,
) -> anyhow::Result<(Vec<WordInfo>, bool, usize)> {
let features_tensor =
TensorRef::from_array_view(([1_usize, num_frames, N_MELS], features))?;
let length_data = [num_frames as i64];
let length_tensor = TensorRef::from_array_view(([1_usize], length_data.as_slice()))?;
let enc_start = std::time::Instant::now();
let encoder_outputs = triplet
.encoder
.run(ort::inputs![features_tensor, length_tensor])
.context("Encoder inference failed")?;
tracing::info!(
elapsed_ms = enc_start.elapsed().as_millis() as u64,
"encoder_inference"
);
let (_enc_shape, enc_data) = encoder_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract encoder output")?;
let (_len_shape, len_data) = encoder_outputs[1]
.try_extract_tensor::<i64>()
.context("Failed to extract encoder length")?;
let enc_len = usize::try_from(len_data[0]).context("Negative encoder length")?;
tracing::debug!("Encoder output: {} frames", enc_len);
let enc_data_owned: Vec<f32> = enc_data.to_vec();
drop(encoder_outputs);
let dec_start = std::time::Instant::now();
let result = decode::greedy_decode(
&mut triplet.decoder,
&mut triplet.joiner,
&enc_data_owned,
enc_len,
self.tokenizer.blank_id(),
self.tokenizer.vocab_size(),
decoder_state,
)?;
tracing::info!(
elapsed_ms = dec_start.elapsed().as_millis() as u64,
"greedy_decode"
);
let words = self.tokens_to_words(&result.tokens, frame_offset);
let preview: String = words
.iter()
.take(10)
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
let ellipsis = if words.len() > 10 { "..." } else { "" };
tracing::info!(
"Decoded {} tokens → \"{preview}{ellipsis}\"",
result.tokens.len()
);
Ok((words, result.endpoint_detected, enc_len))
}
fn tokens_to_words(&self, tokens: &[decode::TokenInfo], frame_offset: usize) -> Vec<WordInfo> {
if tokens.is_empty() {
return Vec::new();
}
let mut words = Vec::new();
let mut current_word = String::new();
let mut word_start_frame: Option<usize> = None;
let mut word_end_frame: usize = 0;
let mut word_confidences: Vec<f32> = Vec::new();
for token in tokens {
let token_text = self.tokenizer.token_text(token.token_id);
let is_word_boundary = token_text.starts_with('\u{2581}');
if is_word_boundary && !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word.clone(),
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64
* SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
current_word.clear();
word_confidences.clear();
word_start_frame = None;
}
let clean = token_text.replace('\u{2581}', "");
if !clean.is_empty() {
current_word.push_str(&clean);
if word_start_frame.is_none() {
word_start_frame = Some(token.frame_index);
}
word_end_frame = token.frame_index;
word_confidences.push(token.confidence);
}
}
if !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word,
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
}
words
}
fn delta_words(new: &[WordInfo], prev: &[WordInfo], fuzzy_threshold: f32) -> Vec<WordInfo> {
if prev.is_empty() {
return new.to_vec();
}
let mut best = 0;
for start in 0..prev.len() {
let mut matched = 0;
for (a, b) in new.iter().zip(prev[start..].iter()) {
if words_match(&a.word, &b.word, fuzzy_threshold) {
matched += 1;
} else {
break;
}
}
if matched > best {
best = matched;
}
}
new[best..].to_vec()
}
}
fn levenshtein(a: &str, b: &str) -> usize {
let a_len = a.chars().count();
let b_len = b.chars().count();
if a_len == 0 {
return b_len;
}
if b_len == 0 {
return a_len;
}
let mut prev = vec![0; b_len + 1];
let mut curr = vec![0; b_len + 1];
for (j, item) in prev.iter_mut().enumerate().take(b_len + 1) {
*item = j;
}
for (i, ca) in a.chars().enumerate() {
curr[0] = i + 1;
for (j, cb) in b.chars().enumerate() {
let cost = if ca == cb { 0 } else { 1 };
curr[j + 1] = (curr[j] + 1).min(prev[j + 1] + 1).min(prev[j] + cost);
}
std::mem::swap(&mut prev, &mut curr);
}
prev[b_len]
}
fn word_similarity(a: &str, b: &str) -> f32 {
if a == b {
return 1.0;
}
let dist = levenshtein(a, b);
let max_len = a.chars().count().max(b.chars().count());
if max_len == 0 {
return 1.0;
}
1.0 - (dist as f32 / max_len as f32)
}
fn words_match(a: &str, b: &str, threshold: f32) -> bool {
if threshold >= 1.0 {
return a == b;
}
word_similarity(a, b) >= threshold
}
#[derive(Debug, Clone, Serialize)]
pub struct TranscribeResult {
pub text: String,
pub words: Vec<WordInfo>,
pub duration_s: f64,
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct TranscriptSegment {
pub text: Arc<String>,
pub words: Arc<Vec<WordInfo>>,
pub is_final: bool,
pub timestamp: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_state_new_left_pads_context_with_blank() {
let blank_id = 0;
let state = DecoderState::new(blank_id);
assert_eq!(state.tokens.len(), CONTEXT_SIZE);
assert!(state.tokens.iter().all(|&t| t == blank_id as i64));
assert_eq!(state.blank_id, blank_id);
assert_eq!(state.consecutive_blanks, 0);
}
#[test]
fn test_decoder_state_push_token_slides_window() {
let mut state = DecoderState::new(0);
state.push_token(7);
assert_eq!(state.tokens.last().copied(), Some(7));
state.push_token(9);
assert_eq!(state.tokens, vec![7, 9]);
}
#[test]
fn test_decoder_state_custom_blank_id_seeds_context() {
let state = DecoderState::new(42);
assert!(state.tokens.iter().all(|&t| t == 42));
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_normal_drop() {
let pool = Pool::new(vec![1u32, 2, 3]);
assert_eq!(pool.available(), 3);
{
let _guard = pool.checkout().await.expect("checkout");
assert_eq!(pool.available(), 2);
}
assert_eq!(pool.available(), 3);
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_panic_unwind() {
let pool = std::sync::Arc::new(Pool::new(vec![1u32]));
assert_eq!(pool.available(), 1);
let pool_clone = pool.clone();
let result = tokio::spawn(async move {
let _guard = pool_clone.checkout().await.expect("checkout");
assert_eq!(pool_clone.available(), 0);
panic!("synthetic inference panic");
})
.await;
assert!(result.is_err(), "spawned task must report the panic");
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_close_wakes_waiters_with_closed() {
let pool = std::sync::Arc::new(Pool::<u32>::new(vec![]));
let waiter = tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await.map(|_g| ()) }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
pool.close();
let res = waiter.await.expect("join");
assert!(matches!(res, Err(PoolError::Closed)));
}
#[tokio::test]
async fn test_pool_fifo_under_contention() {
let pool = std::sync::Arc::new(Pool::new(vec![0u32]));
let primary = pool.checkout().await.expect("primary checkout");
assert_eq!(pool.available(), 0);
let waker_log = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
let mut handles = Vec::new();
for id in 0u32..3 {
let pool = pool.clone();
let log = waker_log.clone();
handles.push(tokio::spawn(async move {
let g = pool.checkout().await.expect("checkout");
log.lock().await.push(id);
drop(g);
}));
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
drop(primary);
for h in handles {
h.await.expect("join");
}
let log = waker_log.lock().await.clone();
assert_eq!(log, vec![0, 1, 2], "waiters must wake in FIFO order");
}
#[tokio::test]
async fn test_into_owned_for_spawn_blocking() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let (item, reservation) = guard.into_owned();
let item = tokio::task::spawn_blocking(move || {
assert_eq!(item, "triplet");
reservation.checkin(item.clone());
item
})
.await
.expect("join");
assert_eq!(pool.available(), 1);
assert_eq!(item, "triplet");
}
#[tokio::test]
async fn test_pool_close_is_idempotent() {
let pool = Pool::<u32>::new(vec![]);
pool.close();
pool.close();
}
#[tokio::test]
async fn test_pool_triplet_survives_panic() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let (item, reservation) = guard.into_owned();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
let mut guard = reservation.guard(item);
guard.push_str(" mutated");
panic!("simulated inference panic");
}));
assert!(result.is_err(), "panic should have occurred");
assert_eq!(
pool.available(),
1,
"pool slot must be recovered after panic"
);
let g = pool.checkout().await.expect("checkout after panic");
assert_eq!(g.as_str(), "triplet mutated");
}
#[tokio::test]
async fn test_pool_all_slots_recovered_after_panic_storm() {
let pool = std::sync::Arc::new(Pool::new(vec![1u32, 2, 3, 4]));
assert_eq!(pool.available(), 4);
let mut handles = Vec::new();
for _ in 0..4 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let guard = pool.checkout().await.expect("checkout");
let (item, reservation) = guard.into_owned();
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
let guard = reservation.guard(item);
let _ = *guard;
panic!("simulated inference panic");
}));
}));
}
for h in handles {
let _ = h.await;
}
assert_eq!(
pool.available(),
4,
"all pool slots must be recovered after panic storm"
);
}
#[test]
fn test_streaming_config_defaults() {
let cfg = StreamingConfig::default();
assert_eq!(cfg.window_frames, 400);
assert_eq!(cfg.overlap_frames, 100);
assert_eq!(cfg.fuzzy_match_threshold, 1.0);
assert_eq!(cfg.shift_frames(), 300);
assert_eq!(cfg.shift_encoder_frames(), 75);
assert!(cfg.validate().is_ok());
}
#[test]
fn test_streaming_config_validation() {
assert!(
StreamingConfig {
window_frames: 0,
..Default::default()
}
.validate()
.is_err()
);
assert!(
StreamingConfig {
window_frames: 100,
overlap_frames: 100,
..Default::default()
}
.validate()
.is_err()
);
assert!(
StreamingConfig {
window_frames: 100,
overlap_frames: 50,
..Default::default()
}
.validate()
.is_err()
);
assert!(
StreamingConfig {
window_frames: 200,
overlap_frames: 40,
fuzzy_match_threshold: 1.5,
}
.validate()
.is_err()
);
assert!(
StreamingConfig {
window_frames: 200,
overlap_frames: 40,
fuzzy_match_threshold: 0.8,
}
.validate()
.is_ok()
);
}
#[test]
fn test_levenshtein_basic() {
assert_eq!(levenshtein("kitten", "sitting"), 3);
assert_eq!(levenshtein("", "abc"), 3);
assert_eq!(levenshtein("abc", ""), 3);
assert_eq!(levenshtein("abc", "abc"), 0);
assert_eq!(levenshtein("ab", "abc"), 1);
}
#[test]
fn test_word_similarity() {
assert!((word_similarity("hello", "hello") - 1.0).abs() < f32::EPSILON);
assert!(word_similarity("hello", "hallo") > 0.7);
assert!(word_similarity("hello", "world") < 0.5);
}
#[test]
fn test_words_match_exact_threshold() {
assert!(words_match("hello", "hello", 1.0));
assert!(!words_match("hello", "hallo", 1.0));
}
#[test]
fn test_words_match_fuzzy() {
assert!(words_match("hello", "hallo", 0.7));
assert!(!words_match("hello", "world", 0.7));
}
#[test]
fn test_delta_words_exact() {
let mk = |w: &str| WordInfo {
word: w.into(),
start: 0.0,
end: 0.0,
confidence: 1.0,
speaker: None,
};
let prev = vec![mk("a"), mk("b"), mk("c")];
let new = vec![mk("b"), mk("c"), mk("d")];
let delta = Engine::delta_words(&new, &prev, 1.0);
assert_eq!(delta.len(), 1);
assert_eq!(delta[0].word, "d");
}
#[test]
fn test_delta_words_fuzzy() {
let mk = |w: &str| WordInfo {
word: w.into(),
start: 0.0,
end: 0.0,
confidence: 1.0,
speaker: None,
};
let prev = vec![mk("a"), mk("hello")];
let new = vec![mk("helio"), mk("b")];
let delta_exact = Engine::delta_words(&new, &prev, 1.0);
assert_eq!(delta_exact.len(), 2);
let delta_fuzzy = Engine::delta_words(&new, &prev, 0.7);
assert_eq!(delta_fuzzy.len(), 1);
assert_eq!(delta_fuzzy[0].word, "b");
}
#[tokio::test]
async fn test_vad_streaming_produces_same_text_as_offline() {
let home = std::env::var_os("HOME").map(std::path::PathBuf::from);
let model_dir = home.as_ref().map(|p| p.join(".phostt/models"));
if model_dir.is_none()
|| !model_dir
.as_ref()
.unwrap()
.join("encoder.int8.onnx")
.exists()
{
eprintln!("Skipping test_vad_streaming: model not found");
return;
}
let model_dir = model_dir.unwrap();
let wav_path = model_dir.join("test_wavs").join("0.wav");
if !wav_path.exists() {
eprintln!("Skipping test_vad_streaming: test WAV not found");
return;
}
let samples = audio::decode_audio_file(wav_path.to_str().unwrap()).unwrap();
let engine = Engine::load_with_pool_size_and_config_and_vad(
model_dir.to_str().unwrap(),
1,
StreamingConfig::default(),
true,
)
.unwrap();
let mut triplet = engine.pool.checkout().await.unwrap();
let offline = engine.transcribe_samples(&samples, &mut triplet).unwrap();
let offline_text = offline.text;
let mut state = engine.create_state(false).unwrap();
let chunk_size = samples.len() / 3;
let chunks = vec![
&samples[..chunk_size],
&samples[chunk_size..2 * chunk_size],
&samples[2 * chunk_size..],
];
let mut vad_text = String::new();
for chunk in chunks {
let segs = engine
.process_chunk(chunk, &mut state, &mut triplet)
.unwrap();
for audio in state.vad_pending_asr.drain(..) {
let result = engine.transcribe_samples(&audio, &mut triplet).unwrap();
if !result.text.is_empty() {
if !vad_text.is_empty() {
vad_text.push(' ');
}
vad_text.push_str(&result.text);
}
}
let _ = segs;
}
if let Some(flush) = engine.flush_state(&mut state, &mut triplet) {
if !vad_text.is_empty() {
vad_text.push(' ');
}
vad_text.push_str(&flush.text);
}
let normalize = |s: &str| s.split_whitespace().collect::<Vec<_>>().join(" ");
assert_eq!(
normalize(&vad_text),
normalize(&offline_text),
"VAD streaming transcript should match offline transcript"
);
}
#[tokio::test]
async fn test_vad_hybrid_emits_partials() {
let home = std::env::var_os("HOME").map(std::path::PathBuf::from);
let model_dir = home.as_ref().map(|p| p.join(".phostt/models"));
if model_dir.is_none()
|| !model_dir
.as_ref()
.unwrap()
.join("encoder.int8.onnx")
.exists()
{
eprintln!("Skipping test_vad_hybrid_emits_partials: model not found");
return;
}
let model_dir = model_dir.unwrap();
let wav_path = model_dir.join("test_wavs").join("0.wav");
if !wav_path.exists() {
eprintln!("Skipping test_vad_hybrid_emits_partials: test WAV not found");
return;
}
let samples = audio::decode_audio_file(wav_path.to_str().unwrap()).unwrap();
let config = StreamingConfig {
window_frames: 100,
overlap_frames: 20,
fuzzy_match_threshold: 1.0,
};
let engine = Engine::load_with_pool_size_and_config_and_vad(
model_dir.to_str().unwrap(),
1,
config,
true,
)
.unwrap();
let mut triplet = engine.pool.checkout().await.unwrap();
let mut state = engine.create_state(false).unwrap();
let chunk_size = samples.len() / 10;
let mut partial_count = 0usize;
let mut final_count = 0usize;
for i in 0..10 {
let end = ((i + 1) * chunk_size).min(samples.len());
let chunk = &samples[i * chunk_size..end];
let segs = engine
.process_chunk(chunk, &mut state, &mut triplet)
.unwrap();
for seg in &segs {
if seg.is_final {
final_count += 1;
} else {
partial_count += 1;
}
}
for audio in state.vad_pending_asr.drain(..) {
let result = engine.transcribe_samples(&audio, &mut triplet).unwrap();
if !result.text.is_empty() {
final_count += 1;
}
}
}
if let Some(flush) = engine.flush_state(&mut state, &mut triplet)
&& flush.is_final
{
final_count += 1;
}
assert!(
partial_count > 0,
"VAD hybrid should emit Partial segments during active speech, got {partial_count}"
);
assert!(
final_count > 0,
"VAD hybrid should emit at least one Final segment, got {final_count}"
);
}
#[tokio::test]
async fn test_vad_hybrid_matches_offline() {
let home = std::env::var_os("HOME").map(std::path::PathBuf::from);
let model_dir = home.as_ref().map(|p| p.join(".phostt/models"));
if model_dir.is_none()
|| !model_dir
.as_ref()
.unwrap()
.join("encoder.int8.onnx")
.exists()
{
eprintln!("Skipping test_vad_hybrid_matches_offline: model not found");
return;
}
let model_dir = model_dir.unwrap();
let wav_path = model_dir.join("test_wavs").join("0.wav");
if !wav_path.exists() {
eprintln!("Skipping test_vad_hybrid_matches_offline: test WAV not found");
return;
}
let samples = audio::decode_audio_file(wav_path.to_str().unwrap()).unwrap();
let engine = Engine::load_with_pool_size_and_config_and_vad(
model_dir.to_str().unwrap(),
1,
StreamingConfig::default(),
true,
)
.unwrap();
let mut triplet = engine.pool.checkout().await.unwrap();
let offline = engine.transcribe_samples(&samples, &mut triplet).unwrap();
let offline_text = offline.text;
let mut state = engine.create_state(false).unwrap();
let chunk_size = samples.len() / 4;
let chunks = vec![
&samples[..chunk_size],
&samples[chunk_size..2 * chunk_size],
&samples[2 * chunk_size..3 * chunk_size],
&samples[3 * chunk_size..],
];
let mut hybrid_text = String::new();
for chunk in chunks {
let _segs = engine
.process_chunk(chunk, &mut state, &mut triplet)
.unwrap();
for audio in state.vad_pending_asr.drain(..) {
let result = engine.transcribe_samples(&audio, &mut triplet).unwrap();
if !result.text.is_empty() {
if !hybrid_text.is_empty() {
hybrid_text.push(' ');
}
hybrid_text.push_str(&result.text);
}
}
}
if let Some(flush) = engine.flush_state(&mut state, &mut triplet) {
if !hybrid_text.is_empty() {
hybrid_text.push(' ');
}
hybrid_text.push_str(&flush.text);
}
let normalize = |s: &str| s.split_whitespace().collect::<Vec<_>>().join(" ");
assert_eq!(
normalize(&hybrid_text),
normalize(&offline_text),
"VAD hybrid final transcript should match offline transcript"
);
}
#[tokio::test]
async fn test_streaming_matches_offline() {
let home = std::env::var_os("HOME").map(std::path::PathBuf::from);
let model_dir = home.as_ref().map(|p| p.join(".phostt/models"));
if model_dir.is_none()
|| !model_dir
.as_ref()
.unwrap()
.join("encoder.int8.onnx")
.exists()
{
eprintln!("Skipping test_streaming_matches_offline: model not found");
return;
}
let model_dir = model_dir.unwrap();
let engine = Engine::load(model_dir.to_str().unwrap()).unwrap();
let wav_path = model_dir.join("test_wavs").join("0.wav");
if !wav_path.exists() {
eprintln!("Skipping test_streaming_matches_offline: test WAV not found");
return;
}
let samples = audio::decode_audio_file(wav_path.to_str().unwrap()).unwrap();
let mut triplet = engine.pool.checkout().await.unwrap();
let offline = engine.transcribe_samples(&samples, &mut triplet).unwrap();
let offline_text = offline.text;
let mut state = engine.create_state(false).unwrap();
let chunk_size = samples.len() / 3;
let chunks = vec![
&samples[..chunk_size],
&samples[chunk_size..2 * chunk_size],
&samples[2 * chunk_size..],
];
let mut streaming_text = String::new();
for chunk in chunks {
let segs = engine
.process_chunk(chunk, &mut state, &mut triplet)
.unwrap();
for seg in segs {
if seg.is_final {
if !streaming_text.is_empty() {
streaming_text.push(' ');
}
streaming_text.push_str(&seg.text);
}
}
}
if let Some(flush) = engine.flush_state(&mut state, &mut triplet) {
if !streaming_text.is_empty() {
streaming_text.push(' ');
}
streaming_text.push_str(&flush.text);
}
let normalize = |s: &str| s.split_whitespace().collect::<Vec<_>>().join(" ");
assert_eq!(
normalize(&streaming_text),
normalize(&offline_text),
"streaming transcript should match offline transcript"
);
}
}