transcribe-rs 0.3.7

A simple library to help you transcribe audio
Documentation
//! Cohere ONNX encoder/decoder transcription engine.

use std::borrow::Cow;
use std::collections::HashMap;
use std::path::{Path, PathBuf};

use ndarray::{Array2, ArrayD, Ix3, IxDyn};
use ort::session::Session;
use ort::session::SessionInputValue;
use ort::value::DynValue;

use super::{session, Quantization};
use crate::decode::{load_vocab, sentencepiece_to_text, GreedyDecoder};
use crate::{
    ModelCapabilities, SpeechModel, TranscribeError, TranscribeOptions, TranscriptionResult,
};

const SAMPLE_RATE: u32 = 16000;
const NUM_DECODER_LAYERS: usize = 8;
const NUM_HEADS: usize = 8;
const HEAD_DIM: usize = 128;
const MAX_SEQ_LEN: usize = 1024;
const DEFAULT_MAX_NEW_TOKENS: usize = 512;

const CAPABILITIES: ModelCapabilities = ModelCapabilities {
    name: "Cohere",
    engine_id: "cohere",
    sample_rate: SAMPLE_RATE,
    languages: &[
        "en", "de", "fr", "it", "es", "pt", "el", "nl", "pl", "ar", "vi", "zh", "ja", "ko",
    ],
    supports_timestamps: false,
    supports_translation: false,
    supports_streaming: false,
};

#[derive(Debug, Clone, Default)]
pub struct CohereParams {
    /// Language hint for prompt tokens. Defaults to English.
    pub language: Option<String>,
    /// Translation is currently ignored by the local ONNX export.
    pub translate: bool,
    /// Maximum number of autoregressive tokens to emit per chunk.
    pub max_new_tokens: Option<usize>,
}

pub struct CohereModel {
    encoder: Session,
    decoder: Session,
    vocab: Vec<String>,
    token_to_id: HashMap<String, i64>,
    eos_id: i64,
    encoder_input_name: String,
    decoder_input_names: Vec<String>,
}

impl CohereModel {
    pub fn load(model_dir: &Path, quantization: &Quantization) -> Result<Self, TranscribeError> {
        let encoder_path = resolve_model_file(
            model_dir,
            encoder_candidates(quantization),
            "cohere-encoder.int4.onnx",
        )?;
        let decoder_path = resolve_model_file(
            model_dir,
            decoder_candidates(quantization),
            "cohere-decoder.int4.onnx",
        )?;
        let vocab_path =
            resolve_model_file(model_dir, &["tokens.txt", "vocabulary.txt"], "tokens.txt")?;

        log::info!("Loading Cohere encoder from {:?}...", encoder_path);
        let encoder = session::create_session(&encoder_path)?;

        log::info!("Loading Cohere decoder from {:?}...", decoder_path);
        let decoder = session::create_session(&decoder_path)?;

        let (vocab, _) = load_vocab(&vocab_path)?;
        let token_to_id = vocab
            .iter()
            .enumerate()
            .filter(|(_, token)| !token.is_empty())
            .map(|(id, token)| (token.clone(), id as i64))
            .collect::<HashMap<_, _>>();

        let encoder_input_name = encoder
            .inputs()
            .first()
            .map(|input| input.name().to_string())
            .unwrap_or_else(|| "audio".to_string());
        let decoder_input_names = decoder
            .inputs()
            .iter()
            .map(|input| input.name().to_string())
            .collect::<Vec<_>>();
        let eos_id = token_to_id.get("<|endoftext|>").copied().unwrap_or(3);

        Ok(Self {
            encoder,
            decoder,
            vocab,
            token_to_id,
            eos_id,
            encoder_input_name,
            decoder_input_names,
        })
    }

    pub fn transcribe_with(
        &mut self,
        samples: &[f32],
        params: &CohereParams,
    ) -> Result<TranscriptionResult, TranscribeError> {
        if params.translate {
            log::warn!(
                "Cohere ONNX export does not support local translation; ignoring translate=true"
            );
        }

        if samples.is_empty() {
            return Ok(TranscriptionResult {
                text: String::new(),
                segments: None,
            });
        }

        let prompt_ids = self.build_prompt_ids(params.language.as_deref());
        let max_new_tokens = params
            .max_new_tokens
            .unwrap_or(DEFAULT_MAX_NEW_TOKENS)
            .min(MAX_SEQ_LEN.saturating_sub(prompt_ids.len()));

        let text = self.transcribe_chunk(samples, &prompt_ids, max_new_tokens)?;

        Ok(TranscriptionResult {
            text,
            segments: None,
        })
    }

    fn transcribe_chunk(
        &mut self,
        samples: &[f32],
        prompt_ids: &[i64],
        max_new_tokens: usize,
    ) -> Result<String, TranscribeError> {
        let audio = Array2::from_shape_vec((1, samples.len()), samples.to_vec())?.into_dyn();
        let (cross_k, cross_v) = {
            let mut encoder_outputs = self.encoder.run(vec![(
                Cow::Owned(self.encoder_input_name.clone()),
                ort::value::Value::from_array(audio)?.into_dyn(),
            )])?;
            let cross_k = remove_output(&mut encoder_outputs, "n_layer_cross_k")?;
            let cross_v = remove_output(&mut encoder_outputs, "n_layer_cross_v")?;
            (cross_k, cross_v)
        };

        // Resolve decoder input names once before the loop.
        let token_name = self.decoder_input_name("tokens", &["input_ids"]);
        let self_k_name = self.decoder_input_name(
            "in_n_layer_self_k_cache",
            &["past_key_values", "past_key_values.key"],
        );
        let self_v_name =
            self.decoder_input_name("in_n_layer_self_v_cache", &["past_key_values.value"]);
        let cross_k_name = self.decoder_input_name("n_layer_cross_k", &["encoder_kv_cache.key"]);
        let cross_v_name = self.decoder_input_name("n_layer_cross_v", &["encoder_kv_cache.value"]);
        let offset_name = self.decoder_input_name("offset", &["cache_position"]);

        let mut greedy = GreedyDecoder::new(self.eos_id);
        let mut generated_ids: Vec<i64> = Vec::new();
        let mut current_tokens = prompt_ids.to_vec();
        let mut offset = 0_i64;

        let mut self_k_cache: DynValue =
            ort::value::Value::from_array(ArrayD::<f32>::zeros(IxDyn(&[
                NUM_DECODER_LAYERS,
                1,
                NUM_HEADS,
                MAX_SEQ_LEN,
                HEAD_DIM,
            ])))?
            .into_dyn();
        let mut self_v_cache: DynValue =
            ort::value::Value::from_array(ArrayD::<f32>::zeros(IxDyn(&[
                NUM_DECODER_LAYERS,
                1,
                NUM_HEADS,
                MAX_SEQ_LEN,
                HEAD_DIM,
            ])))?
            .into_dyn();

        for _ in 0..max_new_tokens {
            let n_tokens = current_tokens.len();
            let tokens = Array2::from_shape_vec((1, n_tokens), current_tokens.clone())?.into_dyn();
            let offset_tensor = ndarray::arr0(offset).into_dyn();

            // Build inputs: move self caches (replaced from output below),
            // borrow cross caches (constant across the loop).
            let inputs: Vec<(Cow<str>, SessionInputValue)> = vec![
                (
                    Cow::Borrowed(token_name.as_str()),
                    SessionInputValue::from(ort::value::Value::from_array(tokens)?),
                ),
                (
                    Cow::Borrowed(self_k_name.as_str()),
                    SessionInputValue::from(self_k_cache),
                ),
                (
                    Cow::Borrowed(self_v_name.as_str()),
                    SessionInputValue::from(self_v_cache),
                ),
                (
                    Cow::Borrowed(cross_k_name.as_str()),
                    SessionInputValue::from(&cross_k),
                ),
                (
                    Cow::Borrowed(cross_v_name.as_str()),
                    SessionInputValue::from(&cross_v),
                ),
                (
                    Cow::Borrowed(offset_name.as_str()),
                    SessionInputValue::from(ort::value::Value::from_array(offset_tensor)?),
                ),
            ];

            let mut decoder_outputs = self.decoder.run(inputs)?;

            // Extract last-position logits in a scoped borrow, then release before remove().
            let last_logits = {
                let logits = decoder_outputs
                    .get("logits")
                    .ok_or_else(|| TranscribeError::Inference("Missing logits output".into()))?
                    .try_extract_array::<f32>()?;
                let logits = logits
                    .into_dimensionality::<Ix3>()
                    .map_err(|e| TranscribeError::Inference(e.to_string()))?;
                let last_pos = logits.shape()[1].saturating_sub(1);
                logits.slice(ndarray::s![0, last_pos, ..]).to_vec()
            };

            let next_token = match greedy.next_token(&last_logits) {
                Some(t) => t,
                None => break,
            };

            generated_ids.push(next_token);
            current_tokens = vec![next_token];
            offset += n_tokens as i64;

            // Take KV caches directly from outputs (no data copy).
            self_k_cache = remove_output(&mut decoder_outputs, "out_n_layer_self_k_cache")?;
            self_v_cache = remove_output(&mut decoder_outputs, "out_n_layer_self_v_cache")?;
        }

        Ok(self.decode_ids(&generated_ids))
    }

    fn build_prompt_ids(&self, language: Option<&str>) -> Vec<i64> {
        let requested = match language.unwrap_or("en") {
            "auto" => "en",
            "zh-Hans" | "zh-Hant" => "zh",
            other => other,
        };

        let language_token = format!("<|{}|>", requested);
        let chosen_language = if self.token_to_id.contains_key(&language_token) {
            requested
        } else {
            "en"
        };

        [
            "<|startofcontext|>".to_string(),
            "<|startoftranscript|>".to_string(),
            "<|emo:undefined|>".to_string(),
            format!("<|{}|>", chosen_language),
            format!("<|{}|>", chosen_language),
            "<|pnc|>".to_string(),
            "<|noitn|>".to_string(),
            "<|notimestamp|>".to_string(),
            "<|nodiarize|>".to_string(),
        ]
        .iter()
        .filter_map(|token| {
            let id = self.token_to_id.get(token).copied();
            if id.is_none() {
                log::warn!("Prompt token not found in vocab: {}", token);
            }
            id
        })
        .collect()
    }

    fn decode_ids(&self, token_ids: &[i64]) -> String {
        let pieces = token_ids
            .iter()
            .filter_map(|&id| self.vocab.get(id as usize))
            .filter(|token| {
                !token.trim().is_empty()
                    && !token.starts_with("<|")
                    && token.as_str() != "<unk>"
                    && token.as_str() != "<pad>"
            })
            .map(|token| token.as_str())
            .collect::<Vec<_>>();

        sentencepiece_to_text(&pieces)
    }

    fn decoder_input_name(&self, preferred: &str, fallbacks: &[&str]) -> String {
        if self
            .decoder_input_names
            .iter()
            .any(|name| name == preferred)
        {
            return preferred.to_string();
        }

        for fallback in fallbacks {
            if self.decoder_input_names.iter().any(|name| name == fallback) {
                return (*fallback).to_string();
            }
        }

        preferred.to_string()
    }
}

impl SpeechModel for CohereModel {
    fn capabilities(&self) -> ModelCapabilities {
        CAPABILITIES
    }

    fn transcribe_raw(
        &mut self,
        samples: &[f32],
        options: &TranscribeOptions,
    ) -> Result<TranscriptionResult, TranscribeError> {
        self.transcribe_with(
            samples,
            &CohereParams {
                language: options.language.clone(),
                translate: options.translate,
                max_new_tokens: None,
            },
        )
    }
}

fn resolve_model_file(
    model_dir: &Path,
    candidates: &[&str],
    missing_name: &str,
) -> Result<PathBuf, TranscribeError> {
    for base_dir in [model_dir.to_path_buf(), model_dir.join("onnx")] {
        for candidate in candidates {
            let path = base_dir.join(candidate);
            if path.exists() {
                return Ok(path);
            }
        }
    }

    Err(TranscribeError::ModelNotFound(model_dir.join(missing_name)))
}

fn encoder_candidates(quantization: &Quantization) -> &'static [&'static str] {
    match quantization {
        Quantization::Int4 => &["cohere-encoder.int4.onnx", "encoder_model.int4.onnx"],
        Quantization::Int8 => &["cohere-encoder.int8.onnx", "encoder_model.int8.onnx"],
        Quantization::FP16 => &["cohere-encoder.fp16.onnx", "encoder_model_fp16.onnx"],
        Quantization::FP32 => &["cohere-encoder.onnx", "encoder_model.onnx"],
    }
}

fn decoder_candidates(quantization: &Quantization) -> &'static [&'static str] {
    match quantization {
        Quantization::Int4 => &["cohere-decoder.int4.onnx", "decoder_model_merged.int4.onnx"],
        Quantization::Int8 => &["cohere-decoder.int8.onnx", "decoder_model_merged.int8.onnx"],
        Quantization::FP16 => &["cohere-decoder.fp16.onnx", "decoder_model_merged_fp16.onnx"],
        Quantization::FP32 => &["cohere-decoder.onnx", "decoder_model_merged.onnx"],
    }
}

fn remove_output(
    outputs: &mut ort::session::SessionOutputs,
    name: &str,
) -> Result<DynValue, TranscribeError> {
    outputs
        .remove(name)
        .ok_or_else(|| TranscribeError::Inference(format!("Missing expected output: {name}")))
}