use candle_core::{DType, Result, Tensor};
#[derive(Debug, Clone)]
pub struct SamplingConfig {
pub temperature: f64,
pub top_k: usize,
pub top_p: f64,
pub do_sample: bool,
pub repetition_penalty: f64,
pub suppress_tokens: Vec<usize>,
pub min_new_tokens: usize,
pub eos_token_id: Option<usize>,
pub subtalker_do_sample: bool,
pub subtalker_temperature: f64,
pub subtalker_top_k: usize,
pub subtalker_top_p: f64,
}
impl Default for SamplingConfig {
fn default() -> Self {
Self {
temperature: 0.9,
top_k: 50,
top_p: 1.0,
do_sample: true,
repetition_penalty: 1.05,
suppress_tokens: Vec::new(),
min_new_tokens: 2, eos_token_id: None,
subtalker_do_sample: true,
subtalker_temperature: 0.9,
subtalker_top_k: 50,
subtalker_top_p: 1.0,
}
}
}
impl SamplingConfig {
pub fn greedy() -> Self {
Self {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
do_sample: false,
repetition_penalty: 1.0,
suppress_tokens: Vec::new(),
min_new_tokens: 2,
eos_token_id: None,
subtalker_do_sample: false,
subtalker_temperature: 0.9,
subtalker_top_k: 50,
subtalker_top_p: 1.0,
}
}
pub fn with_eos_token_id(mut self, eos_token_id: usize) -> Self {
self.eos_token_id = Some(eos_token_id);
self
}
pub fn with_min_new_tokens(mut self, min_new_tokens: usize) -> Self {
self.min_new_tokens = min_new_tokens;
self
}
pub fn with_suppress_tokens(mut self, tokens: Vec<usize>) -> Self {
self.suppress_tokens = tokens;
self
}
pub fn for_subtalker(&self) -> Self {
Self {
temperature: self.subtalker_temperature,
top_k: self.subtalker_top_k,
top_p: self.subtalker_top_p,
do_sample: self.subtalker_do_sample,
repetition_penalty: 1.0, suppress_tokens: Vec::new(), min_new_tokens: 0, eos_token_id: None,
subtalker_do_sample: self.subtalker_do_sample,
subtalker_temperature: self.subtalker_temperature,
subtalker_top_k: self.subtalker_top_k,
subtalker_top_p: self.subtalker_top_p,
}
}
pub fn with_subtalker_params(
mut self,
do_sample: bool,
temperature: f64,
top_k: usize,
top_p: f64,
) -> Self {
self.subtalker_do_sample = do_sample;
self.subtalker_temperature = temperature;
self.subtalker_top_k = top_k;
self.subtalker_top_p = top_p;
self
}
}
pub fn apply_suppress_tokens(logits: &Tensor, suppress_tokens: &[usize]) -> Result<Tensor> {
if suppress_tokens.is_empty() {
return Ok(logits.clone());
}
let mut logits_vec = logits.to_vec1::<f32>()?;
let vocab_size = logits_vec.len();
for &token_id in suppress_tokens {
if token_id < vocab_size {
logits_vec[token_id] = f32::NEG_INFINITY;
}
}
Tensor::from_vec(logits_vec, logits.shape(), logits.device())
}
pub fn apply_repetition_penalty(
logits: &Tensor,
generated_tokens: &[i64],
penalty: f64,
) -> Result<Tensor> {
if penalty == 1.0 || generated_tokens.is_empty() {
return Ok(logits.clone());
}
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in generated_tokens {
if (token_id as usize) < logits_vec.len() {
let idx = token_id as usize;
if logits_vec[idx] > 0.0 {
logits_vec[idx] /= penalty as f32;
} else {
logits_vec[idx] *= penalty as f32;
}
}
}
Tensor::from_vec(logits_vec, logits.shape(), logits.device())
}
pub fn apply_temperature(logits: &Tensor, temperature: f64) -> Result<Tensor> {
if temperature == 1.0 {
return Ok(logits.clone());
}
logits.affine(1.0 / temperature, 0.0)
}
pub fn apply_top_k(logits: &Tensor, k: usize) -> Result<Tensor> {
if k == 0 || k >= logits.dim(0)? {
return Ok(logits.clone());
}
let mut logits_vec = logits.to_vec1::<f32>()?;
let mut sorted = logits_vec.clone();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let threshold = sorted[k - 1];
for logit in &mut logits_vec {
if *logit < threshold {
*logit = f32::NEG_INFINITY;
}
}
Tensor::from_vec(logits_vec, logits.shape(), logits.device())
}
pub fn apply_top_p(logits: &Tensor, p: f64) -> Result<Tensor> {
if p >= 1.0 {
return Ok(logits.clone());
}
let logits_vec = logits.to_vec1::<f32>()?;
let vocab_size = logits_vec.len();
let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = logits_vec.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
let mut indices: Vec<usize> = (0..vocab_size).collect();
indices.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut cumsum = 0.0;
let mut cutoff_idx = vocab_size;
for (i, &idx) in indices.iter().enumerate() {
cumsum += probs[idx];
if cumsum > p as f32 {
cutoff_idx = i + 1;
break;
}
}
let mut filtered_logits = vec![f32::NEG_INFINITY; vocab_size];
for &idx in indices.iter().take(cutoff_idx) {
filtered_logits[idx] = logits_vec[idx];
}
Tensor::from_vec(filtered_logits, logits.shape(), logits.device())
}
pub fn sample_token(
logits: &Tensor,
config: &SamplingConfig,
generated_tokens: &[i64],
) -> Result<i64> {
sample_token_with_step(logits, config, generated_tokens, generated_tokens.len())
}
pub fn sample_token_with_step(
logits: &Tensor,
config: &SamplingConfig,
generated_tokens: &[i64],
current_step: usize,
) -> Result<i64> {
let logits = logits.to_dtype(DType::F32)?;
let mut suppress_tokens = config.suppress_tokens.clone();
if current_step < config.min_new_tokens
&& let Some(eos_id) = config.eos_token_id
&& !suppress_tokens.contains(&eos_id)
{
suppress_tokens.push(eos_id);
}
let logits = apply_repetition_penalty(&logits, generated_tokens, config.repetition_penalty)?;
let logits = apply_suppress_tokens(&logits, &suppress_tokens)?;
let logits = apply_temperature(&logits, config.temperature)?;
let logits = apply_top_k(&logits, config.top_k)?;
let logits = apply_top_p(&logits, config.top_p)?;
if !config.do_sample {
let token_id = logits.argmax(0)?.to_scalar::<u32>()? as i64;
return Ok(token_id);
}
let probs = candle_nn::ops::softmax(&logits, 0)?;
sample_from_probs(&probs)
}
fn sample_from_probs(probs: &Tensor) -> Result<i64> {
let probs_vec = probs.to_vec1::<f32>()?;
let sum: f32 = probs_vec.iter().sum();
if sum <= 0.0 || !sum.is_finite() {
return Ok(probs_vec
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0) as i64);
}
let random: f32 = super::mt_rng::global_uniform();
let mut cumsum = 0.0;
let inv_sum = 1.0 / sum;
for (idx, &prob) in probs_vec.iter().enumerate() {
cumsum += prob * inv_sum;
if random < cumsum {
return Ok(idx as i64);
}
}
Ok(probs_vec
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0) as i64)
}
pub fn sample_token_batch(
logits: &Tensor,
config: &SamplingConfig,
generated_tokens: &[Vec<i64>],
) -> Result<Tensor> {
let batch_size = logits.dim(0)?;
let mut tokens = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let batch_logits = logits.get(b)?;
let batch_generated = if b < generated_tokens.len() {
&generated_tokens[b]
} else {
&Vec::new()
};
let token = sample_token(&batch_logits, config, batch_generated)?;
tokens.push(token);
}
Tensor::from_vec(tokens, batch_size, logits.device())
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_temperature() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], 3, &device)?;
let scaled = apply_temperature(&logits, 1.0)?;
assert_eq!(scaled.to_vec1::<f32>()?, vec![1.0, 2.0, 3.0]);
let scaled = apply_temperature(&logits, 0.5)?;
assert_eq!(scaled.to_vec1::<f32>()?, vec![2.0, 4.0, 6.0]);
Ok(())
}
#[test]
fn test_top_k() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![1.0f32, 5.0, 3.0, 4.0, 2.0], 5, &device)?;
let filtered = apply_top_k(&logits, 2)?;
let filtered_vec = filtered.to_vec1::<f32>()?;
assert!(filtered_vec[0].is_infinite() && filtered_vec[0].is_sign_negative());
assert_eq!(filtered_vec[1], 5.0);
assert!(filtered_vec[2].is_infinite() && filtered_vec[2].is_sign_negative());
assert_eq!(filtered_vec[3], 4.0);
assert!(filtered_vec[4].is_infinite() && filtered_vec[4].is_sign_negative());
Ok(())
}
#[test]
fn test_repetition_penalty() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![2.0f32, -1.0, 3.0], 3, &device)?;
let penalized = apply_repetition_penalty(&logits, &[0, 1], 2.0)?;
let penalized_vec = penalized.to_vec1::<f32>()?;
assert_eq!(penalized_vec[0], 1.0); assert_eq!(penalized_vec[1], -2.0); assert_eq!(penalized_vec[2], 3.0);
Ok(())
}
#[test]
fn test_greedy_sampling() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![1.0f32, 5.0, 3.0], 3, &device)?;
let config = SamplingConfig::greedy();
let token = sample_token(&logits, &config, &[])?;
assert_eq!(token, 1);
Ok(())
}
#[test]
fn test_sampling_with_temperature() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![1.0f32, 5.0, 3.0], 3, &device)?;
let config = SamplingConfig {
temperature: 1.0,
do_sample: false, ..Default::default()
};
let token = sample_token(&logits, &config, &[])?;
assert_eq!(token, 1);
Ok(())
}
#[test]
fn test_suppress_tokens() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![1.0f32, 5.0, 3.0, 4.0], 4, &device)?;
let suppressed = apply_suppress_tokens(&logits, &[1])?;
let suppressed_vec = suppressed.to_vec1::<f32>()?;
assert_eq!(suppressed_vec[0], 1.0);
assert!(suppressed_vec[1].is_infinite() && suppressed_vec[1].is_sign_negative());
assert_eq!(suppressed_vec[2], 3.0);
assert_eq!(suppressed_vec[3], 4.0);
Ok(())
}
#[test]
fn test_suppress_tokens_in_sampling() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![1.0f32, 5.0, 3.0], 3, &device)?;
let config = SamplingConfig::greedy().with_suppress_tokens(vec![1]);
let token = sample_token(&logits, &config, &[])?;
assert_eq!(token, 2);
Ok(())
}
#[test]
fn test_min_new_tokens() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![1.0f32, 5.0, 3.0], 3, &device)?;
let config = SamplingConfig::greedy()
.with_eos_token_id(1)
.with_min_new_tokens(2);
let token = sample_token_with_step(&logits, &config, &[], 0)?;
assert_eq!(token, 2);
let token = sample_token_with_step(&logits, &config, &[2], 1)?;
assert_eq!(token, 2);
let token = sample_token_with_step(&logits, &config, &[2, 2], 2)?;
assert_eq!(token, 1);
Ok(())
}
#[test]
fn test_min_new_tokens_with_sample_token() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(vec![1.0f32, 5.0, 3.0], 3, &device)?;
let config = SamplingConfig::greedy()
.with_eos_token_id(1)
.with_min_new_tokens(2);
let token = sample_token(&logits, &config, &[])?;
assert_eq!(token, 2);
let token = sample_token(&logits, &config, &[2])?;
assert_eq!(token, 2);
let token = sample_token(&logits, &config, &[2, 2])?;
assert_eq!(token, 1);
Ok(())
}
}