pub mod audio;
mod decode;
mod features;
mod tokenizer;
#[cfg(feature = "diarization")]
pub mod diarization;
#[cfg(all(feature = "coreml", feature = "cuda"))]
compile_error!("Features `coreml` and `cuda` are mutually exclusive. Choose one.");
use anyhow::Context;
#[cfg(any(feature = "coreml", feature = "cuda"))]
use ort::ep;
use ort::session::Session;
use ort::value::TensorRef;
use serde::Serialize;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use crate::error::GigasttError;
use features::MelSpectrogram;
use kaldi_native_fbank::fbank::FbankComputer;
use kaldi_native_fbank::online::{FeatureComputer, OnlineFeature};
use tokenizer::Tokenizer;
pub const N_MELS: usize = 80;
pub const N_FFT: usize = 400;
pub const HOP_LENGTH: usize = 160;
pub const ENCODER_OUT_DIM: usize = 512;
pub const DECODER_OUT_DIM: usize = 512;
pub const CONTEXT_SIZE: usize = 2;
fn ort_err(e: impl std::fmt::Display) -> anyhow::Error {
anyhow::anyhow!("{e}")
}
pub(crate) fn now_timestamp() -> f64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
const SECONDS_PER_FRAME: f64 = (HOP_LENGTH as f64) * 4.0 / 16000.0;
const STREAMING_WINDOW_FRAMES: usize = 400;
const STREAMING_OVERLAP_FRAMES: usize = 100;
const STREAMING_SHIFT_FRAMES: usize = STREAMING_WINDOW_FRAMES - STREAMING_OVERLAP_FRAMES;
const STREAMING_SHIFT_ENCODER_FRAMES: usize = STREAMING_SHIFT_FRAMES / 4;
const DEFAULT_POOL_SIZE: usize = 4;
pub struct SessionTriplet {
pub(crate) encoder: Session,
pub(crate) decoder: Session,
pub(crate) joiner: Session,
}
#[derive(Debug)]
pub enum PoolError {
Closed,
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolError::Closed => write!(f, "session pool is closed"),
}
}
}
impl std::error::Error for PoolError {}
pub struct Pool<T> {
sender: async_channel::Sender<T>,
receiver: async_channel::Receiver<T>,
total: usize,
}
pub type SessionPool = Pool<SessionTriplet>;
impl<T> Pool<T> {
pub fn new(items: Vec<T>) -> Self {
let total = items.len();
let (sender, receiver) = async_channel::bounded(total.max(1));
for item in items {
sender
.try_send(item)
.expect("channel capacity matches item count");
}
Self {
sender,
receiver,
total,
}
}
pub async fn checkout(&self) -> Result<PoolGuard<'_, T>, PoolError> {
match self.receiver.recv().await {
Ok(item) => Ok(PoolGuard {
pool: self,
item: Some(item),
}),
Err(_) => Err(PoolError::Closed),
}
}
pub fn close(&self) {
self.sender.close();
self.receiver.close();
}
pub fn total(&self) -> usize {
self.total
}
pub fn available(&self) -> usize {
self.receiver.len()
}
}
pub struct PoolGuard<'a, T> {
pool: &'a Pool<T>,
item: Option<T>,
}
impl<T> PoolGuard<'_, T> {
pub fn into_owned(mut self) -> (T, OwnedReservation<T>) {
let item = self
.item
.take()
.expect("PoolGuard::into_owned called after drop");
let reservation = OwnedReservation {
sender: self.pool.sender.clone(),
};
(item, reservation)
}
}
impl<T> Deref for PoolGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.item
.as_ref()
.expect("PoolGuard accessed after item taken")
}
}
impl<T> DerefMut for PoolGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.item
.as_mut()
.expect("PoolGuard accessed after item taken")
}
}
impl<T> Drop for PoolGuard<'_, T> {
fn drop(&mut self) {
if let Some(item) = self.item.take() {
let _ = self.pool.sender.try_send(item);
}
}
}
pub struct OwnedReservation<T> {
sender: async_channel::Sender<T>,
}
impl<T> OwnedReservation<T> {
pub fn checkin(self, item: T) {
let _ = self.sender.try_send(item);
}
}
#[non_exhaustive]
pub struct DecoderState {
pub tokens: Vec<i64>,
pub blank_id: usize,
pub consecutive_blanks: usize,
}
impl DecoderState {
pub fn new(blank_id: usize) -> Self {
Self {
tokens: vec![blank_id as i64; CONTEXT_SIZE],
blank_id,
consecutive_blanks: 0,
}
}
pub fn push_token(&mut self, token: i64) {
self.tokens.rotate_left(1);
let last = self.tokens.last_mut().expect("CONTEXT_SIZE > 0");
*last = token;
}
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct WordInfo {
pub word: String,
pub start: f64,
pub end: f64,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
}
#[cfg(feature = "diarization")]
pub struct DiarizationStreamState {
pub audio_buffer: Vec<f32>,
pub cluster: diarization::SpeakerCluster,
pub current_speaker: Option<u32>,
}
#[non_exhaustive]
pub struct StreamingState {
pub decoder: DecoderState,
pub online: OnlineFeature,
pub frames_seen: usize,
pub accumulated_text: String,
pub accumulated_words: Vec<WordInfo>,
pub total_frames: usize,
pub feature_window: Vec<f32>,
pub prev_window_words: Vec<WordInfo>,
#[cfg(feature = "diarization")]
pub diarization_state: Option<DiarizationStreamState>,
}
pub struct Engine {
pub pool: SessionPool,
tokenizer: Tokenizer,
mel: MelSpectrogram,
#[cfg(feature = "diarization")]
pub speaker_encoder: Option<diarization::SpeakerEncoder>,
}
impl Engine {
pub fn vocab_size(&self) -> usize {
self.tokenizer.vocab_size()
}
pub fn load(model_dir: &str) -> Result<Self, GigasttError> {
Self::load_with_pool_size(model_dir, DEFAULT_POOL_SIZE)
}
pub fn load_with_pool_size(model_dir: &str, pool_size: usize) -> Result<Self, GigasttError> {
let dir = Path::new(model_dir);
if !dir.join("encoder.int8.onnx").exists() {
return Err(GigasttError::ModelLoad(format!(
"encoder.int8.onnx not found in {model_dir}"
)));
}
Self::load_inner(dir, model_dir, pool_size)
.map_err(|e| GigasttError::ModelLoad(format!("{e:#}")))
}
fn load_sessions(
dir: &Path,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
let encoder_path = dir.join("encoder.int8.onnx");
#[cfg(feature = "coreml")]
let (encoder, decoder, joiner) = {
let cache_dir = dir.join("coreml_cache");
let coreml_ep = ep::CoreML::default()
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
.with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
.with_model_cache_dir(cache_dir.to_string_lossy())
.build();
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([coreml_ep.clone()])
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([coreml_ep.clone()])
.map_err(ort_err)?
.commit_from_file(dir.join("decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([coreml_ep])
.map_err(ort_err)?
.commit_from_file(dir.join("joiner.int8.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
#[cfg(feature = "cuda")]
let (encoder, decoder, joiner) = {
let cuda_ep = ep::CUDA::default().build();
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([cuda_ep.clone()])
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([cuda_ep.clone()])
.map_err(ort_err)?
.commit_from_file(dir.join("decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers([cuda_ep])
.map_err(ort_err)?
.commit_from_file(dir.join("joiner.int8.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
let (encoder, decoder, joiner) = {
let cache_dir = dir.join("optimized_cache");
std::fs::create_dir_all(&cache_dir).ok();
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_optimized_model_path(cache_dir.join("encoder_optimized.onnx"))
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.commit_from_file(dir.join("decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.commit_from_file(dir.join("joiner.int8.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
Ok((encoder, decoder, joiner))
}
fn load_inner(dir: &Path, model_dir: &str, pool_size: usize) -> anyhow::Result<Self> {
tracing::info!("Loading Zipformer-vi INT8 ONNX models from {model_dir} (pool_size={pool_size})...");
#[cfg(feature = "coreml")]
tracing::info!("Using CoreML execution provider (Neural Engine + CPU)");
#[cfg(feature = "cuda")]
tracing::info!("Using CUDA execution provider (falls back to CPU if unavailable)");
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
tracing::info!("Using CPU execution provider");
let prepacked = ort::session::builder::PrepackedWeights::new();
let triplets: Vec<SessionTriplet> = std::thread::scope(|s| {
let handles: Vec<_> = (0..pool_size)
.map(|i| {
let pp = &prepacked;
s.spawn(move || {
tracing::info!(
"Loading session triplet {}/{pool_size} (shared weights)",
i + 1
);
let (encoder, decoder, joiner) = Self::load_sessions(dir, pp)?;
Ok(SessionTriplet {
encoder,
decoder,
joiner,
})
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("Thread panicked during model loading"))
.collect::<anyhow::Result<Vec<_>>>()
})?;
let tokenizer = Tokenizer::load(&dir.join("tokens.txt"))?;
let mel = MelSpectrogram::new();
tracing::info!(
"Models loaded (vocab_size={}, pool_size={pool_size})",
tokenizer.vocab_size()
);
#[cfg(feature = "diarization")]
let speaker_encoder = match diarization::SpeakerEncoder::load(dir) {
Ok(enc) => {
tracing::info!("Speaker encoder loaded (diarization available)");
Some(enc)
}
Err(e) => {
tracing::warn!("Speaker encoder not loaded, diarization unavailable: {e:#}");
None
}
};
Ok(Self {
pool: SessionPool::new(triplets),
tokenizer,
mel,
#[cfg(feature = "diarization")]
speaker_encoder,
})
}
#[cfg(feature = "diarization")]
pub fn has_speaker_encoder(&self) -> bool {
self.speaker_encoder.is_some()
}
pub fn create_state(&self, diarization_enabled: bool) -> StreamingState {
#[cfg(feature = "diarization")]
let diarization_state = if diarization_enabled && self.speaker_encoder.is_some() {
Some(DiarizationStreamState {
audio_buffer: Vec::new(),
cluster: diarization::SpeakerCluster::new(),
current_speaker: None,
})
} else {
None
};
#[cfg(not(feature = "diarization"))]
if diarization_enabled {
tracing::warn!(
"diarization_enabled=true ignored: build lacks the `diarization` feature"
);
}
let computer = FbankComputer::new(features::phostt_fbank_options())
.expect("FBANK options valid");
let online = OnlineFeature::new(FeatureComputer::Fbank(computer));
StreamingState {
decoder: DecoderState::new(self.tokenizer.blank_id()),
online,
frames_seen: 0,
accumulated_text: String::new(),
accumulated_words: Vec::new(),
total_frames: 0,
feature_window: Vec::new(),
prev_window_words: Vec::new(),
#[cfg(feature = "diarization")]
diarization_state,
}
}
pub fn process_chunk(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, GigasttError> {
if samples.is_empty() {
return Ok(vec![]);
}
#[cfg(feature = "diarization")]
let samples_16k_copy = if state.diarization_state.is_some() {
Some(samples.to_vec())
} else {
None
};
state.online.accept_waveform(16000.0, samples);
let ready = state.online.num_frames_ready();
let new_frames = ready.saturating_sub(state.frames_seen);
if new_frames == 0 {
return Ok(vec![]);
}
let new_features = features::extract_online_frames(&state.online, state.frames_seen, new_frames);
state.frames_seen = ready;
state.feature_window.extend_from_slice(&new_features);
let mut emitted_words: Vec<WordInfo> = Vec::new();
let mut endpoint = false;
while state.feature_window.len() / N_MELS >= STREAMING_WINDOW_FRAMES {
let num_frames = STREAMING_WINDOW_FRAMES;
let features = &state.feature_window[..num_frames * N_MELS];
let frame_offset = state.total_frames;
let (window_words, window_endpoint, _enc_len) = self
.run_inference(triplet, features, num_frames, &mut state.decoder, frame_offset)
.map_err(|e| GigasttError::Inference(format!("{e:#}")))?;
let delta = Self::delta_words(&window_words, &state.prev_window_words);
emitted_words.extend(delta);
state.prev_window_words = window_words;
let shift = STREAMING_SHIFT_FRAMES * N_MELS;
state.feature_window.drain(..shift);
state.total_frames += STREAMING_SHIFT_ENCODER_FRAMES;
if window_endpoint {
endpoint = true;
break;
}
}
#[cfg(feature = "diarization")]
if let (Some(dia), Some(copy), Some(enc)) = (
state.diarization_state.as_mut(),
samples_16k_copy.as_ref(),
self.speaker_encoder.as_ref(),
) {
dia.audio_buffer.extend_from_slice(copy);
while dia.audio_buffer.len() >= diarization::SEGMENT_SAMPLES {
let segment: Vec<f32> = dia
.audio_buffer
.drain(..diarization::SEGMENT_SAMPLES)
.collect();
match enc.extract_embedding(&segment) {
Ok(embedding) => {
let speaker = dia.cluster.assign(&embedding);
dia.current_speaker = Some(speaker);
}
Err(e) => {
tracing::warn!("Embedding extraction failed: {e:#}");
}
}
}
if let Some(speaker_id) = dia.current_speaker {
for w in &mut emitted_words {
w.speaker = Some(speaker_id);
}
}
}
if emitted_words.is_empty() && !endpoint {
return Ok(vec![]);
}
for w in &emitted_words {
if !state.accumulated_text.is_empty() {
state.accumulated_text.push(' ');
}
state.accumulated_text.push_str(&w.word);
}
state.accumulated_words.extend(emitted_words);
let text = state.accumulated_text.clone();
let words = state.accumulated_words.clone();
let ts = now_timestamp();
if endpoint {
state.accumulated_text.clear();
state.accumulated_words.clear();
state.decoder.consecutive_blanks = 0;
state.prev_window_words.clear();
Ok(vec![TranscriptSegment {
text,
words,
is_final: true,
timestamp: ts,
}])
} else {
Ok(vec![TranscriptSegment {
text,
words,
is_final: false,
timestamp: ts,
}])
}
}
pub fn flush_state(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Option<TranscriptSegment> {
state.online.input_finished();
let ready = state.online.num_frames_ready();
let new_frames = ready.saturating_sub(state.frames_seen);
if new_frames > 0 {
let new_features =
features::extract_online_frames(&state.online, state.frames_seen, new_frames);
state.feature_window.extend_from_slice(&new_features);
state.frames_seen = ready;
}
if !state.feature_window.is_empty() {
let num_frames = state.feature_window.len() / N_MELS;
let features = &state.feature_window[..];
let frame_offset = state.total_frames;
let (window_words, _endpoint, _enc_len) = self
.run_inference(triplet, features, num_frames, &mut state.decoder, frame_offset)
.ok()?;
let delta = Self::delta_words(&window_words, &state.prev_window_words);
for w in &delta {
if !state.accumulated_text.is_empty() {
state.accumulated_text.push(' ');
}
state.accumulated_text.push_str(&w.word);
}
state.accumulated_words.extend(delta);
state.prev_window_words = window_words;
state.feature_window.clear();
state.total_frames += num_frames / 4;
}
if state.accumulated_text.is_empty() {
return None;
}
let seg = TranscriptSegment {
text: std::mem::take(&mut state.accumulated_text),
words: std::mem::take(&mut state.accumulated_words),
is_final: true,
timestamp: now_timestamp(),
};
Some(seg)
}
pub fn transcribe_file(
&self,
path: &str,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples = audio::decode_audio_file(path)
.map_err(|e| GigasttError::InvalidAudio(format!("{e:#}")))?;
self.transcribe_samples(&float_samples, triplet)
}
pub fn transcribe_bytes(
&self,
data: &[u8],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
self.transcribe_bytes_shared(bytes::Bytes::copy_from_slice(data), triplet)
}
pub fn transcribe_bytes_shared(
&self,
data: bytes::Bytes,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples = audio::decode_audio_bytes_shared(data)
.map_err(|e| GigasttError::InvalidAudio(format!("{e:#}")))?;
self.transcribe_samples(&float_samples, triplet)
}
fn transcribe_samples(
&self,
float_samples: &[f32],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let duration_s = float_samples.len() as f64 / 16000.0;
let (features, num_frames) = self.mel.compute(float_samples);
tracing::info!("Extracted {} mel frames", num_frames);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
let (words, _endpoint, _enc_len) = self
.run_inference(triplet, &features, num_frames, &mut decoder_state, 0)
.map_err(|e| GigasttError::Inference(format!("{e:#}")))?;
let text: String = words
.iter()
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
Ok(TranscribeResult {
text,
words,
duration_s,
})
}
fn run_inference(
&self,
triplet: &mut SessionTriplet,
features: &[f32],
num_frames: usize,
decoder_state: &mut DecoderState,
frame_offset: usize,
) -> anyhow::Result<(Vec<WordInfo>, bool, usize)> {
let features_tensor =
TensorRef::from_array_view(([1_usize, num_frames, N_MELS], features))?;
let length_data = [num_frames as i64];
let length_tensor = TensorRef::from_array_view(([1_usize], length_data.as_slice()))?;
let enc_start = std::time::Instant::now();
let encoder_outputs = triplet
.encoder
.run(ort::inputs![features_tensor, length_tensor])
.context("Encoder inference failed")?;
tracing::info!(
elapsed_ms = enc_start.elapsed().as_millis() as u64,
"encoder_inference"
);
let (_enc_shape, enc_data) = encoder_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract encoder output")?;
let (_len_shape, len_data) = encoder_outputs[1]
.try_extract_tensor::<i64>()
.context("Failed to extract encoder length")?;
let enc_len = usize::try_from(len_data[0]).context("Negative encoder length")?;
tracing::debug!("Encoder output: {} frames", enc_len);
let enc_data_owned: Vec<f32> = enc_data.to_vec();
drop(encoder_outputs);
let dec_start = std::time::Instant::now();
let result = decode::greedy_decode(
&mut triplet.decoder,
&mut triplet.joiner,
&enc_data_owned,
enc_len,
self.tokenizer.blank_id(),
decoder_state,
)?;
tracing::info!(
elapsed_ms = dec_start.elapsed().as_millis() as u64,
"greedy_decode"
);
let words = self.tokens_to_words(&result.tokens, frame_offset);
let preview: String = words
.iter()
.take(10)
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
let ellipsis = if words.len() > 10 { "..." } else { "" };
tracing::info!(
"Decoded {} tokens → \"{preview}{ellipsis}\"",
result.tokens.len()
);
Ok((words, result.endpoint_detected, enc_len))
}
fn tokens_to_words(&self, tokens: &[decode::TokenInfo], frame_offset: usize) -> Vec<WordInfo> {
if tokens.is_empty() {
return Vec::new();
}
if tokens.is_empty() {
return Vec::new();
}
let mut words = Vec::new();
let mut current_word = String::new();
let mut word_start_frame: Option<usize> = None;
let mut word_end_frame: usize = 0;
let mut word_confidences: Vec<f32> = Vec::new();
for token in tokens {
let token_text = self.tokenizer.token_text(token.token_id);
let is_word_boundary = token_text.starts_with('\u{2581}');
if is_word_boundary && !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word.clone(),
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64
* SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
current_word.clear();
word_confidences.clear();
word_start_frame = None;
}
let clean = token_text.replace('\u{2581}', "");
if !clean.is_empty() {
current_word.push_str(&clean);
if word_start_frame.is_none() {
word_start_frame = Some(token.frame_index);
}
word_end_frame = token.frame_index;
word_confidences.push(token.confidence);
}
}
if !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word,
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
}
words
}
fn delta_words(new: &[WordInfo], prev: &[WordInfo]) -> Vec<WordInfo> {
if prev.is_empty() {
return new.to_vec();
}
let mut best = 0;
for start in 0..prev.len() {
let mut matched = 0;
for (a, b) in new.iter().zip(prev[start..].iter()) {
if a.word == b.word {
matched += 1;
} else {
break;
}
}
if matched > best {
best = matched;
}
}
new[best..].to_vec()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TranscribeResult {
pub text: String,
pub words: Vec<WordInfo>,
pub duration_s: f64,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TranscriptSegment {
pub text: String,
pub words: Vec<WordInfo>,
pub is_final: bool,
pub timestamp: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_state_new_left_pads_context_with_blank() {
let blank_id = 0;
let state = DecoderState::new(blank_id);
assert_eq!(state.tokens.len(), CONTEXT_SIZE);
assert!(state.tokens.iter().all(|&t| t == blank_id as i64));
assert_eq!(state.blank_id, blank_id);
assert_eq!(state.consecutive_blanks, 0);
}
#[test]
fn test_decoder_state_push_token_slides_window() {
let mut state = DecoderState::new(0);
state.push_token(7);
assert_eq!(state.tokens.last().copied(), Some(7));
state.push_token(9);
assert_eq!(state.tokens, vec![7, 9]);
}
#[test]
fn test_decoder_state_custom_blank_id_seeds_context() {
let state = DecoderState::new(42);
assert!(state.tokens.iter().all(|&t| t == 42));
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_normal_drop() {
let pool = Pool::new(vec![1u32, 2, 3]);
assert_eq!(pool.available(), 3);
{
let _guard = pool.checkout().await.expect("checkout");
assert_eq!(pool.available(), 2);
}
assert_eq!(pool.available(), 3);
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_panic_unwind() {
let pool = std::sync::Arc::new(Pool::new(vec![1u32]));
assert_eq!(pool.available(), 1);
let pool_clone = pool.clone();
let result = tokio::spawn(async move {
let _guard = pool_clone.checkout().await.expect("checkout");
assert_eq!(pool_clone.available(), 0);
panic!("synthetic inference panic");
})
.await;
assert!(result.is_err(), "spawned task must report the panic");
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_close_wakes_waiters_with_closed() {
let pool = std::sync::Arc::new(Pool::<u32>::new(vec![]));
let waiter = tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await.map(|_g| ()) }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
pool.close();
let res = waiter.await.expect("join");
assert!(matches!(res, Err(PoolError::Closed)));
}
#[tokio::test]
async fn test_pool_fifo_under_contention() {
let pool = std::sync::Arc::new(Pool::new(vec![0u32]));
let primary = pool.checkout().await.expect("primary checkout");
assert_eq!(pool.available(), 0);
let waker_log = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
let mut handles = Vec::new();
for id in 0u32..3 {
let pool = pool.clone();
let log = waker_log.clone();
handles.push(tokio::spawn(async move {
let g = pool.checkout().await.expect("checkout");
log.lock().await.push(id);
drop(g);
}));
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
drop(primary);
for h in handles {
h.await.expect("join");
}
let log = waker_log.lock().await.clone();
assert_eq!(log, vec![0, 1, 2], "waiters must wake in FIFO order");
}
#[tokio::test]
async fn test_into_owned_for_spawn_blocking() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let (item, reservation) = guard.into_owned();
let item = tokio::task::spawn_blocking(move || {
assert_eq!(item, "triplet");
reservation.checkin(item.clone());
item
})
.await
.expect("join");
assert_eq!(pool.available(), 1);
assert_eq!(item, "triplet");
}
#[tokio::test]
async fn test_pool_close_is_idempotent() {
let pool = Pool::<u32>::new(vec![]);
pool.close();
pool.close();
}
#[tokio::test]
async fn test_streaming_matches_offline() {
let home = std::env::var_os("HOME").map(std::path::PathBuf::from);
let model_dir = home.as_ref().map(|p| p.join(".phostt/models"));
if model_dir.is_none() || !model_dir.as_ref().unwrap().join("encoder.int8.onnx").exists() {
eprintln!("Skipping test_streaming_matches_offline: model not found");
return;
}
let model_dir = model_dir.unwrap();
let engine = Engine::load(model_dir.to_str().unwrap()).unwrap();
let wav_path = model_dir.join("test_wavs").join("0.wav");
if !wav_path.exists() {
eprintln!("Skipping test_streaming_matches_offline: test WAV not found");
return;
}
let samples = audio::decode_audio_file(wav_path.to_str().unwrap()).unwrap();
let mut triplet = engine.pool.checkout().await.unwrap();
let offline = engine.transcribe_samples(&samples, &mut triplet).unwrap();
let offline_text = offline.text;
let mut state = engine.create_state(false);
let chunk_size = samples.len() / 3;
let chunks = vec![
&samples[..chunk_size],
&samples[chunk_size..2 * chunk_size],
&samples[2 * chunk_size..],
];
let mut streaming_text = String::new();
for chunk in chunks {
let segs = engine.process_chunk(chunk, &mut state, &mut triplet).unwrap();
for seg in segs {
if seg.is_final {
if !streaming_text.is_empty() {
streaming_text.push(' ');
}
streaming_text.push_str(&seg.text);
}
}
}
if let Some(flush) = engine.flush_state(&mut state, &mut triplet) {
if !streaming_text.is_empty() {
streaming_text.push(' ');
}
streaming_text.push_str(&flush.text);
}
let normalize = |s: &str| s.split_whitespace().collect::<Vec<_>>().join(" ");
assert_eq!(
normalize(&streaming_text),
normalize(&offline_text),
"streaming transcript should match offline transcript"
);
}
}