rten_generate/
sampler.rs

1//! Samplers which select a token from model outputs.
2
3use std::cell::RefCell;
4
5use rten_simd::SimdOp;
6use rten_vecmath::Softmax;
7
8use crate::Logits;
9use crate::generator::TokenId;
10
11/// Samplers take the output logits from a model and select a token ID.
12pub trait Sampler {
13    /// Sample a token ID from the output logits of a model.
14    ///
15    /// # Panics
16    ///
17    /// `sample` will panic if `logits` is empty.
18    fn sample(&self, logits: &Logits) -> TokenId;
19}
20
21/// A [`Sampler`] which always chooses the token ID with the highest probability.
22#[derive(Clone, Default)]
23pub struct ArgMax {
24    _private: (),
25}
26
27impl ArgMax {
28    pub fn new() -> ArgMax {
29        ArgMax { _private: () }
30    }
31}
32
33impl Sampler for ArgMax {
34    fn sample(&self, logits: &Logits) -> TokenId {
35        let next_id = logits
36            .enumerate()
37            .reduce(|(max_i, max_val), (i, val)| {
38                if val > max_val {
39                    (i, val)
40                } else {
41                    (max_i, max_val)
42                }
43            })
44            .expect("logits should be non-empty")
45            .0;
46        next_id as TokenId
47    }
48}
49
50/// A [`Sampler`] which chooses a token ID according to the probability of each
51/// logit.
52///
53/// Input logits are first normalized using a softmax operation before a token
54/// ID is sampled according to the probability of each logit.
55///
56/// By default sampling uses a random seed so results will vary for each run.
57/// To get repeatable sampling, use [`with_seed`](Multinomial::with_seed).
58#[derive(Clone, Default)]
59pub struct Multinomial {
60    rng: RefCell<fastrand::Rng>,
61
62    // Scratch space for normalized logits.
63    scratch: RefCell<Vec<f32>>,
64}
65
66impl Multinomial {
67    /// Create a sampler with a random seed.
68    pub fn new() -> Self {
69        Self {
70            rng: RefCell::new(fastrand::Rng::default()),
71            scratch: RefCell::new(Vec::new()),
72        }
73    }
74
75    /// Create a sampler with a fixed seed.
76    ///
77    /// This guarantees repeatable sampling.
78    pub fn with_seed(seed: u64) -> Self {
79        let rng = fastrand::Rng::with_seed(seed);
80        Self {
81            rng: RefCell::new(rng),
82            scratch: RefCell::new(Vec::new()),
83        }
84    }
85}
86
87impl Sampler for Multinomial {
88    fn sample(&self, logits: &Logits) -> TokenId {
89        assert!(!logits.is_empty());
90
91        // Normalize logits to probabilities.
92        let mut scratch = self.scratch.borrow_mut();
93        scratch.clear();
94        scratch.reserve(logits.len());
95        let scratch = &mut scratch.spare_capacity_mut()[..logits.len()];
96        let probs = Softmax::new(logits.logits(), scratch).dispatch();
97
98        let mut rng = self.rng.borrow_mut();
99
100        // Sample ID according to probabilities.
101        //
102        // `multinomial` may return None if the input contains a NaN or
103        // infinity. In that case we fall back to the ID zero.
104        let idx = multinomial(&mut rng, probs).unwrap_or(0);
105
106        logits.indices()[idx]
107    }
108}
109
110/// Sample an item from a vector of probabilities.
111///
112/// Returns the index of the selected item, or `None` if the vector is empty
113/// or sums to less than 1.
114fn multinomial(rng: &mut fastrand::Rng, probs: &[f32]) -> Option<usize> {
115    let target = rng.f32();
116
117    let mut cum_prob = 0.;
118    for (idx, &prob) in probs.iter().enumerate() {
119        cum_prob += prob;
120        if target <= cum_prob {
121            return Some(idx);
122        }
123    }
124
125    None
126}
127
128#[cfg(test)]
129mod tests {
130    use rten_simd::SimdOp;
131    use rten_testing::TestCases;
132    use rten_vecmath::Softmax;
133
134    use super::{ArgMax, Multinomial, Sampler};
135    use crate::Logits;
136    use crate::generator::TokenId;
137
138    #[test]
139    fn test_argmax() {
140        let logits = Logits::dense(vec![0.1, 0.2, 0.8, 0.7]);
141        let sampler = ArgMax::new();
142
143        for _ in 0..5 {
144            let tok_id = sampler.sample(&logits);
145            assert_eq!(tok_id, 2);
146        }
147    }
148
149    #[test]
150    fn test_multinomial() {
151        let logits = Logits::dense(vec![0.25, 0.25, 0.5]);
152        let sampler = Multinomial::with_seed(1234);
153        let n_iters = 512;
154
155        let mut counts = vec![0u32; logits.len()];
156        for _ in 0..n_iters {
157            let tok_id = sampler.sample(&logits);
158            counts[tok_id as usize] += 1;
159        }
160
161        let mut normalized_logits = logits.logits().to_vec();
162        Softmax::new_mut(&mut normalized_logits).dispatch();
163
164        // Check sample count for each token is within a threshold percentage
165        // of expectations. Increasing the sample count should bring actual
166        // closer to expected.
167        let threshold = 0.12;
168        for (prob, count) in normalized_logits.into_iter().zip(counts) {
169            let expected = (prob * n_iters as f32).round() as i32;
170            let delta = (count as i32 - expected).abs();
171            let delta_frac = delta as f32 / expected as f32;
172
173            assert!(
174                delta_frac <= threshold,
175                "sample count differs from expectation by {:.1}%, above threshold {}%",
176                delta_frac * 100.0,
177                threshold * 100.0
178            );
179        }
180    }
181
182    #[test]
183    fn test_multinomial_nan_infinity() {
184        #[derive(Debug)]
185        struct Case {
186            logits: Vec<f32>,
187            expected: TokenId,
188        }
189
190        let cases = [
191            // Softmax normalization spreads NaNs and positive infinities.
192            Case {
193                logits: vec![0.1, f32::NAN, 0.5],
194                expected: 0,
195            },
196            Case {
197                logits: vec![0.1, f32::INFINITY, 0.5],
198                expected: 0,
199            },
200            // Negative infinity shrinks to zero after softmax.
201            Case {
202                logits: vec![0., f32::NEG_INFINITY, 100.0],
203                expected: 2,
204            },
205        ];
206
207        cases.test_each(|case| {
208            let logits = Logits::dense(case.logits.clone());
209            let sampler = Multinomial::with_seed(1234);
210            let n_iters = 10;
211            for _ in 0..n_iters {
212                let token_id = sampler.sample(&logits);
213                assert_eq!(token_id, case.expected);
214            }
215        });
216    }
217}