use std::borrow::Cow;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use ndarray::{Array2, ArrayD, Ix3, IxDyn};
use ort::session::Session;
use ort::session::SessionInputValue;
use ort::value::DynValue;
use super::{session, Quantization};
use crate::decode::{load_vocab, sentencepiece_to_text, GreedyDecoder};
use crate::{
ModelCapabilities, SpeechModel, TranscribeError, TranscribeOptions, TranscriptionResult,
};
const SAMPLE_RATE: u32 = 16000;
const NUM_DECODER_LAYERS: usize = 8;
const NUM_HEADS: usize = 8;
const HEAD_DIM: usize = 128;
const MAX_SEQ_LEN: usize = 1024;
const DEFAULT_MAX_NEW_TOKENS: usize = 512;
const CAPABILITIES: ModelCapabilities = ModelCapabilities {
name: "Cohere",
engine_id: "cohere",
sample_rate: SAMPLE_RATE,
languages: &[
"en", "de", "fr", "it", "es", "pt", "el", "nl", "pl", "ar", "vi", "zh", "ja", "ko",
],
supports_timestamps: false,
supports_translation: false,
supports_streaming: false,
};
#[derive(Debug, Clone, Default)]
pub struct CohereParams {
pub language: Option<String>,
pub translate: bool,
pub max_new_tokens: Option<usize>,
}
pub struct CohereModel {
encoder: Session,
decoder: Session,
vocab: Vec<String>,
token_to_id: HashMap<String, i64>,
eos_id: i64,
encoder_input_name: String,
decoder_input_names: Vec<String>,
}
impl CohereModel {
pub fn load(model_dir: &Path, quantization: &Quantization) -> Result<Self, TranscribeError> {
let encoder_path = resolve_model_file(
model_dir,
encoder_candidates(quantization),
"cohere-encoder.int4.onnx",
)?;
let decoder_path = resolve_model_file(
model_dir,
decoder_candidates(quantization),
"cohere-decoder.int4.onnx",
)?;
let vocab_path =
resolve_model_file(model_dir, &["tokens.txt", "vocabulary.txt"], "tokens.txt")?;
log::info!("Loading Cohere encoder from {:?}...", encoder_path);
let encoder = session::create_session(&encoder_path)?;
log::info!("Loading Cohere decoder from {:?}...", decoder_path);
let decoder = session::create_session(&decoder_path)?;
let (vocab, _) = load_vocab(&vocab_path)?;
let token_to_id = vocab
.iter()
.enumerate()
.filter(|(_, token)| !token.is_empty())
.map(|(id, token)| (token.clone(), id as i64))
.collect::<HashMap<_, _>>();
let encoder_input_name = encoder
.inputs()
.first()
.map(|input| input.name().to_string())
.unwrap_or_else(|| "audio".to_string());
let decoder_input_names = decoder
.inputs()
.iter()
.map(|input| input.name().to_string())
.collect::<Vec<_>>();
let eos_id = token_to_id.get("<|endoftext|>").copied().unwrap_or(3);
Ok(Self {
encoder,
decoder,
vocab,
token_to_id,
eos_id,
encoder_input_name,
decoder_input_names,
})
}
pub fn transcribe_with(
&mut self,
samples: &[f32],
params: &CohereParams,
) -> Result<TranscriptionResult, TranscribeError> {
if params.translate {
log::warn!(
"Cohere ONNX export does not support local translation; ignoring translate=true"
);
}
if samples.is_empty() {
return Ok(TranscriptionResult {
text: String::new(),
segments: None,
});
}
let prompt_ids = self.build_prompt_ids(params.language.as_deref());
let max_new_tokens = params
.max_new_tokens
.unwrap_or(DEFAULT_MAX_NEW_TOKENS)
.min(MAX_SEQ_LEN.saturating_sub(prompt_ids.len()));
let text = self.transcribe_chunk(samples, &prompt_ids, max_new_tokens)?;
Ok(TranscriptionResult {
text,
segments: None,
})
}
fn transcribe_chunk(
&mut self,
samples: &[f32],
prompt_ids: &[i64],
max_new_tokens: usize,
) -> Result<String, TranscribeError> {
let audio = Array2::from_shape_vec((1, samples.len()), samples.to_vec())?.into_dyn();
let (cross_k, cross_v) = {
let mut encoder_outputs = self.encoder.run(vec![(
Cow::Owned(self.encoder_input_name.clone()),
ort::value::Value::from_array(audio)?.into_dyn(),
)])?;
let cross_k = remove_output(&mut encoder_outputs, "n_layer_cross_k")?;
let cross_v = remove_output(&mut encoder_outputs, "n_layer_cross_v")?;
(cross_k, cross_v)
};
let token_name = self.decoder_input_name("tokens", &["input_ids"]);
let self_k_name = self.decoder_input_name(
"in_n_layer_self_k_cache",
&["past_key_values", "past_key_values.key"],
);
let self_v_name =
self.decoder_input_name("in_n_layer_self_v_cache", &["past_key_values.value"]);
let cross_k_name = self.decoder_input_name("n_layer_cross_k", &["encoder_kv_cache.key"]);
let cross_v_name = self.decoder_input_name("n_layer_cross_v", &["encoder_kv_cache.value"]);
let offset_name = self.decoder_input_name("offset", &["cache_position"]);
let mut greedy = GreedyDecoder::new(self.eos_id);
let mut generated_ids: Vec<i64> = Vec::new();
let mut current_tokens = prompt_ids.to_vec();
let mut offset = 0_i64;
let mut self_k_cache: DynValue =
ort::value::Value::from_array(ArrayD::<f32>::zeros(IxDyn(&[
NUM_DECODER_LAYERS,
1,
NUM_HEADS,
MAX_SEQ_LEN,
HEAD_DIM,
])))?
.into_dyn();
let mut self_v_cache: DynValue =
ort::value::Value::from_array(ArrayD::<f32>::zeros(IxDyn(&[
NUM_DECODER_LAYERS,
1,
NUM_HEADS,
MAX_SEQ_LEN,
HEAD_DIM,
])))?
.into_dyn();
for _ in 0..max_new_tokens {
let n_tokens = current_tokens.len();
let tokens = Array2::from_shape_vec((1, n_tokens), current_tokens.clone())?.into_dyn();
let offset_tensor = ndarray::arr0(offset).into_dyn();
let inputs: Vec<(Cow<str>, SessionInputValue)> = vec![
(
Cow::Borrowed(token_name.as_str()),
SessionInputValue::from(ort::value::Value::from_array(tokens)?),
),
(
Cow::Borrowed(self_k_name.as_str()),
SessionInputValue::from(self_k_cache),
),
(
Cow::Borrowed(self_v_name.as_str()),
SessionInputValue::from(self_v_cache),
),
(
Cow::Borrowed(cross_k_name.as_str()),
SessionInputValue::from(&cross_k),
),
(
Cow::Borrowed(cross_v_name.as_str()),
SessionInputValue::from(&cross_v),
),
(
Cow::Borrowed(offset_name.as_str()),
SessionInputValue::from(ort::value::Value::from_array(offset_tensor)?),
),
];
let mut decoder_outputs = self.decoder.run(inputs)?;
let last_logits = {
let logits = decoder_outputs
.get("logits")
.ok_or_else(|| TranscribeError::Inference("Missing logits output".into()))?
.try_extract_array::<f32>()?;
let logits = logits
.into_dimensionality::<Ix3>()
.map_err(|e| TranscribeError::Inference(e.to_string()))?;
let last_pos = logits.shape()[1].saturating_sub(1);
logits.slice(ndarray::s![0, last_pos, ..]).to_vec()
};
let next_token = match greedy.next_token(&last_logits) {
Some(t) => t,
None => break,
};
generated_ids.push(next_token);
current_tokens = vec![next_token];
offset += n_tokens as i64;
self_k_cache = remove_output(&mut decoder_outputs, "out_n_layer_self_k_cache")?;
self_v_cache = remove_output(&mut decoder_outputs, "out_n_layer_self_v_cache")?;
}
Ok(self.decode_ids(&generated_ids))
}
fn build_prompt_ids(&self, language: Option<&str>) -> Vec<i64> {
let requested = match language.unwrap_or("en") {
"auto" => "en",
"zh-Hans" | "zh-Hant" => "zh",
other => other,
};
let language_token = format!("<|{}|>", requested);
let chosen_language = if self.token_to_id.contains_key(&language_token) {
requested
} else {
"en"
};
[
"<|startofcontext|>".to_string(),
"<|startoftranscript|>".to_string(),
"<|emo:undefined|>".to_string(),
format!("<|{}|>", chosen_language),
format!("<|{}|>", chosen_language),
"<|pnc|>".to_string(),
"<|noitn|>".to_string(),
"<|notimestamp|>".to_string(),
"<|nodiarize|>".to_string(),
]
.iter()
.filter_map(|token| {
let id = self.token_to_id.get(token).copied();
if id.is_none() {
log::warn!("Prompt token not found in vocab: {}", token);
}
id
})
.collect()
}
fn decode_ids(&self, token_ids: &[i64]) -> String {
let pieces = token_ids
.iter()
.filter_map(|&id| self.vocab.get(id as usize))
.filter(|token| {
!token.trim().is_empty()
&& !token.starts_with("<|")
&& token.as_str() != "<unk>"
&& token.as_str() != "<pad>"
})
.map(|token| token.as_str())
.collect::<Vec<_>>();
sentencepiece_to_text(&pieces)
}
fn decoder_input_name(&self, preferred: &str, fallbacks: &[&str]) -> String {
if self
.decoder_input_names
.iter()
.any(|name| name == preferred)
{
return preferred.to_string();
}
for fallback in fallbacks {
if self.decoder_input_names.iter().any(|name| name == fallback) {
return (*fallback).to_string();
}
}
preferred.to_string()
}
}
impl SpeechModel for CohereModel {
fn capabilities(&self) -> ModelCapabilities {
CAPABILITIES
}
fn transcribe_raw(
&mut self,
samples: &[f32],
options: &TranscribeOptions,
) -> Result<TranscriptionResult, TranscribeError> {
self.transcribe_with(
samples,
&CohereParams {
language: options.language.clone(),
translate: options.translate,
max_new_tokens: None,
},
)
}
}
fn resolve_model_file(
model_dir: &Path,
candidates: &[&str],
missing_name: &str,
) -> Result<PathBuf, TranscribeError> {
for base_dir in [model_dir.to_path_buf(), model_dir.join("onnx")] {
for candidate in candidates {
let path = base_dir.join(candidate);
if path.exists() {
return Ok(path);
}
}
}
Err(TranscribeError::ModelNotFound(model_dir.join(missing_name)))
}
fn encoder_candidates(quantization: &Quantization) -> &'static [&'static str] {
match quantization {
Quantization::Int4 => &["cohere-encoder.int4.onnx", "encoder_model.int4.onnx"],
Quantization::Int8 => &["cohere-encoder.int8.onnx", "encoder_model.int8.onnx"],
Quantization::FP16 => &["cohere-encoder.fp16.onnx", "encoder_model_fp16.onnx"],
Quantization::FP32 => &["cohere-encoder.onnx", "encoder_model.onnx"],
}
}
fn decoder_candidates(quantization: &Quantization) -> &'static [&'static str] {
match quantization {
Quantization::Int4 => &["cohere-decoder.int4.onnx", "decoder_model_merged.int4.onnx"],
Quantization::Int8 => &["cohere-decoder.int8.onnx", "decoder_model_merged.int8.onnx"],
Quantization::FP16 => &["cohere-decoder.fp16.onnx", "decoder_model_merged_fp16.onnx"],
Quantization::FP32 => &["cohere-decoder.onnx", "decoder_model_merged.onnx"],
}
}
fn remove_output(
outputs: &mut ort::session::SessionOutputs,
name: &str,
) -> Result<DynValue, TranscribeError> {
outputs
.remove(name)
.ok_or_else(|| TranscribeError::Inference(format!("Missing expected output: {name}")))
}