mod decode;
mod features;
mod tokenizer;
pub mod audio;
#[cfg(feature = "diarization")]
pub mod diarization;
#[cfg(all(feature = "coreml", feature = "cuda"))]
compile_error!("Features `coreml` and `cuda` are mutually exclusive. Choose one.");
use anyhow::Context;
#[cfg(any(feature = "coreml", feature = "cuda"))]
use ort::ep;
use ort::session::Session;
use ort::value::TensorRef;
use serde::Serialize;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::error::GigasttError;
use features::MelSpectrogram;
use tokenizer::Tokenizer;
pub const N_MELS: usize = 64;
pub const N_FFT: usize = 320;
pub const HOP_LENGTH: usize = 160;
pub const PRED_HIDDEN: usize = 320;
fn ort_err(e: impl std::fmt::Display) -> anyhow::Error {
anyhow::anyhow!("{e}")
}
pub(crate) fn now_timestamp() -> f64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
const SECONDS_PER_FRAME: f64 = (HOP_LENGTH as f64) * 4.0 / 16000.0;
const DEFAULT_POOL_SIZE: usize = 4;
pub struct SessionTriplet {
pub(crate) encoder: Session,
pub(crate) decoder: Session,
pub(crate) joiner: Session,
}
pub struct SessionPool {
sender: tokio::sync::mpsc::Sender<SessionTriplet>,
receiver: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<SessionTriplet>>,
total: usize,
available: AtomicUsize,
}
impl SessionPool {
pub fn new(triplets: Vec<SessionTriplet>) -> Self {
let total = triplets.len();
let (sender, receiver) = tokio::sync::mpsc::channel(total);
for triplet in triplets {
sender.try_send(triplet).expect("channel capacity matches triplet count");
}
Self {
sender,
receiver: tokio::sync::Mutex::new(receiver),
total,
available: AtomicUsize::new(total),
}
}
pub async fn checkout(&self) -> SessionTriplet {
let triplet = self
.receiver
.lock()
.await
.recv()
.await
.expect("Pool sender dropped — this is a bug");
self.available.fetch_sub(1, Ordering::Relaxed);
triplet
}
pub async fn checkin(&self, triplet: SessionTriplet) {
self.sender
.send(triplet)
.await
.expect("Pool receiver dropped — this is a bug");
self.available.fetch_add(1, Ordering::Relaxed);
}
pub fn blocking_checkin(&self, triplet: SessionTriplet) {
self.sender
.blocking_send(triplet)
.expect("Pool receiver dropped — this is a bug");
self.available.fetch_add(1, Ordering::Relaxed);
}
pub fn total(&self) -> usize {
self.total
}
pub fn available(&self) -> usize {
self.available.load(Ordering::Relaxed)
}
}
#[non_exhaustive]
pub struct DecoderState {
pub h: Vec<f32>,
pub c: Vec<f32>,
pub prev_token: i64,
pub consecutive_blanks: usize,
}
impl DecoderState {
pub fn new(blank_id: usize) -> Self {
Self {
h: vec![0.0; PRED_HIDDEN],
c: vec![0.0; PRED_HIDDEN],
prev_token: blank_id as i64,
consecutive_blanks: 0,
}
}
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct WordInfo {
pub word: String,
pub start: f64,
pub end: f64,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
}
#[cfg(feature = "diarization")]
pub struct DiarizationStreamState {
pub audio_buffer: Vec<f32>,
pub cluster: diarization::SpeakerCluster,
pub current_speaker: Option<u32>,
}
#[non_exhaustive]
pub struct StreamingState {
pub decoder: DecoderState,
pub audio_buffer: Vec<f32>,
pub accumulated_text: String,
pub accumulated_words: Vec<WordInfo>,
pub total_frames: usize,
#[cfg(feature = "diarization")]
pub diarization_state: Option<DiarizationStreamState>,
}
pub struct Engine {
pub pool: SessionPool,
tokenizer: Tokenizer,
mel: MelSpectrogram,
#[cfg(feature = "diarization")]
pub speaker_encoder: Option<diarization::SpeakerEncoder>,
}
impl Engine {
pub fn load(model_dir: &str) -> Result<Self, GigasttError> {
Self::load_with_pool_size(model_dir, DEFAULT_POOL_SIZE)
}
pub fn load_with_pool_size(model_dir: &str, pool_size: usize) -> Result<Self, GigasttError> {
let dir = Path::new(model_dir);
if !dir.join("v3_e2e_rnnt_encoder.onnx").exists() {
return Err(GigasttError::ModelLoad(format!(
"v3_e2e_rnnt_encoder.onnx not found in {model_dir}"
)));
}
Self::load_inner(dir, model_dir, pool_size)
.map_err(|e| GigasttError::ModelLoad(format!("{e:#}")))
}
fn load_sessions(dir: &Path) -> anyhow::Result<(Session, Session, Session)> {
let encoder_path = if dir.join("v3_e2e_rnnt_encoder_int8.onnx").exists() {
dir.join("v3_e2e_rnnt_encoder_int8.onnx")
} else {
dir.join("v3_e2e_rnnt_encoder.onnx")
};
#[cfg(feature = "coreml")]
let (encoder, decoder, joiner) = {
let cache_dir = dir.join("coreml_cache");
let coreml_ep = ep::CoreML::default()
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
.with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
.with_model_cache_dir(cache_dir.to_string_lossy())
.build();
let encoder = Session::builder()
.map_err(ort_err)?
.with_execution_providers([coreml_ep.clone()])
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_execution_providers([coreml_ep.clone()])
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_execution_providers([coreml_ep])
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
#[cfg(feature = "cuda")]
let (encoder, decoder, joiner) = {
let cuda_ep = ep::CUDA::default().build();
let encoder = Session::builder()
.map_err(ort_err)?
.with_execution_providers([cuda_ep.clone()])
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_execution_providers([cuda_ep.clone()])
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_execution_providers([cuda_ep])
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
let (encoder, decoder, joiner) = {
let encoder = Session::builder()
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
(encoder, decoder, joiner)
};
Ok((encoder, decoder, joiner))
}
fn load_inner(dir: &Path, model_dir: &str, pool_size: usize) -> anyhow::Result<Self> {
if dir.join("v3_e2e_rnnt_encoder_int8.onnx").exists() {
tracing::info!("Using INT8 quantized encoder");
}
tracing::info!("Loading ONNX models from {model_dir} (pool_size={pool_size})...");
#[cfg(feature = "coreml")]
tracing::info!("Using CoreML execution provider (Neural Engine + CPU)");
#[cfg(feature = "cuda")]
tracing::info!("Using CUDA execution provider (falls back to CPU if unavailable)");
#[cfg(not(any(feature = "coreml", feature = "cuda")))]
tracing::info!("Using CPU execution provider");
let triplets: Vec<SessionTriplet> = std::thread::scope(|s| {
let handles: Vec<_> = (0..pool_size)
.map(|i| {
s.spawn(move || {
tracing::info!("Loading session triplet {}/{pool_size}", i + 1);
let (encoder, decoder, joiner) = Self::load_sessions(dir)?;
Ok(SessionTriplet { encoder, decoder, joiner })
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("Thread panicked during model loading"))
.collect::<anyhow::Result<Vec<_>>>()
})?;
let tokenizer = Tokenizer::load(&dir.join("v3_e2e_rnnt_vocab.txt"))?;
let mel = MelSpectrogram::new();
tracing::info!("Models loaded (vocab_size={}, pool_size={pool_size})", tokenizer.vocab_size());
#[cfg(feature = "diarization")]
let speaker_encoder = match diarization::SpeakerEncoder::load(dir) {
Ok(enc) => {
tracing::info!("Speaker encoder loaded (diarization available)");
Some(enc)
}
Err(e) => {
tracing::warn!("Speaker encoder not loaded, diarization unavailable: {e:#}");
None
}
};
Ok(Self {
pool: SessionPool::new(triplets),
tokenizer,
mel,
#[cfg(feature = "diarization")]
speaker_encoder,
})
}
#[cfg(feature = "diarization")]
pub fn has_speaker_encoder(&self) -> bool {
self.speaker_encoder.is_some()
}
pub fn create_state(&self, #[cfg(feature = "diarization")] diarization_enabled: bool) -> StreamingState {
#[cfg(feature = "diarization")]
let diarization_state = if diarization_enabled && self.speaker_encoder.is_some() {
Some(DiarizationStreamState {
audio_buffer: Vec::new(),
cluster: diarization::SpeakerCluster::new(),
current_speaker: None,
})
} else {
None
};
StreamingState {
decoder: DecoderState::new(self.tokenizer.blank_id()),
audio_buffer: Vec::new(),
accumulated_text: String::new(),
accumulated_words: Vec::new(),
total_frames: 0,
#[cfg(feature = "diarization")]
diarization_state,
}
}
pub fn process_chunk(
&self,
samples: &[f32],
state: &mut StreamingState,
triplet: &mut SessionTriplet,
) -> Result<Vec<TranscriptSegment>, GigasttError> {
if samples.is_empty() {
return Ok(vec![]);
}
#[cfg(feature = "diarization")]
let samples_16k_copy = if state.diarization_state.is_some() {
Some(samples.to_vec())
} else {
None
};
let samples = match audio::prepare_audio_buffer(samples, &mut state.audio_buffer) {
Some(s) => s,
None => return Ok(vec![]),
};
let samples = &samples[..];
let (features, num_frames) = self.mel.compute(samples);
if num_frames == 0 {
return Ok(vec![]);
}
#[cfg_attr(not(feature = "diarization"), allow(unused_mut))]
let (mut new_words, endpoint) = self
.run_inference(triplet, &features, num_frames, &mut state.decoder, state.total_frames)
.map_err(|e| GigasttError::Inference(format!("{e:#}")))?;
state.total_frames += num_frames;
#[cfg(feature = "diarization")]
if let (Some(dia), Some(copy), Some(enc)) = (
state.diarization_state.as_mut(),
samples_16k_copy.as_ref(),
self.speaker_encoder.as_ref(),
) {
dia.audio_buffer.extend_from_slice(copy);
while dia.audio_buffer.len() >= diarization::SEGMENT_SAMPLES {
let segment: Vec<f32> =
dia.audio_buffer.drain(..diarization::SEGMENT_SAMPLES).collect();
match enc.extract_embedding(&segment) {
Ok(embedding) => {
let speaker = dia.cluster.assign(&embedding);
dia.current_speaker = Some(speaker);
}
Err(e) => {
tracing::warn!("Embedding extraction failed: {e:#}");
}
}
}
if let Some(speaker_id) = dia.current_speaker {
for w in &mut new_words {
w.speaker = Some(speaker_id);
}
}
}
if new_words.is_empty() && !endpoint {
return Ok(vec![]);
}
for w in &new_words {
if !state.accumulated_text.is_empty() {
state.accumulated_text.push(' ');
}
state.accumulated_text.push_str(&w.word);
}
state.accumulated_words.extend(new_words);
let text = state.accumulated_text.clone();
let words = state.accumulated_words.clone();
let ts = now_timestamp();
if endpoint {
state.accumulated_text.clear();
state.accumulated_words.clear();
state.decoder.consecutive_blanks = 0;
Ok(vec![TranscriptSegment { text, words, is_final: true, timestamp: ts }])
} else {
Ok(vec![TranscriptSegment { text, words, is_final: false, timestamp: ts }])
}
}
pub fn flush_state(&self, state: &mut StreamingState) -> Option<TranscriptSegment> {
if state.accumulated_text.is_empty() {
return None;
}
let seg = TranscriptSegment {
text: std::mem::take(&mut state.accumulated_text),
words: std::mem::take(&mut state.accumulated_words),
is_final: true,
timestamp: now_timestamp(),
};
Some(seg)
}
pub fn transcribe_file(&self, path: &str, triplet: &mut SessionTriplet) -> Result<TranscribeResult, GigasttError> {
let float_samples = audio::decode_audio_file(path)
.map_err(|e| GigasttError::InvalidAudio(format!("{e:#}")))?;
let duration_s = float_samples.len() as f64 / 16000.0;
let (features, num_frames) = self.mel.compute(&float_samples);
tracing::info!("Extracted {} mel frames", num_frames);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
let (words, _endpoint) = self
.run_inference(triplet, &features, num_frames, &mut decoder_state, 0)
.map_err(|e| GigasttError::Inference(format!("{e:#}")))?;
let text: String = words.iter().map(|w| w.word.as_str()).collect::<Vec<_>>().join(" ");
Ok(TranscribeResult { text, words, duration_s })
}
pub fn transcribe_bytes(&self, data: &[u8], triplet: &mut SessionTriplet) -> Result<TranscribeResult, GigasttError> {
let float_samples = audio::decode_audio_bytes(data)
.map_err(|e| GigasttError::InvalidAudio(format!("{e:#}")))?;
let duration_s = float_samples.len() as f64 / 16000.0;
let (features, num_frames) = self.mel.compute(&float_samples);
tracing::info!("Extracted {} mel frames", num_frames);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
let (words, _endpoint) = self
.run_inference(triplet, &features, num_frames, &mut decoder_state, 0)
.map_err(|e| GigasttError::Inference(format!("{e:#}")))?;
let text: String = words.iter().map(|w| w.word.as_str()).collect::<Vec<_>>().join(" ");
Ok(TranscribeResult { text, words, duration_s })
}
fn run_inference(
&self,
triplet: &mut SessionTriplet,
features: &[f32],
num_frames: usize,
decoder_state: &mut DecoderState,
frame_offset: usize,
) -> anyhow::Result<(Vec<WordInfo>, bool)> {
let signal_tensor =
TensorRef::from_array_view(([1_usize, N_MELS, num_frames], features))?;
let length_data = [num_frames as i64];
let length_tensor =
TensorRef::from_array_view(([1_usize], length_data.as_slice()))?;
let encoder_outputs = triplet.encoder
.run(ort::inputs![signal_tensor, length_tensor])
.context("Encoder inference failed")?;
let (_enc_shape, enc_data) = encoder_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract encoder output")?;
let (_len_shape, len_data) = encoder_outputs[1]
.try_extract_tensor::<i32>()
.context("Failed to extract encoder length")?;
let enc_len = usize::try_from(len_data[0]).context("Negative encoder length")?;
tracing::debug!("Encoder output: {} frames", enc_len);
let enc_data_owned: Vec<f32> = enc_data.to_vec();
drop(encoder_outputs);
let result = decode::greedy_decode(
&mut triplet.decoder,
&mut triplet.joiner,
&enc_data_owned,
enc_len,
self.tokenizer.blank_id(),
decoder_state,
)?;
let words = self.tokens_to_words(&result.tokens, frame_offset);
let preview: String = words.iter().take(10).map(|w| w.word.as_str()).collect::<Vec<_>>().join(" ");
let ellipsis = if words.len() > 10 { "..." } else { "" };
tracing::info!("Decoded {} tokens → \"{preview}{ellipsis}\"", result.tokens.len());
Ok((words, result.endpoint_detected))
}
fn tokens_to_words(&self, tokens: &[decode::TokenInfo], frame_offset: usize) -> Vec<WordInfo> {
if tokens.is_empty() {
return Vec::new();
}
let token_ids: Vec<usize> = tokens.iter().map(|t| t.token_id).collect();
let raw_text = self.tokenizer.decode(&token_ids);
if raw_text.is_empty() {
return Vec::new();
}
let mut words = Vec::new();
let mut current_word = String::new();
let mut word_start_frame: Option<usize> = None;
let mut word_end_frame: usize = 0;
let mut word_confidences: Vec<f32> = Vec::new();
for token in tokens {
let token_text = self.tokenizer.token_text(token.token_id);
let is_word_boundary = token_text.starts_with('\u{2581}');
if is_word_boundary && !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word.clone(),
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
current_word.clear();
word_confidences.clear();
word_start_frame = None;
}
let clean = token_text.replace('\u{2581}', "");
if !clean.is_empty() {
current_word.push_str(&clean);
if word_start_frame.is_none() {
word_start_frame = Some(token.frame_index);
}
word_end_frame = token.frame_index;
word_confidences.push(token.confidence);
}
}
if !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word,
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
speaker: None,
});
}
words
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TranscribeResult {
pub text: String,
pub words: Vec<WordInfo>,
pub duration_s: f64,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TranscriptSegment {
pub text: String,
pub words: Vec<WordInfo>,
pub is_final: bool,
pub timestamp: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_state_new_zeros() {
let blank_id = 1024;
let state = DecoderState::new(blank_id);
assert!(state.h.iter().all(|&v| v == 0.0));
assert!(state.c.iter().all(|&v| v == 0.0));
assert_eq!(state.prev_token, blank_id as i64);
}
#[test]
fn test_decoder_state_dimensions() {
let state = DecoderState::new(1024);
assert_eq!(state.h.len(), PRED_HIDDEN);
assert_eq!(state.c.len(), PRED_HIDDEN);
}
#[test]
fn test_decoder_state_custom_blank_id() {
let state = DecoderState::new(42);
assert_eq!(state.prev_token, 42);
}
}