use std::collections::HashSet;
#[cfg(feature = "inference")]
use llguidance::Constraint;
use smol_str::SmolStr;
use crate::{
error::{Error, Result},
options::RequestOptions,
};
#[allow(dead_code)]
#[derive(Debug, Clone, Copy)]
pub(crate) enum SampleResult {
Token(u32),
SchemaComplete,
TokenAndComplete(u32),
}
#[allow(dead_code)]
pub(crate) trait Sampler {
fn sample(
&mut self,
logits: &mut [f32],
seen_tokens: &HashSet<u32>,
step: usize,
) -> Result<SampleResult>;
}
#[allow(dead_code)]
pub(crate) struct FreeSampler {
opts: RequestOptions,
rng: SmallRng,
vocab_size: u32,
}
impl FreeSampler {
#[allow(dead_code)]
pub(crate) fn new(opts: RequestOptions, seed: u64, vocab_size: u32) -> Self {
Self {
opts,
rng: SmallRng::seed_from_u64(seed),
vocab_size,
}
}
}
impl Sampler for FreeSampler {
fn sample(
&mut self,
logits: &mut [f32],
seen_tokens: &HashSet<u32>,
_step: usize,
) -> Result<SampleResult> {
let cap = (self.vocab_size as usize).min(logits.len());
for v in logits.iter_mut().skip(cap) {
*v = f32::NEG_INFINITY;
}
apply_repetition_penalty(logits, seen_tokens, self.opts.repetition_penalty());
let valid = &logits[..cap];
if valid.iter().any(|&v| v.is_nan()) {
return Err(Error::SamplerNonFinite);
}
if valid.iter().all(|&v| !v.is_finite()) {
return Err(Error::SamplerNonFinite);
}
if self.opts.temperature() <= 0.0 {
let id = argmax(logits);
return Ok(SampleResult::Token(id));
}
apply_temperature(logits, self.opts.temperature());
let probs = softmax(logits);
let id = sample_min_p(&probs, self.opts.min_p(), &mut self.rng);
Ok(SampleResult::Token(id))
}
}
#[cfg(feature = "inference")]
#[allow(dead_code)]
pub(crate) struct ConstrainedSampler {
inner: FreeSampler,
constraint: Constraint,
}
#[cfg(feature = "inference")]
impl ConstrainedSampler {
#[allow(dead_code)]
pub(crate) fn new(
constraint: Constraint,
opts: RequestOptions,
seed: u64,
vocab_size: u32,
) -> Self {
Self {
inner: FreeSampler::new(opts, seed, vocab_size),
constraint,
}
}
}
#[cfg(feature = "inference")]
impl Sampler for ConstrainedSampler {
fn sample(
&mut self,
logits: &mut [f32],
seen_tokens: &HashSet<u32>,
step: usize,
) -> Result<SampleResult> {
let step_result = self.constraint.compute_mask().map_err(Error::llguidance)?;
if step_result.is_stop() {
return Ok(SampleResult::SchemaComplete);
}
let mask = match &step_result.sample_mask {
Some(m) => m,
None => {
self
.constraint
.commit_token(None)
.map_err(Error::llguidance)?;
return Ok(SampleResult::SchemaComplete);
}
};
apply_mask(logits, mask);
if logits.iter().all(|&v| !v.is_finite()) {
return Err(Error::LlGuidanceDeadEnd {
step,
state: SmolStr::new_inline("empty mask"),
});
}
let inner_decision = self.inner.sample(logits, seen_tokens, step)?;
let id = match inner_decision {
SampleResult::Token(id) => id,
SampleResult::SchemaComplete | SampleResult::TokenAndComplete(_) => {
return Ok(inner_decision);
}
};
let _commit = self
.constraint
.commit_token(Some(id))
.map_err(Error::llguidance)?;
if self.constraint.has_pending_stop() {
Ok(SampleResult::TokenAndComplete(id))
} else {
Ok(SampleResult::Token(id))
}
}
}
#[allow(dead_code)]
fn apply_repetition_penalty(logits: &mut [f32], seen: &HashSet<u32>, penalty: f32) {
if penalty == 1.0 {
return;
}
for &tok in seen {
let i = tok as usize;
if i >= logits.len() {
continue;
}
let v = logits[i];
logits[i] = if v > 0.0 { v / penalty } else { v * penalty };
}
}
#[allow(dead_code)]
fn apply_temperature(logits: &mut [f32], temp: f32) {
if temp == 1.0 {
return;
}
let inv = 1.0 / temp;
for v in logits.iter_mut() {
*v *= inv;
}
}
#[allow(dead_code)]
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut out: Vec<f32> = logits.iter().map(|v| (v - max).exp()).collect();
let sum: f32 = out.iter().sum();
if sum > 0.0 {
for v in out.iter_mut() {
*v /= sum;
}
}
out
}
#[allow(dead_code)]
fn argmax(logits: &[f32]) -> u32 {
let mut best_i = 0u32;
let mut best_v = f32::NEG_INFINITY;
for (i, &v) in logits.iter().enumerate() {
if v > best_v {
best_v = v;
best_i = i as u32;
}
}
best_i
}
#[allow(dead_code)]
fn sample_min_p(probs: &[f32], min_p: f32, rng: &mut SmallRng) -> u32 {
let p_max = probs.iter().copied().fold(0.0f32, f32::max);
let threshold = min_p * p_max;
let filtered: Vec<(u32, f32)> = probs
.iter()
.enumerate()
.filter_map(|(i, &p)| (p >= threshold && p > 0.0).then_some((i as u32, p)))
.collect();
if filtered.is_empty() {
return probs
.iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(b.1))
.map(|(i, _)| i as u32)
.unwrap_or(0);
}
let total: f32 = filtered.iter().map(|(_, p)| *p).sum();
let r: f32 = rng.gen_f32() * total;
let mut cum = 0.0f32;
for &(id, p) in &filtered {
cum += p;
if r <= cum {
return id;
}
}
filtered.last().unwrap().0
}
#[cfg(feature = "inference")]
#[allow(dead_code)]
fn apply_mask(logits: &mut [f32], mask: &llguidance::toktrie::SimpleVob) {
let mask_len = mask.len();
for (i, logit) in logits.iter_mut().enumerate().take(mask_len) {
if !mask.is_allowed(i as u32) {
*logit = f32::NEG_INFINITY;
}
}
for v in logits.iter_mut().skip(mask_len) {
*v = f32::NEG_INFINITY;
}
}
#[allow(dead_code)]
struct SmallRng {
s1: u64,
s0: u64,
}
#[allow(dead_code)]
impl SmallRng {
fn seed_from_u64(seed: u64) -> Self {
let a = seed.wrapping_mul(0x9E3779B97F4A7C15);
let b = a.wrapping_mul(0xBF58476D1CE4E5B9);
Self {
s1: a | 1,
s0: b | 1,
}
}
fn next_u64(&mut self) -> u64 {
let mut x = self.s0;
let y = self.s1;
self.s0 = y; x ^= x << 23;
self.s1 = x ^ y ^ (x >> 17) ^ (y >> 26);
self.s1.wrapping_add(y)
}
fn gen_f32(&mut self) -> f32 {
let bits = self.next_u64() >> 40; (bits as f32) / ((1u64 << 24) as f32)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::options::RequestOptions;
use std::collections::HashSet;
#[test]
fn argmax_picks_largest() {
let logits = vec![0.1, 0.5, 0.2, 1.5, 0.0];
assert_eq!(argmax(&logits), 3);
}
#[test]
fn softmax_sums_to_one() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn repetition_penalty_lowers_seen_positive_logits() {
let mut logits = vec![1.0, 2.0, 3.0];
let mut seen = HashSet::new();
seen.insert(1u32);
apply_repetition_penalty(&mut logits, &seen, 2.0);
assert_eq!(logits, vec![1.0, 1.0, 3.0]);
}
#[test]
fn repetition_penalty_amplifies_seen_negative_logits() {
let mut logits = vec![-1.0, -2.0, -3.0];
let mut seen = HashSet::new();
seen.insert(1u32);
apply_repetition_penalty(&mut logits, &seen, 2.0);
assert_eq!(logits, vec![-1.0, -4.0, -3.0]);
}
#[test]
fn free_sampler_greedy_picks_argmax() {
let opts = RequestOptions::default()
.with_temperature(0.0)
.with_repetition_penalty(1.05);
let mut sampler = FreeSampler::new(opts, 42, 65_536);
let mut logits = vec![0.1f32, 0.5, 0.2, 1.5, 0.0];
let result = sampler.sample(&mut logits, &HashSet::new(), 0).unwrap();
assert!(matches!(result, SampleResult::Token(3)));
}
#[test]
fn free_sampler_masks_logits_beyond_vocab_size() {
let opts = RequestOptions::default()
.with_temperature(0.0)
.with_repetition_penalty(1.05);
let mut sampler = FreeSampler::new(opts, 42, 5);
let mut logits = vec![0.1f32, 0.5, 0.2, 1.0, 0.0, 0.0, 0.0, 99.0];
let result = sampler.sample(&mut logits, &HashSet::new(), 0).unwrap();
let id = match result {
SampleResult::Token(id) => id,
_ => panic!("expected Token, got {result:?}"),
};
assert!(id < 5, "FreeSampler picked masked id {id} (vocab_size=5)");
assert_eq!(
id, 3,
"expected id=3 (logit 1.0); masked logit at 7 (99.0) should be -inf"
);
}
#[test]
fn free_sampler_errors_on_all_non_finite_post_penalty() {
let opts = RequestOptions::default()
.with_temperature(0.0)
.with_repetition_penalty(1.05);
let mut sampler = FreeSampler::new(opts, 42, 65_536);
let mut logits = vec![f32::NEG_INFINITY; 8];
let seen: HashSet<u32> = (0..8).collect();
let result = sampler.sample(&mut logits, &seen, 0);
assert!(matches!(result, Err(Error::SamplerNonFinite)));
}
#[test]
fn free_sampler_errors_on_single_nan_logit_from_model() {
let opts = RequestOptions::default()
.with_temperature(0.0)
.with_repetition_penalty(1.05);
let mut sampler = FreeSampler::new(opts, 42, 65_536);
let mut logits = vec![0.1f32, 0.5, 0.2, 1.0, 0.0, 0.3, 0.4, f32::NAN];
let result = sampler.sample(&mut logits, &HashSet::new(), 0);
assert!(
matches!(result, Err(Error::SamplerNonFinite)),
"single-NaN logit must reject (issue #2 C-001 regression)"
);
}
#[test]
fn free_sampler_allows_neg_inf_in_valid_range() {
let opts = RequestOptions::default()
.with_temperature(0.0)
.with_repetition_penalty(1.05);
let mut sampler = FreeSampler::new(opts, 42, 65_536);
let mut logits = vec![
f32::NEG_INFINITY,
0.5,
f32::NEG_INFINITY,
1.5,
f32::NEG_INFINITY,
];
let result = sampler.sample(&mut logits, &HashSet::new(), 0).unwrap();
assert!(
matches!(result, SampleResult::Token(3)),
"argmax picks the largest finite logit (1.5 at index 3)"
);
}
#[test]
fn rng_both_state_words_advance() {
let mut rng = SmallRng::seed_from_u64(0x1234_5678_9ABC_DEF0);
let initial_s0 = rng.s0;
let initial_s1 = rng.s1;
let _ = rng.next_u64();
assert_ne!(rng.s0, initial_s0, "s0 must advance after next_u64");
assert_ne!(rng.s1, initial_s1, "s1 must advance after next_u64");
}
#[test]
fn sample_min_p_excludes_zero_prob_tokens_at_min_p_zero() {
let probs = [0.0f32, 0.6, 0.4];
let mut rng = SmallRng::seed_from_u64(0);
for _ in 0..1000 {
let id = sample_min_p(&probs, 0.0, &mut rng);
assert_ne!(
id, 0,
"sample_min_p must never select a zero-probability token even when min_p=0"
);
}
}
#[test]
fn rng_produces_non_constant_outputs() {
let mut rng = SmallRng::seed_from_u64(42);
let mut seen = HashSet::new();
for _ in 0..1024 {
seen.insert(rng.next_u64());
}
assert!(
seen.len() > 1000,
"RNG produced only {} unique values across 1024 draws — state likely frozen",
seen.len()
);
}
}