oxibonsai_runtime/sampling.rs
1//! Sampling strategies for text generation.
2//!
3//! Supports temperature scaling, top-k filtering, top-p (nucleus) filtering,
4//! and repetition penalty. The [`Sampler`] converts a logit vector into a
5//! single token ID using these strategies in order:
6//!
7//! 1. **Temperature scaling** — divide logits by temperature (0 = greedy argmax)
8//! 2. **Top-k** — keep only the k highest-probability candidates
9//! 3. **Softmax** — convert scaled logits to probabilities
10//! 4. **Top-p** — keep the smallest set of tokens whose cumulative probability exceeds p
11//! 5. **Weighted random selection** — sample from the filtered distribution
12
13use std::cmp::Ordering;
14
15use crate::error::RuntimeResult;
16
17/// Sampling parameters.
18#[derive(Debug, Clone)]
19pub struct SamplingParams {
20 /// Temperature for softmax scaling. 0.0 = greedy.
21 pub temperature: f32,
22 /// Top-k filtering (0 = disabled).
23 pub top_k: usize,
24 /// Top-p (nucleus) threshold (1.0 = disabled).
25 pub top_p: f32,
26 /// Repetition penalty (1.0 = disabled).
27 pub repetition_penalty: f32,
28 /// Maximum number of new tokens to generate per request.
29 pub max_tokens: usize,
30}
31
32impl Default for SamplingParams {
33 fn default() -> Self {
34 Self {
35 temperature: 0.7,
36 top_k: 40,
37 top_p: 0.9,
38 repetition_penalty: 1.1,
39 max_tokens: 128,
40 }
41 }
42}
43
44/// Token sampler.
45///
46/// Owns a reusable `probs_buf` that is grown on first use and then reused across
47/// all subsequent `sample()` calls, eliminating the ~1.8 MB per-call heap
48/// allocation that a fresh `Vec` would require for a 151 936-token vocabulary.
49#[derive(Debug)]
50pub struct Sampler {
51 params: SamplingParams,
52 rng_state: u64,
53 /// Reusable working buffer for `(token_index, scaled_logit)` pairs.
54 ///
55 /// After `select_nth_unstable_by` + `drain` the buffer holds only the top-k
56 /// candidates (capacity stays at `vocab_size`). `clear()` on the next call
57 /// resets length to zero without freeing the backing store, so subsequent
58 /// `extend()` calls never reallocate.
59 probs_buf: Vec<(usize, f32)>,
60}
61
62impl Sampler {
63 /// Create a new sampler with the given parameters and seed.
64 pub fn new(params: SamplingParams, seed: u64) -> Self {
65 Self {
66 params,
67 rng_state: seed,
68 probs_buf: Vec::new(),
69 }
70 }
71
72 /// Simple xorshift64 PRNG — no external dependency needed.
73 fn next_u64(&mut self) -> u64 {
74 let mut x = self.rng_state;
75 x ^= x << 13;
76 x ^= x >> 7;
77 x ^= x << 17;
78 self.rng_state = x;
79 x
80 }
81
82 /// Sample a token index from logits.
83 #[tracing::instrument(skip(self, logits), fields(vocab_size = logits.len()), level = "debug")]
84 pub fn sample(&mut self, logits: &[f32]) -> RuntimeResult<u32> {
85 if logits.is_empty() {
86 return Ok(0);
87 }
88
89 // Greedy if temperature is ~0
90 if self.params.temperature < 1e-6 {
91 return Ok(argmax(logits) as u32);
92 }
93
94 // Populate the reusable buffer with temperature-scaled logits.
95 // On the first call this allocates `vocab_size × 12` bytes; every
96 // subsequent call reuses the existing backing store (len is reset to 0
97 // by `clear()`, capacity is preserved from the previous call).
98 self.probs_buf.clear();
99 self.probs_buf.extend(
100 logits
101 .iter()
102 .enumerate()
103 .map(|(i, &v)| (i, v / self.params.temperature)),
104 );
105
106 // Top-k filtering — O(n) average via partial selection rather than O(n log n) full sort.
107 // `select_nth_unstable_by` rearranges `probs_buf` so that element at index `cutoff` is in
108 // its fully-sorted position, all elements before it are ≤ it (lower scaled logits), and all
109 // elements after it are ≥ it (higher scaled logits). Draining the prefix leaves exactly
110 // the top-k elements in arbitrary order, which is sufficient for softmax + sampling.
111 if self.params.top_k > 0 && self.params.top_k < self.probs_buf.len() {
112 let k = self.params.top_k;
113 let cutoff = self.probs_buf.len() - k;
114 self.probs_buf.select_nth_unstable_by(cutoff, |a, b| {
115 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
116 });
117 self.probs_buf.drain(..cutoff);
118 }
119
120 // Softmax
121 let max_val = self
122 .probs_buf
123 .iter()
124 .map(|(_, v)| *v)
125 .fold(f32::NEG_INFINITY, f32::max);
126 let mut sum = 0.0f32;
127 for (_, v) in self.probs_buf.iter_mut() {
128 *v = (*v - max_val).exp();
129 sum += *v;
130 }
131 for (_, v) in self.probs_buf.iter_mut() {
132 *v /= sum;
133 }
134
135 // Top-p filtering
136 if self.params.top_p < 1.0 {
137 self.probs_buf
138 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
139 let mut cum = 0.0f32;
140 let cutoff = self
141 .probs_buf
142 .iter()
143 .position(|&(_, p)| {
144 cum += p;
145 cum > self.params.top_p
146 })
147 .unwrap_or(self.probs_buf.len().saturating_sub(1));
148 self.probs_buf.truncate(cutoff + 1);
149
150 // Re-normalize
151 let sum: f32 = self.probs_buf.iter().map(|(_, p)| p).sum();
152 for (_, p) in self.probs_buf.iter_mut() {
153 *p /= sum;
154 }
155 }
156
157 // Pre-compute random value before the immutable borrow of `probs_buf`
158 // to satisfy the borrow checker: `next_u64` takes `&mut self` which
159 // would conflict with an active `&self.probs_buf` borrow.
160 let rand_val = (self.next_u64() as f64 / u64::MAX as f64) as f32;
161
162 // Weighted random selection
163 let mut cum = 0.0f32;
164 for &(idx, p) in &self.probs_buf {
165 cum += p;
166 if rand_val <= cum {
167 return Ok(idx as u32);
168 }
169 }
170
171 // Fallback: return the highest probability token
172 Ok(self.probs_buf[0].0 as u32)
173 }
174
175 /// Get current parameters.
176 pub fn params(&self) -> &SamplingParams {
177 &self.params
178 }
179}
180
181/// Return the index of the maximum element.
182fn argmax(values: &[f32]) -> usize {
183 values
184 .iter()
185 .enumerate()
186 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
187 .map(|(i, _)| i)
188 .unwrap_or(0)
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn greedy_sampling() {
197 let params = SamplingParams {
198 temperature: 0.0,
199 ..SamplingParams::default()
200 };
201 let mut sampler = Sampler::new(params, 42);
202 let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
203 let token = sampler.sample(&logits).expect("sampling should succeed");
204 assert_eq!(token, 3); // index of 0.9
205 }
206
207 #[test]
208 fn sampling_returns_valid_index() {
209 let params = SamplingParams::default();
210 let mut sampler = Sampler::new(params, 12345);
211 let logits = vec![0.0f32; 100];
212 for _ in 0..50 {
213 let token = sampler.sample(&logits).expect("sampling should succeed");
214 assert!(token < 100);
215 }
216 }
217
218 #[test]
219 fn argmax_basic() {
220 assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
221 assert_eq!(argmax(&[5.0]), 0);
222 }
223
224 #[test]
225 fn buffer_reuse_across_calls() {
226 // Verify the probs_buf is correctly reused without incorrect state leaking.
227 let params = SamplingParams {
228 temperature: 0.7,
229 top_k: 5,
230 top_p: 1.0, // disable top-p so we control exactly
231 repetition_penalty: 1.0,
232 max_tokens: 128,
233 };
234 let mut sampler = Sampler::new(params, 99);
235 let logits: Vec<f32> = (0..200).map(|i| i as f32 * 0.01).collect();
236 for _ in 0..20 {
237 let token = sampler.sample(&logits).expect("sampling should succeed");
238 // Top-k=5 on ascending logits: only the last 5 indices (195-199) are valid
239 assert!(token >= 195, "expected token ≥ 195, got {token}");
240 }
241 }
242}