Skip to main content

oxicuda_seq/decoders/
top_k.rs

1//! Top-`k` sampling for autoregressive sequence decoding.
2//!
3//! Reference: Fan, A., Lewis, M., & Dauphin, Y. (2018).
4//! *Hierarchical Neural Story Generation*. ACL 2018.
5//! <https://aclanthology.org/P18-1082/>.
6//!
7//! # Algorithm
8//!
9//! Given a vector of logits `z ∈ ℝᵛ` and a temperature `T > 0`:
10//!
11//! ```text
12//! z'_i  = z_i / T
13//! S     = arg-top-k by z'
14//! p_i   = softmax over S, zero elsewhere
15//! tok   = Categorical(p)
16//! ```
17//!
18//! The softmax is computed in a numerically-stable way by subtracting the
19//! maximum logit of the top-`k` set before exponentiation.
20//!
21//! Setting `k = 1` collapses to greedy / argmax decoding (deterministic
22//! regardless of the RNG state); setting `k ≥ vocab` recovers full
23//! temperature-scaled softmax sampling.
24
25use crate::error::{SeqError, SeqResult};
26use crate::handle::LcgRng;
27
28/// Configuration for [`top_k_sample`] and [`top_k_sample_batch`].
29#[derive(Debug, Clone, Copy)]
30pub struct TopKConfig {
31    /// Number of highest-probability tokens to keep (`k ≥ 1`).
32    pub k: usize,
33    /// Softmax temperature (`> 0`).  Higher values flatten the
34    /// distribution, lower values sharpen it.
35    pub temperature: f64,
36}
37
38impl Default for TopKConfig {
39    fn default() -> Self {
40        Self {
41            k: 50,
42            temperature: 1.0,
43        }
44    }
45}
46
47impl TopKConfig {
48    fn validate(&self) -> SeqResult<()> {
49        if self.k == 0 {
50            return Err(SeqError::InvalidConfiguration(
51                "top-k: k must be >= 1".to_string(),
52            ));
53        }
54        if !self.temperature.is_finite() || self.temperature <= 0.0 {
55            return Err(SeqError::InvalidParameter {
56                name: "temperature".to_string(),
57                value: self.temperature,
58            });
59        }
60        Ok(())
61    }
62}
63
64/// Sample a single token id from `logits` using top-`k` sampling.
65///
66/// # Errors
67///
68/// * [`SeqError::EmptyInput`] if `logits` is empty.
69/// * [`SeqError::InvalidConfiguration`] / [`SeqError::InvalidParameter`]
70///   if `cfg` is malformed.
71pub fn top_k_sample(logits: &[f64], cfg: &TopKConfig, rng: &mut LcgRng) -> SeqResult<usize> {
72    cfg.validate()?;
73    if logits.is_empty() {
74        return Err(SeqError::EmptyInput);
75    }
76    let v = logits.len();
77    let k_eff = cfg.k.min(v);
78
79    // Temperature-scale once.
80    let scaled: Vec<f64> = logits.iter().map(|&z| z / cfg.temperature).collect();
81
82    // Argmax fast-path when k == 1.
83    if k_eff == 1 {
84        return Ok(argmax(&scaled));
85    }
86
87    // Partial-sort indices by descending logit value, keeping the top k.
88    let indices = top_k_indices(&scaled, k_eff);
89
90    // Numerically-stable softmax over the surviving k.
91    let max_z = indices
92        .iter()
93        .map(|&i| scaled[i])
94        .fold(f64::NEG_INFINITY, f64::max);
95    let mut probs = vec![0.0_f64; k_eff];
96    let mut sum = 0.0_f64;
97    for (slot, &i) in indices.iter().enumerate() {
98        let w = (scaled[i] - max_z).exp();
99        probs[slot] = w;
100        sum += w;
101    }
102    if !sum.is_finite() || sum <= 0.0 {
103        return Err(SeqError::NumericalInstability(
104            "top-k: softmax denominator non-positive".to_string(),
105        ));
106    }
107    for p in probs.iter_mut() {
108        *p /= sum;
109    }
110
111    let chosen_slot = rng.sample_categorical(&probs);
112    Ok(indices[chosen_slot])
113}
114
115/// Batch variant of [`top_k_sample`] that processes `n` independent rows
116/// of length `vocab` from a flat `logits` buffer.
117///
118/// # Errors
119///
120/// * [`SeqError::EmptyInput`] if `logits` is empty or `n == 0` or `vocab == 0`.
121/// * [`SeqError::ShapeMismatch`] if `logits.len() != n * vocab`.
122pub fn top_k_sample_batch(
123    logits: &[f64],
124    n: usize,
125    vocab: usize,
126    cfg: &TopKConfig,
127    rng: &mut LcgRng,
128) -> SeqResult<Vec<usize>> {
129    cfg.validate()?;
130    if logits.is_empty() || n == 0 || vocab == 0 {
131        return Err(SeqError::EmptyInput);
132    }
133    if logits.len() != n * vocab {
134        return Err(SeqError::ShapeMismatch {
135            expected: n * vocab,
136            got: logits.len(),
137        });
138    }
139    let mut out = Vec::with_capacity(n);
140    for b in 0..n {
141        let row = &logits[b * vocab..(b + 1) * vocab];
142        out.push(top_k_sample(row, cfg, rng)?);
143    }
144    Ok(out)
145}
146
147/// Return the index of the maximum element.  `xs` is assumed non-empty.
148#[inline]
149fn argmax(xs: &[f64]) -> usize {
150    let mut best = 0usize;
151    let mut best_v = xs[0];
152    for (i, &v) in xs.iter().enumerate().skip(1) {
153        if v > best_v {
154            best_v = v;
155            best = i;
156        }
157    }
158    best
159}
160
161/// Return the indices of the top-`k` elements of `xs` in descending order.
162/// `k` is assumed to satisfy `1 <= k <= xs.len()`.
163fn top_k_indices(xs: &[f64], k: usize) -> Vec<usize> {
164    // O(v log k) maintain a small min-heap of (value, index) for the top-k.
165    // For simplicity (and because k is generally small), we sort all and
166    // take a prefix.  This stays O(v log v) but avoids hand-rolling a heap
167    // and keeps the code readable.
168    let mut idx: Vec<usize> = (0..xs.len()).collect();
169    idx.sort_by(|&a, &b| {
170        xs[b]
171            .partial_cmp(&xs[a])
172            .unwrap_or(std::cmp::Ordering::Equal)
173    });
174    idx.truncate(k);
175    idx
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    fn full_softmax(logits: &[f64], t: f64) -> Vec<f64> {
183        let scaled: Vec<f64> = logits.iter().map(|&z| z / t).collect();
184        let m = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
185        let exps: Vec<f64> = scaled.iter().map(|&z| (z - m).exp()).collect();
186        let s: f64 = exps.iter().sum();
187        exps.iter().map(|&e| e / s).collect()
188    }
189
190    #[test]
191    fn k_zero_rejected() {
192        let cfg = TopKConfig {
193            k: 0,
194            temperature: 1.0,
195        };
196        let mut rng = LcgRng::new(0);
197        let err = top_k_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
198        assert!(matches!(err, SeqError::InvalidConfiguration(_)));
199    }
200
201    #[test]
202    fn nonpositive_temperature_rejected() {
203        let mut rng = LcgRng::new(0);
204        for t in [0.0_f64, -1.0, f64::NAN] {
205            let cfg = TopKConfig {
206                k: 2,
207                temperature: t,
208            };
209            let err = top_k_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
210            assert!(matches!(err, SeqError::InvalidParameter { .. }));
211        }
212    }
213
214    #[test]
215    fn empty_logits_rejected() {
216        let cfg = TopKConfig::default();
217        let mut rng = LcgRng::new(0);
218        let err = top_k_sample(&[], &cfg, &mut rng).unwrap_err();
219        assert!(matches!(err, SeqError::EmptyInput));
220    }
221
222    #[test]
223    fn k_at_least_vocab_full_softmax() {
224        // k >= vocab should behave like full softmax sampling.  Verify
225        // that with high temperature both tokens have a chance.
226        let logits = vec![0.0, 0.0, 0.0];
227        let cfg = TopKConfig {
228            k: 10,
229            temperature: 1.0,
230        };
231        let mut rng = LcgRng::new(42);
232        let mut counts = [0usize; 3];
233        for _ in 0..3000 {
234            let tok = top_k_sample(&logits, &cfg, &mut rng).expect("sample ok");
235            counts[tok] += 1;
236        }
237        for c in counts {
238            assert!(
239                c > 700,
240                "every token should be sampled: counts = {counts:?}"
241            );
242        }
243    }
244
245    #[test]
246    fn k_one_is_argmax() {
247        let logits = vec![-1.0, 4.5, 2.0, 4.5_f64.next_down()];
248        let cfg = TopKConfig {
249            k: 1,
250            temperature: 0.7,
251        };
252        let mut rng_a = LcgRng::new(0);
253        let mut rng_b = LcgRng::new(999_999);
254        let tok_a = top_k_sample(&logits, &cfg, &mut rng_a).expect("sample ok");
255        let tok_b = top_k_sample(&logits, &cfg, &mut rng_b).expect("sample ok");
256        assert_eq!(tok_a, 1);
257        assert_eq!(tok_b, 1, "k=1 must be deterministic regardless of rng");
258    }
259
260    #[test]
261    fn deterministic_with_seed() {
262        let logits = vec![0.5, 1.2, -0.3, 0.8, 2.1];
263        let cfg = TopKConfig {
264            k: 3,
265            temperature: 1.0,
266        };
267        let mut rng_a = LcgRng::new(123);
268        let mut rng_b = LcgRng::new(123);
269        for _ in 0..200 {
270            let a = top_k_sample(&logits, &cfg, &mut rng_a).expect("ok");
271            let b = top_k_sample(&logits, &cfg, &mut rng_b).expect("ok");
272            assert_eq!(a, b);
273        }
274    }
275
276    #[test]
277    fn distribution_matches_renormalised_softmax() {
278        // Use logits such that top-3 of 5 are well-separated from the
279        // bottom-2; assert chi-square goodness-of-fit at 5% significance
280        // against the theoretical renormalised softmax over the kept set.
281        let logits = vec![3.0_f64, 1.0, 0.0, -2.0, -5.0];
282        let cfg = TopKConfig {
283            k: 3,
284            temperature: 1.0,
285        };
286        // Expected probabilities = softmax over indices {0, 1, 2}.
287        let full = full_softmax(&logits[..3], 1.0);
288        let n_samples = 6000usize;
289        let mut rng = LcgRng::new(7);
290        let mut counts = [0usize; 3];
291        for _ in 0..n_samples {
292            let t = top_k_sample(&logits, &cfg, &mut rng).expect("ok");
293            assert!(t < 3, "top-k must never pick a truncated index");
294            counts[t] += 1;
295        }
296        let mut chi2 = 0.0_f64;
297        for i in 0..3 {
298            let expected = full[i] * n_samples as f64;
299            let diff = counts[i] as f64 - expected;
300            chi2 += diff * diff / expected;
301        }
302        // df = 2, 99th percentile ≈ 9.21.
303        assert!(chi2 < 9.21, "chi-square = {chi2}");
304    }
305
306    #[test]
307    fn batch_correctness() {
308        // Two rows, each strongly peaked on a different index.
309        let logits = vec![10.0, -10.0, -10.0, -10.0, 10.0, -10.0];
310        let cfg = TopKConfig {
311            k: 2,
312            temperature: 1.0,
313        };
314        let mut rng = LcgRng::new(0);
315        let out = top_k_sample_batch(&logits, 2, 3, &cfg, &mut rng).expect("ok");
316        assert_eq!(out, vec![0, 1]);
317    }
318
319    #[test]
320    fn batch_empty_rejected() {
321        let cfg = TopKConfig::default();
322        let mut rng = LcgRng::new(0);
323        assert!(matches!(
324            top_k_sample_batch(&[], 0, 3, &cfg, &mut rng).unwrap_err(),
325            SeqError::EmptyInput
326        ));
327        assert!(matches!(
328            top_k_sample_batch(&[0.0, 0.0], 1, 0, &cfg, &mut rng).unwrap_err(),
329            SeqError::EmptyInput
330        ));
331    }
332
333    #[test]
334    fn batch_shape_mismatch_rejected() {
335        let logits = vec![0.0_f64; 5];
336        let cfg = TopKConfig::default();
337        let mut rng = LcgRng::new(0);
338        let err = top_k_sample_batch(&logits, 2, 3, &cfg, &mut rng).unwrap_err();
339        assert!(matches!(err, SeqError::ShapeMismatch { .. }));
340    }
341
342    #[test]
343    fn high_temperature_flattens() {
344        // With very high temperature and k=full, even strongly peaked
345        // logits should produce a near-uniform sample distribution.
346        let logits = vec![5.0, 0.0, 0.0, 0.0];
347        let cfg = TopKConfig {
348            k: 4,
349            temperature: 50.0,
350        };
351        let mut rng = LcgRng::new(1);
352        let mut counts = [0usize; 4];
353        for _ in 0..4000 {
354            counts[top_k_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
355        }
356        for c in counts {
357            // Uniform would give 1000; allow generous slack.
358            assert!(c > 700);
359        }
360    }
361
362    #[test]
363    fn low_temperature_sharpens() {
364        // With low temperature, sampling should overwhelmingly pick
365        // the largest logit even at k = full.
366        let logits = vec![3.0, 1.0, 0.0, -1.0];
367        let cfg = TopKConfig {
368            k: 4,
369            temperature: 0.05,
370        };
371        let mut rng = LcgRng::new(0);
372        let mut argmax_count = 0usize;
373        for _ in 0..1000 {
374            if top_k_sample(&logits, &cfg, &mut rng).expect("ok") == 0 {
375                argmax_count += 1;
376            }
377        }
378        assert!(argmax_count > 980);
379    }
380
381    #[test]
382    fn top_k_never_picks_truncated_token() {
383        // Logits with a clear top-2 vs the rest; with k=2, the chosen
384        // token must always be one of the two largest.
385        let logits = vec![5.0, 4.5, -3.0, -4.0, -10.0];
386        let cfg = TopKConfig {
387            k: 2,
388            temperature: 1.0,
389        };
390        let mut rng = LcgRng::new(42);
391        for _ in 0..500 {
392            let t = top_k_sample(&logits, &cfg, &mut rng).expect("ok");
393            assert!(t == 0 || t == 1, "got truncated token {t}");
394        }
395    }
396
397    #[test]
398    fn single_vocab_returns_zero() {
399        let logits = vec![2.71_f64];
400        let cfg = TopKConfig {
401            k: 5,
402            temperature: 1.0,
403        };
404        let mut rng = LcgRng::new(0);
405        for _ in 0..10 {
406            assert_eq!(top_k_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
407        }
408    }
409}