1use std::borrow::Cow;
8use std::cmp::Ordering;
9
10#[derive(Debug, Clone, PartialEq, thiserror::Error)]
12pub enum SamplingError {
13 #[error("logits cannot be empty")]
14 EmptyLogits,
15 #[error("invalid token id in history: {0}")]
16 InvalidHistoryToken(i32),
17 #[error("temperature must be > 0, got {0}")]
18 InvalidTemperature(f32),
19 #[error("top_p must be in (0, 1], got {0}")]
20 InvalidTopP(f32),
21 #[error("repetition_penalty must be >= 1.0, got {0}")]
22 InvalidRepetitionPenalty(f32),
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SamplingStrategy {
28 Greedy,
30 Stochastic,
32}
33
34#[derive(Debug, Clone)]
36pub struct SamplingConfig {
37 pub strategy: SamplingStrategy,
38 pub temperature: f32,
39 pub top_k: Option<usize>,
40 pub top_p: Option<f32>,
41 pub repetition_penalty: Option<f32>,
42 pub seed: u64,
43}
44
45impl Default for SamplingConfig {
46 fn default() -> Self {
47 Self {
48 strategy: SamplingStrategy::Stochastic,
49 temperature: 1.0,
50 top_k: None,
51 top_p: None,
52 repetition_penalty: None,
53 seed: 0,
54 }
55 }
56}
57
58impl SamplingConfig {
59 fn validate(&self) -> Result<(), SamplingError> {
60 if self.temperature <= 0.0 {
61 return Err(SamplingError::InvalidTemperature(self.temperature));
62 }
63 if let Some(top_p) = self.top_p {
64 if !(top_p > 0.0 && top_p <= 1.0) {
65 return Err(SamplingError::InvalidTopP(top_p));
66 }
67 }
68 if let Some(penalty) = self.repetition_penalty {
69 if penalty < 1.0 {
70 return Err(SamplingError::InvalidRepetitionPenalty(penalty));
71 }
72 }
73 Ok(())
74 }
75}
76
77pub struct Sampler {
79 cfg: SamplingConfig,
80 rng: XorShift64,
81}
82
83impl Sampler {
84 pub fn new(cfg: SamplingConfig) -> Result<Self, SamplingError> {
85 cfg.validate()?;
86 Ok(Self {
87 rng: XorShift64::seeded(cfg.seed),
88 cfg,
89 })
90 }
91
92 pub fn sample(&mut self, logits: &[f32], history: &[i32]) -> Result<i32, SamplingError> {
94 if logits.is_empty() {
95 return Err(SamplingError::EmptyLogits);
96 }
97
98 let adjusted: Cow<'_, [f32]> = if let Some(penalty) = self.cfg.repetition_penalty {
99 let mut buf = logits.to_vec();
100 apply_repetition_penalty(&mut buf, history, penalty)?;
101 Cow::Owned(buf)
102 } else {
103 Cow::Borrowed(logits)
104 };
105
106 if self.cfg.strategy == SamplingStrategy::Greedy {
107 return greedy_sample(&adjusted);
108 }
109
110 let mut probs = softmax_with_temperature(&adjusted, self.cfg.temperature)?;
111
112 if let Some(top_k) = self.cfg.top_k {
113 apply_top_k(&mut probs, top_k);
114 normalize_probs(&mut probs);
115 }
116
117 if let Some(top_p) = self.cfg.top_p {
118 apply_top_p(&mut probs, top_p)?;
119 }
120
121 normalize_probs(&mut probs);
122 Ok(sample_from_probs(&probs, &mut self.rng))
123 }
124}
125
126pub fn greedy_sample(logits: &[f32]) -> Result<i32, SamplingError> {
128 if logits.is_empty() {
129 return Err(SamplingError::EmptyLogits);
130 }
131
132 let mut best_idx = 0usize;
133 let mut best_val = logits[0];
134 for (idx, &val) in logits.iter().enumerate().skip(1) {
135 if val > best_val {
136 best_idx = idx;
137 best_val = val;
138 }
139 }
140 Ok(best_idx as i32)
141}
142
143pub fn apply_repetition_penalty(
149 logits: &mut [f32],
150 history: &[i32],
151 penalty: f32,
152) -> Result<(), SamplingError> {
153 if penalty < 1.0 {
154 return Err(SamplingError::InvalidRepetitionPenalty(penalty));
155 }
156
157 let mut seen = vec![false; logits.len()];
158 for &token in history {
159 if token < 0 {
160 return Err(SamplingError::InvalidHistoryToken(token));
161 }
162 let idx = token as usize;
163 if idx >= logits.len() {
164 return Err(SamplingError::InvalidHistoryToken(token));
165 }
166 if !seen[idx] {
167 seen[idx] = true;
168 if logits[idx] > 0.0 {
169 logits[idx] /= penalty;
170 } else {
171 logits[idx] *= penalty;
172 }
173 }
174 }
175 Ok(())
176}
177
178fn softmax_with_temperature(logits: &[f32], temperature: f32) -> Result<Vec<f32>, SamplingError> {
179 if logits.is_empty() {
180 return Err(SamplingError::EmptyLogits);
181 }
182 if temperature <= 0.0 {
183 return Err(SamplingError::InvalidTemperature(temperature));
184 }
185
186 let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
187 let max_val = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
188 let mut exps: Vec<f32> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
189 normalize_probs(&mut exps);
190 Ok(exps)
191}
192
193fn normalize_probs(probs: &mut [f32]) {
194 let sum: f32 = probs.iter().sum();
195 if sum <= 0.0 {
196 return;
197 }
198 for p in probs.iter_mut() {
199 *p /= sum;
200 }
201}
202
203fn apply_top_k(probs: &mut [f32], top_k: usize) {
204 if top_k == 0 || top_k >= probs.len() {
205 return;
206 }
207
208 let mut order: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
209 order.sort_by(|a, b| {
210 b.1.partial_cmp(&a.1)
211 .unwrap_or(Ordering::Equal)
212 .then_with(|| a.0.cmp(&b.0))
213 });
214
215 for &(idx, _) in order.iter().skip(top_k) {
216 probs[idx] = 0.0;
217 }
218}
219
220fn apply_top_p(probs: &mut [f32], top_p: f32) -> Result<(), SamplingError> {
221 if !(top_p > 0.0 && top_p <= 1.0) {
222 return Err(SamplingError::InvalidTopP(top_p));
223 }
224
225 let mut order: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
226 order.sort_by(|a, b| {
227 b.1.partial_cmp(&a.1)
228 .unwrap_or(Ordering::Equal)
229 .then_with(|| a.0.cmp(&b.0))
230 });
231
232 let mut cumulative = 0.0f32;
233 let mut keep = vec![false; probs.len()];
234 for &(idx, p) in &order {
235 cumulative += p;
236 keep[idx] = true;
237 if cumulative >= top_p {
238 break;
239 }
240 }
241
242 for (idx, p) in probs.iter_mut().enumerate() {
243 if !keep[idx] {
244 *p = 0.0;
245 }
246 }
247 Ok(())
248}
249
250fn sample_from_probs(probs: &[f32], rng: &mut XorShift64) -> i32 {
251 let r = rng.next_f32();
252 let mut cumulative = 0.0f32;
253 for (idx, &p) in probs.iter().enumerate() {
254 if p <= 0.0 {
255 continue;
256 }
257 cumulative += p;
258 if r < cumulative {
259 return idx as i32;
260 }
261 }
262
263 probs
265 .iter()
266 .enumerate()
267 .max_by(|a, b| {
268 a.1.partial_cmp(b.1)
269 .unwrap_or(Ordering::Equal)
270 .then_with(|| b.0.cmp(&a.0))
271 })
272 .map(|(i, _)| i as i32)
273 .unwrap_or(0)
274}
275
276#[derive(Debug, Clone)]
277struct XorShift64 {
278 state: u64,
279}
280
281impl XorShift64 {
282 fn seeded(seed: u64) -> Self {
283 let state = if seed == 0 {
285 0x9E37_79B9_7F4A_7C15
286 } else {
287 seed
288 };
289 Self { state }
290 }
291
292 fn next_u64(&mut self) -> u64 {
293 let mut x = self.state;
294 x ^= x << 13;
295 x ^= x >> 7;
296 x ^= x << 17;
297 self.state = x;
298 x
299 }
300
301 fn next_f32(&mut self) -> f32 {
302 let v = self.next_u64() >> 40; (v as f32) / ((1u32 << 24) as f32)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn greedy_selects_max_logit() {
313 let logits = vec![0.1, 2.0, 1.5];
314 assert_eq!(greedy_sample(&logits).unwrap(), 1);
315 }
316
317 #[test]
318 fn top_k_limits_candidates() {
319 let cfg = SamplingConfig {
320 top_k: Some(1),
321 seed: 42,
322 ..SamplingConfig::default()
323 };
324 let mut sampler = Sampler::new(cfg).unwrap();
325 let logits = vec![0.1, 5.0, 4.0];
326 assert_eq!(sampler.sample(&logits, &[]).unwrap(), 1);
327 }
328
329 #[test]
330 fn top_p_limits_tail() {
331 let cfg = SamplingConfig {
332 top_p: Some(0.55),
333 seed: 42,
334 ..SamplingConfig::default()
335 };
336 let mut sampler = Sampler::new(cfg).unwrap();
337 let logits = vec![4.0, 2.0, 1.0];
338 assert_eq!(sampler.sample(&logits, &[]).unwrap(), 0);
340 }
341
342 #[test]
343 fn seeded_rng_is_deterministic() {
344 let cfg = SamplingConfig {
345 top_k: Some(3),
346 top_p: Some(0.95),
347 temperature: 0.9,
348 seed: 12345,
349 ..SamplingConfig::default()
350 };
351 let mut a = Sampler::new(cfg.clone()).unwrap();
352 let mut b = Sampler::new(cfg).unwrap();
353
354 let logits = vec![1.0, 1.1, 1.2, 1.3];
355 let mut seq_a = Vec::new();
356 let mut seq_b = Vec::new();
357 for _ in 0..20 {
358 seq_a.push(a.sample(&logits, &seq_a).unwrap());
359 seq_b.push(b.sample(&logits, &seq_b).unwrap());
360 }
361
362 assert_eq!(seq_a, seq_b);
363 }
364
365 #[test]
366 fn repetition_penalty_masks_logits() {
367 let cfg = SamplingConfig {
368 strategy: SamplingStrategy::Greedy,
369 repetition_penalty: Some(2.0),
370 ..SamplingConfig::default()
371 };
372 let mut sampler = Sampler::new(cfg).unwrap();
373
374 let logits = vec![0.9, 1.0];
375 let history = vec![1];
376 assert_eq!(sampler.sample(&logits, &history).unwrap(), 0);
378 }
379
380 #[test]
381 fn invalid_config_is_rejected() {
382 let cfg = SamplingConfig {
383 temperature: 0.0,
384 ..SamplingConfig::default()
385 };
386 assert!(matches!(
387 Sampler::new(cfg),
388 Err(SamplingError::InvalidTemperature(0.0))
389 ));
390 }
391}