1pub mod grammar;
7
8use rand::prelude::*;
9use rand::rngs::StdRng;
10
11use crate::tensor::Tensor;
12
13pub use grammar::{GbnfGrammar, Grammar, GrammarSampler, JsonGrammar, RegexGrammar};
14
15#[derive(Debug, Clone)]
17pub struct MirostatConfig {
18 pub tau: f32,
20 pub eta: f32,
22 pub version: u8,
24}
25
26impl Default for MirostatConfig {
27 fn default() -> Self {
28 Self {
29 tau: 5.0,
30 eta: 0.1,
31 version: 2,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct SamplerConfig {
39 pub temperature: f32,
41 pub top_k: usize,
43 pub top_p: f32,
45 pub min_p: f32,
47 pub typical_p: f32,
49 pub repeat_penalty: f32,
51 pub repeat_window: usize,
53 pub frequency_penalty: f32,
55 pub presence_penalty: f32,
57 pub seed: Option<u64>,
59 pub mirostat: Option<MirostatConfig>,
61}
62
63impl Default for SamplerConfig {
64 fn default() -> Self {
65 Self {
66 temperature: 0.8,
67 top_k: 40,
68 top_p: 0.95,
69 min_p: 0.0,
70 typical_p: 1.0,
71 repeat_penalty: 1.1,
72 repeat_window: 64,
73 frequency_penalty: 0.0,
74 presence_penalty: 0.0,
75 seed: None,
76 mirostat: None,
77 }
78 }
79}
80
81impl SamplerConfig {
82 pub fn greedy() -> Self {
84 Self {
85 temperature: 0.0,
86 top_k: 1,
87 top_p: 1.0,
88 min_p: 0.0,
89 typical_p: 1.0,
90 repeat_penalty: 1.0,
91 repeat_window: 0,
92 frequency_penalty: 0.0,
93 presence_penalty: 0.0,
94 seed: None,
95 mirostat: None,
96 }
97 }
98
99 pub fn creative() -> Self {
101 Self {
102 temperature: 1.0,
103 top_k: 0, top_p: 0.9,
105 min_p: 0.05,
106 typical_p: 1.0,
107 repeat_penalty: 1.2,
108 repeat_window: 64,
109 frequency_penalty: 0.0,
110 presence_penalty: 0.0,
111 seed: None,
112 mirostat: None,
113 }
114 }
115
116 pub fn mirostat_v2(tau: f32, eta: f32) -> Self {
118 Self {
119 temperature: 1.0,
120 top_k: 0,
121 top_p: 1.0,
122 min_p: 0.0,
123 typical_p: 1.0,
124 repeat_penalty: 1.0,
125 repeat_window: 0,
126 frequency_penalty: 0.0,
127 presence_penalty: 0.0,
128 seed: None,
129 mirostat: Some(MirostatConfig {
130 tau,
131 eta,
132 version: 2,
133 }),
134 }
135 }
136}
137
138pub struct Sampler {
140 config: SamplerConfig,
141 rng: StdRng,
142 token_counts: Vec<u32>,
144 mirostat_mu: f32,
146}
147
148impl Sampler {
149 pub fn new(config: SamplerConfig, vocab_size: usize) -> Self {
151 let rng = match config.seed {
152 Some(seed) => StdRng::seed_from_u64(seed),
153 None => StdRng::from_entropy(),
154 };
155
156 let mirostat_mu = config
158 .mirostat
159 .as_ref()
160 .map(|m| m.tau * 2.0)
161 .unwrap_or(10.0);
162
163 Self {
164 config,
165 rng,
166 token_counts: vec![0; vocab_size],
167 mirostat_mu,
168 }
169 }
170
171 pub fn reset(&mut self) {
173 self.token_counts.fill(0);
174 if let Some(ref mirostat) = self.config.mirostat {
176 self.mirostat_mu = mirostat.tau * 2.0;
177 }
178 }
179
180 pub fn sample(&mut self, logits: &Tensor, recent_tokens: &[u32]) -> u32 {
189 let logits_data = logits.as_f32().expect("Logits must be F32");
190 let vocab_size = logits_data.len();
191
192 let mut probs: Vec<f32> = logits_data.to_vec();
199
200 if self.config.repeat_penalty != 1.0 {
202 self.apply_repetition_penalty(&mut probs, recent_tokens);
203 }
204
205 if self.config.frequency_penalty != 0.0 || self.config.presence_penalty != 0.0 {
207 self.apply_frequency_presence_penalty(&mut probs);
208 }
209
210 if let Some(ref mirostat) = self.config.mirostat {
212 return self.sample_mirostat(&mut probs, mirostat.clone());
213 }
214
215 if self.config.temperature > 0.0 && self.config.temperature != 1.0 {
217 let inv_temp = 1.0 / self.config.temperature;
218 for p in &mut probs {
219 *p *= inv_temp;
220 }
221 }
222
223 let max_logit = probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
225 let mut sum = 0.0f32;
226 for p in &mut probs {
227 *p = (*p - max_logit).exp();
228 sum += *p;
229 }
230 for p in &mut probs {
231 *p /= sum;
232 }
233
234 if self.config.temperature == 0.0 || self.config.top_k == 1 {
236 return probs
237 .iter()
238 .enumerate()
239 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
240 .map(|(i, _)| i as u32)
241 .unwrap_or(0);
242 }
243
244 let mut indices: Vec<usize> = (0..vocab_size).collect();
246 indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap_or(std::cmp::Ordering::Equal));
247
248 if self.config.min_p > 0.0 {
250 let threshold = probs[indices[0]] * self.config.min_p;
251 let cutoff = indices
252 .iter()
253 .position(|&i| probs[i] < threshold)
254 .unwrap_or(vocab_size);
255 if cutoff > 0 {
256 indices.truncate(cutoff);
257 }
258 }
259
260 if self.config.top_k > 0 && self.config.top_k < indices.len() {
262 indices.truncate(self.config.top_k);
263 }
264
265 if self.config.top_p < 1.0 {
267 let mut cumsum = 0.0f32;
268 let cutoff = indices
269 .iter()
270 .position(|&i| {
271 cumsum += probs[i];
272 cumsum > self.config.top_p
273 })
274 .unwrap_or(indices.len());
275 if cutoff > 0 {
276 indices.truncate(cutoff + 1); }
278 }
279
280 let filtered_sum: f32 = indices.iter().map(|&i| probs[i]).sum();
282 for &i in &indices {
283 probs[i] /= filtered_sum;
284 }
285
286 let r: f32 = self.rng.r#gen();
288 let mut cumsum = 0.0f32;
289 for &i in &indices {
290 cumsum += probs[i];
291 if r < cumsum {
292 let token_id = i as u32;
293 self.token_counts[i] += 1;
294 return token_id;
295 }
296 }
297
298 let token_id = *indices.last().unwrap() as u32;
300 self.token_counts[token_id as usize] += 1;
301 token_id
302 }
303
304 fn sample_mirostat(&mut self, logits: &mut [f32], config: MirostatConfig) -> u32 {
308 let vocab_size = logits.len();
309
310 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
312 let mut sum = 0.0f32;
313 for p in logits.iter_mut() {
314 *p = (*p - max_logit).exp();
315 sum += *p;
316 }
317 for p in logits.iter_mut() {
318 *p /= sum;
319 }
320
321 let mut sorted_indices: Vec<usize> = (0..vocab_size).collect();
323 sorted_indices.sort_by(|&a, &b| logits[b].partial_cmp(&logits[a]).unwrap());
324
325 let token_id = if config.version == 1 {
326 let n = ((2.0f32.powf(self.mirostat_mu) * vocab_size as f32) as usize)
328 .max(1)
329 .min(vocab_size);
330
331 let candidates = &sorted_indices[..n];
333
334 let filtered_sum: f32 = candidates.iter().map(|&i| logits[i]).sum();
336 let r: f32 = self.rng.r#gen::<f32>() * filtered_sum;
337 let mut cumsum = 0.0f32;
338 let mut selected = candidates[0];
339 for &i in candidates {
340 cumsum += logits[i];
341 if cumsum > r {
342 selected = i;
343 break;
344 }
345 }
346 selected
347 } else {
348 let mu = self.mirostat_mu;
351
352 let mut truncation_idx = vocab_size;
353 for (rank, &i) in sorted_indices.iter().enumerate() {
354 let surprise = -logits[i].log2();
355 if surprise > mu {
356 truncation_idx = rank.max(1);
357 break;
358 }
359 }
360
361 let candidates = &sorted_indices[..truncation_idx];
363 let filtered_sum: f32 = candidates.iter().map(|&i| logits[i]).sum();
364 let r: f32 = self.rng.r#gen::<f32>() * filtered_sum;
365 let mut cumsum = 0.0f32;
366 let mut selected = candidates[0];
367 for &i in candidates {
368 cumsum += logits[i];
369 if cumsum > r {
370 selected = i;
371 break;
372 }
373 }
374 selected
375 };
376
377 let selected_prob = logits[token_id];
379 let surprise = -selected_prob.log2();
380 self.mirostat_mu -= config.eta * (surprise - config.tau);
381
382 self.mirostat_mu = self.mirostat_mu.clamp(0.0, 20.0);
384
385 self.token_counts[token_id] += 1;
386 token_id as u32
387 }
388
389 fn apply_repetition_penalty(&self, logits: &mut [f32], recent_tokens: &[u32]) {
391 let window = if self.config.repeat_window > 0 {
392 recent_tokens.len().min(self.config.repeat_window)
393 } else {
394 recent_tokens.len()
395 };
396
397 let start = recent_tokens.len().saturating_sub(window);
398 for &token_id in &recent_tokens[start..] {
399 let idx = token_id as usize;
400 if idx < logits.len() {
401 if logits[idx] > 0.0 {
402 logits[idx] /= self.config.repeat_penalty;
403 } else {
404 logits[idx] *= self.config.repeat_penalty;
405 }
406 }
407 }
408 }
409
410 fn apply_frequency_presence_penalty(&self, logits: &mut [f32]) {
412 for (i, &count) in self.token_counts.iter().enumerate() {
413 if count > 0 {
414 logits[i] -= self.config.frequency_penalty * count as f32;
416 logits[i] -= self.config.presence_penalty;
418 }
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_default_config() {
429 let config = SamplerConfig::default();
430 assert_eq!(config.temperature, 0.8);
431 assert_eq!(config.top_k, 40);
432 assert!((config.top_p - 0.95).abs() < 0.001);
433 }
434
435 #[test]
436 fn test_greedy_config() {
437 let config = SamplerConfig::greedy();
438 assert_eq!(config.temperature, 0.0);
439 assert_eq!(config.top_k, 1);
440 }
441
442 #[test]
443 fn test_greedy_sampling() {
444 let config = SamplerConfig::greedy();
445 let mut sampler = Sampler::new(config, 10);
446
447 let logits_data = vec![0.0, 0.1, 0.2, 0.3, 0.4, 1.0, 0.2, 0.1, 0.0, -0.1];
449 let logits = Tensor::from_f32(&logits_data, vec![10]).unwrap();
450
451 let token = sampler.sample(&logits, &[]);
452 assert_eq!(token, 5);
453 }
454
455 #[test]
456 fn test_sampler_reset() {
457 let config = SamplerConfig::default();
458 let mut sampler = Sampler::new(config, 10);
459
460 sampler.token_counts[5] = 10;
461 sampler.reset();
462
463 assert_eq!(sampler.token_counts[5], 0);
464 }
465}