use crate::error::{SeqError, SeqResult};
use crate::handle::LcgRng;
#[derive(Debug, Clone, Copy)]
pub struct NucleusConfig {
pub p: f64,
pub temperature: f64,
pub min_tokens: usize,
}
impl Default for NucleusConfig {
fn default() -> Self {
Self {
p: 0.9,
temperature: 1.0,
min_tokens: 1,
}
}
}
impl NucleusConfig {
fn validate(&self) -> SeqResult<()> {
if !self.p.is_finite() || self.p <= 0.0 || self.p > 1.0 {
return Err(SeqError::InvalidConfiguration(format!(
"nucleus: p must be in (0, 1], got {}",
self.p
)));
}
if !self.temperature.is_finite() || self.temperature <= 0.0 {
return Err(SeqError::InvalidParameter {
name: "temperature".to_string(),
value: self.temperature,
});
}
if self.min_tokens == 0 {
return Err(SeqError::InvalidConfiguration(
"nucleus: min_tokens must be >= 1".to_string(),
));
}
Ok(())
}
}
pub fn nucleus_sample(logits: &[f64], cfg: &NucleusConfig, rng: &mut LcgRng) -> SeqResult<usize> {
cfg.validate()?;
if logits.is_empty() {
return Err(SeqError::EmptyInput);
}
let v = logits.len();
let scaled: Vec<f64> = logits.iter().map(|&z| z / cfg.temperature).collect();
let max_z = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if !max_z.is_finite() {
return Err(SeqError::NumericalInstability(
"nucleus: all logits non-finite".to_string(),
));
}
let mut probs = vec![0.0_f64; v];
let mut sum = 0.0_f64;
for (i, &z) in scaled.iter().enumerate() {
let w = (z - max_z).exp();
probs[i] = w;
sum += w;
}
if !sum.is_finite() || sum <= 0.0 {
return Err(SeqError::NumericalInstability(
"nucleus: softmax denominator non-positive".to_string(),
));
}
for q in probs.iter_mut() {
*q /= sum;
}
let mut order: Vec<usize> = (0..v).collect();
order.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.cmp(&b))
});
let mut cum = 0.0_f64;
let mut m = 0usize;
for (rank, &idx) in order.iter().enumerate() {
cum += probs[idx];
if cum >= cfg.p {
m = rank + 1;
break;
}
}
if m == 0 {
m = order.len();
}
let m_eff = m.max(cfg.min_tokens).min(order.len());
let mut kept_probs = vec![0.0_f64; m_eff];
let mut kept_sum = 0.0_f64;
for slot in 0..m_eff {
let q = probs[order[slot]];
kept_probs[slot] = q;
kept_sum += q;
}
if !kept_sum.is_finite() || kept_sum <= 0.0 {
return Err(SeqError::NumericalInstability(
"nucleus: kept mass zero".to_string(),
));
}
for q in kept_probs.iter_mut() {
*q /= kept_sum;
}
let chosen_slot = rng.sample_categorical(&kept_probs);
Ok(order[chosen_slot])
}
pub fn nucleus_sample_batch(
logits: &[f64],
n: usize,
vocab: usize,
cfg: &NucleusConfig,
rng: &mut LcgRng,
) -> SeqResult<Vec<usize>> {
cfg.validate()?;
if logits.is_empty() || n == 0 || vocab == 0 {
return Err(SeqError::EmptyInput);
}
if logits.len() != n * vocab {
return Err(SeqError::ShapeMismatch {
expected: n * vocab,
got: logits.len(),
});
}
let mut out = Vec::with_capacity(n);
for b in 0..n {
let row = &logits[b * vocab..(b + 1) * vocab];
out.push(nucleus_sample(row, cfg, rng)?);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn invalid_p_rejected() {
let mut rng = LcgRng::new(0);
for p in [0.0_f64, -0.1, 1.1, f64::NAN] {
let cfg = NucleusConfig {
p,
temperature: 1.0,
min_tokens: 1,
};
let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
assert!(matches!(err, SeqError::InvalidConfiguration(_)));
}
}
#[test]
fn nonpositive_temperature_rejected() {
let mut rng = LcgRng::new(0);
for t in [0.0_f64, -0.5, f64::NAN] {
let cfg = NucleusConfig {
p: 0.9,
temperature: t,
min_tokens: 1,
};
let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
assert!(matches!(err, SeqError::InvalidParameter { .. }));
}
}
#[test]
fn zero_min_tokens_rejected() {
let cfg = NucleusConfig {
p: 0.9,
temperature: 1.0,
min_tokens: 0,
};
let mut rng = LcgRng::new(0);
let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
assert!(matches!(err, SeqError::InvalidConfiguration(_)));
}
#[test]
fn empty_logits_rejected() {
let cfg = NucleusConfig::default();
let mut rng = LcgRng::new(0);
let err = nucleus_sample(&[], &cfg, &mut rng).unwrap_err();
assert!(matches!(err, SeqError::EmptyInput));
}
#[test]
fn p_one_is_full_softmax() {
let logits = vec![0.0; 4];
let cfg = NucleusConfig {
p: 1.0,
temperature: 1.0,
min_tokens: 1,
};
let mut rng = LcgRng::new(0);
let mut counts = [0usize; 4];
for _ in 0..4000 {
counts[nucleus_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
}
for c in counts {
assert!(c > 800, "counts = {counts:?}");
}
}
#[test]
fn p_half_truncates_to_top_half() {
let logits = vec![2.0, 1.0, 0.0, -1.0];
let cfg = NucleusConfig {
p: 0.5,
temperature: 1.0,
min_tokens: 1,
};
let mut rng = LcgRng::new(123);
for _ in 0..500 {
let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
assert!(t < 2, "token {t} should not appear with p=0.5");
}
}
#[test]
fn min_tokens_lower_bound() {
let logits = vec![5.0, 1.5, 1.0, -4.0, -10.0];
let cfg = NucleusConfig {
p: 0.001,
temperature: 1.0,
min_tokens: 3,
};
let mut rng = LcgRng::new(0);
let mut seen = [false; 5];
for _ in 0..1000 {
let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
seen[t] = true;
}
assert!(seen[0] && seen[1] && seen[2]);
assert!(!seen[3] && !seen[4]);
}
#[test]
fn min_tokens_collapses_to_argmax() {
let logits = vec![5.0, 1.5, 1.0, -4.0, -10.0];
let cfg = NucleusConfig {
p: 0.001,
temperature: 1.0,
min_tokens: 1,
};
let mut rng = LcgRng::new(7);
for _ in 0..200 {
assert_eq!(nucleus_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
}
}
#[test]
fn deterministic_with_seed() {
let logits = vec![0.5, -1.0, 1.2, 0.0, 2.3];
let cfg = NucleusConfig {
p: 0.8,
temperature: 1.0,
min_tokens: 1,
};
let mut rng_a = LcgRng::new(2026);
let mut rng_b = LcgRng::new(2026);
for _ in 0..200 {
let a = nucleus_sample(&logits, &cfg, &mut rng_a).expect("ok");
let b = nucleus_sample(&logits, &cfg, &mut rng_b).expect("ok");
assert_eq!(a, b);
}
}
#[test]
fn batch_correctness() {
let logits = vec![10.0, -10.0, -10.0, -10.0, -10.0, 10.0];
let cfg = NucleusConfig {
p: 0.5,
temperature: 1.0,
min_tokens: 1,
};
let mut rng = LcgRng::new(0);
let out = nucleus_sample_batch(&logits, 2, 3, &cfg, &mut rng).expect("ok");
assert_eq!(out, vec![0, 2]);
}
#[test]
fn batch_shape_mismatch() {
let logits = vec![0.0_f64; 5];
let cfg = NucleusConfig::default();
let mut rng = LcgRng::new(0);
let err = nucleus_sample_batch(&logits, 2, 3, &cfg, &mut rng).unwrap_err();
assert!(matches!(err, SeqError::ShapeMismatch { .. }));
}
#[test]
fn numerically_stable_softmax() {
let logits = vec![1.0e6, 1.0, 1.0];
let cfg = NucleusConfig {
p: 0.9,
temperature: 1.0,
min_tokens: 1,
};
let mut rng = LcgRng::new(0);
let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
assert_eq!(t, 0);
}
#[test]
fn single_element_vocab() {
let logits = vec![2.71];
let cfg = NucleusConfig {
p: 0.5,
temperature: 1.0,
min_tokens: 1,
};
let mut rng = LcgRng::new(0);
for _ in 0..10 {
assert_eq!(nucleus_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
}
}
#[test]
fn nucleus_never_picks_truncated_token() {
let logits = vec![3.0_f64, 2.5, 2.0, -10.0, -10.0, -10.0];
let cfg = NucleusConfig {
p: 0.95,
temperature: 1.0,
min_tokens: 1,
};
let mut rng = LcgRng::new(0);
for _ in 0..500 {
let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
assert!(t < 3, "got truncated token {t}");
}
}
}