Skip to main content

oxicuda_seq/decoders/
nucleus.rs

1//! Nucleus (top-`p`) sampling for autoregressive sequence decoding.
2//!
3//! Reference: Holtzman, A., Buys, J., Du, L., Forbes, M. & Choi, Y. (2020).
4//! *The Curious Case of Neural Text Degeneration*. ICLR 2020.
5//! <https://arxiv.org/abs/1904.09751>.
6//!
7//! # Algorithm
8//!
9//! Given logits `z ∈ ℝᵛ`, temperature `T > 0`, nucleus mass `p ∈ (0, 1]`
10//! and a lower bound `min_tokens ≥ 1`:
11//!
12//! ```text
13//! q_i        = softmax(z_i / T)                   (numerically-stable)
14//! sort q in descending order, breaking ties by index
15//! find the smallest m such that  Σ_{j≤m} q_{(j)}  ≥  p
16//! m'         = max(m, min_tokens)
17//! keep the first m' indices, renormalise their probabilities, sample
18//! ```
19//!
20//! Setting `p = 1.0` recovers full temperature-scaled softmax sampling
21//! (statistically); setting `min_tokens = 1` and `p → 0⁺` collapses to
22//! greedy decoding (the argmax must always survive).
23//!
24//! The softmax is computed using the standard max-subtraction trick so
25//! that logits like `[1e6, 1, 1]` do not overflow.
26
27use crate::error::{SeqError, SeqResult};
28use crate::handle::LcgRng;
29
30/// Configuration for [`nucleus_sample`] and [`nucleus_sample_batch`].
31#[derive(Debug, Clone, Copy)]
32pub struct NucleusConfig {
33    /// Cumulative-probability threshold defining the "nucleus", in `(0, 1]`.
34    pub p: f64,
35    /// Softmax temperature (`> 0`).
36    pub temperature: f64,
37    /// Minimum number of tokens to keep (`≥ 1`).  This safeguards against
38    /// pathologically tiny nuclei when `p` is very small.
39    pub min_tokens: usize,
40}
41
42impl Default for NucleusConfig {
43    fn default() -> Self {
44        Self {
45            p: 0.9,
46            temperature: 1.0,
47            min_tokens: 1,
48        }
49    }
50}
51
52impl NucleusConfig {
53    fn validate(&self) -> SeqResult<()> {
54        if !self.p.is_finite() || self.p <= 0.0 || self.p > 1.0 {
55            return Err(SeqError::InvalidConfiguration(format!(
56                "nucleus: p must be in (0, 1], got {}",
57                self.p
58            )));
59        }
60        if !self.temperature.is_finite() || self.temperature <= 0.0 {
61            return Err(SeqError::InvalidParameter {
62                name: "temperature".to_string(),
63                value: self.temperature,
64            });
65        }
66        if self.min_tokens == 0 {
67            return Err(SeqError::InvalidConfiguration(
68                "nucleus: min_tokens must be >= 1".to_string(),
69            ));
70        }
71        Ok(())
72    }
73}
74
75/// Sample a single token id from `logits` using nucleus (top-`p`) sampling.
76///
77/// # Errors
78///
79/// * [`SeqError::EmptyInput`] if `logits` is empty.
80/// * [`SeqError::InvalidConfiguration`] / [`SeqError::InvalidParameter`]
81///   if `cfg` is malformed.
82pub fn nucleus_sample(logits: &[f64], cfg: &NucleusConfig, rng: &mut LcgRng) -> SeqResult<usize> {
83    cfg.validate()?;
84    if logits.is_empty() {
85        return Err(SeqError::EmptyInput);
86    }
87    let v = logits.len();
88
89    // Numerically-stable softmax with max-subtraction *after* temperature
90    // scaling.
91    let scaled: Vec<f64> = logits.iter().map(|&z| z / cfg.temperature).collect();
92    let max_z = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
93    if !max_z.is_finite() {
94        return Err(SeqError::NumericalInstability(
95            "nucleus: all logits non-finite".to_string(),
96        ));
97    }
98    let mut probs = vec![0.0_f64; v];
99    let mut sum = 0.0_f64;
100    for (i, &z) in scaled.iter().enumerate() {
101        let w = (z - max_z).exp();
102        probs[i] = w;
103        sum += w;
104    }
105    if !sum.is_finite() || sum <= 0.0 {
106        return Err(SeqError::NumericalInstability(
107            "nucleus: softmax denominator non-positive".to_string(),
108        ));
109    }
110    for q in probs.iter_mut() {
111        *q /= sum;
112    }
113
114    // Sort indices by descending probability.
115    let mut order: Vec<usize> = (0..v).collect();
116    order.sort_by(|&a, &b| {
117        probs[b]
118            .partial_cmp(&probs[a])
119            .unwrap_or(std::cmp::Ordering::Equal)
120            .then(a.cmp(&b))
121    });
122
123    // Find smallest m such that cumulative >= p.
124    let mut cum = 0.0_f64;
125    let mut m = 0usize;
126    for (rank, &idx) in order.iter().enumerate() {
127        cum += probs[idx];
128        if cum >= cfg.p {
129            m = rank + 1;
130            break;
131        }
132    }
133    if m == 0 {
134        m = order.len();
135    }
136    let m_eff = m.max(cfg.min_tokens).min(order.len());
137
138    // Renormalise over the kept set.
139    let mut kept_probs = vec![0.0_f64; m_eff];
140    let mut kept_sum = 0.0_f64;
141    for slot in 0..m_eff {
142        let q = probs[order[slot]];
143        kept_probs[slot] = q;
144        kept_sum += q;
145    }
146    if !kept_sum.is_finite() || kept_sum <= 0.0 {
147        return Err(SeqError::NumericalInstability(
148            "nucleus: kept mass zero".to_string(),
149        ));
150    }
151    for q in kept_probs.iter_mut() {
152        *q /= kept_sum;
153    }
154
155    let chosen_slot = rng.sample_categorical(&kept_probs);
156    Ok(order[chosen_slot])
157}
158
159/// Batch variant of [`nucleus_sample`] for `n` independent rows of length
160/// `vocab` in a flat `logits` buffer.
161///
162/// # Errors
163///
164/// * [`SeqError::EmptyInput`] if `logits` is empty, `n == 0`, or `vocab == 0`.
165/// * [`SeqError::ShapeMismatch`] if `logits.len() != n * vocab`.
166pub fn nucleus_sample_batch(
167    logits: &[f64],
168    n: usize,
169    vocab: usize,
170    cfg: &NucleusConfig,
171    rng: &mut LcgRng,
172) -> SeqResult<Vec<usize>> {
173    cfg.validate()?;
174    if logits.is_empty() || n == 0 || vocab == 0 {
175        return Err(SeqError::EmptyInput);
176    }
177    if logits.len() != n * vocab {
178        return Err(SeqError::ShapeMismatch {
179            expected: n * vocab,
180            got: logits.len(),
181        });
182    }
183    let mut out = Vec::with_capacity(n);
184    for b in 0..n {
185        let row = &logits[b * vocab..(b + 1) * vocab];
186        out.push(nucleus_sample(row, cfg, rng)?);
187    }
188    Ok(out)
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn invalid_p_rejected() {
197        let mut rng = LcgRng::new(0);
198        for p in [0.0_f64, -0.1, 1.1, f64::NAN] {
199            let cfg = NucleusConfig {
200                p,
201                temperature: 1.0,
202                min_tokens: 1,
203            };
204            let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
205            assert!(matches!(err, SeqError::InvalidConfiguration(_)));
206        }
207    }
208
209    #[test]
210    fn nonpositive_temperature_rejected() {
211        let mut rng = LcgRng::new(0);
212        for t in [0.0_f64, -0.5, f64::NAN] {
213            let cfg = NucleusConfig {
214                p: 0.9,
215                temperature: t,
216                min_tokens: 1,
217            };
218            let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
219            assert!(matches!(err, SeqError::InvalidParameter { .. }));
220        }
221    }
222
223    #[test]
224    fn zero_min_tokens_rejected() {
225        let cfg = NucleusConfig {
226            p: 0.9,
227            temperature: 1.0,
228            min_tokens: 0,
229        };
230        let mut rng = LcgRng::new(0);
231        let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
232        assert!(matches!(err, SeqError::InvalidConfiguration(_)));
233    }
234
235    #[test]
236    fn empty_logits_rejected() {
237        let cfg = NucleusConfig::default();
238        let mut rng = LcgRng::new(0);
239        let err = nucleus_sample(&[], &cfg, &mut rng).unwrap_err();
240        assert!(matches!(err, SeqError::EmptyInput));
241    }
242
243    #[test]
244    fn p_one_is_full_softmax() {
245        // With p = 1.0, every token is kept; statistically all should
246        // appear when the distribution is uniform.
247        let logits = vec![0.0; 4];
248        let cfg = NucleusConfig {
249            p: 1.0,
250            temperature: 1.0,
251            min_tokens: 1,
252        };
253        let mut rng = LcgRng::new(0);
254        let mut counts = [0usize; 4];
255        for _ in 0..4000 {
256            counts[nucleus_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
257        }
258        for c in counts {
259            assert!(c > 800, "counts = {counts:?}");
260        }
261    }
262
263    #[test]
264    fn p_half_truncates_to_top_half() {
265        // 4-token distribution where the top two probabilities together
266        // already exceed 0.5; with p = 0.5 we should only ever see tokens
267        // 0 and 1.
268        let logits = vec![2.0, 1.0, 0.0, -1.0];
269        let cfg = NucleusConfig {
270            p: 0.5,
271            temperature: 1.0,
272            min_tokens: 1,
273        };
274        let mut rng = LcgRng::new(123);
275        for _ in 0..500 {
276            let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
277            assert!(t < 2, "token {t} should not appear with p=0.5");
278        }
279    }
280
281    #[test]
282    fn min_tokens_lower_bound() {
283        // p ≈ 0 ⇒ nucleus contains only the argmax — but min_tokens = 3
284        // forces the top-3 to remain.  Verify all of the top-3 indices
285        // can be sampled and none of the bottom-2.
286        let logits = vec![5.0, 1.5, 1.0, -4.0, -10.0];
287        let cfg = NucleusConfig {
288            p: 0.001,
289            temperature: 1.0,
290            min_tokens: 3,
291        };
292        let mut rng = LcgRng::new(0);
293        let mut seen = [false; 5];
294        for _ in 0..1000 {
295            let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
296            seen[t] = true;
297        }
298        assert!(seen[0] && seen[1] && seen[2]);
299        assert!(!seen[3] && !seen[4]);
300    }
301
302    #[test]
303    fn min_tokens_collapses_to_argmax() {
304        // min_tokens = 1 with tiny p must always pick the argmax.
305        let logits = vec![5.0, 1.5, 1.0, -4.0, -10.0];
306        let cfg = NucleusConfig {
307            p: 0.001,
308            temperature: 1.0,
309            min_tokens: 1,
310        };
311        let mut rng = LcgRng::new(7);
312        for _ in 0..200 {
313            assert_eq!(nucleus_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
314        }
315    }
316
317    #[test]
318    fn deterministic_with_seed() {
319        let logits = vec![0.5, -1.0, 1.2, 0.0, 2.3];
320        let cfg = NucleusConfig {
321            p: 0.8,
322            temperature: 1.0,
323            min_tokens: 1,
324        };
325        let mut rng_a = LcgRng::new(2026);
326        let mut rng_b = LcgRng::new(2026);
327        for _ in 0..200 {
328            let a = nucleus_sample(&logits, &cfg, &mut rng_a).expect("ok");
329            let b = nucleus_sample(&logits, &cfg, &mut rng_b).expect("ok");
330            assert_eq!(a, b);
331        }
332    }
333
334    #[test]
335    fn batch_correctness() {
336        // Row 0 strongly favours index 0; row 1 strongly favours index 2.
337        let logits = vec![10.0, -10.0, -10.0, -10.0, -10.0, 10.0];
338        let cfg = NucleusConfig {
339            p: 0.5,
340            temperature: 1.0,
341            min_tokens: 1,
342        };
343        let mut rng = LcgRng::new(0);
344        let out = nucleus_sample_batch(&logits, 2, 3, &cfg, &mut rng).expect("ok");
345        assert_eq!(out, vec![0, 2]);
346    }
347
348    #[test]
349    fn batch_shape_mismatch() {
350        let logits = vec![0.0_f64; 5];
351        let cfg = NucleusConfig::default();
352        let mut rng = LcgRng::new(0);
353        let err = nucleus_sample_batch(&logits, 2, 3, &cfg, &mut rng).unwrap_err();
354        assert!(matches!(err, SeqError::ShapeMismatch { .. }));
355    }
356
357    #[test]
358    fn numerically_stable_softmax() {
359        // Without max-subtraction, exp(1e6) overflows.  After subtraction
360        // we should still return a valid index and never panic.
361        let logits = vec![1.0e6, 1.0, 1.0];
362        let cfg = NucleusConfig {
363            p: 0.9,
364            temperature: 1.0,
365            min_tokens: 1,
366        };
367        let mut rng = LcgRng::new(0);
368        let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
369        assert_eq!(t, 0);
370    }
371
372    #[test]
373    fn single_element_vocab() {
374        let logits = vec![2.71];
375        let cfg = NucleusConfig {
376            p: 0.5,
377            temperature: 1.0,
378            min_tokens: 1,
379        };
380        let mut rng = LcgRng::new(0);
381        for _ in 0..10 {
382            assert_eq!(nucleus_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
383        }
384    }
385
386    #[test]
387    fn nucleus_never_picks_truncated_token() {
388        // Bottom tokens have ~0 mass; nucleus at p=0.95 should still
389        // exclude the very-bottom tokens.
390        let logits = vec![3.0_f64, 2.5, 2.0, -10.0, -10.0, -10.0];
391        let cfg = NucleusConfig {
392            p: 0.95,
393            temperature: 1.0,
394            min_tokens: 1,
395        };
396        let mut rng = LcgRng::new(0);
397        for _ in 0..500 {
398            let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
399            assert!(t < 3, "got truncated token {t}");
400        }
401    }
402}