1#[derive(Debug, Clone)]
7pub struct SamplingConfig {
8 pub temperature: f32,
10 pub top_k: usize,
12 pub top_p: f32,
14 pub repetition_penalty: f32,
16}
17
18impl Default for SamplingConfig {
19 fn default() -> Self {
20 Self {
21 temperature: 1.0,
22 top_k: 0,
23 top_p: 1.0,
24 repetition_penalty: 1.0,
25 }
26 }
27}
28
29impl SamplingConfig {
30 pub fn greedy() -> Self {
32 Self {
33 temperature: 0.0,
34 top_k: 1,
35 top_p: 1.0,
36 repetition_penalty: 1.0,
37 }
38 }
39}
40
41pub fn sample(logits: &[f32], config: &SamplingConfig, rng_seed: u64) -> u32 {
46 let mut scores: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
47
48 if config.temperature == 0.0 || config.top_k == 1 {
50 return argmax(logits) as u32;
51 }
52
53 if config.temperature != 1.0 {
55 let inv_temp = 1.0 / config.temperature;
56 for (_, score) in &mut scores {
57 *score *= inv_temp;
58 }
59 }
60
61 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
63
64 if config.top_k > 0 && config.top_k < scores.len() {
66 scores.truncate(config.top_k);
67 }
68
69 let max_score = scores[0].1;
71 let mut sum = 0.0f32;
72 for (_, score) in &mut scores {
73 *score = (*score - max_score).exp();
74 sum += *score;
75 }
76 for (_, score) in &mut scores {
77 *score /= sum;
78 }
79
80 if config.top_p < 1.0 {
82 let mut cumulative = 0.0f32;
83 let mut cutoff = scores.len();
84 for (i, (_, prob)) in scores.iter().enumerate() {
85 cumulative += prob;
86 if cumulative >= config.top_p {
87 cutoff = i + 1;
88 break;
89 }
90 }
91 scores.truncate(cutoff);
92
93 let sum: f32 = scores.iter().map(|(_, p)| p).sum();
95 for (_, prob) in &mut scores {
96 *prob /= sum;
97 }
98 }
99
100 let r = simple_rng(rng_seed);
102 let mut cumulative = 0.0f32;
103 for (token_id, prob) in &scores {
104 cumulative += prob;
105 if r < cumulative {
106 return *token_id as u32;
107 }
108 }
109
110 scores.last().map(|(id, _)| *id as u32).unwrap_or(0)
112}
113
114pub fn argmax(logits: &[f32]) -> usize {
116 logits
117 .iter()
118 .enumerate()
119 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
120 .map(|(i, _)| i)
121 .unwrap_or(0)
122}
123
124pub fn apply_repetition_penalty(logits: &mut [f32], generated_tokens: &[u32], penalty: f32) {
126 if penalty == 1.0 {
127 return;
128 }
129 for &token in generated_tokens {
130 let idx = token as usize;
131 if idx < logits.len() {
132 if logits[idx] > 0.0 {
133 logits[idx] /= penalty;
134 } else {
135 logits[idx] *= penalty;
136 }
137 }
138 }
139}
140
141fn simple_rng(seed: u64) -> f32 {
144 let mut x = seed;
146 x ^= x << 13;
147 x ^= x >> 7;
148 x ^= x << 17;
149 (x & 0x00FF_FFFF) as f32 / 0x0100_0000 as f32
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn greedy_sampling() {
158 let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
159 let config = SamplingConfig::greedy();
160 let token = sample(&logits, &config, 42);
161 assert_eq!(token, 3); }
163
164 #[test]
165 fn argmax_basic() {
166 assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
167 assert_eq!(argmax(&[5.0, 1.0, 2.0]), 0);
168 assert_eq!(argmax(&[-1.0, -2.0, -0.5]), 2);
169 }
170
171 #[test]
172 fn temperature_zero_is_greedy() {
173 let logits = vec![0.1, 0.9, 0.5];
174 let config = SamplingConfig {
175 temperature: 0.0,
176 ..Default::default()
177 };
178 let token = sample(&logits, &config, 123);
179 assert_eq!(token, 1);
180 }
181
182 #[test]
183 fn top_k_limits_candidates() {
184 let logits = vec![0.1, 0.9, 0.8, 0.05, 0.01];
186 let config = SamplingConfig {
187 temperature: 1.0,
188 top_k: 2,
189 top_p: 1.0,
190 repetition_penalty: 1.0,
191 };
192
193 for seed in 0..100 {
195 let token = sample(&logits, &config, seed);
196 assert!(
197 token == 1 || token == 2,
198 "top_k=2 sampled token {token}, expected 1 or 2"
199 );
200 }
201 }
202
203 #[test]
204 fn top_p_nucleus_sampling() {
205 let logits = vec![10.0, 1.0, 0.1, 0.01];
207 let config = SamplingConfig {
208 temperature: 1.0,
209 top_k: 0,
210 top_p: 0.5,
211 repetition_penalty: 1.0,
212 };
213
214 let token = sample(&logits, &config, 42);
215 assert_eq!(token, 0, "nucleus sampling should pick dominant token");
216 }
217
218 #[test]
219 fn repetition_penalty() {
220 let mut logits = vec![0.5, 0.9, 0.3];
221 apply_repetition_penalty(&mut logits, &[1], 2.0);
222
223 assert!((logits[1] - 0.45).abs() < 1e-6);
225 assert!((logits[0] - 0.5).abs() < 1e-6);
227 assert!((logits[2] - 0.3).abs() < 1e-6);
228 }
229
230 #[test]
231 fn repetition_penalty_negative_logits() {
232 let mut logits = vec![-0.5, 0.9, -0.3];
233 apply_repetition_penalty(&mut logits, &[0, 2], 2.0);
234
235 assert!((logits[0] - (-1.0)).abs() < 1e-6);
237 assert!((logits[2] - (-0.6)).abs() < 1e-6);
238 }
239
240 #[test]
241 fn default_config() {
242 let config = SamplingConfig::default();
243 assert_eq!(config.temperature, 1.0);
244 assert_eq!(config.top_k, 0);
245 assert_eq!(config.top_p, 1.0);
246 assert_eq!(config.repetition_penalty, 1.0);
247 }
248
249 #[test]
250 fn simple_rng_in_range() {
251 for seed in 0..1000 {
252 let val = simple_rng(seed);
253 assert!(
254 (0.0..1.0).contains(&val),
255 "rng({seed}) = {val} out of range"
256 );
257 }
258 }
259}