use anyhow::{Context, Result};
use burn::tensor::{Tensor, backend::Backend};
use tokenizers::Tokenizer;
use crate::kv_cache::{KvCache, forward_decoder_cached};
use crate::model::Whisper;
pub const LANGUAGE_CODES: &[&str] = &[
"en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl", "ca", "nl", "ar", "sv", "it",
"id", "hi", "fi", "vi", "he", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no", "th", "ur",
"hr", "bg", "lt", "la", "mi", "ml", "cy", "sk", "te", "fa", "lv", "bn", "sr", "az", "sl", "kn",
"et", "mk", "br", "eu", "is", "hy", "ne", "mn", "bs", "kk", "sq", "sw", "gl", "mr", "pa", "si",
"km", "sn", "yo", "so", "af", "oc", "ka", "be", "tg", "sd", "gu", "am", "yi", "lo", "uz", "fo",
"ht", "ps", "tk", "nn", "mt", "sa", "lb", "my", "bo", "tl", "mg", "as", "tt", "haw", "ln",
"ha", "ba", "jw", "su", "yue",
];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Task {
#[default]
Transcribe,
Translate,
}
pub fn language_token_id(tokenizer: &Tokenizer, code: &str) -> Option<u32> {
tokenizer.token_to_id(&format!("<|{code}|>"))
}
pub fn task_token_id(tokenizer: &Tokenizer, task: Task) -> Option<u32> {
let s = match task {
Task::Transcribe => "<|transcribe|>",
Task::Translate => "<|translate|>",
};
tokenizer.token_to_id(s)
}
pub fn detect_language<B: Backend>(
model: &Whisper<B>,
encoder_out: Tensor<B, 3>,
tokenizer: &Tokenizer,
sot_token: u32,
device: &B::Device,
) -> Result<(String, u32)> {
let mut cache = KvCache::new(model, encoder_out);
let logits = forward_decoder_cached(model, sot_token, &mut cache, device)
.context("language-detection forward pass")?;
let mut best: Option<(f32, u32, &str)> = None;
for &code in LANGUAGE_CODES {
let Some(id) = language_token_id(tokenizer, code) else {
continue;
};
let Some(&logit) = logits.get(id as usize) else {
continue;
};
if best.is_none_or(|(b, _, _)| logit > b) {
best = Some((logit, id, code));
}
}
let (_, id, code) = best.context(
"no language tokens found in tokenizer — language auto-detection requires a \
multilingual model (English-only .en models cannot detect language)",
)?;
Ok((code.to_string(), id))
}