#[derive(Debug, Clone)]
pub struct LcgRng {
state: u64,
}
impl LcgRng {
pub fn new(seed: u64) -> Self {
let state = seed
.wrapping_add(1442695040888963407)
.wrapping_mul(6364136223846793005);
Self { state }
}
pub fn next_u64(&mut self) -> u64 {
self.state = self
.state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
self.state
}
pub fn next_f32(&mut self) -> f32 {
let bits = (self.next_u64() >> 40) as u32;
bits as f32 / (1u32 << 24) as f32
}
pub fn next_usize_below(&mut self, n: usize) -> usize {
assert!(n > 0, "n must be greater than zero");
(self.next_u64() % n as u64) as usize
}
}
pub fn softmax_inplace(logits: &mut [f32]) {
if logits.is_empty() {
return;
}
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0_f32;
for v in logits.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
if sum > 0.0 {
for v in logits.iter_mut() {
*v /= sum;
}
}
}
pub fn log_softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let log_sum_exp = logits.iter().map(|&v| (v - max).exp()).sum::<f32>().ln() + max;
logits.iter().map(|&v| v - log_sum_exp).collect()
}
pub fn entropy(probs: &[f32]) -> f32 {
probs
.iter()
.filter(|&&p| p > 0.0)
.map(|&p| -p * p.ln())
.sum()
}
pub fn perplexity(log_probs: &[f32]) -> f32 {
if log_probs.is_empty() {
return 1.0;
}
let mean_neg_log: f32 = log_probs.iter().map(|&lp| -lp).sum::<f32>() / log_probs.len() as f32;
mean_neg_log.exp()
}
pub fn top_k_indices(logits: &[f32], k: usize) -> Vec<usize> {
if k == 0 || logits.is_empty() {
return Vec::new();
}
let k = k.min(logits.len());
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
indexed.into_iter().map(|(i, _)| i).collect()
}
pub fn apply_temperature(logits: &mut [f32], temp: f32) {
if temp > 0.0 {
for v in logits.iter_mut() {
*v /= temp;
}
}
}
pub fn apply_repetition_penalty(logits: &mut [f32], token_ids: &[u32], penalty: f32) {
if penalty == 1.0 || token_ids.is_empty() {
return;
}
for &id in token_ids {
let idx = id as usize;
if idx < logits.len() {
if logits[idx] >= 0.0 {
logits[idx] /= penalty;
} else {
logits[idx] *= penalty;
}
}
}
}
fn categorical_sample(probs: &[(usize, f32)], rng: &mut LcgRng) -> usize {
let u = rng.next_f32();
let mut cumsum = 0.0_f32;
for &(idx, p) in probs {
cumsum += p;
if u < cumsum {
return idx;
}
}
probs.first().map(|&(i, _)| i).unwrap_or(0)
}
#[derive(Debug, Clone)]
pub struct MirostatV1Sampler {
pub tau: f32,
pub eta: f32,
pub m: usize,
mu: f32,
}
impl MirostatV1Sampler {
pub fn new(tau: f32, eta: f32, m: usize) -> Self {
Self {
tau,
eta,
m,
mu: 2.0 * tau,
}
}
pub fn sample(&mut self, logits: &[f32], rng: &mut LcgRng) -> usize {
if logits.is_empty() {
return 0;
}
let mut candidates: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let m = self.m.min(candidates.len()).max(1);
candidates.truncate(m);
let max_v = candidates[0].1;
let mut sum = 0.0_f32;
for (_, v) in candidates.iter_mut() {
*v = (*v - max_v).exp();
sum += *v;
}
if sum > 0.0 {
for (_, v) in candidates.iter_mut() {
*v /= sum;
}
}
let filtered: Vec<(usize, f32)> = candidates
.iter()
.cloned()
.filter(|&(_, p)| p > 0.0 && (-p.log2()) <= self.mu)
.collect();
let pool = if filtered.is_empty() {
&candidates
} else {
&filtered
};
let pool_sum: f32 = pool.iter().map(|(_, p)| p).sum();
let normalised: Vec<(usize, f32)> = if pool_sum > 0.0 {
pool.iter().map(|&(i, p)| (i, p / pool_sum)).collect()
} else {
pool.to_vec()
};
let chosen = categorical_sample(&normalised, rng);
if let Some(&(_, p)) = normalised.iter().find(|&&(i, _)| i == chosen) {
if p > 0.0 {
let surprise = -p.log2();
self.mu -= self.eta * (surprise - self.tau);
}
}
chosen
}
pub fn reset(&mut self) {
self.mu = 2.0 * self.tau;
}
}
#[derive(Debug, Clone)]
pub struct MirostatV2Sampler {
pub tau: f32,
pub eta: f32,
mu: f32,
}
impl MirostatV2Sampler {
pub fn new(tau: f32, eta: f32) -> Self {
Self {
tau,
eta,
mu: 2.0 * tau,
}
}
pub fn sample(&mut self, logits: &[f32], rng: &mut LcgRng) -> usize {
if logits.is_empty() {
return 0;
}
let mut probs: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
{
let max_v = probs
.iter()
.map(|(_, v)| *v)
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0_f32;
for (_, v) in probs.iter_mut() {
*v = (*v - max_v).exp();
sum += *v;
}
if sum > 0.0 {
for (_, v) in probs.iter_mut() {
*v /= sum;
}
}
}
let threshold = (-self.mu * std::f32::consts::LN_2).exp();
let mut pool: Vec<(usize, f32)> = probs
.iter()
.cloned()
.filter(|&(_, p)| p >= threshold)
.collect();
if pool.is_empty() {
probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
pool.push(probs[0]);
}
let pool_sum: f32 = pool.iter().map(|(_, p)| p).sum();
if pool_sum > 0.0 {
for (_, p) in pool.iter_mut() {
*p /= pool_sum;
}
}
let chosen = categorical_sample(&pool, rng);
if let Some(&(_, p)) = pool.iter().find(|&&(i, _)| i == chosen) {
if p > 0.0 {
let surprise = -p.log2();
self.mu -= self.eta * (surprise - self.tau);
}
}
chosen
}
pub fn reset(&mut self) {
self.mu = 2.0 * self.tau;
}
pub fn mu(&self) -> f32 {
self.mu
}
}
#[derive(Debug, Clone)]
pub struct TypicalSampler {
pub p: f32,
pub min_keep: usize,
}
impl TypicalSampler {
pub fn new(p: f32, min_keep: usize) -> Self {
Self {
p: p.clamp(0.0, 1.0),
min_keep: min_keep.max(1),
}
}
pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
if logits.is_empty() {
return 0;
}
let log_probs = log_softmax(logits);
let probs: Vec<f32> = log_probs.iter().map(|&lp| lp.exp()).collect();
let h = entropy(&probs);
let mut candidates: Vec<(usize, f32, f32)> = log_probs
.iter()
.cloned()
.zip(probs.iter().cloned())
.enumerate()
.map(|(i, (lp, p))| {
let typicality = (-lp - h).abs();
(i, p, typicality)
})
.collect();
candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0_f32;
let mut keep = 0;
for (k, &(_, p, _)) in candidates.iter().enumerate() {
cumsum += p;
keep = k + 1;
if cumsum >= self.p && keep >= self.min_keep {
break;
}
}
keep = keep.max(self.min_keep).min(candidates.len());
candidates.truncate(keep);
let total: f32 = candidates.iter().map(|(_, p, _)| p).sum();
let normalised: Vec<(usize, f32)> = candidates
.iter()
.map(|&(i, p, _)| (i, if total > 0.0 { p / total } else { p }))
.collect();
categorical_sample(&normalised, rng)
}
}
#[derive(Debug, Clone)]
pub struct EtaSampler {
pub epsilon: f32,
pub delta: f32,
}
impl EtaSampler {
pub fn new(epsilon: f32, delta: f32) -> Self {
Self { epsilon, delta }
}
pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
if logits.is_empty() {
return 0;
}
let mut probs: Vec<f32> = logits.to_vec();
softmax_inplace(&mut probs);
let h = entropy(&probs);
let eta_threshold = (self.epsilon).max((-h).exp().sqrt() * self.delta);
let mut candidates: Vec<(usize, f32)> = probs
.iter()
.cloned()
.enumerate()
.filter(|&(_, p)| p >= eta_threshold)
.collect();
if candidates.is_empty() {
let best = probs
.iter()
.cloned()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
return best;
}
let total: f32 = candidates.iter().map(|(_, p)| p).sum();
if total > 0.0 {
for (_, p) in candidates.iter_mut() {
*p /= total;
}
}
categorical_sample(&candidates, rng)
}
}
#[derive(Debug, Clone)]
pub struct MinPSampler {
pub min_p: f32,
pub min_keep: usize,
}
impl MinPSampler {
pub fn new(min_p: f32, min_keep: usize) -> Self {
Self {
min_p: min_p.clamp(0.0, 1.0),
min_keep: min_keep.max(1),
}
}
pub fn sample(&self, logits: &[f32], rng: &mut LcgRng) -> usize {
if logits.is_empty() {
return 0;
}
let mut probs: Vec<f32> = logits.to_vec();
softmax_inplace(&mut probs);
let max_p = probs.iter().cloned().fold(0.0_f32, f32::max);
let threshold = self.min_p * max_p;
let mut candidates: Vec<(usize, f32)> = probs
.iter()
.cloned()
.enumerate()
.filter(|&(_, p)| p >= threshold)
.collect();
if candidates.len() < self.min_keep {
let mut all: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates = all.into_iter().take(self.min_keep).collect();
}
let total: f32 = candidates.iter().map(|(_, p)| p).sum();
if total > 0.0 {
for (_, p) in candidates.iter_mut() {
*p /= total;
}
}
categorical_sample(&candidates, rng)
}
}
#[derive(Debug, Clone)]
pub enum SamplerStep {
Temperature(f32),
RepetitionPenalty {
penalty: f32,
last_n: usize,
tokens: Vec<u32>,
},
TopK(usize),
TopP(f32),
MinP(f32),
Typical(f32),
Mirostat2 {
tau: f32,
eta: f32,
},
Greedy,
}
#[derive(Debug, Clone)]
pub struct SamplerChain {
steps: Vec<SamplerStep>,
rng: LcgRng,
mirostat2: Option<MirostatV2Sampler>,
}
impl SamplerChain {
pub fn new(seed: u64) -> Self {
Self {
steps: Vec::new(),
rng: LcgRng::new(seed),
mirostat2: None,
}
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, step: SamplerStep) -> Self {
if let SamplerStep::Mirostat2 { tau, eta } = step {
self.mirostat2 = Some(MirostatV2Sampler::new(tau, eta));
}
self.steps.push(step);
self
}
pub fn sample(&mut self, logits: &mut Vec<f32>) -> usize {
if logits.is_empty() {
return 0;
}
for step in &self.steps {
match step {
SamplerStep::Temperature(temp) => {
if *temp < 1e-6 {
return argmax_slice(logits);
}
apply_temperature(logits, *temp);
}
SamplerStep::RepetitionPenalty {
penalty,
last_n,
tokens,
} => {
let window = if *last_n == 0 {
tokens.as_slice()
} else {
let start = tokens.len().saturating_sub(*last_n);
&tokens[start..]
};
apply_repetition_penalty(logits, window, *penalty);
}
SamplerStep::TopK(k) => {
if *k > 0 && *k < logits.len() {
let indices = top_k_indices(logits, *k);
let mut mask = vec![f32::NEG_INFINITY; logits.len()];
for i in indices {
mask[i] = logits[i];
}
*logits = mask;
}
}
SamplerStep::TopP(p) => {
if *p < 1.0 {
apply_top_p(logits, *p, &mut self.rng);
}
}
SamplerStep::MinP(min_p) => {
let sampler = MinPSampler::new(*min_p, 1);
return sampler.sample(logits, &mut self.rng);
}
SamplerStep::Typical(p) => {
let sampler = TypicalSampler::new(*p, 1);
return sampler.sample(logits, &mut self.rng);
}
SamplerStep::Mirostat2 { .. } => {
if let Some(ref mut ms) = self.mirostat2 {
return ms.sample(logits, &mut self.rng);
}
}
SamplerStep::Greedy => {
return argmax_slice(logits);
}
}
}
softmax_inplace(logits);
let probs: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
categorical_sample(&probs, &mut self.rng)
}
pub fn greedy() -> Self {
Self::new(0).add(SamplerStep::Greedy)
}
pub fn default_chat(seed: u64) -> Self {
Self::new(seed)
.add(SamplerStep::Temperature(0.7))
.add(SamplerStep::TopP(0.9))
.add(SamplerStep::MinP(0.05))
}
pub fn creative(seed: u64) -> Self {
Self::new(seed)
.add(SamplerStep::Temperature(1.0))
.add(SamplerStep::Mirostat2 { tau: 5.0, eta: 0.1 })
}
pub fn precise(seed: u64) -> Self {
Self::new(seed)
.add(SamplerStep::Temperature(0.3))
.add(SamplerStep::TopK(40))
.add(SamplerStep::TopP(0.9))
}
}
fn argmax_slice(values: &[f32]) -> usize {
values
.iter()
.cloned()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
fn apply_top_p(logits: &mut [f32], p: f32, _rng: &mut LcgRng) {
let max_v = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<(usize, f32)> = logits
.iter()
.enumerate()
.map(|(i, &v)| (i, (v - max_v).exp()))
.collect();
let total: f32 = probs.iter().map(|(_, v)| v).sum();
if total > 0.0 {
for (_, v) in probs.iter_mut() {
*v /= total;
}
}
probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0_f32;
let mut nucleus_end = 0;
for (k, &(_, prob)) in probs.iter().enumerate() {
cumsum += prob;
nucleus_end = k;
if cumsum >= p {
break;
}
}
let nucleus_indices: std::collections::HashSet<usize> =
probs[..=nucleus_end].iter().map(|&(i, _)| i).collect();
for (i, v) in logits.iter_mut().enumerate() {
if !nucleus_indices.contains(&i) {
*v = f32::NEG_INFINITY;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lcg_rng_produces_values() {
let mut rng = LcgRng::new(1);
let v = rng.next_f32();
assert!((0.0..1.0).contains(&v), "f32 out of range: {v}");
}
#[test]
fn softmax_sums_to_one() {
let mut logits = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
softmax_inplace(&mut logits);
let sum: f32 = logits.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "sum={sum}");
}
#[test]
fn mirostat_v2_returns_valid_index() {
let logits = vec![1.0_f32, 5.0, 2.0, 3.0];
let mut sampler = MirostatV2Sampler::new(5.0, 0.1);
let mut rng = LcgRng::new(99);
let idx = sampler.sample(&logits, &mut rng);
assert!(idx < logits.len());
}
#[test]
fn sampler_chain_greedy_preset() {
let mut chain = SamplerChain::greedy();
let mut logits = vec![0.1_f32, 5.0, 0.2, 0.3];
let tok = chain.sample(&mut logits);
assert_eq!(tok, 1); }
}