use crate::config::TranscriptionConfig;
use crate::constants::SAMPLES_PER_ENCODER_FRAME;
use crate::decode::{ParakeetTdtDecoder, RawTranscription};
use crate::error::TranscriptionError;
use crate::frontend::ParakeetFeatureExtractor;
use crate::model::ParakeetModel;
use crate::models::ModelBundle;
use crate::types::TranscriptionResult;
use crate::vocab::Vocabulary;
use ndarray::{Array2, s};
#[derive(Debug, Clone)]
pub struct TranscriptionPipeline {
bundle: ModelBundle,
config: TranscriptionConfig,
extractor: ParakeetFeatureExtractor,
decoder: ParakeetTdtDecoder,
model: ParakeetModel,
}
impl TranscriptionPipeline {
pub fn from_dir(models_dir: impl Into<std::path::PathBuf>) -> Result<Self, TranscriptionError> {
let bundle = ModelBundle::from_dir(models_dir);
Self::from_bundle(bundle)
}
pub fn from_bundle(bundle: ModelBundle) -> Result<Self, TranscriptionError> {
bundle.validate_base()?;
let config = TranscriptionConfig::default();
let vocab = Vocabulary::from_file(bundle.vocab_path())?;
let model = ParakeetModel::from_bundle(&bundle, &vocab)?;
Ok(Self {
extractor: ParakeetFeatureExtractor::new(&config),
decoder: ParakeetTdtDecoder::new(vocab),
config,
model,
bundle,
})
}
#[cfg(feature = "online")]
pub fn from_pretrained() -> Result<Self, TranscriptionError> {
let bundle = ModelBundle::from_pretrained().map_err(|error| {
TranscriptionError::CoreMl(format!("model download failed: {error}"))
})?;
Self::from_bundle(bundle)
}
pub fn run(&self, audio: &[f32]) -> Result<TranscriptionResult, TranscriptionError> {
self.run_with_config(audio, &self.config)
}
pub fn run_with_config(
&self,
audio: &[f32],
config: &TranscriptionConfig,
) -> Result<TranscriptionResult, TranscriptionError> {
let raw = self.transcribe_raw(audio, 0, 0, config)?;
let duration_seconds = audio.len() as f64 / config.sample_rate as f64;
Ok(self.decoder.decode(&raw, duration_seconds))
}
pub fn config(&self) -> &TranscriptionConfig {
&self.config
}
pub fn bundle(&self) -> &ModelBundle {
&self.bundle
}
#[cfg(feature = "long-form")]
pub(crate) fn decode_raw(
&self,
raw: &RawTranscription,
duration_seconds: f64,
) -> TranscriptionResult {
self.decoder.decode(raw, duration_seconds)
}
pub(crate) fn transcribe_raw(
&self,
audio: &[f32],
global_sample_offset: usize,
context_samples: usize,
config: &TranscriptionConfig,
) -> Result<RawTranscription, TranscriptionError> {
if audio.is_empty() {
return Err(TranscriptionError::EmptyAudio);
}
if audio.len() > config.max_audio_samples {
return Err(TranscriptionError::AudioTooLong {
max_seconds: config.max_duration_seconds(),
actual_seconds: audio.len() as f64 / config.sample_rate as f64,
});
}
let features = self.extractor.extract(audio)?;
let feature_frames = features.shape()[0];
let padded_features = pad_features(features, config.max_feature_frames());
let mut raw = self.model.transcribe(&padded_features, feature_frames)?;
apply_time_offsets(
&mut raw,
global_sample_offset / SAMPLES_PER_ENCODER_FRAME,
context_samples / SAMPLES_PER_ENCODER_FRAME,
);
Ok(raw)
}
}
fn pad_features(features: Array2<f32>, target_frames: usize) -> Array2<f32> {
let current_frames = features.shape()[0];
if current_frames >= target_frames {
return features;
}
let feature_size = features.shape()[1];
let mut padded = Array2::<f32>::zeros((target_frames, feature_size));
padded.slice_mut(s![..current_frames, ..]).assign(&features);
padded
}
fn apply_time_offsets(raw: &mut RawTranscription, frame_offset: usize, context_frames: usize) {
if context_frames == 0 {
for frame_idx in &mut raw.frame_indices {
*frame_idx += frame_offset;
}
return;
}
let mut keep_indices = Vec::new();
for (index, frame_idx) in raw.frame_indices.iter_mut().enumerate() {
if *frame_idx < context_frames {
continue;
}
*frame_idx = *frame_idx - context_frames + frame_offset;
keep_indices.push(index);
}
raw.token_ids = keep_indices
.iter()
.map(|index| raw.token_ids[*index])
.collect();
raw.frame_indices = keep_indices
.iter()
.map(|index| raw.frame_indices[*index])
.collect();
raw.durations = keep_indices
.iter()
.map(|index| raw.durations[*index])
.collect();
raw.confidences = keep_indices
.iter()
.map(|index| raw.confidences[*index])
.collect();
}