1pub use rlx_runtime::SampleOpts;
23
24pub fn argmax(logits: &[f32]) -> u32 {
26 let mut best = 0usize;
27 let mut best_v = f32::NEG_INFINITY;
28 for (i, &v) in logits.iter().enumerate() {
29 if v > best_v {
30 best_v = v;
31 best = i;
32 }
33 }
34 best as u32
35}
36
37pub fn sample_next(logits: &[f32], history: &[u32], opts: &SampleOpts, rng: &mut u32) -> u32 {
46 if opts.is_greedy() {
47 return argmax(logits);
48 }
49 let mut probs: Vec<f32> = logits.to_vec();
50
51 if (opts.repetition_penalty - 1.0).abs() > 1e-6 {
53 for &id in history {
54 let i = id as usize;
55 if i < probs.len() {
56 let p = probs[i];
57 probs[i] = if p > 0.0 {
58 p / opts.repetition_penalty
59 } else {
60 p * opts.repetition_penalty
61 };
62 }
63 }
64 }
65
66 if opts.temperature > 0.0 && (opts.temperature - 1.0).abs() > 1e-6 {
68 let inv = 1.0 / opts.temperature;
69 for v in probs.iter_mut() {
70 *v *= inv;
71 }
72 }
73
74 let max = probs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
76 let mut sum = 0f32;
77 for v in probs.iter_mut() {
78 *v = (*v - max).exp();
79 sum += *v;
80 }
81 if sum <= 0.0 {
82 return argmax(logits);
83 }
84 for v in probs.iter_mut() {
85 *v /= sum;
86 }
87
88 let mut order: Vec<usize> = (0..probs.len()).collect();
90 order.sort_by(|a, b| {
91 probs[*b]
92 .partial_cmp(&probs[*a])
93 .unwrap_or(std::cmp::Ordering::Equal)
94 });
95
96 let k = opts.top_k.unwrap_or(order.len() as u32) as usize;
98 let k = k.min(order.len());
99
100 let mut acc = 0f32;
102 let mut keep = 0usize;
103 for &idx in order.iter().take(k) {
104 acc += probs[idx];
105 keep += 1;
106 if acc >= opts.top_p {
107 break;
108 }
109 }
110 let keep = keep.max(1);
111
112 let mut total = 0f32;
114 for &idx in order.iter().take(keep) {
115 total += probs[idx];
116 }
117 if total <= 0.0 {
118 return order[0] as u32;
119 }
120
121 *rng = ((*rng as u64 * 48271) % 0x7FFFFFFF) as u32;
123 let r = (*rng as f32) / (0x7FFFFFFF as f32);
124 let target = r * total;
125 let mut acc = 0f32;
126 for &idx in order.iter().take(keep) {
127 acc += probs[idx];
128 if acc >= target {
129 return idx as u32;
130 }
131 }
132 order[keep - 1] as u32
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn argmax_picks_largest() {
141 assert_eq!(argmax(&[1.0, 5.0, 3.0, -1.0]), 1);
142 }
143
144 #[test]
145 fn greedy_short_circuit() {
146 let mut rng = 1u32;
147 let logits = vec![1.0, 9.0, 3.0];
148 let opts = SampleOpts::greedy();
149 assert_eq!(sample_next(&logits, &[], &opts, &mut rng), 1);
150 }
151
152 #[test]
153 fn high_temperature_widens_distribution() {
154 let logits = vec![0.0, 1.0, 0.0];
155 let opts = SampleOpts::nucleus(2.0, 1.0);
156 let mut rng = 42u32;
157 let mut seen = std::collections::HashSet::new();
159 for _ in 0..50 {
160 seen.insert(sample_next(&logits, &[], &opts, &mut rng));
161 }
162 assert!(seen.len() >= 2);
163 }
164}