transcribe-cli 0.0.6

Native Rust CLI transcription pipeline with GigaAM v3 ONNX
use std::path::Path;

use anyhow::{Context, Result, anyhow, bail};
use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView3, Axis, Ix1, Ix3};
use ort::value::TensorRef;

use crate::onnx_ctc::{CtcVocabulary, ExecutionMode, VocabularyOptions, build_session};

pub struct OnnxTransducerRuntime {
    encoder: ort::session::Session,
    decoder_joint: ort::session::Session,
    vocabulary: CtcVocabulary,
    state_spec: DecoderStateSpec,
    max_tokens_per_step: usize,
}

impl OnnxTransducerRuntime {
    pub fn new(
        encoder_path: &Path,
        decoder_joint_path: &Path,
        vocab_path: &Path,
        execution: &ExecutionMode,
        vocabulary_options: VocabularyOptions,
        max_tokens_per_step: usize,
    ) -> Result<Self> {
        let vocabulary = CtcVocabulary::from_text_file(vocab_path, vocabulary_options)?;
        let encoder = build_session(encoder_path, execution)?;
        let decoder_joint = build_session(decoder_joint_path, execution)?;
        let state_spec = DecoderStateSpec::from_session(&decoder_joint)?;

        Ok(Self {
            encoder,
            decoder_joint,
            vocabulary,
            state_spec,
            max_tokens_per_step: max_tokens_per_step.max(1),
        })
    }

    pub fn transcribe_features(
        &mut self,
        features: ArrayView3<'_, f32>,
        lengths: ArrayView1<'_, i64>,
    ) -> Result<String> {
        let (encoder_out, encoder_lengths) = self.encode(features, lengths)?;
        let encodings = encoder_out.index_axis(Axis(0), 0);
        let encoded_len = encoder_lengths.get(0).copied().unwrap_or_default().max(0) as usize;

        self.decode_tdt(encodings, encoded_len)
    }

    fn encode(
        &mut self,
        features: ArrayView3<'_, f32>,
        lengths: ArrayView1<'_, i64>,
    ) -> Result<(Array3<f32>, Array1<i64>)> {
        let outputs = self.encoder.run(ort::inputs![
            TensorRef::from_array_view(features)?,
            TensorRef::from_array_view(lengths)?,
        ])?;
        let encoder_out = outputs[0]
            .try_extract_array::<f32>()
            .context("failed to extract transducer encoder output tensor")?;
        let encoder_out = encoder_out
            .into_dimensionality::<Ix3>()
            .map_err(|error| anyhow!("unexpected transducer encoder output shape: {error}"))?
            .permuted_axes([0, 2, 1])
            .to_owned();

        let encoder_lengths = outputs[1]
            .try_extract_array::<i64>()
            .context("failed to extract transducer encoder lengths tensor")?;
        let encoder_lengths = encoder_lengths
            .into_dimensionality::<Ix1>()
            .map_err(|error| anyhow!("unexpected transducer length output shape: {error}"))?
            .to_owned();

        Ok((encoder_out, encoder_lengths))
    }

    fn decode_tdt(
        &mut self,
        encodings: ndarray::ArrayView2<'_, f32>,
        encoded_len: usize,
    ) -> Result<String> {
        let valid_len = encoded_len.min(encodings.len_of(Axis(0)));
        if valid_len == 0 {
            return Ok(String::new());
        }

        let blank_id = self.vocabulary.blank_id();
        let vocab_size = self.vocabulary.len();
        let mut state = self.state_spec.initial_state();
        let mut tokens = Vec::new();
        let mut t = 0usize;
        let mut emitted_tokens = 0usize;

        while t < valid_len {
            let frame = encodings.index_axis(Axis(0), t);
            let previous_token = tokens.last().copied().unwrap_or(blank_id) as i32;
            let (output, next_state) = self.decode_step(frame, previous_token, &state)?;

            if output.len() <= vocab_size {
                bail!(
                    "unexpected Parakeet TDT output width {}; expected more than vocabulary size {}",
                    output.len(),
                    vocab_size
                );
            }

            let (token_logits, duration_logits) = output.split_at(vocab_size);
            let token = argmax(token_logits).context("Parakeet TDT emitted empty token logits")?;
            let step =
                argmax(duration_logits).context("Parakeet TDT emitted empty duration logits")?;

            if token != blank_id {
                state = next_state;
                tokens.push(token);
                emitted_tokens += 1;
            }

            if step > 0 {
                t += step;
                emitted_tokens = 0;
            } else if token == blank_id || emitted_tokens == self.max_tokens_per_step {
                t += 1;
                emitted_tokens = 0;
            }
        }

        Ok(self.vocabulary.decode_ids(&tokens))
    }

    fn decode_step(
        &mut self,
        encoder_frame: ndarray::ArrayView1<'_, f32>,
        previous_token: i32,
        state: &DecoderState,
    ) -> Result<(Vec<f32>, DecoderState)> {
        let frame_width = encoder_frame.len();
        let mut encoder_outputs = Array3::<f32>::zeros((1, frame_width, 1));
        for (index, value) in encoder_frame.iter().copied().enumerate() {
            encoder_outputs[(0, index, 0)] = value;
        }
        let targets = Array2::<i32>::from_shape_vec((1, 1), vec![previous_token])
            .context("failed to build transducer target tensor")?;
        let target_length = Array1::<i32>::from_vec(vec![1]);

        let outputs = self.decoder_joint.run(ort::inputs![
            TensorRef::from_array_view(encoder_outputs.view())?,
            TensorRef::from_array_view(targets.view())?,
            TensorRef::from_array_view(target_length.view())?,
            TensorRef::from_array_view(state.state1.view())?,
            TensorRef::from_array_view(state.state2.view())?,
        ])?;

        let logits = outputs[0]
            .try_extract_array::<f32>()
            .context("failed to extract transducer decoder output tensor")?;
        let logits = logits.iter().copied().collect::<Vec<_>>();

        let state1 = outputs[2]
            .try_extract_array::<f32>()
            .context("failed to extract transducer recurrent state 1")?;
        let state1 = state1
            .into_dimensionality::<Ix3>()
            .map_err(|error| anyhow!("unexpected transducer state_1 shape: {error}"))?
            .to_owned();

        let state2 = outputs[3]
            .try_extract_array::<f32>()
            .context("failed to extract transducer recurrent state 2")?;
        let state2 = state2
            .into_dimensionality::<Ix3>()
            .map_err(|error| anyhow!("unexpected transducer state_2 shape: {error}"))?
            .to_owned();

        Ok((logits, DecoderState { state1, state2 }))
    }
}

#[derive(Clone, Debug)]
struct DecoderStateSpec {
    state1_layers: usize,
    state1_hidden: usize,
    state2_layers: usize,
    state2_hidden: usize,
}

impl DecoderStateSpec {
    fn from_session(session: &ort::session::Session) -> Result<Self> {
        let (state1_layers, state1_hidden) = infer_state_shape(session, "input_states_1")?;
        let (state2_layers, state2_hidden) = infer_state_shape(session, "input_states_2")?;
        Ok(Self {
            state1_layers,
            state1_hidden,
            state2_layers,
            state2_hidden,
        })
    }

    fn initial_state(&self) -> DecoderState {
        DecoderState {
            state1: Array3::<f32>::zeros((self.state1_layers, 1, self.state1_hidden)),
            state2: Array3::<f32>::zeros((self.state2_layers, 1, self.state2_hidden)),
        }
    }
}

#[derive(Clone, Debug)]
struct DecoderState {
    state1: Array3<f32>,
    state2: Array3<f32>,
}

fn infer_state_shape(session: &ort::session::Session, input_name: &str) -> Result<(usize, usize)> {
    let input = session
        .inputs
        .iter()
        .find(|input| input.name == input_name)
        .with_context(|| format!("transducer model does not define `{input_name}`"))?;
    let shape = input
        .input_type
        .tensor_shape()
        .with_context(|| format!("transducer input `{input_name}` is not a tensor"))?;

    if shape.len() != 3 {
        bail!(
            "transducer input `{input_name}` has unsupported rank {}; expected 3",
            shape.len()
        );
    }

    let layers = usize::try_from(shape[0])
        .ok()
        .filter(|value| *value > 0)
        .with_context(|| {
            format!(
                "transducer input `{input_name}` has unsupported layer dimension {}",
                shape[0]
            )
        })?;
    let hidden = usize::try_from(shape[2])
        .ok()
        .filter(|value| *value > 0)
        .with_context(|| {
            format!(
                "transducer input `{input_name}` has unsupported hidden dimension {}",
                shape[2]
            )
        })?;
    Ok((layers, hidden))
}

fn argmax(values: &[f32]) -> Option<usize> {
    values
        .iter()
        .enumerate()
        .max_by(|left, right| left.1.total_cmp(right.1))
        .map(|(index, _)| index)
}