use anyhow::{Result, anyhow};
use flate2::{Compression, write::GzEncoder};
use rand::{RngExt, SeedableRng};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::io::Write;
use crate::language::Task;
#[derive(Debug, Clone)]
pub struct DecodingConfig {
pub beam_size: usize,
pub temperatures: Vec<f32>,
pub length_penalty: f32,
pub no_speech_threshold: f32,
pub max_length: usize,
pub language: String,
pub task: Task,
pub compression_ratio_threshold: f32,
pub log_prob_threshold: f32,
}
impl Default for DecodingConfig {
fn default() -> Self {
Self {
beam_size: 5,
temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
length_penalty: 1.0,
no_speech_threshold: 0.6,
max_length: 448, language: "en".to_string(),
task: Task::Transcribe,
compression_ratio_threshold: 2.4,
log_prob_threshold: -1.0,
}
}
}
impl DecodingConfig {
pub fn fast() -> Self {
Self {
beam_size: 1,
temperatures: vec![0.0],
length_penalty: 0.0,
..Default::default()
}
}
pub fn balanced() -> Self {
Self {
beam_size: 5,
temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
length_penalty: 1.0,
..Default::default()
}
}
pub fn accurate() -> Self {
Self {
beam_size: 10,
temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
length_penalty: 1.0,
..Default::default()
}
}
pub fn with_beam_size(mut self, beam_size: usize) -> Self {
self.beam_size = beam_size.max(1);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperatures = vec![temperature.max(0.0)];
self
}
pub fn with_length_penalty(mut self, penalty: f32) -> Self {
self.length_penalty = penalty.max(0.0);
self
}
pub fn with_no_speech_threshold(mut self, threshold: f32) -> Self {
self.no_speech_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn with_language(mut self, language: String) -> Self {
self.language = language;
self
}
pub fn with_task(mut self, task: Task) -> Self {
self.task = task;
self
}
}
pub fn compression_ratio(text: &str) -> f32 {
let bytes = text.as_bytes();
if bytes.is_empty() {
return 0.0;
}
let mut enc = GzEncoder::new(Vec::new(), Compression::default());
enc.write_all(bytes).ok();
let compressed_len = enc.finish().unwrap_or_default().len().max(1);
bytes.len() as f32 / compressed_len as f32
}
fn softmax_at(logits: &[f32], token: u32) -> f32 {
let idx = token as usize;
if idx >= logits.len() {
return 0.0;
}
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits.iter().map(|&l| (l - max).exp()).sum();
((logits[idx] - max).exp()) / exp_sum.max(f32::EPSILON)
}
fn log_softmax_at(logits: &[f32], token: u32, temp: f32) -> f32 {
let idx = token as usize;
if idx >= logits.len() {
return f32::NEG_INFINITY;
}
let scaled: Vec<f32> = if temp > 0.0 {
logits.iter().map(|&l| l / temp).collect()
} else {
logits.to_vec()
};
let max = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let log_sum = max + scaled.iter().map(|&l| (l - max).exp()).sum::<f32>().ln();
scaled[idx] - log_sum
}
fn sample_from_logits(logits: &[f32], temp: f32, rng: &mut impl rand::Rng) -> u32 {
if temp <= 0.0 || logits.is_empty() {
return argmax_logits(logits);
}
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&l| ((l - max) / temp).exp()).collect();
let sum: f32 = exps.iter().sum::<f32>().max(f32::EPSILON);
let threshold: f32 = rng.random::<f32>() * sum;
let mut cumsum = 0.0;
for (i, &e) in exps.iter().enumerate() {
cumsum += e;
if cumsum >= threshold {
return i as u32;
}
}
(logits.len() - 1) as u32
}
fn argmax_logits(logits: &[f32]) -> u32 {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0)
}
#[derive(Debug, Clone)]
struct BeamCandidate {
tokens: Vec<u32>,
log_prob: f32,
finished: bool,
token_count: usize,
}
impl BeamCandidate {
fn new(token: u32) -> Self {
Self {
tokens: vec![token],
log_prob: 0.0,
finished: false,
token_count: 1,
}
}
fn normalized_score(&self, length_penalty: f32) -> f32 {
if self.token_count == 0 {
return self.log_prob;
}
self.log_prob / ((self.token_count as f32).powf(length_penalty))
}
}
impl PartialEq for BeamCandidate {
fn eq(&self, other: &Self) -> bool {
(self.normalized_score(1.0) - other.normalized_score(1.0)).abs() < 1e-6
}
}
impl Eq for BeamCandidate {}
impl PartialOrd for BeamCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for BeamCandidate {
fn cmp(&self, other: &Self) -> Ordering {
other
.normalized_score(1.0)
.partial_cmp(&self.normalized_score(1.0))
.unwrap_or(Ordering::Equal)
}
}
pub struct BeamSearchDecoder {
config: DecodingConfig,
}
impl BeamSearchDecoder {
pub fn new(config: DecodingConfig) -> Self {
Self { config }
}
pub fn decode(
&self,
token_probs: &[Vec<f32>],
initial_token: u32,
vocab_size: usize,
eos_token: u32,
_pad_token: u32,
) -> Result<Vec<u32>> {
if token_probs.is_empty() {
return Ok(vec![initial_token]);
}
if token_probs.iter().any(|probs| probs.len() != vocab_size) {
return Err(anyhow!("Invalid token probabilities shape"));
}
let mut candidates = BinaryHeap::new();
candidates.push(BeamCandidate::new(initial_token));
for step in 0..token_probs.len().min(self.config.max_length) {
let probs = &token_probs[step];
let mut next_candidates = Vec::new();
for candidate in candidates.iter().take(self.config.beam_size) {
if candidate.finished {
next_candidates.push(candidate.clone());
continue;
}
let top_k = self.get_top_k_tokens(probs, self.config.beam_size);
for (token, log_prob) in top_k {
let mut new_candidate = candidate.clone();
new_candidate.tokens.push(token);
new_candidate.log_prob += log_prob;
new_candidate.token_count += 1;
if token == eos_token || step == token_probs.len() - 1 {
new_candidate.finished = true;
}
next_candidates.push(new_candidate);
}
}
next_candidates.sort_by(|a, b| {
b.normalized_score(self.config.length_penalty)
.partial_cmp(&a.normalized_score(self.config.length_penalty))
.unwrap_or(Ordering::Equal)
});
candidates = next_candidates
.into_iter()
.take(self.config.beam_size)
.collect::<BinaryHeap<_>>();
if candidates.iter().all(|c| c.finished) {
break;
}
}
candidates
.iter()
.max_by(|a, b| {
a.normalized_score(self.config.length_penalty)
.partial_cmp(&b.normalized_score(self.config.length_penalty))
.unwrap_or(Ordering::Equal)
})
.map(|c| c.tokens.clone())
.ok_or_else(|| anyhow!("No valid candidates found"))
}
fn get_top_k_tokens(&self, log_probs: &[f32], k: usize) -> Vec<(u32, f32)> {
let mut indexed: Vec<(u32, f32)> = log_probs
.iter()
.enumerate()
.map(|(i, &prob)| (i as u32, prob))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
indexed.into_iter().take(k).collect()
}
}
pub struct GreedyDecoder;
impl GreedyDecoder {
pub fn decode(
token_probs: &[Vec<f32>],
initial_token: u32,
_vocab_size: usize,
eos_token: u32,
_pad_token: u32,
) -> Result<Vec<u32>> {
let mut tokens = vec![initial_token];
for probs in token_probs {
if probs.is_empty() {
break;
}
let (token, _) = probs
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
.unwrap_or((eos_token as usize, &f32::NEG_INFINITY));
let token = token as u32;
tokens.push(token);
if token == eos_token {
break;
}
}
Ok(tokens)
}
}
pub struct HybridDecoder {
config: DecodingConfig,
beam_decoder: BeamSearchDecoder,
}
impl HybridDecoder {
pub fn new(config: DecodingConfig) -> Self {
Self {
beam_decoder: BeamSearchDecoder::new(config.clone()),
config,
}
}
pub fn decode(
&self,
token_probs: &[Vec<f32>],
initial_token: u32,
vocab_size: usize,
eos_token: u32,
pad_token: u32,
) -> Result<Vec<u32>> {
match self
.beam_decoder
.decode(token_probs, initial_token, vocab_size, eos_token, pad_token)
{
Ok(tokens) if tokens.len() > 1 => Ok(tokens),
_ => {
GreedyDecoder::decode(token_probs, initial_token, vocab_size, eos_token, pad_token)
}
}
}
pub fn decode_with_fallback(
&self,
token_probs: &[Vec<f32>],
initial_token: u32,
vocab_size: usize,
eos_token: u32,
no_speech_token: u32,
decode_text: impl Fn(&[u32]) -> String,
) -> Result<Vec<u32>> {
if token_probs.is_empty() {
return Ok(vec![initial_token]);
}
if (no_speech_token as usize) < vocab_size {
let ns_prob = softmax_at(&token_probs[0], no_speech_token);
if ns_prob > self.config.no_speech_threshold {
return Ok(vec![]);
}
}
let mut best: Option<Vec<u32>> = None;
for &temp in &self.config.temperatures {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let mut tokens = vec![initial_token];
let mut log_probs: Vec<f32> = Vec::new();
for step_logits in token_probs.iter().take(self.config.max_length) {
let selected = sample_from_logits(step_logits, temp, &mut rng);
log_probs.push(log_softmax_at(step_logits, selected, temp));
tokens.push(selected);
if selected == eos_token {
break;
}
}
let avg_lp = if log_probs.is_empty() {
0.0
} else {
log_probs.iter().sum::<f32>() / log_probs.len() as f32
};
let text = decode_text(&tokens);
let cr = compression_ratio(&text);
let quality_ok = avg_lp > self.config.log_prob_threshold
&& cr < self.config.compression_ratio_threshold;
if best.is_none() {
best = Some(tokens.clone());
}
if quality_ok {
return Ok(tokens);
}
}
best.ok_or_else(|| anyhow!("decode_with_fallback: no temperatures configured"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoding_config_defaults() {
let config = DecodingConfig::default();
assert_eq!(config.beam_size, 5);
assert_eq!(config.temperatures, vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0]);
assert_eq!(config.language, "en");
}
#[test]
fn test_decoding_config_fast() {
let config = DecodingConfig::fast();
assert_eq!(config.beam_size, 1);
assert_eq!(config.temperatures, vec![0.0]);
}
#[test]
fn test_decoding_config_accurate() {
let config = DecodingConfig::accurate();
assert_eq!(config.beam_size, 10);
assert_eq!(config.temperatures, vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0]);
}
#[test]
fn test_with_temperature_overrides_sequence() {
let config = DecodingConfig::default().with_temperature(0.7);
assert_eq!(config.temperatures, vec![0.7]);
}
#[test]
fn test_beam_candidate_scoring() {
let mut c1 = BeamCandidate::new(1);
c1.log_prob = -4.0;
c1.token_count = 2;
let mut c2 = BeamCandidate::new(2);
c2.log_prob = -1.0;
c2.token_count = 1;
assert!(c2.normalized_score(1.0) > c1.normalized_score(1.0));
}
#[test]
fn test_greedy_decoder() -> Result<()> {
let token_probs = vec![
vec![-10.0, -5.0, -0.5, -10.0], vec![-5.0, -10.0, -0.1, -10.0], vec![-0.5, -5.0, -10.0, -10.0], ];
let tokens = GreedyDecoder::decode(&token_probs, 50256, 4, 0, 50257)?;
assert_eq!(tokens.len(), 4); assert_eq!(tokens[0], 50256); assert_eq!(tokens[1], 2); assert_eq!(tokens[2], 2); assert_eq!(tokens[3], 0);
Ok(())
}
#[test]
fn test_beam_search_decoder() -> Result<()> {
let config = DecodingConfig {
beam_size: 2,
..Default::default()
};
let decoder = BeamSearchDecoder::new(config);
let token_probs = vec![
vec![-5.0, -0.5, -10.0], vec![-0.1, -5.0, -10.0], ];
let tokens = decoder.decode(&token_probs, 100, 3, 0, 99)?;
assert!(tokens.len() >= 2);
assert_eq!(tokens[0], 100);
Ok(())
}
#[test]
fn test_hybrid_decoder_fallback() -> Result<()> {
let config = DecodingConfig::default();
let decoder = HybridDecoder::new(config);
let token_probs = vec![vec![-0.5, -10.0, -10.0]];
let tokens = decoder.decode(&token_probs, 100, 3, 0, 99)?;
assert!(!tokens.is_empty());
assert_eq!(tokens[0], 100);
Ok(())
}
#[test]
fn test_compression_ratio_normal_text() {
let text = "The quick brown fox jumps over the lazy dog.";
let cr = compression_ratio(text);
assert!(cr < 2.4, "normal text compression ratio was {cr}");
}
#[test]
fn test_compression_ratio_repetitive_text() {
let phrase = "the quick brown fox ";
let text = phrase.repeat(100); let cr = compression_ratio(&text);
assert!(cr > 2.4, "repetitive text compression ratio was {cr}");
}
#[test]
fn test_compression_ratio_empty() {
assert_eq!(compression_ratio(""), 0.0);
}
#[test]
fn test_softmax_at_picks_max() {
let logits = vec![-10.0, -0.1, -5.0];
let p_max = softmax_at(&logits, 1);
let p_min = softmax_at(&logits, 0);
assert!(p_max > p_min, "softmax of max logit should be highest");
let total: f32 = (0..3).map(|i| softmax_at(&logits, i)).sum();
assert!((total - 1.0).abs() < 1e-4, "softmax probs must sum to 1");
}
#[test]
fn test_decode_with_fallback_passes_quality() -> Result<()> {
let config = DecodingConfig {
temperatures: vec![0.0],
log_prob_threshold: -100.0,
compression_ratio_threshold: 100.0,
no_speech_threshold: 1.0,
max_length: 5,
..Default::default()
};
let decoder = HybridDecoder::new(config);
let token_probs = vec![vec![-0.01, -10.0, -10.0], vec![-0.01, -10.0, -10.0]];
let tokens = decoder.decode_with_fallback(
&token_probs,
99,
3,
0, 2, |ids| {
ids.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(" ")
},
)?;
assert!(!tokens.is_empty());
assert_eq!(tokens[0], 99);
Ok(())
}
#[test]
fn test_decode_with_fallback_no_speech() -> Result<()> {
let config = DecodingConfig {
temperatures: vec![0.0],
no_speech_threshold: 0.5,
log_prob_threshold: -100.0,
compression_ratio_threshold: 100.0,
max_length: 5,
..Default::default()
};
let decoder = HybridDecoder::new(config);
let token_probs = vec![vec![-10.0, 100.0, -10.0]];
let tokens = decoder.decode_with_fallback(
&token_probs,
99,
3,
0, 1, |_| String::new(),
)?;
assert!(
tokens.is_empty(),
"should return empty when no-speech detected"
);
Ok(())
}
}