1use crate::error::{SeqError, SeqResult};
26use crate::handle::LcgRng;
27
28#[derive(Debug, Clone, Copy)]
30pub struct TopKConfig {
31 pub k: usize,
33 pub temperature: f64,
36}
37
38impl Default for TopKConfig {
39 fn default() -> Self {
40 Self {
41 k: 50,
42 temperature: 1.0,
43 }
44 }
45}
46
47impl TopKConfig {
48 fn validate(&self) -> SeqResult<()> {
49 if self.k == 0 {
50 return Err(SeqError::InvalidConfiguration(
51 "top-k: k must be >= 1".to_string(),
52 ));
53 }
54 if !self.temperature.is_finite() || self.temperature <= 0.0 {
55 return Err(SeqError::InvalidParameter {
56 name: "temperature".to_string(),
57 value: self.temperature,
58 });
59 }
60 Ok(())
61 }
62}
63
64pub fn top_k_sample(logits: &[f64], cfg: &TopKConfig, rng: &mut LcgRng) -> SeqResult<usize> {
72 cfg.validate()?;
73 if logits.is_empty() {
74 return Err(SeqError::EmptyInput);
75 }
76 let v = logits.len();
77 let k_eff = cfg.k.min(v);
78
79 let scaled: Vec<f64> = logits.iter().map(|&z| z / cfg.temperature).collect();
81
82 if k_eff == 1 {
84 return Ok(argmax(&scaled));
85 }
86
87 let indices = top_k_indices(&scaled, k_eff);
89
90 let max_z = indices
92 .iter()
93 .map(|&i| scaled[i])
94 .fold(f64::NEG_INFINITY, f64::max);
95 let mut probs = vec![0.0_f64; k_eff];
96 let mut sum = 0.0_f64;
97 for (slot, &i) in indices.iter().enumerate() {
98 let w = (scaled[i] - max_z).exp();
99 probs[slot] = w;
100 sum += w;
101 }
102 if !sum.is_finite() || sum <= 0.0 {
103 return Err(SeqError::NumericalInstability(
104 "top-k: softmax denominator non-positive".to_string(),
105 ));
106 }
107 for p in probs.iter_mut() {
108 *p /= sum;
109 }
110
111 let chosen_slot = rng.sample_categorical(&probs);
112 Ok(indices[chosen_slot])
113}
114
115pub fn top_k_sample_batch(
123 logits: &[f64],
124 n: usize,
125 vocab: usize,
126 cfg: &TopKConfig,
127 rng: &mut LcgRng,
128) -> SeqResult<Vec<usize>> {
129 cfg.validate()?;
130 if logits.is_empty() || n == 0 || vocab == 0 {
131 return Err(SeqError::EmptyInput);
132 }
133 if logits.len() != n * vocab {
134 return Err(SeqError::ShapeMismatch {
135 expected: n * vocab,
136 got: logits.len(),
137 });
138 }
139 let mut out = Vec::with_capacity(n);
140 for b in 0..n {
141 let row = &logits[b * vocab..(b + 1) * vocab];
142 out.push(top_k_sample(row, cfg, rng)?);
143 }
144 Ok(out)
145}
146
147#[inline]
149fn argmax(xs: &[f64]) -> usize {
150 let mut best = 0usize;
151 let mut best_v = xs[0];
152 for (i, &v) in xs.iter().enumerate().skip(1) {
153 if v > best_v {
154 best_v = v;
155 best = i;
156 }
157 }
158 best
159}
160
161fn top_k_indices(xs: &[f64], k: usize) -> Vec<usize> {
164 let mut idx: Vec<usize> = (0..xs.len()).collect();
169 idx.sort_by(|&a, &b| {
170 xs[b]
171 .partial_cmp(&xs[a])
172 .unwrap_or(std::cmp::Ordering::Equal)
173 });
174 idx.truncate(k);
175 idx
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 fn full_softmax(logits: &[f64], t: f64) -> Vec<f64> {
183 let scaled: Vec<f64> = logits.iter().map(|&z| z / t).collect();
184 let m = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
185 let exps: Vec<f64> = scaled.iter().map(|&z| (z - m).exp()).collect();
186 let s: f64 = exps.iter().sum();
187 exps.iter().map(|&e| e / s).collect()
188 }
189
190 #[test]
191 fn k_zero_rejected() {
192 let cfg = TopKConfig {
193 k: 0,
194 temperature: 1.0,
195 };
196 let mut rng = LcgRng::new(0);
197 let err = top_k_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
198 assert!(matches!(err, SeqError::InvalidConfiguration(_)));
199 }
200
201 #[test]
202 fn nonpositive_temperature_rejected() {
203 let mut rng = LcgRng::new(0);
204 for t in [0.0_f64, -1.0, f64::NAN] {
205 let cfg = TopKConfig {
206 k: 2,
207 temperature: t,
208 };
209 let err = top_k_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
210 assert!(matches!(err, SeqError::InvalidParameter { .. }));
211 }
212 }
213
214 #[test]
215 fn empty_logits_rejected() {
216 let cfg = TopKConfig::default();
217 let mut rng = LcgRng::new(0);
218 let err = top_k_sample(&[], &cfg, &mut rng).unwrap_err();
219 assert!(matches!(err, SeqError::EmptyInput));
220 }
221
222 #[test]
223 fn k_at_least_vocab_full_softmax() {
224 let logits = vec![0.0, 0.0, 0.0];
227 let cfg = TopKConfig {
228 k: 10,
229 temperature: 1.0,
230 };
231 let mut rng = LcgRng::new(42);
232 let mut counts = [0usize; 3];
233 for _ in 0..3000 {
234 let tok = top_k_sample(&logits, &cfg, &mut rng).expect("sample ok");
235 counts[tok] += 1;
236 }
237 for c in counts {
238 assert!(
239 c > 700,
240 "every token should be sampled: counts = {counts:?}"
241 );
242 }
243 }
244
245 #[test]
246 fn k_one_is_argmax() {
247 let logits = vec![-1.0, 4.5, 2.0, 4.5_f64.next_down()];
248 let cfg = TopKConfig {
249 k: 1,
250 temperature: 0.7,
251 };
252 let mut rng_a = LcgRng::new(0);
253 let mut rng_b = LcgRng::new(999_999);
254 let tok_a = top_k_sample(&logits, &cfg, &mut rng_a).expect("sample ok");
255 let tok_b = top_k_sample(&logits, &cfg, &mut rng_b).expect("sample ok");
256 assert_eq!(tok_a, 1);
257 assert_eq!(tok_b, 1, "k=1 must be deterministic regardless of rng");
258 }
259
260 #[test]
261 fn deterministic_with_seed() {
262 let logits = vec![0.5, 1.2, -0.3, 0.8, 2.1];
263 let cfg = TopKConfig {
264 k: 3,
265 temperature: 1.0,
266 };
267 let mut rng_a = LcgRng::new(123);
268 let mut rng_b = LcgRng::new(123);
269 for _ in 0..200 {
270 let a = top_k_sample(&logits, &cfg, &mut rng_a).expect("ok");
271 let b = top_k_sample(&logits, &cfg, &mut rng_b).expect("ok");
272 assert_eq!(a, b);
273 }
274 }
275
276 #[test]
277 fn distribution_matches_renormalised_softmax() {
278 let logits = vec![3.0_f64, 1.0, 0.0, -2.0, -5.0];
282 let cfg = TopKConfig {
283 k: 3,
284 temperature: 1.0,
285 };
286 let full = full_softmax(&logits[..3], 1.0);
288 let n_samples = 6000usize;
289 let mut rng = LcgRng::new(7);
290 let mut counts = [0usize; 3];
291 for _ in 0..n_samples {
292 let t = top_k_sample(&logits, &cfg, &mut rng).expect("ok");
293 assert!(t < 3, "top-k must never pick a truncated index");
294 counts[t] += 1;
295 }
296 let mut chi2 = 0.0_f64;
297 for i in 0..3 {
298 let expected = full[i] * n_samples as f64;
299 let diff = counts[i] as f64 - expected;
300 chi2 += diff * diff / expected;
301 }
302 assert!(chi2 < 9.21, "chi-square = {chi2}");
304 }
305
306 #[test]
307 fn batch_correctness() {
308 let logits = vec![10.0, -10.0, -10.0, -10.0, 10.0, -10.0];
310 let cfg = TopKConfig {
311 k: 2,
312 temperature: 1.0,
313 };
314 let mut rng = LcgRng::new(0);
315 let out = top_k_sample_batch(&logits, 2, 3, &cfg, &mut rng).expect("ok");
316 assert_eq!(out, vec![0, 1]);
317 }
318
319 #[test]
320 fn batch_empty_rejected() {
321 let cfg = TopKConfig::default();
322 let mut rng = LcgRng::new(0);
323 assert!(matches!(
324 top_k_sample_batch(&[], 0, 3, &cfg, &mut rng).unwrap_err(),
325 SeqError::EmptyInput
326 ));
327 assert!(matches!(
328 top_k_sample_batch(&[0.0, 0.0], 1, 0, &cfg, &mut rng).unwrap_err(),
329 SeqError::EmptyInput
330 ));
331 }
332
333 #[test]
334 fn batch_shape_mismatch_rejected() {
335 let logits = vec![0.0_f64; 5];
336 let cfg = TopKConfig::default();
337 let mut rng = LcgRng::new(0);
338 let err = top_k_sample_batch(&logits, 2, 3, &cfg, &mut rng).unwrap_err();
339 assert!(matches!(err, SeqError::ShapeMismatch { .. }));
340 }
341
342 #[test]
343 fn high_temperature_flattens() {
344 let logits = vec![5.0, 0.0, 0.0, 0.0];
347 let cfg = TopKConfig {
348 k: 4,
349 temperature: 50.0,
350 };
351 let mut rng = LcgRng::new(1);
352 let mut counts = [0usize; 4];
353 for _ in 0..4000 {
354 counts[top_k_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
355 }
356 for c in counts {
357 assert!(c > 700);
359 }
360 }
361
362 #[test]
363 fn low_temperature_sharpens() {
364 let logits = vec![3.0, 1.0, 0.0, -1.0];
367 let cfg = TopKConfig {
368 k: 4,
369 temperature: 0.05,
370 };
371 let mut rng = LcgRng::new(0);
372 let mut argmax_count = 0usize;
373 for _ in 0..1000 {
374 if top_k_sample(&logits, &cfg, &mut rng).expect("ok") == 0 {
375 argmax_count += 1;
376 }
377 }
378 assert!(argmax_count > 980);
379 }
380
381 #[test]
382 fn top_k_never_picks_truncated_token() {
383 let logits = vec![5.0, 4.5, -3.0, -4.0, -10.0];
386 let cfg = TopKConfig {
387 k: 2,
388 temperature: 1.0,
389 };
390 let mut rng = LcgRng::new(42);
391 for _ in 0..500 {
392 let t = top_k_sample(&logits, &cfg, &mut rng).expect("ok");
393 assert!(t == 0 || t == 1, "got truncated token {t}");
394 }
395 }
396
397 #[test]
398 fn single_vocab_returns_zero() {
399 let logits = vec![2.71_f64];
400 let cfg = TopKConfig {
401 k: 5,
402 temperature: 1.0,
403 };
404 let mut rng = LcgRng::new(0);
405 for _ in 0..10 {
406 assert_eq!(top_k_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
407 }
408 }
409}