#![cfg(feature = "onnx")]
use super::engine::{AsrEngine, StreamingState, TranscriptionResult};
use anyhow::{Context, Result};
use ort::session::Session;
use std::path::Path;
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
pub struct ParakeetEngine {
session: Mutex<Session>,
tokenizer: tokenizers::Tokenizer,
model_name: String,
}
unsafe impl Send for ParakeetEngine {}
unsafe impl Sync for ParakeetEngine {}
struct ParakeetStreamingState;
impl StreamingState for ParakeetStreamingState {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl ParakeetEngine {
pub fn new(model_dir: &Path, quantization: crate::config::AsrQuantization) -> Result<Self> {
use crate::config::AsrQuantization;
let model_file = match quantization {
AsrQuantization::Int8 => "model.int8.onnx",
_ => "model.onnx",
};
let model_path = model_dir.join(model_file);
let tokenizer_path = model_dir.join("tokenizer.json");
let session = Session::builder()?
.with_intra_threads(num_inference_threads())?
.commit_from_file(&model_path)
.context("Failed to load Parakeet ONNX model")?;
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;
let model_name = model_dir
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("parakeet")
.to_string();
Ok(Self {
session: Mutex::new(session),
tokenizer,
model_name,
})
}
fn run_inference(&self, samples: &[f32]) -> Result<String> {
let num_samples = samples.len();
let audio_tensor = ndarray::Array2::from_shape_vec((1, num_samples), samples.to_vec())?;
let length_tensor = ndarray::Array1::from_vec(vec![num_samples as i64]);
let audio_input = ort::value::Tensor::from_array(audio_tensor)?;
let length_input = ort::value::Tensor::from_array(length_tensor)?;
let mut sess = self.session.lock().unwrap();
let outputs = sess.run(ort::inputs![audio_input, length_input])?;
let token_ids = outputs[0].try_extract_array::<i64>()?;
let ids: Vec<u32> = token_ids.iter().map(|&id| id as u32).collect();
let text = self
.tokenizer
.decode(&ids, true)
.map_err(|e| anyhow::anyhow!("Tokenizer decode failed: {e}"))?;
Ok(text.trim().to_string())
}
}
impl AsrEngine for ParakeetEngine {
fn transcribe(&self, samples: &[f32], _translate: bool) -> Result<TranscriptionResult> {
let text = self.run_inference(samples)?;
Ok(TranscriptionResult {
text,
pre_formatted: true, })
}
fn create_streaming_state(&self) -> Result<Box<dyn StreamingState>> {
Ok(Box::new(ParakeetStreamingState))
}
fn streaming_transcribe(
&self,
_state: &mut dyn StreamingState,
samples: &[f32],
translate: bool,
_abort_flag: &Arc<AtomicBool>,
) -> Result<TranscriptionResult> {
self.transcribe(samples, translate)
}
fn engine_name(&self) -> String {
format!("Parakeet-TDT ({})", self.model_name)
}
fn supports_translation(&self) -> bool {
false
}
}
fn num_inference_threads() -> usize {
let cpus = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
(cpus / 2).max(1)
}