const DEFAULT_MAX_CONSECUTIVE_REPEATS: usize = 8;
pub struct GreedyDecoder {
eos_id: i64,
max_consecutive_repeats: usize,
last_token: i64,
consecutive_count: usize,
}
impl GreedyDecoder {
pub fn new(eos_id: i64) -> Self {
Self {
eos_id,
max_consecutive_repeats: DEFAULT_MAX_CONSECUTIVE_REPEATS,
last_token: -1,
consecutive_count: 0,
}
}
pub fn with_max_repeats(mut self, n: usize) -> Self {
self.max_consecutive_repeats = n;
self
}
pub fn next_token(&mut self, logits: &[f32]) -> Option<i64> {
let token = argmax(logits) as i64;
if token == self.eos_id {
return None;
}
if token == self.last_token {
self.consecutive_count += 1;
if self.consecutive_count > self.max_consecutive_repeats {
log::warn!(
"Greedy decode: token {} repeated {} consecutive times, stopping",
token,
self.consecutive_count
);
return None;
}
} else {
self.consecutive_count = 1;
}
self.last_token = token;
Some(token)
}
}
fn argmax(logits: &[f32]) -> usize {
let mut max_idx = 0;
let mut max_val = f32::NEG_INFINITY;
for (i, &v) in logits.iter().enumerate() {
if v > max_val {
max_val = v;
max_idx = i;
}
}
max_idx
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_argmax() {
assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
assert_eq!(argmax(&[-1.0, -3.0, -0.5]), 2);
assert_eq!(argmax(&[5.0]), 0);
}
#[test]
fn test_eos_stops() {
let mut dec = GreedyDecoder::new(2);
assert_eq!(dec.next_token(&[0.0, 0.0, 10.0, 0.0]), None);
}
#[test]
fn test_normal_token() {
let mut dec = GreedyDecoder::new(2);
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0, 0.0]), Some(1));
}
#[test]
fn test_repeat_limit() {
let mut dec = GreedyDecoder::new(99).with_max_repeats(3);
let logits = [0.0, 10.0, 0.0]; assert_eq!(dec.next_token(&logits), Some(1)); assert_eq!(dec.next_token(&logits), Some(1)); assert_eq!(dec.next_token(&logits), Some(1)); assert_eq!(dec.next_token(&logits), None); }
#[test]
fn test_repeat_resets_on_different_token() {
let mut dec = GreedyDecoder::new(99).with_max_repeats(3);
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); assert_eq!(dec.next_token(&[10.0, 0.0, 0.0]), Some(0)); assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), None); }
#[test]
fn test_nan_handling() {
let mut dec = GreedyDecoder::new(99);
assert_eq!(dec.next_token(&[f32::NAN, f32::NAN, f32::NAN]), Some(0));
}
}