mod decode;
mod features;
mod tokenizer;
pub mod audio;
use anyhow::{Context, Result};
use ort::ep;
use ort::session::Session;
use ort::value::TensorRef;
use serde::Serialize;
use std::path::Path;
use std::sync::Mutex;
use features::MelSpectrogram;
use tokenizer::Tokenizer;
pub const N_MELS: usize = 64;
pub const N_FFT: usize = 320;
pub const HOP_LENGTH: usize = 160;
pub const PRED_HIDDEN: usize = 320;
fn ort_err(e: impl std::fmt::Display) -> anyhow::Error {
anyhow::anyhow!("{e}")
}
pub(crate) fn now_timestamp() -> f64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
const SECONDS_PER_FRAME: f64 = (HOP_LENGTH as f64) * 4.0 / 16000.0;
pub struct DecoderState {
pub h: Vec<f32>,
pub c: Vec<f32>,
pub prev_token: i64,
pub consecutive_blanks: usize,
}
impl DecoderState {
pub fn new(blank_id: usize) -> Self {
Self {
h: vec![0.0; PRED_HIDDEN],
c: vec![0.0; PRED_HIDDEN],
prev_token: blank_id as i64,
consecutive_blanks: 0,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct WordInfo {
pub word: String,
pub start: f64,
pub end: f64,
pub confidence: f32,
}
pub struct StreamingState {
pub decoder: DecoderState,
pub audio_buffer: Vec<f32>,
pub accumulated_text: String,
pub accumulated_words: Vec<WordInfo>,
pub total_frames: usize,
}
pub struct Engine {
encoder: Mutex<Session>,
decoder: Mutex<Session>,
joiner: Mutex<Session>,
tokenizer: Tokenizer,
mel: MelSpectrogram,
}
impl Engine {
pub fn load(model_dir: &str) -> Result<Self> {
let dir = Path::new(model_dir);
anyhow::ensure!(
dir.join("v3_e2e_rnnt_encoder.onnx").exists(),
"v3_e2e_rnnt_encoder.onnx not found in {model_dir}"
);
let encoder_path = if dir.join("v3_e2e_rnnt_encoder_int8.onnx").exists() {
tracing::info!("Using INT8 quantized encoder");
dir.join("v3_e2e_rnnt_encoder_int8.onnx")
} else {
dir.join("v3_e2e_rnnt_encoder.onnx")
};
tracing::info!("Loading ONNX models from {model_dir}...");
let cache_dir = dir.join("coreml_cache");
let coreml_ep = ep::CoreML::default()
.with_compute_units(ep::coreml::ComputeUnits::CPUAndNeuralEngine)
.with_specialization_strategy(ep::coreml::SpecializationStrategy::FastPrediction)
.with_model_cache_dir(cache_dir.to_string_lossy())
.build();
let encoder = Session::builder()
.map_err(ort_err)?
.with_execution_providers([coreml_ep.clone()])
.map_err(ort_err)?
.commit_from_file(&encoder_path)
.map_err(ort_err)?;
let decoder = Session::builder()
.map_err(ort_err)?
.with_execution_providers([coreml_ep.clone()])
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_decoder.onnx"))
.map_err(ort_err)?;
let joiner = Session::builder()
.map_err(ort_err)?
.with_execution_providers([coreml_ep])
.map_err(ort_err)?
.commit_from_file(dir.join("v3_e2e_rnnt_joint.onnx"))
.map_err(ort_err)?;
let tokenizer = Tokenizer::load(&dir.join("v3_e2e_rnnt_vocab.txt"))?;
let mel = MelSpectrogram::new();
tracing::info!("Models loaded (vocab_size={})", tokenizer.vocab_size());
Ok(Self {
encoder: Mutex::new(encoder),
decoder: Mutex::new(decoder),
joiner: Mutex::new(joiner),
tokenizer,
mel,
})
}
pub fn create_state(&self) -> StreamingState {
StreamingState {
decoder: DecoderState::new(self.tokenizer.blank_id()),
audio_buffer: Vec::new(),
accumulated_text: String::new(),
accumulated_words: Vec::new(),
total_frames: 0,
}
}
pub fn process_chunk(
&self,
samples: &[f32],
state: &mut StreamingState,
) -> Result<Vec<TranscriptSegment>> {
if samples.is_empty() {
return Ok(vec![]);
}
let samples = match audio::prepare_audio_buffer(samples, &mut state.audio_buffer) {
Some(s) => s,
None => return Ok(vec![]),
};
let samples = &samples[..];
let (features, num_frames) = self.mel.compute(samples);
if num_frames == 0 {
return Ok(vec![]);
}
let (new_words, endpoint) =
self.run_inference(&features, num_frames, &mut state.decoder, state.total_frames)?;
state.total_frames += num_frames;
if new_words.is_empty() && !endpoint {
return Ok(vec![]);
}
for w in &new_words {
if !state.accumulated_text.is_empty() {
state.accumulated_text.push(' ');
}
state.accumulated_text.push_str(&w.word);
}
state.accumulated_words.extend(new_words);
let text = state.accumulated_text.clone();
let words = state.accumulated_words.clone();
let ts = now_timestamp();
if endpoint {
state.accumulated_text.clear();
state.accumulated_words.clear();
state.decoder.consecutive_blanks = 0;
Ok(vec![TranscriptSegment { text, words, is_final: true, timestamp: ts }])
} else {
Ok(vec![TranscriptSegment { text, words, is_final: false, timestamp: ts }])
}
}
pub fn flush_state(&self, state: &mut StreamingState) -> Option<TranscriptSegment> {
if state.accumulated_text.is_empty() {
return None;
}
let seg = TranscriptSegment {
text: std::mem::take(&mut state.accumulated_text),
words: std::mem::take(&mut state.accumulated_words),
is_final: true,
timestamp: now_timestamp(),
};
Some(seg)
}
pub fn transcribe_file(&self, path: &str) -> Result<String> {
let float_samples = audio::decode_audio_file(path)?;
let (features, num_frames) = self.mel.compute(&float_samples);
tracing::info!("Extracted {} mel frames", num_frames);
let mut decoder_state = DecoderState::new(self.tokenizer.blank_id());
let (words, _endpoint) = self.run_inference(&features, num_frames, &mut decoder_state, 0)?;
let text: String = words.iter().map(|w| w.word.as_str()).collect::<Vec<_>>().join(" ");
Ok(text)
}
fn run_inference(
&self,
features: &[f32],
num_frames: usize,
decoder_state: &mut DecoderState,
frame_offset: usize,
) -> Result<(Vec<WordInfo>, bool)> {
let signal_tensor =
TensorRef::from_array_view(([1_usize, N_MELS, num_frames], features))?;
let length_data = [num_frames as i64];
let length_tensor =
TensorRef::from_array_view(([1_usize], length_data.as_slice()))?;
let mut encoder = self.encoder.lock().unwrap_or_else(|e| {
tracing::warn!("Encoder mutex was poisoned, recovering");
e.into_inner()
});
let encoder_outputs = encoder
.run(ort::inputs![signal_tensor, length_tensor])
.context("Encoder inference failed")?;
let (_enc_shape, enc_data) = encoder_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract encoder output")?;
let (_len_shape, len_data) = encoder_outputs[1]
.try_extract_tensor::<i32>()
.context("Failed to extract encoder length")?;
let enc_len = usize::try_from(len_data[0]).context("Negative encoder length")?;
tracing::debug!("Encoder output: {} frames", enc_len);
let enc_data_owned: Vec<f32> = enc_data.to_vec();
drop(encoder_outputs);
drop(encoder);
let mut decoder = self.decoder.lock().unwrap_or_else(|e| {
tracing::warn!("Decoder mutex was poisoned, recovering");
e.into_inner()
});
let mut joiner = self.joiner.lock().unwrap_or_else(|e| {
tracing::warn!("Joiner mutex was poisoned, recovering");
e.into_inner()
});
let result = decode::greedy_decode(
&mut decoder,
&mut joiner,
&enc_data_owned,
enc_len,
self.tokenizer.blank_id(),
decoder_state,
)?;
let words = self.tokens_to_words(&result.tokens, frame_offset);
let preview: String = words.iter().take(10).map(|w| w.word.as_str()).collect::<Vec<_>>().join(" ");
let ellipsis = if words.len() > 10 { "..." } else { "" };
tracing::info!("Decoded {} tokens → \"{preview}{ellipsis}\"", result.tokens.len());
Ok((words, result.endpoint_detected))
}
fn tokens_to_words(&self, tokens: &[decode::TokenInfo], frame_offset: usize) -> Vec<WordInfo> {
if tokens.is_empty() {
return Vec::new();
}
let token_ids: Vec<usize> = tokens.iter().map(|t| t.token_id).collect();
let raw_text = self.tokenizer.decode(&token_ids);
if raw_text.is_empty() {
return Vec::new();
}
let mut words = Vec::new();
let mut current_word = String::new();
let mut word_start_frame: Option<usize> = None;
let mut word_end_frame: usize = 0;
let mut word_confidences: Vec<f32> = Vec::new();
for token in tokens {
let token_text = self.tokenizer.token_text(token.token_id);
let is_word_boundary = token_text.starts_with('\u{2581}');
if is_word_boundary && !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word.clone(),
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
});
current_word.clear();
word_confidences.clear();
word_start_frame = None;
}
let clean = token_text.replace('\u{2581}', "");
if !clean.is_empty() {
current_word.push_str(&clean);
if word_start_frame.is_none() {
word_start_frame = Some(token.frame_index);
}
word_end_frame = token.frame_index;
word_confidences.push(token.confidence);
}
}
if !current_word.is_empty() {
let avg_conf: f32 = if word_confidences.is_empty() {
1.0
} else {
word_confidences.iter().sum::<f32>() / word_confidences.len() as f32
};
words.push(WordInfo {
word: current_word,
start: (word_start_frame.unwrap_or(0) + frame_offset) as f64 * SECONDS_PER_FRAME,
end: (word_end_frame + frame_offset) as f64 * SECONDS_PER_FRAME,
confidence: avg_conf,
});
}
words
}
}
#[derive(Debug, Clone)]
pub struct TranscriptSegment {
pub text: String,
pub words: Vec<WordInfo>,
pub is_final: bool,
pub timestamp: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_state_new_zeros() {
let blank_id = 1024;
let state = DecoderState::new(blank_id);
assert!(state.h.iter().all(|&v| v == 0.0));
assert!(state.c.iter().all(|&v| v == 0.0));
assert_eq!(state.prev_token, blank_id as i64);
}
#[test]
fn test_decoder_state_dimensions() {
let state = DecoderState::new(1024);
assert_eq!(state.h.len(), PRED_HIDDEN);
assert_eq!(state.c.len(), PRED_HIDDEN);
}
#[test]
fn test_decoder_state_custom_blank_id() {
let state = DecoderState::new(42);
assert_eq!(state.prev_token, 42);
}
}