use crate::error::{CognitionError, Result};
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct SamplingConfig {
pub temperature: f64,
pub top_k: usize,
pub top_p: f64,
pub repetition_penalty: f64,
}
impl Default for SamplingConfig {
fn default() -> Self {
Self {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
repetition_penalty: 1.0,
}
}
}
impl SamplingConfig {
pub fn greedy() -> Self {
Self {
temperature: 0.0,
top_k: 1,
top_p: 1.0,
repetition_penalty: 1.0,
}
}
pub fn creative() -> Self {
Self {
temperature: 0.8,
top_k: 40,
top_p: 0.95,
repetition_penalty: 1.2,
}
}
}
pub fn sample_token(
logits: &Tensor,
config: &SamplingConfig,
rng: &mut impl rand::Rng,
) -> Result<usize> {
sample_token_with_context(logits, config, &[], rng)
}
pub fn sample_token_with_context(
logits: &Tensor,
config: &SamplingConfig,
context: &[usize],
rng: &mut impl rand::Rng,
) -> Result<usize> {
if logits.ndim() != 1 {
return Err(CognitionError::DimensionOutOfRange {
dim: 1,
ndim: logits.ndim(),
operation: "sample_token",
});
}
let n = logits.numel();
if n == 0 {
return Err(CognitionError::EmptyTensor {
operation: "sample_token",
reason: "empty logits vector",
});
}
let penalized = if config.repetition_penalty != 1.0 && !context.is_empty() {
let mut data = logits.data().to_vec();
for &tok in context {
if tok < n {
if data[tok] > 0.0 {
data[tok] /= config.repetition_penalty;
} else {
data[tok] *= config.repetition_penalty;
}
}
}
Tensor::new(data, vec![n])?
} else {
logits.clone()
};
if config.temperature <= 0.0 || config.top_k == 1 {
return penalized.argmax();
}
let scaled = penalized.scale(1.0 / config.temperature);
let mut indexed: Vec<(usize, f64)> = scaled
.data()
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if config.top_k > 0 && config.top_k < n {
indexed.truncate(config.top_k);
}
let max_logit = indexed[0].1;
let exps: Vec<f64> = indexed.iter().map(|(_, l)| (l - max_logit).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum == 0.0 {
return Err(CognitionError::NumericalInstability {
operation: "sample_token",
detail: "softmax sum is zero after filtering".into(),
});
}
let probs: Vec<f64> = exps.iter().map(|e| e / sum).collect();
let mut cumulative = 0.0;
let mut cutoff = probs.len();
if config.top_p < 1.0 {
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if cumulative >= config.top_p {
cutoff = i + 1;
break;
}
}
}
let filtered_probs = &probs[..cutoff];
let filtered_sum: f64 = filtered_probs.iter().sum();
let normed: Vec<f64> = filtered_probs.iter().map(|p| p / filtered_sum).collect();
let u: f64 = rng.random();
let mut acc = 0.0;
for (i, &p) in normed.iter().enumerate() {
acc += p;
if u < acc {
return Ok(indexed[i].0);
}
}
Ok(indexed[cutoff - 1].0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_greedy_picks_max() {
let mut rng = rand::rng();
let logits = Tensor::new(vec![1.0, 5.0, 2.0, 0.5], vec![4]).unwrap();
let config = SamplingConfig::greedy();
let token = sample_token(&logits, &config, &mut rng).unwrap();
assert_eq!(token, 1); }
#[test]
fn test_sampling_returns_valid_index() {
let mut rng = rand::rng();
let logits = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let config = SamplingConfig::creative();
for _ in 0..100 {
let token = sample_token(&logits, &config, &mut rng).unwrap();
assert!(token < 4);
}
}
#[test]
fn test_repetition_penalty() {
let mut rng = rand::rng();
let logits = Tensor::new(vec![1.0, 10.0, 2.0], vec![3]).unwrap();
let config = SamplingConfig {
temperature: 0.01,
top_k: 0,
top_p: 1.0,
repetition_penalty: 100.0, };
let token = sample_token_with_context(&logits, &config, &[1], &mut rng).unwrap();
assert_eq!(token, 2, "repetition penalty should suppress token 1");
}
#[test]
fn test_low_temperature_is_deterministic() {
let mut rng = rand::rng();
let logits = Tensor::new(vec![1.0, 10.0, 2.0], vec![3]).unwrap();
let config = SamplingConfig {
temperature: 0.01,
top_k: 0,
top_p: 1.0,
repetition_penalty: 1.0,
};
let mut count_max = 0;
for _ in 0..100 {
if sample_token(&logits, &config, &mut rng).unwrap() == 1 {
count_max += 1;
}
}
assert!(
count_max > 95,
"low temperature picked max {count_max}/100 times"
);
}
}