Skip to main content

oxicuda_seq/decoders/
typical.rs

1//! Typical (locally typical) decoding for autoregressive sequence
2//! generation.
3//!
4//! Reference: Meister, C., Pimentel, T., Wiher, G. & Cotterell, R. (2022).
5//! *Typical Decoding for Natural Language Generation*. TACL 2023 (arXiv 2202.00666).
6//! <https://arxiv.org/abs/2202.00666>.
7//!
8//! # Algorithm
9//!
10//! Given logits `z ∈ ℝᵛ`, temperature `T > 0`, mass threshold `τ ∈ (0, 1]`
11//! and a lower bound `min_tokens ≥ 1`:
12//!
13//! ```text
14//! 1. p_i  = softmax(z_i / T)                       (numerically-stable)
15//! 2. H    = -Σ_i p_i log p_i                       (conditional entropy)
16//! 3. c_i  = |-log p_i - H|                         (information-content gap)
17//! 4. sort indices by c_i ascending
18//! 5. cumulate p along this order until ≥ τ
19//! 6. enforce min_tokens; renormalise the kept set
20//! 7. sample
21//! ```
22//!
23//! Locally-typical decoding selects tokens whose surprisal is closest to
24//! the expected information content of the next-token distribution.
25//! Setting `τ = 1.0` retains every token (full softmax sampling);
26//! peaked distributions select the argmax (its surprisal matches the
27//! near-zero entropy); uniform distributions retain every token (every
28//! surprisal equals the entropy `log V`).
29//!
30//! Probabilities below `f64::MIN_POSITIVE` are floored to a tiny epsilon
31//! before taking the log, so the surprisal stays finite without distorting
32//! the categorical sampler.
33
34use crate::error::{SeqError, SeqResult};
35use crate::handle::LcgRng;
36
37/// Configuration for [`typical_sample`] and [`typical_sample_batch`].
38#[derive(Debug, Clone, Copy)]
39pub struct TypicalConfig {
40    /// Typical-decoding cumulative mass threshold, in `(0, 1]`.
41    pub tau: f64,
42    /// Softmax temperature (`> 0`).
43    pub temperature: f64,
44    /// Minimum number of tokens to keep (`≥ 1`).
45    pub min_tokens: usize,
46}
47
48impl Default for TypicalConfig {
49    fn default() -> Self {
50        Self {
51            tau: 0.95,
52            temperature: 1.0,
53            min_tokens: 1,
54        }
55    }
56}
57
58impl TypicalConfig {
59    fn validate(&self) -> SeqResult<()> {
60        if !self.tau.is_finite() || self.tau <= 0.0 || self.tau > 1.0 {
61            return Err(SeqError::InvalidConfiguration(format!(
62                "typical: tau must be in (0, 1], got {}",
63                self.tau
64            )));
65        }
66        if !self.temperature.is_finite() || self.temperature <= 0.0 {
67            return Err(SeqError::InvalidParameter {
68                name: "temperature".to_string(),
69                value: self.temperature,
70            });
71        }
72        if self.min_tokens == 0 {
73            return Err(SeqError::InvalidConfiguration(
74                "typical: min_tokens must be >= 1".to_string(),
75            ));
76        }
77        Ok(())
78    }
79}
80
81/// Numerically-stable softmax of `logits / temperature`.
82fn softmax_scaled(logits: &[f64], temperature: f64) -> SeqResult<Vec<f64>> {
83    let scaled: Vec<f64> = logits.iter().map(|&z| z / temperature).collect();
84    let max_z = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
85    if !max_z.is_finite() {
86        return Err(SeqError::NumericalInstability(
87            "typical: all logits non-finite".to_string(),
88        ));
89    }
90    let mut probs = vec![0.0_f64; scaled.len()];
91    let mut sum = 0.0_f64;
92    for (i, &z) in scaled.iter().enumerate() {
93        let w = (z - max_z).exp();
94        probs[i] = w;
95        sum += w;
96    }
97    if !sum.is_finite() || sum <= 0.0 {
98        return Err(SeqError::NumericalInstability(
99            "typical: softmax denominator non-positive".to_string(),
100        ));
101    }
102    for q in probs.iter_mut() {
103        *q /= sum;
104    }
105    Ok(probs)
106}
107
108/// Compute the Shannon entropy `H = -Σ p_i log p_i` of a probability
109/// vector.  Public for testing and downstream use.
110pub fn entropy(probs: &[f64]) -> f64 {
111    let mut h = 0.0_f64;
112    for &p in probs {
113        if p > 0.0 {
114            h -= p * p.ln();
115        }
116    }
117    h
118}
119
120const LOG_FLOOR: f64 = 1.0e-300;
121
122/// Sample a single token id from `logits` using typical decoding.
123///
124/// # Errors
125///
126/// * [`SeqError::EmptyInput`] if `logits` is empty.
127/// * [`SeqError::InvalidConfiguration`] / [`SeqError::InvalidParameter`]
128///   if `cfg` is malformed.
129pub fn typical_sample(logits: &[f64], cfg: &TypicalConfig, rng: &mut LcgRng) -> SeqResult<usize> {
130    cfg.validate()?;
131    if logits.is_empty() {
132        return Err(SeqError::EmptyInput);
133    }
134    let probs = softmax_scaled(logits, cfg.temperature)?;
135    let h = entropy(&probs);
136
137    // |−log p_i − H|.  Floor p before log to avoid −∞.
138    let mut gaps: Vec<(usize, f64)> = probs
139        .iter()
140        .enumerate()
141        .map(|(i, &p)| {
142            let surprisal = -(p.max(LOG_FLOOR)).ln();
143            (i, (surprisal - h).abs())
144        })
145        .collect();
146
147    // Sort indices by ascending information-content gap, breaking ties
148    // by index for determinism.
149    gaps.sort_by(|&(ia, ga), &(ib, gb)| {
150        ga.partial_cmp(&gb)
151            .unwrap_or(std::cmp::Ordering::Equal)
152            .then(ia.cmp(&ib))
153    });
154
155    // Cumulate p along the typical-set order until >= tau.
156    let mut cum = 0.0_f64;
157    let mut m = 0usize;
158    for (rank, &(idx, _)) in gaps.iter().enumerate() {
159        cum += probs[idx];
160        if cum >= cfg.tau {
161            m = rank + 1;
162            break;
163        }
164    }
165    if m == 0 {
166        m = gaps.len();
167    }
168    let m_eff = m.max(cfg.min_tokens).min(gaps.len());
169
170    // Renormalise the kept set.
171    let mut kept_probs = vec![0.0_f64; m_eff];
172    let mut kept_sum = 0.0_f64;
173    for slot in 0..m_eff {
174        let q = probs[gaps[slot].0];
175        kept_probs[slot] = q;
176        kept_sum += q;
177    }
178    if !kept_sum.is_finite() || kept_sum <= 0.0 {
179        return Err(SeqError::NumericalInstability(
180            "typical: kept mass zero".to_string(),
181        ));
182    }
183    for q in kept_probs.iter_mut() {
184        *q /= kept_sum;
185    }
186
187    let chosen_slot = rng.sample_categorical(&kept_probs);
188    Ok(gaps[chosen_slot].0)
189}
190
191/// Batch variant of [`typical_sample`] for `n` independent rows of length
192/// `vocab` in a flat `logits` buffer.
193///
194/// # Errors
195///
196/// * [`SeqError::EmptyInput`] if `logits` is empty, `n == 0`, or `vocab == 0`.
197/// * [`SeqError::ShapeMismatch`] if `logits.len() != n * vocab`.
198pub fn typical_sample_batch(
199    logits: &[f64],
200    n: usize,
201    vocab: usize,
202    cfg: &TypicalConfig,
203    rng: &mut LcgRng,
204) -> SeqResult<Vec<usize>> {
205    cfg.validate()?;
206    if logits.is_empty() || n == 0 || vocab == 0 {
207        return Err(SeqError::EmptyInput);
208    }
209    if logits.len() != n * vocab {
210        return Err(SeqError::ShapeMismatch {
211            expected: n * vocab,
212            got: logits.len(),
213        });
214    }
215    let mut out = Vec::with_capacity(n);
216    for b in 0..n {
217        let row = &logits[b * vocab..(b + 1) * vocab];
218        out.push(typical_sample(row, cfg, rng)?);
219    }
220    Ok(out)
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn invalid_tau_rejected() {
229        let mut rng = LcgRng::new(0);
230        for tau in [0.0_f64, -0.1, 1.1, f64::NAN] {
231            let cfg = TypicalConfig {
232                tau,
233                temperature: 1.0,
234                min_tokens: 1,
235            };
236            let err = typical_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
237            assert!(matches!(err, SeqError::InvalidConfiguration(_)));
238        }
239    }
240
241    #[test]
242    fn nonpositive_temperature_rejected() {
243        let mut rng = LcgRng::new(0);
244        for t in [0.0_f64, -0.5, f64::NAN] {
245            let cfg = TypicalConfig {
246                tau: 0.9,
247                temperature: t,
248                min_tokens: 1,
249            };
250            let err = typical_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
251            assert!(matches!(err, SeqError::InvalidParameter { .. }));
252        }
253    }
254
255    #[test]
256    fn zero_min_tokens_rejected() {
257        let cfg = TypicalConfig {
258            tau: 0.9,
259            temperature: 1.0,
260            min_tokens: 0,
261        };
262        let mut rng = LcgRng::new(0);
263        let err = typical_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
264        assert!(matches!(err, SeqError::InvalidConfiguration(_)));
265    }
266
267    #[test]
268    fn empty_logits_rejected() {
269        let cfg = TypicalConfig::default();
270        let mut rng = LcgRng::new(0);
271        let err = typical_sample(&[], &cfg, &mut rng).unwrap_err();
272        assert!(matches!(err, SeqError::EmptyInput));
273    }
274
275    #[test]
276    fn tau_one_keeps_everything() {
277        let logits = vec![0.0_f64; 5];
278        let cfg = TypicalConfig {
279            tau: 1.0,
280            temperature: 1.0,
281            min_tokens: 1,
282        };
283        let mut rng = LcgRng::new(0);
284        let mut counts = [0usize; 5];
285        for _ in 0..5000 {
286            counts[typical_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
287        }
288        for c in counts {
289            assert!(c > 700, "counts = {counts:?}");
290        }
291    }
292
293    #[test]
294    fn uniform_logits_all_typical() {
295        // For a uniform distribution over v tokens, H = log v and every
296        // surprisal also equals log v, so all gaps are zero.  Sampling at
297        // any tau > 0 must keep them all and produce a uniform output.
298        let logits = vec![0.0_f64; 4];
299        let probs = softmax_scaled(&logits, 1.0).expect("ok");
300        let h = entropy(&probs);
301        assert!((h - (4.0_f64).ln()).abs() < 1e-12);
302        for &p in &probs {
303            let gap = (-p.ln() - h).abs();
304            assert!(gap < 1e-12, "gap = {gap}");
305        }
306    }
307
308    #[test]
309    fn peaked_logits_picks_peak() {
310        // For a strongly peaked distribution the entropy is ~0; the argmax
311        // has surprisal ~0, so its gap is the smallest and it is "most
312        // typical".  Verify the argmax wins.
313        let logits = vec![20.0, 0.0, 0.0, 0.0, 0.0];
314        let cfg = TypicalConfig {
315            tau: 0.9,
316            temperature: 1.0,
317            min_tokens: 1,
318        };
319        let mut rng = LcgRng::new(0);
320        for _ in 0..200 {
321            assert_eq!(typical_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
322        }
323    }
324
325    #[test]
326    fn deterministic_with_seed() {
327        let logits = vec![0.5, -1.0, 1.2, 0.0, 2.3];
328        let cfg = TypicalConfig {
329            tau: 0.9,
330            temperature: 1.0,
331            min_tokens: 1,
332        };
333        let mut rng_a = LcgRng::new(0xC0DE);
334        let mut rng_b = LcgRng::new(0xC0DE);
335        for _ in 0..200 {
336            let a = typical_sample(&logits, &cfg, &mut rng_a).expect("ok");
337            let b = typical_sample(&logits, &cfg, &mut rng_b).expect("ok");
338            assert_eq!(a, b);
339        }
340    }
341
342    #[test]
343    fn min_tokens_lower_bound() {
344        // With tau very small, the kept set would otherwise be just the
345        // single most-typical token; min_tokens = 3 must force 3 to be
346        // kept.  Use a moderately spread distribution so that the 3
347        // most-typical tokens are well-defined and the remaining ones
348        // are excluded.
349        let logits = vec![3.0_f64, 0.0, -1.0, -20.0, -25.0];
350        let cfg = TypicalConfig {
351            tau: 1.0e-6,
352            temperature: 1.0,
353            min_tokens: 3,
354        };
355        let mut rng = LcgRng::new(0);
356        let mut seen_4 = false;
357        let mut seen_lower = false;
358        for _ in 0..1000 {
359            let t = typical_sample(&logits, &cfg, &mut rng).expect("ok");
360            if t == 3 || t == 4 {
361                seen_lower = true;
362            }
363            if t < 3 {
364                seen_4 = true;
365            }
366        }
367        assert!(seen_4);
368        assert!(!seen_lower, "tail tokens leaked into the typical set");
369    }
370
371    #[test]
372    fn batch_correctness() {
373        // Row 0: strongly peaked on idx 0; row 1: strongly peaked on idx 2.
374        let logits = vec![20.0, -5.0, -5.0, -5.0, -5.0, 20.0];
375        let cfg = TypicalConfig {
376            tau: 0.9,
377            temperature: 1.0,
378            min_tokens: 1,
379        };
380        let mut rng = LcgRng::new(0);
381        let out = typical_sample_batch(&logits, 2, 3, &cfg, &mut rng).expect("ok");
382        assert_eq!(out, vec![0, 2]);
383    }
384
385    #[test]
386    fn batch_shape_mismatch() {
387        let logits = vec![0.0_f64; 5];
388        let cfg = TypicalConfig::default();
389        let mut rng = LcgRng::new(0);
390        let err = typical_sample_batch(&logits, 2, 3, &cfg, &mut rng).unwrap_err();
391        assert!(matches!(err, SeqError::ShapeMismatch { .. }));
392    }
393
394    #[test]
395    fn numerically_stable_softmax() {
396        // Logits that would overflow without max-subtraction.
397        let logits = vec![1.0e6_f64, 1.0, 1.0];
398        let cfg = TypicalConfig {
399            tau: 0.9,
400            temperature: 1.0,
401            min_tokens: 1,
402        };
403        let mut rng = LcgRng::new(0);
404        let t = typical_sample(&logits, &cfg, &mut rng).expect("ok");
405        assert_eq!(t, 0);
406    }
407
408    #[test]
409    fn symmetric_distribution_sanity() {
410        // Symmetric two-mode distribution: tokens 0 and 4 are equally
411        // probable, tokens 1 and 3 are equally probable but lower, token
412        // 2 is the lowest.  Their gaps come in matching pairs, so over
413        // many samples token 0 and token 4 must be drawn at comparable
414        // rates and the bulk of mass should sit on the two modes.
415        let logits = vec![2.0_f64, 0.5, -1.0, 0.5, 2.0];
416        let cfg = TypicalConfig {
417            tau: 0.9,
418            temperature: 1.0,
419            min_tokens: 1,
420        };
421        let mut rng = LcgRng::new(0);
422        let mut counts = [0usize; 5];
423        for _ in 0..6000 {
424            counts[typical_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
425        }
426        let lo = counts[0].min(counts[4]) as f64;
427        let hi = counts[0].max(counts[4]) as f64;
428        assert!(hi / lo < 1.5, "asymmetric: counts = {counts:?}");
429    }
430
431    #[test]
432    fn entropy_known_values() {
433        // Uniform over 5: H = log 5 ≈ 1.6094.
434        let p_uniform = vec![0.2_f64; 5];
435        let h_u = entropy(&p_uniform);
436        assert!((h_u - (5.0_f64).ln()).abs() < 1e-12);
437
438        // Peaked: p = [1-4ε, ε, ε, ε, ε] with ε = 1e-12 — H ≈ 0.
439        let eps = 1.0e-12;
440        let p_peak = vec![1.0 - 4.0 * eps, eps, eps, eps, eps];
441        let h_p = entropy(&p_peak);
442        assert!(h_p.abs() < 1.0e-9, "peaked entropy = {h_p}");
443    }
444
445    #[test]
446    fn single_element_vocab() {
447        let logits = vec![0.5];
448        let cfg = TypicalConfig::default();
449        let mut rng = LcgRng::new(0);
450        for _ in 0..10 {
451            assert_eq!(typical_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
452        }
453    }
454}