1#[derive(Debug, Clone, PartialEq)]
15pub enum SamplingError {
16 InvalidLogits,
17 InvalidTemperature,
18 NoValidTokens,
19}
20
21impl std::fmt::Display for SamplingError {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 SamplingError::InvalidLogits => write!(f, "Invalid logits array"),
25 SamplingError::InvalidTemperature => write!(f, "Temperature must be > 0"),
26 SamplingError::NoValidTokens => write!(f, "No valid tokens after filtering"),
27 }
28 }
29}
30
31impl std::error::Error for SamplingError {}
32
33pub type SamplingResult<T> = std::result::Result<T, SamplingError>;
34
35#[derive(Debug, Clone)]
39pub struct SeededRng {
40 state: u64,
41}
42
43impl SeededRng {
44 pub fn new(seed: u64) -> Self {
45 Self {
47 state: if seed == 0 { 1 } else { seed },
48 }
49 }
50
51 pub fn next_f32(&mut self) -> f32 {
53 self.state ^= self.state << 13;
55 self.state ^= self.state >> 7;
56 self.state ^= self.state << 17;
57 (self.state >> 40) as f32 / (1u64 << 24) as f32
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct Sampler {
64 pub temperature: f32,
66
67 pub top_k: Option<usize>,
69
70 pub top_p: Option<f32>,
72
73 pub repetition_penalty: Option<f32>,
75
76 rng: SeededRng,
78}
79
80impl Sampler {
81 pub fn new() -> Self {
83 Self {
84 temperature: 1.0,
85 top_k: None,
86 top_p: None,
87 repetition_penalty: None,
88 rng: SeededRng::new(42),
89 }
90 }
91
92 pub fn with_temperature(mut self, temp: f32) -> Self {
93 self.temperature = temp;
94 self
95 }
96
97 pub fn with_top_k(mut self, k: usize) -> Self {
98 self.top_k = Some(k);
99 self
100 }
101
102 pub fn with_top_p(mut self, p: f32) -> Self {
103 self.top_p = Some(p);
104 self
105 }
106
107 pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
108 self.repetition_penalty = Some(penalty);
109 self
110 }
111
112 pub fn with_seed(mut self, seed: u64) -> Self {
113 self.rng = SeededRng::new(seed);
114 self
115 }
116
117 pub fn sample(&mut self, logits: &[f32]) -> SamplingResult<usize> {
119 self.sample_inner(logits, &[])
120 }
121
122 pub fn sample_with_history(
124 &mut self,
125 logits: &[f32],
126 history: &[usize],
127 ) -> SamplingResult<usize> {
128 self.sample_inner(logits, history)
129 }
130
131 fn sample_inner(&mut self, logits: &[f32], history: &[usize]) -> SamplingResult<usize> {
132 if logits.is_empty() {
133 return Err(SamplingError::InvalidLogits);
134 }
135
136 if self.temperature <= 0.0 {
137 return Err(SamplingError::InvalidTemperature);
138 }
139
140 let mut work_logits = logits.to_vec();
141
142 if let Some(penalty) = self.repetition_penalty {
146 for &token_id in history {
147 if token_id < work_logits.len() {
148 if work_logits[token_id] > 0.0 {
149 work_logits[token_id] /= penalty;
150 } else {
151 work_logits[token_id] *= penalty;
152 }
153 }
154 }
155 }
156
157 if (self.temperature - 1.0).abs() > 1e-6 {
159 for logit in &mut work_logits {
160 *logit /= self.temperature;
161 }
162 }
163
164 if let Some(k) = self.top_k {
166 Self::apply_top_k(&mut work_logits, k);
167 }
168
169 let probs = Self::softmax(&work_logits);
171
172 if self.temperature < 1e-3 {
174 return Ok(Self::argmax(&probs));
175 }
176
177 let probs = if let Some(p) = self.top_p {
179 Self::apply_top_p(&probs, p)
180 } else {
181 probs
182 };
183
184 self.sample_from_distribution(&probs)
186 }
187
188 fn apply_top_k(logits: &mut [f32], k: usize) {
189 if k == 0 || k >= logits.len() {
190 return;
191 }
192
193 let mut indexed: Vec<(usize, f32)> =
194 logits.iter().enumerate().map(|(i, &l)| (i, l)).collect();
195 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
196
197 let threshold = indexed[k - 1].1;
198 for logit in logits.iter_mut() {
199 if *logit < threshold {
200 *logit = f32::NEG_INFINITY;
201 }
202 }
203 }
204
205 fn apply_top_p(probs: &[f32], p: f32) -> Vec<f32> {
206 let mut indexed: Vec<(usize, f32)> =
207 probs.iter().enumerate().map(|(i, &pr)| (i, pr)).collect();
208 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
209
210 let mut cumsum = 0.0;
211 let mut cutoff_idx = 0;
212 for (idx, (_, prob)) in indexed.iter().enumerate() {
213 cumsum += prob;
214 cutoff_idx = idx;
215 if cumsum >= p {
216 break;
217 }
218 }
219
220 let cutoff_prob = indexed[cutoff_idx].1;
221 let mut result = vec![0.0; probs.len()];
222 for (i, &pr) in probs.iter().enumerate() {
223 if pr >= cutoff_prob {
224 result[i] = pr;
225 }
226 }
227
228 let sum: f32 = result.iter().sum();
230 if sum > 0.0 {
231 for p in &mut result {
232 *p /= sum;
233 }
234 }
235
236 result
237 }
238
239 fn softmax(logits: &[f32]) -> Vec<f32> {
240 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
241 let exps: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
242 let sum: f32 = exps.iter().sum();
243
244 if sum > 0.0 {
245 exps.iter().map(|&e| e / sum).collect()
246 } else {
247 vec![1.0 / logits.len() as f32; logits.len()]
248 }
249 }
250
251 fn argmax(probs: &[f32]) -> usize {
252 probs
253 .iter()
254 .enumerate()
255 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
256 .map(|(idx, _)| idx)
257 .unwrap_or(0)
258 }
259
260 fn sample_from_distribution(&mut self, probs: &[f32]) -> SamplingResult<usize> {
261 let r = self.rng.next_f32();
262 let mut cumsum = 0.0;
263
264 for (i, &prob) in probs.iter().enumerate() {
265 cumsum += prob;
266 if r < cumsum {
267 return Ok(i);
268 }
269 }
270
271 for (i, &prob) in probs.iter().enumerate().rev() {
273 if prob > 0.0 {
274 return Ok(i);
275 }
276 }
277
278 Err(SamplingError::NoValidTokens)
279 }
280}
281
282impl Default for Sampler {
283 fn default() -> Self {
284 Self::new()
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn seeded_rng_reproducible() {
294 let mut rng1 = SeededRng::new(42);
295 let mut rng2 = SeededRng::new(42);
296
297 for _ in 0..100 {
298 let v1 = rng1.next_f32();
299 let v2 = rng2.next_f32();
300 assert!((v1 - v2).abs() < 1e-6);
301 assert!((0.0..1.0).contains(&v1));
302 }
303 }
304
305 #[test]
306 fn greedy_sampling() {
307 let logits = vec![1.0, 10.0, 2.0, 0.5];
308 let mut sampler = Sampler::new().with_temperature(0.0001);
309 let token = sampler.sample(&logits).unwrap();
310 assert_eq!(token, 1);
311 }
312
313 #[test]
314 fn softmax_uniform() {
315 let logits = vec![1.0, 1.0, 1.0];
316 let probs = Sampler::softmax(&logits);
317 assert_eq!(probs.len(), 3);
318 assert!((probs[0] - 1.0 / 3.0).abs() < 1e-5);
319 assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
320 }
321
322 #[test]
323 fn temperature_effect() {
324 let logits = vec![1.0, 2.0, 0.5];
325
326 let high_temp: Vec<f32> = logits.iter().map(|l| l / 10.0).collect();
327 let low_temp: Vec<f32> = logits.iter().map(|l| l / 0.1).collect();
328
329 let high_probs = Sampler::softmax(&high_temp);
330 let low_probs = Sampler::softmax(&low_temp);
331
332 let max_high = high_probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
333 let max_low = low_probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
334
335 assert!(max_high < max_low);
337 }
338
339 #[test]
340 fn top_k_filtering() {
341 let mut logits = vec![1.0, 10.0, 2.0, 0.5, 3.0];
342 Sampler::apply_top_k(&mut logits, 2);
343 assert!(logits[1].is_finite()); assert!(logits[4].is_finite()); assert!(!logits[0].is_finite()); }
347
348 #[test]
349 fn top_p_filtering() {
350 let probs = vec![0.5, 0.3, 0.15, 0.05];
351 let filtered = Sampler::apply_top_p(&probs, 0.8);
352 assert!(filtered[0] > 0.0);
353 assert!(filtered[1] > 0.0);
354 assert_eq!(filtered[2], 0.0);
355 assert_eq!(filtered[3], 0.0);
356 }
357
358 #[test]
359 fn repetition_penalty_reduces_likelihood() {
360 let logits = vec![1.0, 2.0, 3.0, 4.0];
361 let history = vec![3]; let probs_no_penalty = Sampler::softmax(&logits);
365
366 let mut penalized = logits.clone();
368 penalized[3] /= 2.0; let probs_with_penalty = Sampler::softmax(&penalized);
370
371 assert!(probs_with_penalty[3] < probs_no_penalty[3]);
373
374 let mut sampler = Sampler::new().with_repetition_penalty(2.0);
376 let result = sampler.sample_with_history(&logits, &history);
377 assert!(result.is_ok());
378 }
379
380 #[test]
381 fn repetition_penalty_handles_negative_logits() {
382 let logits = vec![-1.0, -2.0, 3.0];
383 let history = vec![0, 1]; let mut sampler = Sampler::new().with_repetition_penalty(2.0).with_seed(42);
386 let result = sampler.sample_with_history(&logits, &history);
387 assert!(result.is_ok());
388 }
389
390 #[test]
391 fn deterministic_across_calls() {
392 let logits = vec![0.1, 0.2, 0.3, 0.4];
393
394 let mut sampler1 = Sampler::new().with_seed(42);
395 let mut sampler2 = Sampler::new().with_seed(42);
396
397 for _ in 0..10 {
399 let t1 = sampler1.sample(&logits).unwrap();
400 let t2 = sampler2.sample(&logits).unwrap();
401 assert_eq!(t1, t2);
402 }
403 }
404
405 #[test]
406 fn rng_advances_between_calls() {
407 let logits = vec![0.25, 0.25, 0.25, 0.25];
408 let mut sampler = Sampler::new().with_seed(42);
409
410 let mut seen = std::collections::HashSet::new();
412 for _ in 0..100 {
413 seen.insert(sampler.sample(&logits).unwrap());
414 }
415 assert!(seen.len() > 1, "RNG should produce varied results");
416 }
417
418 #[test]
419 fn combined_sampling() {
420 let logits = vec![1.0, 2.0, 3.0, 4.0, 0.5, 0.1];
421 let mut sampler = Sampler::new()
422 .with_temperature(0.8)
423 .with_top_k(3)
424 .with_top_p(0.9)
425 .with_seed(42);
426
427 let token = sampler.sample(&logits).unwrap();
428 assert!(token < logits.len());
429 }
430
431 #[test]
432 fn invalid_temperature() {
433 let logits = vec![1.0, 2.0];
434 let mut sampler = Sampler::new().with_temperature(0.0);
435 assert_eq!(
436 sampler.sample(&logits),
437 Err(SamplingError::InvalidTemperature)
438 );
439 }
440
441 #[test]
442 fn empty_logits() {
443 let mut sampler = Sampler::new();
444 assert_eq!(sampler.sample(&[]), Err(SamplingError::InvalidLogits));
445 }
446}