1use rlx_ir::Philox4x32;
32
33#[derive(Debug, Clone, Copy)]
42pub struct SampleOpts {
43 pub temperature: f32,
44 pub top_k: usize,
45 pub top_p: f32,
46 pub seed: u64,
47 pub greedy: bool,
48}
49
50impl SampleOpts {
51 pub fn greedy() -> Self {
52 Self {
53 temperature: 1.0,
54 top_k: 0,
55 top_p: 1.0,
56 seed: 0,
57 greedy: true,
58 }
59 }
60
61 pub fn temperature(temp: f32, seed: u64) -> Self {
62 Self {
63 temperature: temp,
64 top_k: 0,
65 top_p: 1.0,
66 seed,
67 greedy: false,
68 }
69 }
70
71 pub fn with_top_k(mut self, k: usize) -> Self {
72 self.top_k = k;
73 self
74 }
75
76 pub fn with_top_p(mut self, p: f32) -> Self {
77 self.top_p = p;
78 self
79 }
80}
81
82pub fn sample_token(logits: &[f32], opts: SampleOpts) -> usize {
87 assert!(!logits.is_empty(), "sample_token: empty logits");
88
89 if opts.greedy {
90 return argmax(logits);
91 }
92
93 let mut work: Vec<f32> = if opts.temperature > 0.0 && opts.temperature != 1.0 {
95 logits.iter().map(|&l| l / opts.temperature).collect()
96 } else {
97 logits.to_vec()
98 };
99
100 if opts.top_k > 0 && opts.top_k < work.len() {
102 let mut indexed: Vec<(usize, f32)> =
103 work.iter().enumerate().map(|(i, &v)| (i, v)).collect();
104 indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
106 let cutoff = indexed[opts.top_k - 1].1;
107 for v in work.iter_mut() {
108 if *v < cutoff {
109 *v = f32::NEG_INFINITY;
110 }
111 }
112 }
113
114 let max = work.iter().copied().fold(f32::NEG_INFINITY, f32::max);
116 let mut probs: Vec<f32> = work.iter().map(|&l| (l - max).exp()).collect();
117 let sum: f32 = probs.iter().sum();
118 if sum > 0.0 {
119 for p in probs.iter_mut() {
120 *p /= sum;
121 }
122 } else {
123 return argmax(logits);
125 }
126
127 if opts.top_p < 1.0 && opts.top_p > 0.0 {
129 let mut order: Vec<usize> = (0..probs.len()).collect();
130 order.sort_unstable_by(|&a, &b| {
131 probs[b]
132 .partial_cmp(&probs[a])
133 .unwrap_or(std::cmp::Ordering::Equal)
134 });
135 let mut cum = 0.0f32;
136 let mut keep = vec![false; probs.len()];
137 for &i in &order {
138 cum += probs[i];
139 keep[i] = true;
140 if cum >= opts.top_p {
141 break;
142 }
143 }
144 let mut renorm = 0.0f32;
145 for (i, p) in probs.iter_mut().enumerate() {
146 if !keep[i] {
147 *p = 0.0;
148 } else {
149 renorm += *p;
150 }
151 }
152 if renorm > 0.0 {
153 for p in probs.iter_mut() {
154 *p /= renorm;
155 }
156 }
157 }
158
159 let mut rng = Philox4x32::new(opts.seed);
161 let u = rng.next_f32();
162 let mut acc = 0.0f32;
163 for (i, &p) in probs.iter().enumerate() {
164 acc += p;
165 if u < acc {
166 return i;
167 }
168 }
169 probs.len() - 1
170}
171
172fn argmax(xs: &[f32]) -> usize {
173 let mut best = 0usize;
174 let mut best_v = f32::NEG_INFINITY;
175 for (i, &v) in xs.iter().enumerate() {
176 if v > best_v {
177 best_v = v;
178 best = i;
179 }
180 }
181 best
182}
183
184pub fn softmax_logits(logits: &[f32]) -> Vec<f32> {
188 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
189 let mut p: Vec<f32> = logits.iter().map(|&l| (l - max).exp()).collect();
190 let sum: f32 = p.iter().sum();
191 if sum > 0.0 {
192 for v in p.iter_mut() {
193 *v /= sum;
194 }
195 }
196 p
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn greedy_matches_argmax() {
205 let logits = vec![0.1, 0.5, 0.2, -1.0, 0.49];
206 let t = sample_token(&logits, SampleOpts::greedy());
207 assert_eq!(t, 1);
208 }
209
210 #[test]
211 fn top_k_one_equals_greedy() {
212 let logits = vec![0.1, 0.5, 0.2, -1.0, 0.49];
213 let opts = SampleOpts::temperature(1.0, 42).with_top_k(1);
214 assert_eq!(sample_token(&logits, opts), 1);
215 }
216
217 #[test]
218 fn top_p_full_equals_unrestricted_multinomial() {
219 let logits = vec![1.0, 2.0, 0.5, 0.0];
222 let opts = SampleOpts::temperature(1.0, 7).with_top_p(1.0);
223 let t = sample_token(&logits, opts);
224 assert!(t < logits.len());
225 }
226
227 #[test]
228 fn deterministic_for_same_seed() {
229 let logits: Vec<f32> = (0..32).map(|i| (i as f32) * 0.01).collect();
230 let opts = SampleOpts::temperature(0.7, 123).with_top_k(4);
231 let a = sample_token(&logits, opts);
232 let b = sample_token(&logits, opts);
233 assert_eq!(a, b);
234 }
235
236 #[test]
237 fn top_p_truncates_low_mass() {
238 let mut logits = vec![-10.0f32; 16];
241 logits[7] = 10.0;
242 let opts = SampleOpts::temperature(1.0, 999).with_top_p(0.5);
243 assert_eq!(sample_token(&logits, opts), 7);
244 }
245
246 #[test]
247 fn high_temperature_still_returns_valid_id() {
248 let logits = vec![0.0; 10];
249 let opts = SampleOpts::temperature(100.0, 1);
250 let t = sample_token(&logits, opts);
251 assert!(t < 10);
252 }
253}