use crate::config::TranscriptionConfig;
use crate::constants::SAMPLES_PER_ENCODER_FRAME;
use crate::decode::{ParakeetTdtDecoder, RawTranscription};
use crate::error::TranscriptionError;
use crate::frontend::ParakeetFeatureExtractor;
use crate::model::ParakeetModel;
use crate::models::ModelBundle;
use crate::types::TranscriptionResult;
use crate::vocab::Vocabulary;
use ndarray::Array2;
#[derive(Debug, Clone)]
pub struct TranscriptionPipeline {
bundle: ModelBundle,
config: TranscriptionConfig,
extractor: ParakeetFeatureExtractor,
decoder: ParakeetTdtDecoder,
model: ParakeetModel,
}
impl TranscriptionPipeline {
pub fn from_dir(models_dir: impl Into<std::path::PathBuf>) -> Result<Self, TranscriptionError> {
let bundle = ModelBundle::from_dir(models_dir);
Self::from_bundle(bundle)
}
pub fn from_bundle(bundle: ModelBundle) -> Result<Self, TranscriptionError> {
bundle.validate_base()?;
let config = TranscriptionConfig::default();
let vocab = Vocabulary::from_file(bundle.vocab_path())?;
let model = ParakeetModel::from_bundle(&bundle, &vocab, &config)?;
Ok(Self {
extractor: ParakeetFeatureExtractor::new(&config),
decoder: ParakeetTdtDecoder::new(vocab),
config,
model,
bundle,
})
}
#[cfg(feature = "online")]
pub fn from_pretrained() -> Result<Self, TranscriptionError> {
let bundle = ModelBundle::from_pretrained().map_err(|error| {
TranscriptionError::CoreMl(format!("model download failed: {error}"))
})?;
Self::from_bundle(bundle)
}
pub fn run(&self, audio: &[f32]) -> Result<TranscriptionResult, TranscriptionError> {
self.run_with_config(audio, &self.config)
}
pub fn run_with_config(
&self,
audio: &[f32],
config: &TranscriptionConfig,
) -> Result<TranscriptionResult, TranscriptionError> {
let raw = self.transcribe_raw(audio, 0, 0, config)?;
let duration_seconds = audio.len() as f64 / config.sample_rate as f64;
Ok(self.decoder.decode(&raw, duration_seconds))
}
pub fn config(&self) -> &TranscriptionConfig {
&self.config
}
pub fn bundle(&self) -> &ModelBundle {
&self.bundle
}
#[cfg(feature = "long-form")]
pub(crate) fn decode_raw(
&self,
raw: &RawTranscription,
duration_seconds: f64,
) -> TranscriptionResult {
self.decoder.decode(raw, duration_seconds)
}
pub(crate) fn transcribe_raw(
&self,
audio: &[f32],
global_sample_offset: usize,
context_samples: usize,
config: &TranscriptionConfig,
) -> Result<RawTranscription, TranscriptionError> {
let prepared = self.prepare_chunk(audio, config)?;
self.transcribe_prepared_raw(prepared, global_sample_offset, context_samples)
}
pub(crate) fn transcribe_prepared_raw(
&self,
prepared: PreparedChunk,
global_sample_offset: usize,
context_samples: usize,
) -> Result<RawTranscription, TranscriptionError> {
let mut raw = self.model.transcribe(
&prepared.features,
prepared.feature_frames,
prepared.target_frames,
)?;
apply_time_offsets(
&mut raw,
global_sample_offset / SAMPLES_PER_ENCODER_FRAME,
context_samples / SAMPLES_PER_ENCODER_FRAME,
);
Ok(raw)
}
#[cfg(feature = "long-form")]
pub(crate) fn chunk_preparer(config: &TranscriptionConfig) -> ChunkPreparer {
ChunkPreparer::new(config)
}
fn prepare_chunk(
&self,
audio: &[f32],
config: &TranscriptionConfig,
) -> Result<PreparedChunk, TranscriptionError> {
if audio.is_empty() {
return Err(TranscriptionError::EmptyAudio);
}
if audio.len() > config.max_audio_samples {
return Err(TranscriptionError::AudioTooLong {
max_seconds: config.max_duration_seconds(),
actual_seconds: audio.len() as f64 / config.sample_rate as f64,
});
}
Ok(PreparedChunk::new(
self.extractor.extract(audio)?,
config.max_feature_frames(),
))
}
}
#[derive(Debug)]
pub(crate) struct PreparedChunk {
features: Array2<f32>,
feature_frames: usize,
target_frames: usize,
}
impl PreparedChunk {
fn new(features: Array2<f32>, target_frames: usize) -> Self {
let feature_frames = features.shape()[0];
Self {
features,
feature_frames,
target_frames,
}
}
}
#[cfg(feature = "long-form")]
#[derive(Debug)]
pub(crate) struct ChunkPreparer {
config: TranscriptionConfig,
extractor: ParakeetFeatureExtractor,
}
#[cfg(feature = "long-form")]
impl ChunkPreparer {
fn new(config: &TranscriptionConfig) -> Self {
Self {
config: config.clone(),
extractor: ParakeetFeatureExtractor::new(config),
}
}
pub(crate) fn prepare(&self, audio: &[f32]) -> Result<PreparedChunk, TranscriptionError> {
if audio.is_empty() {
return Err(TranscriptionError::EmptyAudio);
}
if audio.len() > self.config.max_audio_samples {
return Err(TranscriptionError::AudioTooLong {
max_seconds: self.config.max_duration_seconds(),
actual_seconds: audio.len() as f64 / self.config.sample_rate as f64,
});
}
Ok(PreparedChunk::new(
self.extractor.extract(audio)?,
self.config.max_feature_frames(),
))
}
}
fn apply_time_offsets(raw: &mut RawTranscription, frame_offset: usize, context_frames: usize) {
if context_frames == 0 {
for frame_idx in &mut raw.frame_indices {
*frame_idx += frame_offset;
}
return;
}
let mut keep_indices = Vec::new();
for (index, frame_idx) in raw.frame_indices.iter_mut().enumerate() {
if *frame_idx < context_frames {
continue;
}
*frame_idx = *frame_idx - context_frames + frame_offset;
keep_indices.push(index);
}
raw.token_ids = keep_indices
.iter()
.map(|index| raw.token_ids[*index])
.collect();
raw.frame_indices = keep_indices
.iter()
.map(|index| raw.frame_indices[*index])
.collect();
raw.durations = keep_indices
.iter()
.map(|index| raw.durations[*index])
.collect();
raw.confidences = keep_indices
.iter()
.map(|index| raw.confidences[*index])
.collect();
}