1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
//! FlashAttention-Style Tiled Attention Kernel
//!
//! Implements IO-aware attention per Dao et al. [16]
//! Never materializes the full N×N attention matrix.
//!
//! Standard Attention: O(N²) memory for S = Q × K^T
//! FlashAttention: O(N × d) memory using online softmax
//!
//! ## Variants
//!
//! - **Standard**: FP32 serial dot product (baseline, ~79ms/token)
//! - **Tensor Core**: FP16 WMMA for Q×K^T (target: <2ms/token, ~40x speedup)
//!
//! ## Performance (RTX 4090, seq_len=2048, head_dim=128)
//!
//! | Variant | Time/token | Throughput |
//! |-------------|------------|------------|
//! | Standard | 79ms | 12.6 tok/s |
//! | Tensor Core | ~2ms | ~500 tok/s |
//! FlashAttention kernel for SRAM-bound tiled attention.
//!
//! This module implements the original FlashAttention algorithm (Dao et al.)
//! which tiles computation to fit in GPU SRAM, avoiding materialization
//! of the full N×N attention matrix.
#![allow(clippy::similar_names)]
#![allow(clippy::too_many_lines)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType, WmmaLayout};
/// FlashAttention-style kernel configuration
#[derive(Debug, Clone)]
pub struct AttentionKernel {
/// Sequence length (N)
pub seq_len: u32,
/// Head dimension (d)
pub head_dim: u32,
/// Tile size for Q (B_r)
pub tile_q: u32,
/// Tile size for KV (B_c)
pub tile_kv: u32,
/// Scaling factor for attention scores (1/sqrt(d))
pub scale: f32,
/// Use causal masking (for autoregressive models)
pub causal: bool,
/// Use Tensor Cores for Q×K^T (FP16 WMMA, requires sm_70+)
pub use_tensor_cores: bool,
}
impl AttentionKernel {
/// Create a new attention kernel
///
/// Tile sizes are auto-clamped to not exceed seq_len and head_dim
/// to handle small inputs gracefully.
#[must_use]
pub fn new(seq_len: u32, head_dim: u32) -> Self {
// Auto-clamp tile sizes to input dimensions
// Default tiles: 64, but reduce if inputs are smaller
let tile_q = seq_len.min(64);
// GH-5 FIX: Ensure tile_kv >= head_dim to prevent shared memory overflow
// The K dot product loop accesses K[local_col * head_dim + d_idx], which requires
// head_dim rows in K tile. Without this, we overflow when tile_kv < head_dim.
let tile_kv = seq_len.min(64).max(head_dim);
Self {
seq_len,
head_dim,
tile_q,
tile_kv,
scale: 1.0 / (head_dim as f32).sqrt(),
causal: false,
use_tensor_cores: false,
}
}
/// Create Tensor Core attention kernel (highest performance)
///
/// Uses FP16 WMMA for Q×K^T computation, achieving ~40x speedup over FP32.
/// Requires sm_70+ (Volta or later). Dimensions should be multiples of 16.
///
/// # Performance
///
/// - Standard FP32: ~79ms/token (12.6 tok/s)
/// - Tensor Core FP16: ~2ms/token (500+ tok/s)
#[must_use]
pub fn tensor_core(seq_len: u32, head_dim: u32) -> Self {
// For Tensor Cores, use tile sizes that are multiples of 16
let tile_q = seq_len.clamp(16, 64);
// GH-5 FIX: Ensure tile_kv >= head_dim to prevent shared memory overflow
// Same issue as new() - K dot product requires head_dim rows in K tile.
let tile_kv = seq_len.clamp(16, 64).max(head_dim);
Self {
seq_len,
head_dim,
tile_q,
tile_kv,
scale: 1.0 / (head_dim as f32).sqrt(),
causal: false,
use_tensor_cores: true,
}
}
/// Set tile sizes for Q and KV
#[must_use]
pub const fn with_tiles(mut self, tile_q: u32, tile_kv: u32) -> Self {
self.tile_q = tile_q;
self.tile_kv = tile_kv;
self
}
/// Enable causal masking for autoregressive attention
#[must_use]
pub const fn with_causal(mut self) -> Self {
self.causal = true;
self
}
/// Set custom scale factor
#[must_use]
pub const fn with_scale(mut self, scale: f32) -> Self {
self.scale = scale;
self
}
/// Enable Tensor Core acceleration for Q×K^T computation
///
/// Uses FP16 WMMA instructions for ~40x speedup on sm_70+ GPUs.
#[must_use]
pub const fn with_tensor_cores(mut self) -> Self {
self.use_tensor_cores = true;
self
}
}
impl Kernel for AttentionKernel {
fn name(&self) -> &str {
match (self.use_tensor_cores, self.causal) {
(true, true) => "flash_attention_tensor_core_causal",
(true, false) => "flash_attention_tensor_core",
(false, true) => "flash_attention_causal",
(false, false) => "flash_attention",
}
}
fn build_ptx(&self) -> PtxKernel {
if self.use_tensor_cores {
self.build_tensor_core_attention()
} else {
self.build_flash_attention()
}
}
}
impl AttentionKernel {
fn build_flash_attention(&self) -> PtxKernel {
// FlashAttention-style tiled attention
// Per Dao et al. - never materialize full N×N matrix
let head_dim = self.head_dim;
let tile_q = self.tile_q;
let tile_kv = self.tile_kv;
let scale = self.scale;
let causal = self.causal;
// Shared memory for Q, K, V tiles
let smem_size = (tile_q * head_dim + tile_kv * head_dim * 2) * 4;
let kernel_name = if causal {
"flash_attention_causal"
} else {
"flash_attention"
};
PtxKernel::new(kernel_name)
.param(PtxType::U64, "q_ptr")
.param(PtxType::U64, "k_ptr")
.param(PtxType::U64, "v_ptr")
.param(PtxType::U64, "o_ptr")
.param(PtxType::U32, "seq_len")
.param(PtxType::U32, "head_dim")
.param(PtxType::U32, "num_heads")
.shared_memory(smem_size as usize)
.build(|ctx| {
// Thread and block indices
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
let _ntid = ctx.special_reg(PtxReg::NtidX);
// Load parameters
let seq_len_param = ctx.load_param_u32("seq_len");
let head_dim_param = ctx.load_param_u32("head_dim");
let num_heads = ctx.load_param_u32("num_heads");
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptr = ctx.load_param_u64("k_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let o_ptr = ctx.load_param_u64("o_ptr");
// ctaid_x = Q block index, ctaid_y = head index
let q_block = ctaid_x;
let head_idx = ctaid_y;
// PARITY-114 FIX: Compute predicate but DON'T exit early
// All threads must participate in barriers
let head_valid = ctx.setp_lt_u32(head_idx, num_heads);
// Calculate head offset (head_idx * seq_len * head_dim)
let head_stride = ctx.mul_u32_reg(seq_len_param, head_dim_param);
let head_offset = ctx.mul_wide_u32_reg(head_idx, head_stride);
let head_offset_bytes = ctx.mul_u64(head_offset, 4);
// Q tile base address for this block
let tile_q_imm = ctx.mov_u32_imm(tile_q);
let q_row_start = ctx.mul_u32_reg(q_block, tile_q_imm);
let q_tile_offset = ctx.mul_wide_u32_reg(q_row_start, head_dim_param);
let q_tile_offset_bytes = ctx.mul_u64(q_tile_offset, 4);
let q_base = ctx.add_u64(q_ptr, head_offset_bytes);
let q_tile_base = ctx.add_u64(q_base, q_tile_offset_bytes);
// ===== Initialize output accumulator and softmax stats =====
// Each thread handles one position in the output
let local_row = ctx.div_u32(tid, head_dim);
let local_col = ctx.rem_u32(tid, head_dim);
// PARITY-114 FIX: Compute predicate but DON'T exit early
// This handles launch configs with more threads than tile_q * head_dim
let tile_q_check = ctx.mov_u32_imm(tile_q);
let thread_valid = ctx.setp_lt_u32(local_row, tile_q_check);
// Initialize output accumulator to 0
let o_acc = ctx.mov_f32_imm(0.0);
// Running max for online softmax
let m_prev = ctx.mov_f32_imm(f32::NEG_INFINITY);
// Running sum of exp
let l_prev = ctx.mov_f32_imm(0.0);
// Calculate number of KV blocks
let tile_kv_imm = ctx.mov_u32_imm(tile_kv);
let num_kv_blocks = ctx.div_u32(seq_len_param, tile_kv);
// ===== Pre-compute element offset (needed for output store after loop) =====
// This must be computed BEFORE the loop, not inside it
let local_row_64 = ctx.cvt_u64_u32(local_row);
let local_col_64 = ctx.cvt_u64_u32(local_col);
let head_dim_64 = ctx.cvt_u64_u32(head_dim_param);
let q_elem_offset = ctx.mul_u64_reg(local_row_64, head_dim_64);
let q_elem_offset_full = ctx.add_u64(q_elem_offset, local_col_64);
let q_elem_offset_bytes = ctx.mul_u64(q_elem_offset_full, 4);
// Shared memory base addresses (32-bit constants for shared memory addressing)
let k_smem_base = tile_q * head_dim * 4;
let v_smem_base = (tile_q * head_dim + tile_kv * head_dim) * 4;
// Loop counter
let kv_block = ctx.mov_u32_imm(0);
ctx.label("kv_loop_start");
// Check if we've processed all KV blocks
let kv_done = ctx.setp_ge_u32(kv_block, num_kv_blocks);
ctx.branch_if(kv_done, "kv_loop_end");
// Causal masking: skip KV blocks that are entirely after current Q block
// Use setp_lt and flip: if q_block < kv_block, skip
if causal {
let causal_skip = ctx.setp_lt_u32(q_block, kv_block);
ctx.branch_if(causal_skip, "kv_loop_end");
}
// Calculate K, V tile base addresses
let kv_row_start = ctx.mul_u32_reg(kv_block, tile_kv_imm);
let kv_tile_offset = ctx.mul_wide_u32_reg(kv_row_start, head_dim_param);
let kv_tile_offset_bytes = ctx.mul_u64(kv_tile_offset, 4);
let k_base = ctx.add_u64(k_ptr, head_offset_bytes);
let k_tile_base = ctx.add_u64(k_base, kv_tile_offset_bytes);
let v_base = ctx.add_u64(v_ptr, head_offset_bytes);
let v_tile_base = ctx.add_u64(v_base, kv_tile_offset_bytes);
// ===== Load Q tile to shared memory =====
// Each thread loads one element (using pre-computed q_elem_offset_bytes)
let q_addr = ctx.add_u64(q_tile_base, q_elem_offset_bytes);
let q_val = ctx.ld_global_f32(q_addr);
// Use 32-bit addressing for shared memory (not 64-bit)
let q_smem_offset = ctx.mul_u32(tid, 4);
ctx.st_shared_f32(q_smem_offset, q_val);
// ===== Load K tile to shared memory =====
// (k_smem_base computed before loop, use 32-bit)
let k_addr = ctx.add_u64(k_tile_base, q_elem_offset_bytes);
let k_val = ctx.ld_global_f32(k_addr);
let k_smem_base_u32 = ctx.mov_u32_imm(k_smem_base);
let k_smem_offset_local = ctx.mul_u32(tid, 4);
let k_smem_offset = ctx.add_u32_reg(k_smem_base_u32, k_smem_offset_local);
ctx.st_shared_f32(k_smem_offset, k_val);
// ===== Load V tile to shared memory =====
// (v_smem_base computed before loop, use 32-bit)
let v_addr = ctx.add_u64(v_tile_base, q_elem_offset_bytes);
let v_val = ctx.ld_global_f32(v_addr);
let v_smem_base_u32 = ctx.mov_u32_imm(v_smem_base);
let v_smem_offset_local = ctx.mul_u32(tid, 4);
let v_smem_offset = ctx.add_u32_reg(v_smem_base_u32, v_smem_offset_local);
ctx.st_shared_f32(v_smem_offset, v_val);
ctx.bar_sync(0);
// ===== Compute S = Q × K^T (dot product for this thread's row) =====
// Each thread computes attention score for its Q row
let s_acc = ctx.mov_f32_imm(0.0);
// Inner loop over head_dim elements
let d_idx = ctx.mov_u32_imm(0);
ctx.label("dot_loop_start");
let d_done = ctx.setp_ge_u32(d_idx, head_dim_param);
ctx.branch_if(d_done, "dot_loop_end");
// Load Q[local_row, d_idx] from shared memory (32-bit addressing)
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
let q_row_offset = ctx.mul_u32_reg(local_row, head_dim_u32);
let q_elem_smem = ctx.add_u32_reg(q_row_offset, d_idx);
let q_elem_smem_bytes = ctx.mul_u32(q_elem_smem, 4);
let q_dot_val = ctx.ld_shared_f32(q_elem_smem_bytes);
// Load K[local_col, d_idx] from shared memory (32-bit addressing, K is transposed conceptually)
let k_col_offset = ctx.mul_u32_reg(local_col, head_dim_u32);
let k_elem_smem = ctx.add_u32_reg(k_col_offset, d_idx);
let k_elem_smem_bytes = ctx.mul_u32(k_elem_smem, 4);
let k_smem_base_loop = ctx.mov_u32_imm(k_smem_base);
let k_elem_smem_full = ctx.add_u32_reg(k_smem_base_loop, k_elem_smem_bytes);
let k_dot_val = ctx.ld_shared_f32(k_elem_smem_full);
// Accumulate Q[i,d] * K[j,d] - IN-PLACE UPDATE
ctx.fma_f32_inplace(s_acc, q_dot_val, k_dot_val);
// Increment and loop back - IN-PLACE UPDATE
ctx.add_u32_inplace(d_idx, 1);
ctx.branch("dot_loop_start");
ctx.label("dot_loop_end");
// ===== Apply scale factor =====
let scale_reg = ctx.mov_f32_imm(scale);
let s_scaled = ctx.mul_f32(s_acc, scale_reg);
// ===== Online softmax update =====
// m_new = max(m_prev, s_scaled)
let m_new = ctx.max_f32(m_prev, s_scaled);
// scale_factor = exp(m_prev - m_new)
let m_diff = ctx.sub_f32(m_prev, m_new);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let m_diff_scaled = ctx.mul_f32(m_diff, log2_e);
let scale_factor = ctx.ex2_f32(m_diff_scaled);
// p = exp(s_scaled - m_new)
let s_shifted = ctx.sub_f32(s_scaled, m_new);
let s_shifted_scaled = ctx.mul_f32(s_shifted, log2_e);
let p_val = ctx.ex2_f32(s_shifted_scaled);
// l_new = scale_factor * l_prev + p
let l_scaled = ctx.mul_f32(scale_factor, l_prev);
let l_new = ctx.add_f32(l_scaled, p_val);
// ===== Update output accumulator =====
// o_new = (scale_factor * l_prev * o_prev + p * v) / l_new
let o_scaled = ctx.mul_f32(o_acc, scale_factor);
let o_weighted = ctx.mul_f32(o_scaled, l_prev);
// Load V value from shared memory (32-bit addressing)
let v_elem_smem_bytes = ctx.mul_u32(k_elem_smem, 4); // Reuse k offset calculation
let v_smem_base_loop = ctx.mov_u32_imm(v_smem_base);
let v_elem_smem_full = ctx.add_u32_reg(v_smem_base_loop, v_elem_smem_bytes);
let v_out_val = ctx.ld_shared_f32(v_elem_smem_full);
let pv = ctx.mul_f32(p_val, v_out_val);
let o_sum = ctx.add_f32(o_weighted, pv);
let o_new = ctx.div_f32(o_sum, l_new);
// Update running stats for next iteration - COPY TO ACCUMULATORS
ctx.mov_f32_reg(m_prev, m_new);
ctx.mov_f32_reg(l_prev, l_new);
ctx.mov_f32_reg(o_acc, o_new);
ctx.bar_sync(1);
// Increment KV block counter and loop back - IN-PLACE UPDATE
ctx.add_u32_inplace(kv_block, 1);
ctx.branch("kv_loop_start");
ctx.label("kv_loop_end");
// PARITY-114 FIX: Bounds check HERE (after all threads finished barriers)
// Only threads with valid output coordinates store to O
ctx.branch_if_not(head_valid, "exit");
ctx.branch_if_not(thread_valid, "exit");
// ===== Store output =====
// Calculate output address
let o_base = ctx.add_u64(o_ptr, head_offset_bytes);
let o_tile_offset = ctx.mul_wide_u32_reg(q_row_start, head_dim_param);
let o_tile_offset_bytes = ctx.mul_u64(o_tile_offset, 4);
let o_tile_base = ctx.add_u64(o_base, o_tile_offset_bytes);
let o_addr = ctx.add_u64(o_tile_base, q_elem_offset_bytes);
// Store accumulated output (o_acc is always valid, even if loop never ran)
ctx.st_global_f32(o_addr, o_acc);
ctx.label("exit");
ctx.ret();
})
}
/// Build Tensor Core FlashAttention using WMMA for Q×K^T
///
/// Key optimization: Replace serial FP32 dot product with FP16 WMMA 16×16×16 tiles.
/// For head_dim=128, we need 8 WMMA operations per S[i,j] tile (128/16=8).
///
/// Algorithm:
/// 1. Load Q tile [16×head_dim] to shared memory (FP32→FP16)
/// 2. Load K tile [16×head_dim] to shared memory (FP32→FP16)
/// 3. For each 16-element chunk along head_dim:
/// - WMMA: S_acc[16×16] += Q_frag[16×16] × K_frag^T[16×16]
/// 4. Apply scale, online softmax, V multiplication
///
/// Launch config: grid_2d(seq_len/16, num_heads, 32, 1) - one warp per 16×16 Q×K tile
#[allow(clippy::too_many_lines)]
fn build_tensor_core_attention(&self) -> PtxKernel {
let head_dim = self.head_dim;
let tile_q = 16_u32; // Fixed for WMMA
let tile_kv = 16_u32; // Fixed for WMMA
let scale = self.scale;
let causal = self.causal;
// Number of WMMA steps to accumulate the full dot product
// For head_dim=128, n_k_steps = 8
let n_k_steps = (head_dim + 15) / 16;
// Shared memory layout:
// Q tile: 16 × head_dim × 2 bytes (FP16)
// K tile: 16 × head_dim × 2 bytes (FP16)
// V tile: 16 × head_dim × 4 bytes (FP32)
// S tile: 16 × 16 × 4 bytes (FP32 attention scores)
let q_smem_size = tile_q * head_dim * 2;
let k_smem_size = tile_kv * head_dim * 2;
let v_smem_size = tile_kv * head_dim * 4;
let s_smem_size = tile_q * tile_kv * 4;
let smem_size = q_smem_size + k_smem_size + v_smem_size + s_smem_size;
let kernel_name = if causal {
"flash_attention_tensor_core_causal"
} else {
"flash_attention_tensor_core"
};
PtxKernel::new(kernel_name)
.param(PtxType::U64, "q_ptr")
.param(PtxType::U64, "k_ptr")
.param(PtxType::U64, "v_ptr")
.param(PtxType::U64, "o_ptr")
.param(PtxType::U32, "seq_len")
.param(PtxType::U32, "head_dim")
.param(PtxType::U32, "num_heads")
.shared_memory(smem_size as usize)
.build(|ctx| {
// WMMA operates at warp level (32 threads cooperatively)
// Grid: (seq_len/16) x num_heads
// Block: 32 threads (1 warp)
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
// Load parameters
let seq_len_param = ctx.load_param_u32("seq_len");
let head_dim_param = ctx.load_param_u32("head_dim");
let num_heads = ctx.load_param_u32("num_heads");
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptr = ctx.load_param_u64("k_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let o_ptr = ctx.load_param_u64("o_ptr");
// Block index determines which Q tile we're computing
let q_block = ctaid_x;
let head_idx = ctaid_y;
// PARITY-114 FIX: Compute predicate but DON'T exit early
// All threads must participate in barriers (WMMA requires full warp)
let head_valid = ctx.setp_lt_u32(head_idx, num_heads);
// Calculate head offset
let head_stride = ctx.mul_u32_reg(seq_len_param, head_dim_param);
let head_offset = ctx.mul_wide_u32_reg(head_idx, head_stride);
let head_offset_bytes = ctx.mul_u64(head_offset, 4);
// Shared memory base addresses (need actual smem pointer, not just offset)
// For regular loads/stores, u32 offset from smem[0] works
// For WMMA, we need the actual shared memory address
let smem_ptr = ctx.shared_base_addr(); // u64 pointer to smem
let q_smem_base = ctx.mov_u32_imm(0);
let k_smem_base = ctx.mov_u32_imm(q_smem_size);
let v_smem_base = ctx.mov_u32_imm(q_smem_size + k_smem_size);
let s_smem_base = ctx.mov_u32_imm(q_smem_size + k_smem_size + v_smem_size);
// Pre-compute u64 pointers for WMMA operations
let q_smem_base_64 = ctx.cvt_u64_u32(q_smem_base);
let q_smem_ptr = ctx.add_u64(smem_ptr, q_smem_base_64);
let k_smem_base_64 = ctx.cvt_u64_u32(k_smem_base);
let k_smem_ptr = ctx.add_u64(smem_ptr, k_smem_base_64);
let s_smem_base_64 = ctx.cvt_u64_u32(s_smem_base);
let s_smem_ptr = ctx.add_u64(smem_ptr, s_smem_base_64);
// Q tile base address
let tile_16 = ctx.mov_u32_imm(16);
let q_row_start = ctx.mul_u32_reg(q_block, tile_16);
let q_tile_offset = ctx.mul_wide_u32_reg(q_row_start, head_dim_param);
let q_tile_offset_bytes = ctx.mul_u64(q_tile_offset, 4);
let q_base = ctx.add_u64(q_ptr, head_offset_bytes);
let q_tile_base = ctx.add_u64(q_base, q_tile_offset_bytes);
// ===== Load Q tile to shared memory (FP32 → FP16) =====
// Each of 32 threads loads multiple elements
// Total: 16 × head_dim elements
let q_total_elems = ctx.mov_u32_imm(16 * head_dim);
let elems_per_thread = ctx.div_u32(q_total_elems, 32);
let my_start = ctx.mul_u32_reg(tid, elems_per_thread);
let load_idx = ctx.mov_u32_imm(0);
ctx.label("load_q_loop");
let load_done = ctx.setp_ge_u32(load_idx, elems_per_thread);
ctx.branch_if(load_done, "load_q_end");
let elem_idx = ctx.add_u32_reg(my_start, load_idx);
let elem_check = ctx.setp_ge_u32(elem_idx, q_total_elems);
ctx.branch_if(elem_check, "load_q_end");
// Load from global (FP32)
let q_global_offset = ctx.mul_wide_u32(elem_idx, 4);
let q_addr = ctx.add_u64(q_tile_base, q_global_offset);
let q_val_f32 = ctx.ld_global_f32(q_addr);
// Convert to FP16 and store to shared
let q_val_f16 = ctx.cvt_f16_f32(q_val_f32);
let q_smem_offset = ctx.mul_u32(elem_idx, 2);
let q_smem_addr = ctx.add_u32_reg(q_smem_base, q_smem_offset);
ctx.st_shared_f16(q_smem_addr, q_val_f16);
ctx.add_u32_inplace(load_idx, 1);
ctx.branch("load_q_loop");
ctx.label("load_q_end");
// Initialize output accumulators (16 values per thread for the output row)
// Each thread in the warp contributes to 16×16 output tile
// For simplicity, we'll use per-row accumulators
let o_acc = ctx.mov_f32_imm(0.0);
let m_prev = ctx.mov_f32_imm(f32::NEG_INFINITY);
let l_prev = ctx.mov_f32_imm(0.0);
// Number of KV blocks
let num_kv_blocks = ctx.div_u32(seq_len_param, 16);
let kv_block = ctx.mov_u32_imm(0);
ctx.label("kv_loop_start");
let kv_done = ctx.setp_ge_u32(kv_block, num_kv_blocks);
ctx.branch_if(kv_done, "kv_loop_end");
// Causal masking
if causal {
let causal_skip = ctx.setp_lt_u32(q_block, kv_block);
ctx.branch_if(causal_skip, "kv_loop_end");
}
// Calculate K, V tile base addresses
let kv_row_start = ctx.mul_u32_reg(kv_block, tile_16);
let kv_tile_offset = ctx.mul_wide_u32_reg(kv_row_start, head_dim_param);
let kv_tile_offset_bytes = ctx.mul_u64(kv_tile_offset, 4);
let k_base = ctx.add_u64(k_ptr, head_offset_bytes);
let k_tile_base = ctx.add_u64(k_base, kv_tile_offset_bytes);
let v_base = ctx.add_u64(v_ptr, head_offset_bytes);
let v_tile_base = ctx.add_u64(v_base, kv_tile_offset_bytes);
// ===== Load K tile to shared memory (FP32 → FP16) =====
let load_k_idx = ctx.mov_u32_imm(0);
ctx.label("load_k_loop");
let k_load_done = ctx.setp_ge_u32(load_k_idx, elems_per_thread);
ctx.branch_if(k_load_done, "load_k_end");
let k_elem_idx = ctx.add_u32_reg(my_start, load_k_idx);
let k_elem_check = ctx.setp_ge_u32(k_elem_idx, q_total_elems);
ctx.branch_if(k_elem_check, "load_k_end");
let k_global_offset = ctx.mul_wide_u32(k_elem_idx, 4);
let k_addr = ctx.add_u64(k_tile_base, k_global_offset);
let k_val_f32 = ctx.ld_global_f32(k_addr);
let k_val_f16 = ctx.cvt_f16_f32(k_val_f32);
let k_smem_offset = ctx.mul_u32(k_elem_idx, 2);
let k_smem_addr = ctx.add_u32_reg(k_smem_base, k_smem_offset);
ctx.st_shared_f16(k_smem_addr, k_val_f16);
ctx.add_u32_inplace(load_k_idx, 1);
ctx.branch("load_k_loop");
ctx.label("load_k_end");
// ===== Load V tile to shared memory (FP32) =====
let load_v_idx = ctx.mov_u32_imm(0);
ctx.label("load_v_loop");
let v_load_done = ctx.setp_ge_u32(load_v_idx, elems_per_thread);
ctx.branch_if(v_load_done, "load_v_end");
let v_elem_idx = ctx.add_u32_reg(my_start, load_v_idx);
let v_elem_check = ctx.setp_ge_u32(v_elem_idx, q_total_elems);
ctx.branch_if(v_elem_check, "load_v_end");
let v_global_offset = ctx.mul_wide_u32(v_elem_idx, 4);
let v_addr = ctx.add_u64(v_tile_base, v_global_offset);
let v_val = ctx.ld_global_f32(v_addr);
let v_smem_offset = ctx.mul_u32(v_elem_idx, 4);
let v_smem_addr = ctx.add_u32_reg(v_smem_base, v_smem_offset);
ctx.st_shared_f32(v_smem_addr, v_val);
ctx.add_u32_inplace(load_v_idx, 1);
ctx.branch("load_v_loop");
ctx.label("load_v_end");
ctx.bar_sync(0);
// ===== Compute S = Q × K^T using WMMA =====
// Initialize S tile accumulator to zero (8 f32 registers)
let mut frag_c = Vec::with_capacity(8);
for _ in 0..8 {
frag_c.push(ctx.mov_f32_imm(0.0));
}
// Loop over head_dim in steps of 16
let k_step = ctx.mov_u32_imm(0);
let n_k_steps_reg = ctx.mov_u32_imm(n_k_steps);
ctx.label("wmma_loop_start");
let wmma_done = ctx.setp_ge_u32(k_step, n_k_steps_reg);
ctx.branch_if(wmma_done, "wmma_loop_end");
// Q fragment address: q_smem_ptr + k_step * 16 * 2 bytes
let q_frag_offset = ctx.mul_u32(k_step, 32); // 16 elements × 2 bytes
let q_frag_offset_64 = ctx.cvt_u64_u32(q_frag_offset);
let q_frag_addr = ctx.add_u64(q_smem_ptr, q_frag_offset_64);
let frag_a = ctx.wmma_load_a_f16(q_frag_addr, head_dim, WmmaLayout::RowMajor);
// K fragment address: k_smem_ptr + k_step * 16 * 2 - needs col-major for K^T
let k_frag_offset = ctx.mul_u32(k_step, 32);
let k_frag_offset_64 = ctx.cvt_u64_u32(k_frag_offset);
let k_frag_addr = ctx.add_u64(k_smem_ptr, k_frag_offset_64);
let frag_b = ctx.wmma_load_b_f16(k_frag_addr, head_dim, WmmaLayout::ColMajor);
// WMMA MMA: D = A × B + C (accumulates into D fragment)
let frag_d = ctx.wmma_mma_f16_f32(&frag_a, &frag_b, &frag_c);
// Copy D -> C for next iteration's accumulation (8 f32 registers)
for i in 0..8 {
ctx.mov_f32_reg(frag_c[i], frag_d[i]);
}
ctx.add_u32_inplace(k_step, 1);
ctx.branch("wmma_loop_start");
ctx.label("wmma_loop_end");
// Store S tile (accumulated result in D) to shared memory for softmax
ctx.wmma_store_d_f32(s_smem_ptr, &frag_d, 16, WmmaLayout::RowMajor);
ctx.bar_sync(1);
// ===== Apply scale and online softmax =====
// Each thread handles one element of the 16×16 S tile
// Thread tid handles element (tid/16, tid%16) for tid < 256
// We have 32 threads, so each thread handles 8 elements
let s_idx = ctx.mov_u32_imm(0);
let loop_limit_8 = ctx.mov_u32_imm(8);
let elems_256 = ctx.mov_u32_imm(256);
let step_8 = ctx.mov_u32_imm(8);
ctx.label("softmax_loop_start");
let s_idx_check = ctx.setp_ge_u32(s_idx, loop_limit_8);
ctx.branch_if(s_idx_check, "softmax_loop_end");
// Calculate which S element this iteration handles
let s_elem = ctx.mad_lo_u32(tid, step_8, s_idx);
let s_elem_check = ctx.setp_ge_u32(s_elem, elems_256);
ctx.branch_if(s_elem_check, "softmax_next");
// Load S[i,j] from shared memory
let s_offset = ctx.mul_u32(s_elem, 4);
let s_addr = ctx.add_u32_reg(s_smem_base, s_offset);
let s_val = ctx.ld_shared_f32(s_addr);
// Apply scale
let scale_reg = ctx.mov_f32_imm(scale);
let s_scaled = ctx.mul_f32(s_val, scale_reg);
// Online softmax update (simplified - each thread maintains local max/sum)
let m_new = ctx.max_f32(m_prev, s_scaled);
let m_diff = ctx.sub_f32(m_prev, m_new);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let m_diff_scaled = ctx.mul_f32(m_diff, log2_e);
let scale_factor = ctx.ex2_f32(m_diff_scaled);
let s_shifted = ctx.sub_f32(s_scaled, m_new);
let s_shifted_scaled = ctx.mul_f32(s_shifted, log2_e);
let p_val = ctx.ex2_f32(s_shifted_scaled);
let l_scaled = ctx.mul_f32(scale_factor, l_prev);
let l_new = ctx.add_f32(l_scaled, p_val);
// Store scaled attention weight back
ctx.st_shared_f32(s_addr, p_val);
ctx.mov_f32_reg(m_prev, m_new);
ctx.mov_f32_reg(l_prev, l_new);
ctx.label("softmax_next");
ctx.add_u32_inplace(s_idx, 1);
ctx.branch("softmax_loop_start");
ctx.label("softmax_loop_end");
ctx.bar_sync(2);
// ===== Compute O += softmax(S) × V =====
// Load attention weights and multiply with V
// This is another GEMM: P[16×16] × V[16×head_dim] = O_update[16×head_dim]
// For simplicity, we'll use the scalar path for V multiplication
// (Tensor Core V multiplication would require additional WMMA calls)
let v_col = ctx.rem_u32(tid, head_dim);
let v_row_idx = ctx.mov_u32_imm(0);
ctx.label("v_loop_start");
let v_loop_done = ctx.setp_ge_u32(v_row_idx, tile_16);
ctx.branch_if(v_loop_done, "v_loop_end");
// Load attention weight P[row, v_row_idx]
let p_idx_base = ctx.mul_u32_reg(v_row_idx, tile_16);
let p_offset = ctx.mul_u32(p_idx_base, 4);
let p_addr = ctx.add_u32_reg(s_smem_base, p_offset);
let p_weight = ctx.ld_shared_f32(p_addr);
// Load V[v_row_idx, v_col]
let v_idx = ctx.mad_lo_u32(v_row_idx, head_dim_param, v_col);
let v_offset = ctx.mul_u32(v_idx, 4);
let v_elem_addr = ctx.add_u32_reg(v_smem_base, v_offset);
let v_elem = ctx.ld_shared_f32(v_elem_addr);
// Accumulate: o_acc += p_weight * v_elem
ctx.fma_f32_inplace(o_acc, p_weight, v_elem);
ctx.add_u32_inplace(v_row_idx, 1);
ctx.branch("v_loop_start");
ctx.label("v_loop_end");
ctx.bar_sync(3);
ctx.add_u32_inplace(kv_block, 1);
ctx.branch("kv_loop_start");
ctx.label("kv_loop_end");
// PARITY-114 FIX: Bounds check HERE (after all threads finished barriers)
// Only threads with valid heads store to O
ctx.branch_if_not(head_valid, "exit");
// ===== Normalize and store output =====
let o_normalized = ctx.div_f32(o_acc, l_prev);
// Calculate output address
let o_base = ctx.add_u64(o_ptr, head_offset_bytes);
let tid_div_hd = ctx.div_u32(tid, head_dim);
let o_row = ctx.mad_lo_u32(q_block, tile_16, tid_div_hd);
let o_col = ctx.rem_u32(tid, head_dim);
let head_dim_reg = ctx.mov_u32_imm(head_dim);
let o_idx = ctx.mad_lo_u32(o_row, head_dim_reg, o_col);
let o_offset = ctx.mul_wide_u32(o_idx, 4);
let o_addr = ctx.add_u64(o_base, o_offset);
ctx.st_global_f32(o_addr, o_normalized);
ctx.label("exit");
ctx.ret();
})
}
}