Skip to main content

forgellm_runtime/
sampling.rs

1//! Token sampling strategies for autoregressive generation.
2//!
3//! Supports greedy, top-k, top-p (nucleus), and temperature-scaled sampling.
4
5/// Sampling configuration.
6#[derive(Debug, Clone)]
7pub struct SamplingConfig {
8    /// Temperature for scaling logits (1.0 = no change, <1 = sharper, >1 = flatter).
9    pub temperature: f32,
10    /// Top-k: only consider the k highest-probability tokens (0 = disabled).
11    pub top_k: usize,
12    /// Top-p (nucleus): only consider tokens with cumulative probability ≤ p (1.0 = disabled).
13    pub top_p: f32,
14    /// Repetition penalty (1.0 = no penalty).
15    pub repetition_penalty: f32,
16}
17
18impl Default for SamplingConfig {
19    fn default() -> Self {
20        Self {
21            temperature: 1.0,
22            top_k: 0,
23            top_p: 1.0,
24            repetition_penalty: 1.0,
25        }
26    }
27}
28
29impl SamplingConfig {
30    /// Greedy sampling (always pick the highest probability token).
31    pub fn greedy() -> Self {
32        Self {
33            temperature: 0.0,
34            top_k: 1,
35            top_p: 1.0,
36            repetition_penalty: 1.0,
37        }
38    }
39}
40
41/// Sample a token ID from logits using the given config.
42///
43/// `logits` is a slice of length `vocab_size` with raw (unnormalized) scores.
44/// Returns the selected token ID.
45pub fn sample(logits: &[f32], config: &SamplingConfig, rng_seed: u64) -> u32 {
46    let mut scores: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
47
48    // Greedy: just return argmax
49    if config.temperature == 0.0 || config.top_k == 1 {
50        return argmax(logits) as u32;
51    }
52
53    // Apply temperature
54    if config.temperature != 1.0 {
55        let inv_temp = 1.0 / config.temperature;
56        for (_, score) in &mut scores {
57            *score *= inv_temp;
58        }
59    }
60
61    // Sort by score descending
62    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
63
64    // Top-k filtering
65    if config.top_k > 0 && config.top_k < scores.len() {
66        scores.truncate(config.top_k);
67    }
68
69    // Softmax over remaining candidates
70    let max_score = scores[0].1;
71    let mut sum = 0.0f32;
72    for (_, score) in &mut scores {
73        *score = (*score - max_score).exp();
74        sum += *score;
75    }
76    for (_, score) in &mut scores {
77        *score /= sum;
78    }
79
80    // Top-p (nucleus) filtering
81    if config.top_p < 1.0 {
82        let mut cumulative = 0.0f32;
83        let mut cutoff = scores.len();
84        for (i, (_, prob)) in scores.iter().enumerate() {
85            cumulative += prob;
86            if cumulative >= config.top_p {
87                cutoff = i + 1;
88                break;
89            }
90        }
91        scores.truncate(cutoff);
92
93        // Renormalize
94        let sum: f32 = scores.iter().map(|(_, p)| p).sum();
95        for (_, prob) in &mut scores {
96            *prob /= sum;
97        }
98    }
99
100    // Sample from the distribution using a simple PRNG
101    let r = simple_rng(rng_seed);
102    let mut cumulative = 0.0f32;
103    for (token_id, prob) in &scores {
104        cumulative += prob;
105        if r < cumulative {
106            return *token_id as u32;
107        }
108    }
109
110    // Fallback: return the last candidate
111    scores.last().map(|(id, _)| *id as u32).unwrap_or(0)
112}
113
114/// Greedy sampling: return the token with the highest logit.
115pub fn argmax(logits: &[f32]) -> usize {
116    logits
117        .iter()
118        .enumerate()
119        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
120        .map(|(i, _)| i)
121        .unwrap_or(0)
122}
123
124/// Apply repetition penalty to logits for previously generated tokens.
125pub fn apply_repetition_penalty(logits: &mut [f32], generated_tokens: &[u32], penalty: f32) {
126    if penalty == 1.0 {
127        return;
128    }
129    for &token in generated_tokens {
130        let idx = token as usize;
131        if idx < logits.len() {
132            if logits[idx] > 0.0 {
133                logits[idx] /= penalty;
134            } else {
135                logits[idx] *= penalty;
136            }
137        }
138    }
139}
140
141/// Simple deterministic PRNG for reproducible sampling.
142/// Returns a value in [0, 1).
143fn simple_rng(seed: u64) -> f32 {
144    // xorshift64
145    let mut x = seed;
146    x ^= x << 13;
147    x ^= x >> 7;
148    x ^= x << 17;
149    (x & 0x00FF_FFFF) as f32 / 0x0100_0000 as f32
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn greedy_sampling() {
158        let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
159        let config = SamplingConfig::greedy();
160        let token = sample(&logits, &config, 42);
161        assert_eq!(token, 3); // index of max value (0.9)
162    }
163
164    #[test]
165    fn argmax_basic() {
166        assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
167        assert_eq!(argmax(&[5.0, 1.0, 2.0]), 0);
168        assert_eq!(argmax(&[-1.0, -2.0, -0.5]), 2);
169    }
170
171    #[test]
172    fn temperature_zero_is_greedy() {
173        let logits = vec![0.1, 0.9, 0.5];
174        let config = SamplingConfig {
175            temperature: 0.0,
176            ..Default::default()
177        };
178        let token = sample(&logits, &config, 123);
179        assert_eq!(token, 1);
180    }
181
182    #[test]
183    fn top_k_limits_candidates() {
184        // With top_k=2, only the top 2 logits should be considered
185        let logits = vec![0.1, 0.9, 0.8, 0.05, 0.01];
186        let config = SamplingConfig {
187            temperature: 1.0,
188            top_k: 2,
189            top_p: 1.0,
190            repetition_penalty: 1.0,
191        };
192
193        // Run many samples — should only ever pick index 1 or 2
194        for seed in 0..100 {
195            let token = sample(&logits, &config, seed);
196            assert!(
197                token == 1 || token == 2,
198                "top_k=2 sampled token {token}, expected 1 or 2"
199            );
200        }
201    }
202
203    #[test]
204    fn top_p_nucleus_sampling() {
205        // Token 0 has very high probability, top_p=0.5 should mostly pick it
206        let logits = vec![10.0, 1.0, 0.1, 0.01];
207        let config = SamplingConfig {
208            temperature: 1.0,
209            top_k: 0,
210            top_p: 0.5,
211            repetition_penalty: 1.0,
212        };
213
214        let token = sample(&logits, &config, 42);
215        assert_eq!(token, 0, "nucleus sampling should pick dominant token");
216    }
217
218    #[test]
219    fn repetition_penalty() {
220        let mut logits = vec![0.5, 0.9, 0.3];
221        apply_repetition_penalty(&mut logits, &[1], 2.0);
222
223        // Token 1 (positive logit 0.9) should be divided by 2.0
224        assert!((logits[1] - 0.45).abs() < 1e-6);
225        // Other tokens unchanged
226        assert!((logits[0] - 0.5).abs() < 1e-6);
227        assert!((logits[2] - 0.3).abs() < 1e-6);
228    }
229
230    #[test]
231    fn repetition_penalty_negative_logits() {
232        let mut logits = vec![-0.5, 0.9, -0.3];
233        apply_repetition_penalty(&mut logits, &[0, 2], 2.0);
234
235        // Negative logits should be multiplied by penalty (making them more negative)
236        assert!((logits[0] - (-1.0)).abs() < 1e-6);
237        assert!((logits[2] - (-0.6)).abs() < 1e-6);
238    }
239
240    #[test]
241    fn default_config() {
242        let config = SamplingConfig::default();
243        assert_eq!(config.temperature, 1.0);
244        assert_eq!(config.top_k, 0);
245        assert_eq!(config.top_p, 1.0);
246        assert_eq!(config.repetition_penalty, 1.0);
247    }
248
249    #[test]
250    fn simple_rng_in_range() {
251        for seed in 0..1000 {
252            let val = simple_rng(seed);
253            assert!(
254                (0.0..1.0).contains(&val),
255                "rng({seed}) = {val} out of range"
256            );
257        }
258    }
259
260    // ── Real-world validation tests ──────────────────────────────────────
261
262    #[test]
263    fn temperature_zero_always_picks_argmax() {
264        // temp=0 should always select the highest-logit token regardless of seed.
265        let logits = vec![0.1, 0.3, 0.9, 0.5, 0.2, 0.8, 0.7, 0.4];
266        let config = SamplingConfig {
267            temperature: 0.0,
268            top_k: 0,
269            top_p: 1.0,
270            repetition_penalty: 1.0,
271        };
272
273        for seed in 0..200 {
274            let token = sample(&logits, &config, seed);
275            assert_eq!(
276                token, 2,
277                "temp=0 should always pick argmax (token 2), got {token} at seed {seed}"
278            );
279        }
280    }
281
282    #[test]
283    fn high_temperature_distributes_samples() {
284        // Very high temperature should flatten the distribution, making it
285        // nearly uniform.  With enough samples, every token should appear.
286        let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
287        let config = SamplingConfig {
288            temperature: 100.0,
289            top_k: 0,
290            top_p: 1.0,
291            repetition_penalty: 1.0,
292        };
293
294        let mut seen = [false; 5];
295        for seed in 0..500 {
296            let token = sample(&logits, &config, seed) as usize;
297            assert!(token < 5, "token {token} out of range");
298            seen[token] = true;
299        }
300
301        // With temp=100 the distribution is nearly uniform; all 5 tokens
302        // should be sampled at least once in 500 draws.
303        let seen_count = seen.iter().filter(|&&s| s).count();
304        assert!(
305            seen_count >= 3,
306            "high temperature should sample diverse tokens, but only {seen_count}/5 seen"
307        );
308    }
309
310    #[test]
311    fn repetition_penalty_reduces_repeated_token_probability() {
312        // After applying repetition penalty to the dominant token, argmax
313        // should shift to a different token.
314        let mut logits = vec![0.1, 10.0, 0.2, 9.5];
315        assert_eq!(argmax(&logits), 1, "pre-penalty argmax should be token 1");
316
317        // Penalize token 1 heavily
318        apply_repetition_penalty(&mut logits, &[1], 20.0);
319        assert_ne!(
320            argmax(&logits),
321            1,
322            "after heavy repetition penalty, argmax should shift away from token 1"
323        );
324        assert_eq!(
325            argmax(&logits),
326            3,
327            "after penalizing token 1, token 3 (9.5) should become argmax"
328        );
329    }
330
331    #[test]
332    fn softmax_all_negative_logits_produces_valid_distribution() {
333        // When all logits are very negative, softmax via sample() should
334        // still produce a valid token (no NaN, no panic).
335        let logits = vec![-100.0, -200.0, -150.0, -300.0];
336        let config = SamplingConfig {
337            temperature: 1.0,
338            top_k: 0,
339            top_p: 1.0,
340            repetition_penalty: 1.0,
341        };
342
343        let token = sample(&logits, &config, 42);
344        assert!(
345            (token as usize) < logits.len(),
346            "sampled token {token} should be in valid range"
347        );
348
349        // Greedy should pick the least-negative logit
350        assert_eq!(
351            argmax(&logits),
352            0,
353            "argmax of all-negative logits should be index 0 (-100.0)"
354        );
355    }
356
357    #[test]
358    fn sample_with_single_token_vocab() {
359        // Edge case: vocab_size = 1.  Should always return token 0.
360        let logits = vec![0.5];
361        let config = SamplingConfig::greedy();
362        assert_eq!(sample(&logits, &config, 0), 0);
363
364        let config_temp = SamplingConfig {
365            temperature: 1.0,
366            top_k: 0,
367            top_p: 1.0,
368            repetition_penalty: 1.0,
369        };
370        assert_eq!(sample(&logits, &config_temp, 42), 0);
371    }
372}