baracuda-kernels-sys 0.0.1-alpha.68

Compiled bespoke .cu kernel template instantiations for the baracuda ML kernel facade plus C-ABI FFI facades for the library-backed plans (cuDNN conv/pool, cuSOLVER linalg, cuFFT/cuRAND, CUTLASS GEMM re-export). Hosts curated CUDA kernel sources (int8/FP8/int4/bin GEMM RRR, elementwise, reduce, norm, attention, …), builds them via baracuda-forge, exposes extern "C" entry points for the safe baracuda-kernels crate. CUTLASS template kernels live in the sibling baracuda-cutlass-kernels-sys crate and are re-exported here under the unified baracuda_kernels_gemm_* namespace.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
// SPDX-FileCopyrightText: 2026 Eric Evans and the baracuda contributors
// SPDX-License-Identifier: MIT OR Apache-2.0
//
// FlashDecoding — split-K parallel attention decode for seq_q = 1.
//
// The Phase 10 trailblazer `flash_sdpa_fw_kernel` is tuned for prefill
// (seq_q comparable to seq_k); its Br=64 q-tile shape leaves the
// seq_q=1 decode regime under-parallelized. FA2 hits the same wall —
// FA2's tile structure assumes seq_q ≥ block-rows.
//
// FlashDecoding (Dao 2023, "Flash-Decoding for Long-Context Inference")
// flips the parallelism axis: split K into S chunks of size CHUNK_K
// each, launch one block per (b, h, k_split), and combine per-split
// online-softmax partials in a small reduction kernel.
//
// Pipeline:
//
//   Kernel 1: split_kernel<T>
//     gridDim  = (S, H, B)
//     blockDim = 128 (4 warps; each warp owns 32 elements of head_dim)
//
//     Each block:
//       1. Load Q[b, h, 0, :D] into SMEM (cooperative, vectorized).
//       2. For each k in [k_split * CHUNK_K, (k_split+1) * CHUNK_K):
//            - Load K[b, h, k, :D] into SMEM.
//            - Score s_k = Q · K[k] * scale (warp-reduce over D).
//            - Online softmax update with running (m, l, o) accumulators.
//       3. Write partial (m, l, o) for this split to workspace.
//
//   Kernel 2: combine_kernel<T>
//     gridDim  = (1, H, B)
//     blockDim = 128
//
//     Each block reads the S partials for its (b, h) and merges them
//     via the standard online-softmax associative merge:
//       global_m   = max over splits of partial_m[s]
//       alpha_s    = exp(partial_m[s] - global_m)
//       global_l   = Σ_s alpha_s * partial_l[s]
//       global_o_d = Σ_s alpha_s * partial_o_d[s]
//     Final: y[b, h, 0, d] = global_o_d / global_l.
//
// Per-block partial storage in workspace:
//   partial_m: [B, H, S]    × f32     (4 B per split)
//   partial_l: [B, H, S]    × f32     (4 B per split)
//   partial_o: [B, H, S, D] × f32     (4 D B per split)
//
// For (B=1, H=32, S=64, D=128) workspace ≈ 1 MB. The same workspace is
// reused across launches (caller passes it in via the Workspace::Borrowed
// path — same contract as FA2).
//
// Tier-1 scope (Phase 73 follow-up):
//   - dtypes: f16, bf16
//   - head_dim ∈ [1, 128]
//   - GQA via stride: K/V supply `head_stride` separately from the head
//     index — the launcher handles broadcast-stride at the host level.
//   - seq_q = 1 (decode); arbitrary seq_k.
//   - is_causal: ignored (decode is always non-causal vs the full KV
//     history — caller is responsible for slicing the cache).
//
// Out of scope (deferred):
//   - f32 / f64 (decode is half-precision in practice).
//   - sliding window, ALiBi, soft-cap (caller masks beforehand).
//   - backward (decode is FW-only).
//   - tensor-core MMA in the Q·K dot product — first cut is warp-shuffle
//     reduce. Tensor-core retune is a follow-up phase once perf bench
//     numbers are in.

#pragma once

#include <cstddef>
#include <cstdint>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <mma.h>

#include "baracuda_smem_reduce.cuh"

namespace baracuda { namespace flash_decoding {

constexpr int kMaxD = 128;
constexpr int kChunkK = 256;
constexpr int kThreadsPerBlock = 128;

// =============================================================================
// Type helpers — half/bf16 → f32 accumulator.
// =============================================================================

template <typename T> struct LoadAcc;
template <> struct LoadAcc<__half> {
    static __device__ __forceinline__ float load(__half x) { return __half2float(x); }
    static __device__ __forceinline__ __half store(float x) { return __float2half(x); }
};
template <> struct LoadAcc<__nv_bfloat16> {
    static __device__ __forceinline__ float load(__nv_bfloat16 x) { return __bfloat162float(x); }
    static __device__ __forceinline__ __nv_bfloat16 store(float x) { return __float2bfloat16(x); }
};

// =============================================================================
// Split kernel — one block per (b, h, k_split). Produces a partial
// (m, l, o[D]) for its K chunk.
// =============================================================================
//
// Strides are in element units, matching the rest of baracuda's strided
// FFI convention. The caller computes per-(b, h) base offsets via the
// host-side strides (so GQA broadcast is `stride_kv[1] = 0`).

template <typename T>
__global__ void flash_decoding_split_kernel(
    const T* __restrict__ q,
    const T* __restrict__ k,
    const T* __restrict__ v,
    float* __restrict__ partial_m,
    float* __restrict__ partial_l,
    float* __restrict__ partial_o,
    int32_t batch, int32_t heads, int32_t k_len,
    int32_t head_dim,
    int32_t num_splits,
    int32_t group_size,           // H_q / H_kv (1 for pure MHA)
    int64_t q_b_stride, int64_t q_h_stride,
    int64_t k_b_stride, int64_t k_h_stride, int64_t k_seq_stride,
    int64_t v_b_stride, int64_t v_h_stride, int64_t v_seq_stride,
    float scale)
{
    const int s = blockIdx.x;   // split idx
    const int h = blockIdx.y;   // Q-head index (in [0, H_q))
    const int b = blockIdx.z;
    if (s >= num_splits || h >= heads || b >= batch) return;
    // For GQA: every `group_size` Q heads share one K/V head.
    const int h_kv = h / group_size;

    const int tid = threadIdx.x;
    const int nthreads = blockDim.x;

    const int k_start = s * kChunkK;
    const int k_end   = min(k_start + kChunkK, k_len);
    if (k_start >= k_end) {
        // No work in this split — write neutral partials so the combine
        // kernel doesn't propagate NaNs.
        if (tid == 0) {
            int64_t pidx = ((int64_t)b * heads + h) * num_splits + s;
            partial_m[pidx] = -INFINITY;
            partial_l[pidx] = 0.0f;
        }
        for (int d = tid; d < head_dim; d += nthreads) {
            int64_t poff = (((int64_t)b * heads + h) * num_splits + s) * (int64_t)head_dim + d;
            partial_o[poff] = 0.0f;
        }
        return;
    }

    // Q tile in SMEM (only D elements — the single Q row for this BH).
    __shared__ float sQ[kMaxD];
    // Per-block running (m, l, o[D]) accumulators in SMEM.
    __shared__ float sM;
    __shared__ float sL;
    __shared__ float sO[kMaxD];
    __shared__ float sS[kChunkK];     // scores for this chunk
    __shared__ float warp_buf[32];    // for block_reduce_*

    // Load Q[b, h, 0, :D] once. q_b/q_h strides give the BH base; seq=0
    // so no seq-stride contribution.
    const T* q_bh = q + (int64_t)b * q_b_stride + (int64_t)h * q_h_stride;
    for (int d = tid; d < head_dim; d += nthreads) {
        sQ[d] = LoadAcc<T>::load(q_bh[d]);
    }
    if (tid == 0) { sM = -INFINITY; sL = 0.0f; }
    for (int d = tid; d < head_dim; d += nthreads) {
        sO[d] = 0.0f;
    }
    __syncthreads();

    // K/V indexed by the KV-head id (collapses group_size Q heads onto
    // the same K/V slice — the standard GQA broadcast).
    const T* k_bh = k + (int64_t)b * k_b_stride + (int64_t)h_kv * k_h_stride;
    const T* v_bh = v + (int64_t)b * v_b_stride + (int64_t)h_kv * v_h_stride;

    const int chunk_len = k_end - k_start;
    const int warp_id = tid >> 5;
    const int lane    = tid & 31;
    const int num_warps = nthreads >> 5;   // 4 for kThreadsPerBlock = 128

    // Pass 1 — scores. Warp-cooperative dot product: one warp owns one
    // K-row at a time, all 32 lanes cooperate along the D axis.
    //
    // Why this layout (and not 1 thread per K row, walking D serially):
    // the per-thread serial-D pattern has each warp's 32 threads load
    // K from 32 *different* rows at the same d step. That's 32 cache
    // lines fetched per d step per warp — high pressure, poor reuse.
    //
    // With warp-along-D, 32 lanes of one warp load contiguous D=32
    // halfs of the SAME row — fully coalesced. The warp processes
    // `chunk_len / num_warps` rows over the chunk; 4 warps × 64 rows
    // = 256 rows total for kChunkK = 256.
    //
    // head_dim must be ≥ 32 for this to be meaningful; for D=128 each
    // lane handles D/32 = 4 elements per row.
    if (head_dim >= 32) {
        for (int k_off = warp_id; k_off < chunk_len; k_off += num_warps) {
            const int k_abs = k_start + k_off;
            const T* k_row = k_bh + (int64_t)k_abs * k_seq_stride;
            float acc = 0.0f;
            // Each lane covers D/32 contiguous d-slots, interleaved by
            // warp stride. For D=128: lanes 0..31 own d 0..31, then
            // d 32..63, etc.
            for (int d = lane; d < head_dim; d += 32) {
                acc += sQ[d] * LoadAcc<T>::load(k_row[d]);
            }
            // Warp-reduce sum across the 32 lanes.
            #pragma unroll
            for (int delta = 16; delta > 0; delta >>= 1) {
                acc += __shfl_xor_sync(0xffffffff, acc, delta, 32);
            }
            if (lane == 0) {
                sS[k_off] = acc * scale;
            }
        }
    } else {
        // Tiny-D fallback — one thread per K-row, serial-D. Same shape
        // as the legacy path; we don't care about perf here because the
        // bandwidth math is dominated by the long-D shapes anyway.
        for (int ki = tid; ki < chunk_len; ki += nthreads) {
            const T* k_row = k_bh + (int64_t)(k_start + ki) * k_seq_stride;
            float acc = 0.0f;
            for (int d = 0; d < head_dim; ++d) {
                acc += sQ[d] * LoadAcc<T>::load(k_row[d]);
            }
            sS[ki] = acc * scale;
        }
    }
    // Mask the tail of the chunk if k_len doesn't fill kChunkK.
    for (int ki = tid + chunk_len; ki < kChunkK; ki += nthreads) {
        sS[ki] = -INFINITY;
    }
    __syncthreads();

    // 2. Chunk-local max over scores.
    float local_max = -INFINITY;
    for (int ki = tid; ki < chunk_len; ki += nthreads) {
        if (sS[ki] > local_max) local_max = sS[ki];
    }
    float chunk_max = block_reduce_max_f32(local_max, warp_buf);

    // 3. Chunk-local sum of exp(s - chunk_max).
    float local_sum = 0.0f;
    for (int ki = tid; ki < chunk_len; ki += nthreads) {
        float p = expf(sS[ki] - chunk_max);
        sS[ki] = p;       // overwrite with softmax weight
        local_sum += p;
    }
    float chunk_sum = block_reduce_sum_f32(local_sum, warp_buf);

    // Pass 2 — V accumulation. The "1 thread per d, walks all K-rows"
    // pattern (used in v1) is ALREADY coalesced across a warp because
    // 32 lanes share the same `ki` and load V[ki, lane..lane+31] which
    // is one contiguous row segment per cache line. Keep this layout.
    for (int d = tid; d < head_dim; d += nthreads) {
        float acc = 0.0f;
        for (int ki = 0; ki < chunk_len; ++ki) {
            const T* v_row = v_bh + (int64_t)(k_start + ki) * v_seq_stride;
            acc += sS[ki] * LoadAcc<T>::load(v_row[d]);
        }
        sO[d] = acc;
    }
    __syncthreads();

    // 5. Write partials. m and l describe the chunk-local softmax;
    //    o is the chunk-local weighted V sum (un-normalized). The
    //    combine kernel handles the global merge.
    if (tid == 0) {
        int64_t pidx = ((int64_t)b * heads + h) * num_splits + s;
        partial_m[pidx] = chunk_max;
        partial_l[pidx] = chunk_sum;
    }
    for (int d = tid; d < head_dim; d += nthreads) {
        int64_t poff = (((int64_t)b * heads + h) * num_splits + s) * (int64_t)head_dim + d;
        partial_o[poff] = sO[d];
    }
}

// =============================================================================
// GQA-batched WMMA split kernel — Tier-2 (Phase 73 follow-up #3).
//
// One block per (k_split, h_kv, b). The block computes attention for
// ALL `group_size` Q heads in this KV group at once, batching them
// in the WMMA M-tile. For Llama-3-class GQA (group_size=4 or 8) this
// uses 25-50% of WMMA's M-tile capacity; for full MQA (group_size=16+
// when H_q=H_kv*group) it uses 100%.
//
// Why this beats the SIMT kernel for GQA:
//   - 1 block does the work of `group_size` SIMT blocks (4-8× fewer
//     kernel launch grids).
//   - K/V loaded ONCE per block (vs once per Q head in the SIMT path)
//     — eliminates redundant L2 traffic across Q heads in a group.
//   - QK^T and PV both run on tensor cores at fp16/bf16 → fp32 MMA.
//
// Constraints:
//   - group_size ∈ [1, kWmmaM] (M-tile width = 16).
//   - head_dim must be a multiple of kWmmaK (= 16).
//   - chunk_len rounded up to kWmmaN multiples for the N-tile loop.
//   - dtype: __half or __nv_bfloat16.
//   - blockDim.x = kThreadsPerBlock = 128 (4 warps).
//
// SMEM layout (per block):
//   sQ        [kWmmaM × kMaxD]                    half/bf16
//   sK_tile   [kWmmaN × kMaxD]   one K sub-tile   half/bf16
//   sV_tile   [kWmmaN × kMaxD]   one V sub-tile   half/bf16
//   sScores   [kWmmaM × kChunkK]                  float
//   sO        [kWmmaM × kMaxD]                    float
//   warp_buf  [32]                                float
//
// Total ≈ 16*128*2 + 16*128*2 + 16*128*2 + 16*256*4 + 16*128*4 + 128
//       = 4K + 4K + 4K + 16K + 8K + 0.5K ≈ 36 KB — fits in 48 KB.
// =============================================================================

constexpr int kWmmaM = 16;
constexpr int kWmmaN = 16;
constexpr int kWmmaK = 16;

namespace tc {
using namespace nvcuda;

// Convert f32 → T for storing a Q row into the WMMA half-precision
// fragment buffer.
template <typename T>
struct ToHalf;
template <>
struct ToHalf<__half> {
    static __device__ __forceinline__ __half cvt(float x) { return __float2half(x); }
};
template <>
struct ToHalf<__nv_bfloat16> {
    static __device__ __forceinline__ __nv_bfloat16 cvt(float x) { return __float2bfloat16(x); }
};

}  // namespace tc

template <typename T>
__global__ void flash_decoding_split_kernel_tc(
    const T* __restrict__ q,
    const T* __restrict__ k,
    const T* __restrict__ v,
    float* __restrict__ partial_m,
    float* __restrict__ partial_l,
    float* __restrict__ partial_o,
    int32_t batch, int32_t heads, int32_t k_len,
    int32_t head_dim,
    int32_t num_splits,
    int32_t num_kv_heads,
    int32_t group_size,
    int64_t q_b_stride, int64_t q_h_stride,
    int64_t k_b_stride, int64_t k_h_stride, int64_t k_seq_stride,
    int64_t v_b_stride, int64_t v_h_stride, int64_t v_seq_stride,
    float scale)
{
    using namespace nvcuda;

    const int s     = blockIdx.x;
    const int h_kv  = blockIdx.y;
    const int b     = blockIdx.z;
    if (s >= num_splits || h_kv >= num_kv_heads || b >= batch) return;

    const int tid       = threadIdx.x;
    const int warp_id   = tid >> 5;
    const int lane      = tid & 31;
    const int num_warps = blockDim.x >> 5;   // 4

    const int k_start = s * kChunkK;
    const int k_end   = min(k_start + kChunkK, k_len);
    const int chunk_len = k_end - k_start;

    // Q-head range that maps to this KV head: [h_kv*group_size,
    // (h_kv+1)*group_size).
    const int q_head_base = h_kv * group_size;

    // SMEM allocations — Tier-2 v2 (warp-parallel + smaller round-trip).
    //
    //   sQ                : kWmmaM × kMaxD × sizeof(T)     = 4 KB (f16)
    //   sKV_tile          : num_warps × kWmmaN × kMaxD × sizeof(T)
    //                                                       = 16 KB (f16, partitioned per warp during QK^T)
    //   sScores           : kWmmaM × kChunkK × sizeof(float) = 16 KB
    //   sO                : kWmmaM × kMaxD × sizeof(float)  = 8 KB
    //   sP_warp_scratch   : num_warps × kWmmaM × kWmmaK × sizeof(T) = 2 KB
    //   sMaxRow / sSumRow : trivial
    //   ---
    //   Total ≈ 47 KB — fits in 48 KB default SMEM.
    //
    // sKV_tile alternates roles:
    //   - During QK^T pass: partitioned across warps. Each of 4 warps
    //     owns its own 16-K-row slot at [warp_id * kWmmaN * head_dim].
    //     All 4 warps load + mma in parallel.
    //   - During PV pass: only the first 4 KB (one shared kWmmaK-row
    //     sub-tile of V) is used at a time. The per-warp partitioning
    //     from QK^T is gone — V is loaded cooperatively into the same
    //     buffer base.
    //
    // sP_warp_scratch replaces the v1 kernel's 8 KB shared sP buffer.
    // Each warp converts ONLY the 16-K-col slice of sScores it needs
    // for its current mma into a tiny per-warp scratch slot. Saves
    // 6 KB of SMEM that funds the per-warp K-tile partitioning.
    constexpr int kSmemPerWarp = kWmmaN * kMaxD;  // 16 × 128 = 2048 T-elts
    __shared__ T     sQ[kWmmaM * kMaxD];
    __shared__ T     sKV_tile[4 * kSmemPerWarp]; // num_warps == 4 hard-coded
    __shared__ float sScores[kWmmaM * kChunkK];
    __shared__ float sO[kWmmaM * kMaxD];
    __shared__ T     sP_warp_scratch[4 * kWmmaM * kWmmaK];
    __shared__ float sMaxRow[kWmmaM];
    __shared__ float sSumRow[kWmmaM];

    // Empty-chunk → write neutral partials for every Q head in the group.
    if (k_start >= k_end) {
        for (int g = 0; g < group_size; ++g) {
            const int h_q = q_head_base + g;
            if (tid == 0) {
                int64_t pidx = ((int64_t)b * heads + h_q) * num_splits + s;
                partial_m[pidx] = -INFINITY;
                partial_l[pidx] = 0.0f;
            }
            for (int d = tid; d < head_dim; d += blockDim.x) {
                int64_t poff = (((int64_t)b * heads + h_q) * num_splits + s)
                              * (int64_t)head_dim + d;
                partial_o[poff] = 0.0f;
            }
        }
        return;
    }

    // Load Q for all `group_size` heads in this KV group. Pad unused
    // M-rows with zeros (they contribute zero scores → become -inf
    // after the row-mask + softmax).
    for (int m = 0; m < kWmmaM; ++m) {
        if (m < group_size) {
            const int h_q = q_head_base + m;
            const T* q_row = q + (int64_t)b * q_b_stride
                               + (int64_t)h_q * q_h_stride;
            for (int d = tid; d < head_dim; d += blockDim.x) {
                sQ[m * head_dim + d] = q_row[d];
            }
        } else {
            for (int d = tid; d < head_dim; d += blockDim.x) {
                sQ[m * head_dim + d] = tc::ToHalf<T>::cvt(0.0f);
            }
        }
    }

    // Initialize sO to zero (we'll accumulate across K sub-tiles).
    for (int i = tid; i < kWmmaM * head_dim; i += blockDim.x) {
        sO[i] = 0.0f;
    }
    // Initialize row stats — running online softmax over the chunk.
    if (tid < kWmmaM) {
        sMaxRow[tid] = -INFINITY;
        sSumRow[tid] = 0.0f;
    }
    __syncthreads();

    const T* k_bh = k + (int64_t)b * k_b_stride + (int64_t)h_kv * k_h_stride;
    const T* v_bh = v + (int64_t)b * v_b_stride + (int64_t)h_kv * v_h_stride;

    // ==========================================================================
    // Pass 1 — compute all chunk_len scores into sScores via WMMA mma.
    //
    // Warp-parallel version: each of the 4 warps owns its own
    // 16-K-row slot in sKV_tile and runs its own mma chain in
    // parallel. For each outer iteration:
    //   1. Each warp loads its 16 K-rows into sKV_tile[warp_id * slot..]
    //      cooperatively across its 32 lanes (no inter-warp coordination
    //      during the load itself).
    //   2. __syncthreads to publish all 4 warps' K-tiles.
    //   3. All 4 warps run mma in parallel: warp_id owns N-tile
    //      (n_base + warp_id * kWmmaN). Each walks head_dim / kWmmaK
    //      sub-mmas in the K-reduction direction.
    //   4. Each warp store_matrix_sync to sScores at its column offset.
    // Outer-iter count: chunk_len / (num_warps * kWmmaN) = 256 / 64 = 4
    // for the max chunk. Per outer iter, all 4 warps active throughout.
    // ==========================================================================

    for (int n_base = 0; n_base < chunk_len; n_base += num_warps * kWmmaN) {
        const int n_warp     = n_base + warp_id * kWmmaN;
        const bool warp_active = (n_warp < chunk_len);
        T* const my_k_slot   = &sKV_tile[warp_id * kSmemPerWarp];

        // Each warp loads its OWN 16-K-row slot. 32 lanes cooperate
        // within the warp; no cross-warp coordination needed for the
        // load itself.
        for (int i = lane; i < kWmmaN * head_dim; i += 32) {
            const int row = i / head_dim;
            const int d   = i % head_dim;
            const int k_abs = k_start + n_warp + row;
            if (warp_active && k_abs < k_end) {
                my_k_slot[row * head_dim + d] =
                    k_bh[(int64_t)k_abs * k_seq_stride + d];
            } else {
                my_k_slot[row * head_dim + d] = tc::ToHalf<T>::cvt(0.0f);
            }
        }
        __syncthreads();

        // All warps mma in parallel.
        if (warp_active) {
            wmma::fragment<wmma::matrix_a, kWmmaM, kWmmaN, kWmmaK, T,
                           wmma::row_major> q_frag;
            wmma::fragment<wmma::matrix_b, kWmmaM, kWmmaN, kWmmaK, T,
                           wmma::col_major> k_frag;
            wmma::fragment<wmma::accumulator, kWmmaM, kWmmaN, kWmmaK, float> c_frag;
            wmma::fill_fragment(c_frag, 0.0f);

            for (int kk = 0; kk < head_dim; kk += kWmmaK) {
                wmma::load_matrix_sync(q_frag, sQ + kk, head_dim);
                wmma::load_matrix_sync(k_frag, my_k_slot + kk, head_dim);
                wmma::mma_sync(c_frag, q_frag, k_frag, c_frag);
            }
            wmma::store_matrix_sync(
                &sScores[0 * kChunkK + n_warp],
                c_frag, kChunkK, wmma::mem_row_major);
        }
        __syncthreads();
    }

    // Apply scale + chunk-tail mask (sScores beyond chunk_len → -inf).
    for (int i = tid; i < kWmmaM * kChunkK; i += blockDim.x) {
        int m = i / kChunkK;
        int n = i % kChunkK;
        if (n < chunk_len && m < group_size) {
            sScores[i] *= scale;
        } else {
            sScores[i] = -INFINITY;
        }
    }
    __syncthreads();

    // Per-row softmax over [kWmmaM, chunk_len]. Each warp owns ONE row
    // (only group_size rows are meaningful; the rest produce -inf).
    //
    // Use block_reduce-style helpers per row. With 4 warps and 16 rows,
    // each warp handles 4 rows sequentially.
    for (int m_local = warp_id; m_local < kWmmaM; m_local += num_warps) {
        // Phase 1: row max via warp-shuffle reduce.
        float row_max = -INFINITY;
        for (int n = lane; n < chunk_len; n += 32) {
            float v = sScores[m_local * kChunkK + n];
            if (v > row_max) row_max = v;
        }
        #pragma unroll
        for (int delta = 16; delta > 0; delta >>= 1) {
            float other = __shfl_xor_sync(0xffffffff, row_max, delta, 32);
            if (other > row_max) row_max = other;
        }
        // Phase 2: row sum of exp(s - row_max).
        float row_sum = 0.0f;
        for (int n = lane; n < chunk_len; n += 32) {
            float p = expf(sScores[m_local * kChunkK + n] - row_max);
            sScores[m_local * kChunkK + n] = p;
            row_sum += p;
        }
        #pragma unroll
        for (int delta = 16; delta > 0; delta >>= 1) {
            row_sum += __shfl_xor_sync(0xffffffff, row_sum, delta, 32);
        }
        if (lane == 0) {
            sMaxRow[m_local] = row_max;
            sSumRow[m_local] = row_sum;
        }
    }
    __syncthreads();

    // ==========================================================================
    // Pass 2 — accumulate sO = P @ V via WMMA mma.
    //
    // Two tweaks over the v1 PV layout:
    //   - V reuses sKV_tile's first kWmmaK × head_dim slot
    //     (loaded cooperatively across all 4 warps, no partitioning).
    //   - sP lives in a tiny per-warp 16×16 scratch (2 KB total)
    //     instead of a full 8 KB shared sP buffer. Each warp converts
    //     ONLY the sScores slice it actively mma's against.
    //
    // For head_dim = 128: 8 N-tiles to cover. With 4 warps and 2
    // N-tiles per warp, each iteration of the outer (k_sub) loop has
    // every warp working.
    // ==========================================================================
    T* const sV_tile = &sKV_tile[0];  // reuse, only first 4 KB needed
    const int n_tiles_per_d = head_dim / kWmmaN;     // 8 for D=128
    const int n_tiles_per_warp = (n_tiles_per_d + num_warps - 1) / num_warps;

    for (int k_sub = 0; k_sub < chunk_len; k_sub += kWmmaK) {
        const int rows_to_load = min(kWmmaK, chunk_len - k_sub);

        // Coop-load V sub-tile [kWmmaK, head_dim] using all threads.
        for (int i = tid; i < kWmmaK * head_dim; i += blockDim.x) {
            int row = i / head_dim;
            int d   = i % head_dim;
            int k_abs = k_start + k_sub + row;
            if (row < rows_to_load && k_abs < k_end) {
                sV_tile[row * head_dim + d] =
                    v_bh[(int64_t)k_abs * v_seq_stride + d];
            } else {
                sV_tile[row * head_dim + d] = tc::ToHalf<T>::cvt(0.0f);
            }
        }

        // Each warp converts its needed P slice — kWmmaM × kWmmaK = 256
        // fp32 cells from sScores[:, k_sub:k_sub+16] → its per-warp
        // sP slot. 32 lanes × 8 cells each.
        T* const my_p_slot = &sP_warp_scratch[warp_id * kWmmaM * kWmmaK];
        for (int i = lane; i < kWmmaM * kWmmaK; i += 32) {
            const int m = i / kWmmaK;
            const int k_in_sub = i % kWmmaK;
            my_p_slot[i] = tc::ToHalf<T>::cvt(
                sScores[m * kChunkK + k_sub + k_in_sub]);
        }
        __syncthreads();

        // Each warp processes its assigned N-tile(s).
        for (int n_idx = 0; n_idx < n_tiles_per_warp; ++n_idx) {
            const int n_tile = warp_id + n_idx * num_warps;
            if (n_tile >= n_tiles_per_d) break;
            const int d_base = n_tile * kWmmaN;

            wmma::fragment<wmma::matrix_a, kWmmaM, kWmmaN, kWmmaK, T,
                           wmma::row_major> p_frag;
            wmma::fragment<wmma::matrix_b, kWmmaM, kWmmaN, kWmmaK, T,
                           wmma::row_major> v_frag;
            wmma::fragment<wmma::accumulator, kWmmaM, kWmmaN, kWmmaK, float> o_frag;

            // Load existing sO accumulator for this [M, n_tile] block.
            wmma::load_matrix_sync(
                o_frag, &sO[0 * head_dim + d_base], head_dim,
                wmma::mem_row_major);
            // P fragment from per-warp scratch (ld = kWmmaK, tight).
            wmma::load_matrix_sync(p_frag, my_p_slot, kWmmaK);
            // V fragment from shared sV_tile.
            wmma::load_matrix_sync(v_frag, sV_tile + d_base, head_dim);
            wmma::mma_sync(o_frag, p_frag, v_frag, o_frag);

            // Store back to sO.
            wmma::store_matrix_sync(
                &sO[0 * head_dim + d_base], o_frag, head_dim,
                wmma::mem_row_major);
        }
        __syncthreads();
    }

    // ==========================================================================
    // Pass 3 — write partials. Each of the `group_size` Q heads gets
    // its own (m, l, o[D]) tuple in workspace, indexed by the Q-head
    // ID. Padded M-rows (m >= group_size) are not written.
    // ==========================================================================
    for (int g = 0; g < group_size; ++g) {
        const int h_q = q_head_base + g;
        if (tid == 0) {
            int64_t pidx = ((int64_t)b * heads + h_q) * num_splits + s;
            partial_m[pidx] = sMaxRow[g];
            partial_l[pidx] = sSumRow[g];
        }
        for (int d = tid; d < head_dim; d += blockDim.x) {
            int64_t poff = (((int64_t)b * heads + h_q) * num_splits + s)
                          * (int64_t)head_dim + d;
            partial_o[poff] = sO[g * head_dim + d];
        }
    }
}

// =============================================================================
// Combine kernel — one block per (b, h). Reads `num_splits` partial
// (m, l, o[D]) triples for its BH and emits the final y[b, h, 0, :D].
// =============================================================================

template <typename T>
__global__ void flash_decoding_combine_kernel(
    const float* __restrict__ partial_m,
    const float* __restrict__ partial_l,
    const float* __restrict__ partial_o,
    T* __restrict__ y,
    int32_t batch, int32_t heads,
    int32_t head_dim,
    int32_t num_splits,
    int64_t y_b_stride, int64_t y_h_stride)
{
    const int h = blockIdx.y;
    const int b = blockIdx.z;
    if (h >= heads || b >= batch) return;

    const int tid = threadIdx.x;
    const int nthreads = blockDim.x;

    __shared__ float warp_buf[32];

    // Phase 1 — find the global max across splits.
    const int64_t ml_base = ((int64_t)b * heads + h) * num_splits;
    float local_max = -INFINITY;
    for (int s = tid; s < num_splits; s += nthreads) {
        float m = partial_m[ml_base + s];
        if (m > local_max) local_max = m;
    }
    float global_max = block_reduce_max_f32(local_max, warp_buf);

    // Phase 2 — global_l = Σ_s exp(partial_m[s] - global_max) * partial_l[s].
    float local_l = 0.0f;
    for (int s = tid; s < num_splits; s += nthreads) {
        float pm = partial_m[ml_base + s];
        float pl = partial_l[ml_base + s];
        float alpha = (pm == -INFINITY) ? 0.0f : expf(pm - global_max);
        local_l += alpha * pl;
    }
    float global_l = block_reduce_sum_f32(local_l, warp_buf);
    // Guard against degenerate (all-masked) input.
    float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f;

    // Phase 3 — per-d, accumulate weighted partial_o.
    const int64_t o_base = (((int64_t)b * heads + h)) * (int64_t)num_splits * (int64_t)head_dim;
    T* y_bh = y + (int64_t)b * y_b_stride + (int64_t)h * y_h_stride;
    for (int d = tid; d < head_dim; d += nthreads) {
        float acc = 0.0f;
        for (int s = 0; s < num_splits; ++s) {
            float pm = partial_m[ml_base + s];
            float alpha = (pm == -INFINITY) ? 0.0f : expf(pm - global_max);
            float po = partial_o[o_base + (int64_t)s * head_dim + d];
            acc += alpha * po;
        }
        y_bh[d] = LoadAcc<T>::store(acc * inv_l);
    }
}

// =============================================================================
// Host launcher — workspace contract + 2-kernel dispatch.
// =============================================================================
//
// Workspace bytes:
//   partial_m:           sizeof(float) * B * H * S
//   partial_l:           sizeof(float) * B * H * S
//   partial_o:           sizeof(float) * B * H * S * D
//   total = B * H * S * (2 + D) * sizeof(float)

__host__ inline int64_t flash_decoding_num_splits(int32_t k_len) {
    if (k_len <= 0) return 0;
    return (int64_t)((k_len + kChunkK - 1) / kChunkK);
}

__host__ inline size_t flash_decoding_workspace_bytes(
    int32_t batch, int32_t heads, int32_t k_len, int32_t head_dim)
{
    int64_t s = flash_decoding_num_splits(k_len);
    if (s == 0) return 0;
    return (size_t)batch * (size_t)heads * (size_t)s
         * (size_t)(2 + head_dim) * sizeof(float);
}

// TC (tensor-core / WMMA) dispatch — DISABLED at single-batch decode.
//
// Two iterations of the WMMA kernel were benchmarked vs SIMT-GQA on
// RTX 4070 (full results in commit-message tables):
//
//   v1 (warp-sequentialized QK^T, 8 KB shared sP):
//     llama3-70b K=4096: 231µs  vs SIMT 132µs  (SIMT 1.75× faster)
//     llama3-70b K=8192: 448µs  vs SIMT 252µs  (SIMT 1.78× faster)
//     qwen2-14b  K=8192: 231µs  vs SIMT 134µs  (SIMT 1.73× faster)
//
//   v2 (warp-parallel QK^T, per-warp 2 KB sP scratch):
//     llama3-70b K=4096: 165µs  vs SIMT 132µs  (SIMT 1.25× faster)
//     llama3-70b K=8192: 442µs  vs SIMT 252µs  (SIMT 1.75× faster)
//     qwen2-14b  K=8192: 166µs  vs SIMT 134µs  (SIMT 1.24× faster)
//
// v2 closed real per-kernel gaps (1.13-1.41× faster than v1) but no
// tested GQA shape edges SIMT-GQA. The structural ceiling holds:
//
//   1. Single-batch decode at GQA group_size=4-8 fills only 4-8 of
//      WMMA's 16 M-tile rows. Throughput penalty (16-group)/16 = 50-75%.
//
//   2. Decode is bandwidth/L2-bound at the tested shapes. Tensor cores
//      attack compute — useless when compute isn't the bottleneck.
//
//   3. TC grid is (num_splits, H_kv, B); SIMT is (num_splits, H_q, B).
//      With B=1 and H_kv ≤ H_q the TC grid has group_size× fewer
//      blocks. SIMT's higher block count + L2 reuse wins.
//
// The v2 kernel (`flash_decoding_split_kernel_tc` below) and dispatch
// helper stay in the tree as documented reference. Re-enabling is one
// line. Likely worth re-evaluating only at multi-batch decode (B ≥ 8)
// where the M-tile fills with B × group_size rows — but that workload
// is owned by `BatchPagedDecodePlan` (Phase 46 FlashInfer vendored).
//
// constexpr int kMinTcBlocks = 72;  // RTX 4070-tuned block-count gate
__host__ inline bool flash_decoding_should_use_tc(
    int32_t /*group_size*/, int32_t /*head_dim*/,
    int32_t /*batch*/, int32_t /*num_kv_heads*/, int32_t /*num_splits*/)
{
    return false;
}

template <typename T>
__host__ inline int32_t launch_flash_decoding(
    const T* q, const T* k, const T* v, T* y,
    void* workspace, size_t workspace_bytes,
    int32_t batch, int32_t heads, int32_t num_kv_heads,
    int32_t k_len, int32_t head_dim,
    int64_t q_b_stride, int64_t q_h_stride,
    int64_t k_b_stride, int64_t k_h_stride, int64_t k_seq_stride,
    int64_t v_b_stride, int64_t v_h_stride, int64_t v_seq_stride,
    int64_t y_b_stride, int64_t y_h_stride,
    float scale,
    cudaStream_t stream)
{
    if (batch <= 0 || heads <= 0 || num_kv_heads <= 0 || head_dim <= 0) return 2;
    if (heads % num_kv_heads != 0) return 2;
    if (head_dim > kMaxD) return 3;
    if (k_len <= 0) {
        // No KV → write zeros + bail.
        // Caller is expected to zero-init y; nothing to do here.
        return 0;
    }

    const int32_t group_size = heads / num_kv_heads;

    int32_t num_splits = (int32_t)flash_decoding_num_splits(k_len);
    size_t need = (size_t)batch * (size_t)heads * (size_t)num_splits
                * (size_t)(2 + head_dim) * sizeof(float);
    if (workspace_bytes < need) return 4;
    if (workspace == nullptr) return 4;

    unsigned char* wp = (unsigned char*)workspace;
    size_t per_ml = (size_t)batch * (size_t)heads * (size_t)num_splits * sizeof(float);
    float* partial_m = (float*)wp;        wp += per_ml;
    float* partial_l = (float*)wp;        wp += per_ml;
    float* partial_o = (float*)wp;

    dim3 block(kThreadsPerBlock);

    if (flash_decoding_should_use_tc(
            group_size, head_dim, batch, num_kv_heads, num_splits))
    {
        // TC path — one block per (split, h_kv, b). Each block batches
        // all group_size Q heads into the WMMA M-tile.
        dim3 grid_split((unsigned)num_splits, (unsigned)num_kv_heads, (unsigned)batch);
        flash_decoding_split_kernel_tc<T><<<grid_split, block, 0, stream>>>(
            q, k, v, partial_m, partial_l, partial_o,
            batch, heads, k_len, head_dim, num_splits,
            num_kv_heads, group_size,
            q_b_stride, q_h_stride,
            k_b_stride, k_h_stride, k_seq_stride,
            v_b_stride, v_h_stride, v_seq_stride,
            scale);
    } else {
        // SIMT path — one block per (split, h_q, b). Each block handles
        // a single Q head; GQA broadcast handled via integer division
        // h_q / group_size inside the kernel.
        dim3 grid_split((unsigned)num_splits, (unsigned)heads, (unsigned)batch);
        flash_decoding_split_kernel<T><<<grid_split, block, 0, stream>>>(
            q, k, v, partial_m, partial_l, partial_o,
            batch, heads, k_len, head_dim, num_splits, group_size,
            q_b_stride, q_h_stride,
            k_b_stride, k_h_stride, k_seq_stride,
            v_b_stride, v_h_stride, v_seq_stride,
            scale);
    }
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) return 1000 + (int32_t)err;

    // Combine kernel — same shape (per Q head) regardless of which split
    // kernel ran.
    dim3 grid_comb(1, (unsigned)heads, (unsigned)batch);
    flash_decoding_combine_kernel<T><<<grid_comb, block, 0, stream>>>(
        partial_m, partial_l, partial_o, y,
        batch, heads, head_dim, num_splits,
        y_b_stride, y_h_stride);
    err = cudaGetLastError();
    if (err != cudaSuccess) return 1000 + (int32_t)err;
    return 0;
}

} } // namespace baracuda::flash_decoding

// =============================================================================
// FFI macro — one symbol pair per dtype.
// =============================================================================

#define BARACUDA_KERNELS_FLASH_DECODING_INSTANTIATE(NAME, T)                                       \
    extern "C" int32_t baracuda_kernels_ ## NAME ## _run(                                           \
        const void* q, const void* k, const void* v, void* y,                                       \
        void* workspace, size_t workspace_bytes,                                                    \
        int32_t batch, int32_t heads, int32_t num_kv_heads,                                         \
        int32_t k_len, int32_t head_dim,                                                            \
        int64_t q_b_stride, int64_t q_h_stride,                                                     \
        int64_t k_b_stride, int64_t k_h_stride, int64_t k_seq_stride,                               \
        int64_t v_b_stride, int64_t v_h_stride, int64_t v_seq_stride,                               \
        int64_t y_b_stride, int64_t y_h_stride,                                                     \
        float scale,                                                                                \
        void* stream_ptr)                                                                           \
    {                                                                                               \
        cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);                                \
        return baracuda::flash_decoding::launch_flash_decoding<T>(                                  \
            (const T*)q, (const T*)k, (const T*)v, (T*)y,                                           \
            workspace, workspace_bytes,                                                             \
            batch, heads, num_kv_heads, k_len, head_dim,                                            \
            q_b_stride, q_h_stride,                                                                 \
            k_b_stride, k_h_stride, k_seq_stride,                                                   \
            v_b_stride, v_h_stride, v_seq_stride,                                                   \
            y_b_stride, y_h_stride,                                                                 \
            scale, stream);                                                                         \
    }                                                                                               \
    extern "C" int32_t baracuda_kernels_ ## NAME ## _can_implement(                                 \
        int32_t batch, int32_t heads, int32_t num_kv_heads,                                         \
        int32_t k_len, int32_t head_dim)                                                            \
    {                                                                                               \
        if (batch <= 0 || heads <= 0 || num_kv_heads <= 0 || head_dim <= 0) return 2;               \
        if (heads % num_kv_heads != 0) return 2;                                                    \
        if (head_dim > baracuda::flash_decoding::kMaxD) return 3;                                   \
        if (k_len < 0) return 2;                                                                    \
        return 0;                                                                                   \
    }                                                                                               \
    extern "C" size_t baracuda_kernels_ ## NAME ## _workspace_bytes(                                \
        int32_t batch, int32_t heads, int32_t k_len, int32_t head_dim)                              \
    {                                                                                               \
        return baracuda::flash_decoding::flash_decoding_workspace_bytes(                            \
            batch, heads, k_len, head_dim);                                                         \
    }