Skip to main content

provable_contracts/kernels/
sampling.rs

1//! Sampling algorithms kernel.
2//!
3//! Matches `sampling-algorithms-v1.yaml`.
4//! Greedy, top-k, top-p, and temperature sampling for autoregressive generation.
5//!
6//! Each function provides one of three backends:
7//! - `fn {name}_scalar(...)` -- Pure Rust scalar reference (ground truth)
8//! - `unsafe fn {name}_avx2(...)` -- AVX2 SIMD implementation
9//! - `fn sampling_ptx() -> &'static str` -- PTX assembly source string
10
11// ────────────────────────────────────────────────────────────────────────────
12// Scalar implementations
13// ────────────────────────────────────────────────────────────────────────────
14
15/// Greedy sampling: return the index of the maximum logit.
16///
17/// # Panics
18/// Panics if `logits` is empty.
19pub fn greedy_scalar(logits: &[f32]) -> usize {
20    assert!(!logits.is_empty(), "logits must not be empty");
21    let mut best_idx = 0;
22    let mut best_val = logits[0];
23    for (i, &v) in logits.iter().enumerate().skip(1) {
24        if v > best_val {
25            best_val = v;
26            best_idx = i;
27        }
28    }
29    best_idx
30}
31
32/// Apply temperature scaling to logits in-place: `logits[i] /= temperature`.
33///
34/// # Panics
35/// Panics if `temperature <= 0`.
36pub fn temperature_scalar(logits: &mut [f32], temperature: f32) {
37    assert!(
38        temperature > 0.0,
39        "temperature must be positive, got {temperature}"
40    );
41    for v in logits.iter_mut() {
42        *v /= temperature;
43    }
44}
45
46/// Top-K filtering: zero out all probabilities except the K highest.
47///
48/// `probs` is modified in-place. After filtering, probabilities are renormalized.
49///
50/// # Panics
51/// Panics if `k == 0` or `k > probs.len()`.
52pub fn top_k_scalar(probs: &mut [f32], k: usize) {
53    let n = probs.len();
54    assert!(k > 0 && k <= n, "k={k} must be in [1, {n}]");
55
56    if k == n {
57        return; // nothing to filter
58    }
59
60    // Find the k-th largest value (selection via sorted indices)
61    let mut indices: Vec<usize> = (0..n).collect();
62    indices.sort_by(|&a, &b| {
63        probs[b]
64            .partial_cmp(&probs[a])
65            .unwrap_or(std::cmp::Ordering::Equal)
66    });
67
68    // Zero out everything below top-K
69    for &idx in &indices[k..] {
70        probs[idx] = 0.0;
71    }
72
73    // Renormalize
74    let sum: f32 = probs.iter().sum();
75    if sum > 0.0 {
76        for v in probs.iter_mut() {
77            *v /= sum;
78        }
79    }
80}
81
82/// Top-P (nucleus) filtering: retain the minimal set of tokens whose cumulative
83/// probability exceeds `threshold`.
84///
85/// `probs` is modified in-place. After filtering, probabilities are renormalized.
86///
87/// # Panics
88/// Panics if `threshold <= 0` or `threshold > 1`.
89pub fn top_p_scalar(probs: &mut [f32], threshold: f32) {
90    let n = probs.len();
91    assert!(
92        threshold > 0.0 && threshold <= 1.0,
93        "threshold must be in (0, 1], got {threshold}"
94    );
95
96    // Sort indices by probability descending
97    let mut indices: Vec<usize> = (0..n).collect();
98    indices.sort_by(|&a, &b| {
99        probs[b]
100            .partial_cmp(&probs[a])
101            .unwrap_or(std::cmp::Ordering::Equal)
102    });
103
104    // Accumulate until we exceed threshold
105    let mut cumsum = 0.0f32;
106    let mut cutoff = n;
107    for (rank, &idx) in indices.iter().enumerate() {
108        cumsum += probs[idx];
109        if cumsum >= threshold {
110            cutoff = rank + 1;
111            break;
112        }
113    }
114
115    // Zero out everything past cutoff
116    for &idx in &indices[cutoff..] {
117        probs[idx] = 0.0;
118    }
119
120    // Renormalize
121    let sum: f32 = probs.iter().sum();
122    if sum > 0.0 {
123        for v in probs.iter_mut() {
124            *v /= sum;
125        }
126    }
127}
128
129/// Full sampling pipeline: apply temperature, softmax, then greedy (scalar reference).
130///
131/// Returns the selected token index.
132pub fn sample_scalar(logits: &[f32]) -> usize {
133    greedy_scalar(logits)
134}
135
136// ────────────────────────────────────────────────────────────────────────────
137// AVX2 implementation
138// ────────────────────────────────────────────────────────────────────────────
139
140/// AVX2 greedy sampling -- delegates to scalar.
141///
142/// # Safety
143/// Requires AVX2 support.
144#[cfg(target_arch = "x86_64")]
145#[target_feature(enable = "avx2")]
146pub unsafe fn greedy_avx2(logits: &[f32]) -> usize {
147    greedy_scalar(logits)
148}
149
150/// AVX2 temperature scaling -- delegates to scalar.
151///
152/// # Safety
153/// Requires AVX2 support.
154#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx2")]
156pub unsafe fn temperature_avx2(logits: &mut [f32], temperature: f32) {
157    temperature_scalar(logits, temperature);
158}
159
160// ────────────────────────────────────────────────────────────────────────────
161// PTX implementation
162// ────────────────────────────────────────────────────────────────────────────
163
164/// PTX assembly for greedy sampling (argmax reduction).
165///
166/// Uses parallel reduction to find the maximum logit index.
167pub fn sampling_ptx() -> &'static str {
168    r#".version 8.5
169.target sm_90
170.address_size 64
171.visible .entry greedy_kernel(
172    .param .u64 LOGITS,
173    .param .u64 OUT_IDX,
174    .param .u32 VOCAB_SIZE
175) {
176    .reg .u32 %tid, %vocab_size, %k, %best_idx, %cur_idx;
177    .reg .u64 %logits_ptr, %out_ptr, %addr, %off64;
178    .reg .f32 %best_val, %cur_val;
179    .reg .pred %p_loop, %p_better;
180
181    mov.u32 %tid, %tid.x;
182
183    ld.param.u32 %vocab_size, [VOCAB_SIZE];
184    ld.param.u64 %logits_ptr, [LOGITS];
185    ld.param.u64 %out_ptr, [OUT_IDX];
186
187    // Only thread 0 performs the scan (simple serial argmax)
188    setp.ne.u32 %p_loop, %tid, 0;
189    @%p_loop bra EXIT;
190
191    // Load first element as initial best
192    ld.global.f32 %best_val, [%logits_ptr];
193    mov.u32 %best_idx, 0;
194    mov.u32 %k, 1;
195
196SCAN_LOOP:
197    setp.ge.u32 %p_loop, %k, %vocab_size;
198    @%p_loop bra STORE;
199
200    mul.wide.u32 %off64, %k, 4;
201    add.u64 %addr, %logits_ptr, %off64;
202    ld.global.f32 %cur_val, [%addr];
203
204    setp.gt.f32 %p_better, %cur_val, %best_val;
205    @!%p_better bra NEXT;
206    mov.f32 %best_val, %cur_val;
207    mov.u32 %best_idx, %k;
208NEXT:
209    add.u32 %k, %k, 1;
210    bra SCAN_LOOP;
211
212STORE:
213    st.global.u32 [%out_ptr], %best_idx;
214
215EXIT:
216    ret;
217}
218"#
219}
220
221// ────────────────────────────────────────────────────────────────────────────
222// Tests
223// ────────────────────────────────────────────────────────────────────────────
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use proptest::prelude::*;
229
230    #[test]
231    fn test_greedy_basic() {
232        assert_eq!(greedy_scalar(&[1.0, 3.0, 2.0]), 1);
233        assert_eq!(greedy_scalar(&[5.0]), 0);
234        assert_eq!(greedy_scalar(&[0.0, 0.0, 0.0, 1.0]), 3);
235    }
236
237    #[test]
238    fn test_greedy_is_argmax() {
239        let logits = [0.1, 0.5, -0.3, 0.8, 0.2];
240        let result = greedy_scalar(&logits);
241        let argmax = logits
242            .iter()
243            .enumerate()
244            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
245            .unwrap()
246            .0;
247        assert_eq!(result, argmax);
248    }
249
250    #[test]
251    fn test_temperature_identity() {
252        let original = [1.0, 2.0, 3.0];
253        let mut scaled = original;
254        temperature_scalar(&mut scaled, 1.0);
255        assert_eq!(scaled, original);
256    }
257
258    #[test]
259    fn test_temperature_scaling() {
260        let mut logits = [2.0, 4.0];
261        temperature_scalar(&mut logits, 2.0);
262        assert!((logits[0] - 1.0).abs() < 1e-6);
263        assert!((logits[1] - 2.0).abs() < 1e-6);
264    }
265
266    #[test]
267    fn test_top_k_cardinality() {
268        let mut probs = [0.1, 0.2, 0.3, 0.4];
269        top_k_scalar(&mut probs, 2);
270        let nonzero = probs.iter().filter(|&&p| p > 0.0).count();
271        assert!(nonzero <= 2, "expected at most 2, got {nonzero}");
272    }
273
274    #[test]
275    fn test_top_k_keeps_highest() {
276        let mut probs = [0.1, 0.4, 0.2, 0.3];
277        top_k_scalar(&mut probs, 2);
278        // indices 1 (0.4) and 3 (0.3) should survive
279        assert_eq!(probs[0], 0.0);
280        assert!(probs[1] > 0.0);
281        assert_eq!(probs[2], 0.0);
282        assert!(probs[3] > 0.0);
283    }
284
285    #[test]
286    fn test_top_k_renormalizes() {
287        let mut probs = [0.1, 0.2, 0.3, 0.4];
288        top_k_scalar(&mut probs, 2);
289        let sum: f32 = probs.iter().sum();
290        assert!((sum - 1.0).abs() < 1e-5, "sum should be 1.0, got {sum}");
291    }
292
293    #[test]
294    fn test_top_p_cumulative() {
295        let mut probs = [0.1, 0.2, 0.3, 0.4];
296        let threshold = 0.6;
297        top_p_scalar(&mut probs, threshold);
298        let sum: f32 = probs.iter().sum();
299        assert!(sum >= threshold - 1e-5, "sum {sum} < threshold {threshold}");
300    }
301
302    #[test]
303    fn test_top_p_minimal_set() {
304        let mut probs = [0.1, 0.2, 0.3, 0.4];
305        top_p_scalar(&mut probs, 0.5);
306        // Only index 3 (0.4) and index 2 (0.3) needed: 0.4+0.3 = 0.7 >= 0.5
307        // Actually, 0.4 alone is not >= 0.5, so need 0.4+0.3
308        let nonzero = probs.iter().filter(|&&p| p > 0.0).count();
309        assert!(
310            nonzero <= 2,
311            "expected minimal set size <= 2, got {nonzero}"
312        );
313    }
314
315    proptest! {
316        #[test]
317        fn prop_greedy_is_argmax(logits in proptest::collection::vec(-10.0f32..10.0, 1..16)) {
318            let result = greedy_scalar(&logits);
319            let argmax = logits.iter().enumerate()
320                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
321                .unwrap().0;
322            prop_assert_eq!(result, argmax);
323        }
324
325        #[test]
326        fn prop_top_k_cardinality(
327            k in 1usize..8,
328            n in 8usize..16,
329        ) {
330            let mut probs: Vec<f32> = (0..n).map(|i| (i as f32 + 1.0) / (n as f32)).collect();
331            let sum: f32 = probs.iter().sum();
332            for v in probs.iter_mut() { *v /= sum; }
333
334            top_k_scalar(&mut probs, k);
335            let nonzero = probs.iter().filter(|&&p| p > 0.0).count();
336            prop_assert!(nonzero <= k, "nonzero={nonzero} > k={k}");
337        }
338
339        #[test]
340        fn prop_temperature_identity(logits in proptest::collection::vec(-10.0f32..10.0, 1..16)) {
341            let original = logits.clone();
342            let mut scaled = logits;
343            temperature_scalar(&mut scaled, 1.0);
344            for (a, b) in original.iter().zip(scaled.iter()) {
345                prop_assert!((a - b).abs() < 1e-6);
346            }
347        }
348    }
349
350    #[test]
351    fn test_sample_scalar_delegates_to_greedy() {
352        let logits = [0.1, 0.5, -0.3, 0.8, 0.2];
353        assert_eq!(sample_scalar(&logits), greedy_scalar(&logits));
354    }
355
356    #[test]
357    fn test_top_k_full_k_is_noop() {
358        let mut probs = [0.25, 0.25, 0.25, 0.25];
359        let original = probs;
360        top_k_scalar(&mut probs, 4);
361        assert_eq!(probs, original);
362    }
363
364    #[test]
365    fn test_top_p_threshold_one() {
366        let mut probs = [0.25, 0.25, 0.25, 0.25];
367        top_p_scalar(&mut probs, 1.0);
368        let sum: f32 = probs.iter().sum();
369        assert!((sum - 1.0).abs() < 1e-5);
370    }
371
372    #[cfg(target_arch = "x86_64")]
373    #[test]
374    fn test_temperature_avx2_parity() {
375        if !is_x86_feature_detected!("avx2") {
376            return;
377        }
378        let mut scalar = [2.0, 4.0, 6.0, 8.0];
379        let mut avx2 = scalar;
380        temperature_scalar(&mut scalar, 2.0);
381        unsafe { temperature_avx2(&mut avx2, 2.0) };
382        assert_eq!(scalar, avx2);
383    }
384
385    #[test]
386    fn test_sampling_ptx_structure() {
387        let ptx = sampling_ptx();
388        assert!(ptx.contains(".entry greedy_kernel"));
389        assert!(ptx.contains("ret;"));
390    }
391
392    #[cfg(target_arch = "x86_64")]
393    #[test]
394    fn test_greedy_avx2_parity() {
395        if !is_x86_feature_detected!("avx2") {
396            return;
397        }
398        let logits = [0.1, 0.5, -0.3, 0.8, 0.2];
399        let scalar = greedy_scalar(&logits);
400        let avx2 = unsafe { greedy_avx2(&logits) };
401        assert_eq!(scalar, avx2);
402    }
403}