use super::super::types::{ExecutorResult, RawOutputs};
use crate::runtime_adapter::AdapterError;
use ndarray::IxDyn;
pub fn ctc_decode_step(
data: RawOutputs,
vocab_path: &str,
blank_index: usize,
) -> ExecutorResult<RawOutputs> {
let tensor_map = match data {
RawOutputs::TensorMap(map) => map,
_ => {
return Err(AdapterError::InvalidInput(
"CTCDecode requires tensor map".to_string(),
))
}
};
let logits = tensor_map
.values()
.next()
.ok_or_else(|| AdapterError::InvalidInput("No outputs for CTCDecode".to_string()))?;
let shape = logits.shape();
if shape.len() != 3 {
return Err(AdapterError::InvalidInput(format!(
"CTCDecode expects 3D tensor [batch, time, vocab], got {:?}",
shape
)));
}
let _batch_size = shape[0];
let time_steps = shape[1];
let vocab_size = shape[2];
let mut token_ids = Vec::new();
let mut prev_id: Option<usize> = None;
for t in 0..time_steps {
let mut max_val = f32::NEG_INFINITY;
let mut max_idx = 0;
for v in 0..vocab_size {
let val = logits[IxDyn(&[0, t, v])]; if val > max_val {
max_val = val;
max_idx = v;
}
}
if max_idx != blank_index && Some(max_idx) != prev_id {
token_ids.push(max_idx);
}
prev_id = Some(max_idx);
}
let text = decode_ctc_tokens(&token_ids, vocab_path)?;
Ok(RawOutputs::Text(text))
}
pub fn bpe_decode_step(data: RawOutputs, vocab_path: &str) -> ExecutorResult<RawOutputs> {
let token_ids = match data {
RawOutputs::TokenIds(ids) => ids,
_ => {
return Err(AdapterError::InvalidInput(
"BPEDecode requires token IDs".to_string(),
))
}
};
let text = decode_bpe_tokens(&token_ids, vocab_path)?;
Ok(RawOutputs::Text(text))
}
pub fn whisper_decode_step(data: RawOutputs, tokenizer_path: &str) -> ExecutorResult<RawOutputs> {
let token_ids = match data {
RawOutputs::TokenIds(ids) => ids,
_ => {
return Err(AdapterError::InvalidInput(
"WhisperDecode requires token IDs".to_string(),
))
}
};
let text = decode_whisper_tokens(&token_ids, tokenizer_path)?;
Ok(RawOutputs::Text(text))
}
fn decode_ctc_tokens(token_ids: &[usize], vocab_path: &str) -> ExecutorResult<String> {
let content = std::fs::read_to_string(vocab_path)
.map_err(|e| AdapterError::InvalidInput(format!("Failed to read vocab file: {}", e)))?;
let json_vocab = if content.trim().starts_with('{') {
let json_vocab = serde_json::from_str::<std::collections::HashMap<String, usize>>(&content)
.map_err(|e| {
AdapterError::InvalidInput(format!("Failed to parse vocab JSON: {}", e))
})?;
let max_id = json_vocab.values().max().copied().unwrap_or(0);
let mut id_to_char = vec![String::new(); max_id + 1];
for (char_str, id) in json_vocab {
if id < id_to_char.len() {
id_to_char[id] = char_str;
}
}
Some(id_to_char)
} else {
None
};
let vocab: Vec<String> = if let Some(jv) = json_vocab {
jv
} else {
content
.lines()
.map(|line| line.trim().to_string())
.collect()
};
let mut text = String::new();
for &id in token_ids {
if id < vocab.len() {
let token = &vocab[id];
if token == "|" {
text.push(' '); } else if !token.starts_with('<') && !token.ends_with('>') {
text.push_str(token);
}
}
}
Ok(text.split_whitespace().collect::<Vec<_>>().join(" "))
}
fn decode_bpe_tokens(token_ids: &[usize], vocab_path: &str) -> ExecutorResult<String> {
use base64::{engine::general_purpose, Engine as _};
let content = std::fs::read_to_string(vocab_path)
.map_err(|e| AdapterError::InvalidInput(format!("Failed to read vocab file: {}", e)))?;
let tokens: Vec<String> = content
.lines()
.map(|line| line.trim().to_string())
.collect();
let mut decoded_bytes = Vec::new();
for &id in token_ids {
if id < tokens.len() {
let token_line = &tokens[id];
if token_line.starts_with("<|") && token_line.ends_with("|>") {
continue;
}
let base64_part = if let Some(space_idx) = token_line.find(' ') {
&token_line[..space_idx]
} else {
token_line
};
if let Ok(bytes) = general_purpose::STANDARD.decode(base64_part) {
decoded_bytes.extend_from_slice(&bytes);
}
}
}
Ok(String::from_utf8_lossy(&decoded_bytes).to_string())
}
fn decode_whisper_tokens(token_ids: &[usize], tokenizer_path: &str) -> ExecutorResult<String> {
use tokenizers::Tokenizer;
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| AdapterError::InvalidInput(format!("Failed to load tokenizer: {}", e)))?;
let ids: Vec<u32> = token_ids.iter().map(|&id| id as u32).collect();
let filtered_ids: Vec<u32> = ids.into_iter().filter(|&id| id < 50257).collect();
let text = tokenizer
.decode(&filtered_ids, true) .map_err(|e| AdapterError::InvalidInput(format!("Failed to decode tokens: {}", e)))?;
Ok(text.trim().to_string())
}