provable_contracts/kernels/
sampling.rs1pub fn greedy_scalar(logits: &[f32]) -> usize {
20 assert!(!logits.is_empty(), "logits must not be empty");
21 let mut best_idx = 0;
22 let mut best_val = logits[0];
23 for (i, &v) in logits.iter().enumerate().skip(1) {
24 if v > best_val {
25 best_val = v;
26 best_idx = i;
27 }
28 }
29 best_idx
30}
31
32pub fn temperature_scalar(logits: &mut [f32], temperature: f32) {
37 assert!(
38 temperature > 0.0,
39 "temperature must be positive, got {temperature}"
40 );
41 for v in logits.iter_mut() {
42 *v /= temperature;
43 }
44}
45
46pub fn top_k_scalar(probs: &mut [f32], k: usize) {
53 let n = probs.len();
54 assert!(k > 0 && k <= n, "k={k} must be in [1, {n}]");
55
56 if k == n {
57 return; }
59
60 let mut indices: Vec<usize> = (0..n).collect();
62 indices.sort_by(|&a, &b| {
63 probs[b]
64 .partial_cmp(&probs[a])
65 .unwrap_or(std::cmp::Ordering::Equal)
66 });
67
68 for &idx in &indices[k..] {
70 probs[idx] = 0.0;
71 }
72
73 let sum: f32 = probs.iter().sum();
75 if sum > 0.0 {
76 for v in probs.iter_mut() {
77 *v /= sum;
78 }
79 }
80}
81
82pub fn top_p_scalar(probs: &mut [f32], threshold: f32) {
90 let n = probs.len();
91 assert!(
92 threshold > 0.0 && threshold <= 1.0,
93 "threshold must be in (0, 1], got {threshold}"
94 );
95
96 let mut indices: Vec<usize> = (0..n).collect();
98 indices.sort_by(|&a, &b| {
99 probs[b]
100 .partial_cmp(&probs[a])
101 .unwrap_or(std::cmp::Ordering::Equal)
102 });
103
104 let mut cumsum = 0.0f32;
106 let mut cutoff = n;
107 for (rank, &idx) in indices.iter().enumerate() {
108 cumsum += probs[idx];
109 if cumsum >= threshold {
110 cutoff = rank + 1;
111 break;
112 }
113 }
114
115 for &idx in &indices[cutoff..] {
117 probs[idx] = 0.0;
118 }
119
120 let sum: f32 = probs.iter().sum();
122 if sum > 0.0 {
123 for v in probs.iter_mut() {
124 *v /= sum;
125 }
126 }
127}
128
129pub fn sample_scalar(logits: &[f32]) -> usize {
133 greedy_scalar(logits)
134}
135
136#[cfg(target_arch = "x86_64")]
145#[target_feature(enable = "avx2")]
146pub unsafe fn greedy_avx2(logits: &[f32]) -> usize {
147 greedy_scalar(logits)
148}
149
150#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx2")]
156pub unsafe fn temperature_avx2(logits: &mut [f32], temperature: f32) {
157 temperature_scalar(logits, temperature);
158}
159
160pub fn sampling_ptx() -> &'static str {
168 r#".version 8.5
169.target sm_90
170.address_size 64
171.visible .entry greedy_kernel(
172 .param .u64 LOGITS,
173 .param .u64 OUT_IDX,
174 .param .u32 VOCAB_SIZE
175) {
176 .reg .u32 %tid, %vocab_size, %k, %best_idx, %cur_idx;
177 .reg .u64 %logits_ptr, %out_ptr, %addr, %off64;
178 .reg .f32 %best_val, %cur_val;
179 .reg .pred %p_loop, %p_better;
180
181 mov.u32 %tid, %tid.x;
182
183 ld.param.u32 %vocab_size, [VOCAB_SIZE];
184 ld.param.u64 %logits_ptr, [LOGITS];
185 ld.param.u64 %out_ptr, [OUT_IDX];
186
187 // Only thread 0 performs the scan (simple serial argmax)
188 setp.ne.u32 %p_loop, %tid, 0;
189 @%p_loop bra EXIT;
190
191 // Load first element as initial best
192 ld.global.f32 %best_val, [%logits_ptr];
193 mov.u32 %best_idx, 0;
194 mov.u32 %k, 1;
195
196SCAN_LOOP:
197 setp.ge.u32 %p_loop, %k, %vocab_size;
198 @%p_loop bra STORE;
199
200 mul.wide.u32 %off64, %k, 4;
201 add.u64 %addr, %logits_ptr, %off64;
202 ld.global.f32 %cur_val, [%addr];
203
204 setp.gt.f32 %p_better, %cur_val, %best_val;
205 @!%p_better bra NEXT;
206 mov.f32 %best_val, %cur_val;
207 mov.u32 %best_idx, %k;
208NEXT:
209 add.u32 %k, %k, 1;
210 bra SCAN_LOOP;
211
212STORE:
213 st.global.u32 [%out_ptr], %best_idx;
214
215EXIT:
216 ret;
217}
218"#
219}
220
221#[cfg(test)]
226mod tests {
227 use super::*;
228 use proptest::prelude::*;
229
230 #[test]
231 fn test_greedy_basic() {
232 assert_eq!(greedy_scalar(&[1.0, 3.0, 2.0]), 1);
233 assert_eq!(greedy_scalar(&[5.0]), 0);
234 assert_eq!(greedy_scalar(&[0.0, 0.0, 0.0, 1.0]), 3);
235 }
236
237 #[test]
238 fn test_greedy_is_argmax() {
239 let logits = [0.1, 0.5, -0.3, 0.8, 0.2];
240 let result = greedy_scalar(&logits);
241 let argmax = logits
242 .iter()
243 .enumerate()
244 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
245 .unwrap()
246 .0;
247 assert_eq!(result, argmax);
248 }
249
250 #[test]
251 fn test_temperature_identity() {
252 let original = [1.0, 2.0, 3.0];
253 let mut scaled = original;
254 temperature_scalar(&mut scaled, 1.0);
255 assert_eq!(scaled, original);
256 }
257
258 #[test]
259 fn test_temperature_scaling() {
260 let mut logits = [2.0, 4.0];
261 temperature_scalar(&mut logits, 2.0);
262 assert!((logits[0] - 1.0).abs() < 1e-6);
263 assert!((logits[1] - 2.0).abs() < 1e-6);
264 }
265
266 #[test]
267 fn test_top_k_cardinality() {
268 let mut probs = [0.1, 0.2, 0.3, 0.4];
269 top_k_scalar(&mut probs, 2);
270 let nonzero = probs.iter().filter(|&&p| p > 0.0).count();
271 assert!(nonzero <= 2, "expected at most 2, got {nonzero}");
272 }
273
274 #[test]
275 fn test_top_k_keeps_highest() {
276 let mut probs = [0.1, 0.4, 0.2, 0.3];
277 top_k_scalar(&mut probs, 2);
278 assert_eq!(probs[0], 0.0);
280 assert!(probs[1] > 0.0);
281 assert_eq!(probs[2], 0.0);
282 assert!(probs[3] > 0.0);
283 }
284
285 #[test]
286 fn test_top_k_renormalizes() {
287 let mut probs = [0.1, 0.2, 0.3, 0.4];
288 top_k_scalar(&mut probs, 2);
289 let sum: f32 = probs.iter().sum();
290 assert!((sum - 1.0).abs() < 1e-5, "sum should be 1.0, got {sum}");
291 }
292
293 #[test]
294 fn test_top_p_cumulative() {
295 let mut probs = [0.1, 0.2, 0.3, 0.4];
296 let threshold = 0.6;
297 top_p_scalar(&mut probs, threshold);
298 let sum: f32 = probs.iter().sum();
299 assert!(sum >= threshold - 1e-5, "sum {sum} < threshold {threshold}");
300 }
301
302 #[test]
303 fn test_top_p_minimal_set() {
304 let mut probs = [0.1, 0.2, 0.3, 0.4];
305 top_p_scalar(&mut probs, 0.5);
306 let nonzero = probs.iter().filter(|&&p| p > 0.0).count();
309 assert!(
310 nonzero <= 2,
311 "expected minimal set size <= 2, got {nonzero}"
312 );
313 }
314
315 proptest! {
316 #[test]
317 fn prop_greedy_is_argmax(logits in proptest::collection::vec(-10.0f32..10.0, 1..16)) {
318 let result = greedy_scalar(&logits);
319 let argmax = logits.iter().enumerate()
320 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
321 .unwrap().0;
322 prop_assert_eq!(result, argmax);
323 }
324
325 #[test]
326 fn prop_top_k_cardinality(
327 k in 1usize..8,
328 n in 8usize..16,
329 ) {
330 let mut probs: Vec<f32> = (0..n).map(|i| (i as f32 + 1.0) / (n as f32)).collect();
331 let sum: f32 = probs.iter().sum();
332 for v in probs.iter_mut() { *v /= sum; }
333
334 top_k_scalar(&mut probs, k);
335 let nonzero = probs.iter().filter(|&&p| p > 0.0).count();
336 prop_assert!(nonzero <= k, "nonzero={nonzero} > k={k}");
337 }
338
339 #[test]
340 fn prop_temperature_identity(logits in proptest::collection::vec(-10.0f32..10.0, 1..16)) {
341 let original = logits.clone();
342 let mut scaled = logits;
343 temperature_scalar(&mut scaled, 1.0);
344 for (a, b) in original.iter().zip(scaled.iter()) {
345 prop_assert!((a - b).abs() < 1e-6);
346 }
347 }
348 }
349
350 #[test]
351 fn test_sample_scalar_delegates_to_greedy() {
352 let logits = [0.1, 0.5, -0.3, 0.8, 0.2];
353 assert_eq!(sample_scalar(&logits), greedy_scalar(&logits));
354 }
355
356 #[test]
357 fn test_top_k_full_k_is_noop() {
358 let mut probs = [0.25, 0.25, 0.25, 0.25];
359 let original = probs;
360 top_k_scalar(&mut probs, 4);
361 assert_eq!(probs, original);
362 }
363
364 #[test]
365 fn test_top_p_threshold_one() {
366 let mut probs = [0.25, 0.25, 0.25, 0.25];
367 top_p_scalar(&mut probs, 1.0);
368 let sum: f32 = probs.iter().sum();
369 assert!((sum - 1.0).abs() < 1e-5);
370 }
371
372 #[cfg(target_arch = "x86_64")]
373 #[test]
374 fn test_temperature_avx2_parity() {
375 if !is_x86_feature_detected!("avx2") {
376 return;
377 }
378 let mut scalar = [2.0, 4.0, 6.0, 8.0];
379 let mut avx2 = scalar;
380 temperature_scalar(&mut scalar, 2.0);
381 unsafe { temperature_avx2(&mut avx2, 2.0) };
382 assert_eq!(scalar, avx2);
383 }
384
385 #[test]
386 fn test_sampling_ptx_structure() {
387 let ptx = sampling_ptx();
388 assert!(ptx.contains(".entry greedy_kernel"));
389 assert!(ptx.contains("ret;"));
390 }
391
392 #[cfg(target_arch = "x86_64")]
393 #[test]
394 fn test_greedy_avx2_parity() {
395 if !is_x86_feature_detected!("avx2") {
396 return;
397 }
398 let logits = [0.1, 0.5, -0.3, 0.8, 0.2];
399 let scalar = greedy_scalar(&logits);
400 let avx2 = unsafe { greedy_avx2(&logits) };
401 assert_eq!(scalar, avx2);
402 }
403}