Skip to main content

rlx_text/
sampling.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Sampling helpers for LM runners.
17//!
18//! `SampleOpts` itself lives in `rlx-runtime::lm` so backend code can
19//! reference it without depending on `rlx-text`. This module adds the
20//! actual sampling routines that consume those options.
21
22pub use rlx_runtime::SampleOpts;
23
24/// Greedy argmax over `logits`.
25pub fn argmax(logits: &[f32]) -> u32 {
26    let mut best = 0usize;
27    let mut best_v = f32::NEG_INFINITY;
28    for (i, &v) in logits.iter().enumerate() {
29        if v > best_v {
30            best_v = v;
31            best = i;
32        }
33    }
34    best as u32
35}
36
37/// Sample one token from `logits` according to `opts`.
38///
39/// Greedy when `opts.temperature == 0`. Otherwise applies temperature,
40/// optional top-k, top-p (nucleus), and a repetition penalty against the
41/// `history` ids (set empty to disable).
42///
43/// `rng` is a 32-bit Lehmer LCG seed; pass any non-zero value. Returns
44/// the new seed so callers can chain.
45pub fn sample_next(logits: &[f32], history: &[u32], opts: &SampleOpts, rng: &mut u32) -> u32 {
46    if opts.is_greedy() {
47        return argmax(logits);
48    }
49    let mut probs: Vec<f32> = logits.to_vec();
50
51    // Repetition penalty (greater than 1 discourages repeats).
52    if (opts.repetition_penalty - 1.0).abs() > 1e-6 {
53        for &id in history {
54            let i = id as usize;
55            if i < probs.len() {
56                let p = probs[i];
57                probs[i] = if p > 0.0 {
58                    p / opts.repetition_penalty
59                } else {
60                    p * opts.repetition_penalty
61                };
62            }
63        }
64    }
65
66    // Temperature.
67    if opts.temperature > 0.0 && (opts.temperature - 1.0).abs() > 1e-6 {
68        let inv = 1.0 / opts.temperature;
69        for v in probs.iter_mut() {
70            *v *= inv;
71        }
72    }
73
74    // Softmax in-place.
75    let max = probs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
76    let mut sum = 0f32;
77    for v in probs.iter_mut() {
78        *v = (*v - max).exp();
79        sum += *v;
80    }
81    if sum <= 0.0 {
82        return argmax(logits);
83    }
84    for v in probs.iter_mut() {
85        *v /= sum;
86    }
87
88    // Sort indices by probability desc.
89    let mut order: Vec<usize> = (0..probs.len()).collect();
90    order.sort_by(|a, b| {
91        probs[*b]
92            .partial_cmp(&probs[*a])
93            .unwrap_or(std::cmp::Ordering::Equal)
94    });
95
96    // Top-k.
97    let k = opts.top_k.unwrap_or(order.len() as u32) as usize;
98    let k = k.min(order.len());
99
100    // Top-p.
101    let mut acc = 0f32;
102    let mut keep = 0usize;
103    for &idx in order.iter().take(k) {
104        acc += probs[idx];
105        keep += 1;
106        if acc >= opts.top_p {
107            break;
108        }
109    }
110    let keep = keep.max(1);
111
112    // Renormalise.
113    let mut total = 0f32;
114    for &idx in order.iter().take(keep) {
115        total += probs[idx];
116    }
117    if total <= 0.0 {
118        return order[0] as u32;
119    }
120
121    // Lehmer LCG: state * 48271 mod 2^31-1.
122    *rng = ((*rng as u64 * 48271) % 0x7FFFFFFF) as u32;
123    let r = (*rng as f32) / (0x7FFFFFFF as f32);
124    let target = r * total;
125    let mut acc = 0f32;
126    for &idx in order.iter().take(keep) {
127        acc += probs[idx];
128        if acc >= target {
129            return idx as u32;
130        }
131    }
132    order[keep - 1] as u32
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn argmax_picks_largest() {
141        assert_eq!(argmax(&[1.0, 5.0, 3.0, -1.0]), 1);
142    }
143
144    #[test]
145    fn greedy_short_circuit() {
146        let mut rng = 1u32;
147        let logits = vec![1.0, 9.0, 3.0];
148        let opts = SampleOpts::greedy();
149        assert_eq!(sample_next(&logits, &[], &opts, &mut rng), 1);
150    }
151
152    #[test]
153    fn high_temperature_widens_distribution() {
154        let logits = vec![0.0, 1.0, 0.0];
155        let opts = SampleOpts::nucleus(2.0, 1.0);
156        let mut rng = 42u32;
157        // Pull a few samples; should hit at least two distinct ids.
158        let mut seen = std::collections::HashSet::new();
159        for _ in 0..50 {
160            seen.insert(sample_next(&logits, &[], &opts, &mut rng));
161        }
162        assert!(seen.len() >= 2);
163    }
164}