use sapient_core::error::{Result, SapientError};
#[derive(Debug, Clone)]
pub enum SamplingStrategy {
Greedy,
Temperature(f32),
TopK { k: usize, temperature: f32 },
TopP { p: f32, temperature: f32 },
Combined {
top_k: usize,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
},
}
impl Default for SamplingStrategy {
fn default() -> Self {
Self::Greedy
}
}
pub struct Sampler {
pub strategy: SamplingStrategy,
rng_seed: u64,
counter: u64,
}
impl Sampler {
pub fn new(strategy: SamplingStrategy) -> Self {
let seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(42);
Self {
strategy,
rng_seed: seed,
counter: 0,
}
}
pub fn with_seed(strategy: SamplingStrategy, seed: u64) -> Self {
Self {
strategy,
rng_seed: seed,
counter: 0,
}
}
pub fn sample(&mut self, logits: &[f32], prev_tokens: &[u32]) -> Result<u32> {
match &self.strategy {
SamplingStrategy::Greedy => Ok(argmax(logits)),
SamplingStrategy::Temperature(t) => {
let t = *t;
if t <= 0.0 {
return Ok(argmax(logits));
}
let scaled = scale_logits(logits, t);
let probs = softmax(&scaled);
Ok(self.random_sample(&probs))
}
SamplingStrategy::TopK { k, temperature } => {
let (k, t) = (*k, *temperature);
if t <= 0.0 {
return Ok(argmax(logits));
}
let scaled = scale_logits(logits, t);
let filtered = top_k_filter(&scaled, k);
let probs = softmax(&filtered);
Ok(self.random_sample(&probs))
}
SamplingStrategy::TopP { p, temperature } => {
let (p, t) = (*p, *temperature);
if t <= 0.0 {
return Ok(argmax(logits));
}
let scaled = scale_logits(logits, t);
let filtered = top_p_filter(&scaled, p);
let probs = softmax(&filtered);
Ok(self.random_sample(&probs))
}
SamplingStrategy::Combined {
top_k,
top_p,
temperature,
repetition_penalty,
} => {
let (k, p, t, rp) = (*top_k, *top_p, *temperature, *repetition_penalty);
let mut penalized = apply_repetition_penalty(logits, prev_tokens, rp);
if t <= 0.0 {
return Ok(argmax(&penalized));
}
penalized = scale_logits(&penalized, t);
penalized = top_k_filter(&penalized, k);
penalized = top_p_filter(&penalized, p);
let probs = softmax(&penalized);
Ok(self.random_sample(&probs))
}
}
}
fn random_u64(&mut self) -> u64 {
self.counter += 1;
let mut x = self
.rng_seed
.wrapping_add(self.counter.wrapping_mul(6364136223846793005));
x ^= x >> 30;
x = x.wrapping_mul(0xbf58476d1ce4e5b9);
x ^= x >> 27;
x = x.wrapping_mul(0x94d049bb133111eb);
x ^= x >> 31;
x
}
fn random_f32(&mut self) -> f32 {
(self.random_u64() >> 11) as f32 / (1u64 << 53) as f32
}
fn random_sample(&mut self, probs: &[f32]) -> u32 {
let r = self.random_f32();
let mut cum = 0.0f32;
for (i, &p) in probs.iter().enumerate() {
cum += p;
if r < cum {
return i as u32;
}
}
(probs.len() - 1) as u32
}
}
pub fn argmax(logits: &[f32]) -> u32 {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0)
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut out: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = out.iter().sum();
out.iter_mut().for_each(|x| *x /= sum);
out
}
fn scale_logits(logits: &[f32], temperature: f32) -> Vec<f32> {
if temperature <= 0.0 || temperature == 1.0 {
return logits.to_vec();
}
logits.iter().map(|&x| x / temperature).collect()
}
fn top_k_filter(logits: &[f32], k: usize) -> Vec<f32> {
if k == 0 || k >= logits.len() {
return logits.to_vec();
}
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let threshold = indexed[k - 1].1;
logits
.iter()
.map(|&x| if x >= threshold { x } else { f32::NEG_INFINITY })
.collect()
}
fn top_p_filter(logits: &[f32], p: f32) -> Vec<f32> {
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let probs = softmax(logits);
let mut sorted_probs: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
sorted_probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let mut cum = 0.0f32;
let mut cutoff_idx = sorted_probs.len();
for (i, (_, prob)) in sorted_probs.iter().enumerate() {
cum += prob;
if cum >= p {
cutoff_idx = i + 1;
break;
}
}
let keep: std::collections::HashSet<usize> =
sorted_probs[..cutoff_idx].iter().map(|(i, _)| *i).collect();
logits
.iter()
.enumerate()
.map(|(i, &x)| {
if keep.contains(&i) {
x
} else {
f32::NEG_INFINITY
}
})
.collect()
}
fn apply_repetition_penalty(logits: &[f32], prev_tokens: &[u32], penalty: f32) -> Vec<f32> {
if (penalty - 1.0).abs() < 1e-6 {
return logits.to_vec();
}
let mut out = logits.to_vec();
for &tok in prev_tokens {
let idx = tok as usize;
if idx < out.len() {
if out[idx] >= 0.0 {
out[idx] /= penalty;
} else {
out[idx] *= penalty;
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn greedy_picks_argmax() {
let logits = vec![0.1, 0.9, 0.3, 0.5];
let mut s = Sampler::with_seed(SamplingStrategy::Greedy, 42);
assert_eq!(s.sample(&logits, &[]).unwrap(), 1);
}
#[test]
fn top_k_removes_low_prob() {
let logits = vec![10.0, 1.0, 1.0, 1.0];
let filtered = top_k_filter(&logits, 1);
assert_eq!(filtered[0], 10.0);
assert!(filtered[1].is_infinite() && filtered[1] < 0.0);
}
#[test]
fn repetition_penalty_reduces_score() {
let logits = vec![1.0, 2.0, 3.0];
let penalized = apply_repetition_penalty(&logits, &[2], 1.3);
assert!(penalized[2] < logits[2]);
}
}