use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use serde::{Deserialize, Serialize};
use tracing::debug;
#[derive(Clone, Copy, Debug, Serialize, Deserialize, Default)]
pub enum SamplingStrategy {
Greedy,
#[default]
Temperature,
TopK,
TopP,
TopKP,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SamplingConfig {
pub strategy: SamplingStrategy,
pub temperature: f32,
pub top_k: u32,
pub top_p: f32,
pub repeat_penalty: f32,
pub seed: Option<u64>,
}
impl Default for SamplingConfig {
fn default() -> Self {
Self {
strategy: SamplingStrategy::Temperature,
temperature: 0.7,
top_k: 40,
top_p: 0.9,
repeat_penalty: 1.1,
seed: None,
}
}
}
#[derive(Clone, Debug)]
pub struct TokenCandidate {
pub id: i32,
pub logit: f32,
pub p: f32,
}
pub struct Sampler {
config: SamplingConfig,
recent_tokens: Vec<i32>,
rng: StdRng,
}
impl Sampler {
pub fn new(config: SamplingConfig) -> Self {
let rng = match config.seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_entropy(),
};
Self {
config,
recent_tokens: Vec::new(),
rng,
}
}
pub fn sample_from_candidates<T: AsRef<[(i32, f32, f32)]>>(
&mut self,
candidates_data: T,
) -> Option<i32> {
let candidates_ref = candidates_data.as_ref();
if candidates_ref.is_empty() {
return None;
}
let mut candidates: Vec<TokenCandidate> = candidates_ref
.iter()
.map(|(id, logit, p)| TokenCandidate {
id: *id,
logit: *logit,
p: *p,
})
.collect();
self.sample_internal(&mut candidates)
}
fn sample_internal(&mut self, candidates: &mut [TokenCandidate]) -> Option<i32> {
if candidates.is_empty() {
return None;
}
if matches!(
self.config.strategy,
SamplingStrategy::Temperature | SamplingStrategy::TopKP
) {
Self::apply_temperature(candidates, self.config.temperature);
}
let mut adjusted = candidates.to_vec();
if matches!(
self.config.strategy,
SamplingStrategy::TopK | SamplingStrategy::TopKP
) && self.config.top_k > 0
{
Self::apply_top_k(&mut adjusted, self.config.top_k as usize);
}
if matches!(
self.config.strategy,
SamplingStrategy::TopP | SamplingStrategy::TopKP
) && self.config.top_p > 0.0
&& self.config.top_p < 1.0
{
Self::apply_top_p(&mut adjusted, self.config.top_p);
}
let token = match self.config.strategy {
SamplingStrategy::Greedy => Self::greedy_sample(&adjusted),
_ => self.probabilistic_sample(&adjusted),
};
if let Some(t) = token {
self.recent_tokens.push(t);
if self.recent_tokens.len() > 50 {
self.recent_tokens.remove(0);
}
}
token
}
pub fn sample(&mut self, candidates: &[TokenCandidate]) -> Option<i32> {
let mut candidates_vec = candidates.to_vec();
self.sample_internal(&mut candidates_vec)
}
fn apply_temperature(candidates: &mut [TokenCandidate], temperature: f32) {
if temperature <= 0.0 {
return; }
for token in candidates.iter_mut() {
token.logit /= temperature;
}
debug!(
"Applied temperature scaling: {} (adjusted logits for {} candidates)",
temperature,
candidates.len()
);
}
fn apply_top_k(candidates: &mut Vec<TokenCandidate>, k: usize) {
if candidates.len() <= k {
return; }
candidates.sort_by(|a, b| b.p.partial_cmp(&a.p).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(k);
debug!("Applied top-k filtering: kept {} tokens out of original", k);
}
fn apply_top_p(candidates: &mut Vec<TokenCandidate>, p: f32) {
if candidates.is_empty() {
return;
}
candidates.sort_by(|a, b| b.p.partial_cmp(&a.p).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut cutoff_idx = candidates.len();
for (i, token) in candidates.iter().enumerate() {
cumsum += token.p;
if cumsum > p {
cutoff_idx = i + 1;
break;
}
}
candidates.truncate(cutoff_idx);
debug!(
"Applied top-p filtering: kept {} tokens for p={}",
cutoff_idx, p
);
}
fn greedy_sample(candidates: &[TokenCandidate]) -> Option<i32> {
candidates
.iter()
.max_by(|a, b| a.p.partial_cmp(&b.p).unwrap_or(std::cmp::Ordering::Equal))
.map(|t| t.id)
}
fn probabilistic_sample(&mut self, candidates: &[TokenCandidate]) -> Option<i32> {
if candidates.is_empty() {
return None;
}
let max_logit = candidates
.iter()
.map(|c| c.logit)
.fold(f32::NEG_INFINITY, f32::max);
let scores: Vec<f32> = candidates
.iter()
.map(|c| (c.logit - max_logit).exp())
.collect();
let sum: f32 = scores.iter().sum();
if sum <= 0.0 {
return None;
}
let probs: Vec<f32> = scores.iter().map(|s| s / sum).collect();
let threshold: f32 = self.rng.gen_range(0.0..1.0);
let mut cumulative = 0.0;
for (i, prob) in probs.iter().enumerate() {
cumulative += *prob;
if cumulative >= threshold {
return Some(candidates[i].id);
}
}
candidates.last().map(|t| t.id)
}
pub fn get_recent_tokens(&self) -> &[i32] {
&self.recent_tokens
}
pub fn clear_history(&mut self) {
self.recent_tokens.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_greedy_sampling() {
let candidates = vec![
TokenCandidate {
id: 1,
logit: 0.1,
p: 0.1,
},
TokenCandidate {
id: 2,
logit: 0.5,
p: 0.5,
},
TokenCandidate {
id: 3,
logit: 0.3,
p: 0.3,
},
];
let token = Sampler::greedy_sample(&candidates);
assert_eq!(token, Some(2)); }
#[test]
fn test_top_k_filtering() {
let mut candidates = vec![
TokenCandidate {
id: 1,
logit: 0.1,
p: 0.1,
},
TokenCandidate {
id: 2,
logit: 0.5,
p: 0.5,
},
TokenCandidate {
id: 3,
logit: 0.3,
p: 0.3,
},
TokenCandidate {
id: 4,
logit: 0.05,
p: 0.05,
},
];
Sampler::apply_top_k(&mut candidates, 2);
assert_eq!(candidates.len(), 2);
assert_eq!(candidates[0].id, 2); assert_eq!(candidates[1].id, 3); }
#[test]
fn test_top_p_filtering() {
let mut candidates = vec![
TokenCandidate {
id: 1,
logit: 0.0,
p: 0.5,
},
TokenCandidate {
id: 2,
logit: 0.0,
p: 0.3,
},
TokenCandidate {
id: 3,
logit: 0.0,
p: 0.15,
},
TokenCandidate {
id: 4,
logit: 0.0,
p: 0.05,
},
];
Sampler::apply_top_p(&mut candidates, 0.8);
assert_eq!(candidates.len(), 3);
}
#[test]
fn test_sampler_with_config() {
let config = SamplingConfig {
strategy: SamplingStrategy::Greedy,
temperature: 1.0,
top_k: 40,
top_p: 0.9,
repeat_penalty: 1.1,
seed: None,
};
let mut sampler = Sampler::new(config);
let candidates = vec![TokenCandidate {
id: 5,
logit: 0.8,
p: 0.8,
}];
let token = sampler.sample(&candidates);
assert_eq!(token, Some(5));
assert_eq!(sampler.get_recent_tokens(), &[5]);
}
#[test]
fn test_temperature_scaling() {
let mut candidates = vec![
TokenCandidate {
id: 1,
logit: 2.0,
p: 0.1,
},
TokenCandidate {
id: 2,
logit: 1.0,
p: 0.5,
},
];
let original_logits = candidates.iter().map(|c| c.logit).collect::<Vec<_>>();
Sampler::apply_temperature(&mut candidates, 2.0);
let scaled_logits = candidates.iter().map(|c| c.logit).collect::<Vec<_>>();
for (original, scaled) in original_logits.iter().zip(scaled_logits.iter()) {
assert!(scaled < original);
}
}
}