Skip to main content

sapient_generate/
sampler.rs

1//! Token sampling strategies for text generation.
2//!
3//! - **Greedy**: always pick the highest-probability token (deterministic).
4//! - **Top-K**: sample from the top K tokens.
5//! - **Top-P (nucleus)**: sample from the smallest set whose cumulative
6//!   probability exceeds P.
7//! - **Temperature**: scale logits before softmax.
8
9use sapient_core::error::{Result, SapientError};
10
11// ── SamplingStrategy ──────────────────────────────────────────────────────────
12
13#[derive(Debug, Clone)]
14pub enum SamplingStrategy {
15    /// Always pick the argmax token — fastest, deterministic.
16    Greedy,
17    /// Sample with temperature scaling.
18    Temperature(f32),
19    /// Sample from the top-k highest probability tokens.
20    TopK { k: usize, temperature: f32 },
21    /// Nucleus sampling — sample from the minimum set covering probability p.
22    TopP { p: f32, temperature: f32 },
23    /// Combined top-k + top-p + temperature.
24    Combined {
25        top_k: usize,
26        top_p: f32,
27        temperature: f32,
28        repetition_penalty: f32,
29    },
30}
31
32impl Default for SamplingStrategy {
33    fn default() -> Self {
34        Self::Greedy
35    }
36}
37
38// ── Sampler ───────────────────────────────────────────────────────────────────
39
40pub struct Sampler {
41    pub strategy: SamplingStrategy,
42    rng_seed: u64,
43    counter: u64,
44}
45
46impl Sampler {
47    pub fn new(strategy: SamplingStrategy) -> Self {
48        let seed = std::time::SystemTime::now()
49            .duration_since(std::time::UNIX_EPOCH)
50            .map(|d| d.as_nanos() as u64)
51            .unwrap_or(42);
52        Self {
53            strategy,
54            rng_seed: seed,
55            counter: 0,
56        }
57    }
58
59    pub fn with_seed(strategy: SamplingStrategy, seed: u64) -> Self {
60        Self {
61            strategy,
62            rng_seed: seed,
63            counter: 0,
64        }
65    }
66
67    /// Sample the next token from `logits` — shape: (vocab_size,).
68    /// Optionally pass previously generated `token_ids` for repetition penalty.
69    pub fn sample(&mut self, logits: &[f32], prev_tokens: &[u32]) -> Result<u32> {
70        match &self.strategy {
71            SamplingStrategy::Greedy => Ok(argmax(logits)),
72
73            SamplingStrategy::Temperature(t) => {
74                let t = *t;
75                if t <= 0.0 {
76                    return Ok(argmax(logits));
77                }
78                let scaled = scale_logits(logits, t);
79                let probs = softmax(&scaled);
80                Ok(self.random_sample(&probs))
81            }
82
83            SamplingStrategy::TopK { k, temperature } => {
84                let (k, t) = (*k, *temperature);
85                if t <= 0.0 {
86                    return Ok(argmax(logits));
87                }
88                let scaled = scale_logits(logits, t);
89                let filtered = top_k_filter(&scaled, k);
90                let probs = softmax(&filtered);
91                Ok(self.random_sample(&probs))
92            }
93
94            SamplingStrategy::TopP { p, temperature } => {
95                let (p, t) = (*p, *temperature);
96                if t <= 0.0 {
97                    return Ok(argmax(logits));
98                }
99                let scaled = scale_logits(logits, t);
100                let filtered = top_p_filter(&scaled, p);
101                let probs = softmax(&filtered);
102                Ok(self.random_sample(&probs))
103            }
104
105            SamplingStrategy::Combined {
106                top_k,
107                top_p,
108                temperature,
109                repetition_penalty,
110            } => {
111                let (k, p, t, rp) = (*top_k, *top_p, *temperature, *repetition_penalty);
112                let mut penalized = apply_repetition_penalty(logits, prev_tokens, rp);
113                if t <= 0.0 {
114                    return Ok(argmax(&penalized));
115                }
116                penalized = scale_logits(&penalized, t);
117                penalized = top_k_filter(&penalized, k);
118                penalized = top_p_filter(&penalized, p);
119                let probs = softmax(&penalized);
120                Ok(self.random_sample(&probs))
121            }
122        }
123    }
124
125    /// Simple xorshift RNG for sampling without an external rand crate.
126    fn random_u64(&mut self) -> u64 {
127        self.counter += 1;
128        let mut x = self
129            .rng_seed
130            .wrapping_add(self.counter.wrapping_mul(6364136223846793005));
131        x ^= x >> 30;
132        x = x.wrapping_mul(0xbf58476d1ce4e5b9);
133        x ^= x >> 27;
134        x = x.wrapping_mul(0x94d049bb133111eb);
135        x ^= x >> 31;
136        x
137    }
138
139    fn random_f32(&mut self) -> f32 {
140        (self.random_u64() >> 11) as f32 / (1u64 << 53) as f32
141    }
142
143    fn random_sample(&mut self, probs: &[f32]) -> u32 {
144        let r = self.random_f32();
145        let mut cum = 0.0f32;
146        for (i, &p) in probs.iter().enumerate() {
147            cum += p;
148            if r < cum {
149                return i as u32;
150            }
151        }
152        (probs.len() - 1) as u32
153    }
154}
155
156// ── Helpers ───────────────────────────────────────────────────────────────────
157
158pub fn argmax(logits: &[f32]) -> u32 {
159    logits
160        .iter()
161        .enumerate()
162        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
163        .map(|(i, _)| i as u32)
164        .unwrap_or(0)
165}
166
167fn softmax(logits: &[f32]) -> Vec<f32> {
168    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
169    let mut out: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
170    let sum: f32 = out.iter().sum();
171    out.iter_mut().for_each(|x| *x /= sum);
172    out
173}
174
175fn scale_logits(logits: &[f32], temperature: f32) -> Vec<f32> {
176    if temperature <= 0.0 || temperature == 1.0 {
177        return logits.to_vec();
178    }
179    logits.iter().map(|&x| x / temperature).collect()
180}
181
182fn top_k_filter(logits: &[f32], k: usize) -> Vec<f32> {
183    if k == 0 || k >= logits.len() {
184        return logits.to_vec();
185    }
186    let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
187    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
188    let threshold = indexed[k - 1].1;
189    logits
190        .iter()
191        .map(|&x| if x >= threshold { x } else { f32::NEG_INFINITY })
192        .collect()
193}
194
195fn top_p_filter(logits: &[f32], p: f32) -> Vec<f32> {
196    let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
197    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198
199    let probs = softmax(logits);
200    let mut sorted_probs: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
201    sorted_probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
202
203    let mut cum = 0.0f32;
204    let mut cutoff_idx = sorted_probs.len();
205    for (i, (_, prob)) in sorted_probs.iter().enumerate() {
206        cum += prob;
207        if cum >= p {
208            cutoff_idx = i + 1;
209            break;
210        }
211    }
212
213    let keep: std::collections::HashSet<usize> =
214        sorted_probs[..cutoff_idx].iter().map(|(i, _)| *i).collect();
215    logits
216        .iter()
217        .enumerate()
218        .map(|(i, &x)| {
219            if keep.contains(&i) {
220                x
221            } else {
222                f32::NEG_INFINITY
223            }
224        })
225        .collect()
226}
227
228fn apply_repetition_penalty(logits: &[f32], prev_tokens: &[u32], penalty: f32) -> Vec<f32> {
229    if (penalty - 1.0).abs() < 1e-6 {
230        return logits.to_vec();
231    }
232    let mut out = logits.to_vec();
233    for &tok in prev_tokens {
234        let idx = tok as usize;
235        if idx < out.len() {
236            if out[idx] >= 0.0 {
237                out[idx] /= penalty;
238            } else {
239                out[idx] *= penalty;
240            }
241        }
242    }
243    out
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn greedy_picks_argmax() {
252        let logits = vec![0.1, 0.9, 0.3, 0.5];
253        let mut s = Sampler::with_seed(SamplingStrategy::Greedy, 42);
254        assert_eq!(s.sample(&logits, &[]).unwrap(), 1);
255    }
256
257    #[test]
258    fn top_k_removes_low_prob() {
259        let logits = vec![10.0, 1.0, 1.0, 1.0];
260        let filtered = top_k_filter(&logits, 1);
261        assert_eq!(filtered[0], 10.0);
262        assert!(filtered[1].is_infinite() && filtered[1] < 0.0);
263    }
264
265    #[test]
266    fn repetition_penalty_reduces_score() {
267        let logits = vec![1.0, 2.0, 3.0];
268        let penalized = apply_repetition_penalty(&logits, &[2], 1.3);
269        assert!(penalized[2] < logits[2]);
270    }
271}