use rand::Rng;
#[derive(Debug, Clone)]
pub struct SamplerConfig {
pub temperature: f32,
pub top_k: usize,
pub repetition_penalty: f32,
}
pub fn sample(
logits: &[f32],
config: &SamplerConfig,
past_tokens: &[i64],
mask: impl Fn(usize) -> bool,
) -> usize {
let mut scores: Vec<f32> = logits.to_vec();
for (i, s) in scores.iter_mut().enumerate() {
if mask(i) {
*s = f32::NEG_INFINITY;
}
}
if config.repetition_penalty != 1.0 {
for &tok in past_tokens {
let idx = tok as usize;
if idx < scores.len() {
if scores[idx] > 0.0 {
scores[idx] /= config.repetition_penalty;
} else {
scores[idx] *= config.repetition_penalty;
}
}
}
}
if config.temperature != 1.0 && config.temperature > 0.0 {
for s in &mut scores {
*s /= config.temperature;
}
}
if config.top_k > 0 && config.top_k < scores.len() {
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for &(i, _) in &indexed[config.top_k..] {
scores[i] = f32::NEG_INFINITY;
}
}
let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = scores.iter().map(|&s| (s - max).exp()).collect();
let sum: f32 = probs.iter().sum();
if sum > 0.0 {
for p in &mut probs {
*p /= sum;
}
}
let mut rng = rand::rng();
let r: f32 = rng.random::<f32>();
let mut accum = 0.0;
for (i, &p) in probs.iter().enumerate() {
accum += p;
if accum >= r {
return i;
}
}
probs
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
pub fn talker_mask(token: usize) -> bool {
(2048..3072).contains(&token) && token != 2150
}
pub fn no_mask(_token: usize) -> bool {
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn talker_mask_allows_codebook_tokens() {
for i in 0..2048 {
assert!(!talker_mask(i), "token {i} should not be masked");
}
}
#[test]
fn talker_mask_blocks_control_tokens_except_eos() {
assert!(!talker_mask(2150), "EOS should not be masked");
assert!(talker_mask(2148), "PAD should be masked");
assert!(talker_mask(2149), "BOS should be masked");
assert!(talker_mask(2154), "THINK should be masked");
}
#[test]
fn sample_returns_valid_index() {
let logits = vec![0.0; 100];
let config = SamplerConfig {
temperature: 1.0,
top_k: 50,
repetition_penalty: 1.0,
};
let idx = sample(&logits, &config, &[], no_mask);
assert!(idx < 100);
}
#[test]
fn sample_respects_mask() {
let logits = vec![0.0; 10];
let config = SamplerConfig {
temperature: 1.0,
top_k: 0,
repetition_penalty: 1.0,
};
let idx = sample(&logits, &config, &[], |i| i != 5);
assert_eq!(idx, 5);
}
}