pub mod audio;
mod decode;
mod features;
mod tokenizer;
#[cfg(feature = "diarization")]
use polyvoice::{DiarizationConfig as DiaConfig, OfflineDiarizer, OnlineDiarizer, OnnxEmbeddingExtractor, SampleRate};
#[cfg(feature = "diarization")]
const SPEAKER_EMBEDDING_DIM: usize = 256;
#[cfg(feature = "diarization")]
const SPEAKER_SEGMENT_SAMPLES: usize = 24000;
#[cfg(feature = "diarization")]
const SPEAKER_POOL_SIZE: usize = 4;
#[cfg(all(feature = "coreml", feature = "cuda"))]
compile_error!("Features `coreml` and `cuda` are mutually exclusive. Choose one.");
use anyhow::Context;
#[cfg(any(feature = "coreml", feature = "cuda"))]
use ort::ep;
use ort::session::Session;
use ort::value::TensorRef;
use serde::Serialize;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use crate::error::GigasttError;
use features::MelSpectrogram;
use tokenizer::Tokenizer;
pub const N_MELS: usize = 64;
pub const N_FFT: usize = 320;
pub const HOP_LENGTH: usize = 160;
pub const PRED_HIDDEN: usize = 320;
fn ort_err(e: impl std::fmt::Display) -> anyhow::Error {
anyhow::anyhow!("{e}")
}
pub(crate) fn now_timestamp() -> f64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
const ENCODER_SUBSAMPLING: usize = 4;
const SECONDS_PER_FRAME: f64 = (HOP_LENGTH as f64 * ENCODER_SUBSAMPLING as f64) / 16000.0;
#[cfg(target_os = "android")]
const DEFAULT_POOL_SIZE: usize = 1;
#[cfg(not(target_os = "android"))]
const DEFAULT_POOL_SIZE: usize = 4;
pub struct SessionTriplet {
pub(crate) encoder: Session,
pub(crate) decoder: Session,
pub(crate) joiner: Session,
}
#[derive(Debug)]
pub enum PoolError {
Closed,
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolError::Closed => write!(f, "session pool is closed"),
}
}
}
impl std::error::Error for PoolError {}
pub struct Pool<T> {
inner: std::sync::Arc<PoolInner<T>>,
}
struct PoolInner<T> {
items: std::sync::Mutex<std::collections::VecDeque<T>>,
waiters: std::sync::Mutex<std::collections::VecDeque<Waiter<T>>>,
closed: std::sync::atomic::AtomicBool,
total: usize,
}
enum Waiter<T> {
Async(tokio::sync::oneshot::Sender<T>),
Blocking(std::sync::mpsc::Sender<T>),
}
pub type SessionPool = Pool<SessionTriplet>;
impl<T: Send> Pool<T> {
pub fn new(items: Vec<T>) -> Self {
let total = items.len();
Self {
inner: std::sync::Arc::new(PoolInner {
items: std::sync::Mutex::new(std::collections::VecDeque::from(items)),
waiters: std::sync::Mutex::new(std::collections::VecDeque::new()),
closed: std::sync::atomic::AtomicBool::new(false),
total,
}),
}
}
pub async fn checkout(&self) -> Result<PoolGuard<T>, PoolError> {
{
let mut items = self.inner.items.lock().unwrap();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
if let Some(item) = items.pop_front() {
return Ok(PoolGuard::new(self.inner.clone(), item));
}
}
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut waiters = self.inner.waiters.lock().unwrap();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
waiters.push_back(Waiter::Async(tx));
}
match rx.await {
Ok(item) => Ok(PoolGuard::new(self.inner.clone(), item)),
Err(_) => Err(PoolError::Closed),
}
}
pub fn checkout_blocking(&self) -> Result<PoolGuard<T>, PoolError> {
{
let mut items = self.inner.items.lock().unwrap();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
if let Some(item) = items.pop_front() {
return Ok(PoolGuard::new(self.inner.clone(), item));
}
}
let (tx, rx) = std::sync::mpsc::channel();
{
let mut waiters = self.inner.waiters.lock().unwrap();
if self.inner.closed.load(std::sync::atomic::Ordering::SeqCst) {
return Err(PoolError::Closed);
}
waiters.push_back(Waiter::Blocking(tx));
}
match rx.recv() {
Ok(item) => Ok(PoolGuard::new(self.inner.clone(), item)),
Err(_) => Err(PoolError::Closed),
}
}
pub fn close(&self) {
self.inner
.closed
.store(true, std::sync::atomic::Ordering::SeqCst);
let mut waiters = self.inner.waiters.lock().unwrap();
waiters.clear();
}
pub fn total(&self) -> usize {
self.inner.total
}
pub fn available(&self) -> usize {
let items = self.inner.items.lock().unwrap();
items.len()
}
}
impl<T> PoolInner<T> {
fn checkin(&self, item: T) {
if self.closed.load(std::sync::atomic::Ordering::SeqCst) {
return;
}
let mut waiters = self.waiters.lock().unwrap();
if let Some(waiter) = waiters.pop_front() {
drop(waiters);
match waiter {
Waiter::Async(tx) => {
let _ = tx.send(item);
}
Waiter::Blocking(tx) => {
let _ = tx.send(item);
}
}
} else {
drop(waiters);
let mut items = self.items.lock().unwrap();
items.push_back(item);
}
}
}
pub struct PoolGuard<T> {
inner: Option<std::sync::Arc<PoolInner<T>>>,
item: Option<T>,
}
impl<T> PoolGuard<T> {
fn new(inner: std::sync::Arc<PoolInner<T>>, item: T) -> Self {
Self {
inner: Some(inner),
item: Some(item),
}
}
pub fn into_owned(mut self) -> (T, OwnedReservation<T>) {
let item = self
.item
.take()
.unwrap_or_else(|| unreachable!("PoolGuard::into_owned called after drop"));
let inner = self.inner.take().unwrap();
std::mem::forget(self);
let reservation = OwnedReservation { inner };
(item, reservation)
}
}
impl<T> Deref for PoolGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.item
.as_ref()
.unwrap_or_else(|| unreachable!("PoolGuard accessed after item taken"))
}
}
impl<T> DerefMut for PoolGuard<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.item
.as_mut()
.unwrap_or_else(|| unreachable!("PoolGuard accessed after item taken"))
}
}
impl<T> Drop for PoolGuard<T> {
fn drop(&mut self) {
if let (Some(inner), Some(item)) = (self.inner.take(), self.item.take()) {
inner.checkin(item);
}
}
}
pub struct OwnedReservation<T> {
inner: std::sync::Arc<PoolInner<T>>,
}
impl<T> OwnedReservation<T> {
pub fn checkin(self, item: T) {
self.inner.checkin(item);
}
}
#[non_exhaustive]
pub struct DecoderState {
pub h: Vec<f32>,
pub c: Vec<f32>,
pub prev_token: i64,
pub consecutive_blanks: usize,
}
impl DecoderState {
pub fn new(blank_id: usize) -> Self {
Self {
h: vec![0.0; PRED_HIDDEN],
c: vec![0.0; PRED_HIDDEN],
prev_token: blank_id as i64,
consecutive_blanks: 0,
}
}
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct WordInfo {
pub word: String,
pub start: f64,
pub end: f64,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
}
#[non_exhaustive]
pub struct StreamingState {
pub decoder: DecoderState,
pub audio_buffer: Vec<f32>,
pub assembler: TranscriptAssembler,
pub total_frames: usize,
pub resampler: Option<rubato::SincFixedIn<f32>>,
pub mel_fft_input: Vec<rustfft::num_complex::Complex<f32>>,
pub mel_power: Vec<f32>,
#[cfg(feature = "diarization")]
pub diarization_state: Option<OnlineDiarizer>,
}
pub struct FeatureExtractor {
mel: MelSpectrogram,
}
impl Default for FeatureExtractor {
fn default() -> Self {
Self::new()
}
}
impl FeatureExtractor {
pub fn new() -> Self {
Self {
mel: MelSpectrogram::new(),
}
}
pub fn prepare_buffer(&self, samples: &[f32], audio_buffer: &mut Vec<f32>) -> Option<Vec<f32>> {
audio::prepare_audio_buffer(samples, audio_buffer)
}
pub fn compute_mel(
&self,
samples: &[f32],
fft_buf: &mut Vec<rustfft::num_complex::Complex<f32>>,
power_buf: &mut Vec<f32>,
) -> (Vec<f32>, usize) {
self.mel.compute_with_buffers(samples, fft_buf, power_buf)
}
pub fn compute(&self, samples: &[f32]) -> (Vec<f32>, usize) {
self.mel.compute(samples)
}
}
pub struct TranscriptAssembler {
text: String,
words: Vec<WordInfo>,
}
impl Default for TranscriptAssembler {
fn default() -> Self {
Self::new()
}
}
impl TranscriptAssembler {
pub fn new() -> Self {
Self {
text: String::new(),
words: Vec::new(),
}
}
pub fn append(&mut self, new_words: Vec<WordInfo>) {
for w in &new_words {
if !self.text.is_empty() {
self.text.push(' ');
}
self.text.push_str(&w.word);
}
self.words.extend(new_words);
}
pub fn finalize(&mut self, timestamp: f64) -> TranscriptSegment {
TranscriptSegment {
text: std::mem::take(&mut self.text),
words: std::mem::take(&mut self.words),
is_final: true,
timestamp,
}
}
pub fn partial(&self, timestamp: f64) -> TranscriptSegment {
TranscriptSegment {
text: self.text.clone(),
words: self.words.clone(),
is_final: false,
timestamp,
}
}
pub fn is_empty(&self) -> bool {
self.text.is_empty()
}
}
pub struct Engine {
pub pool: SessionPool,
tokenizer: Tokenizer,
features: FeatureExtractor,
int8: bool,
#[cfg(feature = "diarization")]
pub speaker_encoder: Option<OnnxEmbeddingExtractor>,
}
impl Engine {
pub fn is_int8(&self) -> bool {
self.int8
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.vocab_size()
}
pub fn load(model_dir: &str) -> Result<Self, GigasttError> {
Self::load_with_pool_size(model_dir, DEFAULT_POOL_SIZE)
}
pub fn load_with_pool_size(model_dir: &str, pool_size: usize) -> Result<Self, GigasttError> {
let dir = Path::new(model_dir);
if !dir.join("v3_e2e_rnnt_encoder.onnx").exists() {
return Err(GigasttError::ModelLoad {
path: model_dir.to_string(),
source: None,
});
}
Self::load_inner(dir, model_dir, pool_size).map_err(|e| GigasttError::ModelLoad {
path: model_dir.to_string(),
source: Some(e.into()),
})
}
fn load_sessions(
dir: &Path,
prepacked: &ort::session::builder::PrepackedWeights,
) -> anyhow::Result<(Session, Session, Session)> {
let encoder_path = if dir.join("v3_e2e_rnnt_encoder_int8.onnx").exists() {
dir.join("v3_e2e_rnnt_encoder_int8.onnx")
} else {
dir.join("v3_e2e_rnnt_encoder.onnx")
};
#[cfg(feature = "coreml")]
let (encoder, decoder, joiner) = {
let cache_dir = dir.join("coreml_cache");
let coreml_ep = ep::CoreML::default()
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
.with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
.with_model_cache_dir(cache_dir.to_string_lossy())
.build();
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [coreml_ep.clone(), cpu_fallback.into()];
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
#[cfg(feature = "cuda")]
let (encoder, decoder, joiner) = {
let cuda_ep = ep::CUDA::default().build();
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [cuda_ep.clone(), cpu_fallback.into()];
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
let (encoder, decoder, joiner) = {
let cache_dir = dir.join("optimized_cache");
std::fs::create_dir_all(&cache_dir).ok();
let cpu_fallback = ort::execution_providers::CPUExecutionProvider::default();
let eps = [cpu_fallback.into()];
let encoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(1)
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.with_optimized_model_path(cache_dir.join("encoder_optimized.onnx"))
.map_err(ort_err)?
.with_execution_providers(&eps)
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(1)
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_prepacked_weights(prepacked)
.map_err(ort_err)?
.with_intra_threads(1)
.map_err(ort_err)?
.with_inter_threads(1)
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
Ok((encoder, decoder, joiner))
}
fn load_inner(dir: &Path, model_dir: &str, pool_size: usize) -> anyhow::Result<Self> {
let is_int8 = dir.join("v3_e2e_rnnt_encoder_int8.onnx").exists();
if is_int8 {
tracing::info!("Using INT8 quantized encoder");
}
tracing::info!("Loading ONNX models from {model_dir} (pool_size={pool_size})...");
#[cfg(feature = "coreml")]
tracing::info!("Using CoreML execution provider (Neural Engine + CPU)");
#[cfg(feature = "cuda")]
tracing::info!("Using CUDA execution provider (falls back to CPU if unavailable)");
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
tracing::info!("Using CPU execution provider");
let prepacked = ort::session::builder::PrepackedWeights::new();
let triplets: Vec<SessionTriplet> = std::thread::scope(|s| {
let handles: Vec<_> = (0..pool_size)
.map(|i| {
let pp = &prepacked;
s.spawn(move || {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
tracing::info!(
"Loading session triplet {}/{pool_size} (shared weights)",
i + 1
);
let (encoder, decoder, joiner) = Self::load_sessions(dir, pp)?;
Ok(SessionTriplet {
encoder,
decoder,
joiner,
})
}))
.map_err(|_| anyhow::anyhow!("model loading thread panicked"))?
})
})
.collect();
handles
.into_iter()
.map(|h| match h.join() {
Ok(r) => r,
Err(_) => Err(anyhow::anyhow!("model loading thread panicked")),
})
.collect::<anyhow::Result<Vec<_>>>()
})?;
let tokenizer = Tokenizer::load(&dir.join("v3_e2e_rnnt_vocab.txt"))?;
let features = FeatureExtractor::new();
tracing::info!(
"Models loaded (vocab_size={}, pool_size={pool_size})",
tokenizer.vocab_size()
);
#[cfg(feature = "diarization")]
let speaker_encoder = {
let model_path = dir.join("wespeaker_resnet34.onnx");
if model_path.exists() {
match OnnxEmbeddingExtractor::new(&model_path, SPEAKER_EMBEDDING_DIM, SPEAKER_SEGMENT_SAMPLES, SPEAKER_POOL_SIZE) {
Ok(enc) => {
tracing::info!("Speaker encoder loaded (diarization available)");
Some(enc)
}
Err(e) => {
tracing::warn!("Speaker encoder not loaded, diarization unavailable: {e:#}");
None
}
}
} else {
tracing::warn!("wespeaker_resnet34.onnx not found, diarization unavailable");
None
}
};
Ok(Self {
pool: SessionPool::new(triplets),
tokenizer,
features,
int8: is_int8,
#[cfg(feature = "diarization")]
speaker_encoder,
})
}
#[cfg(feature = "diarization")]
pub fn has_speaker_encoder(&self) -> bool {
self.speaker_encoder.is_some()
}
pub fn create_state(&self, diarization_enabled: bool) -> StreamingState {
#[cfg(feature = "diarization")]
let diarization_state = if diarization_enabled && self.speaker_encoder.is_some() {
Some(OnlineDiarizer::new(DiaConfig {
threshold: 0.5,
max_speakers: 64,
window_secs: 1.5,
hop_secs: 0.75,
min_speech_secs: 0.25,
max_gap_secs: 0.5,
sample_rate: SampleRate::new(16000).expect("16kHz is valid"),
}))
} else {
None
};
#[cfg(not(feature = "diarization"))]
if diarization_enabled {
tracing::warn!(
"diarization_enabled=true ignored: build lacks the `diarization` feature"
);
}
StreamingState {
decoder: DecoderState::new(self.tokenizer.blank_id()),
audio_buffer: Vec::new(),
assembler: TranscriptAssembler::new(),
total_frames: 0,
resampler: None,
mel_fft_input: Vec::new(),
mel_power: Vec::new(),
#[cfg(feature = "diarization")]
diarization_state,
}
}
pub fn process_chunk(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, GigasttError> {
if samples.is_empty() {
return Ok(vec![]);
}
#[cfg(feature = "diarization")]
let samples_16k_copy = if state.diarization_state.is_some() {
Some(samples.to_vec())
} else {
None
};
let samples = match self
.features
.prepare_buffer(samples, &mut state.audio_buffer)
{
Some(s) => s,
None => return Ok(vec![]),
};
let samples = &samples[..];
let mel_start = std::time::Instant::now();
let (features, num_frames) =
self.features
.compute_mel(samples, &mut state.mel_fft_input, &mut state.mel_power);
tracing::debug!(
elapsed_us = mel_start.elapsed().as_micros() as u64,
"mel_compute"
);
if num_frames == 0 {
return Ok(vec![]);
}
#[cfg_attr(not(feature = "diarization"), allow(unused_mut))]
let (mut new_words, endpoint) = self
.run_inference(
triplet,
&features,
num_frames,
&mut state.decoder,
state.total_frames,
)
.map_err(|e| GigasttError::Inference { source: e.into() })?;
state.total_frames += num_frames;
#[cfg(feature = "diarization")]
if let (Some(dia), Some(copy), Some(enc)) = (
state.diarization_state.as_mut(),
samples_16k_copy.as_ref(),
self.speaker_encoder.as_ref(),
) {
if let Err(e) = dia.feed(copy, enc).map(|_segs| ()) {
tracing::warn!("Diarization feed failed: {e:#}");
}
if let Some(speaker_id) = dia.current_speaker() {
for w in &mut new_words {
w.speaker = Some(speaker_id.0);
}
}
}
if new_words.is_empty() && !endpoint {
return Ok(vec![]);
}
state.assembler.append(new_words);
let ts = now_timestamp();
if endpoint {
state.decoder.consecutive_blanks = 0;
Ok(vec![state.assembler.finalize(ts)])
} else {
Ok(vec![state.assembler.partial(ts)])
}
}
pub fn flush_state(&self, state: &mut StreamingState) -> Option<TranscriptSegment> {
if state.assembler.is_empty() {
return None;
}
Some(state.assembler.finalize(now_timestamp()))
}
pub fn transcribe_file(
&self,
path: &str,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples =
audio::decode_audio_file(path).map_err(|e| GigasttError::InvalidAudio {
reason: format!("{e:#}"),
})?;
self.transcribe_samples(&float_samples, triplet)
}
pub fn transcribe_bytes(
&self,
data: &[u8],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
self.transcribe_bytes_shared(bytes::Bytes::copy_from_slice(data), triplet)
}
pub fn transcribe_bytes_shared(
&self,
data: bytes::Bytes,
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let float_samples =
audio::decode_audio_bytes_shared(data).map_err(|e| GigasttError::InvalidAudio {
reason: format!("{e:#}"),
})?;
self.transcribe_samples(&float_samples, triplet)
}
fn transcribe_samples(
&self,
float_samples: &[f32],
triplet: &mut SessionTriplet,
) -> Result<TranscribeResult, GigasttError> {
let duration_s = float_samples.len() as f64 / 16000.0;
let (features, num_frames) = self.features.compute(float_samples);
tracing::info!("Extracted {} mel frames", num_frames);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
#[cfg_attr(not(feature = "diarization"), allow(unused_mut))]
let (mut words, _endpoint) = self
.run_inference(triplet, &features, num_frames, &mut decoder_state, 0)
.map_err(|e| GigasttError::Inference { source: e.into() })?;
#[cfg(feature = "diarization")]
if let Some(ref enc) = self.speaker_encoder {
let config = DiaConfig::default();
let diarizer = OfflineDiarizer::new(config);
match diarizer.run(float_samples, enc) {
Ok(dia_result) => {
for word in &mut words {
let mid = (word.start + word.end) / 2.0;
if let Some(turn) = dia_result.turns.iter().find(|t| {
t.time.start <= mid && t.time.end >= mid
}) {
word.speaker = Some(turn.speaker.0);
}
}
}
Err(e) => {
tracing::warn!("Offline diarization failed: {e:#}");
}
}
}
let text: String = words
.iter()
.map(|w| w.word.as_str())
.collect::<Vec<_>>()
.join(" ");
Ok(TranscribeResult {
text,
words,
duration_s,
})
}
fn run_inference(
&self,
triplet: &mut SessionTriplet,
features: &[f32],
num_frames: usize,
decoder_state: &mut DecoderState,
frame_offset: usize,
) -> anyhow::Result<(Vec<WordInfo>, bool)> {
let signal_tensor = TensorRef::from_array_view(([1_usize, N_MELS, num_frames], features))?;
let length_data = [num_frames as i64];
let length_tensor = TensorRef::from_array_view(([1_usize], length_data.as_slice()))?;
let enc_start = std::time::Instant::now();
let encoder_outputs = triplet
.encoder
.run(ort::inputs![signal_tensor, length_tensor])
.context("Encoder inference failed")?;
tracing::info!(
elapsed_ms = enc_start.elapsed().as_millis() as u64,
"encoder_inference"
);
let (_enc_shape, enc_data) = encoder_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract encoder output")?;
let (_len_shape, len_data) = encoder_outputs[1]
.try_extract_tensor::<i32>()
.context("Failed to extract encoder length")?;
let enc_len = usize::try_from(len_data[0]).context("Negative encoder length")?;
tracing::debug!("Encoder output: {} frames", enc_len);
let enc_data_owned: Vec<f32> = enc_data.to_vec();
drop(encoder_outputs);
let dec_start = std::time::Instant::now();
let result = decode::greedy_decode(
&mut triplet.decoder,
&mut triplet.joiner,
&enc_data_owned,
enc_len,
self.tokenizer.blank_id(),
decoder_state,
)?;
tracing::info!(
elapsed_ms = dec_start.elapsed().as_millis() as u64,
"greedy_decode"
);
let words = self.tokens_to_words(&result.tokens, frame_offset);
tracing::info!(
tokens = result.tokens.len(),
words = words.len(),
duration_ms = dec_start.elapsed().as_millis() as u64,
"Decoded tokens"
);
Ok((words, result.endpoint_detected))
}
fn tokens_to_words(&self, tokens: &[decode::TokenInfo], frame_offset: usize) -> Vec<WordInfo> {
if tokens.is_empty() {
return Vec::new();
}
let mut words = Vec::new();
let mut current_word = String::new();
let mut word_start_frame: Option<usize> = None;
let mut word_end_frame: usize = 0;
let mut word_confidences: Vec<f32> = Vec::new();
for token in tokens {
let token_text = self.tokenizer.token_text(token.token_id);
let is_word_boundary = token_text.starts_with('\u{2581}');
if is_word_boundary && !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: std::mem::take(&mut current_word),
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64
* SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
current_word.clear();
word_confidences.clear();
word_start_frame = None;
}
let clean = if let Some(stripped) = token_text.strip_prefix('\u{2581}') {
stripped
} else {
token_text
};
if !clean.is_empty() {
current_word.push_str(clean);
if word_start_frame.is_none() {
word_start_frame = Some(token.frame_index);
}
word_end_frame = token.frame_index;
word_confidences.push(token.confidence);
}
}
if !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word,
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
}
words
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TranscribeResult {
pub text: String,
pub words: Vec<WordInfo>,
pub duration_s: f64,
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct TranscriptSegment {
pub text: String,
pub words: Vec<WordInfo>,
pub is_final: bool,
pub timestamp: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_state_new_zeros() {
let blank_id = 1024;
let state = DecoderState::new(blank_id);
assert!(state.h.iter().all(|&v| v == 0.0));
assert!(state.c.iter().all(|&v| v == 0.0));
assert_eq!(state.prev_token, blank_id as i64);
}
#[test]
fn test_decoder_state_dimensions() {
let state = DecoderState::new(1024);
assert_eq!(state.h.len(), PRED_HIDDEN);
assert_eq!(state.c.len(), PRED_HIDDEN);
}
#[test]
fn test_decoder_state_custom_blank_id() {
let state = DecoderState::new(42);
assert_eq!(state.prev_token, 42);
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_normal_drop() {
let pool = Pool::new(vec![1u32, 2, 3]);
assert_eq!(pool.available(), 3);
{
let _guard = pool.checkout().await.expect("checkout");
assert_eq!(pool.available(), 2);
}
assert_eq!(pool.available(), 3);
}
#[tokio::test]
async fn test_pool_guard_returns_triplet_on_panic_unwind() {
let pool = std::sync::Arc::new(Pool::new(vec![1u32]));
assert_eq!(pool.available(), 1);
let pool_clone = pool.clone();
let result = tokio::spawn(async move {
let _guard = pool_clone.checkout().await.expect("checkout");
assert_eq!(pool_clone.available(), 0);
panic!("synthetic inference panic");
})
.await;
assert!(result.is_err(), "spawned task must report the panic");
assert_eq!(pool.available(), 1);
}
#[tokio::test]
async fn test_pool_close_wakes_waiters_with_closed() {
let pool = std::sync::Arc::new(Pool::<u32>::new(vec![]));
let waiter = tokio::spawn({
let pool = pool.clone();
async move { pool.checkout().await.map(|_g| ()) }
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
pool.close();
let res = waiter.await.expect("join");
assert!(matches!(res, Err(PoolError::Closed)));
}
#[tokio::test]
async fn test_pool_fifo_under_contention() {
let pool = std::sync::Arc::new(Pool::new(vec![0u32]));
let primary = pool.checkout().await.expect("primary checkout");
assert_eq!(pool.available(), 0);
let waker_log = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
let mut handles = Vec::new();
for id in 0u32..3 {
let pool = pool.clone();
let log = waker_log.clone();
handles.push(tokio::spawn(async move {
let g = pool.checkout().await.expect("checkout");
log.lock().await.push(id);
drop(g);
}));
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
drop(primary);
for h in handles {
h.await.expect("join");
}
let log = waker_log.lock().await.clone();
assert_eq!(log, vec![0, 1, 2], "waiters must wake in FIFO order");
}
#[tokio::test]
async fn test_into_owned_for_spawn_blocking() {
let pool = std::sync::Arc::new(Pool::new(vec![String::from("triplet")]));
let guard = pool.checkout().await.expect("checkout");
let (item, reservation) = guard.into_owned();
let item = tokio::task::spawn_blocking(move || {
assert_eq!(item, "triplet");
reservation.checkin(item.clone());
item
})
.await
.expect("join");
assert_eq!(pool.available(), 1);
assert_eq!(item, "triplet");
}
#[tokio::test]
async fn test_pool_close_is_idempotent() {
let pool = Pool::<u32>::new(vec![]);
pool.close();
pool.close();
}
}