use anyhow::{Context, Result};
use ort::session::Session;
use ort::value::TensorRef;
use super::{CONTEXT_SIZE, DECODER_OUT_DIM, DecoderState, ENCODER_OUT_DIM};
const MAX_TOKENS_PER_STEP: usize = 10;
pub(crate) const ENDPOINT_BLANK_THRESHOLD: usize = 15;
#[derive(Debug, Clone)]
pub(crate) struct TokenInfo {
pub token_id: usize,
pub frame_index: usize,
pub confidence: f32,
}
#[derive(Debug)]
pub(crate) struct DecodeResult {
pub tokens: Vec<TokenInfo>,
pub endpoint_detected: bool,
}
pub(crate) fn extract_encoder_frame(
encoded: &[f32],
encoded_len: usize,
t: usize,
enc_frame: &mut [f32],
) {
let dim = enc_frame.len();
assert!(
t < encoded_len,
"frame index {t} out of range {encoded_len}"
);
let start = t * dim;
enc_frame.copy_from_slice(&encoded[start..start + dim]);
}
pub(crate) fn argmax(logits: &[f32], blank_id: usize) -> usize {
logits
.iter()
.enumerate()
.max_by(|(_i, a), (_j, b)| a.total_cmp(b))
.map(|(idx, _)| idx)
.unwrap_or(blank_id)
}
pub(crate) fn argmax_with_confidence(logits: &[f32], blank_id: usize) -> (usize, f32) {
if logits.is_empty() {
return (blank_id, 0.0);
}
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum();
let token = argmax(logits, blank_id);
let confidence = (logits[token] - max_logit).exp() / sum_exp;
(token, confidence)
}
fn run_decoder(decoder: &mut Session, context: &[i64], out: &mut Vec<f32>) -> Result<()> {
debug_assert_eq!(context.len(), CONTEXT_SIZE);
let input_tensor = TensorRef::from_array_view(([1_usize, CONTEXT_SIZE], context))?;
let outputs = decoder
.run(ort::inputs![input_tensor])
.context("Decoder inference failed")?;
let (_shape, data) = outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract decoder output")?;
out.clear();
out.extend_from_slice(data);
Ok(())
}
fn run_joiner_single(
joiner: &mut Session,
enc_frame: &[f32],
dec_data: &[f32],
out: &mut Vec<f32>,
) -> Result<()> {
let enc_tensor = TensorRef::from_array_view(([1_usize, ENCODER_OUT_DIM], enc_frame))?;
let dec_tensor = TensorRef::from_array_view(([1_usize, DECODER_OUT_DIM], dec_data))?;
let outputs = joiner
.run(ort::inputs![enc_tensor, dec_tensor])
.context("Joiner inference failed")?;
let (_shape, logits) = outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract joiner output")?;
out.clear();
out.extend_from_slice(logits);
Ok(())
}
pub fn greedy_decode(
decoder: &mut Session,
joiner: &mut Session,
encoded: &[f32], encoded_len: usize,
blank_id: usize,
vocab_size: usize,
state: &mut DecoderState,
) -> Result<DecodeResult> {
anyhow::ensure!(
encoded.len() >= ENCODER_OUT_DIM * encoded_len,
"Encoder output size mismatch: got {}, expected >= {}",
encoded.len(),
ENCODER_OUT_DIM * encoded_len
);
let mut tokens = Vec::new();
let mut endpoint_detected = false;
let mut enc_frame = vec![0.0_f32; ENCODER_OUT_DIM];
let mut joiner_buf = Vec::with_capacity(vocab_size);
let mut decoder_buf_a = Vec::with_capacity(DECODER_OUT_DIM);
let mut decoder_buf_b = Vec::with_capacity(DECODER_OUT_DIM);
let mut decoder_calls: u32 = 0;
let mut joiner_calls: u32 = 0;
run_decoder(decoder, &state.tokens, &mut decoder_buf_a)?;
decoder_calls += 1;
for t in 0..encoded_len {
extract_encoder_frame(encoded, encoded_len, t, &mut enc_frame);
let mut tokens_this_step = 0;
loop {
joiner_calls += 1;
run_joiner_single(joiner, &enc_frame, &decoder_buf_a, &mut joiner_buf)?;
let (token, confidence) = argmax_with_confidence(&joiner_buf, blank_id);
if token == blank_id {
state.consecutive_blanks += 1;
if state.consecutive_blanks >= ENDPOINT_BLANK_THRESHOLD && !tokens.is_empty() {
endpoint_detected = true;
}
break;
}
if tokens_this_step >= MAX_TOKENS_PER_STEP {
state.consecutive_blanks += 1;
if state.consecutive_blanks >= ENDPOINT_BLANK_THRESHOLD && !tokens.is_empty() {
endpoint_detected = true;
}
break;
}
state.consecutive_blanks = 0;
state.push_token(token as i64);
run_decoder(decoder, &state.tokens, &mut decoder_buf_b)?;
decoder_calls += 1;
std::mem::swap(&mut decoder_buf_a, &mut decoder_buf_b);
tokens.push(TokenInfo {
token_id: token,
frame_index: t,
confidence,
});
tokens_this_step += 1;
}
}
tracing::debug!(
decoder_calls,
joiner_calls,
encoded_len,
"decode_loop_stats"
);
Ok(DecodeResult {
tokens,
endpoint_detected,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_encoder_frame_first() {
let encoded = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut frame = vec![0.0; 2];
extract_encoder_frame(&encoded, 3, 0, &mut frame);
assert_eq!(frame, vec![1.0, 2.0]);
}
#[test]
fn test_extract_encoder_frame_last() {
let encoded = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut frame = vec![0.0; 2];
extract_encoder_frame(&encoded, 3, 2, &mut frame);
assert_eq!(frame, vec![5.0, 6.0]);
}
#[test]
fn test_extract_encoder_frame_middle() {
let encoded = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut frame = vec![0.0; 2];
extract_encoder_frame(&encoded, 3, 1, &mut frame);
assert_eq!(frame, vec![3.0, 4.0]);
}
#[test]
fn test_argmax_clear_winner() {
let logits = vec![0.1, 0.5, 0.9, 0.2];
assert_eq!(argmax(&logits, 999), 2);
}
#[test]
fn test_argmax_tie_returns_last() {
let logits = vec![1.0, 1.0, 0.5];
assert_eq!(argmax(&logits, 999), 1);
}
#[test]
fn test_argmax_negative_values() {
let logits = vec![-3.0, -1.0, -2.0];
assert_eq!(argmax(&logits, 999), 1);
}
#[test]
fn test_argmax_empty_returns_blank() {
let logits: Vec<f32> = vec![];
assert_eq!(argmax(&logits, 1024), 1024);
}
#[test]
fn test_argmax_blank_id_selected() {
let logits = vec![0.1, 0.2, 0.9];
assert_eq!(argmax(&logits, 2), 2);
}
#[test]
fn test_confidence_picks_top_softmax_value() {
let logits = vec![0.0, 0.0, 100.0];
let (id, conf) = argmax_with_confidence(&logits, 0);
assert_eq!(id, 2);
assert!(conf > 0.999, "expected near-1 confidence, got {conf}");
}
}