Skip to main content

provable_contracts/kernels/
gqa.rs

1//! Grouped Query Attention kernel.
2//!
3//! Matches `gqa-kernel-v1.yaml`.
4//! KV head broadcasting: kv_head = query_head / (num_heads / num_kv_heads)
5//!
6//! Each function provides one of three backends:
7//! - `fn gqa_scalar(...)` -- Pure Rust scalar reference (ground truth)
8//! - `unsafe fn gqa_avx2(...)` -- AVX2 SIMD implementation
9//! - `fn gqa_ptx() -> &'static str` -- PTX assembly source string
10
11use super::ops;
12
13/// Single-head attention helper: computes attention for one query sequence
14/// against one KV head.
15///
16/// Q_head is seq_len x d_k, K_head is seq_len x d_k, V_head is seq_len x d_v,
17/// output is seq_len x d_v.
18fn single_head_attention(
19    q_head: &[f32],
20    k_head: &[f32],
21    v_head: &[f32],
22    seq_len: usize,
23    d_k: usize,
24    d_v: usize,
25    output: &mut [f32],
26) {
27    // scores = Q_head * K_head^T / sqrt(d_k), shape seq_len x seq_len
28    let mut scores = vec![0.0f32; seq_len * seq_len];
29    ops::score_matrix(q_head, k_head, seq_len, seq_len, d_k, &mut scores);
30
31    // Softmax each row
32    ops::softmax_rows(&mut scores, seq_len, seq_len);
33
34    // output = scores * V_head, shape seq_len x d_v
35    ops::matmul_sv(&scores, v_head, seq_len, seq_len, d_v, output);
36}
37
38// ────────────────────────────────────────────────────────────────────────────
39// Scalar implementation
40// ────────────────────────────────────────────────────────────────────────────
41
42/// Grouped Query Attention (scalar reference).
43///
44/// For each query head h in 0..num_heads:
45///   kv_head = h / (num_heads / num_kv_heads)
46///   Compute attention(Q\[h\], K\[kv_head\], V\[kv_head\]) -> output\[h\]
47///
48/// Layout (all row-major):
49/// - Q: num_heads * seq_len * d_k
50/// - K: num_kv_heads * seq_len * d_k
51/// - V: num_kv_heads * seq_len * d_v
52/// - output: num_heads * seq_len * d_v
53///
54/// # Panics
55/// Panics if `num_heads % num_kv_heads != 0` or dimensions are inconsistent.
56pub fn gqa_scalar(
57    q: &[f32],
58    k: &[f32],
59    v: &[f32],
60    seq_len: usize,
61    d_k: usize,
62    d_v: usize,
63    num_heads: usize,
64    num_kv_heads: usize,
65    output: &mut [f32],
66) {
67    assert!(
68        num_kv_heads > 0 && num_heads % num_kv_heads == 0,
69        "num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
70    );
71    let q_total = num_heads * seq_len * d_k;
72    let k_total = num_kv_heads * seq_len * d_k;
73    let v_total = num_kv_heads * seq_len * d_v;
74    let o_total = num_heads * seq_len * d_v;
75    assert_eq!(
76        q.len(),
77        q_total,
78        "Q dimension mismatch: expected {q_total} got {}",
79        q.len()
80    );
81    assert_eq!(
82        k.len(),
83        k_total,
84        "K dimension mismatch: expected {k_total} got {}",
85        k.len()
86    );
87    assert_eq!(
88        v.len(),
89        v_total,
90        "V dimension mismatch: expected {v_total} got {}",
91        v.len()
92    );
93    assert_eq!(
94        output.len(),
95        o_total,
96        "output dimension mismatch: expected {o_total} got {}",
97        output.len()
98    );
99
100    let heads_per_kv = num_heads / num_kv_heads;
101    let q_head_stride = seq_len * d_k;
102    let k_head_stride = seq_len * d_k;
103    let v_head_stride = seq_len * d_v;
104    let o_head_stride = seq_len * d_v;
105
106    for h in 0..num_heads {
107        let kv_head = h / heads_per_kv;
108
109        let q_start = h * q_head_stride;
110        let k_start = kv_head * k_head_stride;
111        let v_start = kv_head * v_head_stride;
112        let o_start = h * o_head_stride;
113
114        let q_head = &q[q_start..q_start + q_head_stride];
115        let k_head = &k[k_start..k_start + k_head_stride];
116        let v_head = &v[v_start..v_start + v_head_stride];
117        let o_head = &mut output[o_start..o_start + o_head_stride];
118
119        single_head_attention(q_head, k_head, v_head, seq_len, d_k, d_v, o_head);
120    }
121}
122
123// ────────────────────────────────────────────────────────────────────────────
124// AVX2 implementation
125// ────────────────────────────────────────────────────────────────────────────
126
127/// AVX2 Grouped Query Attention -- delegates to scalar.
128///
129/// # Safety
130/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
131///
132/// # Panics
133/// Panics if dimensions are inconsistent.
134#[cfg(target_arch = "x86_64")]
135#[target_feature(enable = "avx2")]
136pub unsafe fn gqa_avx2(
137    q: &[f32],
138    k: &[f32],
139    v: &[f32],
140    seq_len: usize,
141    d_k: usize,
142    d_v: usize,
143    num_heads: usize,
144    num_kv_heads: usize,
145    output: &mut [f32],
146) {
147    gqa_scalar(q, k, v, seq_len, d_k, d_v, num_heads, num_kv_heads, output);
148}
149
150include!("gqa_ptx.rs");
151
152// ────────────────────────────────────────────────────────────────────────────
153// Tests
154// ────────────────────────────────────────────────────────────────────────────
155
156#[cfg(test)]
157mod tests {
158    use super::super::ops::sequential_floats;
159    use super::super::ulp::assert_ulp_eq;
160    use super::*;
161    use proptest::prelude::*;
162
163    // ── MHA equivalence (num_heads == num_kv_heads) ─────────────────────
164
165    #[test]
166    fn test_gqa_equals_mha_when_heads_match() {
167        // When num_heads == num_kv_heads, GQA degenerates to standard MHA.
168        // Each query head gets its own unique KV head.
169        let seq_len = 2;
170        let d_k = 3;
171        let d_v = 2;
172        let num_heads = 2;
173        let num_kv_heads = 2;
174
175        let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
176        let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.15);
177        let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.2);
178        let mut output = vec![0.0f32; num_heads * seq_len * d_v];
179
180        gqa_scalar(
181            &q,
182            &k,
183            &v,
184            seq_len,
185            d_k,
186            d_v,
187            num_heads,
188            num_kv_heads,
189            &mut output,
190        );
191
192        // Verify by computing each head independently
193        for h in 0..num_heads {
194            let q_start = h * seq_len * d_k;
195            let k_start = h * seq_len * d_k; // kv_head == h since num_heads == num_kv_heads
196            let v_start = h * seq_len * d_v;
197            let o_start = h * seq_len * d_v;
198
199            let mut expected = vec![0.0f32; seq_len * d_v];
200            single_head_attention(
201                &q[q_start..q_start + seq_len * d_k],
202                &k[k_start..k_start + seq_len * d_k],
203                &v[v_start..v_start + seq_len * d_v],
204                seq_len,
205                d_k,
206                d_v,
207                &mut expected,
208            );
209
210            assert_ulp_eq(&output[o_start..o_start + seq_len * d_v], &expected, 0);
211        }
212    }
213
214    // ── KV broadcasting test ────────────────────────────────────────────
215
216    #[test]
217    fn test_gqa_kv_broadcasting() {
218        // 4 query heads, 2 kv heads: heads 0,1 use kv 0; heads 2,3 use kv 1
219        let seq_len = 2;
220        let d_k = 2;
221        let d_v = 2;
222        let num_heads = 4;
223        let num_kv_heads = 2;
224
225        let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
226        let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.2);
227        let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.15);
228        let mut output = vec![0.0f32; num_heads * seq_len * d_v];
229
230        gqa_scalar(
231            &q,
232            &k,
233            &v,
234            seq_len,
235            d_k,
236            d_v,
237            num_heads,
238            num_kv_heads,
239            &mut output,
240        );
241
242        // Verify: heads 0 and 1 use kv_head=0, heads 2 and 3 use kv_head=1
243        let head_stride_o = seq_len * d_v;
244
245        // Head 0 and head 1 both use KV head 0, but with different Q
246        // So they should generally produce different outputs (different Q)
247        // but both use the same K, V from kv_head 0
248        let mut head0_ref = vec![0.0f32; seq_len * d_v];
249        let mut head1_ref = vec![0.0f32; seq_len * d_v];
250        single_head_attention(
251            &q[0..seq_len * d_k],
252            &k[0..seq_len * d_k], // kv head 0
253            &v[0..seq_len * d_v], // kv head 0
254            seq_len,
255            d_k,
256            d_v,
257            &mut head0_ref,
258        );
259        single_head_attention(
260            &q[seq_len * d_k..2 * seq_len * d_k],
261            &k[0..seq_len * d_k], // kv head 0 (shared)
262            &v[0..seq_len * d_v], // kv head 0 (shared)
263            seq_len,
264            d_k,
265            d_v,
266            &mut head1_ref,
267        );
268
269        assert_ulp_eq(&output[0..head_stride_o], &head0_ref, 0);
270        assert_ulp_eq(&output[head_stride_o..2 * head_stride_o], &head1_ref, 0);
271
272        // Head 2 and head 3 use kv_head 1
273        let mut head2_ref = vec![0.0f32; seq_len * d_v];
274        let mut head3_ref = vec![0.0f32; seq_len * d_v];
275        single_head_attention(
276            &q[2 * seq_len * d_k..3 * seq_len * d_k],
277            &k[seq_len * d_k..2 * seq_len * d_k], // kv head 1
278            &v[seq_len * d_v..2 * seq_len * d_v], // kv head 1
279            seq_len,
280            d_k,
281            d_v,
282            &mut head2_ref,
283        );
284        single_head_attention(
285            &q[3 * seq_len * d_k..4 * seq_len * d_k],
286            &k[seq_len * d_k..2 * seq_len * d_k], // kv head 1
287            &v[seq_len * d_v..2 * seq_len * d_v], // kv head 1
288            seq_len,
289            d_k,
290            d_v,
291            &mut head3_ref,
292        );
293
294        assert_ulp_eq(&output[2 * head_stride_o..3 * head_stride_o], &head2_ref, 0);
295        assert_ulp_eq(&output[3 * head_stride_o..4 * head_stride_o], &head3_ref, 0);
296    }
297
298    // ── Single head, single position ────────────────────────────────────
299
300    #[test]
301    fn test_gqa_single_head_single_pos() {
302        // Minimal case: 1 head, 1 kv head, seq_len=1
303        let seq_len = 1;
304        let d_k = 2;
305        let d_v = 3;
306        let num_heads = 1;
307        let num_kv_heads = 1;
308
309        let q = vec![1.0, 0.5];
310        let k = vec![0.5, 1.0];
311        let v = vec![2.0, 3.0, 4.0];
312        let mut output = vec![0.0f32; d_v];
313
314        gqa_scalar(
315            &q,
316            &k,
317            &v,
318            seq_len,
319            d_k,
320            d_v,
321            num_heads,
322            num_kv_heads,
323            &mut output,
324        );
325
326        // Single query, single key: softmax of single score = 1.0, output = V
327        assert_ulp_eq(&output, &v, 0);
328    }
329
330    // ── Assertion tests ─────────────────────────────────────────────────
331
332    #[test]
333    #[should_panic(expected = "must be divisible")]
334    fn test_gqa_bad_head_ratio() {
335        let mut output = vec![0.0f32; 4];
336        gqa_scalar(&[0.0; 6], &[0.0; 4], &[0.0; 4], 1, 2, 2, 3, 2, &mut output);
337    }
338
339    #[test]
340    #[should_panic(expected = "Q dimension mismatch")]
341    fn test_gqa_bad_q_dim() {
342        let mut output = vec![0.0f32; 4];
343        gqa_scalar(&[0.0; 3], &[0.0; 2], &[0.0; 2], 1, 2, 2, 2, 2, &mut output);
344    }
345
346    // ── Property-based tests ────────────────────────────────────────────
347
348    proptest! {
349        #[test]
350        fn prop_gqa_output_finite(
351            seq_len in 1usize..3,
352            d_k in 1usize..4,
353            d_v in 1usize..4,
354        ) {
355            let num_heads = 4usize;
356            let num_kv_heads = 2usize;
357
358            let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
359            let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.1);
360            let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.1);
361            let mut output = vec![0.0f32; num_heads * seq_len * d_v];
362
363            gqa_scalar(&q, &k, &v, seq_len, d_k, d_v, num_heads, num_kv_heads, &mut output);
364
365            for (idx, &val) in output.iter().enumerate() {
366                prop_assert!(val.is_finite(), "output[{idx}] = {val} is not finite");
367            }
368        }
369
370        #[test]
371        fn prop_gqa_mha_equivalence(
372            seq_len in 1usize..3,
373            d_k in 1usize..3,
374            d_v in 1usize..3,
375            num_heads in 1usize..4,
376        ) {
377            // When num_heads == num_kv_heads, each head is independent
378            let num_kv_heads = num_heads;
379            let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
380            let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.15);
381            let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.2);
382            let mut output = vec![0.0f32; num_heads * seq_len * d_v];
383
384            gqa_scalar(&q, &k, &v, seq_len, d_k, d_v, num_heads, num_kv_heads, &mut output);
385
386            // Verify each head independently
387            for h in 0..num_heads {
388                let q_start = h * seq_len * d_k;
389                let k_start = h * seq_len * d_k;
390                let v_start = h * seq_len * d_v;
391                let o_start = h * seq_len * d_v;
392                let o_len = seq_len * d_v;
393
394                let mut expected = vec![0.0f32; o_len];
395                single_head_attention(
396                    &q[q_start..q_start + seq_len * d_k],
397                    &k[k_start..k_start + seq_len * d_k],
398                    &v[v_start..v_start + seq_len * d_v],
399                    seq_len, d_k, d_v, &mut expected,
400                );
401
402                for idx in 0..o_len {
403                    let diff = (output[o_start + idx] - expected[idx]).abs();
404                    prop_assert!(
405                        diff < 1e-5,
406                        "head {h} idx {idx}: expected {} got {} (diff {diff})",
407                        expected[idx], output[o_start + idx]
408                    );
409                }
410            }
411        }
412    }
413
414    // ── AVX2 parity test ────────────────────────────────────────────────
415
416    #[cfg(target_arch = "x86_64")]
417    #[test]
418    fn test_gqa_avx2_parity() {
419        if !is_x86_feature_detected!("avx2") {
420            return;
421        }
422        let seq_len = 3;
423        let d_k = 4;
424        let d_v = 2;
425        let num_heads = 4;
426        let num_kv_heads = 2;
427
428        let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
429        let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.2);
430        let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.15);
431
432        let mut scalar_out = vec![0.0f32; num_heads * seq_len * d_v];
433        let mut avx2_out = vec![0.0f32; num_heads * seq_len * d_v];
434
435        gqa_scalar(
436            &q,
437            &k,
438            &v,
439            seq_len,
440            d_k,
441            d_v,
442            num_heads,
443            num_kv_heads,
444            &mut scalar_out,
445        );
446        unsafe {
447            gqa_avx2(
448                &q,
449                &k,
450                &v,
451                seq_len,
452                d_k,
453                d_v,
454                num_heads,
455                num_kv_heads,
456                &mut avx2_out,
457            );
458        }
459
460        assert_ulp_eq(&scalar_out, &avx2_out, 8);
461    }
462
463    // ── PTX structural tests ────────────────────────────────────────────
464
465    #[test]
466    fn test_gqa_ptx_structure() {
467        let ptx = gqa_ptx();
468        assert!(ptx.contains(".version 8.5"), "missing PTX version");
469        assert!(ptx.contains(".target sm_90"), "missing PTX target");
470        assert!(ptx.contains(".entry gqa_kernel"), "missing entry point");
471        assert!(ptx.contains("ret;"), "missing ret instruction");
472        assert!(ptx.contains(".shared"), "missing shared memory declaration");
473        assert!(ptx.contains("bar.sync"), "missing barrier synchronization");
474        assert!(
475            ptx.contains("div.u32"),
476            "missing integer division for head mapping"
477        );
478        assert!(ptx.contains("ex2.approx.f32"), "missing exp approximation");
479        let open = ptx.matches('{').count();
480        let close = ptx.matches('}').count();
481        assert_eq!(
482            open, close,
483            "unbalanced braces: {open} open vs {close} close"
484        );
485    }
486
487    #[test]
488    fn test_gqa_ptx_nonempty() {
489        assert!(!gqa_ptx().is_empty());
490    }
491}