1use super::ops;
12
13fn single_head_attention(
19 q_head: &[f32],
20 k_head: &[f32],
21 v_head: &[f32],
22 seq_len: usize,
23 d_k: usize,
24 d_v: usize,
25 output: &mut [f32],
26) {
27 let mut scores = vec![0.0f32; seq_len * seq_len];
29 ops::score_matrix(q_head, k_head, seq_len, seq_len, d_k, &mut scores);
30
31 ops::softmax_rows(&mut scores, seq_len, seq_len);
33
34 ops::matmul_sv(&scores, v_head, seq_len, seq_len, d_v, output);
36}
37
38pub fn gqa_scalar(
57 q: &[f32],
58 k: &[f32],
59 v: &[f32],
60 seq_len: usize,
61 d_k: usize,
62 d_v: usize,
63 num_heads: usize,
64 num_kv_heads: usize,
65 output: &mut [f32],
66) {
67 assert!(
68 num_kv_heads > 0 && num_heads % num_kv_heads == 0,
69 "num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
70 );
71 let q_total = num_heads * seq_len * d_k;
72 let k_total = num_kv_heads * seq_len * d_k;
73 let v_total = num_kv_heads * seq_len * d_v;
74 let o_total = num_heads * seq_len * d_v;
75 assert_eq!(
76 q.len(),
77 q_total,
78 "Q dimension mismatch: expected {q_total} got {}",
79 q.len()
80 );
81 assert_eq!(
82 k.len(),
83 k_total,
84 "K dimension mismatch: expected {k_total} got {}",
85 k.len()
86 );
87 assert_eq!(
88 v.len(),
89 v_total,
90 "V dimension mismatch: expected {v_total} got {}",
91 v.len()
92 );
93 assert_eq!(
94 output.len(),
95 o_total,
96 "output dimension mismatch: expected {o_total} got {}",
97 output.len()
98 );
99
100 let heads_per_kv = num_heads / num_kv_heads;
101 let q_head_stride = seq_len * d_k;
102 let k_head_stride = seq_len * d_k;
103 let v_head_stride = seq_len * d_v;
104 let o_head_stride = seq_len * d_v;
105
106 for h in 0..num_heads {
107 let kv_head = h / heads_per_kv;
108
109 let q_start = h * q_head_stride;
110 let k_start = kv_head * k_head_stride;
111 let v_start = kv_head * v_head_stride;
112 let o_start = h * o_head_stride;
113
114 let q_head = &q[q_start..q_start + q_head_stride];
115 let k_head = &k[k_start..k_start + k_head_stride];
116 let v_head = &v[v_start..v_start + v_head_stride];
117 let o_head = &mut output[o_start..o_start + o_head_stride];
118
119 single_head_attention(q_head, k_head, v_head, seq_len, d_k, d_v, o_head);
120 }
121}
122
123#[cfg(target_arch = "x86_64")]
135#[target_feature(enable = "avx2")]
136pub unsafe fn gqa_avx2(
137 q: &[f32],
138 k: &[f32],
139 v: &[f32],
140 seq_len: usize,
141 d_k: usize,
142 d_v: usize,
143 num_heads: usize,
144 num_kv_heads: usize,
145 output: &mut [f32],
146) {
147 gqa_scalar(q, k, v, seq_len, d_k, d_v, num_heads, num_kv_heads, output);
148}
149
150include!("gqa_ptx.rs");
151
152#[cfg(test)]
157mod tests {
158 use super::super::ops::sequential_floats;
159 use super::super::ulp::assert_ulp_eq;
160 use super::*;
161 use proptest::prelude::*;
162
163 #[test]
166 fn test_gqa_equals_mha_when_heads_match() {
167 let seq_len = 2;
170 let d_k = 3;
171 let d_v = 2;
172 let num_heads = 2;
173 let num_kv_heads = 2;
174
175 let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
176 let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.15);
177 let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.2);
178 let mut output = vec![0.0f32; num_heads * seq_len * d_v];
179
180 gqa_scalar(
181 &q,
182 &k,
183 &v,
184 seq_len,
185 d_k,
186 d_v,
187 num_heads,
188 num_kv_heads,
189 &mut output,
190 );
191
192 for h in 0..num_heads {
194 let q_start = h * seq_len * d_k;
195 let k_start = h * seq_len * d_k; let v_start = h * seq_len * d_v;
197 let o_start = h * seq_len * d_v;
198
199 let mut expected = vec![0.0f32; seq_len * d_v];
200 single_head_attention(
201 &q[q_start..q_start + seq_len * d_k],
202 &k[k_start..k_start + seq_len * d_k],
203 &v[v_start..v_start + seq_len * d_v],
204 seq_len,
205 d_k,
206 d_v,
207 &mut expected,
208 );
209
210 assert_ulp_eq(&output[o_start..o_start + seq_len * d_v], &expected, 0);
211 }
212 }
213
214 #[test]
217 fn test_gqa_kv_broadcasting() {
218 let seq_len = 2;
220 let d_k = 2;
221 let d_v = 2;
222 let num_heads = 4;
223 let num_kv_heads = 2;
224
225 let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
226 let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.2);
227 let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.15);
228 let mut output = vec![0.0f32; num_heads * seq_len * d_v];
229
230 gqa_scalar(
231 &q,
232 &k,
233 &v,
234 seq_len,
235 d_k,
236 d_v,
237 num_heads,
238 num_kv_heads,
239 &mut output,
240 );
241
242 let head_stride_o = seq_len * d_v;
244
245 let mut head0_ref = vec![0.0f32; seq_len * d_v];
249 let mut head1_ref = vec![0.0f32; seq_len * d_v];
250 single_head_attention(
251 &q[0..seq_len * d_k],
252 &k[0..seq_len * d_k], &v[0..seq_len * d_v], seq_len,
255 d_k,
256 d_v,
257 &mut head0_ref,
258 );
259 single_head_attention(
260 &q[seq_len * d_k..2 * seq_len * d_k],
261 &k[0..seq_len * d_k], &v[0..seq_len * d_v], seq_len,
264 d_k,
265 d_v,
266 &mut head1_ref,
267 );
268
269 assert_ulp_eq(&output[0..head_stride_o], &head0_ref, 0);
270 assert_ulp_eq(&output[head_stride_o..2 * head_stride_o], &head1_ref, 0);
271
272 let mut head2_ref = vec![0.0f32; seq_len * d_v];
274 let mut head3_ref = vec![0.0f32; seq_len * d_v];
275 single_head_attention(
276 &q[2 * seq_len * d_k..3 * seq_len * d_k],
277 &k[seq_len * d_k..2 * seq_len * d_k], &v[seq_len * d_v..2 * seq_len * d_v], seq_len,
280 d_k,
281 d_v,
282 &mut head2_ref,
283 );
284 single_head_attention(
285 &q[3 * seq_len * d_k..4 * seq_len * d_k],
286 &k[seq_len * d_k..2 * seq_len * d_k], &v[seq_len * d_v..2 * seq_len * d_v], seq_len,
289 d_k,
290 d_v,
291 &mut head3_ref,
292 );
293
294 assert_ulp_eq(&output[2 * head_stride_o..3 * head_stride_o], &head2_ref, 0);
295 assert_ulp_eq(&output[3 * head_stride_o..4 * head_stride_o], &head3_ref, 0);
296 }
297
298 #[test]
301 fn test_gqa_single_head_single_pos() {
302 let seq_len = 1;
304 let d_k = 2;
305 let d_v = 3;
306 let num_heads = 1;
307 let num_kv_heads = 1;
308
309 let q = vec![1.0, 0.5];
310 let k = vec![0.5, 1.0];
311 let v = vec![2.0, 3.0, 4.0];
312 let mut output = vec![0.0f32; d_v];
313
314 gqa_scalar(
315 &q,
316 &k,
317 &v,
318 seq_len,
319 d_k,
320 d_v,
321 num_heads,
322 num_kv_heads,
323 &mut output,
324 );
325
326 assert_ulp_eq(&output, &v, 0);
328 }
329
330 #[test]
333 #[should_panic(expected = "must be divisible")]
334 fn test_gqa_bad_head_ratio() {
335 let mut output = vec![0.0f32; 4];
336 gqa_scalar(&[0.0; 6], &[0.0; 4], &[0.0; 4], 1, 2, 2, 3, 2, &mut output);
337 }
338
339 #[test]
340 #[should_panic(expected = "Q dimension mismatch")]
341 fn test_gqa_bad_q_dim() {
342 let mut output = vec![0.0f32; 4];
343 gqa_scalar(&[0.0; 3], &[0.0; 2], &[0.0; 2], 1, 2, 2, 2, 2, &mut output);
344 }
345
346 proptest! {
349 #[test]
350 fn prop_gqa_output_finite(
351 seq_len in 1usize..3,
352 d_k in 1usize..4,
353 d_v in 1usize..4,
354 ) {
355 let num_heads = 4usize;
356 let num_kv_heads = 2usize;
357
358 let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
359 let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.1);
360 let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.1);
361 let mut output = vec![0.0f32; num_heads * seq_len * d_v];
362
363 gqa_scalar(&q, &k, &v, seq_len, d_k, d_v, num_heads, num_kv_heads, &mut output);
364
365 for (idx, &val) in output.iter().enumerate() {
366 prop_assert!(val.is_finite(), "output[{idx}] = {val} is not finite");
367 }
368 }
369
370 #[test]
371 fn prop_gqa_mha_equivalence(
372 seq_len in 1usize..3,
373 d_k in 1usize..3,
374 d_v in 1usize..3,
375 num_heads in 1usize..4,
376 ) {
377 let num_kv_heads = num_heads;
379 let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
380 let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.15);
381 let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.2);
382 let mut output = vec![0.0f32; num_heads * seq_len * d_v];
383
384 gqa_scalar(&q, &k, &v, seq_len, d_k, d_v, num_heads, num_kv_heads, &mut output);
385
386 for h in 0..num_heads {
388 let q_start = h * seq_len * d_k;
389 let k_start = h * seq_len * d_k;
390 let v_start = h * seq_len * d_v;
391 let o_start = h * seq_len * d_v;
392 let o_len = seq_len * d_v;
393
394 let mut expected = vec![0.0f32; o_len];
395 single_head_attention(
396 &q[q_start..q_start + seq_len * d_k],
397 &k[k_start..k_start + seq_len * d_k],
398 &v[v_start..v_start + seq_len * d_v],
399 seq_len, d_k, d_v, &mut expected,
400 );
401
402 for idx in 0..o_len {
403 let diff = (output[o_start + idx] - expected[idx]).abs();
404 prop_assert!(
405 diff < 1e-5,
406 "head {h} idx {idx}: expected {} got {} (diff {diff})",
407 expected[idx], output[o_start + idx]
408 );
409 }
410 }
411 }
412 }
413
414 #[cfg(target_arch = "x86_64")]
417 #[test]
418 fn test_gqa_avx2_parity() {
419 if !is_x86_feature_detected!("avx2") {
420 return;
421 }
422 let seq_len = 3;
423 let d_k = 4;
424 let d_v = 2;
425 let num_heads = 4;
426 let num_kv_heads = 2;
427
428 let q = sequential_floats(num_heads * seq_len * d_k, 0.1);
429 let k = sequential_floats(num_kv_heads * seq_len * d_k, 0.2);
430 let v = sequential_floats(num_kv_heads * seq_len * d_v, 0.15);
431
432 let mut scalar_out = vec![0.0f32; num_heads * seq_len * d_v];
433 let mut avx2_out = vec![0.0f32; num_heads * seq_len * d_v];
434
435 gqa_scalar(
436 &q,
437 &k,
438 &v,
439 seq_len,
440 d_k,
441 d_v,
442 num_heads,
443 num_kv_heads,
444 &mut scalar_out,
445 );
446 unsafe {
447 gqa_avx2(
448 &q,
449 &k,
450 &v,
451 seq_len,
452 d_k,
453 d_v,
454 num_heads,
455 num_kv_heads,
456 &mut avx2_out,
457 );
458 }
459
460 assert_ulp_eq(&scalar_out, &avx2_out, 8);
461 }
462
463 #[test]
466 fn test_gqa_ptx_structure() {
467 let ptx = gqa_ptx();
468 assert!(ptx.contains(".version 8.5"), "missing PTX version");
469 assert!(ptx.contains(".target sm_90"), "missing PTX target");
470 assert!(ptx.contains(".entry gqa_kernel"), "missing entry point");
471 assert!(ptx.contains("ret;"), "missing ret instruction");
472 assert!(ptx.contains(".shared"), "missing shared memory declaration");
473 assert!(ptx.contains("bar.sync"), "missing barrier synchronization");
474 assert!(
475 ptx.contains("div.u32"),
476 "missing integer division for head mapping"
477 );
478 assert!(ptx.contains("ex2.approx.f32"), "missing exp approximation");
479 let open = ptx.matches('{').count();
480 let close = ptx.matches('}').count();
481 assert_eq!(
482 open, close,
483 "unbalanced braces: {open} open vs {close} close"
484 );
485 }
486
487 #[test]
488 fn test_gqa_ptx_nonempty() {
489 assert!(!gqa_ptx().is_empty());
490 }
491}