use ndarray::{Array, Array1, Array2, Array3, ArrayD, ArrayViewD, IxDyn};
use once_cell::sync::Lazy;
use ort::execution_providers::CPUExecutionProvider;
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::TensorRef;
use regex::Regex;
use std::fs;
use std::path::Path;
pub type DecoderState = (Array3<f32>, Array3<f32>);
const SUBSAMPLING_FACTOR: usize = 8;
const WINDOW_SIZE: f32 = 0.01;
const MAX_TOKENS_PER_STEP: usize = 10;
static DECODE_SPACE_RE: Lazy<Result<Regex, regex::Error>> =
Lazy::new(|| Regex::new(r"\A\s|\s\B|(\s)\b"));
#[derive(Debug, Clone)]
pub struct TimestampedResult {
pub text: String,
pub timestamps: Vec<f32>,
pub tokens: Vec<String>,
}
#[derive(thiserror::Error, Debug)]
pub enum ParakeetError {
#[error("ORT error")]
Ort(#[from] ort::Error),
#[error("I/O error")]
Io(#[from] std::io::Error),
#[error("ndarray shape error")]
Shape(#[from] ndarray::ShapeError),
#[error("Model input not found: {0}")]
InputNotFound(String),
#[error("Model output not found: {0}")]
OutputNotFound(String),
#[error("Failed to get tensor shape for input: {0}")]
TensorShape(String),
}
pub struct ParakeetModel {
encoder: Session,
decoder_joint: Session,
preprocessor: Session,
vocab: Vec<String>,
blank_idx: i32,
vocab_size: usize,
}
impl Drop for ParakeetModel {
fn drop(&mut self) {
log::debug!(
"Dropping ParakeetModel with {} vocab tokens",
self.vocab.len()
);
}
}
impl ParakeetModel {
pub fn new<P: AsRef<Path>>(model_dir: P, quantized: bool) -> Result<Self, ParakeetError> {
let encoder = Self::init_session(&model_dir, "encoder-model", None, quantized)?;
let decoder_joint = Self::init_session(&model_dir, "decoder_joint-model", None, quantized)?;
let preprocessor = Self::init_session(&model_dir, "nemo128", None, false)?;
let (vocab, blank_idx) = Self::load_vocab(&model_dir)?;
let vocab_size = vocab.len();
log::info!(
"Loaded vocabulary with {} tokens, blank_idx={}",
vocab_size,
blank_idx
);
Ok(Self {
encoder,
decoder_joint,
preprocessor,
vocab,
blank_idx,
vocab_size,
})
}
fn init_session<P: AsRef<Path>>(
model_dir: P,
model_name: &str,
intra_threads: Option<usize>,
try_quantized: bool,
) -> Result<Session, ParakeetError> {
let providers = vec![CPUExecutionProvider::default().build()];
let model_filename = if try_quantized {
let quantized_name = format!("{}.int8.onnx", model_name);
let quantized_path = model_dir.as_ref().join(&quantized_name);
if quantized_path.exists() {
log::info!("Loading quantized model from {}...", quantized_name);
quantized_name
} else {
let regular_name = format!("{}.onnx", model_name);
log::info!(
"Quantized model not found, loading regular model from {}...",
regular_name
);
regular_name
}
} else {
let regular_name = format!("{}.onnx", model_name);
log::info!("Loading model from {}...", regular_name);
regular_name
};
let mut builder = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_execution_providers(providers)?
.with_parallel_execution(true)?;
if let Some(threads) = intra_threads {
builder = builder
.with_intra_threads(threads)?
.with_inter_threads(threads)?;
}
let session = builder.commit_from_file(model_dir.as_ref().join(&model_filename))?;
for input in &session.inputs {
log::info!(
"Model '{}' input: name={}, type={:?}",
model_filename,
input.name,
input.input_type
);
}
Ok(session)
}
fn load_vocab<P: AsRef<Path>>(model_dir: P) -> Result<(Vec<String>, i32), ParakeetError> {
let vocab_path = model_dir.as_ref().join("vocab.txt");
let content = fs::read_to_string(vocab_path)?;
let mut max_id = 0;
let mut tokens_with_ids: Vec<(String, usize)> = Vec::new();
let mut blank_idx: Option<usize> = None;
for line in content.lines() {
let parts: Vec<&str> = line.trim_end().split(' ').collect();
if parts.len() >= 2 {
let token = parts[0].to_string();
if let Ok(id) = parts[1].parse::<usize>() {
if token == "<blk>" {
blank_idx = Some(id);
}
tokens_with_ids.push((token, id));
max_id = max_id.max(id);
}
}
}
let mut vocab = vec![String::new(); max_id + 1];
for (token, id) in tokens_with_ids {
vocab[id] = token.replace('\u{2581}', " ");
}
let blank_idx = blank_idx.ok_or_else(|| {
ParakeetError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Missing <blk> token in vocabulary",
))
})? as i32;
Ok((vocab, blank_idx))
}
pub fn preprocess(
&mut self,
waveforms: &ArrayViewD<f32>,
waveforms_lens: &ArrayViewD<i64>,
) -> Result<(ArrayD<f32>, ArrayD<i64>), ParakeetError> {
log::trace!("Running preprocessor inference...");
let inputs = inputs![
"waveforms" => TensorRef::from_array_view(waveforms.view())?,
"waveforms_lens" => TensorRef::from_array_view(waveforms_lens.view())?,
];
let outputs = self.preprocessor.run(inputs)?;
let features = outputs
.get("features")
.ok_or_else(|| ParakeetError::OutputNotFound("features".to_string()))?
.try_extract_array()?;
let features_lens = outputs
.get("features_lens")
.ok_or_else(|| ParakeetError::OutputNotFound("features_lens".to_string()))?
.try_extract_array()?;
Ok((features.to_owned(), features_lens.to_owned()))
}
pub fn encode(
&mut self,
audio_signal: &ArrayViewD<f32>,
length: &ArrayViewD<i64>,
) -> Result<(ArrayD<f32>, ArrayD<i64>), ParakeetError> {
log::trace!("Running encoder inference...");
let inputs = inputs![
"audio_signal" => TensorRef::from_array_view(audio_signal.view())?,
"length" => TensorRef::from_array_view(length.view())?,
];
let outputs = self.encoder.run(inputs)?;
let encoder_output = outputs
.get("outputs")
.ok_or_else(|| ParakeetError::OutputNotFound("outputs".to_string()))?
.try_extract_array()?;
let encoded_lengths = outputs
.get("encoded_lengths")
.ok_or_else(|| ParakeetError::OutputNotFound("encoded_lengths".to_string()))?
.try_extract_array()?;
let encoder_output = encoder_output.permuted_axes(IxDyn(&[0, 2, 1]));
Ok((encoder_output.to_owned(), encoded_lengths.to_owned()))
}
pub fn create_decoder_state(&self) -> Result<DecoderState, ParakeetError> {
let inputs = &self.decoder_joint.inputs;
let state1_shape = inputs
.iter()
.find(|input| input.name == "input_states_1")
.ok_or_else(|| ParakeetError::InputNotFound("input_states_1".to_string()))?
.input_type
.tensor_shape()
.ok_or_else(|| ParakeetError::TensorShape("input_states_1".to_string()))?;
let state2_shape = inputs
.iter()
.find(|input| input.name == "input_states_2")
.ok_or_else(|| ParakeetError::InputNotFound("input_states_2".to_string()))?
.input_type
.tensor_shape()
.ok_or_else(|| ParakeetError::TensorShape("input_states_2".to_string()))?;
let state1 = Array::zeros((
state1_shape[0] as usize,
1, state1_shape[2] as usize,
));
let state2 = Array::zeros((
state2_shape[0] as usize,
1, state2_shape[2] as usize,
));
Ok((state1, state2))
}
pub fn decode_step(
&mut self,
prev_tokens: &[i32],
prev_state: &DecoderState,
encoder_out: &ArrayViewD<f32>, ) -> Result<(ArrayD<f32>, DecoderState), ParakeetError> {
log::trace!("Running decoder inference...");
let target_token = prev_tokens.last().copied().unwrap_or(self.blank_idx);
let encoder_outputs = encoder_out
.to_owned()
.insert_axis(ndarray::Axis(0))
.insert_axis(ndarray::Axis(2));
let targets = Array2::from_shape_vec((1, 1), vec![target_token])?;
let target_length = Array1::from_vec(vec![1]);
let inputs = inputs![
"encoder_outputs" => TensorRef::from_array_view(encoder_outputs.view())?,
"targets" => TensorRef::from_array_view(targets.view())?,
"target_length" => TensorRef::from_array_view(target_length.view())?,
"input_states_1" => TensorRef::from_array_view(prev_state.0.view())?,
"input_states_2" => TensorRef::from_array_view(prev_state.1.view())?,
];
let outputs = self.decoder_joint.run(inputs)?;
let logits = outputs
.get("outputs")
.ok_or_else(|| ParakeetError::OutputNotFound("outputs".to_string()))?
.try_extract_array()?;
log::trace!(
"Logits shape: {:?}, vocab_size: {}",
logits.shape(),
self.vocab_size
);
let state1 = outputs
.get("output_states_1")
.ok_or_else(|| ParakeetError::OutputNotFound("output_states_1".to_string()))?
.try_extract_array()?;
let state2 = outputs
.get("output_states_2")
.ok_or_else(|| ParakeetError::OutputNotFound("output_states_2".to_string()))?
.try_extract_array()?;
let logits = logits.remove_axis(ndarray::Axis(0));
let state1_3d = state1.to_owned().into_dimensionality::<ndarray::Ix3>()?;
let state2_3d = state2.to_owned().into_dimensionality::<ndarray::Ix3>()?;
Ok((logits.to_owned(), (state1_3d, state2_3d)))
}
pub fn recognize_batch(
&mut self,
waveforms: &ArrayViewD<f32>,
waveforms_len: &ArrayViewD<i64>,
) -> Result<Vec<TimestampedResult>, ParakeetError> {
let (features, features_lens) = self.preprocess(waveforms, waveforms_len)?;
let (encoder_out, encoder_out_lens) =
self.encode(&features.view(), &features_lens.view())?;
let mut results = Vec::new();
for (encodings, &encodings_len) in encoder_out.outer_iter().zip(encoder_out_lens.iter()) {
let (tokens, timestamps) =
self.decode_sequence(&encodings.view(), encodings_len as usize)?;
let result = self.decode_tokens(tokens, timestamps);
results.push(result);
}
Ok(results)
}
fn decode_sequence(
&mut self,
encodings: &ArrayViewD<f32>, encodings_len: usize,
) -> Result<(Vec<i32>, Vec<usize>), ParakeetError> {
let mut prev_state = self.create_decoder_state()?;
let mut tokens = Vec::new();
let mut timestamps = Vec::new();
let mut t = 0;
let mut emitted_tokens = 0;
while t < encodings_len {
let encoder_step = encodings.slice(ndarray::s![t, ..]);
let encoder_step_dyn = encoder_step.to_owned().into_dyn();
let (probs, new_state) =
self.decode_step(&tokens, &prev_state, &encoder_step_dyn.view())?;
let vocab_logits_slice = probs.as_slice().ok_or_else(|| {
ParakeetError::Shape(ndarray::ShapeError::from_kind(
ndarray::ErrorKind::IncompatibleShape,
))
})?;
let vocab_logits = if probs.len() > self.vocab_size {
log::trace!(
"TDT model detected: splitting {} logits into vocab({}) + duration",
probs.len(),
self.vocab_size
);
&vocab_logits_slice[..self.vocab_size]
} else {
vocab_logits_slice
};
let token = vocab_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as i32)
.unwrap_or(self.blank_idx);
if token != self.blank_idx {
prev_state = new_state;
tokens.push(token);
timestamps.push(t);
emitted_tokens += 1;
}
if token == self.blank_idx || emitted_tokens == MAX_TOKENS_PER_STEP {
t += 1;
emitted_tokens = 0;
}
}
Ok((tokens, timestamps))
}
fn decode_tokens(&self, ids: Vec<i32>, timestamps: Vec<usize>) -> TimestampedResult {
let tokens: Vec<String> = ids
.iter()
.filter_map(|&id| {
let idx = id as usize;
if idx < self.vocab.len() {
Some(self.vocab[idx].clone())
} else {
None
}
})
.collect();
let text = match &*DECODE_SPACE_RE {
Ok(regex) => regex
.replace_all(&tokens.join(""), |caps: ®ex::Captures| {
if caps.get(1).is_some() {
" "
} else {
""
}
})
.to_string(),
Err(_) => tokens.join(""), };
let float_timestamps: Vec<f32> = timestamps
.iter()
.map(|&t| WINDOW_SIZE * SUBSAMPLING_FACTOR as f32 * t as f32)
.collect();
TimestampedResult {
text,
timestamps: float_timestamps,
tokens,
}
}
pub fn transcribe_samples(
&mut self,
samples: Vec<f32>,
) -> Result<TimestampedResult, ParakeetError> {
let batch_size = 1;
let samples_len = samples.len();
let waveforms = Array2::from_shape_vec((batch_size, samples_len), samples)?.into_dyn();
let waveforms_lens = Array1::from_vec(vec![samples_len as i64]).into_dyn();
let results = self.recognize_batch(&waveforms.view(), &waveforms_lens.view())?;
let timestamped_result = results.into_iter().next().ok_or_else(|| {
ParakeetError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No transcription result returned",
))
})?;
Ok(timestamped_result)
}
}