use crate::error::{AudioError, AudioResult};
use crate::traits::{SttOptions, Transcript};
use ort::session::Session;
use ort::value::Value;
const EOT_TOKEN: u32 = 2;
const MAX_DECODE_TOKENS: usize = 448;
pub struct MoonshineDecoder {
tokenizer: tokenizers::Tokenizer,
}
impl MoonshineDecoder {
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
Self { tokenizer }
}
pub fn transcribe(
&self,
encoder: &mut Session,
decoder: &mut Session,
samples: &[f32],
_opts: &SttOptions,
) -> AudioResult<Transcript> {
if samples.is_empty() {
return Ok(Transcript {
text: String::new(),
confidence: 0.0,
language_detected: None,
words: Vec::new(),
..Default::default()
});
}
let encoder_output = self.run_encoder(encoder, samples)?;
let decoded_tokens = self.greedy_decode(decoder, &encoder_output)?;
self.build_transcript(&decoded_tokens)
}
fn run_encoder(&self, encoder: &mut Session, samples: &[f32]) -> AudioResult<EncoderOutput> {
let n_samples = samples.len() as i64;
let audio_tensor =
Value::from_array(([1i64, n_samples], samples.to_vec())).map_err(|e| {
AudioError::Stt {
provider: "ONNX/Moonshine".into(),
message: format!("failed to create audio tensor [1, {n_samples}]: {e}"),
}
})?;
let outputs =
encoder.run(ort::inputs!["audio" => audio_tensor]).map_err(|e| AudioError::Stt {
provider: "ONNX/Moonshine".into(),
message: format!("encoder inference failed: {e}"),
})?;
let (shape, data) =
outputs[0].try_extract_tensor::<f32>().map_err(|e| AudioError::Stt {
provider: "ONNX/Moonshine".into(),
message: format!("failed to extract encoder hidden states: {e}"),
})?;
let shape_vec: Vec<i64> = shape.iter().copied().collect();
Ok(EncoderOutput { data: data.to_vec(), shape: shape_vec })
}
fn greedy_decode(
&self,
decoder: &mut Session,
encoder_hidden: &EncoderOutput,
) -> AudioResult<Vec<u32>> {
let mut tokens: Vec<u32> = vec![1];
let mut output_tokens: Vec<u32> = Vec::new();
for _ in 0..MAX_DECODE_TOKENS {
let logits = self.run_decoder_step(decoder, encoder_hidden, &tokens)?;
let next_token = argmax(&logits);
if next_token == EOT_TOKEN {
break;
}
output_tokens.push(next_token);
tokens.push(next_token);
}
Ok(output_tokens)
}
fn run_decoder_step(
&self,
decoder: &mut Session,
encoder_hidden: &EncoderOutput,
tokens: &[u32],
) -> AudioResult<Vec<f32>> {
let input_ids: Vec<i64> = tokens.iter().map(|&t| t as i64).collect();
let seq_len = input_ids.len() as i64;
let ids_tensor =
Value::from_array(([1i64, seq_len], input_ids)).map_err(|e| AudioError::Stt {
provider: "ONNX/Moonshine".into(),
message: format!("failed to create input_ids tensor: {e}"),
})?;
let encoder_tensor =
Value::from_array((encoder_hidden.shape.clone(), encoder_hidden.data.clone()))
.map_err(|e| AudioError::Stt {
provider: "ONNX/Moonshine".into(),
message: format!("failed to create encoder_output tensor: {e}"),
})?;
let outputs = decoder
.run(ort::inputs!["input_ids" => ids_tensor, "encoder_output" => encoder_tensor])
.map_err(|e| AudioError::Stt {
provider: "ONNX/Moonshine".into(),
message: format!("decoder inference failed: {e}"),
})?;
let (shape, logits_data) =
outputs[0].try_extract_tensor::<f32>().map_err(|e| AudioError::Stt {
provider: "ONNX/Moonshine".into(),
message: format!("failed to extract decoder logits: {e}"),
})?;
let vocab_size = if shape.len() == 3 { shape[2] as usize } else { logits_data.len() };
let total = logits_data.len();
let start = total.saturating_sub(vocab_size);
Ok(logits_data[start..].to_vec())
}
fn build_transcript(&self, decoded_tokens: &[u32]) -> AudioResult<Transcript> {
if decoded_tokens.is_empty() {
return Ok(Transcript {
text: String::new(),
confidence: 0.0,
language_detected: None,
words: Vec::new(),
..Default::default()
});
}
let text = self
.tokenizer
.decode(decoded_tokens, true)
.map_err(|e| AudioError::Stt {
provider: "ONNX/Moonshine".into(),
message: format!("tokenizer decode failed: {e}"),
})?
.trim()
.to_string();
let confidence = if decoded_tokens.is_empty() {
0.0
} else {
(1.0 - (decoded_tokens.len() as f32 / MAX_DECODE_TOKENS as f32)).max(0.1)
};
Ok(Transcript {
text,
confidence,
language_detected: None,
words: Vec::new(),
..Default::default()
})
}
}
struct EncoderOutput {
data: Vec<f32>,
shape: Vec<i64>,
}
fn argmax(values: &[f32]) -> u32 {
values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_argmax_basic() {
assert_eq!(argmax(&[0.1, 0.9, 0.5]), 1);
assert_eq!(argmax(&[3.0, 1.0, 2.0]), 0);
}
#[test]
fn test_argmax_empty() {
assert_eq!(argmax(&[]), 0);
}
#[test]
fn test_eot_token_value() {
assert_eq!(EOT_TOKEN, 2);
}
#[test]
fn test_max_decode_tokens() {
assert_eq!(MAX_DECODE_TOKENS, 448);
}
}