oxibonsai-kernels 0.1.4

1-bit Q1_0_g128 compute kernels (dequant, GEMV, GEMM) for OxiBonsai
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
//! CUDA C kernel source strings for OxiBonsai FP8 E4M3/E5M2 batch GEMM (prefill) operations.
//!
//! # Prefill kernel catalogue
//!
//! | Kernel                                    | Description                                                   |
//! |-------------------------------------------|---------------------------------------------------------------|
//! | `gemm_fp8_e4m3`                           | Batch GEMM: FP8 E4M3 AoS, col-major I/O, col_sums[8]        |
//! | `gemm_fp8_e4m3_residual`                  | FP8 E4M3 GEMM + fused residual add                          |
//! | `fused_gate_up_swiglu_gemm_fp8_e4m3`      | Fused gate+up FP8 E4M3 GEMM with SwiGLU epilogue            |
//! | `gemv_fp8_e4m3_pf`                        | Single-token FP8 E4M3 GEMV (for sequential attention pass)  |
//! | `gemm_fp8_e5m2`                           | Batch GEMM: FP8 E5M2 AoS, col-major I/O, col_sums[8]        |
//! | `gemm_fp8_e5m2_residual`                  | FP8 E5M2 GEMM + fused residual add                          |
//! | `fused_gate_up_swiglu_gemm_fp8_e5m2`      | Fused gate+up FP8 E5M2 GEMM with SwiGLU epilogue            |
//! | `gemv_fp8_e5m2_pf`                        | Single-token FP8 E5M2 GEMV (for sequential attention pass)  |
//!
//! # Block layout (AoS, 34 bytes/block — matches `BlockFP8E4M3` / `BlockFP8E5M2`)
//!
//! ```text
//! bytes  0-31: 32 FP8 quantized weights   (E4M3 or E5M2)
//! bytes 32-33: FP16 LE block scale
//! ```
//!
//! This differs from Q8_0 (scale at bytes 0-1, weights at 2-33).
//! Scale access: `bptr[32] | ((unsigned short)bptr[33] << 8u)`.
//!
//! # Batch tensor layout
//!
//! All batch inputs/outputs use **column-major** layout: `buf[col * dim + element]`
//! where `col` is the batch/token index.
//!
//! # Grid / block config
//!
//! All kernels:
//! - Grid:  `(ceil(n_rows / 8), 1, 1)` — 8 warps per CTA
//! - Block: `(256, 1, 1)` — 8 warps × 32 lanes

#![cfg(all(
    feature = "native-cuda",
    any(target_os = "linux", target_os = "windows")
))]

/// CUDA C source for FP8 E4M3/E5M2 batch GEMM (prefill) kernels.
///
/// All kernels use AoS weight layout (blocks stored contiguously as-is from GGUF).
/// Batch tensors use column-major layout: `buf[col * dim + element]`.
pub const CUDA_FP8_PREFILL_KERNELS_SRC: &str = r#"
/* =========================================================================
   OxiBonsai CUDA FP8 E4M3 / E5M2 prefill (batch GEMM) kernels.

   FP8 AoS block (34 bytes): [q0..q31, scale_lo, scale_hi]
     bytes  0-31: 32 FP8 quantized weights (E4M3 or E5M2)
     bytes 32-33: FP16 LE block scale
   Scale access: bptr[32] | ((unsigned short)bptr[33] << 8u)
   Weight at index w: fp8_pf_e4m3_to_float(bptr[w])  (w in [0, 32))

   Batch tensors: column-major  buf[col * dim + element]
   Grid:  (ceil(n_rows/8), 1, 1)  — 8 warps per CTA
   Block: (256, 1, 1)             — 8 warps × 32 lanes
   ========================================================================= */

/* ── Hardware FP16 → FP32 via PTX (SM 6.0+, 1 instruction) ─────────────── */
static __device__ __forceinline__ float fp8_pf_fast_fp16_to_float(unsigned short h) {
    float f;
    asm("cvt.f32.f16 %0, %1;" : "=f"(f) : "h"(h));
    return f;
}

/* ── SiLU activation: x · σ(x) ─────────────────────────────────────────── */
static __device__ __forceinline__ float fp8_pf_silu(float x) {
    return x / (1.0f + expf(-x));
}

/* ── FP8 E4M3FN decode (OFP8, bias=7, 4-bit exp, 3-bit mantissa) ─────────
   Format: s[7] exp[6:3] man[2:0], bias=7
   Normal:  (-1)^s * 2^(exp-7) * (1 + man/8)
   Denorm:  (-1)^s * 2^(-6) * (man/8)
   NaN:     exp=0b1111 AND man=0b111 (patterns 0x7f, 0xff) → 0 for inference
   ─────────────────────────────────────────────────────────────────────────── */
static __device__ __forceinline__ float fp8_pf_e4m3_to_float(unsigned char b) {
    /* NaN patterns: 0x7f and 0xff → treat as 0 for inference */
    if (b == 0x7Fu || b == 0xFFu) return 0.0f;
    const unsigned int sign = (b >> 7u) & 1u;
    const unsigned int exp  = (b >> 3u) & 15u;  /* 4-bit exponent */
    const unsigned int mant = b & 7u;            /* 3-bit mantissa */
    float val;
    if (exp == 0u) {
        /* Denormal: (-1)^s * 2^(-6) * (mant/8) */
        val = (float)mant * (1.0f / 8.0f) * (1.0f / 64.0f);
    } else {
        /* Normal: 2^(exp-7) * (1 + mant/8)
           Assemble as IEEE-754 f32: ((exp - 7 + 127) << 23) | (mant << 20) */
        val = __int_as_float(((exp - 7u + 127u) << 23u) | (mant << 20u));
    }
    return sign ? -val : val;
}

/* ── FP8 E5M2 decode (standard, bias=15, 5-bit exp, 2-bit mantissa) ──────
   Format: s[7] exp[6:2] man[1:0], bias=15
   Normal:  (-1)^s * 2^(exp-15) * (1 + man/4)
   Denorm:  (-1)^s * 2^(-14) * (man/4)
   Inf/NaN: exp=31 → 0 for inference
   ─────────────────────────────────────────────────────────────────────────── */
static __device__ __forceinline__ float fp8_pf_e5m2_to_float(unsigned char b) {
    const unsigned int exp  = (b >> 2u) & 31u;  /* 5-bit exponent */
    const unsigned int mant = b & 3u;            /* 2-bit mantissa */
    if (exp == 31u) return 0.0f;                 /* Inf / NaN → 0 */
    const unsigned int sign = (b >> 7u) & 1u;
    float val;
    if (exp == 0u) {
        /* Denormal: (-1)^s * 2^(-14) * (mant/4) */
        val = (float)mant * (1.0f / 4.0f) * (1.0f / 16384.0f);
    } else {
        /* Normal: 2^(exp-15) * (1 + mant/4)
           Assemble as IEEE-754 f32: ((exp - 15 + 127) << 23) | (mant << 21) */
        val = __int_as_float(((exp - 15u + 127u) << 23u) | (mant << 21u));
    }
    return sign ? -val : val;
}

/* =========================================================================
   Kernel 1 — gemm_fp8_e4m3
   Batch FP8 E4M3 GEMM. Accumulates into outputs with +=.
   blocks: AoS, 34 bytes/block (32 FP8 weights + 2 scale bytes)
   Scale at bytes 32-33, weights at bytes 0-31.
   inputs:  col-major [batch_size * k]
   outputs: col-major [batch_size * n_rows], accumulated with +=
   k must be a positive multiple of 32.
   ========================================================================= */
extern "C" __global__ void gemm_fp8_e4m3(
    const unsigned char* __restrict__ blocks,
    const float*         __restrict__ inputs,
    float*               __restrict__ outputs,
    unsigned int n_rows,
    unsigned int k,
    unsigned int batch_size
) {
    const unsigned int warp_id = threadIdx.x >> 5u;
    const unsigned int lane    = threadIdx.x & 31u;
    const unsigned int row     = blockIdx.x * 8u + warp_id;
    if (row >= n_rows) return;

    const unsigned int blocks_per_row = k >> 5u;  /* k / 32 */

    /* Process batch columns in 8-column outer chunks (cap-of-8 fix). */
    for (unsigned int col_base = 0u; col_base < batch_size; col_base += 8u) {
        const unsigned int cols_remaining = batch_size - col_base;
        const unsigned int cols = cols_remaining < 8u ? cols_remaining : 8u;

        float col_sums[8];
        #pragma unroll
        for (unsigned int c = 0u; c < 8u; ++c) col_sums[c] = 0.0f;

        for (unsigned int b = lane; b < blocks_per_row; b += 32u) {
            /* Load AoS block for this row — 34 bytes/block, weights-first layout */
            const unsigned char* bptr = blocks + (unsigned long long)(row * blocks_per_row + b) * 34u;
            /* Scale is at bytes 32-33 (after the 32 FP8 weight bytes) */
            const unsigned short d_raw = (unsigned short)bptr[32u] | ((unsigned short)bptr[33u] << 8u);
            const float scale = fp8_pf_fast_fp16_to_float(d_raw);
            const unsigned int base = b << 5u;  /* b * 32 */

            for (unsigned int col = 0u; col < cols; ++col) {
                const float* inp = inputs + (unsigned long long)(col_base + col) * k;
                const float* xbase = inp + base;
                float bsum = 0.0f;
                #pragma unroll 8
                for (unsigned int w = 0u; w < 32u; ++w) {
                    bsum += fp8_pf_e4m3_to_float(bptr[w]) * xbase[w];
                }
                col_sums[col] += scale * bsum;
            }
        }

        /* Warp-shuffle reduction and write outputs (column-major, accumulate +=) */
        for (unsigned int col = 0u; col < cols; ++col) {
            float s = col_sums[col];
            s += __shfl_down_sync(0xffffffffu, s, 16u);
            s += __shfl_down_sync(0xffffffffu, s,  8u);
            s += __shfl_down_sync(0xffffffffu, s,  4u);
            s += __shfl_down_sync(0xffffffffu, s,  2u);
            s += __shfl_down_sync(0xffffffffu, s,  1u);
            if (lane == 0u)
                outputs[(unsigned long long)(col_base + col) * n_rows + row] += s;
        }
    }
}

/* =========================================================================
   Kernel 2 — gemm_fp8_e4m3_residual
   Batch FP8 E4M3 GEMM + fused in-place residual add.
   For each (row, col): outputs[col*n_rows+row] = residual[col*n_rows+row] + sum
   ========================================================================= */
extern "C" __global__ void gemm_fp8_e4m3_residual(
    const unsigned char* __restrict__ blocks,
    const float*         __restrict__ inputs,
    float*               __restrict__ outputs,
    unsigned int n_rows,
    unsigned int k,
    unsigned int batch_size,
    const float* __restrict__ residual
) {
    const unsigned int warp_id = threadIdx.x >> 5u;
    const unsigned int lane    = threadIdx.x & 31u;
    const unsigned int row     = blockIdx.x * 8u + warp_id;
    if (row >= n_rows) return;

    const unsigned int blocks_per_row = k >> 5u;

    for (unsigned int col_base = 0u; col_base < batch_size; col_base += 8u) {
        const unsigned int cols_remaining = batch_size - col_base;
        const unsigned int cols = cols_remaining < 8u ? cols_remaining : 8u;

        float col_sums[8];
        #pragma unroll
        for (unsigned int c = 0u; c < 8u; ++c) col_sums[c] = 0.0f;

        for (unsigned int b = lane; b < blocks_per_row; b += 32u) {
            const unsigned char* bptr = blocks + (unsigned long long)(row * blocks_per_row + b) * 34u;
            const unsigned short d_raw = (unsigned short)bptr[32u] | ((unsigned short)bptr[33u] << 8u);
            const float scale = fp8_pf_fast_fp16_to_float(d_raw);
            const unsigned int base = b << 5u;

            for (unsigned int col = 0u; col < cols; ++col) {
                const float* inp = inputs + (unsigned long long)(col_base + col) * k;
                const float* xbase = inp + base;
                float bsum = 0.0f;
                #pragma unroll 8
                for (unsigned int w = 0u; w < 32u; ++w) {
                    bsum += fp8_pf_e4m3_to_float(bptr[w]) * xbase[w];
                }
                col_sums[col] += scale * bsum;
            }
        }

        for (unsigned int col = 0u; col < cols; ++col) {
            float s = col_sums[col];
            s += __shfl_down_sync(0xffffffffu, s, 16u);
            s += __shfl_down_sync(0xffffffffu, s,  8u);
            s += __shfl_down_sync(0xffffffffu, s,  4u);
            s += __shfl_down_sync(0xffffffffu, s,  2u);
            s += __shfl_down_sync(0xffffffffu, s,  1u);
            if (lane == 0u) {
                const unsigned long long idx = (unsigned long long)(col_base + col) * n_rows + row;
                outputs[idx] = residual[idx] + s;
            }
        }
    }
}

/* =========================================================================
   Kernel 3 — fused_gate_up_swiglu_gemm_fp8_e4m3
   Batch fused gate+up FP8 E4M3 GEMM with SwiGLU epilogue.

   The concatenated gate+up weight matrix has 2*n_ffn_rows rows total:
     gate rows:  0   .. n_ffn_rows-1
     up   rows:  n_ffn_rows .. 2*n_ffn_rows-1
   blocks pointer covers all 2*n_ffn_rows rows in AoS layout.

   For each (row r, col c):
     outputs[c * n_ffn_rows + r] = SiLU(gate_sum(r,c)) * up_sum(r,c)

   Output buffer must be zeroed before calling (kernel writes, not +=).
   ========================================================================= */
extern "C" __global__ void fused_gate_up_swiglu_gemm_fp8_e4m3(
    const unsigned char* __restrict__ blocks,
    const float*         __restrict__ inputs,
    float*               __restrict__ outputs,
    unsigned int n_ffn_rows,
    unsigned int k,
    unsigned int batch_size
) {
    const unsigned int warp_id = threadIdx.x >> 5u;
    const unsigned int lane    = threadIdx.x & 31u;
    const unsigned int row     = blockIdx.x * 8u + warp_id;
    if (row >= n_ffn_rows) return;

    const unsigned int blocks_per_row = k >> 5u;
    const unsigned int up_row_offset  = n_ffn_rows * blocks_per_row;  /* block index offset for up row r */

    for (unsigned int col_base = 0u; col_base < batch_size; col_base += 8u) {
        const unsigned int cols_remaining = batch_size - col_base;
        const unsigned int cols = cols_remaining < 8u ? cols_remaining : 8u;

        float gate_sums[8];
        float up_sums[8];
        #pragma unroll
        for (unsigned int c = 0u; c < 8u; ++c) { gate_sums[c] = 0.0f; up_sums[c] = 0.0f; }

        for (unsigned int b = lane; b < blocks_per_row; b += 32u) {
            /* Gate block (row r) — weights-first layout */
            const unsigned char* gbptr = blocks + (unsigned long long)(row * blocks_per_row + b) * 34u;
            const unsigned short gd_raw = (unsigned short)gbptr[32u] | ((unsigned short)gbptr[33u] << 8u);
            const float gscale = fp8_pf_fast_fp16_to_float(gd_raw);

            /* Up block (row r + n_ffn_rows) — weights-first layout */
            const unsigned char* ubptr = blocks + (unsigned long long)(up_row_offset + row * blocks_per_row + b) * 34u;
            const unsigned short ud_raw = (unsigned short)ubptr[32u] | ((unsigned short)ubptr[33u] << 8u);
            const float uscale = fp8_pf_fast_fp16_to_float(ud_raw);

            const unsigned int base = b << 5u;

            for (unsigned int col = 0u; col < cols; ++col) {
                const float* inp = inputs + (unsigned long long)(col_base + col) * k;
                const float* xbase = inp + base;
                float gsum = 0.0f;
                float usum = 0.0f;
                #pragma unroll 8
                for (unsigned int w = 0u; w < 32u; ++w) {
                    const float x = xbase[w];
                    gsum += fp8_pf_e4m3_to_float(gbptr[w]) * x;
                    usum += fp8_pf_e4m3_to_float(ubptr[w]) * x;
                }
                gate_sums[col] += gscale * gsum;
                up_sums[col]   += uscale * usum;
            }
        }

        for (unsigned int col = 0u; col < cols; ++col) {
            float gs = gate_sums[col];
            float us = up_sums[col];
            gs += __shfl_down_sync(0xffffffffu, gs, 16u);
            gs += __shfl_down_sync(0xffffffffu, gs,  8u);
            gs += __shfl_down_sync(0xffffffffu, gs,  4u);
            gs += __shfl_down_sync(0xffffffffu, gs,  2u);
            gs += __shfl_down_sync(0xffffffffu, gs,  1u);
            us += __shfl_down_sync(0xffffffffu, us, 16u);
            us += __shfl_down_sync(0xffffffffu, us,  8u);
            us += __shfl_down_sync(0xffffffffu, us,  4u);
            us += __shfl_down_sync(0xffffffffu, us,  2u);
            us += __shfl_down_sync(0xffffffffu, us,  1u);
            if (lane == 0u) {
                outputs[(unsigned long long)(col_base + col) * n_ffn_rows + row] = fp8_pf_silu(gs) * us;
            }
        }
    }
}

/* =========================================================================
   Kernel 4 — gemv_fp8_e4m3_pf
   Single-token FP8 E4M3 GEMV (for attention inner loop / sequential pass).
   output[row] = sum over k of weight_row * input
   ========================================================================= */
extern "C" __global__ void gemv_fp8_e4m3_pf(
    const unsigned char* __restrict__ blocks,
    const float*         __restrict__ input,
    float*               __restrict__ output,
    unsigned int n_rows,
    unsigned int k
) {
    const unsigned int warp_id = threadIdx.x >> 5u;
    const unsigned int lane    = threadIdx.x & 31u;
    const unsigned int row     = blockIdx.x * 8u + warp_id;
    if (row >= n_rows) return;

    const unsigned int blocks_per_row = k >> 5u;

    float acc = 0.0f;
    for (unsigned int b = lane; b < blocks_per_row; b += 32u) {
        const unsigned char* bptr = blocks + (unsigned long long)(row * blocks_per_row + b) * 34u;
        /* Scale at bytes 32-33 */
        const unsigned short d_raw = (unsigned short)bptr[32u] | ((unsigned short)bptr[33u] << 8u);
        const float scale = fp8_pf_fast_fp16_to_float(d_raw);
        const float* xbase = input + (b << 5u);
        float bsum = 0.0f;
        #pragma unroll 8
        for (unsigned int w = 0u; w < 32u; ++w) {
            bsum += fp8_pf_e4m3_to_float(bptr[w]) * xbase[w];
        }
        acc += scale * bsum;
    }

    acc += __shfl_down_sync(0xffffffffu, acc, 16u);
    acc += __shfl_down_sync(0xffffffffu, acc,  8u);
    acc += __shfl_down_sync(0xffffffffu, acc,  4u);
    acc += __shfl_down_sync(0xffffffffu, acc,  2u);
    acc += __shfl_down_sync(0xffffffffu, acc,  1u);
    if (lane == 0u) output[row] = acc;
}

/* =========================================================================
   Kernel 5 — gemm_fp8_e5m2
   Batch FP8 E5M2 GEMM. Accumulates into outputs with +=.
   blocks: AoS, 34 bytes/block (32 FP8 E5M2 weights + 2 scale bytes)
   ========================================================================= */
extern "C" __global__ void gemm_fp8_e5m2(
    const unsigned char* __restrict__ blocks,
    const float*         __restrict__ inputs,
    float*               __restrict__ outputs,
    unsigned int n_rows,
    unsigned int k,
    unsigned int batch_size
) {
    const unsigned int warp_id = threadIdx.x >> 5u;
    const unsigned int lane    = threadIdx.x & 31u;
    const unsigned int row     = blockIdx.x * 8u + warp_id;
    if (row >= n_rows) return;

    const unsigned int blocks_per_row = k >> 5u;

    for (unsigned int col_base = 0u; col_base < batch_size; col_base += 8u) {
        const unsigned int cols_remaining = batch_size - col_base;
        const unsigned int cols = cols_remaining < 8u ? cols_remaining : 8u;

        float col_sums[8];
        #pragma unroll
        for (unsigned int c = 0u; c < 8u; ++c) col_sums[c] = 0.0f;

        for (unsigned int b = lane; b < blocks_per_row; b += 32u) {
            const unsigned char* bptr = blocks + (unsigned long long)(row * blocks_per_row + b) * 34u;
            const unsigned short d_raw = (unsigned short)bptr[32u] | ((unsigned short)bptr[33u] << 8u);
            const float scale = fp8_pf_fast_fp16_to_float(d_raw);
            const unsigned int base = b << 5u;

            for (unsigned int col = 0u; col < cols; ++col) {
                const float* inp = inputs + (unsigned long long)(col_base + col) * k;
                const float* xbase = inp + base;
                float bsum = 0.0f;
                #pragma unroll 8
                for (unsigned int w = 0u; w < 32u; ++w) {
                    bsum += fp8_pf_e5m2_to_float(bptr[w]) * xbase[w];
                }
                col_sums[col] += scale * bsum;
            }
        }

        for (unsigned int col = 0u; col < cols; ++col) {
            float s = col_sums[col];
            s += __shfl_down_sync(0xffffffffu, s, 16u);
            s += __shfl_down_sync(0xffffffffu, s,  8u);
            s += __shfl_down_sync(0xffffffffu, s,  4u);
            s += __shfl_down_sync(0xffffffffu, s,  2u);
            s += __shfl_down_sync(0xffffffffu, s,  1u);
            if (lane == 0u)
                outputs[(unsigned long long)(col_base + col) * n_rows + row] += s;
        }
    }
}

/* =========================================================================
   Kernel 6 — gemm_fp8_e5m2_residual
   Batch FP8 E5M2 GEMM + fused in-place residual add.
   ========================================================================= */
extern "C" __global__ void gemm_fp8_e5m2_residual(
    const unsigned char* __restrict__ blocks,
    const float*         __restrict__ inputs,
    float*               __restrict__ outputs,
    unsigned int n_rows,
    unsigned int k,
    unsigned int batch_size,
    const float* __restrict__ residual
) {
    const unsigned int warp_id = threadIdx.x >> 5u;
    const unsigned int lane    = threadIdx.x & 31u;
    const unsigned int row     = blockIdx.x * 8u + warp_id;
    if (row >= n_rows) return;

    const unsigned int blocks_per_row = k >> 5u;

    for (unsigned int col_base = 0u; col_base < batch_size; col_base += 8u) {
        const unsigned int cols_remaining = batch_size - col_base;
        const unsigned int cols = cols_remaining < 8u ? cols_remaining : 8u;

        float col_sums[8];
        #pragma unroll
        for (unsigned int c = 0u; c < 8u; ++c) col_sums[c] = 0.0f;

        for (unsigned int b = lane; b < blocks_per_row; b += 32u) {
            const unsigned char* bptr = blocks + (unsigned long long)(row * blocks_per_row + b) * 34u;
            const unsigned short d_raw = (unsigned short)bptr[32u] | ((unsigned short)bptr[33u] << 8u);
            const float scale = fp8_pf_fast_fp16_to_float(d_raw);
            const unsigned int base = b << 5u;

            for (unsigned int col = 0u; col < cols; ++col) {
                const float* inp = inputs + (unsigned long long)(col_base + col) * k;
                const float* xbase = inp + base;
                float bsum = 0.0f;
                #pragma unroll 8
                for (unsigned int w = 0u; w < 32u; ++w) {
                    bsum += fp8_pf_e5m2_to_float(bptr[w]) * xbase[w];
                }
                col_sums[col] += scale * bsum;
            }
        }

        for (unsigned int col = 0u; col < cols; ++col) {
            float s = col_sums[col];
            s += __shfl_down_sync(0xffffffffu, s, 16u);
            s += __shfl_down_sync(0xffffffffu, s,  8u);
            s += __shfl_down_sync(0xffffffffu, s,  4u);
            s += __shfl_down_sync(0xffffffffu, s,  2u);
            s += __shfl_down_sync(0xffffffffu, s,  1u);
            if (lane == 0u) {
                const unsigned long long idx = (unsigned long long)(col_base + col) * n_rows + row;
                outputs[idx] = residual[idx] + s;
            }
        }
    }
}

/* =========================================================================
   Kernel 7 — fused_gate_up_swiglu_gemm_fp8_e5m2
   Batch fused gate+up FP8 E5M2 GEMM with SwiGLU epilogue.

   Concatenated gate+up weight matrix: 2*n_ffn_rows rows total.
     gate rows 0..n_ffn_rows-1, up rows n_ffn_rows..2*n_ffn_rows-1.
   blocks pointer covers all 2*n_ffn_rows rows in FP8 E5M2 AoS layout.

   For each (row r, col c):
     outputs[c * n_ffn_rows + r] = SiLU(gate_sum(r,c)) * up_sum(r,c)
   ========================================================================= */
extern "C" __global__ void fused_gate_up_swiglu_gemm_fp8_e5m2(
    const unsigned char* __restrict__ blocks,
    const float*         __restrict__ inputs,
    float*               __restrict__ outputs,
    unsigned int n_ffn_rows,
    unsigned int k,
    unsigned int batch_size
) {
    const unsigned int warp_id = threadIdx.x >> 5u;
    const unsigned int lane    = threadIdx.x & 31u;
    const unsigned int row     = blockIdx.x * 8u + warp_id;
    if (row >= n_ffn_rows) return;

    const unsigned int blocks_per_row = k >> 5u;
    const unsigned int up_row_offset  = n_ffn_rows * blocks_per_row;

    for (unsigned int col_base = 0u; col_base < batch_size; col_base += 8u) {
        const unsigned int cols_remaining = batch_size - col_base;
        const unsigned int cols = cols_remaining < 8u ? cols_remaining : 8u;

        float gate_sums[8];
        float up_sums[8];
        #pragma unroll
        for (unsigned int c = 0u; c < 8u; ++c) { gate_sums[c] = 0.0f; up_sums[c] = 0.0f; }

        for (unsigned int b = lane; b < blocks_per_row; b += 32u) {
            /* Gate block — weights-first layout */
            const unsigned char* gbptr = blocks + (unsigned long long)(row * blocks_per_row + b) * 34u;
            const unsigned short gd_raw = (unsigned short)gbptr[32u] | ((unsigned short)gbptr[33u] << 8u);
            const float gscale = fp8_pf_fast_fp16_to_float(gd_raw);

            /* Up block — weights-first layout */
            const unsigned char* ubptr = blocks + (unsigned long long)(up_row_offset + row * blocks_per_row + b) * 34u;
            const unsigned short ud_raw = (unsigned short)ubptr[32u] | ((unsigned short)ubptr[33u] << 8u);
            const float uscale = fp8_pf_fast_fp16_to_float(ud_raw);

            const unsigned int base = b << 5u;

            for (unsigned int col = 0u; col < cols; ++col) {
                const float* inp = inputs + (unsigned long long)(col_base + col) * k;
                const float* xbase = inp + base;
                float gsum = 0.0f;
                float usum = 0.0f;
                #pragma unroll 8
                for (unsigned int w = 0u; w < 32u; ++w) {
                    const float x = xbase[w];
                    gsum += fp8_pf_e5m2_to_float(gbptr[w]) * x;
                    usum += fp8_pf_e5m2_to_float(ubptr[w]) * x;
                }
                gate_sums[col] += gscale * gsum;
                up_sums[col]   += uscale * usum;
            }
        }

        for (unsigned int col = 0u; col < cols; ++col) {
            float gs = gate_sums[col];
            float us = up_sums[col];
            gs += __shfl_down_sync(0xffffffffu, gs, 16u);
            gs += __shfl_down_sync(0xffffffffu, gs,  8u);
            gs += __shfl_down_sync(0xffffffffu, gs,  4u);
            gs += __shfl_down_sync(0xffffffffu, gs,  2u);
            gs += __shfl_down_sync(0xffffffffu, gs,  1u);
            us += __shfl_down_sync(0xffffffffu, us, 16u);
            us += __shfl_down_sync(0xffffffffu, us,  8u);
            us += __shfl_down_sync(0xffffffffu, us,  4u);
            us += __shfl_down_sync(0xffffffffu, us,  2u);
            us += __shfl_down_sync(0xffffffffu, us,  1u);
            if (lane == 0u) {
                outputs[(unsigned long long)(col_base + col) * n_ffn_rows + row] = fp8_pf_silu(gs) * us;
            }
        }
    }
}

/* =========================================================================
   Kernel 8 — gemv_fp8_e5m2_pf
   Single-token FP8 E5M2 GEMV (for attention inner loop / sequential pass).
   ========================================================================= */
extern "C" __global__ void gemv_fp8_e5m2_pf(
    const unsigned char* __restrict__ blocks,
    const float*         __restrict__ input,
    float*               __restrict__ output,
    unsigned int n_rows,
    unsigned int k
) {
    const unsigned int warp_id = threadIdx.x >> 5u;
    const unsigned int lane    = threadIdx.x & 31u;
    const unsigned int row     = blockIdx.x * 8u + warp_id;
    if (row >= n_rows) return;

    const unsigned int blocks_per_row = k >> 5u;

    float acc = 0.0f;
    for (unsigned int b = lane; b < blocks_per_row; b += 32u) {
        const unsigned char* bptr = blocks + (unsigned long long)(row * blocks_per_row + b) * 34u;
        /* Scale at bytes 32-33 */
        const unsigned short d_raw = (unsigned short)bptr[32u] | ((unsigned short)bptr[33u] << 8u);
        const float scale = fp8_pf_fast_fp16_to_float(d_raw);
        const float* xbase = input + (b << 5u);
        float bsum = 0.0f;
        #pragma unroll 8
        for (unsigned int w = 0u; w < 32u; ++w) {
            bsum += fp8_pf_e5m2_to_float(bptr[w]) * xbase[w];
        }
        acc += scale * bsum;
    }

    acc += __shfl_down_sync(0xffffffffu, acc, 16u);
    acc += __shfl_down_sync(0xffffffffu, acc,  8u);
    acc += __shfl_down_sync(0xffffffffu, acc,  4u);
    acc += __shfl_down_sync(0xffffffffu, acc,  2u);
    acc += __shfl_down_sync(0xffffffffu, acc,  1u);
    if (lane == 0u) output[row] = acc;
}
"#;