Skip to main content

llama_sampling/
lib.rs

1//! # llama-sampling
2//!
3//! Sampling and decoding strategies for llama.rs.
4//! Implements greedy, temperature, top-k, top-p, and repetition penalty sampling
5//! with deterministic seeded RNG for reproducible test runs.
6
7use std::borrow::Cow;
8use std::cmp::Ordering;
9
10/// Error type for sampling operations.
11#[derive(Debug, Clone, PartialEq, thiserror::Error)]
12pub enum SamplingError {
13    #[error("logits cannot be empty")]
14    EmptyLogits,
15    #[error("invalid token id in history: {0}")]
16    InvalidHistoryToken(i32),
17    #[error("temperature must be > 0, got {0}")]
18    InvalidTemperature(f32),
19    #[error("top_p must be in (0, 1], got {0}")]
20    InvalidTopP(f32),
21    #[error("repetition_penalty must be >= 1.0, got {0}")]
22    InvalidRepetitionPenalty(f32),
23}
24
25/// Sampling strategy.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SamplingStrategy {
28    /// Pick argmax directly.
29    Greedy,
30    /// Sample from filtered probability distribution.
31    Stochastic,
32}
33
34/// Configuration for token sampling.
35#[derive(Debug, Clone)]
36pub struct SamplingConfig {
37    pub strategy: SamplingStrategy,
38    pub temperature: f32,
39    pub top_k: Option<usize>,
40    pub top_p: Option<f32>,
41    pub repetition_penalty: Option<f32>,
42    pub seed: u64,
43}
44
45impl Default for SamplingConfig {
46    fn default() -> Self {
47        Self {
48            strategy: SamplingStrategy::Stochastic,
49            temperature: 1.0,
50            top_k: None,
51            top_p: None,
52            repetition_penalty: None,
53            seed: 0,
54        }
55    }
56}
57
58impl SamplingConfig {
59    fn validate(&self) -> Result<(), SamplingError> {
60        if self.temperature <= 0.0 {
61            return Err(SamplingError::InvalidTemperature(self.temperature));
62        }
63        if let Some(top_p) = self.top_p {
64            if !(top_p > 0.0 && top_p <= 1.0) {
65                return Err(SamplingError::InvalidTopP(top_p));
66            }
67        }
68        if let Some(penalty) = self.repetition_penalty {
69            if penalty < 1.0 {
70                return Err(SamplingError::InvalidRepetitionPenalty(penalty));
71            }
72        }
73        Ok(())
74    }
75}
76
77/// Stateful sampler using deterministic RNG.
78pub struct Sampler {
79    cfg: SamplingConfig,
80    rng: XorShift64,
81}
82
83impl Sampler {
84    pub fn new(cfg: SamplingConfig) -> Result<Self, SamplingError> {
85        cfg.validate()?;
86        Ok(Self {
87            rng: XorShift64::seeded(cfg.seed),
88            cfg,
89        })
90    }
91
92    /// Sample a token from logits with optional repetition penalty against `history`.
93    pub fn sample(&mut self, logits: &[f32], history: &[i32]) -> Result<i32, SamplingError> {
94        if logits.is_empty() {
95            return Err(SamplingError::EmptyLogits);
96        }
97
98        let adjusted: Cow<'_, [f32]> = if let Some(penalty) = self.cfg.repetition_penalty {
99            let mut buf = logits.to_vec();
100            apply_repetition_penalty(&mut buf, history, penalty)?;
101            Cow::Owned(buf)
102        } else {
103            Cow::Borrowed(logits)
104        };
105
106        if self.cfg.strategy == SamplingStrategy::Greedy {
107            return greedy_sample(&adjusted);
108        }
109
110        let mut probs = softmax_with_temperature(&adjusted, self.cfg.temperature)?;
111
112        if let Some(top_k) = self.cfg.top_k {
113            apply_top_k(&mut probs, top_k);
114            normalize_probs(&mut probs);
115        }
116
117        if let Some(top_p) = self.cfg.top_p {
118            apply_top_p(&mut probs, top_p)?;
119        }
120
121        normalize_probs(&mut probs);
122        Ok(sample_from_probs(&probs, &mut self.rng))
123    }
124}
125
126/// Greedy decoding (argmax).
127pub fn greedy_sample(logits: &[f32]) -> Result<i32, SamplingError> {
128    if logits.is_empty() {
129        return Err(SamplingError::EmptyLogits);
130    }
131
132    let mut best_idx = 0usize;
133    let mut best_val = logits[0];
134    for (idx, &val) in logits.iter().enumerate().skip(1) {
135        if val > best_val {
136            best_idx = idx;
137            best_val = val;
138        }
139    }
140    Ok(best_idx as i32)
141}
142
143/// Apply repetition penalty to logits for tokens present in `history`.
144///
145/// Rule follows common practice:
146/// - positive logit: divide by penalty
147/// - negative logit: multiply by penalty
148pub fn apply_repetition_penalty(
149    logits: &mut [f32],
150    history: &[i32],
151    penalty: f32,
152) -> Result<(), SamplingError> {
153    if penalty < 1.0 {
154        return Err(SamplingError::InvalidRepetitionPenalty(penalty));
155    }
156
157    let mut seen = vec![false; logits.len()];
158    for &token in history {
159        if token < 0 {
160            return Err(SamplingError::InvalidHistoryToken(token));
161        }
162        let idx = token as usize;
163        if idx >= logits.len() {
164            return Err(SamplingError::InvalidHistoryToken(token));
165        }
166        if !seen[idx] {
167            seen[idx] = true;
168            if logits[idx] > 0.0 {
169                logits[idx] /= penalty;
170            } else {
171                logits[idx] *= penalty;
172            }
173        }
174    }
175    Ok(())
176}
177
178fn softmax_with_temperature(logits: &[f32], temperature: f32) -> Result<Vec<f32>, SamplingError> {
179    if logits.is_empty() {
180        return Err(SamplingError::EmptyLogits);
181    }
182    if temperature <= 0.0 {
183        return Err(SamplingError::InvalidTemperature(temperature));
184    }
185
186    let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
187    let max_val = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
188    let mut exps: Vec<f32> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
189    normalize_probs(&mut exps);
190    Ok(exps)
191}
192
193fn normalize_probs(probs: &mut [f32]) {
194    let sum: f32 = probs.iter().sum();
195    if sum <= 0.0 {
196        return;
197    }
198    for p in probs.iter_mut() {
199        *p /= sum;
200    }
201}
202
203fn apply_top_k(probs: &mut [f32], top_k: usize) {
204    if top_k == 0 || top_k >= probs.len() {
205        return;
206    }
207
208    let mut order: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
209    order.sort_by(|a, b| {
210        b.1.partial_cmp(&a.1)
211            .unwrap_or(Ordering::Equal)
212            .then_with(|| a.0.cmp(&b.0))
213    });
214
215    for &(idx, _) in order.iter().skip(top_k) {
216        probs[idx] = 0.0;
217    }
218}
219
220fn apply_top_p(probs: &mut [f32], top_p: f32) -> Result<(), SamplingError> {
221    if !(top_p > 0.0 && top_p <= 1.0) {
222        return Err(SamplingError::InvalidTopP(top_p));
223    }
224
225    let mut order: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
226    order.sort_by(|a, b| {
227        b.1.partial_cmp(&a.1)
228            .unwrap_or(Ordering::Equal)
229            .then_with(|| a.0.cmp(&b.0))
230    });
231
232    let mut cumulative = 0.0f32;
233    let mut keep = vec![false; probs.len()];
234    for &(idx, p) in &order {
235        cumulative += p;
236        keep[idx] = true;
237        if cumulative >= top_p {
238            break;
239        }
240    }
241
242    for (idx, p) in probs.iter_mut().enumerate() {
243        if !keep[idx] {
244            *p = 0.0;
245        }
246    }
247    Ok(())
248}
249
250fn sample_from_probs(probs: &[f32], rng: &mut XorShift64) -> i32 {
251    let r = rng.next_f32();
252    let mut cumulative = 0.0f32;
253    for (idx, &p) in probs.iter().enumerate() {
254        if p <= 0.0 {
255            continue;
256        }
257        cumulative += p;
258        if r < cumulative {
259            return idx as i32;
260        }
261    }
262
263    // If all probs are zero after filtering, fall back to argmax.
264    probs
265        .iter()
266        .enumerate()
267        .max_by(|a, b| {
268            a.1.partial_cmp(b.1)
269                .unwrap_or(Ordering::Equal)
270                .then_with(|| b.0.cmp(&a.0))
271        })
272        .map(|(i, _)| i as i32)
273        .unwrap_or(0)
274}
275
276#[derive(Debug, Clone)]
277struct XorShift64 {
278    state: u64,
279}
280
281impl XorShift64 {
282    fn seeded(seed: u64) -> Self {
283        // Avoid stuck zero state.
284        let state = if seed == 0 {
285            0x9E37_79B9_7F4A_7C15
286        } else {
287            seed
288        };
289        Self { state }
290    }
291
292    fn next_u64(&mut self) -> u64 {
293        let mut x = self.state;
294        x ^= x << 13;
295        x ^= x >> 7;
296        x ^= x << 17;
297        self.state = x;
298        x
299    }
300
301    fn next_f32(&mut self) -> f32 {
302        let v = self.next_u64() >> 40; // 24 bits
303        (v as f32) / ((1u32 << 24) as f32)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn greedy_selects_max_logit() {
313        let logits = vec![0.1, 2.0, 1.5];
314        assert_eq!(greedy_sample(&logits).unwrap(), 1);
315    }
316
317    #[test]
318    fn top_k_limits_candidates() {
319        let cfg = SamplingConfig {
320            top_k: Some(1),
321            seed: 42,
322            ..SamplingConfig::default()
323        };
324        let mut sampler = Sampler::new(cfg).unwrap();
325        let logits = vec![0.1, 5.0, 4.0];
326        assert_eq!(sampler.sample(&logits, &[]).unwrap(), 1);
327    }
328
329    #[test]
330    fn top_p_limits_tail() {
331        let cfg = SamplingConfig {
332            top_p: Some(0.55),
333            seed: 42,
334            ..SamplingConfig::default()
335        };
336        let mut sampler = Sampler::new(cfg).unwrap();
337        let logits = vec![4.0, 2.0, 1.0];
338        // top_p should keep only token 0 for this distribution.
339        assert_eq!(sampler.sample(&logits, &[]).unwrap(), 0);
340    }
341
342    #[test]
343    fn seeded_rng_is_deterministic() {
344        let cfg = SamplingConfig {
345            top_k: Some(3),
346            top_p: Some(0.95),
347            temperature: 0.9,
348            seed: 12345,
349            ..SamplingConfig::default()
350        };
351        let mut a = Sampler::new(cfg.clone()).unwrap();
352        let mut b = Sampler::new(cfg).unwrap();
353
354        let logits = vec![1.0, 1.1, 1.2, 1.3];
355        let mut seq_a = Vec::new();
356        let mut seq_b = Vec::new();
357        for _ in 0..20 {
358            seq_a.push(a.sample(&logits, &seq_a).unwrap());
359            seq_b.push(b.sample(&logits, &seq_b).unwrap());
360        }
361
362        assert_eq!(seq_a, seq_b);
363    }
364
365    #[test]
366    fn repetition_penalty_masks_logits() {
367        let cfg = SamplingConfig {
368            strategy: SamplingStrategy::Greedy,
369            repetition_penalty: Some(2.0),
370            ..SamplingConfig::default()
371        };
372        let mut sampler = Sampler::new(cfg).unwrap();
373
374        let logits = vec![0.9, 1.0];
375        let history = vec![1];
376        // token 1 becomes 0.5 after penalty, so token 0 should win.
377        assert_eq!(sampler.sample(&logits, &history).unwrap(), 0);
378    }
379
380    #[test]
381    fn invalid_config_is_rejected() {
382        let cfg = SamplingConfig {
383            temperature: 0.0,
384            ..SamplingConfig::default()
385        };
386        assert!(matches!(
387            Sampler::new(cfg),
388            Err(SamplingError::InvalidTemperature(0.0))
389        ));
390    }
391}