use anyhow::{Context, Result};
use ort::session::Session;
use ort::value::TensorRef;
use super::bias::Biaser;
use super::{DecoderState, PRED_HIDDEN};
const MAX_TOKENS_PER_STEP: usize = 10;
const ENC_DIM: usize = 768;
pub(crate) const ENDPOINT_BLANK_THRESHOLD: usize = 15;
#[derive(Debug, Clone)]
pub(crate) struct TokenInfo {
pub token_id: usize,
pub frame_index: usize,
pub confidence: f32,
}
#[derive(Debug)]
pub(crate) struct DecodeResult {
pub tokens: Vec<TokenInfo>,
pub endpoint_detected: bool,
}
pub(crate) fn extract_encoder_frame(
encoded: &[f32],
encoded_len: usize,
t: usize,
enc_frame: &mut [f32],
) {
for ch in 0..enc_frame.len() {
enc_frame[ch] = encoded[ch * encoded_len + t];
}
}
pub(crate) fn argmax(logits: &[f32], blank_id: usize) -> usize {
logits
.iter()
.enumerate()
.max_by(|(_i, a), (_j, b)| a.total_cmp(b))
.map(|(idx, _)| idx)
.unwrap_or(blank_id)
}
pub(crate) fn argmax_with_confidence(logits: &[f32], blank_id: usize) -> (usize, f32) {
if logits.is_empty() {
return (blank_id, 0.0);
}
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum();
let token = argmax(logits, blank_id);
let confidence = (logits[token] - max_logit).exp() / sum_exp;
(token, confidence)
}
#[derive(Default)]
pub(crate) struct DecoderOutput {
dec_data: Vec<f32>,
new_h: Vec<f32>,
new_c: Vec<f32>,
}
impl DecoderOutput {
fn fill(dst: &mut Vec<f32>, src: &[f32]) {
if dst.len() != src.len() {
dst.resize(src.len(), 0.0);
}
dst.copy_from_slice(src);
}
}
fn run_decoder(decoder: &mut Session, state: &DecoderState, out: &mut DecoderOutput) -> Result<()> {
let target_data = [state.prev_token];
let target_tensor = TensorRef::from_array_view(([1_usize, 1], target_data.as_slice()))?;
let h_tensor = TensorRef::from_array_view(([1_usize, 1, PRED_HIDDEN], state.h.as_slice()))?;
let c_tensor = TensorRef::from_array_view(([1_usize, 1, PRED_HIDDEN], state.c.as_slice()))?;
let decoder_outputs = decoder
.run(ort::inputs![target_tensor, h_tensor, c_tensor])
.context("Decoder inference failed")?;
let (_dec_shape, dec_data) = decoder_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract decoder output")?;
let (_h_shape, new_h_data) = decoder_outputs[1]
.try_extract_tensor::<f32>()
.context("Failed to extract decoder h state")?;
let (_c_shape, new_c_data) = decoder_outputs[2]
.try_extract_tensor::<f32>()
.context("Failed to extract decoder c state")?;
DecoderOutput::fill(&mut out.dec_data, dec_data);
DecoderOutput::fill(&mut out.new_h, new_h_data);
DecoderOutput::fill(&mut out.new_c, new_c_data);
Ok(())
}
fn run_joiner_single(
joiner: &mut Session,
enc_frame: &[f32],
dec_data: &[f32],
logits_buf: &mut Vec<f32>,
) -> Result<()> {
let enc_tensor = TensorRef::from_array_view(([1_usize, ENC_DIM, 1], enc_frame))?;
let dec_tensor = TensorRef::from_array_view(([1_usize, PRED_HIDDEN, 1], dec_data))?;
let joiner_outputs = joiner
.run(ort::inputs![enc_tensor, dec_tensor])
.context("Joiner inference failed")?;
let (_joint_shape, logits) = joiner_outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract joiner output")?;
DecoderOutput::fill(logits_buf, logits);
Ok(())
}
pub(crate) trait DecodeBackend {
fn decode_step(&mut self, state: &DecoderState, out: &mut DecoderOutput) -> Result<()>;
fn joiner_step(
&mut self,
enc_frame: &[f32],
dec_data: &[f32],
logits_buf: &mut Vec<f32>,
) -> Result<()>;
}
struct OrtBackend<'a> {
decoder: &'a mut Session,
joiner: &'a mut Session,
}
impl DecodeBackend for OrtBackend<'_> {
fn decode_step(&mut self, state: &DecoderState, out: &mut DecoderOutput) -> Result<()> {
run_decoder(self.decoder, state, out)
}
fn joiner_step(
&mut self,
enc_frame: &[f32],
dec_data: &[f32],
logits_buf: &mut Vec<f32>,
) -> Result<()> {
run_joiner_single(self.joiner, enc_frame, dec_data, logits_buf)
}
}
pub fn greedy_decode(
decoder: &mut Session,
joiner: &mut Session,
encoded: &[f32], encoded_len: usize,
blank_id: usize,
state: &mut DecoderState,
biaser: Option<&Biaser>,
) -> Result<DecodeResult> {
let mut backend = OrtBackend { decoder, joiner };
greedy_decode_impl(&mut backend, encoded, encoded_len, blank_id, state, biaser)
}
fn greedy_decode_impl<B: DecodeBackend>(
backend: &mut B,
encoded: &[f32], encoded_len: usize,
blank_id: usize,
state: &mut DecoderState,
biaser: Option<&Biaser>,
) -> Result<DecodeResult> {
let mut tokens = Vec::new();
let mut endpoint_detected = false;
let mut enc_frame = vec![0.0_f32; ENC_DIM];
let mut logits_buf = Vec::new();
let mut decoder_calls: u32 = 0;
let mut joiner_calls: u32 = 0;
let mut skipped_decoder_calls: u32 = 0;
let mut decoder_out = DecoderOutput::default();
let mut cache_valid = false;
let mut in_blank_run = false;
let mut bias_state = biaser.map(|b| b.new_state());
anyhow::ensure!(
encoded.len() >= ENC_DIM * encoded_len,
"Encoder output size mismatch: got {}, expected >= {}",
encoded.len(),
ENC_DIM * encoded_len
);
for t in 0..encoded_len {
let mut tokens_this_step = 0;
extract_encoder_frame(encoded, encoded_len, t, &mut enc_frame);
loop {
if in_blank_run {
skipped_decoder_calls += 1;
if !cache_valid {
anyhow::bail!("blank run invariant violated: decoder output cache is stale");
}
} else {
decoder_calls += 1;
backend.decode_step(state, &mut decoder_out)?;
cache_valid = true;
}
joiner_calls += 1;
backend.joiner_step(&enc_frame, &decoder_out.dec_data, &mut logits_buf)?;
if let (Some(b), Some(bs)) = (biaser, bias_state.as_ref()) {
b.boost_logits(bs, &mut logits_buf);
}
let (token, confidence) = argmax_with_confidence(&logits_buf, blank_id);
if token == blank_id {
in_blank_run = true;
state.consecutive_blanks += 1;
if state.consecutive_blanks >= ENDPOINT_BLANK_THRESHOLD && !tokens.is_empty() {
endpoint_detected = true;
}
break;
}
if tokens_this_step >= MAX_TOKENS_PER_STEP {
in_blank_run = false;
cache_valid = false;
state.consecutive_blanks = 0;
break;
}
in_blank_run = false;
state.consecutive_blanks = 0;
state.prev_token = token as i64;
if decoder_out.new_h.len() != PRED_HIDDEN || decoder_out.new_c.len() != PRED_HIDDEN {
anyhow::bail!(
"Unexpected decoder state shape: h={}, c={}, expected {}",
decoder_out.new_h.len(),
decoder_out.new_c.len(),
PRED_HIDDEN
);
}
state.h.copy_from_slice(&decoder_out.new_h);
state.c.copy_from_slice(&decoder_out.new_c);
if let (Some(b), Some(bs)) = (biaser, bias_state.as_mut()) {
b.advance(bs, token);
}
tokens.push(TokenInfo {
token_id: token,
frame_index: t,
confidence,
});
tokens_this_step += 1;
}
}
tracing::debug!(
decoder_calls,
joiner_calls,
skipped_decoder_calls,
encoded_len,
"decode_loop_stats"
);
Ok(DecodeResult {
tokens,
endpoint_detected,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_encoder_frame_first() {
let encoded = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut frame = vec![0.0; 2];
extract_encoder_frame(&encoded, 3, 0, &mut frame);
assert_eq!(frame, vec![1.0, 4.0]);
}
#[test]
fn test_extract_encoder_frame_last() {
let encoded = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut frame = vec![0.0; 2];
extract_encoder_frame(&encoded, 3, 2, &mut frame);
assert_eq!(frame, vec![3.0, 6.0]);
}
#[test]
fn test_extract_encoder_frame_middle() {
let encoded = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut frame = vec![0.0; 2];
extract_encoder_frame(&encoded, 3, 1, &mut frame);
assert_eq!(frame, vec![2.0, 5.0]);
}
#[test]
fn test_argmax_clear_winner() {
let logits = vec![0.1, 0.5, 0.9, 0.2];
assert_eq!(argmax(&logits, 999), 2);
}
#[test]
fn test_argmax_tie_returns_last() {
let logits = vec![1.0, 1.0, 0.5];
assert_eq!(argmax(&logits, 999), 1);
}
#[test]
fn test_argmax_single_element() {
let logits = vec![42.0];
assert_eq!(argmax(&logits, 999), 0);
}
#[test]
fn test_argmax_negative_values() {
let logits = vec![-3.0, -1.0, -2.0];
assert_eq!(argmax(&logits, 999), 1);
}
#[test]
fn test_argmax_empty_returns_blank() {
let logits: Vec<f32> = vec![];
assert_eq!(argmax(&logits, 1024), 1024);
}
#[test]
fn test_argmax_blank_id_selected() {
let logits = vec![0.1, 0.2, 0.9]; assert_eq!(argmax(&logits, 2), 2); }
struct FakeBackend {
script: std::collections::VecDeque<usize>,
vocab: usize,
blank_id: usize,
decoder_calls: u32,
joiner_calls: u32,
}
impl FakeBackend {
fn new(script: Vec<usize>, vocab: usize, blank_id: usize) -> Self {
Self {
script: script.into(),
vocab,
blank_id,
decoder_calls: 0,
joiner_calls: 0,
}
}
}
impl DecodeBackend for FakeBackend {
fn decode_step(&mut self, _state: &DecoderState, out: &mut DecoderOutput) -> Result<()> {
self.decoder_calls += 1;
DecoderOutput::fill(&mut out.dec_data, &[0.0; PRED_HIDDEN]);
DecoderOutput::fill(&mut out.new_h, &[0.0; PRED_HIDDEN]);
DecoderOutput::fill(&mut out.new_c, &[0.0; PRED_HIDDEN]);
Ok(())
}
fn joiner_step(
&mut self,
_enc_frame: &[f32],
_dec_data: &[f32],
logits_buf: &mut Vec<f32>,
) -> Result<()> {
self.joiner_calls += 1;
let tok = self.script.pop_front().unwrap_or(self.blank_id);
logits_buf.clear();
logits_buf.resize(self.vocab, 0.0);
logits_buf[tok] = 10.0; Ok(())
}
}
fn fake_enc(frames: usize) -> Vec<f32> {
vec![0.0_f32; ENC_DIM * frames]
}
#[test]
fn test_greedy_decode_happy_path() {
let mut backend = FakeBackend::new(vec![1, 4, 2, 4], 5, 4);
let mut state = DecoderState::new(4);
let result =
greedy_decode_impl(&mut backend, &fake_enc(2), 2, 4, &mut state, None).unwrap();
assert_eq!(result.tokens.len(), 2);
assert_eq!(result.tokens[0].token_id, 1);
assert_eq!(result.tokens[0].frame_index, 0);
assert_eq!(result.tokens[1].token_id, 2);
assert_eq!(result.tokens[1].frame_index, 1);
assert_eq!(state.prev_token, 2);
assert_eq!(state.h.len(), PRED_HIDDEN);
assert!(!result.endpoint_detected);
}
#[test]
fn test_greedy_decode_blank_run_skips_decoder() {
let mut backend = FakeBackend::new(vec![1, 4, 4, 4, 4], 5, 4);
let mut state = DecoderState::new(4);
let result =
greedy_decode_impl(&mut backend, &fake_enc(4), 4, 4, &mut state, None).unwrap();
assert_eq!(result.tokens.len(), 1);
assert_eq!(
backend.decoder_calls, 2,
"decoder must not run during the blank run"
);
assert!(backend.joiner_calls >= 5);
}
#[test]
fn test_greedy_decode_endpoint_after_threshold_blanks() {
let mut script = vec![1usize];
script.extend(std::iter::repeat_n(4usize, ENDPOINT_BLANK_THRESHOLD + 1));
let frames = ENDPOINT_BLANK_THRESHOLD + 2;
let mut backend = FakeBackend::new(script, 5, 4);
let mut state = DecoderState::new(4);
let result =
greedy_decode_impl(&mut backend, &fake_enc(frames), frames, 4, &mut state, None)
.unwrap();
assert!(
result.endpoint_detected,
"{ENDPOINT_BLANK_THRESHOLD}+ blanks after a token must endpoint"
);
}
#[test]
fn test_greedy_decode_no_endpoint_before_first_token() {
let frames = ENDPOINT_BLANK_THRESHOLD + 5;
let mut backend = FakeBackend::new(vec![4usize; frames], 5, 4);
let mut state = DecoderState::new(4);
let result =
greedy_decode_impl(&mut backend, &fake_enc(frames), frames, 4, &mut state, None)
.unwrap();
assert!(result.tokens.is_empty());
assert!(
!result.endpoint_detected,
"blanks before any token must not endpoint"
);
}
#[test]
fn test_greedy_decode_token_cap_does_not_inflate_blanks() {
let mut backend = FakeBackend::new(vec![1usize; MAX_TOKENS_PER_STEP + 1], 5, 4);
let mut state = DecoderState::new(4);
let result =
greedy_decode_impl(&mut backend, &fake_enc(1), 1, 4, &mut state, None).unwrap();
assert_eq!(result.tokens.len(), MAX_TOKENS_PER_STEP);
assert_eq!(
state.consecutive_blanks, 0,
"token cap must not inflate the blank counter"
);
assert!(!result.endpoint_detected);
}
#[test]
fn test_argmax_with_confidence_clear_winner() {
let (tok, conf) = argmax_with_confidence(&[0.1, 5.0, 0.2], 99);
assert_eq!(tok, 1);
assert!(
conf > 0.5 && conf <= 1.0,
"confidence should be a softmax prob in (0.5, 1], got {conf}"
);
}
#[test]
fn test_argmax_with_confidence_empty_returns_blank_zero() {
let (tok, conf) = argmax_with_confidence(&[], 1024);
assert_eq!(tok, 1024);
assert_eq!(conf, 0.0);
}
struct LogitBackend {
script: std::collections::VecDeque<Vec<f32>>,
vocab: usize,
blank_id: usize,
}
impl LogitBackend {
fn new(script: Vec<Vec<f32>>, vocab: usize, blank_id: usize) -> Self {
Self {
script: script.into(),
vocab,
blank_id,
}
}
}
impl DecodeBackend for LogitBackend {
fn decode_step(&mut self, _state: &DecoderState, out: &mut DecoderOutput) -> Result<()> {
DecoderOutput::fill(&mut out.dec_data, &[0.0; PRED_HIDDEN]);
DecoderOutput::fill(&mut out.new_h, &[0.0; PRED_HIDDEN]);
DecoderOutput::fill(&mut out.new_c, &[0.0; PRED_HIDDEN]);
Ok(())
}
fn joiner_step(
&mut self,
_enc_frame: &[f32],
_dec_data: &[f32],
logits_buf: &mut Vec<f32>,
) -> Result<()> {
logits_buf.clear();
match self.script.pop_front() {
Some(v) => logits_buf.extend_from_slice(&v),
None => {
logits_buf.resize(self.vocab, 0.0);
logits_buf[self.blank_id] = 10.0;
}
}
Ok(())
}
}
fn ab_script() -> Vec<Vec<f32>> {
vec![
vec![0.0, 2.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 100.0],
]
}
#[test]
fn test_bias_steers_argmax_to_boosted_token() {
let mut backend = LogitBackend::new(ab_script(), 4, 3);
let mut state = DecoderState::new(3);
let unbiased =
greedy_decode_impl(&mut backend, &fake_enc(2), 2, 3, &mut state, None).unwrap();
assert_eq!(unbiased.tokens.len(), 1);
assert_eq!(unbiased.tokens[0].token_id, 1, "no bias → model picks A");
let biaser = Biaser::from_sequences(vec![vec![2]], 5.0).unwrap();
let mut backend = LogitBackend::new(ab_script(), 4, 3);
let mut state = DecoderState::new(3);
let biased =
greedy_decode_impl(&mut backend, &fake_enc(2), 2, 3, &mut state, Some(&biaser))
.unwrap();
assert_eq!(biased.tokens.len(), 1);
assert_eq!(
biased.tokens[0].token_id, 2,
"boost must steer the argmax from A to the hotword token B"
);
}
#[test]
fn test_bias_prefix_advances_then_boosts_continuation() {
let script = vec![
vec![0.0, 0.0, 0.0, 0.0, 0.0, 3.0],
vec![0.0, 0.0, 0.0, 100.0, 0.0, 0.0],
vec![0.0, 2.0, 1.0, 0.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0, 100.0, 0.0, 0.0],
];
let biaser = Biaser::from_sequences(vec![vec![5, 2]], 5.0).unwrap();
let mut backend = LogitBackend::new(script, 6, 3);
let mut state = DecoderState::new(3);
let result =
greedy_decode_impl(&mut backend, &fake_enc(2), 2, 3, &mut state, Some(&biaser))
.unwrap();
assert_eq!(
result.tokens.iter().map(|t| t.token_id).collect::<Vec<_>>(),
vec![5, 2],
"prefix [5] must advance so the boost on the continuation 2 steers frame 1"
);
}
#[test]
fn test_bias_none_is_byte_for_byte_unchanged() {
let base_script = || {
vec![
vec![0.0, 2.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 100.0],
vec![0.0, 1.5, 2.5, 0.0],
vec![0.0, 0.0, 0.0, 100.0],
]
};
let mut b_none = LogitBackend::new(base_script(), 4, 3);
let mut s_none = DecoderState::new(3);
let none = greedy_decode_impl(&mut b_none, &fake_enc(2), 2, 3, &mut s_none, None).unwrap();
let biaser = Biaser::from_sequences(vec![vec![0]], 0.5).unwrap();
let mut b_some = LogitBackend::new(base_script(), 4, 3);
let mut s_some = DecoderState::new(3);
let some = greedy_decode_impl(&mut b_some, &fake_enc(2), 2, 3, &mut s_some, Some(&biaser))
.unwrap();
assert_eq!(
none.tokens.iter().map(|t| t.token_id).collect::<Vec<_>>(),
some.tokens.iter().map(|t| t.token_id).collect::<Vec<_>>(),
"a non-winning hotword must not change the decoded tokens"
);
assert_eq!(none.tokens.len(), 2);
}
}