scriptrs 0.2.0

Rust transcription with native CoreML Parakeet v2 inference
Documentation
use crate::config::TranscriptionConfig;
use crate::constants::SAMPLES_PER_ENCODER_FRAME;
use crate::decode::{ParakeetTdtDecoder, RawTranscription};
use crate::error::TranscriptionError;
use crate::frontend::ParakeetFeatureExtractor;
use crate::model::ParakeetModel;
use crate::models::ModelBundle;
use crate::types::TranscriptionResult;
use crate::vocab::Vocabulary;
use ndarray::Array2;

/// Single-chunk Parakeet v2 transcription pipeline
///
/// This is the base `scriptrs` entry point for audio that already fits inside a
/// single Parakeet window. It expects mono 16kHz audio samples as `&[f32]`.
///
/// If the input is longer than [`TranscriptionConfig::max_audio_samples`], this
/// pipeline returns [`TranscriptionError::AudioTooLong`] instead of splitting it.
/// Use `LongFormTranscriptionPipeline` with the `long-form` feature if you want
/// `scriptrs` to own chunk planning internally.
#[derive(Debug, Clone)]
pub struct TranscriptionPipeline {
    bundle: ModelBundle,
    config: TranscriptionConfig,
    extractor: ParakeetFeatureExtractor,
    decoder: ParakeetTdtDecoder,
    model: ParakeetModel,
}

impl TranscriptionPipeline {
    /// Build a transcription pipeline from a local model directory
    ///
    /// The directory must contain the Parakeet runtime bundle expected by
    /// [`ModelBundle::from_dir`].
    pub fn from_dir(models_dir: impl Into<std::path::PathBuf>) -> Result<Self, TranscriptionError> {
        let bundle = ModelBundle::from_dir(models_dir);
        Self::from_bundle(bundle)
    }

    /// Build a transcription pipeline from a resolved model bundle
    pub fn from_bundle(bundle: ModelBundle) -> Result<Self, TranscriptionError> {
        bundle.validate_base()?;
        let config = TranscriptionConfig::default();
        let vocab = Vocabulary::from_file(bundle.vocab_path())?;
        let model = ParakeetModel::from_bundle(&bundle, &vocab, &config)?;
        Ok(Self {
            extractor: ParakeetFeatureExtractor::new(&config),
            decoder: ParakeetTdtDecoder::new(vocab),
            config,
            model,
            bundle,
        })
    }

    #[cfg(feature = "online")]
    /// Download models and build a transcription pipeline
    ///
    /// With the default configuration this resolves models from
    /// `avencera/scriptrs-models` on Hugging Face. Set `SCRIPTRS_MODELS_DIR` to
    /// force a local bundle or `SCRIPTRS_MODELS_REPO` to override the repo.
    pub fn from_pretrained() -> Result<Self, TranscriptionError> {
        let bundle = ModelBundle::from_pretrained().map_err(|error| {
            TranscriptionError::CoreMl(format!("model download failed: {error}"))
        })?;
        Self::from_bundle(bundle)
    }

    /// Transcribe a single chunk of audio
    ///
    /// `audio` must be mono 16kHz samples. Empty input returns
    /// [`TranscriptionError::EmptyAudio`]. Oversized input returns
    /// [`TranscriptionError::AudioTooLong`].
    pub fn run(&self, audio: &[f32]) -> Result<TranscriptionResult, TranscriptionError> {
        self.run_with_config(audio, &self.config)
    }

    /// Transcribe a single chunk of audio with an explicit config
    ///
    /// This is mainly useful if you want to reuse the same pipeline with a
    /// tweaked [`TranscriptionConfig`] instead of the default frontend settings.
    pub fn run_with_config(
        &self,
        audio: &[f32],
        config: &TranscriptionConfig,
    ) -> Result<TranscriptionResult, TranscriptionError> {
        let raw = self.transcribe_raw(audio, 0, 0, config)?;
        let duration_seconds = audio.len() as f64 / config.sample_rate as f64;
        Ok(self.decoder.decode(&raw, duration_seconds))
    }

    /// Return the default pipeline config
    pub fn config(&self) -> &TranscriptionConfig {
        &self.config
    }

    /// Return the resolved model bundle
    pub fn bundle(&self) -> &ModelBundle {
        &self.bundle
    }

    #[cfg(feature = "long-form")]
    pub(crate) fn decode_raw(
        &self,
        raw: &RawTranscription,
        duration_seconds: f64,
    ) -> TranscriptionResult {
        self.decoder.decode(raw, duration_seconds)
    }

    pub(crate) fn transcribe_raw(
        &self,
        audio: &[f32],
        global_sample_offset: usize,
        context_samples: usize,
        config: &TranscriptionConfig,
    ) -> Result<RawTranscription, TranscriptionError> {
        let prepared = self.prepare_chunk(audio, config)?;
        self.transcribe_prepared_raw(prepared, global_sample_offset, context_samples)
    }

    pub(crate) fn transcribe_prepared_raw(
        &self,
        prepared: PreparedChunk,
        global_sample_offset: usize,
        context_samples: usize,
    ) -> Result<RawTranscription, TranscriptionError> {
        let mut raw = self.model.transcribe(
            &prepared.features,
            prepared.feature_frames,
            prepared.target_frames,
        )?;
        apply_time_offsets(
            &mut raw,
            global_sample_offset / SAMPLES_PER_ENCODER_FRAME,
            context_samples / SAMPLES_PER_ENCODER_FRAME,
        );
        Ok(raw)
    }

    #[cfg(feature = "long-form")]
    pub(crate) fn chunk_preparer(config: &TranscriptionConfig) -> ChunkPreparer {
        ChunkPreparer::new(config)
    }

    fn prepare_chunk(
        &self,
        audio: &[f32],
        config: &TranscriptionConfig,
    ) -> Result<PreparedChunk, TranscriptionError> {
        if audio.is_empty() {
            return Err(TranscriptionError::EmptyAudio);
        }
        if audio.len() > config.max_audio_samples {
            return Err(TranscriptionError::AudioTooLong {
                max_seconds: config.max_duration_seconds(),
                actual_seconds: audio.len() as f64 / config.sample_rate as f64,
            });
        }

        Ok(PreparedChunk::new(
            self.extractor.extract(audio)?,
            config.max_feature_frames(),
        ))
    }
}

#[derive(Debug)]
pub(crate) struct PreparedChunk {
    features: Array2<f32>,
    feature_frames: usize,
    target_frames: usize,
}

impl PreparedChunk {
    fn new(features: Array2<f32>, target_frames: usize) -> Self {
        let feature_frames = features.shape()[0];
        Self {
            features,
            feature_frames,
            target_frames,
        }
    }
}

#[cfg(feature = "long-form")]
#[derive(Debug)]
pub(crate) struct ChunkPreparer {
    config: TranscriptionConfig,
    extractor: ParakeetFeatureExtractor,
}

#[cfg(feature = "long-form")]
impl ChunkPreparer {
    fn new(config: &TranscriptionConfig) -> Self {
        Self {
            config: config.clone(),
            extractor: ParakeetFeatureExtractor::new(config),
        }
    }

    pub(crate) fn prepare(&self, audio: &[f32]) -> Result<PreparedChunk, TranscriptionError> {
        if audio.is_empty() {
            return Err(TranscriptionError::EmptyAudio);
        }
        if audio.len() > self.config.max_audio_samples {
            return Err(TranscriptionError::AudioTooLong {
                max_seconds: self.config.max_duration_seconds(),
                actual_seconds: audio.len() as f64 / self.config.sample_rate as f64,
            });
        }

        Ok(PreparedChunk::new(
            self.extractor.extract(audio)?,
            self.config.max_feature_frames(),
        ))
    }
}

fn apply_time_offsets(raw: &mut RawTranscription, frame_offset: usize, context_frames: usize) {
    if context_frames == 0 {
        for frame_idx in &mut raw.frame_indices {
            *frame_idx += frame_offset;
        }
        return;
    }

    let mut keep_indices = Vec::new();
    for (index, frame_idx) in raw.frame_indices.iter_mut().enumerate() {
        if *frame_idx < context_frames {
            continue;
        }
        *frame_idx = *frame_idx - context_frames + frame_offset;
        keep_indices.push(index);
    }

    raw.token_ids = keep_indices
        .iter()
        .map(|index| raw.token_ids[*index])
        .collect();
    raw.frame_indices = keep_indices
        .iter()
        .map(|index| raw.frame_indices[*index])
        .collect();
    raw.durations = keep_indices
        .iter()
        .map(|index| raw.durations[*index])
        .collect();
    raw.confidences = keep_indices
        .iter()
        .map(|index| raw.confidences[*index])
        .collect();
}