1use std::cell::RefCell;
4
5use rten_simd::SimdOp;
6use rten_vecmath::Softmax;
7
8use crate::Logits;
9use crate::generator::TokenId;
10
11pub trait Sampler {
13 fn sample(&self, logits: &Logits) -> TokenId;
19}
20
21#[derive(Clone, Default)]
23pub struct ArgMax {
24 _private: (),
25}
26
27impl ArgMax {
28 pub fn new() -> ArgMax {
29 ArgMax { _private: () }
30 }
31}
32
33impl Sampler for ArgMax {
34 fn sample(&self, logits: &Logits) -> TokenId {
35 let next_id = logits
36 .enumerate()
37 .reduce(|(max_i, max_val), (i, val)| {
38 if val > max_val {
39 (i, val)
40 } else {
41 (max_i, max_val)
42 }
43 })
44 .expect("logits should be non-empty")
45 .0;
46 next_id as TokenId
47 }
48}
49
50#[derive(Clone, Default)]
59pub struct Multinomial {
60 rng: RefCell<fastrand::Rng>,
61
62 scratch: RefCell<Vec<f32>>,
64}
65
66impl Multinomial {
67 pub fn new() -> Self {
69 Self {
70 rng: RefCell::new(fastrand::Rng::default()),
71 scratch: RefCell::new(Vec::new()),
72 }
73 }
74
75 pub fn with_seed(seed: u64) -> Self {
79 let rng = fastrand::Rng::with_seed(seed);
80 Self {
81 rng: RefCell::new(rng),
82 scratch: RefCell::new(Vec::new()),
83 }
84 }
85}
86
87impl Sampler for Multinomial {
88 fn sample(&self, logits: &Logits) -> TokenId {
89 assert!(!logits.is_empty());
90
91 let mut scratch = self.scratch.borrow_mut();
93 scratch.clear();
94 scratch.reserve(logits.len());
95 let scratch = &mut scratch.spare_capacity_mut()[..logits.len()];
96 let probs = Softmax::new(logits.logits(), scratch).dispatch();
97
98 let mut rng = self.rng.borrow_mut();
99
100 let idx = multinomial(&mut rng, probs).unwrap_or(0);
105
106 logits.indices()[idx]
107 }
108}
109
110fn multinomial(rng: &mut fastrand::Rng, probs: &[f32]) -> Option<usize> {
115 let target = rng.f32();
116
117 let mut cum_prob = 0.;
118 for (idx, &prob) in probs.iter().enumerate() {
119 cum_prob += prob;
120 if target <= cum_prob {
121 return Some(idx);
122 }
123 }
124
125 None
126}
127
128#[cfg(test)]
129mod tests {
130 use rten_simd::SimdOp;
131 use rten_testing::TestCases;
132 use rten_vecmath::Softmax;
133
134 use super::{ArgMax, Multinomial, Sampler};
135 use crate::Logits;
136 use crate::generator::TokenId;
137
138 #[test]
139 fn test_argmax() {
140 let logits = Logits::dense(vec![0.1, 0.2, 0.8, 0.7]);
141 let sampler = ArgMax::new();
142
143 for _ in 0..5 {
144 let tok_id = sampler.sample(&logits);
145 assert_eq!(tok_id, 2);
146 }
147 }
148
149 #[test]
150 fn test_multinomial() {
151 let logits = Logits::dense(vec![0.25, 0.25, 0.5]);
152 let sampler = Multinomial::with_seed(1234);
153 let n_iters = 512;
154
155 let mut counts = vec![0u32; logits.len()];
156 for _ in 0..n_iters {
157 let tok_id = sampler.sample(&logits);
158 counts[tok_id as usize] += 1;
159 }
160
161 let mut normalized_logits = logits.logits().to_vec();
162 Softmax::new_mut(&mut normalized_logits).dispatch();
163
164 let threshold = 0.12;
168 for (prob, count) in normalized_logits.into_iter().zip(counts) {
169 let expected = (prob * n_iters as f32).round() as i32;
170 let delta = (count as i32 - expected).abs();
171 let delta_frac = delta as f32 / expected as f32;
172
173 assert!(
174 delta_frac <= threshold,
175 "sample count differs from expectation by {:.1}%, above threshold {}%",
176 delta_frac * 100.0,
177 threshold * 100.0
178 );
179 }
180 }
181
182 #[test]
183 fn test_multinomial_nan_infinity() {
184 #[derive(Debug)]
185 struct Case {
186 logits: Vec<f32>,
187 expected: TokenId,
188 }
189
190 let cases = [
191 Case {
193 logits: vec![0.1, f32::NAN, 0.5],
194 expected: 0,
195 },
196 Case {
197 logits: vec![0.1, f32::INFINITY, 0.5],
198 expected: 0,
199 },
200 Case {
202 logits: vec![0., f32::NEG_INFINITY, 100.0],
203 expected: 2,
204 },
205 ];
206
207 cases.test_each(|case| {
208 let logits = Logits::dense(case.logits.clone());
209 let sampler = Multinomial::with_seed(1234);
210 let n_iters = 10;
211 for _ in 0..n_iters {
212 let token_id = sampler.sample(&logits);
213 assert_eq!(token_id, case.expected);
214 }
215 });
216 }
217}