use crate::audio::extract_features_raw;
use crate::config::PreprocessorConfig;
use crate::error::{Error, Result};
use crate::execution::ModelConfig as ExecutionConfig;
use crate::model_cohere::{CohereEncoderOutput, CohereModel, CoherePastKv, N_MELS};
use ndarray::{Array2, Axis};
use std::collections::HashMap;
use std::path::Path;
use tokenizers::Tokenizer;
const TOKEN_WORD_BOUNDARY: &str = "\u{2581}";
const TOKEN_STARTOFCONTEXT: &str = "<|startofcontext|>";
const TOKEN_STARTOFTRANSCRIPT: &str = "<|startoftranscript|>";
const TOKEN_EMO_UNDEFINED: &str = "<|emo:undefined|>";
const TOKEN_ENDOFTEXT: &str = "<|endoftext|>";
const TOKEN_PNC: &str = "<|pnc|>";
const TOKEN_NOPNC: &str = "<|nopnc|>";
const TOKEN_NOTIMESTAMP: &str = "<|notimestamp|>";
const TOKEN_NODIARIZE: &str = "<|nodiarize|>";
const TOKEN_ITN: &str = "<|itn|>";
const TOKEN_NOITN: &str = "<|noitn|>";
const MAX_DECODE_TOKENS_LIMIT: usize = 1024;
const DEFAULT_MAX_DECODE_TOKENS: usize = 512;
const TRAINING_CHUNK_SECS: f32 = 35.0;
const SUPPORTED_LANGUAGES: &[&str] = &[
"ar", "de", "el", "en", "es", "fr", "it", "ja", "ko", "nl", "pl", "pt", "vi", "zh",
];
struct DecoderTokens {
decoder_start: i64,
startofcontext: i64,
sot: i64,
emo_undefined: i64,
eos: i64,
pnc: i64,
nopnc: i64,
notimestamp: i64,
nodiarize: i64,
itn: i64,
noitn: i64,
}
impl DecoderTokens {
fn resolve(tokenizer: &Tokenizer) -> Result<Self> {
Ok(Self {
decoder_start: require_token(tokenizer, TOKEN_WORD_BOUNDARY)?,
startofcontext: require_token(tokenizer, TOKEN_STARTOFCONTEXT)?,
sot: require_token(tokenizer, TOKEN_STARTOFTRANSCRIPT)?,
emo_undefined: require_token(tokenizer, TOKEN_EMO_UNDEFINED)?,
eos: require_token(tokenizer, TOKEN_ENDOFTEXT)?,
pnc: require_token(tokenizer, TOKEN_PNC)?,
nopnc: require_token(tokenizer, TOKEN_NOPNC)?,
notimestamp: require_token(tokenizer, TOKEN_NOTIMESTAMP)?,
nodiarize: require_token(tokenizer, TOKEN_NODIARIZE)?,
itn: require_token(tokenizer, TOKEN_ITN)?,
noitn: require_token(tokenizer, TOKEN_NOITN)?,
})
}
}
fn cohere_preprocessor_config() -> PreprocessorConfig {
PreprocessorConfig {
feature_extractor_type: "CohereAsrFeatureExtractor".to_string(),
feature_size: N_MELS,
hop_length: 160,
n_fft: 512,
padding_side: "right".to_string(),
padding_value: 0.0,
preemphasis: 0.97,
processor_class: "CohereAsrProcessor".to_string(),
return_attention_mask: true,
sampling_rate: 16000,
win_length: 400,
}
}
pub struct CohereASR {
model: CohereModel,
tokenizer: Tokenizer,
preprocessor: PreprocessorConfig,
lang_tokens: HashMap<String, i64>,
tokens: DecoderTokens,
max_decode_tokens: usize,
}
impl CohereASR {
pub fn from_pretrained<P: AsRef<Path>>(
model_dir: P,
exec_config: Option<ExecutionConfig>,
) -> Result<Self> {
let model_dir = model_dir.as_ref();
let exec = exec_config.unwrap_or_default();
let model = CohereModel::from_pretrained(model_dir, exec)?;
let tokenizer_path = model_dir.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(Error::Config(format!(
"Missing tokenizer.json in {}",
model_dir.display()
)));
}
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| Error::Tokenizer(format!("Failed to load tokenizer.json: {e}")))?;
let preprocessor = cohere_preprocessor_config();
let tokens = DecoderTokens::resolve(&tokenizer)?;
let mut lang_tokens = HashMap::with_capacity(SUPPORTED_LANGUAGES.len());
for code in SUPPORTED_LANGUAGES {
let lit = format!("<|{code}|>");
if let Some(id) = tokenizer.token_to_id(&lit) {
lang_tokens.insert((*code).to_string(), id as i64);
}
}
if lang_tokens.is_empty() {
return Err(Error::Tokenizer(
"No supported language tokens found in tokenizer.json".into(),
));
}
Ok(Self {
model,
tokenizer,
preprocessor,
lang_tokens,
tokens,
max_decode_tokens: DEFAULT_MAX_DECODE_TOKENS,
})
}
pub fn training_chunk_secs(&self) -> f32 {
TRAINING_CHUNK_SECS
}
pub fn max_decode_tokens(&self) -> usize {
self.max_decode_tokens
}
pub fn set_max_decode_tokens(&mut self, max: usize) {
self.max_decode_tokens = max.clamp(1, MAX_DECODE_TOKENS_LIMIT);
}
pub fn transcribe_audio(
&mut self,
audio: &[f32],
language: &str,
punctuation: bool,
itn: bool,
) -> Result<String> {
if audio.is_empty() {
return Ok(String::new());
}
let lang_token = self.lang_tokens.get(language).copied().ok_or_else(|| {
Error::Config(format!(
"Unsupported language '{}'. Supported: {:?}",
language,
self.supported_languages()
))
})?;
let mel_2d = extract_features_raw(
audio.to_vec(),
self.preprocessor.sampling_rate as u32,
1,
&self.preprocessor,
)?;
let mel_3d = mel_2d.insert_axis(Axis(0)).as_standard_layout().to_owned();
let encoder_out = self.model.run_encoder(&mel_3d)?;
let t = &self.tokens;
let pnc_token = if punctuation { t.pnc } else { t.nopnc };
let itn_token = if itn { t.itn } else { t.noitn };
let prompt = vec![
t.decoder_start,
t.startofcontext,
t.sot,
t.emo_undefined,
lang_token,
lang_token,
pnc_token,
itn_token,
t.notimestamp,
t.nodiarize,
];
let token_ids = self.decode_greedy(&prompt, &encoder_out)?;
let text = self
.tokenizer
.decode(
&token_ids.iter().map(|&i| i as u32).collect::<Vec<_>>(),
true,
)
.map_err(|e| Error::Tokenizer(format!("Failed to decode tokens: {e}")))?;
let cleaned = text
.trim()
.trim_start_matches(['.', '?', '!', ','])
.trim()
.to_string();
Ok(cleaned)
}
fn decode_greedy(
&mut self,
prompt: &[i64],
encoder_out: &CohereEncoderOutput,
) -> Result<Vec<i64>> {
let mut past_kv = CoherePastKv::empty();
let mut output_tokens: Vec<i64> = Vec::new();
let prompt_tensor = Array2::from_shape_vec((1, prompt.len()), prompt.to_vec())
.map_err(|e| Error::Model(format!("Prompt tensor shape error: {e}")))?;
let (logits, new_past) =
self.model
.run_decoder_step(&prompt_tensor, &past_kv, encoder_out)?;
past_kv = new_past;
let mut next_token = argmax(logits.as_slice().unwrap());
if next_token == self.tokens.eos {
return Ok(output_tokens);
}
output_tokens.push(next_token);
for _ in 1..self.max_decode_tokens {
let token_tensor = Array2::from_shape_vec((1, 1), vec![next_token])
.map_err(|e| Error::Model(format!("Token tensor shape error: {e}")))?;
let (logits, new_past) =
self.model
.run_decoder_step(&token_tensor, &past_kv, encoder_out)?;
past_kv = new_past;
next_token = argmax(logits.as_slice().unwrap());
if next_token == self.tokens.eos {
break;
}
output_tokens.push(next_token);
if let Some(repeat_len) = find_ngram_repetition(&output_tokens, 8) {
output_tokens.truncate(output_tokens.len() - repeat_len);
break;
}
}
Ok(output_tokens)
}
pub fn supported_languages(&self) -> Vec<String> {
let mut langs: Vec<String> = self.lang_tokens.keys().cloned().collect();
langs.sort();
langs
}
}
fn require_token(tokenizer: &Tokenizer, literal: &str) -> Result<i64> {
tokenizer
.token_to_id(literal)
.map(|id| id as i64)
.ok_or_else(|| Error::Tokenizer(format!("Tokenizer is missing required token {literal}")))
}
fn find_ngram_repetition(tokens: &[i64], min_len: usize) -> Option<usize> {
let n = tokens.len();
if n < min_len * 2 {
return None;
}
for repeat_len in min_len..=(n / 2) {
let tail = &tokens[n - repeat_len..];
let prev = &tokens[n - 2 * repeat_len..n - repeat_len];
if tail == prev {
return Some(repeat_len);
}
}
None
}
fn argmax(logits: &[f32]) -> i64 {
logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as i64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_argmax() {
assert_eq!(argmax(&[0.1, 0.5, 0.3, 0.9, 0.2]), 3);
assert_eq!(argmax(&[1.0, 0.0, 0.0]), 0);
}
#[test]
fn test_supported_languages_count() {
assert_eq!(SUPPORTED_LANGUAGES.len(), 14);
}
#[test]
fn test_ngram_repetition_detection() {
assert_eq!(find_ngram_repetition(&[1, 2, 3, 4, 5, 6, 7, 8], 4), None);
assert_eq!(find_ngram_repetition(&[1, 2, 3, 4, 1, 2, 3, 4], 4), Some(4));
let mut tokens = vec![10, 20, 30, 40, 50, 60, 70, 80];
tokens.extend_from_slice(&[10, 20, 30, 40, 50, 60, 70, 80]);
assert_eq!(find_ngram_repetition(&tokens, 8), Some(8));
assert_eq!(find_ngram_repetition(&[1, 2, 1, 2], 4), None);
}
}