#[allow(clippy::wildcard_imports)]
use super::*;
impl NucleusSampler {
#[must_use]
pub fn new(top_p: f32) -> Self {
assert!(top_p > 0.0 && top_p <= 1.0, "top_p must be in (0.0, 1.0]");
Self {
top_p,
temperature: 1.0,
min_tokens_to_keep: 1,
}
}
#[must_use]
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn with_min_tokens_to_keep(mut self, min_tokens: usize) -> Self {
self.min_tokens_to_keep = min_tokens;
self
}
#[must_use]
pub fn filter(&self, logits: &Tensor) -> Tensor {
let vocab_size = logits.data().len();
let scaled_logits: Vec<f32> = logits
.data()
.iter()
.map(|&x| x / self.temperature)
.collect();
let max_logit = scaled_logits
.iter()
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = scaled_logits
.iter()
.map(|&x| (x - max_logit).exp())
.collect();
let sum: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum).collect();
let mut indices: Vec<usize> = (0..vocab_size).collect();
indices.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut cumsum = 0.0;
let mut cutoff_idx = vocab_size;
for (i, &idx) in indices.iter().enumerate() {
cumsum += probs[idx];
if cumsum >= self.top_p && i >= self.min_tokens_to_keep - 1 {
cutoff_idx = i + 1;
break;
}
}
let mut filtered_logits = vec![f32::NEG_INFINITY; vocab_size];
for &idx in &indices[..cutoff_idx] {
filtered_logits[idx] = scaled_logits[idx];
}
Tensor::new(&filtered_logits, &[vocab_size])
}
#[must_use]
pub fn sample(&self, logits: &Tensor) -> usize {
let filtered = self.filter(logits);
sample_from_logits(&filtered)
}
#[must_use]
pub fn top_p(&self) -> f32 {
self.top_p
}
#[must_use]
pub fn temperature(&self) -> f32 {
self.temperature
}
}
impl std::fmt::Debug for NucleusSampler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NucleusSampler")
.field("top_p", &self.top_p)
.field("temperature", &self.temperature)
.field("min_tokens_to_keep", &self.min_tokens_to_keep)
.finish()
}
}
pub struct TopKSampler {
top_k: usize,
temperature: f32,
}
impl TopKSampler {
#[must_use]
pub fn new(top_k: usize) -> Self {
assert!(top_k > 0, "top_k must be > 0");
Self {
top_k,
temperature: 1.0,
}
}
#[must_use]
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn filter(&self, logits: &Tensor) -> Tensor {
let vocab_size = logits.data().len();
let k = self.top_k.min(vocab_size);
let scaled_logits: Vec<f32> = logits
.data()
.iter()
.map(|&x| x / self.temperature)
.collect();
let mut indices: Vec<usize> = (0..vocab_size).collect();
indices.sort_by(|&a, &b| {
scaled_logits[b]
.partial_cmp(&scaled_logits[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut filtered_logits = vec![f32::NEG_INFINITY; vocab_size];
for &idx in &indices[..k] {
filtered_logits[idx] = scaled_logits[idx];
}
Tensor::new(&filtered_logits, &[vocab_size])
}
#[must_use]
pub fn sample(&self, logits: &Tensor) -> usize {
let filtered = self.filter(logits);
sample_from_logits(&filtered)
}
#[must_use]
pub fn top_k(&self) -> usize {
self.top_k
}
}
impl std::fmt::Debug for TopKSampler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TopKSampler")
.field("top_k", &self.top_k)
.field("temperature", &self.temperature)
.finish()
}
}
#[derive(Debug, Default)]
pub struct GreedyDecoder;
impl GreedyDecoder {
#[must_use]
pub fn new() -> Self {
Self
}
#[must_use]
pub fn decode(&self, logits: &Tensor) -> usize {
argmax(logits.data())
}
}
#[must_use]
pub fn apply_repetition_penalty(
logits: &Tensor,
generated_tokens: &[usize],
penalty: f32,
) -> Tensor {
let mut data = logits.data().to_vec();
for &token_id in generated_tokens {
if token_id < data.len() {
if data[token_id] > 0.0 {
data[token_id] /= penalty;
} else {
data[token_id] *= penalty;
}
}
}
Tensor::new(&data, logits.shape())
}
#[must_use]
pub fn apply_temperature(logits: &Tensor, temperature: f32) -> Tensor {
assert!(temperature > 0.0, "Temperature must be positive");
let data: Vec<f32> = logits.data().iter().map(|&x| x / temperature).collect();
Tensor::new(&data, logits.shape())
}
pub(super) fn sample_from_logits(logits: &Tensor) -> usize {
use rand::Rng;
let mut rng = rand::rng();
let max_logit = logits
.data()
.iter()
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = logits
.data()
.iter()
.map(|&x| (x - max_logit).exp())
.collect();
let sum: f32 = exp_logits.iter().sum();
if sum <= 0.0 {
return 0;
}
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum).collect();
let r: f32 = rng.random();
let mut cumsum = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if r < cumsum {
return i;
}
}
probs.len() - 1
}
pub(super) fn argmax(data: &[f32]) -> usize {
data.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i)
}
#[derive(Debug, Clone)]
pub struct TeacherForcing {
pub(crate) schedule: TeacherForcingSchedule,
pub(crate) initial_ratio: f32,
pub(crate) final_ratio: f32,
pub(crate) num_steps: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TeacherForcingSchedule {
Constant,
Linear,
Exponential,
InverseSquareRoot,
}
#[cfg(test)]
#[path = "tests_sampling_contract.rs"]
mod tests_sampling_contract;