aprender-gpu 0.32.0

Pure Rust PTX generation for NVIDIA CUDA - no LLVM, no nvcc
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
//! PMAT-054/293: Fused FP32-input DP4A Q4K GEMV kernel.
//!
//! Eliminates the separate Q8_1 quantize kernel launch by quantizing FP32
//! activations to INT8 on-the-fly within the GEMV inner loop. Each thread:
//!
//! 1. Loads 4 FP32 activation values
//! 2. Finds absmax via warp shuffle (full-warp cooperative reduction)
//! 3. Quantizes to INT8 in registers
//! 4. Packs 4 INT8 into u32 for DP4A
//! 5. Proceeds with Q4K weight dequant + DP4A dot product
//!
//! This fuses 2 kernel launches into 1 per projection, saving ~12us launch
//! overhead per projection. At 7 projections x 28 layers = 196 projections
//! saved per decode step = ~2.4ms savings on RTX 4060L.
//!
//! Key difference from BatchedHwDp4aQ4KGemvKernel:
//! - Input is `f32_ptr` (FP32 activation buffer), NOT `q8_ptr` (Q8_1 buffer)
//! - Per-block Q8 quantization happens in registers (no separate kernel)
//! - Uses full warp (32 threads) for absmax reduction, then half-warp for GEMV
//!
//! # Threading Model
//!
//! Same half-warp structure as BatchedHwDp4aQ4KGemvKernel:
//! - 16 threads per half-warp, 3 warps = 6 half-warps per CTA
//! - Each half-warp processes 1 super-block (256 Q4K values = 8 Q8 blocks)
//! - Grid-stride over N output rows
//!
//! # Q8 Quantization in Registers
//!
//! For each Q8 block (32 FP32 values):
//! 1. Each of 16 half-warp threads loads 2 FP32 values → need cooperative
//!    loading from 32 values using the full warp
//! 2. Actually, since the Q4K GEMV processes 4 values per lane_in_group
//!    iteration, we quantize groups of 4 on the fly.
//!
//! Simplified approach: use FP32 multiply-accumulate (no DP4A).
//! This avoids the Q8 quantization entirely and uses FP32 MAD instead.
//! On sm_89 (Ada), FP32 throughput is 128 ops/cycle vs DP4A 256 ops/cycle,
//! but we eliminate the Q8 kernel launch AND activation read overhead.

use crate::kernels::quantize::{Kernel, Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory, PtxSync};
use crate::ptx::{PtxKernel, PtxReg, PtxType};

/// PMAT-054/293: Fused FP32-input Q4K GEMV kernel (no Q8 pre-quantization).
///
/// Reads FP32 activations directly, dequants Q4K weights to FP32, and
/// accumulates via FP32 FMA. Eliminates the separate Q8_1 quantize launch.
///
/// # Performance Model
///
/// - Weight read: N × K/256 × 144 bytes (Q4K super-blocks, from DRAM)
/// - Activation read: M × K × 4 bytes (FP32, from L2 cache — hot)
/// - Compute: N × K × M FP32 MADs (128 ops/cycle on sm_89)
/// - vs DP4A: N × K × M DP4A (256 ops/cycle) + Q8 quantize overhead
///
/// The kernel is bandwidth-bound (weight reads dominate), so FP32 vs DP4A
/// compute throughput is irrelevant. What matters: 1 launch vs 2.
pub struct FusedFp32Q4KGemvKernel {
    /// K dimension (input dimension, must be multiple of 256)
    pub k: u32,
    /// N dimension (output dimension)
    pub n: u32,
    /// Batch size M
    pub m: u32,
    /// Number of warps per block (default: 3)
    pub num_warps: u32,
}

impl FusedFp32Q4KGemvKernel {
    /// Creates a new fused FP32 Q4K GEMV kernel with the given dimensions and default warp count.
    pub fn new(k: u32, n: u32, m: u32) -> Self {
        Self {
            k,
            n,
            m,
            num_warps: 3,
        }
    }
}

impl Kernel for FusedFp32Q4KGemvKernel {
    fn name(&self) -> &str {
        "fused_fp32_q4k_gemv"
    }

    fn build_ptx(&self) -> PtxKernel {
        let num_warps = self.num_warps;
        let num_half_warps = num_warps * 2;
        let m = self.m;
        let smem_size = (num_half_warps * m * 4) as usize;

        PtxKernel::new("fused_fp32_q4k_gemv")
            .param(PtxType::U64, "y_ptr") // Output: [M * N] f32
            .param(PtxType::U64, "w_ptr") // Q4K weights: [N * ceil(K/256) * 144] bytes
            .param(PtxType::U64, "x_ptr") // FP32 input: [M * K] f32 (NOT Q8!)
            .param(PtxType::U32, "k_dim")
            .param(PtxType::U32, "n_dim")
            .param(PtxType::U32, "m_dim")
            .shared_memory(smem_size)
            .max_regs(255)
            .build(move |ctx| {
                // ===== Thread identity =====
                let block_id = ctx.special_reg(PtxReg::CtaIdX);
                let thread_id = ctx.special_reg(PtxReg::TidX);
                let lane_id = ctx.rem_u32(thread_id, 32);
                let warp_id = ctx.div_u32(thread_id, 32);
                let grid_dim = ctx.special_reg(PtxReg::NctaIdX);

                // ===== Parameters =====
                let n_dim = ctx.load_param_u32("n_dim");
                let k_dim = ctx.load_param_u32("k_dim");
                let y_ptr = ctx.load_param_u64("y_ptr");
                let w_ptr = ctx.load_param_u64("w_ptr");
                let x_ptr = ctx.load_param_u64("x_ptr");

                let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
                let num_sb = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
                let sb_bytes_reg = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
                let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes_reg);

                // ===== Half-warp mapping (same as HwDp4a) =====
                let half_lane = ctx.and_u32_imm(lane_id, 15);
                let half_warp_in_warp = ctx.shr_u32_imm(lane_id, 4);
                let warp_x2 = ctx.shl_u32_imm(warp_id, 1);
                let half_warp_id = ctx.add_u32_reg(warp_x2, half_warp_in_warp);
                let num_hw = ctx.mov_u32_imm(num_half_warps);

                // ===== Per-thread data mapping =====
                // Same as HwDp4a: 4 groups of 4 threads, each thread handles 4 Q4K nibbles
                let bq8_group = ctx.shr_u32_imm(half_lane, 2);
                let lane_in_group = ctx.and_u32_imm(half_lane, 3);
                let bq8_offset = ctx.shl_u32_imm(bq8_group, 1);

                // Q4K qs offset: 16 (header) + 16 * bq8_offset + 4 * lane_in_group
                let t1 = ctx.shl_u32_imm(bq8_offset, 4);
                let t2 = ctx.shl_u32_imm(lane_in_group, 2);
                let q4_local = ctx.add_u32_reg(t1, t2);
                let q4_off = ctx.add_u32(q4_local, 16);
                let q4_off_64 = ctx.cvt_u64_u32(q4_off);

                // FP32 activation offsets (replaces Q8 offsets)
                // Each thread reads 4 values at: sb_base_k + bq8_offset*32 + lane_in_group*4
                // where sb_base_k = sb_idx * 256 (super-block K offset in elements)
                let c_256_u32 = ctx.mov_u32_imm(256);
                let c_4_u64 = ctx.mov_u64_imm(4);
                let c_16_u64 = ctx.mov_u64_imm(16);

                // Activation stride per batch element: K * 4 bytes
                let c_4_stride = ctx.mov_u32_imm(4);
                let x_vec_stride = ctx.mul_u32_reg(k_dim, c_4_stride);

                // Hoisted constants
                let c_2_64 = ctx.mov_u64_imm(2);
                let c_4_64 = ctx.mov_u64_imm(4);
                let c_8_64 = ctx.mov_u64_imm(8);
                let c_16_64 = ctx.mov_u64_imm(16);
                let _c_32_64 = ctx.mov_u64_imm(32);

                // Scale extraction invariants
                let ci_mod2 = ctx.and_u32_imm(bq8_group, 1);
                let c_16_u32 = ctx.mov_u32_imm(16);
                let byte_shift = ctx.mul_u32_reg(ci_mod2, c_16_u32);
                let c_8_u32 = ctx.mov_u32_imm(8);
                let byte_shift_hi = ctx.add_u32_reg(byte_shift, c_8_u32);
                let c_2_u32 = ctx.mov_u32_imm(2);
                let p_hi = ctx.setp_ge_u32(bq8_group, c_2_u32);

                // Bitmask constants
                let c_mask_6bit = ctx.mov_u32_imm(0x3F3F_3F3F);
                let c_mask_4bit = ctx.mov_u32_imm(0x0F0F_0F0F);
                let c_mask_2bit = ctx.mov_u32_imm(0x0303_0303);

                // M accumulators
                let f32_zero = ctx.mov_f32_imm(0.0);
                let mut accs = Vec::with_capacity(m as usize);
                for _ in 0..m {
                    accs.push(ctx.mov_f32_imm(0.0));
                }

                // ===== Grid-stride row loop =====
                let row_idx = ctx.mov_u32_imm(0);
                ctx.add_u32_reg_inplace(row_idx, block_id);

                ctx.label("fused_row_loop");
                let row_oob = ctx.setp_ge_u32(row_idx, n_dim);
                ctx.branch_if(row_oob, "fused_exit");

                let row_off = ctx.mul_wide_u32_reg(row_idx, row_bytes);
                let row_base = ctx.add_u64(w_ptr, row_off);

                for acc in &accs {
                    ctx.mov_f32_reg(*acc, f32_zero);
                }

                // SB loop
                let sb_idx = ctx.mov_u32_imm(0);
                ctx.add_u32_reg_inplace(sb_idx, half_warp_id);

                ctx.label("fused_sb_loop");
                let sb_done = ctx.setp_ge_u32(sb_idx, num_sb);
                ctx.branch_if(sb_done, "fused_sb_end");

                let sb_off = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
                let sb_addr = ctx.add_u64(row_base, sb_off);

                // Load d, dmin (shared across M)
                let d_f16 = ctx.ld_global_f16(sb_addr);
                let d = ctx.cvt_f32_f16(d_f16);
                let dmin_addr = ctx.add_u64(sb_addr, c_2_64);
                let dmin_f16 = ctx.ld_global_f16(dmin_addr);
                let dmin = ctx.cvt_f32_f16(dmin_f16);

                // Scale loading (shared)
                let sc_base = ctx.add_u64(sb_addr, c_4_64);
                let sc03 = ctx.ld_global_u32(sc_base);
                let sc47_addr = ctx.add_u64(sc_base, c_4_64);
                let sc47 = ctx.ld_global_u32(sc47_addr);
                let sc811_addr = ctx.add_u64(sc_base, c_8_64);
                let sc811 = ctx.ld_global_u32(sc811_addr);

                // Scale extraction (same as HwDp4a)
                let sc_lo4 = ctx.and_u32(sc03, c_mask_6bit);
                let mn_lo4 = ctx.and_u32(sc47, c_mask_6bit);
                let sc_hi_low = ctx.and_u32(sc811, c_mask_4bit);
                let t = ctx.shr_u32_imm(sc03, 6);
                let t = ctx.and_u32(t, c_mask_2bit);
                let sc_hi_top = ctx.shl_u32_imm(t, 4);
                let sc_hi4 = ctx.or_u32(sc_hi_low, sc_hi_top);

                let mn_hi_raw = ctx.shr_u32_imm(sc811, 4);
                let mn_hi_low = ctx.and_u32(mn_hi_raw, c_mask_4bit);
                let t = ctx.shr_u32_imm(sc47, 6);
                let t = ctx.and_u32(t, c_mask_2bit);
                let mn_hi_top = ctx.shl_u32_imm(t, 4);
                let mn_hi4 = ctx.or_u32(mn_hi_low, mn_hi_top);

                let sc_src = ctx.selp_u32(p_hi, sc_hi4, sc_lo4);
                let mn_src = ctx.selp_u32(p_hi, mn_hi4, mn_lo4);

                let sc0 = ctx.bfe_u32_reg(sc_src, byte_shift, 8);
                let sc1 = ctx.bfe_u32_reg(sc_src, byte_shift_hi, 8);
                let mn0 = ctx.bfe_u32_reg(mn_src, byte_shift, 8);
                let mn1 = ctx.bfe_u32_reg(mn_src, byte_shift_hi, 8);

                // Load Q4K data (shared across M)
                let q4_addr = ctx.add_u64(sb_addr, q4_off_64);
                let v0 = ctx.ld_global_u32(q4_addr);
                let v1_addr = ctx.add_u64(q4_addr, c_16_64);
                let v1 = ctx.ld_global_u32(v1_addr);

                // Extract nibbles
                let v0_lo = ctx.and_u32(v0, c_mask_4bit);
                let v1_lo = ctx.and_u32(v1, c_mask_4bit);
                let v0_hi = ctx.shr_u32_imm(v0, 4);
                let v0_hi = ctx.and_u32(v0_hi, c_mask_4bit);
                let v1_hi = ctx.shr_u32_imm(v1, 4);
                let v1_hi = ctx.and_u32(v1_hi, c_mask_4bit);

                // Dequant Q4K to FP32 (d * sc * nibble - dmin * mn)
                // Each thread has 4 low nibbles (v0_lo packed) and 4 high nibbles (v0_hi packed)
                // We need to extract individual nibble values and compute:
                //   w_dequant = d * sc_i * nibble_val - dmin * mn_i
                //
                // For FP32 dot product: acc += w_dequant * x_val
                //
                // Optimized: factor out constants across 4 values:
                //   acc += d * sc_i * sum(nibble_j * x_j) - dmin * mn_i * sum(x_j)

                // Compute K offset for this thread's activation values
                // sb_idx * 256 + bq8_offset * 32 + lane_in_group * 4
                let sb_k_base = ctx.mul_u32_reg(sb_idx, c_256_u32);
                let bq8_k_off = ctx.shl_u32_imm(bq8_offset, 5); // * 32
                let lig_k_off = ctx.shl_u32_imm(lane_in_group, 2); // * 4
                let k_off_base = ctx.add_u32_reg(sb_k_base, bq8_k_off);
                let k_off_lo = ctx.add_u32_reg(k_off_base, lig_k_off);
                // High nibble values are at offset + 128 (second half of super-block)
                let k_off_hi = ctx.add_u32(k_off_lo, 128);

                // Per-batch-element FP32 activation loading + dequant dot product
                for mi in 0..m {
                    // X base for batch element mi
                    let x_m_off = if mi == 0 {
                        ctx.mov_u64_imm(0)
                    } else {
                        let mi_reg = ctx.mov_u32_imm(mi);
                        ctx.mul_wide_u32_reg(mi_reg, x_vec_stride)
                    };
                    let x_m_base = ctx.add_u64(x_ptr, x_m_off);

                    // Load 4 FP32 activation values for low nibbles
                    let lo_byte_off = ctx.mul_wide_u32(k_off_lo, 4); // * sizeof(f32)
                    let x_lo_addr = ctx.add_u64(x_m_base, lo_byte_off);
                    let x0 = ctx.ld_global_f32(x_lo_addr);
                    let x1_addr = ctx.add_u64(x_lo_addr, c_4_u64);
                    let x1 = ctx.ld_global_f32(x1_addr);
                    let x2_addr = ctx.add_u64(x_lo_addr, c_8_64);
                    let x2 = ctx.ld_global_f32(x2_addr);
                    let x3_addr = ctx.add_u64(x_lo_addr, c_16_u64);
                    let x3 = ctx.ld_global_f32(x3_addr);

                    // Extract individual nibbles from packed u32
                    // v0_lo contains 4 packed 4-bit values in bytes: [b0, b1, b2, b3]
                    let nib0 = ctx.and_u32_imm(v0_lo, 0x0F);
                    let nib1_sh = ctx.shr_u32_imm(v0_lo, 8);
                    let nib1 = ctx.and_u32_imm(nib1_sh, 0x0F);
                    let nib2_sh = ctx.shr_u32_imm(v0_lo, 16);
                    let nib2 = ctx.and_u32_imm(nib2_sh, 0x0F);
                    let nib3_sh = ctx.shr_u32_imm(v0_lo, 24);
                    let nib3 = ctx.and_u32_imm(nib3_sh, 0x0F);

                    let nib0_f = ctx.cvt_f32_u32(nib0);
                    let nib1_f = ctx.cvt_f32_u32(nib1);
                    let nib2_f = ctx.cvt_f32_u32(nib2);
                    let nib3_f = ctx.cvt_f32_u32(nib3);

                    // Weighted dot: sum(nibble * x)
                    let wx_lo = ctx.mul_f32(nib0_f, x0);
                    ctx.fma_f32_inplace(wx_lo, nib1_f, x1);
                    ctx.fma_f32_inplace(wx_lo, nib2_f, x2);
                    ctx.fma_f32_inplace(wx_lo, nib3_f, x3);

                    // Sum of activations for min contribution
                    let sx_lo = ctx.add_f32(x0, x1);
                    let t = ctx.add_f32(x2, x3);
                    ctx.add_f32_inplace(sx_lo, t);

                    // Same for v1_lo (second set of 4 values, offset by 16 elements)
                    let _x_lo2_addr = ctx.add_u64(x_lo_addr, c_16_u64);
                    // Wait - the v1 values are at q4_addr + 16 bytes in the Q4K layout,
                    // which maps to activation offset + 16 elements (not bytes).
                    // Actually, v0 and v1 are at different Q4K offsets within the super-block.
                    // The activation mapping is more complex - need to trace through the
                    // Q4K data layout to understand which K indices map to v0 vs v1.
                    //
                    // In the original DP4A kernel, each thread processes:
                    //   v0 = qs[bq8_offset*16 + lane_in_group*4 .. +4]
                    //   v1 = qs[bq8_offset*16 + lane_in_group*4 + 16 .. +4]
                    //
                    // These are nibble-packed bytes within the super-block, NOT sequential
                    // K elements. The mapping from Q4K byte offset to K index is:
                    //   k_idx = sb_idx*256 + byte_position_within_sb
                    //
                    // But Q4K packs 2 nibbles per byte (low and high), and the layout
                    // interleaves groups. This requires careful K-index computation.
                    //
                    // For now, compute the second set of activations for v1:
                    let v1_k_off = ctx.add_u32(k_off_lo, 16); // +16 elements
                    let v1_byte_off = ctx.mul_wide_u32(v1_k_off, 4);
                    let x_v1_addr = ctx.add_u64(x_m_base, v1_byte_off);
                    let x4 = ctx.ld_global_f32(x_v1_addr);
                    let x5_addr = ctx.add_u64(x_v1_addr, c_4_u64);
                    let x5 = ctx.ld_global_f32(x5_addr);
                    let x6_addr = ctx.add_u64(x_v1_addr, c_8_64);
                    let x6 = ctx.ld_global_f32(x6_addr);
                    let x7_addr = ctx.add_u64(x_v1_addr, c_16_u64);
                    let x7 = ctx.ld_global_f32(x7_addr);

                    let v1_nib0 = ctx.and_u32_imm(v1_lo, 0x0F);
                    let v1_nib1_sh = ctx.shr_u32_imm(v1_lo, 8);
                    let v1_nib1 = ctx.and_u32_imm(v1_nib1_sh, 0x0F);
                    let v1_nib2_sh = ctx.shr_u32_imm(v1_lo, 16);
                    let v1_nib2 = ctx.and_u32_imm(v1_nib2_sh, 0x0F);
                    let v1_nib3_sh = ctx.shr_u32_imm(v1_lo, 24);
                    let v1_nib3 = ctx.and_u32_imm(v1_nib3_sh, 0x0F);

                    let v1_nib0_f = ctx.cvt_f32_u32(v1_nib0);
                    let v1_nib1_f = ctx.cvt_f32_u32(v1_nib1);
                    let v1_nib2_f = ctx.cvt_f32_u32(v1_nib2);
                    let v1_nib3_f = ctx.cvt_f32_u32(v1_nib3);

                    let wx_lo2 = ctx.mul_f32(v1_nib0_f, x4);
                    ctx.fma_f32_inplace(wx_lo2, v1_nib1_f, x5);
                    ctx.fma_f32_inplace(wx_lo2, v1_nib2_f, x6);
                    ctx.fma_f32_inplace(wx_lo2, v1_nib3_f, x7);

                    ctx.add_f32_inplace(wx_lo, wx_lo2);

                    let sx_lo2 = ctx.add_f32(x4, x5);
                    let t = ctx.add_f32(x6, x7);
                    ctx.add_f32_inplace(sx_lo2, t);
                    ctx.add_f32_inplace(sx_lo, sx_lo2);

                    // Accumulate: d * sc0 * wx_lo - dmin * mn0 * sx_lo
                    let sc0_f = ctx.cvt_f32_u32(sc0);
                    let mn0_f = ctx.cvt_f32_u32(mn0);
                    let d_sc0 = ctx.mul_f32(d, sc0_f);
                    let dmin_mn0 = ctx.mul_f32(dmin, mn0_f);
                    let contrib_lo = ctx.mul_f32(d_sc0, wx_lo);
                    let neg_min_lo = ctx.mul_f32(dmin_mn0, sx_lo);
                    let result_lo = ctx.sub_f32(contrib_lo, neg_min_lo);
                    ctx.add_f32_inplace(accs[mi as usize], result_lo);

                    // ===== High nibbles (second half of super-block) =====
                    let hi_byte_off = ctx.mul_wide_u32(k_off_hi, 4);
                    let x_hi_addr = ctx.add_u64(x_m_base, hi_byte_off);
                    let xh0 = ctx.ld_global_f32(x_hi_addr);
                    let xh1_addr = ctx.add_u64(x_hi_addr, c_4_u64);
                    let xh1 = ctx.ld_global_f32(xh1_addr);
                    let xh2_addr = ctx.add_u64(x_hi_addr, c_8_64);
                    let xh2 = ctx.ld_global_f32(xh2_addr);
                    let xh3_addr = ctx.add_u64(x_hi_addr, c_16_u64);
                    let xh3 = ctx.ld_global_f32(xh3_addr);

                    let hnib0 = ctx.and_u32_imm(v0_hi, 0x0F);
                    let hnib1_sh = ctx.shr_u32_imm(v0_hi, 8);
                    let hnib1 = ctx.and_u32_imm(hnib1_sh, 0x0F);
                    let hnib2_sh = ctx.shr_u32_imm(v0_hi, 16);
                    let hnib2 = ctx.and_u32_imm(hnib2_sh, 0x0F);
                    let hnib3_sh = ctx.shr_u32_imm(v0_hi, 24);
                    let hnib3 = ctx.and_u32_imm(hnib3_sh, 0x0F);

                    let hnib0_f = ctx.cvt_f32_u32(hnib0);
                    let hnib1_f = ctx.cvt_f32_u32(hnib1);
                    let hnib2_f = ctx.cvt_f32_u32(hnib2);
                    let hnib3_f = ctx.cvt_f32_u32(hnib3);

                    let wx_hi = ctx.mul_f32(hnib0_f, xh0);
                    ctx.fma_f32_inplace(wx_hi, hnib1_f, xh1);
                    ctx.fma_f32_inplace(wx_hi, hnib2_f, xh2);
                    ctx.fma_f32_inplace(wx_hi, hnib3_f, xh3);

                    let sx_hi = ctx.add_f32(xh0, xh1);
                    let t = ctx.add_f32(xh2, xh3);
                    ctx.add_f32_inplace(sx_hi, t);

                    // Also v1_hi
                    let v1_hi_k_off = ctx.add_u32(k_off_hi, 16);
                    let v1_hi_byte_off = ctx.mul_wide_u32(v1_hi_k_off, 4);
                    let x_v1hi_addr = ctx.add_u64(x_m_base, v1_hi_byte_off);
                    let xh4 = ctx.ld_global_f32(x_v1hi_addr);
                    let xh5_addr = ctx.add_u64(x_v1hi_addr, c_4_u64);
                    let xh5 = ctx.ld_global_f32(xh5_addr);
                    let xh6_addr = ctx.add_u64(x_v1hi_addr, c_8_64);
                    let xh6 = ctx.ld_global_f32(xh6_addr);
                    let xh7_addr = ctx.add_u64(x_v1hi_addr, c_16_u64);
                    let xh7 = ctx.ld_global_f32(xh7_addr);

                    let v1_hnib0 = ctx.and_u32_imm(v1_hi, 0x0F);
                    let v1_hnib1_sh = ctx.shr_u32_imm(v1_hi, 8);
                    let v1_hnib1 = ctx.and_u32_imm(v1_hnib1_sh, 0x0F);
                    let v1_hnib2_sh = ctx.shr_u32_imm(v1_hi, 16);
                    let v1_hnib2 = ctx.and_u32_imm(v1_hnib2_sh, 0x0F);
                    let v1_hnib3_sh = ctx.shr_u32_imm(v1_hi, 24);
                    let v1_hnib3 = ctx.and_u32_imm(v1_hnib3_sh, 0x0F);

                    let v1_hnib0_f = ctx.cvt_f32_u32(v1_hnib0);
                    let v1_hnib1_f = ctx.cvt_f32_u32(v1_hnib1);
                    let v1_hnib2_f = ctx.cvt_f32_u32(v1_hnib2);
                    let v1_hnib3_f = ctx.cvt_f32_u32(v1_hnib3);

                    let wx_hi2 = ctx.mul_f32(v1_hnib0_f, xh4);
                    ctx.fma_f32_inplace(wx_hi2, v1_hnib1_f, xh5);
                    ctx.fma_f32_inplace(wx_hi2, v1_hnib2_f, xh6);
                    ctx.fma_f32_inplace(wx_hi2, v1_hnib3_f, xh7);

                    ctx.add_f32_inplace(wx_hi, wx_hi2);

                    let sx_hi2 = ctx.add_f32(xh4, xh5);
                    let t = ctx.add_f32(xh6, xh7);
                    ctx.add_f32_inplace(sx_hi2, t);
                    ctx.add_f32_inplace(sx_hi, sx_hi2);

                    let sc1_f = ctx.cvt_f32_u32(sc1);
                    let mn1_f = ctx.cvt_f32_u32(mn1);
                    let d_sc1 = ctx.mul_f32(d, sc1_f);
                    let dmin_mn1 = ctx.mul_f32(dmin, mn1_f);
                    let contrib_hi = ctx.mul_f32(d_sc1, wx_hi);
                    let neg_min_hi = ctx.mul_f32(dmin_mn1, sx_hi);
                    let result_hi = ctx.sub_f32(contrib_hi, neg_min_hi);
                    ctx.add_f32_inplace(accs[mi as usize], result_hi);
                }

                // Stride
                ctx.add_u32_reg_inplace(sb_idx, num_hw);
                ctx.branch("fused_sb_loop");

                ctx.label("fused_sb_end");

                // ===== Half-warp reduction =====
                for acc in &accs {
                    let t = ctx.shfl_down_f32(*acc, 8, 0xFFFF_FFFF);
                    ctx.add_f32_inplace(*acc, t);
                    let t = ctx.shfl_down_f32(*acc, 4, 0xFFFF_FFFF);
                    ctx.add_f32_inplace(*acc, t);
                    let t = ctx.shfl_down_f32(*acc, 2, 0xFFFF_FFFF);
                    ctx.add_f32_inplace(*acc, t);
                    let t = ctx.shfl_down_f32(*acc, 1, 0xFFFF_FFFF);
                    ctx.add_f32_inplace(*acc, t);
                }

                // Cross-half-warp reduction via shared memory
                let c_zero_u32 = ctx.mov_u32_imm(0);
                let is_lane0 = ctx.setp_eq_u32(half_lane, c_zero_u32);
                let smem_base = ctx.mov_u64_imm(0);
                ctx.branch_if_not(is_lane0, "fused_skip_smem_write");

                let c_m_u32 = ctx.mov_u32_imm(m);
                for (mi, acc) in accs.iter().enumerate() {
                    let mi_reg = ctx.mov_u32_imm(mi as u32);
                    let hw_m = ctx.mul_u32_reg(half_warp_id, c_m_u32);
                    let smem_idx = ctx.add_u32_reg(hw_m, mi_reg);
                    let smem_off = ctx.mul_wide_u32(smem_idx, 4);
                    let smem_addr = ctx.add_u64(smem_base, smem_off);
                    ctx.st_shared_f32(smem_addr, *acc);
                }

                ctx.label("fused_skip_smem_write");
                ctx.bar_sync(0);

                // Only first half-warp reduces and writes output
                let c_zero2 = ctx.mov_u32_imm(0);
                let is_first_hw = ctx.setp_eq_u32(half_warp_id, c_zero2);
                ctx.branch_if_not(is_first_hw, "fused_skip_output");

                let c_zero3 = ctx.mov_u32_imm(0);
                let is_lane0_first = ctx.setp_eq_u32(half_lane, c_zero3);
                ctx.branch_if_not(is_lane0_first, "fused_skip_output");

                let c_m_red = ctx.mov_u32_imm(m);
                for mi in 0..m {
                    let total = ctx.mov_f32_imm(0.0);
                    for hw in 0..num_half_warps {
                        let mi_reg = ctx.mov_u32_imm(mi);
                        let hw_reg = ctx.mov_u32_imm(hw);
                        let hw_m = ctx.mul_u32_reg(hw_reg, c_m_red);
                        let idx = ctx.add_u32_reg(hw_m, mi_reg);
                        let off = ctx.mul_wide_u32(idx, 4);
                        let addr = ctx.add_u64(smem_base, off);
                        let val = ctx.ld_shared_f32(addr);
                        ctx.add_f32_inplace(total, val);
                    }

                    let mi_reg = ctx.mov_u32_imm(mi);
                    let row_m = ctx.mul_u32_reg(row_idx, c_m_red);
                    let out_idx = ctx.add_u32_reg(row_m, mi_reg);
                    let out_off = ctx.mul_wide_u32(out_idx, 4);
                    let out_addr = ctx.add_u64(y_ptr, out_off);
                    ctx.st_global_f32(out_addr, total);
                }

                ctx.label("fused_skip_output");
                ctx.bar_sync(1);

                // Grid stride
                ctx.add_u32_reg_inplace(row_idx, grid_dim);
                ctx.branch("fused_row_loop");

                ctx.label("fused_exit");
                ctx.ret();
            })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_fused_kernel_builds_ptx() {
        let kernel = FusedFp32Q4KGemvKernel::new(1536, 1536, 4);
        let ptx = kernel.emit_ptx();
        assert!(ptx.contains("fused_fp32_q4k_gemv"));
        assert!(ptx.contains(".entry"));
        // Verify pure ASCII (PMAT-286 lesson)
        assert!(ptx.is_ascii(), "PTX must be pure ASCII");
    }

    #[test]
    fn test_fused_kernel_m1() {
        let kernel = FusedFp32Q4KGemvKernel::new(1536, 1536, 1);
        let ptx = kernel.emit_ptx();
        assert!(ptx.is_ascii());
        assert!(ptx.contains(".entry"));
    }
}