use super::super::types::ExecutorResult;
use super::{parse_kv_cache_name, parse_present_name_full};
use crate::execution::template::PipelineStage;
use crate::runtime_adapter::onnx::ONNXSession;
use crate::runtime_adapter::AdapterError;
use ndarray::{Array2, ArrayD, IxDyn};
use ort::value::Value;
use std::collections::HashMap;
#[allow(clippy::too_many_arguments)]
pub fn execute_whisper_decoder_stage(
_stage: &PipelineStage,
stage_outputs: &HashMap<String, HashMap<String, ArrayD<f32>>>,
config: &HashMap<String, serde_json::Value>,
max_tokens: usize,
start_token_id: i64,
end_token_id: i64,
language_token_id: i64,
task_token_id: i64,
no_timestamps_token_id: i64,
suppress_tokens: &[i64],
repetition_penalty: f32,
session: &ONNXSession,
) -> ExecutorResult<Vec<usize>> {
let num_layers = config
.get("num_layers")
.and_then(|v| v.as_u64())
.unwrap_or(4) as usize;
let num_heads = config
.get("num_heads")
.and_then(|v| v.as_u64())
.unwrap_or(6) as usize;
let head_dim = config
.get("head_dim")
.and_then(|v| v.as_u64())
.unwrap_or(64) as usize;
let encoder_seq_len = config
.get("encoder_seq_len")
.and_then(|v| v.as_u64())
.unwrap_or(1500) as usize;
let encoder_outputs = stage_outputs
.get("encoder")
.ok_or_else(|| AdapterError::InvalidInput("No encoder outputs found".to_string()))?;
let encoder_hidden_states = encoder_outputs
.get("last_hidden_state")
.or_else(|| encoder_outputs.values().next())
.ok_or_else(|| AdapterError::InvalidInput("No encoder hidden states".to_string()))?;
let enc_shape = encoder_hidden_states.shape();
let batch_size = enc_shape[0];
let input_names = session.input_names();
let forced_tokens: Vec<i64> = vec![
start_token_id, language_token_id, task_token_id, no_timestamps_token_id, ];
let num_forced = forced_tokens.len();
let mut decoder_kv_cache: Vec<ArrayD<f32>> = Vec::new();
for _ in 0..(num_layers * 2) {
let kv = ArrayD::<f32>::zeros(IxDyn(&[batch_size, num_heads, 0, head_dim]));
decoder_kv_cache.push(kv);
}
let mut encoder_kv_cache: Vec<ArrayD<f32>> = (0..(num_layers * 2))
.map(|_| ArrayD::<f32>::zeros(IxDyn(&[batch_size, num_heads, encoder_seq_len, head_dim])))
.collect();
let mut generated_tokens: Vec<usize> = forced_tokens.iter().map(|&t| t as usize).collect();
let suppress_set: std::collections::HashSet<i64> = suppress_tokens.iter().copied().collect();
for step in 0..max_tokens {
let current_token = if step < num_forced {
forced_tokens[step]
} else {
*generated_tokens.last().unwrap() as i64
};
let input_ids =
Array2::<i64>::from_shape_vec((batch_size, 1), vec![current_token; batch_size])
.map_err(|e| {
AdapterError::InvalidInput(format!("Failed to create input_ids: {}", e))
})?;
let input_ids_value: Value = Value::from_array(input_ids)
.map_err(|e| AdapterError::InvalidInput(format!("Failed to convert input_ids: {}", e)))?
.into();
let mut inputs: HashMap<String, Value> = HashMap::new();
if let Some(name) = input_names.iter().find(|n| n.contains("input_ids")) {
inputs.insert(name.clone(), input_ids_value);
} else if !input_names.is_empty() {
inputs.insert(input_names[0].clone(), input_ids_value);
}
let enc_hidden_value: Value = Value::from_array(encoder_hidden_states.clone())
.map_err(|e| {
AdapterError::InvalidInput(format!(
"Failed to convert encoder hidden states: {}",
e
))
})?
.into();
for name in input_names.iter() {
if name.contains("past_key_values") {
if let Some(captures) = parse_kv_cache_name(name) {
let (layer, is_encoder, is_key) = captures;
if is_encoder {
let kv_idx = layer * 2 + if is_key { 0 } else { 1 };
if kv_idx < encoder_kv_cache.len() {
let kv_value: Value =
Value::from_array(encoder_kv_cache[kv_idx].clone())
.map_err(|e| {
AdapterError::InvalidInput(format!(
"Failed to convert encoder KV: {}",
e
))
})?
.into();
inputs.insert(name.clone(), kv_value);
}
} else {
let kv_idx = layer * 2 + if is_key { 0 } else { 1 };
if kv_idx < decoder_kv_cache.len() {
let kv_value: Value =
Value::from_array(decoder_kv_cache[kv_idx].clone())
.map_err(|e| {
AdapterError::InvalidInput(format!(
"Failed to convert decoder KV: {}",
e
))
})?
.into();
inputs.insert(name.clone(), kv_value);
}
}
}
}
}
if let Some(name) = input_names
.iter()
.find(|n| n.contains("encoder_hidden_states"))
{
inputs.insert(name.clone(), enc_hidden_value);
}
let outputs = session
.run_with_values(inputs)
.map_err(|e| AdapterError::InferenceFailed(format!("Whisper decoder failed: {}", e)))?;
let logits = outputs
.get("logits")
.or_else(|| outputs.values().next())
.ok_or_else(|| AdapterError::InvalidInput("No logits output".to_string()))?;
for (name, tensor) in &outputs {
if name.starts_with("present.") {
if let Some((layer, is_encoder, is_key)) = parse_present_name_full(name) {
let kv_idx = layer * 2 + if is_key { 0 } else { 1 };
if is_encoder {
if step == 0 && kv_idx < encoder_kv_cache.len() {
encoder_kv_cache[kv_idx] = tensor.clone();
}
} else {
if kv_idx < decoder_kv_cache.len() {
decoder_kv_cache[kv_idx] = tensor.clone();
}
}
}
}
}
if step < num_forced - 1 {
continue;
}
let mut logits_vec = logits
.as_slice()
.ok_or_else(|| AdapterError::InvalidInput("Logits not contiguous".to_string()))?
.to_vec();
for &token in &suppress_set {
if (token as usize) < logits_vec.len() {
logits_vec[token as usize] = f32::NEG_INFINITY;
}
}
if repetition_penalty != 1.0 && generated_tokens.len() > 4 {
let recent: std::collections::HashSet<usize> =
generated_tokens.iter().rev().take(10).copied().collect();
for token in &recent {
if *token < logits_vec.len() {
let score = logits_vec[*token];
logits_vec[*token] = if score > 0.0 {
score / repetition_penalty
} else {
score * repetition_penalty
};
}
}
}
let next_token = logits_vec
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(end_token_id as usize);
if next_token == end_token_id as usize {
break;
}
if generated_tokens.len() >= 5 {
let last_five: Vec<usize> = generated_tokens.iter().rev().take(5).copied().collect();
if last_five.iter().all(|&id| id == next_token) {
break;
}
}
generated_tokens.push(next_token);
}
Ok(generated_tokens)
}