1#[derive(Debug, Clone)]
7pub struct SamplingConfig {
8 pub temperature: f32,
10 pub top_k: usize,
12 pub top_p: f32,
14 pub repetition_penalty: f32,
16}
17
18impl Default for SamplingConfig {
19 fn default() -> Self {
20 Self {
21 temperature: 1.0,
22 top_k: 0,
23 top_p: 1.0,
24 repetition_penalty: 1.0,
25 }
26 }
27}
28
29impl SamplingConfig {
30 pub fn greedy() -> Self {
32 Self {
33 temperature: 0.0,
34 top_k: 1,
35 top_p: 1.0,
36 repetition_penalty: 1.0,
37 }
38 }
39}
40
41pub fn sample(logits: &[f32], config: &SamplingConfig, rng_seed: u64) -> u32 {
46 let mut scores: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
47
48 if config.temperature == 0.0 || config.top_k == 1 {
50 return argmax(logits) as u32;
51 }
52
53 if config.temperature != 1.0 {
55 let inv_temp = 1.0 / config.temperature;
56 for (_, score) in &mut scores {
57 *score *= inv_temp;
58 }
59 }
60
61 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
63
64 if config.top_k > 0 && config.top_k < scores.len() {
66 scores.truncate(config.top_k);
67 }
68
69 let max_score = scores[0].1;
71 let mut sum = 0.0f32;
72 for (_, score) in &mut scores {
73 *score = (*score - max_score).exp();
74 sum += *score;
75 }
76 for (_, score) in &mut scores {
77 *score /= sum;
78 }
79
80 if config.top_p < 1.0 {
82 let mut cumulative = 0.0f32;
83 let mut cutoff = scores.len();
84 for (i, (_, prob)) in scores.iter().enumerate() {
85 cumulative += prob;
86 if cumulative >= config.top_p {
87 cutoff = i + 1;
88 break;
89 }
90 }
91 scores.truncate(cutoff);
92
93 let sum: f32 = scores.iter().map(|(_, p)| p).sum();
95 for (_, prob) in &mut scores {
96 *prob /= sum;
97 }
98 }
99
100 let r = simple_rng(rng_seed);
102 let mut cumulative = 0.0f32;
103 for (token_id, prob) in &scores {
104 cumulative += prob;
105 if r < cumulative {
106 return *token_id as u32;
107 }
108 }
109
110 scores.last().map(|(id, _)| *id as u32).unwrap_or(0)
112}
113
114pub fn argmax(logits: &[f32]) -> usize {
116 logits
117 .iter()
118 .enumerate()
119 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
120 .map(|(i, _)| i)
121 .unwrap_or(0)
122}
123
124pub fn apply_repetition_penalty(logits: &mut [f32], generated_tokens: &[u32], penalty: f32) {
126 if penalty == 1.0 {
127 return;
128 }
129 for &token in generated_tokens {
130 let idx = token as usize;
131 if idx < logits.len() {
132 if logits[idx] > 0.0 {
133 logits[idx] /= penalty;
134 } else {
135 logits[idx] *= penalty;
136 }
137 }
138 }
139}
140
141fn simple_rng(seed: u64) -> f32 {
144 let mut x = seed;
146 x ^= x << 13;
147 x ^= x >> 7;
148 x ^= x << 17;
149 (x & 0x00FF_FFFF) as f32 / 0x0100_0000 as f32
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn greedy_sampling() {
158 let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
159 let config = SamplingConfig::greedy();
160 let token = sample(&logits, &config, 42);
161 assert_eq!(token, 3); }
163
164 #[test]
165 fn argmax_basic() {
166 assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
167 assert_eq!(argmax(&[5.0, 1.0, 2.0]), 0);
168 assert_eq!(argmax(&[-1.0, -2.0, -0.5]), 2);
169 }
170
171 #[test]
172 fn temperature_zero_is_greedy() {
173 let logits = vec![0.1, 0.9, 0.5];
174 let config = SamplingConfig {
175 temperature: 0.0,
176 ..Default::default()
177 };
178 let token = sample(&logits, &config, 123);
179 assert_eq!(token, 1);
180 }
181
182 #[test]
183 fn top_k_limits_candidates() {
184 let logits = vec![0.1, 0.9, 0.8, 0.05, 0.01];
186 let config = SamplingConfig {
187 temperature: 1.0,
188 top_k: 2,
189 top_p: 1.0,
190 repetition_penalty: 1.0,
191 };
192
193 for seed in 0..100 {
195 let token = sample(&logits, &config, seed);
196 assert!(
197 token == 1 || token == 2,
198 "top_k=2 sampled token {token}, expected 1 or 2"
199 );
200 }
201 }
202
203 #[test]
204 fn top_p_nucleus_sampling() {
205 let logits = vec![10.0, 1.0, 0.1, 0.01];
207 let config = SamplingConfig {
208 temperature: 1.0,
209 top_k: 0,
210 top_p: 0.5,
211 repetition_penalty: 1.0,
212 };
213
214 let token = sample(&logits, &config, 42);
215 assert_eq!(token, 0, "nucleus sampling should pick dominant token");
216 }
217
218 #[test]
219 fn repetition_penalty() {
220 let mut logits = vec![0.5, 0.9, 0.3];
221 apply_repetition_penalty(&mut logits, &[1], 2.0);
222
223 assert!((logits[1] - 0.45).abs() < 1e-6);
225 assert!((logits[0] - 0.5).abs() < 1e-6);
227 assert!((logits[2] - 0.3).abs() < 1e-6);
228 }
229
230 #[test]
231 fn repetition_penalty_negative_logits() {
232 let mut logits = vec![-0.5, 0.9, -0.3];
233 apply_repetition_penalty(&mut logits, &[0, 2], 2.0);
234
235 assert!((logits[0] - (-1.0)).abs() < 1e-6);
237 assert!((logits[2] - (-0.6)).abs() < 1e-6);
238 }
239
240 #[test]
241 fn default_config() {
242 let config = SamplingConfig::default();
243 assert_eq!(config.temperature, 1.0);
244 assert_eq!(config.top_k, 0);
245 assert_eq!(config.top_p, 1.0);
246 assert_eq!(config.repetition_penalty, 1.0);
247 }
248
249 #[test]
250 fn simple_rng_in_range() {
251 for seed in 0..1000 {
252 let val = simple_rng(seed);
253 assert!(
254 (0.0..1.0).contains(&val),
255 "rng({seed}) = {val} out of range"
256 );
257 }
258 }
259
260 #[test]
263 fn temperature_zero_always_picks_argmax() {
264 let logits = vec![0.1, 0.3, 0.9, 0.5, 0.2, 0.8, 0.7, 0.4];
266 let config = SamplingConfig {
267 temperature: 0.0,
268 top_k: 0,
269 top_p: 1.0,
270 repetition_penalty: 1.0,
271 };
272
273 for seed in 0..200 {
274 let token = sample(&logits, &config, seed);
275 assert_eq!(
276 token, 2,
277 "temp=0 should always pick argmax (token 2), got {token} at seed {seed}"
278 );
279 }
280 }
281
282 #[test]
283 fn high_temperature_distributes_samples() {
284 let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
287 let config = SamplingConfig {
288 temperature: 100.0,
289 top_k: 0,
290 top_p: 1.0,
291 repetition_penalty: 1.0,
292 };
293
294 let mut seen = [false; 5];
295 for seed in 0..500 {
296 let token = sample(&logits, &config, seed) as usize;
297 assert!(token < 5, "token {token} out of range");
298 seen[token] = true;
299 }
300
301 let seen_count = seen.iter().filter(|&&s| s).count();
304 assert!(
305 seen_count >= 3,
306 "high temperature should sample diverse tokens, but only {seen_count}/5 seen"
307 );
308 }
309
310 #[test]
311 fn repetition_penalty_reduces_repeated_token_probability() {
312 let mut logits = vec![0.1, 10.0, 0.2, 9.5];
315 assert_eq!(argmax(&logits), 1, "pre-penalty argmax should be token 1");
316
317 apply_repetition_penalty(&mut logits, &[1], 20.0);
319 assert_ne!(
320 argmax(&logits),
321 1,
322 "after heavy repetition penalty, argmax should shift away from token 1"
323 );
324 assert_eq!(
325 argmax(&logits),
326 3,
327 "after penalizing token 1, token 3 (9.5) should become argmax"
328 );
329 }
330
331 #[test]
332 fn softmax_all_negative_logits_produces_valid_distribution() {
333 let logits = vec![-100.0, -200.0, -150.0, -300.0];
336 let config = SamplingConfig {
337 temperature: 1.0,
338 top_k: 0,
339 top_p: 1.0,
340 repetition_penalty: 1.0,
341 };
342
343 let token = sample(&logits, &config, 42);
344 assert!(
345 (token as usize) < logits.len(),
346 "sampled token {token} should be in valid range"
347 );
348
349 assert_eq!(
351 argmax(&logits),
352 0,
353 "argmax of all-negative logits should be index 0 (-100.0)"
354 );
355 }
356
357 #[test]
358 fn sample_with_single_token_vocab() {
359 let logits = vec![0.5];
361 let config = SamplingConfig::greedy();
362 assert_eq!(sample(&logits, &config, 0), 0);
363
364 let config_temp = SamplingConfig {
365 temperature: 1.0,
366 top_k: 0,
367 top_p: 1.0,
368 repetition_penalty: 1.0,
369 };
370 assert_eq!(sample(&logits, &config_temp, 42), 0);
371 }
372}