Skip to main content

llama_sampling/
lib.rs

1//! # llama-sampling
2//!
3//! Sampling and decoding strategies for llama.rs.
4//!
5//! Supports:
6//! - Greedy (argmax)
7//! - Temperature scaling
8//! - Top-k filtering
9//! - Top-p (nucleus) filtering
10//! - Repetition penalty
11//! - Deterministic seeded RNG for reproducible generation
12
13/// Sampling error type.
14#[derive(Debug, Clone, PartialEq)]
15pub enum SamplingError {
16    InvalidLogits,
17    InvalidTemperature,
18    NoValidTokens,
19}
20
21impl std::fmt::Display for SamplingError {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            SamplingError::InvalidLogits => write!(f, "Invalid logits array"),
25            SamplingError::InvalidTemperature => write!(f, "Temperature must be > 0"),
26            SamplingError::NoValidTokens => write!(f, "No valid tokens after filtering"),
27        }
28    }
29}
30
31impl std::error::Error for SamplingError {}
32
33pub type SamplingResult<T> = std::result::Result<T, SamplingError>;
34
35/// Deterministic RNG for reproducible sampling.
36///
37/// Uses a simple xorshift64 algorithm for fast, reproducible random numbers.
38#[derive(Debug, Clone)]
39pub struct SeededRng {
40    state: u64,
41}
42
43impl SeededRng {
44    pub fn new(seed: u64) -> Self {
45        // Avoid zero state which would produce all zeros
46        Self {
47            state: if seed == 0 { 1 } else { seed },
48        }
49    }
50
51    /// Generate next random float in [0, 1).
52    pub fn next_f32(&mut self) -> f32 {
53        // xorshift64
54        self.state ^= self.state << 13;
55        self.state ^= self.state >> 7;
56        self.state ^= self.state << 17;
57        (self.state >> 40) as f32 / (1u64 << 24) as f32
58    }
59}
60
61/// Sampling configuration and strategy.
62#[derive(Debug, Clone)]
63pub struct Sampler {
64    /// Temperature for softmax scaling. > 1.0 = more random, < 1.0 = more deterministic.
65    pub temperature: f32,
66
67    /// Top-k: only sample from top k logits.
68    pub top_k: Option<usize>,
69
70    /// Top-p (nucleus sampling): sample from smallest set of tokens with cumulative prob >= p.
71    pub top_p: Option<f32>,
72
73    /// Repetition penalty: penalize tokens that appear in history.
74    pub repetition_penalty: Option<f32>,
75
76    /// RNG state for reproducible sampling. Mutated on each call.
77    rng: SeededRng,
78}
79
80impl Sampler {
81    /// Create a sampler with default settings (greedy).
82    pub fn new() -> Self {
83        Self {
84            temperature: 1.0,
85            top_k: None,
86            top_p: None,
87            repetition_penalty: None,
88            rng: SeededRng::new(42),
89        }
90    }
91
92    pub fn with_temperature(mut self, temp: f32) -> Self {
93        self.temperature = temp;
94        self
95    }
96
97    pub fn with_top_k(mut self, k: usize) -> Self {
98        self.top_k = Some(k);
99        self
100    }
101
102    pub fn with_top_p(mut self, p: f32) -> Self {
103        self.top_p = Some(p);
104        self
105    }
106
107    pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
108        self.repetition_penalty = Some(penalty);
109        self
110    }
111
112    pub fn with_seed(mut self, seed: u64) -> Self {
113        self.rng = SeededRng::new(seed);
114        self
115    }
116
117    /// Sample a token index from logits using configured strategy.
118    pub fn sample(&mut self, logits: &[f32]) -> SamplingResult<usize> {
119        self.sample_inner(logits, &[])
120    }
121
122    /// Sample with history for repetition penalty.
123    pub fn sample_with_history(
124        &mut self,
125        logits: &[f32],
126        history: &[usize],
127    ) -> SamplingResult<usize> {
128        self.sample_inner(logits, history)
129    }
130
131    fn sample_inner(&mut self, logits: &[f32], history: &[usize]) -> SamplingResult<usize> {
132        if logits.is_empty() {
133            return Err(SamplingError::InvalidLogits);
134        }
135
136        if self.temperature <= 0.0 {
137            return Err(SamplingError::InvalidTemperature);
138        }
139
140        let mut work_logits = logits.to_vec();
141
142        // Apply repetition penalty: for tokens in history, divide positive
143        // logits by penalty and multiply negative logits by penalty.
144        // This always makes repeated tokens less likely regardless of sign.
145        if let Some(penalty) = self.repetition_penalty {
146            for &token_id in history {
147                if token_id < work_logits.len() {
148                    if work_logits[token_id] > 0.0 {
149                        work_logits[token_id] /= penalty;
150                    } else {
151                        work_logits[token_id] *= penalty;
152                    }
153                }
154            }
155        }
156
157        // Apply temperature scaling
158        if (self.temperature - 1.0).abs() > 1e-6 {
159            for logit in &mut work_logits {
160                *logit /= self.temperature;
161            }
162        }
163
164        // Apply top-k filtering
165        if let Some(k) = self.top_k {
166            Self::apply_top_k(&mut work_logits, k);
167        }
168
169        // Convert to probabilities
170        let probs = Self::softmax(&work_logits);
171
172        // If temperature is very low (near-greedy), just argmax
173        if self.temperature < 1e-3 {
174            return Ok(Self::argmax(&probs));
175        }
176
177        // Apply top-p (nucleus) filtering
178        let probs = if let Some(p) = self.top_p {
179            Self::apply_top_p(&probs, p)
180        } else {
181            probs
182        };
183
184        // Sample from distribution
185        self.sample_from_distribution(&probs)
186    }
187
188    fn apply_top_k(logits: &mut [f32], k: usize) {
189        if k == 0 || k >= logits.len() {
190            return;
191        }
192
193        let mut indexed: Vec<(usize, f32)> =
194            logits.iter().enumerate().map(|(i, &l)| (i, l)).collect();
195        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
196
197        let threshold = indexed[k - 1].1;
198        for logit in logits.iter_mut() {
199            if *logit < threshold {
200                *logit = f32::NEG_INFINITY;
201            }
202        }
203    }
204
205    fn apply_top_p(probs: &[f32], p: f32) -> Vec<f32> {
206        let mut indexed: Vec<(usize, f32)> =
207            probs.iter().enumerate().map(|(i, &pr)| (i, pr)).collect();
208        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
209
210        let mut cumsum = 0.0;
211        let mut cutoff_idx = 0;
212        for (idx, (_, prob)) in indexed.iter().enumerate() {
213            cumsum += prob;
214            cutoff_idx = idx;
215            if cumsum >= p {
216                break;
217            }
218        }
219
220        let cutoff_prob = indexed[cutoff_idx].1;
221        let mut result = vec![0.0; probs.len()];
222        for (i, &pr) in probs.iter().enumerate() {
223            if pr >= cutoff_prob {
224                result[i] = pr;
225            }
226        }
227
228        // Renormalize
229        let sum: f32 = result.iter().sum();
230        if sum > 0.0 {
231            for p in &mut result {
232                *p /= sum;
233            }
234        }
235
236        result
237    }
238
239    fn softmax(logits: &[f32]) -> Vec<f32> {
240        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
241        let exps: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
242        let sum: f32 = exps.iter().sum();
243
244        if sum > 0.0 {
245            exps.iter().map(|&e| e / sum).collect()
246        } else {
247            vec![1.0 / logits.len() as f32; logits.len()]
248        }
249    }
250
251    fn argmax(probs: &[f32]) -> usize {
252        probs
253            .iter()
254            .enumerate()
255            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
256            .map(|(idx, _)| idx)
257            .unwrap_or(0)
258    }
259
260    fn sample_from_distribution(&mut self, probs: &[f32]) -> SamplingResult<usize> {
261        let r = self.rng.next_f32();
262        let mut cumsum = 0.0;
263
264        for (i, &prob) in probs.iter().enumerate() {
265            cumsum += prob;
266            if r < cumsum {
267                return Ok(i);
268            }
269        }
270
271        // Fallback to last token with nonzero probability
272        for (i, &prob) in probs.iter().enumerate().rev() {
273            if prob > 0.0 {
274                return Ok(i);
275            }
276        }
277
278        Err(SamplingError::NoValidTokens)
279    }
280}
281
282impl Default for Sampler {
283    fn default() -> Self {
284        Self::new()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn seeded_rng_reproducible() {
294        let mut rng1 = SeededRng::new(42);
295        let mut rng2 = SeededRng::new(42);
296
297        for _ in 0..100 {
298            let v1 = rng1.next_f32();
299            let v2 = rng2.next_f32();
300            assert!((v1 - v2).abs() < 1e-6);
301            assert!((0.0..1.0).contains(&v1));
302        }
303    }
304
305    #[test]
306    fn greedy_sampling() {
307        let logits = vec![1.0, 10.0, 2.0, 0.5];
308        let mut sampler = Sampler::new().with_temperature(0.0001);
309        let token = sampler.sample(&logits).unwrap();
310        assert_eq!(token, 1);
311    }
312
313    #[test]
314    fn softmax_uniform() {
315        let logits = vec![1.0, 1.0, 1.0];
316        let probs = Sampler::softmax(&logits);
317        assert_eq!(probs.len(), 3);
318        assert!((probs[0] - 1.0 / 3.0).abs() < 1e-5);
319        assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
320    }
321
322    #[test]
323    fn temperature_effect() {
324        let logits = vec![1.0, 2.0, 0.5];
325
326        let high_temp: Vec<f32> = logits.iter().map(|l| l / 10.0).collect();
327        let low_temp: Vec<f32> = logits.iter().map(|l| l / 0.1).collect();
328
329        let high_probs = Sampler::softmax(&high_temp);
330        let low_probs = Sampler::softmax(&low_temp);
331
332        let max_high = high_probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
333        let max_low = low_probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
334
335        // Higher temperature = more uniform = lower peak
336        assert!(max_high < max_low);
337    }
338
339    #[test]
340    fn top_k_filtering() {
341        let mut logits = vec![1.0, 10.0, 2.0, 0.5, 3.0];
342        Sampler::apply_top_k(&mut logits, 2);
343        assert!(logits[1].is_finite()); // Top token
344        assert!(logits[4].is_finite()); // 2nd top token
345        assert!(!logits[0].is_finite()); // Below top-k
346    }
347
348    #[test]
349    fn top_p_filtering() {
350        let probs = vec![0.5, 0.3, 0.15, 0.05];
351        let filtered = Sampler::apply_top_p(&probs, 0.8);
352        assert!(filtered[0] > 0.0);
353        assert!(filtered[1] > 0.0);
354        assert_eq!(filtered[2], 0.0);
355        assert_eq!(filtered[3], 0.0);
356    }
357
358    #[test]
359    fn repetition_penalty_reduces_likelihood() {
360        let logits = vec![1.0, 2.0, 3.0, 4.0];
361        let history = vec![3]; // Token 3 in history
362
363        // Without penalty
364        let probs_no_penalty = Sampler::softmax(&logits);
365
366        // With penalty applied
367        let mut penalized = logits.clone();
368        penalized[3] /= 2.0; // Positive logit divided by penalty
369        let probs_with_penalty = Sampler::softmax(&penalized);
370
371        // Token 3 should have lower probability after penalty
372        assert!(probs_with_penalty[3] < probs_no_penalty[3]);
373
374        // Verify via sampler API
375        let mut sampler = Sampler::new().with_repetition_penalty(2.0);
376        let result = sampler.sample_with_history(&logits, &history);
377        assert!(result.is_ok());
378    }
379
380    #[test]
381    fn repetition_penalty_handles_negative_logits() {
382        let logits = vec![-1.0, -2.0, 3.0];
383        let history = vec![0, 1]; // Negative logit tokens in history
384
385        let mut sampler = Sampler::new().with_repetition_penalty(2.0).with_seed(42);
386        let result = sampler.sample_with_history(&logits, &history);
387        assert!(result.is_ok());
388    }
389
390    #[test]
391    fn deterministic_across_calls() {
392        let logits = vec![0.1, 0.2, 0.3, 0.4];
393
394        let mut sampler1 = Sampler::new().with_seed(42);
395        let mut sampler2 = Sampler::new().with_seed(42);
396
397        // Multiple calls should produce same sequence
398        for _ in 0..10 {
399            let t1 = sampler1.sample(&logits).unwrap();
400            let t2 = sampler2.sample(&logits).unwrap();
401            assert_eq!(t1, t2);
402        }
403    }
404
405    #[test]
406    fn rng_advances_between_calls() {
407        let logits = vec![0.25, 0.25, 0.25, 0.25];
408        let mut sampler = Sampler::new().with_seed(42);
409
410        // With uniform distribution, we should eventually see different tokens
411        let mut seen = std::collections::HashSet::new();
412        for _ in 0..100 {
413            seen.insert(sampler.sample(&logits).unwrap());
414        }
415        assert!(seen.len() > 1, "RNG should produce varied results");
416    }
417
418    #[test]
419    fn combined_sampling() {
420        let logits = vec![1.0, 2.0, 3.0, 4.0, 0.5, 0.1];
421        let mut sampler = Sampler::new()
422            .with_temperature(0.8)
423            .with_top_k(3)
424            .with_top_p(0.9)
425            .with_seed(42);
426
427        let token = sampler.sample(&logits).unwrap();
428        assert!(token < logits.len());
429    }
430
431    #[test]
432    fn invalid_temperature() {
433        let logits = vec![1.0, 2.0];
434        let mut sampler = Sampler::new().with_temperature(0.0);
435        assert_eq!(
436            sampler.sample(&logits),
437            Err(SamplingError::InvalidTemperature)
438        );
439    }
440
441    #[test]
442    fn empty_logits() {
443        let mut sampler = Sampler::new();
444        assert_eq!(sampler.sample(&[]), Err(SamplingError::InvalidLogits));
445    }
446}