pub mod advanced;
pub mod chain;
pub mod grammar;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use grammar::{apply_grammar_mask, Grammar, GrammarState};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplerConfig {
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
pub min_p: f32,
pub repetition_penalty: f32,
pub repetition_penalty_window: usize,
pub seed: Option<u64>,
pub mirostat: u8,
pub mirostat_tau: f32,
pub mirostat_eta: f32,
#[serde(skip)]
pub grammar: Option<Arc<Grammar>>,
#[serde(skip)]
#[allow(clippy::type_complexity)]
pub token_vocab: Option<Arc<Vec<(u32, Vec<u8>)>>>,
#[serde(default)]
pub logit_bias: std::collections::HashMap<u32, f32>,
#[serde(default)]
pub banned_tokens: Vec<u32>,
#[serde(default)]
pub dry_multiplier: f32,
#[serde(default = "dry_base_default")]
pub dry_base: f32,
#[serde(default = "dry_allowed_length_default")]
pub dry_allowed_length: usize,
#[serde(default)]
pub xtc_threshold: f32,
#[serde(default = "xtc_probability_default")]
pub xtc_probability: f32,
#[serde(default = "typical_p_default")]
pub typical_p: f32,
#[serde(default)]
pub top_a: f32,
#[serde(default)]
pub eta_cutoff: f32,
#[serde(default)]
pub epsilon_cutoff: f32,
}
fn dry_base_default() -> f32 {
1.75
}
fn dry_allowed_length_default() -> usize {
2
}
fn xtc_probability_default() -> f32 {
0.5
}
fn typical_p_default() -> f32 {
1.0
}
impl Default for SamplerConfig {
fn default() -> Self {
Self {
temperature: 0.7,
top_k: 40,
top_p: 0.9,
min_p: 0.0,
repetition_penalty: 1.1,
repetition_penalty_window: 64,
seed: None,
mirostat: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
grammar: None,
token_vocab: None,
logit_bias: std::collections::HashMap::new(),
banned_tokens: Vec::new(),
dry_multiplier: 0.0,
dry_base: 1.75,
dry_allowed_length: 2,
xtc_threshold: 0.0,
xtc_probability: 0.5,
typical_p: 1.0,
top_a: 0.0,
eta_cutoff: 0.0,
epsilon_cutoff: 0.0,
}
}
}
impl SamplerConfig {
pub fn greedy() -> Self {
Self {
temperature: 0.0,
top_k: 1,
top_p: 1.0,
min_p: 0.0,
repetition_penalty: 1.0,
repetition_penalty_window: 0,
seed: None,
mirostat: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
grammar: None,
token_vocab: None,
logit_bias: std::collections::HashMap::new(),
banned_tokens: Vec::new(),
dry_multiplier: 0.0,
dry_base: 1.75,
dry_allowed_length: 2,
xtc_threshold: 0.0,
xtc_probability: 0.5,
typical_p: 1.0,
top_a: 0.0,
eta_cutoff: 0.0,
epsilon_cutoff: 0.0,
}
}
pub fn mirostat_v2(tau: f32, eta: f32) -> Self {
Self {
temperature: 1.0,
mirostat: 2,
mirostat_tau: tau,
mirostat_eta: eta,
top_k: 0,
top_p: 1.0,
min_p: 0.0,
repetition_penalty: 1.0,
repetition_penalty_window: 0,
seed: None,
grammar: None,
token_vocab: None,
logit_bias: std::collections::HashMap::new(),
banned_tokens: Vec::new(),
dry_multiplier: 0.0,
dry_base: 1.75,
dry_allowed_length: 2,
xtc_threshold: 0.0,
xtc_probability: 0.5,
typical_p: 1.0,
top_a: 0.0,
eta_cutoff: 0.0,
epsilon_cutoff: 0.0,
}
}
}
pub struct Sampler {
config: SamplerConfig,
rng: Xorshift64,
mirostat_mu: f32,
grammar_state: Option<GrammarState>,
}
impl Sampler {
pub fn new(config: SamplerConfig) -> Self {
let seed = config.seed.unwrap_or_else(|| {
let mut s = 0x517cc1b727220a95u64;
s ^= (&s as *const u64 as u64).wrapping_mul(0x9e3779b97f4a7c15);
s ^ s.wrapping_shr(33)
});
let mirostat_mu = 2.0 * config.mirostat_tau;
let grammar_state = config.grammar.as_ref().map(|g| g.initial_state());
Self {
config,
rng: Xorshift64::new(seed),
mirostat_mu,
grammar_state,
}
}
pub fn sample(&mut self, logits: &[f32], recent_tokens: &[u32]) -> u32 {
let token = if self.config.mirostat == 2 {
self.sample_mirostat_v2(logits, recent_tokens)
} else {
sample_with_rng(
logits,
&self.config,
recent_tokens,
&mut self.rng,
self.grammar_state.as_ref(),
)
};
if let Some(state) = &mut self.grammar_state {
if let Some(vocab) = &self.config.token_vocab {
if let Ok(idx) = vocab.binary_search_by_key(&token, |&(id, _)| id) {
let bytes = vocab[idx].1.clone();
let _ = state.advance(&bytes);
}
}
}
token
}
pub fn reset_grammar(&mut self) {
self.grammar_state = self.config.grammar.as_ref().map(|g| g.initial_state());
}
pub fn grammar_complete(&self) -> bool {
self.grammar_state
.as_ref()
.is_none_or(GrammarState::is_complete)
}
fn sample_mirostat_v2(&mut self, logits: &[f32], recent_tokens: &[u32]) -> u32 {
if logits.is_empty() {
return 0;
}
let mut processed = logits.to_vec();
apply_logit_bias_and_banned_tokens(&mut processed, &self.config);
apply_repetition_penalty(&mut processed, &self.config, recent_tokens);
if let (Some(state), Some(vocab)) = (&self.grammar_state, &self.config.token_vocab) {
apply_grammar_mask(&mut processed, state, vocab.as_ref());
}
if self.config.temperature > 0.0 && self.config.temperature != 1.0 {
let inv_temp = 1.0 / self.config.temperature;
for val in &mut processed {
*val *= inv_temp;
}
}
let mut candidates: Vec<(u32, f32)> = processed
.iter()
.enumerate()
.map(|(i, &v)| (i as u32, v))
.collect();
candidates
.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
softmax_candidates(&mut candidates);
let mu = self.mirostat_mu;
candidates.retain(|&(_, p)| {
if p <= 0.0 {
return false;
}
let surprise = -p.log2();
surprise <= mu
});
if candidates.is_empty() {
let token = argmax(&processed);
let top_prob = softmax_single_max(&processed);
let surprise = if top_prob > 0.0 {
-top_prob.log2()
} else {
self.config.mirostat_tau
};
self.mirostat_mu =
mu - self.config.mirostat_eta * (surprise - self.config.mirostat_tau);
return token;
}
let total: f32 = candidates.iter().map(|(_, p)| p).sum();
if total > 0.0 && total != 1.0 {
for (_, p) in &mut candidates {
*p /= total;
}
}
let r = self.rng.next_f32();
let mut cumulative = 0.0f32;
let mut selected_idx = candidates[0].0;
let mut selected_prob = candidates[0].1 * total; for &(idx, prob) in &candidates {
cumulative += prob;
if r < cumulative {
selected_idx = idx;
selected_prob = prob * total;
break;
}
}
let surprise = if selected_prob > 0.0 {
-selected_prob.log2()
} else {
self.config.mirostat_tau
};
self.mirostat_mu = mu - self.config.mirostat_eta * (surprise - self.config.mirostat_tau);
selected_idx
}
pub fn config(&self) -> &SamplerConfig {
&self.config
}
pub fn rng_state(&self) -> u64 {
self.rng.state_value()
}
pub fn mirostat_mu_value(&self) -> f32 {
self.mirostat_mu
}
pub fn restore_rng_state(&mut self, state: u64, mu: f32) {
self.rng = Xorshift64::from_state_value(state);
self.mirostat_mu = mu;
}
}
pub fn sample(logits: &[f32], config: &SamplerConfig, recent_tokens: &[u32]) -> u32 {
if logits.is_empty() {
return 0;
}
let seed = config.seed.unwrap_or(0xDEADBEEF_CAFEBABE);
let mut rng = Xorshift64::new(seed);
sample_with_rng(logits, config, recent_tokens, &mut rng, None)
}
fn sample_with_rng(
logits: &[f32],
config: &SamplerConfig,
recent_tokens: &[u32],
rng: &mut Xorshift64,
grammar_state: Option<&GrammarState>,
) -> u32 {
if logits.is_empty() {
return 0;
}
let mut processed = logits.to_vec();
apply_logit_bias_and_banned_tokens(&mut processed, config);
apply_repetition_penalty(&mut processed, config, recent_tokens);
if let (Some(state), Some(vocab)) = (grammar_state, &config.token_vocab) {
apply_grammar_mask(&mut processed, state, vocab.as_ref());
}
if config.temperature <= 0.0 || config.top_k == 1 {
return argmax(&processed);
}
if config.temperature != 1.0 {
let inv_temp = 1.0 / config.temperature;
for val in &mut processed {
*val *= inv_temp;
}
}
let mut candidates: Vec<(u32, f32)> = processed
.iter()
.enumerate()
.map(|(i, &v)| (i as u32, v))
.collect();
candidates.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if config.top_k > 0 && config.top_k < candidates.len() {
candidates.truncate(config.top_k);
}
softmax_candidates(&mut candidates);
if config.min_p > 0.0 && !candidates.is_empty() {
let max_prob = candidates[0].1; let threshold = config.min_p * max_prob;
candidates.retain(|&(_, p)| p >= threshold);
}
if config.top_p < 1.0 && !candidates.is_empty() {
let mut cumulative = 0.0f32;
let mut cutoff = candidates.len();
for (i, &(_, prob)) in candidates.iter().enumerate() {
cumulative += prob;
if cumulative >= config.top_p {
cutoff = i + 1;
break;
}
}
candidates.truncate(cutoff);
}
let total: f32 = candidates.iter().map(|(_, p)| p).sum();
if total > 0.0 && total != 1.0 {
for (_, p) in &mut candidates {
*p /= total;
}
}
if candidates.is_empty() {
return argmax(&processed);
}
if candidates.len() == 1 {
return candidates[0].0;
}
let r = rng.next_f32();
let mut cumulative = 0.0f32;
for &(idx, prob) in &candidates {
cumulative += prob;
if r < cumulative {
return idx;
}
}
candidates.last().map(|&(idx, _)| idx).unwrap_or(0)
}
fn apply_logit_bias_and_banned_tokens(processed: &mut [f32], config: &SamplerConfig) {
for &token in &config.banned_tokens {
let idx = token as usize;
if idx < processed.len() {
processed[idx] = f32::NEG_INFINITY;
}
}
for (&token, &bias) in &config.logit_bias {
let idx = token as usize;
if idx < processed.len() {
if processed[idx].is_finite() {
processed[idx] += bias;
}
}
}
}
fn apply_repetition_penalty(processed: &mut [f32], config: &SamplerConfig, recent_tokens: &[u32]) {
if config.repetition_penalty == 1.0 || recent_tokens.is_empty() {
return;
}
let window_start = recent_tokens
.len()
.saturating_sub(config.repetition_penalty_window);
for &token in &recent_tokens[window_start..] {
let idx = token as usize;
if idx < processed.len() {
if processed[idx] > 0.0 {
processed[idx] /= config.repetition_penalty;
} else {
processed[idx] *= config.repetition_penalty;
}
}
}
}
fn softmax_candidates(candidates: &mut [(u32, f32)]) {
if candidates.is_empty() {
return;
}
let max_logit = candidates
.iter()
.map(|(_, v)| *v)
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for (_, logit) in candidates.iter_mut() {
*logit = (*logit - max_logit).exp();
sum += *logit;
}
if sum > 0.0 {
for (_, prob) in candidates.iter_mut() {
*prob /= sum;
}
}
}
fn softmax_single_max(logits: &[f32]) -> f32 {
let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let sum: f32 = logits.iter().map(|&v| (v - max_val).exp()).sum();
if sum > 0.0 {
1.0 / sum
} else {
0.0
}
}
fn argmax(values: &[f32]) -> u32 {
let mut max_idx = 0u32;
let mut max_val = f32::NEG_INFINITY;
for (i, &v) in values.iter().enumerate() {
if v > max_val {
max_val = v;
max_idx = i as u32;
}
}
max_idx
}
struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 { 0x517cc1b727220a95 } else { seed },
}
}
fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn next_f32(&mut self) -> f32 {
(self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
}
pub(crate) fn state_value(&self) -> u64 {
self.state
}
pub(crate) fn from_state_value(state: u64) -> Self {
Self {
state: if state == 0 { 1 } else { state },
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_greedy_sampling() {
let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2];
let config = SamplerConfig::greedy();
let token = sample(&logits, &config, &[]);
assert_eq!(token, 3); }
#[test]
fn test_empty_logits() {
let logits: Vec<f32> = vec![];
let config = SamplerConfig::greedy();
let token = sample(&logits, &config, &[]);
assert_eq!(token, 0);
}
#[test]
fn test_temperature_zero_is_greedy() {
let logits = vec![1.0, 5.0, 3.0, 2.0];
let config = SamplerConfig {
temperature: 0.0,
..SamplerConfig::default()
};
let token = sample(&logits, &config, &[]);
assert_eq!(token, 1); }
#[test]
fn test_top_k_1_is_greedy() {
let logits = vec![1.0, 5.0, 3.0, 2.0];
let config = SamplerConfig {
temperature: 1.0,
top_k: 1,
..SamplerConfig::default()
};
let token = sample(&logits, &config, &[]);
assert_eq!(token, 1);
}
#[test]
fn test_seeded_determinism() {
let logits = vec![1.0, 2.0, 3.0, 2.0, 1.0];
let config = SamplerConfig {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
min_p: 0.0,
seed: Some(42),
..SamplerConfig::default()
};
let mut sampler1 = Sampler::new(config.clone());
let mut sampler2 = Sampler::new(config);
for _ in 0..10 {
let t1 = sampler1.sample(&logits, &[]);
let t2 = sampler2.sample(&logits, &[]);
assert_eq!(t1, t2, "seeded samplers should produce identical results");
}
}
#[test]
fn test_top_p_filters_low_prob() {
let logits = vec![100.0, 0.0, 0.0, 0.0, 0.0];
let config = SamplerConfig {
temperature: 1.0,
top_k: 0,
top_p: 0.5,
min_p: 0.0,
seed: Some(123),
..SamplerConfig::default()
};
let token = sample(&logits, &config, &[]);
assert_eq!(token, 0);
}
#[test]
fn test_repetition_penalty() {
let logits = vec![1.0, 5.0, 4.9, 1.0];
let config = SamplerConfig {
temperature: 0.0, repetition_penalty: 100.0, repetition_penalty_window: 64,
..SamplerConfig::greedy()
};
let token_no_penalty = sample(&logits, &SamplerConfig::greedy(), &[]);
assert_eq!(token_no_penalty, 1);
let token_with_penalty = sample(&logits, &config, &[1]);
assert_eq!(token_with_penalty, 2);
}
#[test]
fn test_sampling_distribution() {
let logits = vec![2.0, 2.0, 2.0, 2.0]; let config = SamplerConfig {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
min_p: 0.0,
seed: Some(999),
..SamplerConfig::default()
};
let mut sampler = Sampler::new(config);
let mut counts = [0u32; 4];
for _ in 0..1000 {
let t = sampler.sample(&logits, &[]);
counts[t as usize] += 1;
}
for (i, &count) in counts.iter().enumerate() {
assert!(
count > 100 && count < 400,
"token {i} got {count} hits (expected ~250 for uniform distribution)"
);
}
}
#[test]
fn test_min_p_filtering() {
let logits = vec![10.0, -10.0, -10.0, -10.0];
let config = SamplerConfig {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
min_p: 0.1, seed: Some(42),
..SamplerConfig::default()
};
let mut sampler = Sampler::new(config);
for _ in 0..100 {
assert_eq!(sampler.sample(&logits, &[]), 0);
}
}
#[test]
fn test_xorshift_range() {
let mut rng = Xorshift64::new(12345);
for _ in 0..10000 {
let v = rng.next_f32();
assert!((0.0..1.0).contains(&v), "RNG produced {v} outside [0, 1)");
}
}
#[test]
fn test_mirostat_v2_basic() {
let logits = vec![3.0, 2.0, 1.0, 0.5, 0.1, -1.0, -2.0, -5.0];
let config = SamplerConfig {
seed: Some(42),
..SamplerConfig::mirostat_v2(5.0, 0.1)
};
let mut sampler = Sampler::new(config);
for _ in 0..50 {
let token = sampler.sample(&logits, &[]);
assert!((token as usize) < logits.len());
}
}
#[test]
fn test_mirostat_v2_adapts_mu() {
let logits = vec![5.0, 0.0, 0.0, 0.0];
let config = SamplerConfig {
seed: Some(123),
..SamplerConfig::mirostat_v2(3.0, 0.1)
};
let mut sampler = Sampler::new(config);
let initial_mu = sampler.mirostat_mu;
sampler.sample(&logits, &[]);
assert!(
(sampler.mirostat_mu - initial_mu).abs() > 1e-6,
"mu should adapt after sampling"
);
}
#[test]
fn test_mirostat_v2_low_tau_prefers_top() {
let logits = vec![10.0, 0.0, 0.0, 0.0, 0.0];
let config = SamplerConfig {
seed: Some(42),
..SamplerConfig::mirostat_v2(0.5, 0.1) };
let mut sampler = Sampler::new(config);
let mut top_count = 0;
for _ in 0..100 {
if sampler.sample(&logits, &[]) == 0 {
top_count += 1;
}
}
assert!(
top_count > 90,
"low tau should strongly prefer top token, got {top_count}/100"
);
}
#[test]
fn test_mirostat_v2_deterministic_with_seed() {
let logits = vec![2.0, 1.5, 1.0, 0.5];
let config = SamplerConfig {
seed: Some(777),
..SamplerConfig::mirostat_v2(5.0, 0.1)
};
let mut sampler1 = Sampler::new(config.clone());
let mut sampler2 = Sampler::new(config);
for _ in 0..20 {
assert_eq!(
sampler1.sample(&logits, &[]),
sampler2.sample(&logits, &[]),
"same seed should produce same sequence"
);
}
}
#[test]
fn test_softmax_candidates_basic() {
let mut candidates = vec![(0, 0.0f32), (1, 0.0), (2, 0.0)];
softmax_candidates(&mut candidates);
for &(_, p) in &candidates {
assert!((p - 1.0 / 3.0).abs() < 0.01, "expected ~0.333, got {p}");
}
}
#[test]
fn banned_tokens_never_sampled() {
let vocab_size = 5usize;
let logits: Vec<f32> = (0..vocab_size).map(|i| i as f32).collect();
let mut banned = Vec::new();
for i in 0u32..vocab_size as u32 {
if i != 3 {
banned.push(i);
}
}
let config = SamplerConfig {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
min_p: 0.0,
seed: Some(42),
banned_tokens: banned,
..SamplerConfig::default()
};
let mut sampler = Sampler::new(config);
for _ in 0..50 {
let tok = sampler.sample(&logits, &[]);
assert_eq!(
tok, 3,
"only token 3 should ever be sampled when all others are banned"
);
}
}
#[test]
fn positive_bias_increases_token_probability() {
let logits = vec![10.0f32, -20.0, -20.0, -20.0];
let mut bias = std::collections::HashMap::new();
bias.insert(1u32, 100.0f32);
let config = SamplerConfig {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
min_p: 0.0,
seed: Some(7),
logit_bias: bias,
..SamplerConfig::default()
};
let mut sampler = Sampler::new(config);
let tok = sampler.sample(&logits, &[]);
assert_eq!(tok, 1, "large positive bias should make token 1 dominate");
}
#[test]
fn negative_bias_decreases() {
let logits = vec![100.0f32, 1.0, 0.5, 0.1];
let mut bias = std::collections::HashMap::new();
bias.insert(0u32, -200.0f32);
let config = SamplerConfig {
temperature: 0.0, logit_bias: bias,
..SamplerConfig::greedy()
};
let tok = sample(&logits, &config, &[]);
assert_eq!(
tok, 1,
"after large negative bias on token 0, token 1 should win"
);
}
#[test]
fn logit_bias_empty_config_no_op() {
let logits = vec![1.0f32, 2.0, 3.0, 0.5];
let config_empty = SamplerConfig {
temperature: 0.0,
logit_bias: std::collections::HashMap::new(),
banned_tokens: Vec::new(),
..SamplerConfig::greedy()
};
let tok = sample(&logits, &config_empty, &[]);
assert_eq!(tok, 2, "empty logit_bias / banned_tokens should be a no-op");
}
#[test]
fn test_grammar_constrained_yes_no() {
let g = Grammar::parse(r#"root ::= "yes" | "no""#).unwrap();
let state = g.initial_state();
assert!(state.allows_token(b"yes"));
assert!(state.allows_token(b"no"));
assert!(!state.allows_token(b"maybe"));
}
#[test]
fn test_grammar_sampler_masks_logits() {
let vocab: Vec<(u32, Vec<u8>)> = vec![
(0, b"maybe".to_vec()),
(1, b"yes".to_vec()),
(2, b"no".to_vec()),
];
let g = Arc::new(Grammar::parse(r#"root ::= "yes" | "no""#).unwrap());
let config = SamplerConfig {
temperature: 0.0, grammar: Some(g),
token_vocab: Some(Arc::new(vocab)),
..SamplerConfig::default()
};
let logits = vec![100.0f32, 1.0, 1.0];
let mut sampler = Sampler::new(config);
let tok = sampler.sample(&logits, &[]);
assert!(tok == 1 || tok == 2, "expected yes(1) or no(2), got {tok}");
}
#[test]
fn test_grammar_state_advances_through_sequence() {
let vocab: Vec<(u32, Vec<u8>)> =
vec![(0, b"a".to_vec()), (1, b"b".to_vec()), (2, b"c".to_vec())];
let g = Arc::new(Grammar::parse(r#"root ::= "a" "b""#).unwrap());
let config = SamplerConfig {
temperature: 0.0,
grammar: Some(g),
token_vocab: Some(Arc::new(vocab)),
..SamplerConfig::default()
};
let logits = vec![1.0f32, 0.5, 0.5];
let mut sampler = Sampler::new(config);
let tok1 = sampler.sample(&logits, &[]);
assert_eq!(tok1, 0, "first token must be 'a' (id=0)");
let tok2 = sampler.sample(&logits, &[0]);
assert_eq!(tok2, 1, "second token must be 'b' (id=1)");
assert!(
sampler.grammar_complete(),
"grammar should be complete after 'a' + 'b'"
);
}
#[test]
fn test_grammar_parse_roundtrip() {
let g = Grammar::parse("root ::= [a-z]+ \":\" [0-9]+").unwrap();
assert!(!g.rules.is_empty());
assert_eq!(g.root, "root");
}
#[test]
fn test_grammar_stuck_state_masks_all() {
let g = Arc::new(Grammar::parse(r#"root ::= "x""#).unwrap());
let mut state = g.initial_state();
let result = state.advance(b"y");
assert!(result.is_err(), "advancing with wrong bytes should error");
}
}