Skip to main content

rlx_qwen3/
sampling.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Host-side logits sampler.
17//!
18//! Operates on a single `[vocab]` slice — caller is responsible for
19//! pulling the last position's row out of `[B, S, vocab]` logits.
20//!
21//! Sampling is a host-side step (not a graph op) for now because:
22//!   - The decision tree (temperature → top-k → top-p → multinomial)
23//!     is branchy and cheap; no win from baking it into the graph.
24//!   - Keeping it out of the graph lets a downstream `Speculator`
25//!     impl call the same sampler for both the draft and the
26//!     verifier without graph surgery.
27//!
28//! Determinism: backed by `rlx_ir::Philox4x32`, same RNG already used
29//! by `rlx-runtime/src/spec_decode.rs`. Same seed → same sequence.
30
31use rlx_ir::Philox4x32;
32
33/// Sampling configuration. Construct via [`SampleOpts::greedy`] /
34/// [`SampleOpts::temperature`] or build manually.
35///
36/// Order of operations matches HF `transformers` defaults:
37///   1. `temperature` divides logits (skipped if `<= 0` or `1.0`).
38///   2. `top_k` truncates to the K highest-logit tokens (0 = disabled).
39///   3. `top_p` truncates by nucleus cumulative-mass cutoff (1.0 = disabled).
40///   4. Softmax + multinomial sample (or argmax when greedy).
41#[derive(Debug, Clone, Copy)]
42pub struct SampleOpts {
43    pub temperature: f32,
44    pub top_k: usize,
45    pub top_p: f32,
46    pub seed: u64,
47    pub greedy: bool,
48}
49
50impl SampleOpts {
51    pub fn greedy() -> Self {
52        Self {
53            temperature: 1.0,
54            top_k: 0,
55            top_p: 1.0,
56            seed: 0,
57            greedy: true,
58        }
59    }
60
61    pub fn temperature(temp: f32, seed: u64) -> Self {
62        Self {
63            temperature: temp,
64            top_k: 0,
65            top_p: 1.0,
66            seed,
67            greedy: false,
68        }
69    }
70
71    pub fn with_top_k(mut self, k: usize) -> Self {
72        self.top_k = k;
73        self
74    }
75
76    pub fn with_top_p(mut self, p: f32) -> Self {
77        self.top_p = p;
78        self
79    }
80}
81
82/// Sample one token id from a `[vocab]` logits slice. Returns the
83/// chosen index. Stateless w.r.t. prior calls — the RNG is seeded
84/// per-call from `opts.seed` so repeated calls with the same seed
85/// and logits yield the same token.
86pub fn sample_token(logits: &[f32], opts: SampleOpts) -> usize {
87    assert!(!logits.is_empty(), "sample_token: empty logits");
88
89    if opts.greedy {
90        return argmax(logits);
91    }
92
93    // 1. temperature: divide logits, in place on a working copy.
94    let mut work: Vec<f32> = if opts.temperature > 0.0 && opts.temperature != 1.0 {
95        logits.iter().map(|&l| l / opts.temperature).collect()
96    } else {
97        logits.to_vec()
98    };
99
100    // 2. top_k: mask everything outside the K highest logits.
101    if opts.top_k > 0 && opts.top_k < work.len() {
102        let mut indexed: Vec<(usize, f32)> =
103            work.iter().enumerate().map(|(i, &v)| (i, v)).collect();
104        // Partial sort: nth_element-style, descending.
105        indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
106        let cutoff = indexed[opts.top_k - 1].1;
107        for v in work.iter_mut() {
108            if *v < cutoff {
109                *v = f32::NEG_INFINITY;
110            }
111        }
112    }
113
114    // 3. softmax (numerically stable).
115    let max = work.iter().copied().fold(f32::NEG_INFINITY, f32::max);
116    let mut probs: Vec<f32> = work.iter().map(|&l| (l - max).exp()).collect();
117    let sum: f32 = probs.iter().sum();
118    if sum > 0.0 {
119        for p in probs.iter_mut() {
120            *p /= sum;
121        }
122    } else {
123        // All -inf (shouldn't happen post-softmax): fall back to greedy.
124        return argmax(logits);
125    }
126
127    // 4. top_p: nucleus cutoff over sorted-descending probability.
128    if opts.top_p < 1.0 && opts.top_p > 0.0 {
129        let mut order: Vec<usize> = (0..probs.len()).collect();
130        order.sort_unstable_by(|&a, &b| {
131            probs[b]
132                .partial_cmp(&probs[a])
133                .unwrap_or(std::cmp::Ordering::Equal)
134        });
135        let mut cum = 0.0f32;
136        let mut keep = vec![false; probs.len()];
137        for &i in &order {
138            cum += probs[i];
139            keep[i] = true;
140            if cum >= opts.top_p {
141                break;
142            }
143        }
144        let mut renorm = 0.0f32;
145        for (i, p) in probs.iter_mut().enumerate() {
146            if !keep[i] {
147                *p = 0.0;
148            } else {
149                renorm += *p;
150            }
151        }
152        if renorm > 0.0 {
153            for p in probs.iter_mut() {
154                *p /= renorm;
155            }
156        }
157    }
158
159    // 5. multinomial sample.
160    let mut rng = Philox4x32::new(opts.seed);
161    let u = rng.next_f32();
162    let mut acc = 0.0f32;
163    for (i, &p) in probs.iter().enumerate() {
164        acc += p;
165        if u < acc {
166            return i;
167        }
168    }
169    probs.len() - 1
170}
171
172fn argmax(xs: &[f32]) -> usize {
173    let mut best = 0usize;
174    let mut best_v = f32::NEG_INFINITY;
175    for (i, &v) in xs.iter().enumerate() {
176        if v > best_v {
177            best_v = v;
178            best = i;
179        }
180    }
181    best
182}
183
184/// Numerically-stable softmax over a logits row. Exposed so
185/// `Speculator` implementations can hand the resulting probability
186/// vector to `rlx-runtime::spec_decode` without re-implementing it.
187pub fn softmax_logits(logits: &[f32]) -> Vec<f32> {
188    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
189    let mut p: Vec<f32> = logits.iter().map(|&l| (l - max).exp()).collect();
190    let sum: f32 = p.iter().sum();
191    if sum > 0.0 {
192        for v in p.iter_mut() {
193            *v /= sum;
194        }
195    }
196    p
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn greedy_matches_argmax() {
205        let logits = vec![0.1, 0.5, 0.2, -1.0, 0.49];
206        let t = sample_token(&logits, SampleOpts::greedy());
207        assert_eq!(t, 1);
208    }
209
210    #[test]
211    fn top_k_one_equals_greedy() {
212        let logits = vec![0.1, 0.5, 0.2, -1.0, 0.49];
213        let opts = SampleOpts::temperature(1.0, 42).with_top_k(1);
214        assert_eq!(sample_token(&logits, opts), 1);
215    }
216
217    #[test]
218    fn top_p_full_equals_unrestricted_multinomial() {
219        // With top_p=1.0 the nucleus mask is a no-op; sampling should
220        // still be deterministic given the seed and produce a valid id.
221        let logits = vec![1.0, 2.0, 0.5, 0.0];
222        let opts = SampleOpts::temperature(1.0, 7).with_top_p(1.0);
223        let t = sample_token(&logits, opts);
224        assert!(t < logits.len());
225    }
226
227    #[test]
228    fn deterministic_for_same_seed() {
229        let logits: Vec<f32> = (0..32).map(|i| (i as f32) * 0.01).collect();
230        let opts = SampleOpts::temperature(0.7, 123).with_top_k(4);
231        let a = sample_token(&logits, opts);
232        let b = sample_token(&logits, opts);
233        assert_eq!(a, b);
234    }
235
236    #[test]
237    fn top_p_truncates_low_mass() {
238        // One token has nearly all the mass; top_p=0.5 should keep
239        // only that token and pick it regardless of RNG.
240        let mut logits = vec![-10.0f32; 16];
241        logits[7] = 10.0;
242        let opts = SampleOpts::temperature(1.0, 999).with_top_p(0.5);
243        assert_eq!(sample_token(&logits, opts), 7);
244    }
245
246    #[test]
247    fn high_temperature_still_returns_valid_id() {
248        let logits = vec![0.0; 10];
249        let opts = SampleOpts::temperature(100.0, 1);
250        let t = sample_token(&logits, opts);
251        assert!(t < 10);
252    }
253}