pub mod audio;
mod bias;
mod decode;
mod features;
#[cfg(not(feature = "__internals"))]
mod tokenizer;
#[cfg(feature = "__internals")]
pub mod tokenizer;
#[cfg(feature = "diarization")]
use polyvoice::streaming::StreamingPipeline;
#[cfg(feature = "diarization")]
#[allow(deprecated)]
use polyvoice::{
ClusterConfig, DiarizationConfig as DiaConfig, EmbeddingError, EmbeddingExtractor, EnergyVad,
OnnxEmbeddingExtractor, Pipeline, VadConfig,
};
#[cfg(feature = "diarization")]
const SPEAKER_EMBEDDING_DIM: usize = 256;
#[cfg(feature = "diarization")]
const SPEAKER_SEGMENT_SAMPLES: usize = 24000;
#[cfg(feature = "diarization")]
const SPEAKER_POOL_SIZE: usize = 4;
#[cfg(feature = "diarization")]
#[allow(deprecated)] pub struct SharedExtractor(std::sync::Arc<OnnxEmbeddingExtractor>);
#[cfg(feature = "diarization")]
#[allow(deprecated)] impl EmbeddingExtractor for SharedExtractor {
fn extract(&self, samples: &[f32], config: &DiaConfig) -> Result<Vec<f32>, EmbeddingError> {
self.0.extract(samples, config)
}
fn embedding_dim(&self) -> usize {
self.0.embedding_dim()
}
}
#[cfg(all(feature = "coreml", feature = "cuda"))]
compile_error!("Features `coreml` and `cuda` are mutually exclusive. Choose one.");
use anyhow::Context;
#[cfg(any(feature = "coreml", feature = "cuda"))]
use ort::ep;
use ort::session::Session;
use ort::value::TensorRef;
use serde::Serialize;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use crate::error::GigasttError;
use crate::model::ModelVariant;
use features::MelSpectrogram;
use tokenizer::Tokenizer;
pub const N_MELS: usize = 64;
pub const N_FFT: usize = 320;
pub const HOP_LENGTH: usize = 160;
pub const PRED_HIDDEN: usize = 320;
fn ort_err(e: impl std::fmt::Display) -> anyhow::Error {
anyhow::anyhow!("{e}")
}
pub fn now_timestamp() -> f64 {
use std::sync::OnceLock;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
static ANCHOR: OnceLock<(SystemTime, Instant)> = OnceLock::new();
let (epoch, start) = ANCHOR.get_or_init(|| (SystemTime::now(), Instant::now()));
let base = match epoch.duration_since(UNIX_EPOCH) {
Ok(d) => d.as_secs_f64(),
Err(e) => {
tracing::warn!("System clock is before Unix epoch: {e}");
0.0
}
};
base + start.elapsed().as_secs_f64()
}
fn total_ram_bytes() -> u64 {
#[cfg(target_os = "macos")]
{
let mut mem: u64 = 0;
let mut len = std::mem::size_of::<u64>();
let mib = [libc::CTL_HW, libc::HW_MEMSIZE];
let rc = unsafe {
libc::sysctl(
mib.as_ptr() as *mut libc::c_int,
mib.len() as libc::c_uint,
&mut mem as *mut u64 as *mut libc::c_void,
&mut len,
std::ptr::null_mut(),
0,
)
};
if rc == 0 { mem } else { 0 }
}
#[cfg(all(unix, not(target_os = "macos")))]
{
let pages = unsafe { libc::sysconf(libc::_SC_PHYS_PAGES) };
let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
if pages > 0 && page_size > 0 {
(pages as u64).saturating_mul(page_size as u64)
} else {
0
}
}
#[cfg(not(unix))]
{
0
}
}
const ENCODER_SUBSAMPLING: usize = 4;
const SECONDS_PER_FRAME: f64 = (HOP_LENGTH as f64 * ENCODER_SUBSAMPLING as f64) / 16000.0;
const STREAM_MAX_WINDOW_SAMPLES: usize = 16000 * 5 / 2;
const STREAM_LEFT_CONTEXT_SAMPLES: usize = 16000 * 3 / 2;
const STREAM_DECODE_STRIDE_SAMPLES: usize = 16000 * 4 / 5;
const CHUNK_THRESHOLD_SAMPLES: usize = 16000 * 30;
const CHUNK_WINDOW_SAMPLES: usize = 16000 * 24;
const CHUNK_OVERLAP_SAMPLES: usize = 16000 * 2;
#[cfg(target_os = "android")]
const DEFAULT_POOL_SIZE: usize = 1;
#[cfg(not(target_os = "android"))]
const DEFAULT_POOL_SIZE: usize = 2;
const ENCODER_RESIDENT_MULTIPLIER: u64 = 2;
const POOL_RAM_FRACTION_DENOM: u64 = 2;
pub struct SessionTriplet {
pub(crate) encoder: Session,
pub(crate) decoder: Session,
pub(crate) joiner: Session,
}
#[derive(Debug)]
pub enum PoolError {
Closed,
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolError::Closed => write!(f, "session pool is closed"),
}
}
}
impl std::error::Error for PoolError {}
pub struct Pool<T> {
inner: std::sync::Arc<PoolInner<T>>,
}
struct PoolInner<T> {
items: parking_lot::Mutex<std::collections::VecDeque<T>>,
waiters: parking_lot::Mutex<std::collections::VecDeque<Waiter<T>>>,
closed: std::sync::atomic::AtomicBool,
total: usize,
}
enum Waiter<T> {
Async(tokio::sync::oneshot::Sender<T>),
Blocking(std::sync::mpsc::Sender<T>),
}
pub type SessionPool = Pool<SessionTriplet>;
impl<T: Send> Pool<T> {
pub fn new(items: Vec<T>) -> Self {
let total = items.len();
Self {
inner: std::sync::Arc::new(PoolInner {
items: parking_lot::Mutex::new(std::collections::VecDeque::from(items)),
waiters: parking_lot::Mutex::new(std::collections::VecDeque::new()),
closed: std::sync::atomic::AtomicBool::new(false),
total,
}),
}
}
pub async fn checkout(&self) -> Result<PoolGuard<T>, PoolError> {
{
let mut items = self.inner.items.lock();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
if let Some(item) = items.pop_front() {
return Ok(PoolGuard::new(self.inner.clone(), item));
}
}
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut waiters = self.inner.waiters.lock();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
let mut items = self.inner.items.lock();
if let Some(item) = items.pop_front() {
drop(items);
drop(waiters);
return Ok(PoolGuard::new(self.inner.clone(), item));
}
waiters.push_back(Waiter::Async(tx));
}
match rx.await {
Ok(item) => Ok(PoolGuard::new(self.inner.clone(), item)),
Err(_) => Err(PoolError::Closed),
}
}
pub fn checkout_blocking(&self) -> Result<PoolGuard<T>, PoolError> {
{
let mut items = self.inner.items.lock();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
if let Some(item) = items.pop_front() {
return Ok(PoolGuard::new(self.inner.clone(), item));
}
}
let (tx, rx) = std::sync::mpsc::channel();
{
let mut waiters = self.inner.waiters.lock();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
let mut items = self.inner.items.lock();
if let Some(item) = items.pop_front() {
drop(items);
drop(waiters);
return Ok(PoolGuard::new(self.inner.clone(), item));
}
waiters.push_back(Waiter::Blocking(tx));
}
match rx.recv() {
Ok(item) => Ok(PoolGuard::new(self.inner.clone(), item)),
Err(_) => Err(PoolError::Closed),
}
}
pub fn close(&self) {
self.inner
.closed
.store(true, std::sync::atomic::Ordering::SeqCst);
let mut waiters = self.inner.waiters.lock();
waiters.clear();
}
pub fn total(&self) -> usize {
self.inner.total
}
pub fn available(&self) -> usize {
let items = self.inner.items.lock();
items.len()
}
pub fn waiters(&self) -> usize {
let waiters = self.inner.waiters.lock();
waiters.len()
}
}
impl<T> PoolInner<T> {
fn checkin(&self, mut item: T) {
if self.closed.load(std::sync::atomic::Ordering::SeqCst) {
return;
}
loop {
let mut waiters = self.waiters.lock();
if let Some(waiter) = waiters.pop_front() {
drop(waiters);
match waiter {
Waiter::Async(tx) => {
if let Err(returned_item) = tx.send(item) {
item = returned_item;
continue;
}
}
Waiter::Blocking(tx) => {
if let Err(std::sync::mpsc::SendError(returned_item)) = tx.send(item) {
item = returned_item;
continue;
}
}
}
} else {
drop(waiters);
let mut items = self.items.lock();
items.push_back(item);
}
break;
}
}
}
pub struct PoolGuard<T> {
inner: Option<std::sync::Arc<PoolInner<T>>>,
item: Option<T>,
}
impl<T> PoolGuard<T> {
fn new(inner: std::sync::Arc<PoolInner<T>>, item: T) -> Self {
Self {
inner: Some(inner),
item: Some(item),
}
}
pub fn into_owned(mut self) -> OwnedReservation<T> {
let item = self
.item
.take()
.unwrap_or_else(|| unreachable!("PoolGuard::into_owned called after drop"));
let inner = self.inner.take().unwrap();
OwnedReservation {
inner,
item: Some(item),
}
}
}
impl<T> Deref for PoolGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.item
.as_ref()
.unwrap_or_else(|| unreachable!("PoolGuard accessed after item taken"))
}
}
impl<T> DerefMut for PoolGuard<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.item
.as_mut()
.unwrap_or_else(|| unreachable!("PoolGuard accessed after item taken"))
}
}
impl<T> Drop for PoolGuard<T> {
fn drop(&mut self) {
if let (Some(inner), Some(item)) = (self.inner.take(), self.item.take()) {
inner.checkin(item);
}
}
}
pub struct OwnedReservation<T> {
inner: std::sync::Arc<PoolInner<T>>,
item: Option<T>,
}
impl<T> std::ops::Deref for OwnedReservation<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.item
.as_ref()
.unwrap_or_else(|| unreachable!("OwnedReservation accessed after checkin"))
}
}
impl<T> std::ops::DerefMut for OwnedReservation<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.item
.as_mut()
.unwrap_or_else(|| unreachable!("OwnedReservation accessed after checkin"))
}
}
impl<T> OwnedReservation<T> {
pub fn checkin(mut self) {
if let Some(item) = self.item.take() {
self.inner.checkin(item);
}
}
}
impl<T> Drop for OwnedReservation<T> {
fn drop(&mut self) {
if let Some(item) = self.item.take() {
self.inner.checkin(item);
}
}
}
#[non_exhaustive]
pub struct DecoderState {
pub h: Vec<f32>,
pub c: Vec<f32>,
pub prev_token: i64,
pub consecutive_blanks: usize,
}
impl DecoderState {
pub fn new(blank_id: usize) -> Self {
Self {
h: vec![0.0; PRED_HIDDEN],
c: vec![0.0; PRED_HIDDEN],
prev_token: blank_id as i64,
consecutive_blanks: 0,
}
}
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct WordInfo {
pub word: String,
pub start: f64,
pub end: f64,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
}
impl WordInfo {
pub fn new(
word: impl Into<String>,
start: f64,
end: f64,
confidence: f32,
speaker: Option<u32>,
) -> Self {
Self {
word: word.into(),
start,
end,
confidence,
speaker,
}
}
}
#[non_exhaustive]
pub struct StreamingState {
pub decoder: DecoderState,
pub audio_buffer: Vec<f32>,
pub assembler: TranscriptAssembler,
pub window_start_samples: usize,
pub context_samples: usize,
pub pending_samples: usize,
pub resampler: Option<rubato::Async<f32>>,
pub mel_fft_input: Vec<rustfft::num_complex::Complex<f32>>,
pub mel_power: Vec<f32>,
pub mel_output: Vec<f32>,
pub resample_output_buf: Vec<f32>,
pub vad_endpointer: Option<crate::vad::VadEndpointer>,
#[cfg(feature = "diarization")]
pub diarization_state: Option<StreamingPipeline<EnergyVad, SharedExtractor>>,
}
pub struct FeatureExtractor {
mel: MelSpectrogram,
}
impl Default for FeatureExtractor {
fn default() -> Self {
Self::new()
}
}
impl FeatureExtractor {
pub fn new() -> Self {
Self {
mel: MelSpectrogram::new(),
}
}
pub fn prepare_buffer(&self, samples: &[f32], audio_buffer: &mut Vec<f32>) -> Option<usize> {
audio::prepare_audio_buffer(samples, audio_buffer)
}
pub fn compute_mel(
&self,
samples: &[f32],
fft_buf: &mut Vec<rustfft::num_complex::Complex<f32>>,
power_buf: &mut Vec<f32>,
output_buf: &mut Vec<f32>,
) -> usize {
self.mel
.compute_with_buffers(samples, fft_buf, power_buf, output_buf)
}
pub fn compute(&self, samples: &[f32]) -> (Vec<f32>, usize) {
self.mel.compute(samples)
}
}
pub struct TranscriptAssembler {
text: String,
words: Vec<WordInfo>,
}
impl Default for TranscriptAssembler {
fn default() -> Self {
Self::new()
}
}
impl TranscriptAssembler {
pub fn new() -> Self {
Self {
text: String::new(),
words: Vec::new(),
}
}
pub fn append(&mut self, new_words: Vec<WordInfo>) {
for w in &new_words {
if !self.text.is_empty() {
self.text.push(' ');
}
self.text.push_str(&w.word);
}
self.words.extend(new_words);
}
pub fn set_words(&mut self, words: Vec<WordInfo>) {
self.text = words
.iter()
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
self.words = words;
}
pub fn finalize(&mut self, timestamp: f64) -> TranscriptSegment {
TranscriptSegment {
text: std::mem::take(&mut self.text),
words: std::mem::take(&mut self.words),
is_final: true,
timestamp,
}
}
pub fn partial(&self, timestamp: f64) -> TranscriptSegment {
TranscriptSegment {
text: self.text.clone(),
words: self.words.clone(),
is_final: false,
timestamp,
}
}
pub fn is_empty(&self) -> bool {
self.text.is_empty()
}
}
#[cfg_attr(not(feature = "coreml"), allow(dead_code))]
fn probe_or_rebuild<S>(
state: S,
probe: impl Fn(&S) -> anyhow::Result<()>,
rebuild: impl FnOnce(S, anyhow::Error) -> anyhow::Result<S>,
) -> anyhow::Result<S> {
match probe(&state) {
Ok(()) => Ok(state),
Err(probe_err) => {
let rebuilt = rebuild(state, probe_err)?;
probe(&rebuilt).context("state failed probe even after rebuild")?;
Ok(rebuilt)
}
}
}
pub struct Engine {
pub pool: SessionPool,
pub batch_pool: Option<SessionPool>,
tokenizer: Tokenizer,
features: FeatureExtractor,
variant: ModelVariant,
punctuator: Option<crate::punctuation::Punctuator>,
itn: bool,
biaser: Option<bias::Biaser>,
vad: Option<crate::vad::SileroVad>,
vad_config: crate::vad::VadConfig,
int8: bool,
#[cfg(feature = "diarization")]
#[allow(deprecated)] pub speaker_encoder: Option<std::sync::Arc<OnnxEmbeddingExtractor>>,
}
impl Engine {
pub fn is_int8(&self) -> bool {
self.int8
}
pub fn variant(&self) -> ModelVariant {
self.variant
}
pub fn with_punctuator(mut self, punctuator: Option<crate::punctuation::Punctuator>) -> Self {
self.punctuator = punctuator;
self
}
pub fn has_punctuator(&self) -> bool {
self.punctuator.is_some()
}
pub fn with_itn(mut self, enabled: bool) -> Self {
self.itn = enabled;
self
}
pub fn has_itn(&self) -> bool {
self.itn
}
pub fn with_hotwords(mut self, phrases: &[(String, f32)], boost: f32) -> Self {
self.biaser = if phrases.is_empty() {
None
} else {
bias::Biaser::from_phrases(&self.tokenizer, phrases, boost)
};
if let Some(b) = &self.biaser {
tracing::info!(
"Hotword biasing enabled ({} phrase(s), boost {boost})",
b.phrase_count()
);
}
self
}
pub fn has_hotwords(&self) -> bool {
self.biaser.is_some()
}
pub fn with_vad(
mut self,
vad: Option<crate::vad::SileroVad>,
config: crate::vad::VadConfig,
) -> Self {
self.vad = vad;
self.vad_config = config;
if self.vad.is_some() {
tracing::info!(
"VAD enabled (threshold {}, min_silence {}ms)",
self.vad_config.threshold,
self.vad_config.min_silence_ms
);
}
self
}
pub fn has_vad(&self) -> bool {
self.vad.is_some()
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.vocab_size()
}
pub fn load(model_dir: &str) -> Result<Self, GigasttError> {
Self::load_with_pool_size(model_dir, DEFAULT_POOL_SIZE)
}
pub fn load_with_pool_size(model_dir: &str, pool_size: usize) -> Result<Self, GigasttError> {
Self::load_with_pool_size_min(model_dir, pool_size, pool_size)
}
pub fn load_with_pool_size_min(
model_dir: &str,
pool_size: usize,
min_size: usize,
) -> Result<Self, GigasttError> {
Self::load_with_pools(model_dir, pool_size, min_size, 0)
}
pub fn load_with_pools(
model_dir: &str,
pool_size: usize,
min_size: usize,
batch_pool_size: usize,
) -> Result<Self, GigasttError> {
Self::load_with_pools_threads(model_dir, pool_size, min_size, batch_pool_size, 1)
}
pub fn load_with_pools_threads(
model_dir: &str,
pool_size: usize,
min_size: usize,
batch_pool_size: usize,
encoder_intra_threads: usize,
) -> Result<Self, GigasttError> {
let dir = Path::new(model_dir);
let Some(variant) = ModelVariant::detect_in_dir(dir) else {
return Err(GigasttError::ModelLoad {
path: model_dir.to_string(),
source: None,
});
};
let encoder_bytes = std::fs::metadata(Self::encoder_model_path(dir, variant))
.map(|m| m.len())
.unwrap_or(0);
let pool_size = Self::cap_pool_size_for_ram(pool_size, encoder_bytes, total_ram_bytes());
let logical_cpus = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let encoder_intra_threads =
Self::clamp_encoder_intra_threads(pool_size, encoder_intra_threads, logical_cpus);
Self::load_inner(
dir,
variant,
model_dir,
pool_size,
min_size,
batch_pool_size,
encoder_intra_threads,
)
.map_err(|e| GigasttError::ModelLoad {
path: model_dir.to_string(),
source: Some(e.into()),
})
}
fn encoder_model_path(dir: &Path, variant: ModelVariant) -> std::path::PathBuf {
let int8 = dir.join(variant.encoder_int8_file());
if int8.exists() {
int8
} else {
dir.join(variant.encoder_file())
}
}
fn load_sessions(
dir: &Path,
variant: ModelVariant,
prepacked: &ort::session::builder::PrepackedWeights,
encoder_intra_threads: usize,
) -> anyhow::Result<(Session, Session, Session)> {
#[cfg(feature = "coreml")]
{
let _ = encoder_intra_threads;
Self::load_sessions_coreml(dir, variant, prepacked)
}
#[cfg(feature = "cuda")]
{
let _ = encoder_intra_threads;
Self::load_sessions_cuda(dir, variant, prepacked)
}
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
{
Self::load_sessions_cpu(dir, variant, prepacked, encoder_intra_threads)
}
}
#[cfg(feature = "coreml")]
fn load_sessions_coreml(
dir: &Path,
variant: ModelVariant,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
let cache_dir = dir.join("coreml_cache");
let coreml_ep = ep::CoreML::default()
.with_model_format(ep::coreml::ModelFormat::MLProgram)
.with_static_input_shapes(true)
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
.with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
.with_model_cache_dir(cache_dir.to_string_lossy())
.build();
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [coreml_ep.clone(), cpu_fallback.into()];
let encoder_path = Self::encoder_model_path(dir, variant);
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join(variant.decoder_file()))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join(variant.joint_file()))
.map_err(ort_err)?;
Ok((encoder, decoder, joiner))
}
#[cfg(feature = "cuda")]
fn load_sessions_cuda(
dir: &Path,
variant: ModelVariant,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
let cuda_ep = ep::CUDA::default().build();
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [cuda_ep.clone(), cpu_fallback.into()];
let encoder_path = Self::encoder_model_path(dir, variant);
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join(variant.decoder_file()))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join(variant.joint_file()))
.map_err(ort_err)?;
Ok((encoder, decoder, joiner))
}
#[cfg(not(feature = "cuda"))]
fn load_sessions_cpu(
dir: &Path,
variant: ModelVariant,
prepacked: &ort::session::builder::PrepackedWeights,
encoder_intra_threads: usize,
) -> anyhow::Result<(Session, Session, Session)> {
let cache_dir = dir.join("optimized_cache");
std::fs::create_dir_all(&cache_dir)
.with_context(|| format!("Failed to create ONNX cache dir: {}", cache_dir.display()))?;
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [cpu_fallback.into()];
let encoder_path = Self::encoder_model_path(dir, variant);
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(encoder_intra_threads.max(1))
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.with_optimized_model_path(cache_dir.join("encoder_optimized.onnx"))
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(1)
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.commit_from_file(dir.join(variant.decoder_file()))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(1)
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.commit_from_file(dir.join(variant.joint_file()))
.map_err(ort_err)?;
Ok((encoder, decoder, joiner))
}
fn split_triplets(
triplets: Vec<SessionTriplet>,
batch_pool_size: usize,
) -> (SessionPool, Option<SessionPool>) {
Self::split_pool(triplets, batch_pool_size)
}
fn split_pool<T: Send>(
mut items: Vec<T>,
batch_pool_size: usize,
) -> (Pool<T>, Option<Pool<T>>) {
let n = items.len();
let batch = Self::batch_split_count(n, batch_pool_size);
if batch == 0 {
return (Pool::new(items), None);
}
let batch_items = items.split_off(n - batch);
(Pool::new(items), Some(Pool::new(batch_items)))
}
fn batch_split_count(n: usize, batch_pool_size: usize) -> usize {
batch_pool_size.min(n.saturating_sub(1))
}
fn cap_pool_size_for_ram(requested: usize, encoder_bytes: u64, total_ram: u64) -> usize {
if requested <= 1 || encoder_bytes == 0 || total_ram == 0 {
return requested.max(1);
}
let per_triplet = encoder_bytes.saturating_mul(ENCODER_RESIDENT_MULTIPLIER);
let budget = total_ram / POOL_RAM_FRACTION_DENOM;
let max_slots = (budget / per_triplet.max(1)).max(1) as usize;
if max_slots < requested {
tracing::warn!(
"Capping pool size {requested} -> {max_slots}: \
{requested} encoder slots (~{} MiB each) would exceed half of \
{} MiB total RAM. Concurrency is reduced; add RAM or lower \
--pool-size to silence this.",
per_triplet / (1024 * 1024),
total_ram / (1024 * 1024),
);
max_slots
} else {
requested
}
}
fn clamp_encoder_intra_threads(
pool_size: usize,
requested: usize,
logical_cpus: usize,
) -> usize {
let requested = requested.max(1);
let pool_size = pool_size.max(1);
let logical_cpus = logical_cpus.max(1);
let max_per_encoder = (logical_cpus / pool_size).max(1);
if requested > max_per_encoder {
tracing::warn!(
"Capping encoder intra-op threads {requested} -> {max_per_encoder}: \
{pool_size} pooled encoder(s) x {requested} threads would exceed \
the {logical_cpus} logical CPU(s) available. Lower --pool-size or \
--encoder-intra-threads to silence this."
);
max_per_encoder
} else {
requested
}
}
fn finalize_pool_load<T>(
results: Vec<anyhow::Result<T>>,
pool_size: usize,
min_size: usize,
) -> anyhow::Result<Vec<T>> {
let min_size = min_size.clamp(1, pool_size.max(1));
let mut loaded = Vec::with_capacity(results.len());
let mut first_err: Option<anyhow::Error> = None;
for r in results {
match r {
Ok(t) => loaded.push(t),
Err(e) => {
if first_err.is_none() {
first_err = Some(e);
}
}
}
}
let n = loaded.len();
if n >= min_size {
if n < pool_size {
let detail = first_err
.map(|e| format!("; first error: {e:#}"))
.unwrap_or_default();
tracing::warn!(
"degraded pool: loaded {n}/{pool_size} session triplets ({} failed){detail}",
pool_size - n
);
}
Ok(loaded)
} else {
let detail = first_err.map(|e| format!(": {e:#}")).unwrap_or_default();
Err(anyhow::anyhow!(
"loaded only {n}/{pool_size} session triplets, need at least {min_size}{detail}"
))
}
}
fn load_triplets(
dir: &Path,
variant: ModelVariant,
pool_size: usize,
min_size: usize,
prepacked: &ort::session::builder::PrepackedWeights,
load_one: impl Fn(
&Path,
ModelVariant,
&ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)>
+ Sync,
) -> anyhow::Result<Vec<SessionTriplet>> {
let results: Vec<anyhow::Result<SessionTriplet>> = std::thread::scope(|s| {
let handles: Vec<_> = (0..pool_size)
.map(|i| {
let pp = prepacked;
let load_one = &load_one;
s.spawn(move || {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
tracing::info!(
"Loading session triplet {}/{pool_size} (shared weights)",
i + 1
);
let (encoder, decoder, joiner) = load_one(dir, variant, pp)?;
Ok(SessionTriplet {
encoder,
decoder,
joiner,
})
}))
.map_err(|_| anyhow::anyhow!("model loading thread panicked"))?
})
})
.collect();
handles
.into_iter()
.map(|h| match h.join() {
Ok(r) => r,
Err(_) => Err(anyhow::anyhow!("model loading thread panicked")),
})
.collect()
});
Self::finalize_pool_load(results, pool_size, min_size)
}
fn load_inner(
dir: &Path,
variant: ModelVariant,
model_dir: &str,
pool_size: usize,
min_size: usize,
batch_pool_size: usize,
encoder_intra_threads: usize,
) -> anyhow::Result<Self> {
tracing::info!("Detected model variant: {variant:?}");
let is_int8 = dir.join(variant.encoder_int8_file()).exists();
if is_int8 {
tracing::info!("Using INT8 quantized encoder");
}
tracing::info!("Loading ONNX models from {model_dir} (pool_size={pool_size})...");
#[cfg(feature = "coreml")]
tracing::info!("Using CoreML execution provider (Neural Engine + CPU)");
#[cfg(feature = "cuda")]
tracing::info!("Using CUDA execution provider (falls back to CPU if unavailable)");
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
tracing::info!("Using CPU execution provider");
let prepacked = ort::session::builder::PrepackedWeights::new();
#[cfg(feature = "coreml")]
let triplets = match Self::load_triplets(
dir,
variant,
pool_size,
min_size,
&prepacked,
|d, v, pp| Self::load_sessions(d, v, pp, encoder_intra_threads),
) {
Ok(triplets) => triplets,
Err(load_err) => {
tracing::warn!(
"CoreML EP failed to load sessions ({load_err:#}); falling back to CPU execution provider"
);
let prepacked = ort::session::builder::PrepackedWeights::new();
Self::load_triplets(dir, variant, pool_size, min_size, &prepacked, |d, v, pp| {
Self::load_sessions_cpu(d, v, pp, encoder_intra_threads)
})?
}
};
#[cfg(not(feature = "coreml"))]
let triplets =
Self::load_triplets(dir, variant, pool_size, min_size, &prepacked, |d, v, pp| {
Self::load_sessions(d, v, pp, encoder_intra_threads)
})?;
let tokenizer = Tokenizer::load(&dir.join(variant.vocab_file()))?;
let features = FeatureExtractor::new();
tracing::info!(
"Models loaded (vocab_size={}, pool_size={pool_size})",
tokenizer.vocab_size()
);
#[cfg(feature = "diarization")]
#[allow(deprecated)] let speaker_encoder = {
let model_path = dir.join("wespeaker_resnet34.onnx");
if model_path.exists() {
match OnnxEmbeddingExtractor::new(
&model_path,
SPEAKER_EMBEDDING_DIM,
SPEAKER_SEGMENT_SAMPLES,
SPEAKER_POOL_SIZE,
) {
Ok(enc) => {
tracing::info!("Speaker encoder loaded (diarization available)");
Some(std::sync::Arc::new(enc))
}
Err(e) => {
tracing::warn!(
"Speaker encoder not loaded, diarization unavailable: {e:#}"
);
None
}
}
} else {
tracing::warn!("wespeaker_resnet34.onnx not found, diarization unavailable");
None
}
};
let (pool, batch_pool) = Self::split_triplets(triplets, batch_pool_size);
let engine = Self {
pool,
batch_pool,
tokenizer,
features,
variant,
punctuator: None,
itn: false,
biaser: None,
vad: None,
vad_config: crate::vad::VadConfig::default(),
int8: is_int8,
#[cfg(feature = "diarization")]
speaker_encoder,
};
#[cfg(feature = "coreml")]
let engine = probe_or_rebuild(
engine,
|e: &Self| e.warmup_one().map_err(anyhow::Error::from),
|mut e, probe_err| {
tracing::warn!(
"CoreML EP failed at runtime ({probe_err:#}); falling back to CPU execution provider"
);
let prepacked = ort::session::builder::PrepackedWeights::new();
let triplets = Self::load_triplets(
dir,
variant,
pool_size,
min_size,
&prepacked,
|d, v, pp| Self::load_sessions_cpu(d, v, pp, encoder_intra_threads),
)?;
let (pool, batch_pool) = Self::split_triplets(triplets, batch_pool_size);
e.pool = pool;
e.batch_pool = batch_pool;
Ok(e)
},
)?;
Ok(engine)
}
fn warmup_one(&self) -> Result<(), GigasttError> {
self.warmup_one_on(&self.pool)
}
fn warmup_one_on(&self, pool: &SessionPool) -> Result<(), GigasttError> {
let silence = vec![0.0f32; 16000]; let mut guard = pool
.checkout_blocking()
.map_err(|e| GigasttError::Inference {
source: Box::new(e),
})?;
self.transcribe_samples(&silence, &mut guard)?;
Ok(())
}
pub fn warmup(&self) -> Result<(), GigasttError> {
for _ in 0..self.pool.total() {
self.warmup_one()?;
}
if let Some(ref batch) = self.batch_pool {
for _ in 0..batch.total() {
self.warmup_one_on(batch)?;
}
}
Ok(())
}
pub fn pool_for_batch(&self) -> &SessionPool {
self.batch_pool.as_ref().unwrap_or(&self.pool)
}
pub fn close_pools(&self) {
self.pool.close();
if let Some(ref batch) = self.batch_pool {
batch.close();
}
}
#[cfg(feature = "diarization")]
pub fn has_speaker_encoder(&self) -> bool {
self.speaker_encoder.is_some()
}
pub fn create_state(&self, diarization_enabled: bool) -> StreamingState {
#[cfg(feature = "diarization")]
let diarization_state = match (diarization_enabled, &self.speaker_encoder) {
(true, Some(enc)) => {
let config = DiaConfig {
cluster: ClusterConfig {
threshold: 0.5,
..ClusterConfig::default()
},
..DiaConfig::default()
};
let vad_config = VadConfig::default();
let vad = EnergyVad::new(-40.0, 16000, vad_config.frame_size);
let extractor = SharedExtractor(std::sync::Arc::clone(enc));
match StreamingPipeline::new(vad, extractor, config, vad_config) {
Ok(pipeline) => Some(pipeline),
Err(e) => {
tracing::warn!("Failed to initialize streaming diarization: {e:#}");
None
}
}
}
_ => None,
};
#[cfg(not(feature = "diarization"))]
if diarization_enabled {
tracing::warn!(
"diarization_enabled=true ignored: build lacks the `diarization` feature"
);
}
StreamingState {
decoder: DecoderState::new(self.tokenizer.blank_id()),
audio_buffer: Vec::new(),
assembler: TranscriptAssembler::new(),
window_start_samples: 0,
context_samples: 0,
pending_samples: 0,
resampler: None,
mel_fft_input: Vec::new(),
mel_power: Vec::new(),
mel_output: Vec::new(),
resample_output_buf: Vec::new(),
vad_endpointer: self
.vad
.as_ref()
.map(|_| crate::vad::VadEndpointer::new(&self.vad_config)),
#[cfg(feature = "diarization")]
diarization_state,
}
}
pub fn process_chunk(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, GigasttError> {
if samples.is_empty() {
return Ok(vec![]);
}
#[cfg(feature = "diarization")]
if let Some(dia) = state.diarization_state.as_mut()
&& let Err(e) = dia.feed(samples)
{
tracing::warn!("Diarization feed failed: {e:#}");
}
state.audio_buffer.extend_from_slice(samples);
state.pending_samples += samples.len();
let mut vad_endpoint = false;
if let (Some(vad), Some(ep)) = (self.vad.as_ref(), state.vad_endpointer.as_mut()) {
match ep.push(vad, samples) {
Ok(fired) => vad_endpoint = fired,
Err(e) => tracing::warn!("VAD endpoint detection failed: {e:#}"),
}
}
let over_cap = state.audio_buffer.len() >= STREAM_MAX_WINDOW_SAMPLES;
if state.pending_samples < STREAM_DECODE_STRIDE_SAMPLES && !over_cap && !vad_endpoint {
return Ok(vec![]);
}
if state.audio_buffer.len() < N_FFT && !vad_endpoint && !over_cap {
return Ok(vec![]);
}
let endpoint = self
.decode_window(state, triplet)
.map_err(|e| GigasttError::Inference { source: e.into() })?;
state.pending_samples = 0;
let ts = now_timestamp();
if endpoint || over_cap || vad_endpoint {
let seg = state.assembler.finalize(ts);
let keep = STREAM_LEFT_CONTEXT_SAMPLES.min(state.audio_buffer.len());
let slide_off = state.audio_buffer.len() - keep;
if slide_off > 0 {
audio::consume_audio_buffer(&mut state.audio_buffer, slide_off);
state.window_start_samples += slide_off;
}
state.context_samples = keep;
if seg.text.trim().is_empty() {
return Ok(vec![]);
}
return Ok(vec![seg]);
}
if state.assembler.is_empty() {
return Ok(vec![]);
}
Ok(vec![state.assembler.partial(ts)])
}
fn decode_window(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> anyhow::Result<bool> {
let mel_start = std::time::Instant::now();
let num_frames = self.features.compute_mel(
&state.audio_buffer,
&mut state.mel_fft_input,
&mut state.mel_power,
&mut state.mel_output,
);
tracing::debug!(
elapsed_us = mel_start.elapsed().as_micros() as u64,
"mel_compute"
);
if num_frames == 0 {
return Ok(false);
}
let frame_offset = state.window_start_samples / (HOP_LENGTH * ENCODER_SUBSAMPLING);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
let (all_words, endpoint) = self.run_inference(
triplet,
&state.mel_output[..],
num_frames,
&mut decoder_state,
frame_offset,
)?;
let window_start_s = frame_offset as f64 * SECONDS_PER_FRAME;
let context_boundary_s = window_start_s + state.context_samples as f64 / 16000.0;
#[cfg_attr(not(feature = "diarization"), allow(unused_mut))]
let mut tail: Vec<WordInfo> = all_words
.into_iter()
.filter(|w| w.start + f64::EPSILON >= context_boundary_s)
.collect();
#[cfg(feature = "diarization")]
if let Some(dia) = state.diarization_state.as_mut()
&& let Some(turn) = dia.turns().last()
{
let speaker = turn.speaker.0;
for w in &mut tail {
w.speaker = Some(speaker);
}
}
state.assembler.set_words(tail);
Ok(endpoint)
}
pub fn finish_stream(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Option<TranscriptSegment> {
let has_pending = state.pending_samples > 0 && state.audio_buffer.len() >= N_FFT;
if has_pending && let Err(e) = self.decode_window(state, triplet) {
tracing::warn!("finish_stream decode failed: {e:#}");
}
self.flush_state(state)
}
pub fn flush_state(&self, state: &mut StreamingState) -> Option<TranscriptSegment> {
if state.assembler.is_empty() {
return None;
}
Some(state.assembler.finalize(now_timestamp()))
}
pub fn transcribe_file(
&self,
path: &str,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples =
audio::decode_audio_file(path).map_err(|e| GigasttError::InvalidAudio {
reason: format!("{e:#}"),
})?;
self.transcribe_samples(&float_samples, triplet)
}
pub fn transcribe_bytes(
&self,
data: &[u8],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
self.transcribe_bytes_shared(bytes::Bytes::copy_from_slice(data), triplet)
}
pub fn transcribe_bytes_shared(
&self,
data: bytes::Bytes,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples =
audio::decode_audio_bytes_shared(data).map_err(|e| GigasttError::InvalidAudio {
reason: format!("{e:#}"),
})?;
self.transcribe_samples(&float_samples, triplet)
}
fn transcribe_samples(
&self,
float_samples: &[f32],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let duration_s = float_samples.len() as f64 / 16000.0;
#[cfg_attr(not(feature = "diarization"), allow(unused_mut))]
let mut words = match &self.vad {
Some(vad) => match vad.speech_regions(float_samples, &self.vad_config) {
Ok(regions) => self.decode_speech_regions(float_samples, ®ions, triplet)?,
Err(e) => {
tracing::warn!("VAD failed, decoding full audio: {e:#}");
self.decode_words(float_samples, triplet)?
}
},
None => self.decode_words(float_samples, triplet)?,
};
#[cfg(feature = "diarization")]
if let Some(ref enc) = self.speaker_encoder {
let config = DiaConfig::default();
let vad_config = VadConfig::default();
let pipeline = Pipeline::new(config, vad_config);
let mut vad = EnergyVad::new(-40.0, 16000, vad_config.frame_size);
match pipeline.run(float_samples, enc.as_ref(), &mut vad) {
Ok(dia_result) => {
for word in &mut words {
let mid = (word.start + word.end) / 2.0;
if let Some(turn) = dia_result
.turns
.iter()
.find(|t| t.time.start <= mid && t.time.end >= mid)
{
word.speaker = Some(turn.speaker.0);
}
}
}
Err(e) => {
tracing::warn!("Offline diarization failed: {e:#}");
}
}
}
let text: String = words
.iter()
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
let text = if self.itn {
crate::itn::apply_itn(&text)
} else {
text
};
let text = match &self.punctuator {
Some(p) => p.restore(&text),
None => text,
};
Ok(TranscribeResult {
text,
words,
duration_s,
})
}
fn decode_words(
&self,
samples: &[f32],
triplet: &mut SessionTriplet,
) -> Result<Vec<WordInfo>, GigasttError> {
if samples.len() <= CHUNK_THRESHOLD_SAMPLES {
let (features, num_frames) = self.features.compute(samples);
tracing::info!("Extracted {} mel frames", num_frames);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
Ok(self
.run_inference(triplet, &features, num_frames, &mut decoder_state, 0)
.map_err(|e| GigasttError::Inference { source: e.into() })?
.0)
} else {
self.transcribe_samples_chunked(samples, triplet)
}
}
fn decode_speech_regions(
&self,
float_samples: &[f32],
regions: &[(usize, usize)],
triplet: &mut SessionTriplet,
) -> Result<Vec<WordInfo>, GigasttError> {
if regions.is_empty() {
tracing::info!("VAD found no speech; skipping decode");
return Ok(Vec::new());
}
let speech_len: usize = regions.iter().map(|(s, e)| e - s).sum();
let mut speech = Vec::with_capacity(speech_len);
for &(s, e) in regions {
speech.extend_from_slice(&float_samples[s..e]);
}
tracing::info!(
"VAD kept {}/{} samples ({} speech region(s))",
speech_len,
float_samples.len(),
regions.len()
);
let mut words = self.decode_words(&speech, triplet)?;
for w in &mut words {
w.start = crate::vad::remap_compressed_seconds(w.start, regions, 16000.0);
w.end = crate::vad::remap_compressed_seconds(w.end, regions, 16000.0);
}
Ok(words)
}
fn transcribe_samples_chunked(
&self,
float_samples: &[f32],
triplet: &mut SessionTriplet,
) -> Result<Vec<WordInfo>, GigasttError> {
let total = float_samples.len();
let stride = CHUNK_WINDOW_SAMPLES - CHUNK_OVERLAP_SAMPLES;
let frame_samples = HOP_LENGTH * ENCODER_SUBSAMPLING;
let stride = (stride / frame_samples) * frame_samples;
tracing::info!(
"Long-form chunked decode: {:.1}s in ~{}s windows ({}s overlap)",
total as f64 / 16000.0,
CHUNK_WINDOW_SAMPLES / 16000,
CHUNK_OVERLAP_SAMPLES / 16000,
);
let mut merged: Vec<WordInfo> = Vec::new();
let mut start = 0usize;
while start < total {
let end = (start + CHUNK_WINDOW_SAMPLES).min(total);
let chunk = &float_samples[start..end];
let (features, num_frames) = self.features.compute(chunk);
let frame_offset = start / frame_samples;
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
let (chunk_words, _endpoint) = self
.run_inference(
triplet,
&features,
num_frames,
&mut decoder_state,
frame_offset,
)
.map_err(|e| GigasttError::Inference { source: e.into() })?;
let overlap_mid_s = (start as f64 + CHUNK_OVERLAP_SAMPLES as f64 / 2.0) / 16000.0;
merged = stitch_chunk_words(merged, chunk_words, overlap_mid_s);
if end == total {
break;
}
start += stride;
}
Ok(merged)
}
fn run_inference(
&self,
triplet: &mut SessionTriplet,
features: &[f32],
num_frames: usize,
decoder_state: &mut DecoderState,
frame_offset: usize,
) -> anyhow::Result<(Vec<WordInfo>, bool)> {
let signal_tensor = TensorRef::from_array_view(([1_usize, N_MELS, num_frames], features))?;
let length_data = [num_frames as i64];
let length_tensor = TensorRef::from_array_view(([1_usize], length_data.as_slice()))?;
let enc_start = std::time::Instant::now();
let encoder_outputs = triplet
.encoder
.run(ort::inputs![signal_tensor, length_tensor])
.context("Encoder inference failed")?;
tracing::info!(
elapsed_ms = enc_start.elapsed().as_millis() as u64,
"encoder_inference"
);
let (_enc_shape, enc_data) = encoder_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract encoder output")?;
let (_len_shape, len_data) = encoder_outputs[1]
.try_extract_tensor::<i32>()
.context("Failed to extract encoder length")?;
let enc_len = usize::try_from(len_data[0]).context("Negative encoder length")?;
tracing::debug!("Encoder output: {} frames", enc_len);
let dec_start = std::time::Instant::now();
let result = decode::greedy_decode(
&mut triplet.decoder,
&mut triplet.joiner,
enc_data,
enc_len,
self.tokenizer.blank_id(),
decoder_state,
self.biaser.as_ref(),
)?;
tracing::info!(
elapsed_ms = dec_start.elapsed().as_millis() as u64,
"greedy_decode"
);
let words = self.tokens_to_words(&result.tokens, frame_offset);
tracing::info!(
tokens = result.tokens.len(),
words = words.len(),
duration_ms = dec_start.elapsed().as_millis() as u64,
"Decoded tokens"
);
Ok((words, result.endpoint_detected))
}
fn tokens_to_words(&self, tokens: &[decode::TokenInfo], frame_offset: usize) -> Vec<WordInfo> {
TokenFormatter::tokens_to_words(&self.tokenizer, tokens, frame_offset)
}
}
pub(crate) fn stitch_chunk_words(
mut merged: Vec<WordInfo>,
next: Vec<WordInfo>,
seam_s: f64,
) -> Vec<WordInfo> {
if merged.is_empty() {
return next;
}
merged.retain(|w| w.start <= seam_s);
merged.extend(next.into_iter().filter(|w| w.start > seam_s));
merged
}
pub(crate) struct TokenFormatter;
impl TokenFormatter {
pub(crate) fn tokens_to_words(
tokenizer: &Tokenizer,
tokens: &[decode::TokenInfo],
frame_offset: usize,
) -> Vec<WordInfo> {
if tokens.is_empty() {
return Vec::new();
}
let mut words = Vec::new();
let mut current_word = String::new();
let mut word_start_frame: Option<usize> = None;
let mut word_end_frame: usize = 0;
let mut word_confidences: Vec<f32> = Vec::new();
for token in tokens {
let token_text = tokenizer.token_text(token.token_id);
let is_word_boundary = token_text.starts_with(tokenizer::WORD_BOUNDARY);
if is_word_boundary && !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: std::mem::take(&mut current_word),
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64
* SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
current_word.clear();
word_confidences.clear();
word_start_frame = None;
}
let clean = if let Some(stripped) = token_text.strip_prefix(tokenizer::WORD_BOUNDARY) {
stripped
} else {
token_text
};
if !clean.is_empty() {
current_word.push_str(clean);
if word_start_frame.is_none() {
word_start_frame = Some(token.frame_index);
}
word_end_frame = token.frame_index;
word_confidences.push(token.confidence);
}
}
if !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word,
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
}
words
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TranscribeResult {
pub text: String,
pub words: Vec<WordInfo>,
pub duration_s: f64,
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct TranscriptSegment {
pub text: String,
pub words: Vec<WordInfo>,
pub is_final: bool,
pub timestamp: f64,
}
impl TranscriptSegment {
pub fn empty_final() -> Self {
Self {
text: String::new(),
words: vec![],
is_final: true,
timestamp: now_timestamp(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cap_pool_size_for_ram_clamps_on_low_memory() {
let enc = 225 * 1024 * 1024;
let two_gib = 2 * 1024 * 1024 * 1024;
assert_eq!(Engine::cap_pool_size_for_ram(4, enc, two_gib), 2);
}
#[test]
fn test_cap_pool_size_for_ram_no_clamp_with_ample_ram() {
let enc = 225 * 1024 * 1024;
let sixty_four_gib = 64u64 * 1024 * 1024 * 1024;
assert_eq!(Engine::cap_pool_size_for_ram(4, enc, sixty_four_gib), 4);
}
#[test]
fn test_cap_pool_size_for_ram_never_below_one() {
let huge_enc = 8 * 1024 * 1024 * 1024;
let small_ram = 1024 * 1024 * 1024;
assert_eq!(Engine::cap_pool_size_for_ram(4, huge_enc, small_ram), 1);
}
#[test]
fn test_cap_pool_size_for_ram_noop_on_unknown_inputs() {
assert_eq!(Engine::cap_pool_size_for_ram(4, 0, 8 << 30), 4);
assert_eq!(Engine::cap_pool_size_for_ram(4, 200 << 20, 0), 4);
assert_eq!(Engine::cap_pool_size_for_ram(1, 200 << 20, 1 << 30), 1);
assert_eq!(Engine::cap_pool_size_for_ram(0, 200 << 20, 1 << 30), 1);
}
#[test]
fn test_clamp_encoder_intra_threads() {
assert_eq!(Engine::clamp_encoder_intra_threads(2, 1, 10), 1);
assert_eq!(Engine::clamp_encoder_intra_threads(4, 1, 4), 1);
assert_eq!(Engine::clamp_encoder_intra_threads(2, 4, 16), 4);
assert_eq!(Engine::clamp_encoder_intra_threads(4, 4, 10), 2);
assert_eq!(Engine::clamp_encoder_intra_threads(8, 4, 4), 1);
assert_eq!(Engine::clamp_encoder_intra_threads(0, 4, 8), 4);
assert_eq!(Engine::clamp_encoder_intra_threads(2, 0, 8), 1);
assert_eq!(Engine::clamp_encoder_intra_threads(2, 4, 0), 1);
}
#[test]
fn test_batch_split_count_clamps() {
assert_eq!(Engine::batch_split_count(4, 1), 1); assert_eq!(Engine::batch_split_count(4, 0), 0); assert_eq!(Engine::batch_split_count(4, 10), 3); assert_eq!(Engine::batch_split_count(1, 1), 0); assert_eq!(Engine::batch_split_count(0, 1), 0); assert_eq!(Engine::batch_split_count(2, 1), 1);
}
#[test]
fn test_split_pool_routes_items_to_two_pools() {
let (pool, batch) = Engine::split_pool(vec![1u32, 2, 3, 4], 1);
assert_eq!(pool.total(), 3);
assert_eq!(batch.as_ref().map(|b| b.total()), Some(1));
let (pool, batch) = Engine::split_pool(vec![1u32, 2, 3, 4], 0);
assert_eq!(pool.total(), 4);
assert!(batch.is_none());
let (pool, batch) = Engine::split_pool(vec![1u32, 2, 3], 9);
assert_eq!(pool.total(), 1);
assert_eq!(batch.as_ref().map(|b| b.total()), Some(2));
let (pool, batch) = Engine::split_pool(vec![1u32], 1);
assert_eq!(pool.total(), 1);
assert!(batch.is_none());
}
#[test]
fn test_token_formatter_groups_words() {
let tok = Tokenizer::from_tokens(vec![
"\u{2581}hel".into(), "lo".into(), "\u{2581}wor".into(), "ld".into(), ]);
let tokens = vec![
decode::TokenInfo {
token_id: 0,
frame_index: 0,
confidence: 0.9,
},
decode::TokenInfo {
token_id: 1,
frame_index: 1,
confidence: 0.8,
},
decode::TokenInfo {
token_id: 2,
frame_index: 2,
confidence: 0.95,
},
decode::TokenInfo {
token_id: 3,
frame_index: 3,
confidence: 0.85,
},
];
let words = TokenFormatter::tokens_to_words(&tok, &tokens, 0);
assert_eq!(words.len(), 2);
assert_eq!(words[0].word, "hello");
assert_eq!(words[1].word, "world");
assert!((words[0].confidence - 0.85).abs() < 1e-6);
assert!((words[1].confidence - 0.90).abs() < 1e-6);
assert!((words[0].start - 0.0).abs() < 1e-9);
assert!((words[0].end - 0.04).abs() < 1e-9);
assert!((words[1].start - 0.08).abs() < 1e-9);
}
#[test]
fn test_token_formatter_empty_tokens() {
let tok = Tokenizer::from_tokens(vec!["\u{2581}a".into()]);
assert!(TokenFormatter::tokens_to_words(&tok, &[], 0).is_empty());
}
#[test]
fn test_token_formatter_frame_offset_shifts_time() {
let tok = Tokenizer::from_tokens(vec!["\u{2581}x".into()]);
let tokens = vec![decode::TokenInfo {
token_id: 0,
frame_index: 0,
confidence: 1.0,
}];
let words = TokenFormatter::tokens_to_words(&tok, &tokens, 10);
assert_eq!(words.len(), 1);
assert!((words[0].start - 0.4).abs() < 1e-9);
}
fn word(text: &str, start: f64, end: f64) -> WordInfo {
WordInfo::new(text, start, end, 1.0, None)
}
#[test]
fn test_stitch_first_chunk_passes_through() {
let next = vec![word("a", 0.0, 0.5), word("b", 0.6, 1.0)];
let out = stitch_chunk_words(Vec::new(), next.clone(), 11.0);
assert_eq!(out.len(), 2);
assert_eq!(out[0].word, "a");
assert_eq!(out[1].word, "b");
}
#[test]
fn test_stitch_dedups_overlap_no_drop_no_dup() {
let chunk_a = vec![
word("first", 1.0, 1.4), word("middle", 21.0, 21.4), word("dup", 22.4, 22.8), ];
let chunk_b = vec![
word("dup", 22.5, 22.9), word("later", 25.0, 25.4), word("end", 40.0, 40.4), ];
let seam_s = 22.0 + CHUNK_OVERLAP_SAMPLES as f64 / 2.0 / 16000.0; assert!((seam_s - 23.0).abs() < 1e-9);
let out = stitch_chunk_words(chunk_a, chunk_b, seam_s);
let texts: Vec<&str> = out.iter().map(|w| w.word.as_str()).collect();
assert_eq!(texts, vec!["first", "middle", "dup", "later", "end"]);
for w in out.windows(2) {
assert!(w[0].start <= w[1].start, "not monotonic: {:?}", out);
}
}
#[test]
fn test_stitch_drops_a_tail_past_seam() {
let chunk_a = vec![word("keep", 22.0, 22.4), word("a_tail", 23.5, 23.9)];
let chunk_b = vec![word("b_seam", 23.2, 23.6), word("b_late", 30.0, 30.4)];
let out = stitch_chunk_words(chunk_a, chunk_b, 23.0);
let texts: Vec<&str> = out.iter().map(|w| w.word.as_str()).collect();
assert_eq!(texts, vec!["keep", "b_seam", "b_late"]);
}
#[test]
fn test_stitch_timestamp_offset_math() {
let tok = Tokenizer::from_tokens(vec!["\u{2581}w".into()]);
let tokens = vec![decode::TokenInfo {
token_id: 0,
frame_index: 0,
confidence: 1.0,
}];
let start_samples = 16000 * 22; let frame_offset = start_samples / (HOP_LENGTH * ENCODER_SUBSAMPLING);
let words = TokenFormatter::tokens_to_words(&tok, &tokens, frame_offset);
assert_eq!(words.len(), 1);
assert!(
(words[0].start - 22.0).abs() < 1e-9,
"got {}",
words[0].start
);
}
#[test]
#[allow(clippy::assertions_on_constants)] fn test_chunk_constants_sane() {
assert!(CHUNK_WINDOW_SAMPLES > CHUNK_OVERLAP_SAMPLES);
assert!(CHUNK_THRESHOLD_SAMPLES >= CHUNK_WINDOW_SAMPLES);
}
#[test]
fn test_finalize_pool_load_full() {
let r: Vec<anyhow::Result<u32>> = vec![Ok(1), Ok(2), Ok(3)];
assert_eq!(Engine::finalize_pool_load(r, 3, 3).unwrap(), vec![1, 2, 3]);
}
#[test]
fn test_finalize_pool_load_degraded_boots() {
let r: Vec<anyhow::Result<u32>> = vec![
Ok(1),
Err(anyhow::anyhow!("boom")),
Ok(3),
Err(anyhow::anyhow!("boom2")),
];
assert_eq!(Engine::finalize_pool_load(r, 4, 1).unwrap(), vec![1, 3]);
}
#[test]
fn test_finalize_pool_load_below_min_errors() {
let r: Vec<anyhow::Result<u32>> = vec![
Ok(1),
Err(anyhow::anyhow!("boom")),
Err(anyhow::anyhow!("boom2")),
];
let err = Engine::finalize_pool_load(r, 3, 2).unwrap_err().to_string();
assert!(err.contains("loaded only 1/3"), "got: {err}");
assert!(err.contains("need at least 2"), "got: {err}");
}
#[test]
fn test_finalize_pool_load_all_fail_errors() {
let r: Vec<anyhow::Result<u32>> =
vec![Err(anyhow::anyhow!("a")), Err(anyhow::anyhow!("b"))];
assert!(Engine::finalize_pool_load(r, 2, 1).is_err());
}
#[test]
fn test_finalize_pool_load_min_clamped_to_pool() {
let r: Vec<anyhow::Result<u32>> = vec![Ok(1), Ok(2)];
assert_eq!(Engine::finalize_pool_load(r, 2, 99).unwrap(), vec![1, 2]);
}
#[test]
fn test_decoder_state_new_zeros() {
let blank_id = 1024;
let state = DecoderState::new(blank_id);
assert!(state.h.iter().all(|&v| v == 0.0));
assert!(state.c.iter().all(|&v| v == 0.0));
assert_eq!(state.prev_token, blank_id as i64);
}
#[test]
fn test_decoder_state_dimensions() {
let state = DecoderState::new(1024);
assert_eq!(state.h.len(), PRED_HIDDEN);
assert_eq!(state.c.len(), PRED_HIDDEN);
}
#[test]
fn test_decoder_state_custom_blank_id() {
let state = DecoderState::new(42);
assert_eq!(state.prev_token, 42);
}
#[test]
fn test_feature_extractor_default() {
let _fe = FeatureExtractor::default();
}
#[test]
fn test_transcript_assembler_default() {
let ta = TranscriptAssembler::default();
assert!(ta.text.is_empty());
assert!(ta.words.is_empty());
}
#[test]
fn test_pool_checkout_blocking_fast_path() {
let pool = Pool::new(vec![42u32]);
let guard = pool.checkout_blocking().expect("checkout_blocking");
assert_eq!(*guard, 42);
drop(guard);
assert_eq!(pool.available(), 1);
}
#[test]
fn test_pool_checkout_blocking_closed() {
let pool = Pool::<u32>::new(vec![]);
pool.close();
assert!(matches!(pool.checkout_blocking(), Err(PoolError::Closed)));
}
#[test]
fn test_pool_checkout_blocking_slow_path() {
let pool = std::sync::Arc::new(Pool::new(vec![42u32]));
let primary = pool.checkout_blocking().unwrap();
let handle = std::thread::spawn({
let pool = pool.clone();
move || pool.checkout_blocking()
});
std::thread::sleep(std::time::Duration::from_millis(50));
drop(primary);
let guard = handle.join().expect("join").expect("checkout");
assert_eq!(*guard, 42);
drop(guard);
assert_eq!(pool.available(), 1);
}
#[test]
fn test_pool_error_display() {
assert_eq!(format!("{}", PoolError::Closed), "session pool is closed");
}
#[test]
fn test_ort_err() {
let e = ort_err("test ort error");
assert_eq!(format!("{e}"), "test ort error");
}
#[test]
fn test_engine_load_missing_dir() {
let result = Engine::load_with_pool_size("/nonexistent/path/for/tests", 1);
assert!(matches!(result, Err(GigasttError::ModelLoad { .. })));
}
#[test]
fn test_engine_load_empty_dir() {
let dir = tempfile::tempdir().unwrap();
let result = Engine::load_with_pool_size(dir.path().to_str().unwrap(), 1);
assert!(matches!(result, Err(GigasttError::ModelLoad { .. })));
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_normal_drop() {
let pool = Pool::new(vec![1u32, 2, 3]);
assert_eq!(pool.available(), 3);
{
let _guard = pool.checkout().await.expect("checkout");
assert_eq!(pool.available(), 2);
}
assert_eq!(pool.available(), 3);
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_panic_unwind() {
let pool = std::sync::Arc::new(Pool::new(vec![1u32]));
assert_eq!(pool.available(), 1);
let pool_clone = pool.clone();
let result = tokio::spawn(async move {
let _guard = pool_clone.checkout().await.expect("checkout");
assert_eq!(pool_clone.available(), 0);
panic!("synthetic inference panic");
})
.await;
assert!(result.is_err(), "spawned task must report the panic");
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_close_wakes_waiters_with_closed() {
let pool = std::sync::Arc::new(Pool::<u32>::new(vec![]));
let waiter = tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await.map(|_g| ()) }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
pool.close();
let res = waiter.await.expect("join");
assert!(matches!(res, Err(PoolError::Closed)));
}
#[tokio::test]
async fn test_pool_fifo_under_contention() {
let pool = std::sync::Arc::new(Pool::new(vec![0u32]));
let primary = pool.checkout().await.expect("primary checkout");
assert_eq!(pool.available(), 0);
let waker_log = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
let mut handles = Vec::new();
for id in 0u32..3 {
let pool = pool.clone();
let log = waker_log.clone();
handles.push(tokio::spawn(async move {
let g = pool.checkout().await.expect("checkout");
log.lock().await.push(id);
drop(g);
}));
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
drop(primary);
for h in handles {
h.await.expect("join");
}
let log = waker_log.lock().await.clone();
assert_eq!(log, vec![0, 1, 2], "waiters must wake in FIFO order");
}
#[tokio::test]
async fn test_into_owned_for_spawn_blocking() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let reservation = guard.into_owned();
let result = tokio::task::spawn_blocking(move || {
assert_eq!(*reservation, "triplet");
reservation.checkin();
"done"
})
.await
.expect("join");
assert_eq!(pool.available(), 1);
assert_eq!(result, "done");
}
#[tokio::test]
async fn test_owned_reservation_returns_on_spawn_blocking_panic() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let reservation = guard.into_owned();
let result = tokio::task::spawn_blocking(move || {
let _reservation = reservation;
panic!("simulated inference panic");
})
.await;
assert!(result.is_err(), "spawn_blocking must report the panic");
assert_eq!(
pool.available(),
1,
"reservation must be returned after panic"
);
}
#[tokio::test]
async fn test_owned_reservation_drop_returns_item() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let reservation = guard.into_owned();
tokio::task::spawn_blocking(move || {
let _reservation = reservation;
})
.await
.expect("join");
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_close_is_idempotent() {
let pool = Pool::<u32>::new(vec![]);
pool.close();
pool.close();
}
#[tokio::test]
async fn test_pool_waiters_count() {
let pool = std::sync::Arc::new(Pool::<u32>::new(vec![]));
let w1 = tokio::spawn({
let p = pool.clone();
async move { p.checkout().await.map(|_| ()) }
});
let w2 = tokio::spawn({
let p = pool.clone();
async move { p.checkout().await.map(|_| ()) }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(pool.waiters(), 2, "both blocked tasks must be waiters");
pool.close();
let _ = w1.await;
let _ = w2.await;
}
#[tokio::test]
async fn test_owned_reservation_round_trip_through_option() {
let pool = std::sync::Arc::new(Pool::new(vec![42u32]));
let guard = pool.checkout().await.expect("checkout");
let mut reservation: Option<OwnedReservation<u32>> = Some(guard.into_owned());
let (res_back, val) = tokio::task::spawn_blocking(move || {
let mut r = reservation.take().unwrap();
*r += 1;
let v = *r;
(r, v)
})
.await
.expect("join");
reservation = Some(res_back);
assert_eq!(val, 43);
drop(reservation);
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_slot_not_leaked_on_cancelled_checkout() {
let pool = std::sync::Arc::new(Pool::new(vec![42u32]));
let primary = pool.checkout().await.expect("checkout");
let aborted = tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
aborted.abort();
let _ = aborted.await;
assert_eq!(pool.waiters(), 1);
drop(primary);
assert_eq!(pool.available(), 1, "item must return to pool, not leak");
assert_eq!(pool.waiters(), 0, "dead waiter must be removed");
}
#[tokio::test]
async fn test_pool_slot_not_leaked_on_timeout_checkout() {
let pool = std::sync::Arc::new(Pool::new(vec![42u32]));
let primary = pool.checkout().await.expect("checkout");
let result =
tokio::time::timeout(std::time::Duration::from_millis(10), pool.checkout()).await;
assert!(result.is_err(), "checkout must time out");
assert_eq!(pool.waiters(), 1);
drop(primary);
assert_eq!(
pool.available(),
1,
"item must return to pool after timeout"
);
assert_eq!(pool.waiters(), 0, "dead waiter must be removed");
}
#[tokio::test]
async fn test_pool_multiple_dead_waiters_are_skipped() {
let pool = std::sync::Arc::new(Pool::new(vec![0u32]));
let primary = pool.checkout().await.expect("checkout");
let mut handles = Vec::new();
for _ in 0..3 {
handles.push(tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await }
}));
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
for h in handles {
h.abort();
let _ = h.await;
}
assert_eq!(pool.waiters(), 3);
drop(primary);
assert_eq!(
pool.available(),
1,
"item returned after skipping 3 dead waiters"
);
assert_eq!(pool.waiters(), 0);
}
#[test]
fn test_transcript_assembler_append_and_finalize() {
let mut asm = TranscriptAssembler::new();
assert!(asm.is_empty());
asm.append(vec![
WordInfo {
word: "hello".into(),
start: 0.0,
end: 0.5,
confidence: 0.9,
speaker: None,
},
WordInfo {
word: "world".into(),
start: 0.6,
end: 1.0,
confidence: 0.85,
speaker: None,
},
]);
assert!(!asm.is_empty());
let seg = asm.finalize(1.0);
assert_eq!(seg.text, "hello world");
assert_eq!(seg.words.len(), 2);
assert!(seg.is_final);
assert_eq!(seg.timestamp, 1.0);
assert!(asm.is_empty());
}
#[test]
fn test_transcript_assembler_partial() {
let mut asm = TranscriptAssembler::new();
asm.append(vec![WordInfo {
word: "partial".into(),
start: 0.0,
end: 0.3,
confidence: 0.8,
speaker: None,
}]);
let seg = asm.partial(0.3);
assert_eq!(seg.text, "partial");
assert!(!seg.is_final);
assert!(!asm.is_empty());
}
#[test]
fn test_feature_extractor_compute_empty() {
let fe = FeatureExtractor::new();
let (mel, frames) = fe.compute(&[]);
assert_eq!(mel.len(), N_MELS);
assert_eq!(frames, 1);
assert!(mel.iter().all(|&v| v == 0.0));
}
#[test]
fn test_now_timestamp_non_negative() {
let ts = now_timestamp();
assert!(ts >= 0.0, "timestamp must be non-negative");
}
#[test]
fn test_now_timestamp_monotonic_and_epoch_aligned() {
let a = now_timestamp();
let b = now_timestamp();
assert!(
b >= a,
"now_timestamp must be non-decreasing (monotonic anchor)"
);
assert!(
a > 1_700_000_000.0,
"timestamp must stay Unix-epoch aligned"
);
assert!(a < 4_000_000_000.0, "timestamp exceeds a sane upper bound");
}
#[test]
fn test_probe_or_rebuild_keeps_state_when_probe_passes() {
let rebuilt = std::cell::Cell::new(false);
let result = probe_or_rebuild(
7u32,
|v| {
assert_eq!(*v, 7);
Ok(())
},
|_, _| {
rebuilt.set(true);
Ok(99)
},
)
.expect("healthy state must survive unchanged");
assert_eq!(result, 7);
assert!(!rebuilt.get(), "rebuild must not run when the probe passes");
}
#[test]
fn test_probe_or_rebuild_rebuilds_when_probe_fails() {
let result = probe_or_rebuild(
1u32,
|v| {
if *v == 1 {
Err(anyhow::anyhow!("first probe failed"))
} else {
Ok(())
}
},
|old, probe_err| {
assert_eq!(old, 1, "rebuild receives the failed state");
assert!(
probe_err.to_string().contains("first probe failed"),
"rebuild receives the probe error for logging"
);
Ok(2)
},
)
.expect("rebuilt state passing the probe is OK");
assert_eq!(result, 2);
}
#[test]
fn test_probe_or_rebuild_propagates_rebuild_error() {
let result = probe_or_rebuild(
1u32,
|_| Err(anyhow::anyhow!("probe failed")),
|_, _| Err(anyhow::anyhow!("rebuild failed")),
);
let err = result.expect_err("rebuild failure must be fatal");
assert!(err.to_string().contains("rebuild failed"));
}
#[test]
fn test_probe_or_rebuild_fails_when_rebuilt_state_fails_probe() {
let result = probe_or_rebuild(
1u32,
|_| Err(anyhow::anyhow!("always fails")),
|_, _| Ok(2u32),
);
assert!(
result.is_err(),
"a rebuilt state that still fails the probe must be a hard error"
);
}
#[test]
fn test_encoder_model_path_prefers_int8_when_present() {
let dir = tempfile::tempdir().expect("tempdir");
std::fs::write(dir.path().join("v3_e2e_rnnt_encoder.onnx"), b"fp32").unwrap();
std::fs::write(dir.path().join("v3_e2e_rnnt_encoder_int8.onnx"), b"int8").unwrap();
let path = Engine::encoder_model_path(dir.path(), ModelVariant::E2eRnnt);
assert_eq!(
path.file_name().unwrap(),
"v3_e2e_rnnt_encoder_int8.onnx",
"INT8 encoder must win when both files exist"
);
}
#[test]
fn test_encoder_model_path_falls_back_to_fp32() {
let dir = tempfile::tempdir().expect("tempdir");
std::fs::write(dir.path().join("v3_e2e_rnnt_encoder.onnx"), b"fp32").unwrap();
let path = Engine::encoder_model_path(dir.path(), ModelVariant::E2eRnnt);
assert_eq!(path.file_name().unwrap(), "v3_e2e_rnnt_encoder.onnx");
}
#[test]
fn test_encoder_model_path_rnnt_prefers_int8() {
let dir = tempfile::tempdir().expect("tempdir");
std::fs::write(dir.path().join("v3_rnnt_encoder.onnx"), b"fp32").unwrap();
std::fs::write(dir.path().join("v3_rnnt_encoder_int8.onnx"), b"int8").unwrap();
let path = Engine::encoder_model_path(dir.path(), ModelVariant::Rnnt);
assert_eq!(
path.file_name().unwrap(),
"v3_rnnt_encoder_int8.onnx",
"INT8 rnnt encoder must win when both files exist"
);
}
#[test]
fn test_encoder_model_path_rnnt_falls_back_to_fp32() {
let dir = tempfile::tempdir().expect("tempdir");
std::fs::write(dir.path().join("v3_rnnt_encoder.onnx"), b"fp32").unwrap();
let path = Engine::encoder_model_path(dir.path(), ModelVariant::Rnnt);
assert_eq!(path.file_name().unwrap(), "v3_rnnt_encoder.onnx");
}
#[test]
fn test_pool_sequential_checkouts_visit_every_item() {
let pool = Pool::new(vec![1u32, 2, 3]);
let mut seen = Vec::new();
for _ in 0..pool.total() {
let guard = pool.checkout_blocking().expect("checkout");
seen.push(*guard);
}
seen.sort_unstable();
assert_eq!(seen, vec![1, 2, 3]);
}
#[test]
#[ignore = "requires model"]
fn test_warmup_runs_silent_inference_on_every_triplet() {
let engine = Engine::load_with_pool_size(&crate::model::default_model_dir(), 2)
.expect("engine should load");
engine
.warmup()
.expect("warmup must succeed on a working engine");
assert_eq!(
engine.pool.available(),
engine.pool.total(),
"every triplet must be returned to the pool after warmup"
);
}
}