pub mod audio;
mod decode;
mod features;
#[cfg(not(feature = "__internals"))]
mod tokenizer;
#[cfg(feature = "__internals")]
pub mod tokenizer;
#[cfg(feature = "diarization")]
use polyvoice::streaming::StreamingPipeline;
#[cfg(feature = "diarization")]
use polyvoice::{
ClusterConfig, DiarizationConfig as DiaConfig, EmbeddingError, EmbeddingExtractor, EnergyVad,
OnnxEmbeddingExtractor, Pipeline, VadConfig,
};
#[cfg(feature = "diarization")]
const SPEAKER_EMBEDDING_DIM: usize = 256;
#[cfg(feature = "diarization")]
const SPEAKER_SEGMENT_SAMPLES: usize = 24000;
#[cfg(feature = "diarization")]
const SPEAKER_POOL_SIZE: usize = 4;
#[cfg(feature = "diarization")]
pub struct SharedExtractor(std::sync::Arc<OnnxEmbeddingExtractor>);
#[cfg(feature = "diarization")]
impl EmbeddingExtractor for SharedExtractor {
fn extract(&self, samples: &[f32], config: &DiaConfig) -> Result<Vec<f32>, EmbeddingError> {
self.0.extract(samples, config)
}
fn embedding_dim(&self) -> usize {
self.0.embedding_dim()
}
}
#[cfg(all(feature = "coreml", feature = "cuda"))]
compile_error!("Features `coreml` and `cuda` are mutually exclusive. Choose one.");
use anyhow::Context;
#[cfg(any(feature = "coreml", feature = "cuda"))]
use ort::ep;
use ort::session::Session;
use ort::value::TensorRef;
use serde::Serialize;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use crate::error::GigasttError;
use features::MelSpectrogram;
use tokenizer::Tokenizer;
pub const N_MELS: usize = 64;
pub const N_FFT: usize = 320;
pub const HOP_LENGTH: usize = 160;
pub const PRED_HIDDEN: usize = 320;
fn ort_err(e: impl std::fmt::Display) -> anyhow::Error {
anyhow::anyhow!("{e}")
}
pub fn now_timestamp() -> f64 {
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(d) => d.as_secs_f64(),
Err(e) => {
tracing::warn!("System clock is before Unix epoch: {e}");
0.0
}
}
}
const ENCODER_SUBSAMPLING: usize = 4;
const SECONDS_PER_FRAME: f64 = (HOP_LENGTH as f64 * ENCODER_SUBSAMPLING as f64) / 16000.0;
const STREAM_MAX_WINDOW_SAMPLES: usize = 16000 * 5;
const STREAM_LEFT_CONTEXT_SAMPLES: usize = 16000 * 3 / 2;
const STREAM_DECODE_STRIDE_SAMPLES: usize = 16000 * 4 / 5;
#[cfg(target_os = "android")]
const DEFAULT_POOL_SIZE: usize = 1;
#[cfg(not(target_os = "android"))]
const DEFAULT_POOL_SIZE: usize = 4;
pub struct SessionTriplet {
pub(crate) encoder: Session,
pub(crate) decoder: Session,
pub(crate) joiner: Session,
}
#[derive(Debug)]
pub enum PoolError {
Closed,
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolError::Closed => write!(f, "session pool is closed"),
}
}
}
impl std::error::Error for PoolError {}
pub struct Pool<T> {
inner: std::sync::Arc<PoolInner<T>>,
}
struct PoolInner<T> {
items: parking_lot::Mutex<std::collections::VecDeque<T>>,
waiters: parking_lot::Mutex<std::collections::VecDeque<Waiter<T>>>,
closed: std::sync::atomic::AtomicBool,
total: usize,
}
enum Waiter<T> {
Async(tokio::sync::oneshot::Sender<T>),
Blocking(std::sync::mpsc::Sender<T>),
}
pub type SessionPool = Pool<SessionTriplet>;
impl<T: Send> Pool<T> {
pub fn new(items: Vec<T>) -> Self {
let total = items.len();
Self {
inner: std::sync::Arc::new(PoolInner {
items: parking_lot::Mutex::new(std::collections::VecDeque::from(items)),
waiters: parking_lot::Mutex::new(std::collections::VecDeque::new()),
closed: std::sync::atomic::AtomicBool::new(false),
total,
}),
}
}
pub async fn checkout(&self) -> Result<PoolGuard<T>, PoolError> {
{
let mut items = self.inner.items.lock();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
if let Some(item) = items.pop_front() {
return Ok(PoolGuard::new(self.inner.clone(), item));
}
}
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut waiters = self.inner.waiters.lock();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
let mut items = self.inner.items.lock();
if let Some(item) = items.pop_front() {
drop(items);
drop(waiters);
return Ok(PoolGuard::new(self.inner.clone(), item));
}
waiters.push_back(Waiter::Async(tx));
}
match rx.await {
Ok(item) => Ok(PoolGuard::new(self.inner.clone(), item)),
Err(_) => Err(PoolError::Closed),
}
}
pub fn checkout_blocking(&self) -> Result<PoolGuard<T>, PoolError> {
{
let mut items = self.inner.items.lock();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
if let Some(item) = items.pop_front() {
return Ok(PoolGuard::new(self.inner.clone(), item));
}
}
let (tx, rx) = std::sync::mpsc::channel();
{
let mut waiters = self.inner.waiters.lock();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
let mut items = self.inner.items.lock();
if let Some(item) = items.pop_front() {
drop(items);
drop(waiters);
return Ok(PoolGuard::new(self.inner.clone(), item));
}
waiters.push_back(Waiter::Blocking(tx));
}
match rx.recv() {
Ok(item) => Ok(PoolGuard::new(self.inner.clone(), item)),
Err(_) => Err(PoolError::Closed),
}
}
pub fn close(&self) {
self.inner
.closed
.store(true, std::sync::atomic::Ordering::SeqCst);
let mut waiters = self.inner.waiters.lock();
waiters.clear();
}
pub fn total(&self) -> usize {
self.inner.total
}
pub fn available(&self) -> usize {
let items = self.inner.items.lock();
items.len()
}
pub fn waiters(&self) -> usize {
let waiters = self.inner.waiters.lock();
waiters.len()
}
}
impl<T> PoolInner<T> {
fn checkin(&self, mut item: T) {
if self.closed.load(std::sync::atomic::Ordering::SeqCst) {
return;
}
loop {
let mut waiters = self.waiters.lock();
if let Some(waiter) = waiters.pop_front() {
drop(waiters);
match waiter {
Waiter::Async(tx) => {
if let Err(returned_item) = tx.send(item) {
item = returned_item;
continue;
}
}
Waiter::Blocking(tx) => {
if let Err(std::sync::mpsc::SendError(returned_item)) = tx.send(item) {
item = returned_item;
continue;
}
}
}
} else {
drop(waiters);
let mut items = self.items.lock();
items.push_back(item);
}
break;
}
}
}
pub struct PoolGuard<T> {
inner: Option<std::sync::Arc<PoolInner<T>>>,
item: Option<T>,
}
impl<T> PoolGuard<T> {
fn new(inner: std::sync::Arc<PoolInner<T>>, item: T) -> Self {
Self {
inner: Some(inner),
item: Some(item),
}
}
pub fn into_owned(mut self) -> OwnedReservation<T> {
let item = self
.item
.take()
.unwrap_or_else(|| unreachable!("PoolGuard::into_owned called after drop"));
let inner = self.inner.take().unwrap();
OwnedReservation {
inner,
item: Some(item),
}
}
}
impl<T> Deref for PoolGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.item
.as_ref()
.unwrap_or_else(|| unreachable!("PoolGuard accessed after item taken"))
}
}
impl<T> DerefMut for PoolGuard<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.item
.as_mut()
.unwrap_or_else(|| unreachable!("PoolGuard accessed after item taken"))
}
}
impl<T> Drop for PoolGuard<T> {
fn drop(&mut self) {
if let (Some(inner), Some(item)) = (self.inner.take(), self.item.take()) {
inner.checkin(item);
}
}
}
pub struct OwnedReservation<T> {
inner: std::sync::Arc<PoolInner<T>>,
item: Option<T>,
}
impl<T> std::ops::Deref for OwnedReservation<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.item
.as_ref()
.unwrap_or_else(|| unreachable!("OwnedReservation accessed after checkin"))
}
}
impl<T> std::ops::DerefMut for OwnedReservation<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.item
.as_mut()
.unwrap_or_else(|| unreachable!("OwnedReservation accessed after checkin"))
}
}
impl<T> OwnedReservation<T> {
pub fn checkin(mut self) {
if let Some(item) = self.item.take() {
self.inner.checkin(item);
}
}
}
impl<T> Drop for OwnedReservation<T> {
fn drop(&mut self) {
if let Some(item) = self.item.take() {
self.inner.checkin(item);
}
}
}
#[non_exhaustive]
pub struct DecoderState {
pub h: Vec<f32>,
pub c: Vec<f32>,
pub prev_token: i64,
pub consecutive_blanks: usize,
}
impl DecoderState {
pub fn new(blank_id: usize) -> Self {
Self {
h: vec![0.0; PRED_HIDDEN],
c: vec![0.0; PRED_HIDDEN],
prev_token: blank_id as i64,
consecutive_blanks: 0,
}
}
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct WordInfo {
pub word: String,
pub start: f64,
pub end: f64,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
}
#[non_exhaustive]
pub struct StreamingState {
pub decoder: DecoderState,
pub audio_buffer: Vec<f32>,
pub assembler: TranscriptAssembler,
pub window_start_samples: usize,
pub context_samples: usize,
pub pending_samples: usize,
pub resampler: Option<rubato::Async<f32>>,
pub mel_fft_input: Vec<rustfft::num_complex::Complex<f32>>,
pub mel_power: Vec<f32>,
pub mel_output: Vec<f32>,
pub resample_output_buf: Vec<f32>,
#[cfg(feature = "diarization")]
pub diarization_state: Option<StreamingPipeline<EnergyVad, SharedExtractor>>,
}
pub struct FeatureExtractor {
mel: MelSpectrogram,
}
impl Default for FeatureExtractor {
fn default() -> Self {
Self::new()
}
}
impl FeatureExtractor {
pub fn new() -> Self {
Self {
mel: MelSpectrogram::new(),
}
}
pub fn prepare_buffer(&self, samples: &[f32], audio_buffer: &mut Vec<f32>) -> Option<usize> {
audio::prepare_audio_buffer(samples, audio_buffer)
}
pub fn compute_mel(
&self,
samples: &[f32],
fft_buf: &mut Vec<rustfft::num_complex::Complex<f32>>,
power_buf: &mut Vec<f32>,
output_buf: &mut Vec<f32>,
) -> usize {
self.mel
.compute_with_buffers(samples, fft_buf, power_buf, output_buf)
}
pub fn compute(&self, samples: &[f32]) -> (Vec<f32>, usize) {
self.mel.compute(samples)
}
}
pub struct TranscriptAssembler {
text: String,
words: Vec<WordInfo>,
}
impl Default for TranscriptAssembler {
fn default() -> Self {
Self::new()
}
}
impl TranscriptAssembler {
pub fn new() -> Self {
Self {
text: String::new(),
words: Vec::new(),
}
}
pub fn append(&mut self, new_words: Vec<WordInfo>) {
for w in &new_words {
if !self.text.is_empty() {
self.text.push(' ');
}
self.text.push_str(&w.word);
}
self.words.extend(new_words);
}
pub fn set_words(&mut self, words: Vec<WordInfo>) {
self.text = words
.iter()
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
self.words = words;
}
pub fn finalize(&mut self, timestamp: f64) -> TranscriptSegment {
TranscriptSegment {
text: std::mem::take(&mut self.text),
words: std::mem::take(&mut self.words),
is_final: true,
timestamp,
}
}
pub fn partial(&self, timestamp: f64) -> TranscriptSegment {
TranscriptSegment {
text: self.text.clone(),
words: self.words.clone(),
is_final: false,
timestamp,
}
}
pub fn is_empty(&self) -> bool {
self.text.is_empty()
}
}
#[cfg_attr(not(feature = "coreml"), allow(dead_code))]
fn probe_or_rebuild<S>(
state: S,
probe: impl Fn(&S) -> anyhow::Result<()>,
rebuild: impl FnOnce(S, anyhow::Error) -> anyhow::Result<S>,
) -> anyhow::Result<S> {
match probe(&state) {
Ok(()) => Ok(state),
Err(probe_err) => {
let rebuilt = rebuild(state, probe_err)?;
probe(&rebuilt).context("state failed probe even after rebuild")?;
Ok(rebuilt)
}
}
}
pub struct Engine {
pub pool: SessionPool,
tokenizer: Tokenizer,
features: FeatureExtractor,
int8: bool,
#[cfg(feature = "diarization")]
pub speaker_encoder: Option<std::sync::Arc<OnnxEmbeddingExtractor>>,
}
impl Engine {
pub fn is_int8(&self) -> bool {
self.int8
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.vocab_size()
}
pub fn load(model_dir: &str) -> Result<Self, GigasttError> {
Self::load_with_pool_size(model_dir, DEFAULT_POOL_SIZE)
}
pub fn load_with_pool_size(model_dir: &str, pool_size: usize) -> Result<Self, GigasttError> {
let dir = Path::new(model_dir);
if !dir.join("v3_e2e_rnnt_encoder.onnx").exists() {
return Err(GigasttError::ModelLoad {
path: model_dir.to_string(),
source: None,
});
}
Self::load_inner(dir, model_dir, pool_size).map_err(|e| GigasttError::ModelLoad {
path: model_dir.to_string(),
source: Some(e.into()),
})
}
fn encoder_model_path(dir: &Path) -> std::path::PathBuf {
if dir.join("v3_e2e_rnnt_encoder_int8.onnx").exists() {
dir.join("v3_e2e_rnnt_encoder_int8.onnx")
} else {
dir.join("v3_e2e_rnnt_encoder.onnx")
}
}
fn load_sessions(
dir: &Path,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
#[cfg(feature = "coreml")]
{
Self::load_sessions_coreml(dir, prepacked)
}
#[cfg(feature = "cuda")]
{
Self::load_sessions_cuda(dir, prepacked)
}
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
{
Self::load_sessions_cpu(dir, prepacked)
}
}
#[cfg(feature = "coreml")]
fn load_sessions_coreml(
dir: &Path,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
let cache_dir = dir.join("coreml_cache");
let coreml_ep = ep::CoreML::default()
.with_model_format(ep::coreml::ModelFormat::MLProgram)
.with_static_input_shapes(true)
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
.with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
.with_model_cache_dir(cache_dir.to_string_lossy())
.build();
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [coreml_ep.clone(), cpu_fallback.into()];
let encoder_path = Self::encoder_model_path(dir);
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
Ok((encoder, decoder, joiner))
}
#[cfg(feature = "cuda")]
fn load_sessions_cuda(
dir: &Path,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
let cuda_ep = ep::CUDA::default().build();
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [cuda_ep.clone(), cpu_fallback.into()];
let encoder_path = Self::encoder_model_path(dir);
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
Ok((encoder, decoder, joiner))
}
#[cfg(not(feature = "cuda"))]
fn load_sessions_cpu(
dir: &Path,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
let cache_dir = dir.join("optimized_cache");
std::fs::create_dir_all(&cache_dir)
.with_context(|| format!("Failed to create ONNX cache dir: {}", cache_dir.display()))?;
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [cpu_fallback.into()];
let encoder_path = Self::encoder_model_path(dir);
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(1)
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.with_optimized_model_path(cache_dir.join("encoder_optimized.onnx"))
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(1)
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(1)
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
Ok((encoder, decoder, joiner))
}
fn load_triplets(
dir: &Path,
pool_size: usize,
prepacked: &ort::session::builder::PrepackedWeights,
load_one: impl Fn(
&Path,
&ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)>
+ Sync,
) -> anyhow::Result<Vec<SessionTriplet>> {
std::thread::scope(|s| {
let handles: Vec<_> = (0..pool_size)
.map(|i| {
let pp = prepacked;
let load_one = &load_one;
s.spawn(move || {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
tracing::info!(
"Loading session triplet {}/{pool_size} (shared weights)",
i + 1
);
let (encoder, decoder, joiner) = load_one(dir, pp)?;
Ok(SessionTriplet {
encoder,
decoder,
joiner,
})
}))
.map_err(|_| anyhow::anyhow!("model loading thread panicked"))?
})
})
.collect();
handles
.into_iter()
.map(|h| match h.join() {
Ok(r) => r,
Err(_) => Err(anyhow::anyhow!("model loading thread panicked")),
})
.collect::<anyhow::Result<Vec<_>>>()
})
}
fn load_inner(dir: &Path, model_dir: &str, pool_size: usize) -> anyhow::Result<Self> {
let is_int8 = dir.join("v3_e2e_rnnt_encoder_int8.onnx").exists();
if is_int8 {
tracing::info!("Using INT8 quantized encoder");
}
tracing::info!("Loading ONNX models from {model_dir} (pool_size={pool_size})...");
#[cfg(feature = "coreml")]
tracing::info!("Using CoreML execution provider (Neural Engine + CPU)");
#[cfg(feature = "cuda")]
tracing::info!("Using CUDA execution provider (falls back to CPU if unavailable)");
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
tracing::info!("Using CPU execution provider");
let prepacked = ort::session::builder::PrepackedWeights::new();
#[cfg(feature = "coreml")]
let triplets = match Self::load_triplets(dir, pool_size, &prepacked, Self::load_sessions) {
Ok(triplets) => triplets,
Err(load_err) => {
tracing::warn!(
"CoreML EP failed to load sessions ({load_err:#}); falling back to CPU execution provider"
);
let prepacked = ort::session::builder::PrepackedWeights::new();
Self::load_triplets(dir, pool_size, &prepacked, Self::load_sessions_cpu)?
}
};
#[cfg(not(feature = "coreml"))]
let triplets = Self::load_triplets(dir, pool_size, &prepacked, Self::load_sessions)?;
let tokenizer = Tokenizer::load(&dir.join("v3_e2e_rnnt_vocab.txt"))?;
let features = FeatureExtractor::new();
tracing::info!(
"Models loaded (vocab_size={}, pool_size={pool_size})",
tokenizer.vocab_size()
);
#[cfg(feature = "diarization")]
let speaker_encoder = {
let model_path = dir.join("wespeaker_resnet34.onnx");
if model_path.exists() {
match OnnxEmbeddingExtractor::new(
&model_path,
SPEAKER_EMBEDDING_DIM,
SPEAKER_SEGMENT_SAMPLES,
SPEAKER_POOL_SIZE,
) {
Ok(enc) => {
tracing::info!("Speaker encoder loaded (diarization available)");
Some(std::sync::Arc::new(enc))
}
Err(e) => {
tracing::warn!(
"Speaker encoder not loaded, diarization unavailable: {e:#}"
);
None
}
}
} else {
tracing::warn!("wespeaker_resnet34.onnx not found, diarization unavailable");
None
}
};
let engine = Self {
pool: SessionPool::new(triplets),
tokenizer,
features,
int8: is_int8,
#[cfg(feature = "diarization")]
speaker_encoder,
};
#[cfg(feature = "coreml")]
let engine = probe_or_rebuild(
engine,
|e: &Self| e.warmup_one().map_err(anyhow::Error::from),
|mut e, probe_err| {
tracing::warn!(
"CoreML EP failed at runtime ({probe_err:#}); falling back to CPU execution provider"
);
let prepacked = ort::session::builder::PrepackedWeights::new();
let triplets =
Self::load_triplets(dir, pool_size, &prepacked, Self::load_sessions_cpu)?;
e.pool = SessionPool::new(triplets);
Ok(e)
},
)?;
Ok(engine)
}
fn warmup_one(&self) -> Result<(), GigasttError> {
let silence = vec![0.0f32; 16000]; let mut guard = self
.pool
.checkout_blocking()
.map_err(|e| GigasttError::Inference {
source: Box::new(e),
})?;
self.transcribe_samples(&silence, &mut guard)?;
Ok(())
}
pub fn warmup(&self) -> Result<(), GigasttError> {
for _ in 0..self.pool.total() {
self.warmup_one()?;
}
Ok(())
}
#[cfg(feature = "diarization")]
pub fn has_speaker_encoder(&self) -> bool {
self.speaker_encoder.is_some()
}
pub fn create_state(&self, diarization_enabled: bool) -> StreamingState {
#[cfg(feature = "diarization")]
let diarization_state = match (diarization_enabled, &self.speaker_encoder) {
(true, Some(enc)) => {
let config = DiaConfig {
cluster: ClusterConfig {
threshold: 0.5,
..ClusterConfig::default()
},
..DiaConfig::default()
};
let vad_config = VadConfig::default();
let vad = EnergyVad::new(-40.0, 16000, vad_config.frame_size);
let extractor = SharedExtractor(std::sync::Arc::clone(enc));
match StreamingPipeline::new(vad, extractor, config, vad_config) {
Ok(pipeline) => Some(pipeline),
Err(e) => {
tracing::warn!("Failed to initialize streaming diarization: {e:#}");
None
}
}
}
_ => None,
};
#[cfg(not(feature = "diarization"))]
if diarization_enabled {
tracing::warn!(
"diarization_enabled=true ignored: build lacks the `diarization` feature"
);
}
StreamingState {
decoder: DecoderState::new(self.tokenizer.blank_id()),
audio_buffer: Vec::new(),
assembler: TranscriptAssembler::new(),
window_start_samples: 0,
context_samples: 0,
pending_samples: 0,
resampler: None,
mel_fft_input: Vec::new(),
mel_power: Vec::new(),
mel_output: Vec::new(),
resample_output_buf: Vec::new(),
#[cfg(feature = "diarization")]
diarization_state,
}
}
pub fn process_chunk(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, GigasttError> {
if samples.is_empty() {
return Ok(vec![]);
}
#[cfg(feature = "diarization")]
if let Some(dia) = state.diarization_state.as_mut()
&& let Err(e) = dia.feed(samples)
{
tracing::warn!("Diarization feed failed: {e:#}");
}
state.audio_buffer.extend_from_slice(samples);
state.pending_samples += samples.len();
let over_cap = state.audio_buffer.len() >= STREAM_MAX_WINDOW_SAMPLES;
if state.pending_samples < STREAM_DECODE_STRIDE_SAMPLES && !over_cap {
return Ok(vec![]);
}
if state.audio_buffer.len() < N_FFT {
return Ok(vec![]);
}
let endpoint = self
.decode_window(state, triplet)
.map_err(|e| GigasttError::Inference { source: e.into() })?;
state.pending_samples = 0;
let ts = now_timestamp();
if endpoint || over_cap {
let seg = state.assembler.finalize(ts);
let keep = STREAM_LEFT_CONTEXT_SAMPLES.min(state.audio_buffer.len());
let slide_off = state.audio_buffer.len() - keep;
if slide_off > 0 {
audio::consume_audio_buffer(&mut state.audio_buffer, slide_off);
state.window_start_samples += slide_off;
}
state.context_samples = keep;
if seg.text.trim().is_empty() {
return Ok(vec![]);
}
return Ok(vec![seg]);
}
if state.assembler.is_empty() {
return Ok(vec![]);
}
Ok(vec![state.assembler.partial(ts)])
}
fn decode_window(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> anyhow::Result<bool> {
let mel_start = std::time::Instant::now();
let num_frames = self.features.compute_mel(
&state.audio_buffer,
&mut state.mel_fft_input,
&mut state.mel_power,
&mut state.mel_output,
);
tracing::debug!(
elapsed_us = mel_start.elapsed().as_micros() as u64,
"mel_compute"
);
if num_frames == 0 {
return Ok(false);
}
let frame_offset = state.window_start_samples / (HOP_LENGTH * ENCODER_SUBSAMPLING);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
let (all_words, endpoint) = self.run_inference(
triplet,
&state.mel_output[..],
num_frames,
&mut decoder_state,
frame_offset,
)?;
let window_start_s = frame_offset as f64 * SECONDS_PER_FRAME;
let context_boundary_s = window_start_s + state.context_samples as f64 / 16000.0;
#[cfg_attr(not(feature = "diarization"), allow(unused_mut))]
let mut tail: Vec<WordInfo> = all_words
.into_iter()
.filter(|w| w.start + f64::EPSILON >= context_boundary_s)
.collect();
#[cfg(feature = "diarization")]
if let Some(dia) = state.diarization_state.as_mut()
&& let Some(turn) = dia.turns().last()
{
let speaker = turn.speaker.0;
for w in &mut tail {
w.speaker = Some(speaker);
}
}
state.assembler.set_words(tail);
Ok(endpoint)
}
pub fn finish_stream(
&self,
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Option<TranscriptSegment> {
let has_pending = state.pending_samples > 0 && state.audio_buffer.len() >= N_FFT;
if has_pending && let Err(e) = self.decode_window(state, triplet) {
tracing::warn!("finish_stream decode failed: {e:#}");
}
self.flush_state(state)
}
pub fn flush_state(&self, state: &mut StreamingState) -> Option<TranscriptSegment> {
if state.assembler.is_empty() {
return None;
}
Some(state.assembler.finalize(now_timestamp()))
}
pub fn transcribe_file(
&self,
path: &str,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples =
audio::decode_audio_file(path).map_err(|e| GigasttError::InvalidAudio {
reason: format!("{e:#}"),
})?;
self.transcribe_samples(&float_samples, triplet)
}
pub fn transcribe_bytes(
&self,
data: &[u8],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
self.transcribe_bytes_shared(bytes::Bytes::copy_from_slice(data), triplet)
}
pub fn transcribe_bytes_shared(
&self,
data: bytes::Bytes,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples =
audio::decode_audio_bytes_shared(data).map_err(|e| GigasttError::InvalidAudio {
reason: format!("{e:#}"),
})?;
self.transcribe_samples(&float_samples, triplet)
}
fn transcribe_samples(
&self,
float_samples: &[f32],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let duration_s = float_samples.len() as f64 / 16000.0;
let (features, num_frames) = self.features.compute(float_samples);
tracing::info!("Extracted {} mel frames", num_frames);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
#[cfg_attr(not(feature = "diarization"), allow(unused_mut))]
let (mut words, _endpoint) = self
.run_inference(triplet, &features, num_frames, &mut decoder_state, 0)
.map_err(|e| GigasttError::Inference { source: e.into() })?;
#[cfg(feature = "diarization")]
if let Some(ref enc) = self.speaker_encoder {
let config = DiaConfig::default();
let vad_config = VadConfig::default();
let pipeline = Pipeline::new(config, vad_config);
let mut vad = EnergyVad::new(-40.0, 16000, vad_config.frame_size);
match pipeline.run(float_samples, enc.as_ref(), &mut vad) {
Ok(dia_result) => {
for word in &mut words {
let mid = (word.start + word.end) / 2.0;
if let Some(turn) = dia_result
.turns
.iter()
.find(|t| t.time.start <= mid && t.time.end >= mid)
{
word.speaker = Some(turn.speaker.0);
}
}
}
Err(e) => {
tracing::warn!("Offline diarization failed: {e:#}");
}
}
}
let text: String = words
.iter()
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
Ok(TranscribeResult {
text,
words,
duration_s,
})
}
fn run_inference(
&self,
triplet: &mut SessionTriplet,
features: &[f32],
num_frames: usize,
decoder_state: &mut DecoderState,
frame_offset: usize,
) -> anyhow::Result<(Vec<WordInfo>, bool)> {
let signal_tensor = TensorRef::from_array_view(([1_usize, N_MELS, num_frames], features))?;
let length_data = [num_frames as i64];
let length_tensor = TensorRef::from_array_view(([1_usize], length_data.as_slice()))?;
let enc_start = std::time::Instant::now();
let encoder_outputs = triplet
.encoder
.run(ort::inputs![signal_tensor, length_tensor])
.context("Encoder inference failed")?;
tracing::info!(
elapsed_ms = enc_start.elapsed().as_millis() as u64,
"encoder_inference"
);
let (_enc_shape, enc_data) = encoder_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract encoder output")?;
let (_len_shape, len_data) = encoder_outputs[1]
.try_extract_tensor::<i32>()
.context("Failed to extract encoder length")?;
let enc_len = usize::try_from(len_data[0]).context("Negative encoder length")?;
tracing::debug!("Encoder output: {} frames", enc_len);
let dec_start = std::time::Instant::now();
let result = decode::greedy_decode(
&mut triplet.decoder,
&mut triplet.joiner,
enc_data,
enc_len,
self.tokenizer.blank_id(),
decoder_state,
)?;
tracing::info!(
elapsed_ms = dec_start.elapsed().as_millis() as u64,
"greedy_decode"
);
let words = self.tokens_to_words(&result.tokens, frame_offset);
tracing::info!(
tokens = result.tokens.len(),
words = words.len(),
duration_ms = dec_start.elapsed().as_millis() as u64,
"Decoded tokens"
);
Ok((words, result.endpoint_detected))
}
fn tokens_to_words(&self, tokens: &[decode::TokenInfo], frame_offset: usize) -> Vec<WordInfo> {
if tokens.is_empty() {
return Vec::new();
}
let mut words = Vec::new();
let mut current_word = String::new();
let mut word_start_frame: Option<usize> = None;
let mut word_end_frame: usize = 0;
let mut word_confidences: Vec<f32> = Vec::new();
for token in tokens {
let token_text = self.tokenizer.token_text(token.token_id);
let is_word_boundary = token_text.starts_with('\u{2581}');
if is_word_boundary && !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: std::mem::take(&mut current_word),
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64
* SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
current_word.clear();
word_confidences.clear();
word_start_frame = None;
}
let clean = if let Some(stripped) = token_text.strip_prefix('\u{2581}') {
stripped
} else {
token_text
};
if !clean.is_empty() {
current_word.push_str(clean);
if word_start_frame.is_none() {
word_start_frame = Some(token.frame_index);
}
word_end_frame = token.frame_index;
word_confidences.push(token.confidence);
}
}
if !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word,
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
}
words
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TranscribeResult {
pub text: String,
pub words: Vec<WordInfo>,
pub duration_s: f64,
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct TranscriptSegment {
pub text: String,
pub words: Vec<WordInfo>,
pub is_final: bool,
pub timestamp: f64,
}
impl TranscriptSegment {
pub fn empty_final() -> Self {
Self {
text: String::new(),
words: vec![],
is_final: true,
timestamp: now_timestamp(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_state_new_zeros() {
let blank_id = 1024;
let state = DecoderState::new(blank_id);
assert!(state.h.iter().all(|&v| v == 0.0));
assert!(state.c.iter().all(|&v| v == 0.0));
assert_eq!(state.prev_token, blank_id as i64);
}
#[test]
fn test_decoder_state_dimensions() {
let state = DecoderState::new(1024);
assert_eq!(state.h.len(), PRED_HIDDEN);
assert_eq!(state.c.len(), PRED_HIDDEN);
}
#[test]
fn test_decoder_state_custom_blank_id() {
let state = DecoderState::new(42);
assert_eq!(state.prev_token, 42);
}
#[test]
fn test_feature_extractor_default() {
let _fe = FeatureExtractor::default();
}
#[test]
fn test_transcript_assembler_default() {
let ta = TranscriptAssembler::default();
assert!(ta.text.is_empty());
assert!(ta.words.is_empty());
}
#[test]
fn test_pool_checkout_blocking_fast_path() {
let pool = Pool::new(vec![42u32]);
let guard = pool.checkout_blocking().expect("checkout_blocking");
assert_eq!(*guard, 42);
drop(guard);
assert_eq!(pool.available(), 1);
}
#[test]
fn test_pool_checkout_blocking_closed() {
let pool = Pool::<u32>::new(vec![]);
pool.close();
assert!(matches!(pool.checkout_blocking(), Err(PoolError::Closed)));
}
#[test]
fn test_pool_checkout_blocking_slow_path() {
let pool = std::sync::Arc::new(Pool::new(vec![42u32]));
let primary = pool.checkout_blocking().unwrap();
let handle = std::thread::spawn({
let pool = pool.clone();
move || pool.checkout_blocking()
});
std::thread::sleep(std::time::Duration::from_millis(50));
drop(primary);
let guard = handle.join().expect("join").expect("checkout");
assert_eq!(*guard, 42);
drop(guard);
assert_eq!(pool.available(), 1);
}
#[test]
fn test_pool_error_display() {
assert_eq!(format!("{}", PoolError::Closed), "session pool is closed");
}
#[test]
fn test_ort_err() {
let e = ort_err("test ort error");
assert_eq!(format!("{e}"), "test ort error");
}
#[test]
fn test_engine_load_missing_dir() {
let result = Engine::load_with_pool_size("/nonexistent/path/for/tests", 1);
assert!(matches!(result, Err(GigasttError::ModelLoad { .. })));
}
#[test]
fn test_engine_load_empty_dir() {
let dir = tempfile::tempdir().unwrap();
let result = Engine::load_with_pool_size(dir.path().to_str().unwrap(), 1);
assert!(matches!(result, Err(GigasttError::ModelLoad { .. })));
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_normal_drop() {
let pool = Pool::new(vec![1u32, 2, 3]);
assert_eq!(pool.available(), 3);
{
let _guard = pool.checkout().await.expect("checkout");
assert_eq!(pool.available(), 2);
}
assert_eq!(pool.available(), 3);
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_panic_unwind() {
let pool = std::sync::Arc::new(Pool::new(vec![1u32]));
assert_eq!(pool.available(), 1);
let pool_clone = pool.clone();
let result = tokio::spawn(async move {
let _guard = pool_clone.checkout().await.expect("checkout");
assert_eq!(pool_clone.available(), 0);
panic!("synthetic inference panic");
})
.await;
assert!(result.is_err(), "spawned task must report the panic");
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_close_wakes_waiters_with_closed() {
let pool = std::sync::Arc::new(Pool::<u32>::new(vec![]));
let waiter = tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await.map(|_g| ()) }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
pool.close();
let res = waiter.await.expect("join");
assert!(matches!(res, Err(PoolError::Closed)));
}
#[tokio::test]
async fn test_pool_fifo_under_contention() {
let pool = std::sync::Arc::new(Pool::new(vec![0u32]));
let primary = pool.checkout().await.expect("primary checkout");
assert_eq!(pool.available(), 0);
let waker_log = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
let mut handles = Vec::new();
for id in 0u32..3 {
let pool = pool.clone();
let log = waker_log.clone();
handles.push(tokio::spawn(async move {
let g = pool.checkout().await.expect("checkout");
log.lock().await.push(id);
drop(g);
}));
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
drop(primary);
for h in handles {
h.await.expect("join");
}
let log = waker_log.lock().await.clone();
assert_eq!(log, vec![0, 1, 2], "waiters must wake in FIFO order");
}
#[tokio::test]
async fn test_into_owned_for_spawn_blocking() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let reservation = guard.into_owned();
let result = tokio::task::spawn_blocking(move || {
assert_eq!(*reservation, "triplet");
reservation.checkin();
"done"
})
.await
.expect("join");
assert_eq!(pool.available(), 1);
assert_eq!(result, "done");
}
#[tokio::test]
async fn test_owned_reservation_returns_on_spawn_blocking_panic() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let reservation = guard.into_owned();
let result = tokio::task::spawn_blocking(move || {
let _reservation = reservation;
panic!("simulated inference panic");
})
.await;
assert!(result.is_err(), "spawn_blocking must report the panic");
assert_eq!(
pool.available(),
1,
"reservation must be returned after panic"
);
}
#[tokio::test]
async fn test_owned_reservation_drop_returns_item() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let reservation = guard.into_owned();
tokio::task::spawn_blocking(move || {
let _reservation = reservation;
})
.await
.expect("join");
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_close_is_idempotent() {
let pool = Pool::<u32>::new(vec![]);
pool.close();
pool.close();
}
#[tokio::test]
async fn test_pool_waiters_count() {
let pool = std::sync::Arc::new(Pool::<u32>::new(vec![]));
let w1 = tokio::spawn({
let p = pool.clone();
async move { p.checkout().await.map(|_| ()) }
});
let w2 = tokio::spawn({
let p = pool.clone();
async move { p.checkout().await.map(|_| ()) }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(pool.waiters(), 2, "both blocked tasks must be waiters");
pool.close();
let _ = w1.await;
let _ = w2.await;
}
#[tokio::test]
async fn test_owned_reservation_round_trip_through_option() {
let pool = std::sync::Arc::new(Pool::new(vec![42u32]));
let guard = pool.checkout().await.expect("checkout");
let mut reservation: Option<OwnedReservation<u32>> = Some(guard.into_owned());
let (res_back, val) = tokio::task::spawn_blocking(move || {
let mut r = reservation.take().unwrap();
*r += 1;
let v = *r;
(r, v)
})
.await
.expect("join");
reservation = Some(res_back);
assert_eq!(val, 43);
drop(reservation);
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_slot_not_leaked_on_cancelled_checkout() {
let pool = std::sync::Arc::new(Pool::new(vec![42u32]));
let primary = pool.checkout().await.expect("checkout");
let aborted = tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
aborted.abort();
let _ = aborted.await;
assert_eq!(pool.waiters(), 1);
drop(primary);
assert_eq!(pool.available(), 1, "item must return to pool, not leak");
assert_eq!(pool.waiters(), 0, "dead waiter must be removed");
}
#[tokio::test]
async fn test_pool_slot_not_leaked_on_timeout_checkout() {
let pool = std::sync::Arc::new(Pool::new(vec![42u32]));
let primary = pool.checkout().await.expect("checkout");
let result =
tokio::time::timeout(std::time::Duration::from_millis(10), pool.checkout()).await;
assert!(result.is_err(), "checkout must time out");
assert_eq!(pool.waiters(), 1);
drop(primary);
assert_eq!(
pool.available(),
1,
"item must return to pool after timeout"
);
assert_eq!(pool.waiters(), 0, "dead waiter must be removed");
}
#[tokio::test]
async fn test_pool_multiple_dead_waiters_are_skipped() {
let pool = std::sync::Arc::new(Pool::new(vec![0u32]));
let primary = pool.checkout().await.expect("checkout");
let mut handles = Vec::new();
for _ in 0..3 {
handles.push(tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await }
}));
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
for h in handles {
h.abort();
let _ = h.await;
}
assert_eq!(pool.waiters(), 3);
drop(primary);
assert_eq!(
pool.available(),
1,
"item returned after skipping 3 dead waiters"
);
assert_eq!(pool.waiters(), 0);
}
#[test]
fn test_transcript_assembler_append_and_finalize() {
let mut asm = TranscriptAssembler::new();
assert!(asm.is_empty());
asm.append(vec![
WordInfo {
word: "hello".into(),
start: 0.0,
end: 0.5,
confidence: 0.9,
speaker: None,
},
WordInfo {
word: "world".into(),
start: 0.6,
end: 1.0,
confidence: 0.85,
speaker: None,
},
]);
assert!(!asm.is_empty());
let seg = asm.finalize(1.0);
assert_eq!(seg.text, "hello world");
assert_eq!(seg.words.len(), 2);
assert!(seg.is_final);
assert_eq!(seg.timestamp, 1.0);
assert!(asm.is_empty());
}
#[test]
fn test_transcript_assembler_partial() {
let mut asm = TranscriptAssembler::new();
asm.append(vec![WordInfo {
word: "partial".into(),
start: 0.0,
end: 0.3,
confidence: 0.8,
speaker: None,
}]);
let seg = asm.partial(0.3);
assert_eq!(seg.text, "partial");
assert!(!seg.is_final);
assert!(!asm.is_empty());
}
#[test]
fn test_feature_extractor_compute_empty() {
let fe = FeatureExtractor::new();
let (mel, frames) = fe.compute(&[]);
assert_eq!(mel.len(), N_MELS);
assert_eq!(frames, 1);
assert!(mel.iter().all(|&v| v == 0.0));
}
#[test]
fn test_now_timestamp_non_negative() {
let ts = now_timestamp();
assert!(ts >= 0.0, "timestamp must be non-negative");
}
#[test]
fn test_probe_or_rebuild_keeps_state_when_probe_passes() {
let rebuilt = std::cell::Cell::new(false);
let result = probe_or_rebuild(
7u32,
|v| {
assert_eq!(*v, 7);
Ok(())
},
|_, _| {
rebuilt.set(true);
Ok(99)
},
)
.expect("healthy state must survive unchanged");
assert_eq!(result, 7);
assert!(!rebuilt.get(), "rebuild must not run when the probe passes");
}
#[test]
fn test_probe_or_rebuild_rebuilds_when_probe_fails() {
let result = probe_or_rebuild(
1u32,
|v| {
if *v == 1 {
Err(anyhow::anyhow!("first probe failed"))
} else {
Ok(())
}
},
|old, probe_err| {
assert_eq!(old, 1, "rebuild receives the failed state");
assert!(
probe_err.to_string().contains("first probe failed"),
"rebuild receives the probe error for logging"
);
Ok(2)
},
)
.expect("rebuilt state passing the probe is OK");
assert_eq!(result, 2);
}
#[test]
fn test_probe_or_rebuild_propagates_rebuild_error() {
let result = probe_or_rebuild(
1u32,
|_| Err(anyhow::anyhow!("probe failed")),
|_, _| Err(anyhow::anyhow!("rebuild failed")),
);
let err = result.expect_err("rebuild failure must be fatal");
assert!(err.to_string().contains("rebuild failed"));
}
#[test]
fn test_probe_or_rebuild_fails_when_rebuilt_state_fails_probe() {
let result = probe_or_rebuild(
1u32,
|_| Err(anyhow::anyhow!("always fails")),
|_, _| Ok(2u32),
);
assert!(
result.is_err(),
"a rebuilt state that still fails the probe must be a hard error"
);
}
#[test]
fn test_encoder_model_path_prefers_int8_when_present() {
let dir = tempfile::tempdir().expect("tempdir");
std::fs::write(dir.path().join("v3_e2e_rnnt_encoder.onnx"), b"fp32").unwrap();
std::fs::write(dir.path().join("v3_e2e_rnnt_encoder_int8.onnx"), b"int8").unwrap();
let path = Engine::encoder_model_path(dir.path());
assert_eq!(
path.file_name().unwrap(),
"v3_e2e_rnnt_encoder_int8.onnx",
"INT8 encoder must win when both files exist"
);
}
#[test]
fn test_encoder_model_path_falls_back_to_fp32() {
let dir = tempfile::tempdir().expect("tempdir");
std::fs::write(dir.path().join("v3_e2e_rnnt_encoder.onnx"), b"fp32").unwrap();
let path = Engine::encoder_model_path(dir.path());
assert_eq!(path.file_name().unwrap(), "v3_e2e_rnnt_encoder.onnx");
}
#[test]
fn test_pool_sequential_checkouts_visit_every_item() {
let pool = Pool::new(vec![1u32, 2, 3]);
let mut seen = Vec::new();
for _ in 0..pool.total() {
let guard = pool.checkout_blocking().expect("checkout");
seen.push(*guard);
}
seen.sort_unstable();
assert_eq!(seen, vec![1, 2, 3]);
}
#[test]
#[ignore = "requires model"]
fn test_warmup_runs_silent_inference_on_every_triplet() {
let engine = Engine::load_with_pool_size(&crate::model::default_model_dir(), 2)
.expect("engine should load");
engine
.warmup()
.expect("warmup must succeed on a working engine");
assert_eq!(
engine.pool.available(),
engine.pool.total(),
"every triplet must be returned to the pool after warmup"
);
}
}