scriptrs 0.2.0

Rust transcription with native CoreML Parakeet v2 inference
Documentation
use std::cell::RefCell;

use ndarray::{Array1, Array2, Array3};

use crate::config::TranscriptionConfig;
use crate::constants::{
    DECODER_HIDDEN_SIZE, DECODER_LAYERS, ENCODER_HIDDEN_SIZE, MAX_TOKENS_PER_STEP,
};
use crate::decode::RawTranscription;
use crate::error::TranscriptionError;
use crate::models::ModelBundle;
use crate::vocab::Vocabulary;

#[derive(Debug, Clone)]
pub(crate) struct ParakeetModel {
    inner: ParakeetModelInner,
    blank_id: usize,
    encoder_input: RefCell<EncoderInputBuffer>,
}

impl ParakeetModel {
    pub(crate) fn from_bundle(
        bundle: &ModelBundle,
        vocab: &Vocabulary,
        config: &TranscriptionConfig,
    ) -> Result<Self, TranscriptionError> {
        Ok(Self {
            inner: ParakeetModelInner::from_bundle(bundle)?,
            blank_id: vocab.blank_id(),
            encoder_input: RefCell::new(EncoderInputBuffer::new(
                config.feature_size,
                config.max_feature_frames(),
            )),
        })
    }

    pub(crate) fn transcribe(
        &self,
        features: &Array2<f32>,
        feature_frames: usize,
        target_frames: usize,
    ) -> Result<RawTranscription, TranscriptionError> {
        let (encoder_output, time_steps) =
            self.run_encoder(features, feature_frames, target_frames)?;
        self.greedy_decode(&encoder_output, time_steps)
    }

    fn run_encoder(
        &self,
        features: &Array2<f32>,
        feature_frames: usize,
        target_frames: usize,
    ) -> Result<(Array3<f32>, usize), TranscriptionError> {
        let mut encoder_input = self.encoder_input.borrow_mut();
        encoder_input.copy_from_features(features, feature_frames, target_frames)?;
        self.inner.run_encoder(
            encoder_input.values(),
            encoder_input.feature_size(),
            encoder_input.target_frames(),
            feature_frames,
        )
    }

    fn greedy_decode(
        &self,
        encoder_output: &Array3<f32>,
        time_steps: usize,
    ) -> Result<RawTranscription, TranscriptionError> {
        let mut state = GreedyDecodeState::new(self.blank_id);

        while state.frame_idx < time_steps {
            state.ensure_decoder_step(&self.inner)?;
            state.copy_encoder_frame(encoder_output)?;
            let cached_decoder = state.cached_decoder()?;
            let decision = self
                .inner
                .run_joint(state.encoder_step(), &cached_decoder.decoder_step)?;

            if decision.token_id != self.blank_id {
                let cached_decoder = state.take_cached_decoder()?;
                state.record_emission(
                    decision.token_id,
                    decision.duration_step,
                    decision.confidence,
                    cached_decoder.hidden_state,
                    cached_decoder.cell_state,
                );
            }

            state.advance(decision.token_id, decision.duration_step, self.blank_id);
        }

        Ok(state.into_raw())
    }
}

#[derive(Debug, Clone)]
struct GreedyDecodeState {
    hidden_state: Array3<f32>,
    cell_state: Array3<f32>,
    decoder_targets: Array2<i32>,
    decoder_target_length: Array1<i32>,
    encoder_step: Array3<f32>,
    cached_decoder: Option<CachedDecoderStep>,
    raw: RawTranscription,
    frame_idx: usize,
    emitted_tokens: usize,
    last_token: i32,
}

impl GreedyDecodeState {
    fn new(blank_id: usize) -> Self {
        Self {
            hidden_state: Array3::<f32>::zeros((DECODER_LAYERS, 1, DECODER_HIDDEN_SIZE)),
            cell_state: Array3::<f32>::zeros((DECODER_LAYERS, 1, DECODER_HIDDEN_SIZE)),
            decoder_targets: Array2::<i32>::zeros((1, 1)),
            decoder_target_length: Array1::from_elem(1, 1i32),
            encoder_step: Array3::<f32>::zeros((1, ENCODER_HIDDEN_SIZE, 1)),
            cached_decoder: None,
            raw: RawTranscription::empty(),
            frame_idx: 0,
            emitted_tokens: 0,
            last_token: blank_id as i32,
        }
    }

    fn copy_encoder_frame(
        &mut self,
        encoder_output: &Array3<f32>,
    ) -> Result<(), TranscriptionError> {
        if self.frame_idx >= encoder_output.shape()[2] {
            return Err(TranscriptionError::InvalidModelOutput(format!(
                "encoder frame index {} exceeded time steps {}",
                self.frame_idx,
                encoder_output.shape()[2]
            )));
        }

        for hidden_idx in 0..ENCODER_HIDDEN_SIZE {
            self.encoder_step[[0, hidden_idx, 0]] = encoder_output[[0, hidden_idx, self.frame_idx]];
        }

        Ok(())
    }

    fn ensure_decoder_step(
        &mut self,
        model: &ParakeetModelInner,
    ) -> Result<(), TranscriptionError> {
        if self.cached_decoder.is_some() {
            return Ok(());
        }

        self.decoder_targets[[0, 0]] = self.last_token;
        self.cached_decoder = Some(model.run_decoder(
            &self.decoder_targets,
            &self.decoder_target_length,
            &self.hidden_state,
            &self.cell_state,
        )?);
        Ok(())
    }

    fn cached_decoder(&self) -> Result<&CachedDecoderStep, TranscriptionError> {
        self.cached_decoder.as_ref().ok_or_else(|| {
            TranscriptionError::InvalidModelOutput("decoder cache was not primed".to_owned())
        })
    }

    fn take_cached_decoder(&mut self) -> Result<CachedDecoderStep, TranscriptionError> {
        self.cached_decoder.take().ok_or_else(|| {
            TranscriptionError::InvalidModelOutput("decoder cache was not primed".to_owned())
        })
    }

    fn encoder_step(&self) -> &Array3<f32> {
        &self.encoder_step
    }

    fn record_emission(
        &mut self,
        token_id: usize,
        duration_step: usize,
        confidence: f32,
        hidden_state: Array3<f32>,
        cell_state: Array3<f32>,
    ) {
        self.hidden_state = hidden_state;
        self.cell_state = cell_state;
        self.raw.token_ids.push(token_id as u32);
        self.raw.frame_indices.push(self.frame_idx);
        self.raw.durations.push(duration_step);
        self.raw.confidences.push(confidence);
        self.last_token = token_id as i32;
        self.emitted_tokens += 1;
    }

    fn advance(&mut self, token_id: usize, duration_step: usize, blank_id: usize) {
        if duration_step > 0 {
            self.frame_idx += duration_step;
            self.emitted_tokens = 0;
            return;
        }

        if token_id == blank_id || self.emitted_tokens >= MAX_TOKENS_PER_STEP {
            self.frame_idx += 1;
            self.emitted_tokens = 0;
        }
    }

    fn into_raw(self) -> RawTranscription {
        self.raw
    }
}

#[derive(Debug, Clone)]
struct CachedDecoderStep {
    decoder_step: Array3<f32>,
    hidden_state: Array3<f32>,
    cell_state: Array3<f32>,
}

#[derive(Debug, Clone)]
struct EncoderInputBuffer {
    values: Vec<f32>,
    feature_size: usize,
    target_frames: usize,
    last_feature_frames: usize,
}

impl EncoderInputBuffer {
    fn new(feature_size: usize, target_frames: usize) -> Self {
        Self {
            values: vec![0.0; feature_size * target_frames],
            feature_size,
            target_frames,
            last_feature_frames: 0,
        }
    }

    fn copy_from_features(
        &mut self,
        features: &Array2<f32>,
        feature_frames: usize,
        target_frames: usize,
    ) -> Result<(), TranscriptionError> {
        if features.shape()[1] != self.feature_size {
            return Err(TranscriptionError::InvalidModelOutput(format!(
                "feature size {} did not match encoder input {}",
                features.shape()[1],
                self.feature_size
            )));
        }
        if self.target_frames != target_frames {
            self.values.resize(self.feature_size * target_frames, 0.0);
            self.target_frames = target_frames;
            self.last_feature_frames = 0;
        }
        if feature_frames > self.target_frames {
            return Err(TranscriptionError::InvalidModelOutput(format!(
                "feature frame count {feature_frames} exceeded encoder target {}",
                self.target_frames
            )));
        }

        for feature_idx in 0..self.feature_size {
            let base = feature_idx * self.target_frames;
            for frame_idx in 0..feature_frames {
                self.values[base + frame_idx] = features[[frame_idx, feature_idx]];
            }
            if feature_frames < self.last_feature_frames {
                self.values[base + feature_frames..base + self.last_feature_frames].fill(0.0);
            }
        }

        self.last_feature_frames = feature_frames;
        Ok(())
    }

    fn values(&self) -> &[f32] {
        &self.values
    }

    fn feature_size(&self) -> usize {
        self.feature_size
    }

    fn target_frames(&self) -> usize {
        self.target_frames
    }
}

#[derive(Debug, Clone)]
struct JointDecision {
    token_id: usize,
    duration_step: usize,
    confidence: f32,
}

#[derive(Debug, Clone)]
enum ParakeetModelInner {
    #[cfg(target_os = "macos")]
    SplitCoreMl(crate::coreml::ParakeetSplitCoreMlModel),
    #[cfg(not(target_os = "macos"))]
    Unsupported,
}

impl ParakeetModelInner {
    fn from_bundle(bundle: &ModelBundle) -> Result<Self, TranscriptionError> {
        #[cfg(target_os = "macos")]
        {
            Ok(Self::SplitCoreMl(
                crate::coreml::ParakeetSplitCoreMlModel::new(
                    bundle.encoder_dir(),
                    bundle.decoder_dir(),
                    bundle.joint_decision_dir(),
                )?,
            ))
        }
        #[cfg(not(target_os = "macos"))]
        {
            let _ = bundle;
            Err(TranscriptionError::UnsupportedPlatform)
        }
    }

    fn run_encoder(
        &self,
        input: &[f32],
        feature_size: usize,
        target_frames: usize,
        feature_frames: usize,
    ) -> Result<(Array3<f32>, usize), TranscriptionError> {
        match self {
            #[cfg(target_os = "macos")]
            Self::SplitCoreMl(model) => {
                model.run_encoder(input, feature_size, target_frames, &[feature_frames as i32])
            }
            #[cfg(not(target_os = "macos"))]
            Self::Unsupported => Err(TranscriptionError::UnsupportedPlatform),
        }
    }

    fn run_decoder(
        &self,
        targets: &Array2<i32>,
        target_length: &Array1<i32>,
        hidden_state: &Array3<f32>,
        cell_state: &Array3<f32>,
    ) -> Result<CachedDecoderStep, TranscriptionError> {
        match self {
            #[cfg(target_os = "macos")]
            Self::SplitCoreMl(model) => {
                let output = model.run_decoder(targets, target_length, hidden_state, cell_state)?;
                Ok(CachedDecoderStep {
                    decoder_step: output.decoder_step,
                    hidden_state: output.hidden_state,
                    cell_state: output.cell_state,
                })
            }
            #[cfg(not(target_os = "macos"))]
            Self::Unsupported => Err(TranscriptionError::UnsupportedPlatform),
        }
    }

    fn run_joint(
        &self,
        encoder_step: &Array3<f32>,
        decoder_step: &Array3<f32>,
    ) -> Result<JointDecision, TranscriptionError> {
        match self {
            #[cfg(target_os = "macos")]
            Self::SplitCoreMl(model) => {
                let output = model.run_joint(encoder_step, decoder_step)?;
                Ok(JointDecision {
                    token_id: output.token_id,
                    duration_step: output.duration_step,
                    confidence: output.token_prob,
                })
            }
            #[cfg(not(target_os = "macos"))]
            Self::Unsupported => Err(TranscriptionError::UnsupportedPlatform),
        }
    }
}