pub fn sample_token(logits: &[f32], temperature: f32, top_p: f32) -> usize {
let vocab_size = logits.len();
if vocab_size == 0 {
return 0;
}
let scaled: Vec<f32> = if temperature > 0.001 {
logits.iter().map(|&l| l / temperature).collect()
} else {
return logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
};
let max_val = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = probs.iter().sum();
if sum > 0.0 {
for p in probs.iter_mut() {
*p /= sum;
}
} else {
let uniform = 1.0 / vocab_size as f32;
for p in probs.iter_mut() {
*p = uniform;
}
}
if top_p < 1.0 {
let mut indexed: Vec<(usize, f32)> =
probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0f32;
let mut cutoff_idx = indexed.len();
for (i, &(_, p)) in indexed.iter().enumerate() {
cumsum += p;
if cumsum >= top_p {
cutoff_idx = i + 1;
break;
}
}
for &(_, _p) in indexed.iter().skip(cutoff_idx) {
}
let mut new_probs = vec![0.0f32; vocab_size];
let new_sum: f32 = indexed[..cutoff_idx]
.iter()
.map(|&(i, p)| {
new_probs[i] = p;
p
})
.sum();
if new_sum > 0.0 {
for p in new_probs.iter_mut() {
*p /= new_sum;
}
}
probs = new_probs;
}
let r: f32 = rand::random::<f32>();
let mut cumsum = 0.0f32;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if cumsum >= r {
return i;
}
}
vocab_size - 1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_greedy_sampling() {
let logits = vec![0.1, 0.1, 5.0, 0.1, 0.1];
let token = sample_token(&logits, 0.0, 1.0);
assert_eq!(token, 2); }
#[test]
fn test_temperature_sampling() {
let logits = vec![1.0, 1.0, 1.0, 1.0];
let mut counts = vec![0usize; 4];
for _ in 0..1000 {
let token = sample_token(&logits, 1.0, 1.0);
assert!(token < 4);
counts[token] += 1;
}
for &c in &counts {
assert!(c > 100, "Token count {} too low", c);
}
}
#[test]
fn test_top_p_sampling() {
let logits = vec![10.0, 0.0, 0.0, 0.0, 0.0];
let token = sample_token(&logits, 1.0, 0.5);
assert_eq!(token, 0); }
}