Skip to main content

llama_gguf/sampling/
mod.rs

1//! Token sampling strategies for text generation
2//!
3//! This module provides various sampling strategies for selecting the next token
4//! during text generation, including temperature, top-k, top-p (nucleus), and more.
5
6pub mod grammar;
7
8use rand::prelude::*;
9use rand::rngs::StdRng;
10
11use crate::tensor::Tensor;
12
13pub use grammar::{GbnfGrammar, Grammar, GrammarSampler, JsonGrammar, RegexGrammar};
14
15/// Mirostat sampling configuration
16#[derive(Debug, Clone)]
17pub struct MirostatConfig {
18    /// Target surprise value (tau) - higher = more random
19    pub tau: f32,
20    /// Learning rate (eta)
21    pub eta: f32,
22    /// Mirostat version (1 or 2)
23    pub version: u8,
24}
25
26impl Default for MirostatConfig {
27    fn default() -> Self {
28        Self {
29            tau: 5.0,
30            eta: 0.1,
31            version: 2,
32        }
33    }
34}
35
36/// Sampler configuration
37#[derive(Debug, Clone)]
38pub struct SamplerConfig {
39    /// Temperature for softmax scaling (higher = more random)
40    pub temperature: f32,
41    /// Top-K: only consider the K most likely tokens
42    pub top_k: usize,
43    /// Top-P (nucleus): only consider tokens with cumulative probability <= p
44    pub top_p: f32,
45    /// Min-P: only consider tokens with probability >= min_p * max_prob
46    pub min_p: f32,
47    /// Typical-P sampling
48    pub typical_p: f32,
49    /// Repetition penalty
50    pub repeat_penalty: f32,
51    /// Window size for repetition penalty
52    pub repeat_window: usize,
53    /// Frequency penalty
54    pub frequency_penalty: f32,
55    /// Presence penalty
56    pub presence_penalty: f32,
57    /// Random seed (None for random)
58    pub seed: Option<u64>,
59    /// Mirostat sampling (overrides other sampling methods if set)
60    pub mirostat: Option<MirostatConfig>,
61}
62
63impl Default for SamplerConfig {
64    fn default() -> Self {
65        Self {
66            temperature: 0.8,
67            top_k: 40,
68            top_p: 0.95,
69            min_p: 0.0,
70            typical_p: 1.0,
71            repeat_penalty: 1.1,
72            repeat_window: 64,
73            frequency_penalty: 0.0,
74            presence_penalty: 0.0,
75            seed: None,
76            mirostat: None,
77        }
78    }
79}
80
81impl SamplerConfig {
82    /// Create a greedy sampling config (always picks most likely token)
83    pub fn greedy() -> Self {
84        Self {
85            temperature: 0.0,
86            top_k: 1,
87            top_p: 1.0,
88            min_p: 0.0,
89            typical_p: 1.0,
90            repeat_penalty: 1.0,
91            repeat_window: 0,
92            frequency_penalty: 0.0,
93            presence_penalty: 0.0,
94            seed: None,
95            mirostat: None,
96        }
97    }
98
99    /// Create a creative sampling config
100    pub fn creative() -> Self {
101        Self {
102            temperature: 1.0,
103            top_k: 0, // Disabled
104            top_p: 0.9,
105            min_p: 0.05,
106            typical_p: 1.0,
107            repeat_penalty: 1.2,
108            repeat_window: 64,
109            frequency_penalty: 0.0,
110            presence_penalty: 0.0,
111            seed: None,
112            mirostat: None,
113        }
114    }
115
116    /// Create a Mirostat v2 sampling config
117    pub fn mirostat_v2(tau: f32, eta: f32) -> Self {
118        Self {
119            temperature: 1.0,
120            top_k: 0,
121            top_p: 1.0,
122            min_p: 0.0,
123            typical_p: 1.0,
124            repeat_penalty: 1.0,
125            repeat_window: 0,
126            frequency_penalty: 0.0,
127            presence_penalty: 0.0,
128            seed: None,
129            mirostat: Some(MirostatConfig {
130                tau,
131                eta,
132                version: 2,
133            }),
134        }
135    }
136}
137
138/// Token sampler for text generation
139pub struct Sampler {
140    config: SamplerConfig,
141    rng: StdRng,
142    /// Token frequency counts for repetition penalty
143    token_counts: Vec<u32>,
144    /// Mirostat mu (adaptive parameter)
145    mirostat_mu: f32,
146}
147
148impl Sampler {
149    /// Create a new sampler
150    pub fn new(config: SamplerConfig, vocab_size: usize) -> Self {
151        let rng = match config.seed {
152            Some(seed) => StdRng::seed_from_u64(seed),
153            None => StdRng::from_entropy(),
154        };
155
156        // Initialize mirostat mu based on tau
157        let mirostat_mu = config
158            .mirostat
159            .as_ref()
160            .map(|m| m.tau * 2.0)
161            .unwrap_or(10.0);
162
163        Self {
164            config,
165            rng,
166            token_counts: vec![0; vocab_size],
167            mirostat_mu,
168        }
169    }
170
171    /// Reset the sampler state
172    pub fn reset(&mut self) {
173        self.token_counts.fill(0);
174        // Reset mirostat mu
175        if let Some(ref mirostat) = self.config.mirostat {
176            self.mirostat_mu = mirostat.tau * 2.0;
177        }
178    }
179
180    /// Sample next token from logits
181    ///
182    /// # Arguments
183    /// * `logits` - Logits tensor [vocab_size]
184    /// * `recent_tokens` - Recently generated tokens for repetition penalty
185    ///
186    /// # Returns
187    /// Selected token ID
188    pub fn sample(&mut self, logits: &Tensor, recent_tokens: &[u32]) -> u32 {
189        let logits_data = logits.as_f32().expect("Logits must be F32");
190        let vocab_size = logits_data.len();
191
192        // Debug: show top 5 tokens before sampling
193        // let mut top5: Vec<(usize, f32)> = logits_data.iter().cloned().enumerate().collect();
194        // top5.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
195        // eprintln!("Top 5 logits: {:?}", &top5[..5.min(top5.len())]);
196
197        // Copy logits to work with
198        let mut probs: Vec<f32> = logits_data.to_vec();
199
200        // Apply repetition penalty
201        if self.config.repeat_penalty != 1.0 {
202            self.apply_repetition_penalty(&mut probs, recent_tokens);
203        }
204
205        // Apply frequency and presence penalties
206        if self.config.frequency_penalty != 0.0 || self.config.presence_penalty != 0.0 {
207            self.apply_frequency_presence_penalty(&mut probs);
208        }
209
210        // Check if Mirostat is enabled
211        if let Some(ref mirostat) = self.config.mirostat {
212            return self.sample_mirostat(&mut probs, mirostat.clone());
213        }
214
215        // Apply temperature
216        if self.config.temperature > 0.0 && self.config.temperature != 1.0 {
217            let inv_temp = 1.0 / self.config.temperature;
218            for p in &mut probs {
219                *p *= inv_temp;
220            }
221        }
222
223        // Convert to probabilities with softmax
224        let max_logit = probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
225        let mut sum = 0.0f32;
226        for p in &mut probs {
227            *p = (*p - max_logit).exp();
228            sum += *p;
229        }
230        for p in &mut probs {
231            *p /= sum;
232        }
233
234        // Greedy decoding
235        if self.config.temperature == 0.0 || self.config.top_k == 1 {
236            return probs
237                .iter()
238                .enumerate()
239                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
240                .map(|(i, _)| i as u32)
241                .unwrap_or(0);
242        }
243
244        // Create sorted indices by probability
245        let mut indices: Vec<usize> = (0..vocab_size).collect();
246        indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap_or(std::cmp::Ordering::Equal));
247
248        // Apply min-p filtering
249        if self.config.min_p > 0.0 {
250            let threshold = probs[indices[0]] * self.config.min_p;
251            let cutoff = indices
252                .iter()
253                .position(|&i| probs[i] < threshold)
254                .unwrap_or(vocab_size);
255            if cutoff > 0 {
256                indices.truncate(cutoff);
257            }
258        }
259
260        // Apply top-k filtering
261        if self.config.top_k > 0 && self.config.top_k < indices.len() {
262            indices.truncate(self.config.top_k);
263        }
264
265        // Apply top-p (nucleus) filtering
266        if self.config.top_p < 1.0 {
267            let mut cumsum = 0.0f32;
268            let cutoff = indices
269                .iter()
270                .position(|&i| {
271                    cumsum += probs[i];
272                    cumsum > self.config.top_p
273                })
274                .unwrap_or(indices.len());
275            if cutoff > 0 {
276                indices.truncate(cutoff + 1); // Include the token that crossed threshold
277            }
278        }
279
280        // Renormalize probabilities over filtered tokens
281        let filtered_sum: f32 = indices.iter().map(|&i| probs[i]).sum();
282        for &i in &indices {
283            probs[i] /= filtered_sum;
284        }
285
286        // Sample from filtered distribution
287        let r: f32 = self.rng.r#gen();
288        let mut cumsum = 0.0f32;
289        for &i in &indices {
290            cumsum += probs[i];
291            if r < cumsum {
292                let token_id = i as u32;
293                self.token_counts[i] += 1;
294                return token_id;
295            }
296        }
297
298        // Fallback to last token in filtered set
299        let token_id = *indices.last().unwrap() as u32;
300        self.token_counts[token_id as usize] += 1;
301        token_id
302    }
303
304    /// Mirostat sampling algorithm
305    ///
306    /// Mirostat dynamically adjusts the sampling to target a specific "surprise" level.
307    fn sample_mirostat(&mut self, logits: &mut [f32], config: MirostatConfig) -> u32 {
308        let vocab_size = logits.len();
309
310        // Convert logits to probabilities
311        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
312        let mut sum = 0.0f32;
313        for p in logits.iter_mut() {
314            *p = (*p - max_logit).exp();
315            sum += *p;
316        }
317        for p in logits.iter_mut() {
318            *p /= sum;
319        }
320
321        // Sort tokens by probability (descending)
322        let mut sorted_indices: Vec<usize> = (0..vocab_size).collect();
323        sorted_indices.sort_by(|&a, &b| logits[b].partial_cmp(&logits[a]).unwrap());
324
325        let token_id = if config.version == 1 {
326            // Mirostat v1: uses a fixed number of candidates based on mu
327            let n = ((2.0f32.powf(self.mirostat_mu) * vocab_size as f32) as usize)
328                .max(1)
329                .min(vocab_size);
330
331            // Truncate to top n candidates
332            let candidates = &sorted_indices[..n];
333
334            // Renormalize and sample
335            let filtered_sum: f32 = candidates.iter().map(|&i| logits[i]).sum();
336            let r: f32 = self.rng.r#gen::<f32>() * filtered_sum;
337            let mut cumsum = 0.0f32;
338            let mut selected = candidates[0];
339            for &i in candidates {
340                cumsum += logits[i];
341                if cumsum > r {
342                    selected = i;
343                    break;
344                }
345            }
346            selected
347        } else {
348            // Mirostat v2: uses mu to truncate based on surprise
349            // Find the truncation point where -log2(p) > mu
350            let mu = self.mirostat_mu;
351
352            let mut truncation_idx = vocab_size;
353            for (rank, &i) in sorted_indices.iter().enumerate() {
354                let surprise = -logits[i].log2();
355                if surprise > mu {
356                    truncation_idx = rank.max(1);
357                    break;
358                }
359            }
360
361            // Sample from truncated distribution
362            let candidates = &sorted_indices[..truncation_idx];
363            let filtered_sum: f32 = candidates.iter().map(|&i| logits[i]).sum();
364            let r: f32 = self.rng.r#gen::<f32>() * filtered_sum;
365            let mut cumsum = 0.0f32;
366            let mut selected = candidates[0];
367            for &i in candidates {
368                cumsum += logits[i];
369                if cumsum > r {
370                    selected = i;
371                    break;
372                }
373            }
374            selected
375        };
376
377        // Update mu based on the surprise of the selected token
378        let selected_prob = logits[token_id];
379        let surprise = -selected_prob.log2();
380        self.mirostat_mu -= config.eta * (surprise - config.tau);
381
382        // Clamp mu to reasonable bounds
383        self.mirostat_mu = self.mirostat_mu.clamp(0.0, 20.0);
384
385        self.token_counts[token_id] += 1;
386        token_id as u32
387    }
388
389    /// Apply repetition penalty to logits
390    fn apply_repetition_penalty(&self, logits: &mut [f32], recent_tokens: &[u32]) {
391        let window = if self.config.repeat_window > 0 {
392            recent_tokens.len().min(self.config.repeat_window)
393        } else {
394            recent_tokens.len()
395        };
396
397        let start = recent_tokens.len().saturating_sub(window);
398        for &token_id in &recent_tokens[start..] {
399            let idx = token_id as usize;
400            if idx < logits.len() {
401                if logits[idx] > 0.0 {
402                    logits[idx] /= self.config.repeat_penalty;
403                } else {
404                    logits[idx] *= self.config.repeat_penalty;
405                }
406            }
407        }
408    }
409
410    /// Apply frequency and presence penalties
411    fn apply_frequency_presence_penalty(&self, logits: &mut [f32]) {
412        for (i, &count) in self.token_counts.iter().enumerate() {
413            if count > 0 {
414                // Frequency penalty: scales with count
415                logits[i] -= self.config.frequency_penalty * count as f32;
416                // Presence penalty: constant if present
417                logits[i] -= self.config.presence_penalty;
418            }
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_default_config() {
429        let config = SamplerConfig::default();
430        assert_eq!(config.temperature, 0.8);
431        assert_eq!(config.top_k, 40);
432        assert!((config.top_p - 0.95).abs() < 0.001);
433    }
434
435    #[test]
436    fn test_greedy_config() {
437        let config = SamplerConfig::greedy();
438        assert_eq!(config.temperature, 0.0);
439        assert_eq!(config.top_k, 1);
440    }
441
442    #[test]
443    fn test_greedy_sampling() {
444        let config = SamplerConfig::greedy();
445        let mut sampler = Sampler::new(config, 10);
446
447        // Create logits where token 5 has highest probability
448        let logits_data = vec![0.0, 0.1, 0.2, 0.3, 0.4, 1.0, 0.2, 0.1, 0.0, -0.1];
449        let logits = Tensor::from_f32(&logits_data, vec![10]).unwrap();
450
451        let token = sampler.sample(&logits, &[]);
452        assert_eq!(token, 5);
453    }
454
455    #[test]
456    fn test_sampler_reset() {
457        let config = SamplerConfig::default();
458        let mut sampler = Sampler::new(config, 10);
459
460        sampler.token_counts[5] = 10;
461        sampler.reset();
462
463        assert_eq!(sampler.token_counts[5], 0);
464    }
465}