Skip to main content

forgellm_runtime/
sampling.rs

1//! Token sampling strategies for autoregressive generation.
2//!
3//! Supports greedy, top-k, top-p (nucleus), and temperature-scaled sampling.
4
5/// Sampling configuration.
6#[derive(Debug, Clone)]
7pub struct SamplingConfig {
8    /// Temperature for scaling logits (1.0 = no change, <1 = sharper, >1 = flatter).
9    pub temperature: f32,
10    /// Top-k: only consider the k highest-probability tokens (0 = disabled).
11    pub top_k: usize,
12    /// Top-p (nucleus): only consider tokens with cumulative probability ≤ p (1.0 = disabled).
13    pub top_p: f32,
14    /// Repetition penalty (1.0 = no penalty).
15    pub repetition_penalty: f32,
16}
17
18impl Default for SamplingConfig {
19    fn default() -> Self {
20        Self {
21            temperature: 1.0,
22            top_k: 0,
23            top_p: 1.0,
24            repetition_penalty: 1.0,
25        }
26    }
27}
28
29impl SamplingConfig {
30    /// Greedy sampling (always pick the highest probability token).
31    pub fn greedy() -> Self {
32        Self {
33            temperature: 0.0,
34            top_k: 1,
35            top_p: 1.0,
36            repetition_penalty: 1.0,
37        }
38    }
39}
40
41/// Sample a token ID from logits using the given config.
42///
43/// `logits` is a slice of length `vocab_size` with raw (unnormalized) scores.
44/// Returns the selected token ID.
45pub fn sample(logits: &[f32], config: &SamplingConfig, rng_seed: u64) -> u32 {
46    let mut scores: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
47
48    // Greedy: just return argmax
49    if config.temperature == 0.0 || config.top_k == 1 {
50        return argmax(logits) as u32;
51    }
52
53    // Apply temperature
54    if config.temperature != 1.0 {
55        let inv_temp = 1.0 / config.temperature;
56        for (_, score) in &mut scores {
57            *score *= inv_temp;
58        }
59    }
60
61    // Sort by score descending
62    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
63
64    // Top-k filtering
65    if config.top_k > 0 && config.top_k < scores.len() {
66        scores.truncate(config.top_k);
67    }
68
69    // Softmax over remaining candidates
70    let max_score = scores[0].1;
71    let mut sum = 0.0f32;
72    for (_, score) in &mut scores {
73        *score = (*score - max_score).exp();
74        sum += *score;
75    }
76    for (_, score) in &mut scores {
77        *score /= sum;
78    }
79
80    // Top-p (nucleus) filtering
81    if config.top_p < 1.0 {
82        let mut cumulative = 0.0f32;
83        let mut cutoff = scores.len();
84        for (i, (_, prob)) in scores.iter().enumerate() {
85            cumulative += prob;
86            if cumulative >= config.top_p {
87                cutoff = i + 1;
88                break;
89            }
90        }
91        scores.truncate(cutoff);
92
93        // Renormalize
94        let sum: f32 = scores.iter().map(|(_, p)| p).sum();
95        for (_, prob) in &mut scores {
96            *prob /= sum;
97        }
98    }
99
100    // Sample from the distribution using a simple PRNG
101    let r = simple_rng(rng_seed);
102    let mut cumulative = 0.0f32;
103    for (token_id, prob) in &scores {
104        cumulative += prob;
105        if r < cumulative {
106            return *token_id as u32;
107        }
108    }
109
110    // Fallback: return the last candidate
111    scores.last().map(|(id, _)| *id as u32).unwrap_or(0)
112}
113
114/// Greedy sampling: return the token with the highest logit.
115pub fn argmax(logits: &[f32]) -> usize {
116    logits
117        .iter()
118        .enumerate()
119        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
120        .map(|(i, _)| i)
121        .unwrap_or(0)
122}
123
124/// Apply repetition penalty to logits for previously generated tokens.
125pub fn apply_repetition_penalty(logits: &mut [f32], generated_tokens: &[u32], penalty: f32) {
126    if penalty == 1.0 {
127        return;
128    }
129    for &token in generated_tokens {
130        let idx = token as usize;
131        if idx < logits.len() {
132            if logits[idx] > 0.0 {
133                logits[idx] /= penalty;
134            } else {
135                logits[idx] *= penalty;
136            }
137        }
138    }
139}
140
141/// Simple deterministic PRNG for reproducible sampling.
142/// Returns a value in [0, 1).
143fn simple_rng(seed: u64) -> f32 {
144    // xorshift64
145    let mut x = seed;
146    x ^= x << 13;
147    x ^= x >> 7;
148    x ^= x << 17;
149    (x & 0x00FF_FFFF) as f32 / 0x0100_0000 as f32
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn greedy_sampling() {
158        let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
159        let config = SamplingConfig::greedy();
160        let token = sample(&logits, &config, 42);
161        assert_eq!(token, 3); // index of max value (0.9)
162    }
163
164    #[test]
165    fn argmax_basic() {
166        assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
167        assert_eq!(argmax(&[5.0, 1.0, 2.0]), 0);
168        assert_eq!(argmax(&[-1.0, -2.0, -0.5]), 2);
169    }
170
171    #[test]
172    fn temperature_zero_is_greedy() {
173        let logits = vec![0.1, 0.9, 0.5];
174        let config = SamplingConfig {
175            temperature: 0.0,
176            ..Default::default()
177        };
178        let token = sample(&logits, &config, 123);
179        assert_eq!(token, 1);
180    }
181
182    #[test]
183    fn top_k_limits_candidates() {
184        // With top_k=2, only the top 2 logits should be considered
185        let logits = vec![0.1, 0.9, 0.8, 0.05, 0.01];
186        let config = SamplingConfig {
187            temperature: 1.0,
188            top_k: 2,
189            top_p: 1.0,
190            repetition_penalty: 1.0,
191        };
192
193        // Run many samples — should only ever pick index 1 or 2
194        for seed in 0..100 {
195            let token = sample(&logits, &config, seed);
196            assert!(
197                token == 1 || token == 2,
198                "top_k=2 sampled token {token}, expected 1 or 2"
199            );
200        }
201    }
202
203    #[test]
204    fn top_p_nucleus_sampling() {
205        // Token 0 has very high probability, top_p=0.5 should mostly pick it
206        let logits = vec![10.0, 1.0, 0.1, 0.01];
207        let config = SamplingConfig {
208            temperature: 1.0,
209            top_k: 0,
210            top_p: 0.5,
211            repetition_penalty: 1.0,
212        };
213
214        let token = sample(&logits, &config, 42);
215        assert_eq!(token, 0, "nucleus sampling should pick dominant token");
216    }
217
218    #[test]
219    fn repetition_penalty() {
220        let mut logits = vec![0.5, 0.9, 0.3];
221        apply_repetition_penalty(&mut logits, &[1], 2.0);
222
223        // Token 1 (positive logit 0.9) should be divided by 2.0
224        assert!((logits[1] - 0.45).abs() < 1e-6);
225        // Other tokens unchanged
226        assert!((logits[0] - 0.5).abs() < 1e-6);
227        assert!((logits[2] - 0.3).abs() < 1e-6);
228    }
229
230    #[test]
231    fn repetition_penalty_negative_logits() {
232        let mut logits = vec![-0.5, 0.9, -0.3];
233        apply_repetition_penalty(&mut logits, &[0, 2], 2.0);
234
235        // Negative logits should be multiplied by penalty (making them more negative)
236        assert!((logits[0] - (-1.0)).abs() < 1e-6);
237        assert!((logits[2] - (-0.6)).abs() < 1e-6);
238    }
239
240    #[test]
241    fn default_config() {
242        let config = SamplingConfig::default();
243        assert_eq!(config.temperature, 1.0);
244        assert_eq!(config.top_k, 0);
245        assert_eq!(config.top_p, 1.0);
246        assert_eq!(config.repetition_penalty, 1.0);
247    }
248
249    #[test]
250    fn simple_rng_in_range() {
251        for seed in 0..1000 {
252            let val = simple_rng(seed);
253            assert!(
254                (0.0..1.0).contains(&val),
255                "rng({seed}) = {val} out of range"
256            );
257        }
258    }
259}