Skip to main content

oxibonsai_runtime/
sampling.rs

1//! Sampling strategies for text generation.
2//!
3//! Supports temperature scaling, top-k filtering, top-p (nucleus) filtering,
4//! and repetition penalty. The [`Sampler`] converts a logit vector into a
5//! single token ID using these strategies in order:
6//!
7//! 1. **Temperature scaling** — divide logits by temperature (0 = greedy argmax)
8//! 2. **Top-k** — keep only the k highest-probability candidates
9//! 3. **Softmax** — convert scaled logits to probabilities
10//! 4. **Top-p** — keep the smallest set of tokens whose cumulative probability exceeds p
11//! 5. **Weighted random selection** — sample from the filtered distribution
12
13use std::cmp::Ordering;
14
15use crate::error::RuntimeResult;
16
17/// Sampling parameters.
18#[derive(Debug, Clone)]
19pub struct SamplingParams {
20    /// Temperature for softmax scaling. 0.0 = greedy.
21    pub temperature: f32,
22    /// Top-k filtering (0 = disabled).
23    pub top_k: usize,
24    /// Top-p (nucleus) threshold (1.0 = disabled).
25    pub top_p: f32,
26    /// Repetition penalty (1.0 = disabled).
27    pub repetition_penalty: f32,
28    /// Maximum number of new tokens to generate per request.
29    pub max_tokens: usize,
30}
31
32impl Default for SamplingParams {
33    fn default() -> Self {
34        Self {
35            temperature: 0.7,
36            top_k: 40,
37            top_p: 0.9,
38            repetition_penalty: 1.1,
39            max_tokens: 128,
40        }
41    }
42}
43
44/// Token sampler.
45///
46/// Owns a reusable `probs_buf` that is grown on first use and then reused across
47/// all subsequent `sample()` calls, eliminating the ~1.8 MB per-call heap
48/// allocation that a fresh `Vec` would require for a 151 936-token vocabulary.
49#[derive(Debug)]
50pub struct Sampler {
51    params: SamplingParams,
52    rng_state: u64,
53    /// Reusable working buffer for `(token_index, scaled_logit)` pairs.
54    ///
55    /// After `select_nth_unstable_by` + `drain` the buffer holds only the top-k
56    /// candidates (capacity stays at `vocab_size`).  `clear()` on the next call
57    /// resets length to zero without freeing the backing store, so subsequent
58    /// `extend()` calls never reallocate.
59    probs_buf: Vec<(usize, f32)>,
60}
61
62impl Sampler {
63    /// Create a new sampler with the given parameters and seed.
64    pub fn new(params: SamplingParams, seed: u64) -> Self {
65        Self {
66            params,
67            rng_state: seed,
68            probs_buf: Vec::new(),
69        }
70    }
71
72    /// Simple xorshift64 PRNG — no external dependency needed.
73    fn next_u64(&mut self) -> u64 {
74        let mut x = self.rng_state;
75        x ^= x << 13;
76        x ^= x >> 7;
77        x ^= x << 17;
78        self.rng_state = x;
79        x
80    }
81
82    /// Sample a token index from logits.
83    #[tracing::instrument(skip(self, logits), fields(vocab_size = logits.len()), level = "debug")]
84    pub fn sample(&mut self, logits: &[f32]) -> RuntimeResult<u32> {
85        if logits.is_empty() {
86            return Ok(0);
87        }
88
89        // Greedy if temperature is ~0
90        if self.params.temperature < 1e-6 {
91            return Ok(argmax(logits) as u32);
92        }
93
94        // Populate the reusable buffer with temperature-scaled logits.
95        // On the first call this allocates `vocab_size × 12` bytes; every
96        // subsequent call reuses the existing backing store (len is reset to 0
97        // by `clear()`, capacity is preserved from the previous call).
98        self.probs_buf.clear();
99        self.probs_buf.extend(
100            logits
101                .iter()
102                .enumerate()
103                .map(|(i, &v)| (i, v / self.params.temperature)),
104        );
105
106        // Top-k filtering — O(n) average via partial selection rather than O(n log n) full sort.
107        // `select_nth_unstable_by` rearranges `probs_buf` so that element at index `cutoff` is in
108        // its fully-sorted position, all elements before it are ≤ it (lower scaled logits), and all
109        // elements after it are ≥ it (higher scaled logits).  Draining the prefix leaves exactly
110        // the top-k elements in arbitrary order, which is sufficient for softmax + sampling.
111        if self.params.top_k > 0 && self.params.top_k < self.probs_buf.len() {
112            let k = self.params.top_k;
113            let cutoff = self.probs_buf.len() - k;
114            self.probs_buf.select_nth_unstable_by(cutoff, |a, b| {
115                a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
116            });
117            self.probs_buf.drain(..cutoff);
118        }
119
120        // Softmax
121        let max_val = self
122            .probs_buf
123            .iter()
124            .map(|(_, v)| *v)
125            .fold(f32::NEG_INFINITY, f32::max);
126        let mut sum = 0.0f32;
127        for (_, v) in self.probs_buf.iter_mut() {
128            *v = (*v - max_val).exp();
129            sum += *v;
130        }
131        for (_, v) in self.probs_buf.iter_mut() {
132            *v /= sum;
133        }
134
135        // Top-p filtering
136        if self.params.top_p < 1.0 {
137            self.probs_buf
138                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
139            let mut cum = 0.0f32;
140            let cutoff = self
141                .probs_buf
142                .iter()
143                .position(|&(_, p)| {
144                    cum += p;
145                    cum > self.params.top_p
146                })
147                .unwrap_or(self.probs_buf.len().saturating_sub(1));
148            self.probs_buf.truncate(cutoff + 1);
149
150            // Re-normalize
151            let sum: f32 = self.probs_buf.iter().map(|(_, p)| p).sum();
152            for (_, p) in self.probs_buf.iter_mut() {
153                *p /= sum;
154            }
155        }
156
157        // Pre-compute random value before the immutable borrow of `probs_buf`
158        // to satisfy the borrow checker: `next_u64` takes `&mut self` which
159        // would conflict with an active `&self.probs_buf` borrow.
160        let rand_val = (self.next_u64() as f64 / u64::MAX as f64) as f32;
161
162        // Weighted random selection
163        let mut cum = 0.0f32;
164        for &(idx, p) in &self.probs_buf {
165            cum += p;
166            if rand_val <= cum {
167                return Ok(idx as u32);
168            }
169        }
170
171        // Fallback: return the highest probability token
172        Ok(self.probs_buf[0].0 as u32)
173    }
174
175    /// Get current parameters.
176    pub fn params(&self) -> &SamplingParams {
177        &self.params
178    }
179}
180
181/// Return the index of the maximum element.
182fn argmax(values: &[f32]) -> usize {
183    values
184        .iter()
185        .enumerate()
186        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
187        .map(|(i, _)| i)
188        .unwrap_or(0)
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn greedy_sampling() {
197        let params = SamplingParams {
198            temperature: 0.0,
199            ..SamplingParams::default()
200        };
201        let mut sampler = Sampler::new(params, 42);
202        let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
203        let token = sampler.sample(&logits).expect("sampling should succeed");
204        assert_eq!(token, 3); // index of 0.9
205    }
206
207    #[test]
208    fn sampling_returns_valid_index() {
209        let params = SamplingParams::default();
210        let mut sampler = Sampler::new(params, 12345);
211        let logits = vec![0.0f32; 100];
212        for _ in 0..50 {
213            let token = sampler.sample(&logits).expect("sampling should succeed");
214            assert!(token < 100);
215        }
216    }
217
218    #[test]
219    fn argmax_basic() {
220        assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
221        assert_eq!(argmax(&[5.0]), 0);
222    }
223
224    #[test]
225    fn buffer_reuse_across_calls() {
226        // Verify the probs_buf is correctly reused without incorrect state leaking.
227        let params = SamplingParams {
228            temperature: 0.7,
229            top_k: 5,
230            top_p: 1.0, // disable top-p so we control exactly
231            repetition_penalty: 1.0,
232            max_tokens: 128,
233        };
234        let mut sampler = Sampler::new(params, 99);
235        let logits: Vec<f32> = (0..200).map(|i| i as f32 * 0.01).collect();
236        for _ in 0..20 {
237            let token = sampler.sample(&logits).expect("sampling should succeed");
238            // Top-k=5 on ascending logits: only the last 5 indices (195-199) are valid
239            assert!(token >= 195, "expected token ≥ 195, got {token}");
240        }
241    }
242}