trueno-gpu 0.4.11

Pure Rust PTX generation for NVIDIA CUDA - no LLVM, no nvcc
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
//! Tiled Q4_K GEMV Kernels with Shared Memory Input Caching
//!
//! - `TiledQ4KGemvKernel`: Input vector cached in shared memory
//! - `ChunkedTiledQ4KGemvKernel`: Handles K > 8K with fixed 32KB chunks

use crate::kernels::quantize::{Kernel, Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};

/// Tiled Q4_K GEMV kernel with shared memory input caching
///
/// Addresses inefficiency in `Q4KGemvKernel`:
/// - Original: Each warp loads entire input vector from global memory
/// - Tiled: Input vector cached in shared memory, shared by multiple outputs
pub struct TiledQ4KGemvKernel {
    /// K dimension (input dimension, must be multiple of 256)
    pub k: u32,
    /// N dimension (output dimension)
    pub n: u32,
    /// Number of outputs per block (default: 4)
    pub outputs_per_block: u32,
}

impl TiledQ4KGemvKernel {
    /// Create a new tiled Q4_K GEMV kernel
    #[must_use]
    pub fn new(k: u32, n: u32) -> Self {
        Self {
            k,
            n,
            outputs_per_block: 4, // Default: 4 outputs per block (128 threads = 4 warps)
        }
    }

    /// Set number of outputs computed per block
    #[must_use]
    pub const fn with_outputs_per_block(mut self, outputs_per_block: u32) -> Self {
        self.outputs_per_block = outputs_per_block;
        self
    }
}

impl Kernel for TiledQ4KGemvKernel {
    fn name(&self) -> &str {
        "tiled_q4k_gemv"
    }

    fn build_ptx(&self) -> PtxKernel {
        let k = self.k;
        let outputs_per_block = self.outputs_per_block;

        // Shared memory for input vector: K floats
        let smem_size = (k * 4) as usize;

        PtxKernel::new("tiled_q4k_gemv")
            .param(PtxType::U64, "y_ptr")     // Output vector (N)
            .param(PtxType::U64, "w_ptr")     // Q4_K weights (N × K/256 super-blocks)
            .param(PtxType::U64, "x_ptr")     // Input vector (K)
            .param(PtxType::U32, "k_dim")     // K dimension
            .param(PtxType::U32, "n_dim")     // N dimension
            .shared_memory(smem_size)
            .build(move |ctx| {
                let block_id = ctx.special_reg(PtxReg::CtaIdX);
                let thread_id = ctx.special_reg(PtxReg::TidX);

                // Load parameters
                let n_dim = ctx.load_param_u32("n_dim");
                let k_dim = ctx.load_param_u32("k_dim");
                let y_ptr = ctx.load_param_u64("y_ptr");
                let w_ptr = ctx.load_param_u64("w_ptr");
                let x_ptr = ctx.load_param_u64("x_ptr");

                // Constants
                let four = ctx.mov_u32_imm(4);
                let outputs_per_block_reg = ctx.mov_u32_imm(outputs_per_block);

                // Get shared memory base address (FIX: needed for correct addressing)
                let smem_base = ctx.shared_base_addr();

                // ================================================================
                // PHASE 1: Cooperatively load input vector into shared memory
                // ================================================================
                let idx = ctx.mov_u32_imm(0);

                ctx.label("load_loop");
                let loop_idx = ctx.add_u32_reg(idx, thread_id);
                let in_bounds = ctx.setp_lt_u32(loop_idx, k_dim);
                ctx.branch_if_not(in_bounds, "load_loop_end");

                // Load x[loop_idx] from global memory
                let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
                let x_addr = ctx.add_u64(x_ptr, elem_offset);
                let x_val = ctx.ld_global_f32(x_addr);

                // Store to shared memory using generic addressing
                // smem_base is a generic address from cvta.shared, so use generic st/ld
                let smem_addr = ctx.add_u64(smem_base, elem_offset);
                ctx.st_generic_f32(smem_addr, x_val);

                ctx.add_u32_inplace(idx, 32 * outputs_per_block); // stride by block size
                ctx.branch("load_loop");

                ctx.label("load_loop_end");

                // Synchronize: ensure input is fully loaded
                ctx.bar_sync(0);

                // ================================================================
                // PHASE 2: Compute multiple outputs using cached input
                // ================================================================
                // Each warp computes one output element
                // With 8 warps per block, we compute up to 8 outputs per block
                let warp_id = ctx.div_u32(thread_id, 32);
                let lane_id = ctx.rem_u32(thread_id, 32);

                // Calculate which output this warp is computing
                let base_output = ctx.mul_u32_reg(block_id, outputs_per_block_reg);
                let output_idx = ctx.add_u32_reg(base_output, warp_id);

                // Check if this warp has work to do
                let warp_oob = ctx.setp_ge_u32(output_idx, n_dim);
                ctx.branch_if(warp_oob, "exit");

                // Also check if warp_id < outputs_per_block
                let warp_beyond_block = ctx.setp_ge_u32(warp_id, outputs_per_block_reg);
                ctx.branch_if(warp_beyond_block, "exit");

                // Initialize accumulator
                let acc = ctx.mov_f32_imm(0.0);

                // Calculate number of super-blocks: ceil(K / 256) for GGUF
                let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
                let num_super_blocks = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);

                // Calculate base address for this row's weights
                let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
                let row_bytes = ctx.mul_u32_reg(num_super_blocks, sb_bytes);
                let row_offset = ctx.mul_wide_u32_reg(output_idx, row_bytes);
                let row_base = ctx.add_u64(w_ptr, row_offset);

                // Super-block loop
                let sb_idx = ctx.mov_u32_imm(0);

                ctx.label("sb_loop");
                let sb_done = ctx.setp_ge_u32(sb_idx, num_super_blocks);
                ctx.branch_if(sb_done, "sb_loop_end");

                let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
                let sb_addr = ctx.add_u64(row_base, sb_offset);

                // Load d and dmin
                let d_f16 = ctx.ld_global_f16(sb_addr);
                let d = ctx.cvt_f32_f16(d_f16);
                let two = ctx.mov_u64_imm(2);
                let dmin_addr = ctx.add_u64(sb_addr, two);
                let dmin_f16 = ctx.ld_global_f16(dmin_addr);
                let dmin = ctx.cvt_f32_f16(dmin_f16);

                // Each thread in warp processes 8 elements (256 per super-block / 32 threads)
                let thread_partial = ctx.mov_f32_imm(0.0);

                for offset in [0u32, 32, 64, 96, 128, 160, 192, 224] {
                    let offset_reg = ctx.mov_u32_imm(offset);
                    let val_idx = ctx.add_u32_reg(lane_id, offset_reg);

                    // Determine sub-block (0-7)
                    let sub_block = ctx.div_u32(val_idx, 32);

                    // Load scale bytes (simplified - could be optimized further)
                    let four_64 = ctx.mov_u64_imm(4);
                    let scales_base = ctx.add_u64(sb_addr, four_64);

                    // Simple scale/min extraction for sub-blocks 0-3
                    let sub_block_lt_4 = ctx.mov_u32_imm(4);
                    let is_simple = ctx.setp_lt_u32(sub_block, sub_block_lt_4);

                    let sub_block_64 = ctx.cvt_u64_u32(sub_block);
                    let scale_byte_addr = ctx.add_u64(scales_base, sub_block_64);
                    let scale_byte = ctx.ld_global_u8(scale_byte_addr);
                    let scale_byte_32 = ctx.cvt_u32_u8(scale_byte);

                    let four_reg = ctx.mov_u32_imm(4);
                    let sub_block_plus_4 = ctx.add_u32_reg(sub_block, four_reg);
                    let sub_block_plus_4_64 = ctx.cvt_u64_u32(sub_block_plus_4);
                    let min_byte_addr = ctx.add_u64(scales_base, sub_block_plus_4_64);
                    let min_byte = ctx.ld_global_u8(min_byte_addr);
                    let min_byte_32 = ctx.cvt_u32_u8(min_byte);

                    let mask_6bit = ctx.mov_u32_imm(0x3F);
                    let mask_4bit = ctx.mov_u32_imm(0x0F);
                    let six = ctx.mov_u32_imm(6);

                    let scale_simple = ctx.and_u32(scale_byte_32, mask_6bit);
                    let min_simple = ctx.and_u32(min_byte_32, mask_6bit);

                    // Complex path for blocks 4-7
                    // CORRECTNESS-001: Fixed scale/min extraction per GGML Q4_K spec
                    // CPU reference (extract_scale_min at realizar/quantize.rs:6589):
                    //   scale = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4)
                    //   min   = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4)
                    let eight_64 = ctx.mov_u64_imm(8);
                    let scales_8_base = ctx.add_u64(scales_base, eight_64);
                    // Safe subtraction: for sub_block < 4, use 0 to avoid underflow
                    // (the loaded value won't be used anyway due to selp)
                    let sub_block_minus_4_raw = ctx.sub_u32_reg(sub_block, four_reg);
                    let zero_safe = ctx.mov_u32_imm(0);
                    let sub_block_minus_4 = ctx.selp_u32(is_simple, zero_safe, sub_block_minus_4_raw);
                    let sub_block_minus_4_64 = ctx.cvt_u64_u32(sub_block_minus_4);
                    let scales_8_addr = ctx.add_u64(scales_8_base, sub_block_minus_4_64);
                    let s8_byte = ctx.ld_global_u8(scales_8_addr);
                    let s8_byte_32 = ctx.cvt_u32_u8(s8_byte);

                    // Load scales[sub_block - 4] for scale high bits (not scales[sub_block]!)
                    let scale_hi_src_addr = ctx.add_u64(scales_base, sub_block_minus_4_64);
                    let scale_hi_src_byte = ctx.ld_global_u8(scale_hi_src_addr);
                    let scale_hi_src_32 = ctx.cvt_u32_u8(scale_hi_src_byte);

                    // scale = (scales[sub_block + 4] & 0x0F) | ((scales[sub_block - 4] >> 6) << 4)
                    let s8_lo = ctx.and_u32(s8_byte_32, mask_4bit);
                    let s0_hi = ctx.shr_u32(scale_hi_src_32, six);
                    let s0_hi_shifted = ctx.shl_u32(s0_hi, four_reg);
                    let scale_complex = ctx.or_u32(s8_lo, s0_hi_shifted);

                    // min = (scales[sub_block + 4] >> 4) | ((scales[sub_block] >> 6) << 4)
                    // Note: use scale_byte_32 (scales[sub_block]) NOT min_byte_32 (scales[sub_block + 4])
                    let s8_hi = ctx.shr_u32(s8_byte_32, four_reg);
                    let s4_hi = ctx.shr_u32(scale_byte_32, six);
                    let s4_hi_shifted = ctx.shl_u32(s4_hi, four_reg);
                    let min_complex = ctx.or_u32(s8_hi, s4_hi_shifted);

                    let scale = ctx.selp_u32(is_simple, scale_simple, scale_complex);
                    let min = ctx.selp_u32(is_simple, min_simple, min_complex);

                    let scale_f = ctx.cvt_f32_u32(scale);
                    let min_f = ctx.cvt_f32_u32(min);
                    let ds = ctx.mul_f32(d, scale_f);
                    let dm = ctx.mul_f32(dmin, min_f);

                    // Load quantized value
                    let sixteen_64 = ctx.mov_u64_imm(16);
                    let qs_base = ctx.add_u64(sb_addr, sixteen_64);

                    let chunk_idx = ctx.div_u32(val_idx, 64);
                    let val_in_chunk = ctx.rem_u32(val_idx, 64);
                    let byte_in_chunk = ctx.rem_u32(val_in_chunk, 32);

                    let chunk_offset = ctx.mul_u32(chunk_idx, 32);
                    let qs_byte_offset = ctx.add_u32_reg(chunk_offset, byte_in_chunk);
                    let qs_byte_offset_64 = ctx.cvt_u64_u32(qs_byte_offset);
                    let qs_addr = ctx.add_u64(qs_base, qs_byte_offset_64);
                    let packed = ctx.ld_global_u8(qs_addr);
                    let packed_32 = ctx.cvt_u32_u8(packed);

                    let val_in_chunk_div_32 = ctx.div_u32(val_in_chunk, 32);
                    let shift_amount = ctx.mul_u32_reg(val_in_chunk_div_32, four_reg);
                    let shifted = ctx.shr_u32(packed_32, shift_amount);
                    let quant = ctx.and_u32(shifted, mask_4bit);

                    // Dequantize
                    let quant_f32 = ctx.cvt_f32_u32(quant);
                    let scaled = ctx.mul_f32(ds, quant_f32);
                    let dequant = ctx.sub_f32(scaled, dm);

                    // Load activation from SHARED MEMORY (the key optimization!)
                    // Using generic addressing (smem_base from cvta.shared)
                    let sb_k_base = ctx.mul_u32(sb_idx, Q4K_SUPER_BLOCK_SIZE);
                    let x_idx = ctx.add_u32_reg(sb_k_base, val_idx);
                    let x_smem_offset = ctx.mul_wide_u32_reg(x_idx, four);
                    let x_smem_addr = ctx.add_u64(smem_base, x_smem_offset);
                    let x_cached = ctx.ld_generic_f32(x_smem_addr);

                    ctx.fma_f32_inplace(thread_partial, x_cached, dequant);
                }

                ctx.add_f32_inplace(acc, thread_partial);
                ctx.add_u32_inplace(sb_idx, 1);
                ctx.branch("sb_loop");

                ctx.label("sb_loop_end");

                // Warp shuffle reduction
                let shfl16 = ctx.shfl_down_f32(acc, 16, 0xFFFF_FFFF);
                ctx.add_f32_inplace(acc, shfl16);
                let shfl8 = ctx.shfl_down_f32(acc, 8, 0xFFFF_FFFF);
                ctx.add_f32_inplace(acc, shfl8);
                let shfl4 = ctx.shfl_down_f32(acc, 4, 0xFFFF_FFFF);
                ctx.add_f32_inplace(acc, shfl4);
                let shfl2 = ctx.shfl_down_f32(acc, 2, 0xFFFF_FFFF);
                ctx.add_f32_inplace(acc, shfl2);
                let shfl1 = ctx.shfl_down_f32(acc, 1, 0xFFFF_FFFF);
                ctx.add_f32_inplace(acc, shfl1);

                // Only lane 0 of each warp writes
                let one = ctx.mov_u32_imm(1);
                let is_lane0 = ctx.setp_lt_u32(lane_id, one);
                ctx.branch_if_not(is_lane0, "exit");

                // Store y[output_idx]
                let y_offset = ctx.mul_wide_u32_reg(output_idx, four);
                let y_addr = ctx.add_u64(y_ptr, y_offset);
                ctx.st_global_f32(y_addr, acc);

                ctx.label("exit");
                ctx.ret();
            })
    }
}

// =============================================================================
// PAR-056: CHUNKED TILED Q4K GEMV FOR LARGE K DIMENSIONS
// =============================================================================

/// Chunked Tiled Q4_K GEMV kernel for large input dimensions.
///
/// This kernel extends `TiledQ4KGemvKernel` to handle K dimensions that exceed
/// CUDA shared memory limits (48KB default, 96KB max). It processes the input
/// vector in chunks that fit within shared memory.
///
/// # Problem Solved
///
/// The original `TiledQ4KGemvKernel` allocates K × 4 bytes of shared memory:
/// - 7B FFN down (K=18944): 75KB needed > 48KB default
/// - 32B FFN down (K=27648): 107KB needed > 96KB max
///
/// This kernel uses a fixed 8K element (32KB) chunk size, safe for all GPUs.
///
/// # Algorithm
///
/// 1. For each chunk of 8K elements:
///    a. Cooperatively load chunk into shared memory
///    b. Process super-blocks that use elements from this chunk
///    c. Accumulate partial dot products
/// 2. Final warp reduction and global memory store
///
/// # Performance
///
/// - Memory reads: K × 4 bytes (same as TiledQ4KGemvKernel)
/// - Shared memory: 32KB fixed (vs K × 4 which can exceed limits)
/// - Extra overhead: One barrier per chunk (negligible for large K)
///
/// # Grid Configuration
///
/// - Block: 256 threads (8 warps)
/// - Grid: ceil(N / outputs_per_block) blocks
/// - Shared memory: 32KB fixed (8K floats)
#[derive(Debug, Clone)]
pub struct ChunkedTiledQ4KGemvKernel {
    /// K dimension (input dimension, must be multiple of 256)
    pub k: u32,
    /// N dimension (output dimension)
    pub n: u32,
    /// Number of outputs per block (default: 4)
    pub outputs_per_block: u32,
}

/// Chunk size in elements (8K floats = 32KB, safe for 48KB limit)
const CHUNK_SIZE: u32 = 8192;
/// Chunk size in bytes
const CHUNK_BYTES: u32 = CHUNK_SIZE * 4;

impl ChunkedTiledQ4KGemvKernel {
    /// Create a new chunked tiled Q4_K GEMV kernel
    #[must_use]
    pub fn new(k: u32, n: u32) -> Self {
        Self {
            k,
            n,
            outputs_per_block: 4,
        }
    }

    /// Set number of outputs computed per block
    #[must_use]
    pub const fn with_outputs_per_block(mut self, outputs_per_block: u32) -> Self {
        self.outputs_per_block = outputs_per_block;
        self
    }

    /// Check if chunking is needed (K > 8K elements)
    #[must_use]
    pub const fn needs_chunking(&self) -> bool {
        self.k > CHUNK_SIZE
    }
}

impl Kernel for ChunkedTiledQ4KGemvKernel {
    fn name(&self) -> &str {
        "chunked_tiled_q4k_gemv"
    }

    fn build_ptx(&self) -> PtxKernel {
        let _k = self.k;
        let outputs_per_block = self.outputs_per_block;

        // Fixed 32KB shared memory (8K floats)
        let smem_size = CHUNK_BYTES as usize;

        PtxKernel::new("chunked_tiled_q4k_gemv")
            .param(PtxType::U64, "y_ptr")
            .param(PtxType::U64, "w_ptr")
            .param(PtxType::U64, "x_ptr")
            .param(PtxType::U32, "k_dim")
            .param(PtxType::U32, "n_dim")
            .shared_memory(smem_size)
            .build(move |ctx| {
                let block_id = ctx.special_reg(PtxReg::CtaIdX);
                let thread_id = ctx.special_reg(PtxReg::TidX);

                // Load parameters
                let n_dim = ctx.load_param_u32("n_dim");
                let k_dim = ctx.load_param_u32("k_dim");
                let y_ptr = ctx.load_param_u64("y_ptr");
                let w_ptr = ctx.load_param_u64("w_ptr");
                let x_ptr = ctx.load_param_u64("x_ptr");

                // Constants
                let four = ctx.mov_u32_imm(4);
                let outputs_per_block_reg = ctx.mov_u32_imm(outputs_per_block);

                // Calculate warp and lane IDs
                let warp_id = ctx.div_u32(thread_id, 32);
                let lane_id = ctx.rem_u32(thread_id, 32);

                // Calculate which output this warp is computing
                let base_output = ctx.mul_u32_reg(block_id, outputs_per_block_reg);
                let output_idx = ctx.add_u32_reg(base_output, warp_id);

                // Check bounds
                let warp_oob = ctx.setp_ge_u32(output_idx, n_dim);
                ctx.branch_if(warp_oob, "exit");
                let warp_beyond_block = ctx.setp_ge_u32(warp_id, outputs_per_block_reg);
                ctx.branch_if(warp_beyond_block, "exit");

                // Initialize global accumulator
                let global_acc = ctx.mov_f32_imm(0.0);

                // Calculate number of super-blocks: ceil(K / 256) for GGUF
                let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
                let num_super_blocks = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
                let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
                let row_bytes = ctx.mul_u32_reg(num_super_blocks, sb_bytes);
                let row_offset = ctx.mul_wide_u32_reg(output_idx, row_bytes);
                let row_base = ctx.add_u64(w_ptr, row_offset);

                // Calculate number of chunks using bit operations (CHUNK_SIZE = 8192 = 2^13)
                // num_chunks = k_dim >> 13
                let num_chunks = ctx.shr_u32_imm(k_dim, 13);
                // k_remainder = k_dim & 0x1FFF (8191 = 0x1FFF)
                let remainder_mask = ctx.mov_u32_imm(0x1FFF);
                let k_remainder = ctx.and_u32(k_dim, remainder_mask);
                // has_remainder = k_remainder >= 1 (equivalent to > 0)
                let one = ctx.mov_u32_imm(1);
                let has_remainder = ctx.setp_ge_u32(k_remainder, one);
                let zero_reg = ctx.mov_u32_imm(0);
                let extra_chunk = ctx.selp_u32(has_remainder, one, zero_reg);
                let total_chunks = ctx.add_u32_reg(num_chunks, extra_chunk);

                // ================================================================
                // OUTER LOOP: Process input in chunks
                // ================================================================
                let chunk_idx = ctx.mov_u32_imm(0);

                ctx.label("chunk_loop");
                let chunk_done = ctx.setp_ge_u32(chunk_idx, total_chunks);
                ctx.branch_if(chunk_done, "chunk_loop_end");

                // Calculate chunk start position: chunk_idx << 13
                let chunk_start = ctx.shl_u32_imm(chunk_idx, 13);

                // Calculate elements in this chunk (may be less for last chunk)
                let chunk_end = ctx.add_u32(chunk_start, CHUNK_SIZE);
                // clamp_to_k = chunk_end > k_dim, i.e., k_dim < chunk_end
                let clamp_to_k = ctx.setp_lt_u32(k_dim, chunk_end);
                let actual_chunk_end = ctx.selp_u32(clamp_to_k, k_dim, chunk_end);
                let chunk_elements = ctx.sub_u32_reg(actual_chunk_end, chunk_start);

                // ================================================================
                // PHASE 1: Cooperatively load chunk into shared memory
                // ================================================================
                let load_idx = ctx.mov_u32_imm(0);

                ctx.label("load_loop");
                let loop_load_idx = ctx.add_u32_reg(load_idx, thread_id);
                let in_bounds = ctx.setp_lt_u32(loop_load_idx, chunk_elements);
                ctx.branch_if_not(in_bounds, "load_loop_end");

                // Global index = chunk_start + loop_load_idx
                let global_idx = ctx.add_u32_reg(chunk_start, loop_load_idx);
                let global_offset = ctx.mul_wide_u32_reg(global_idx, four);
                let x_addr = ctx.add_u64(x_ptr, global_offset);
                let x_val = ctx.ld_global_f32(x_addr);

                // Store to shared memory at local offset
                // FIX: Use u32 offset for .shared state space (smem is < 48KB)
                let smem_offset = ctx.mul_u32_reg(loop_load_idx, four);
                ctx.st_shared_f32(smem_offset, x_val);

                ctx.add_u32_inplace(load_idx, 256);
                ctx.branch("load_loop");

                ctx.label("load_loop_end");

                // Barrier: ensure chunk is fully loaded
                ctx.bar_sync(0);

                // ================================================================
                // PHASE 2: Process super-blocks in this chunk's range
                // ================================================================
                // Super-block range: [chunk_start/256, chunk_end/256)
                // Division by 256 = right shift by 8
                let sb_start = ctx.shr_u32_imm(chunk_start, 8);
                let sb_end_candidate = ctx.shr_u32_imm(actual_chunk_end, 8);
                // Clamp to actual super-block count: if sb_end_candidate > num_super_blocks
                // i.e., num_super_blocks < sb_end_candidate
                let sb_oob = ctx.setp_lt_u32(num_super_blocks, sb_end_candidate);
                let sb_end = ctx.selp_u32(sb_oob, num_super_blocks, sb_end_candidate);

                // Copy sb_start to sb_idx for loop
                let sb_idx = ctx.add_u32_reg(sb_start, zero_reg);

                ctx.label("sb_loop");
                let sb_done = ctx.setp_ge_u32(sb_idx, sb_end);
                ctx.branch_if(sb_done, "sb_loop_end");

                let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
                let sb_addr = ctx.add_u64(row_base, sb_offset);

                // Load d and dmin
                let d_f16 = ctx.ld_global_f16(sb_addr);
                let d = ctx.cvt_f32_f16(d_f16);
                let two = ctx.mov_u64_imm(2);
                let dmin_addr = ctx.add_u64(sb_addr, two);
                let dmin_f16 = ctx.ld_global_f16(dmin_addr);
                let dmin = ctx.cvt_f32_f16(dmin_f16);

                // Each thread in warp processes 8 elements (256 per super-block / 32 threads)
                let thread_partial = ctx.mov_f32_imm(0.0);

                for offset in [0u32, 32, 64, 96, 128, 160, 192, 224] {
                    let offset_reg = ctx.mov_u32_imm(offset);
                    let val_idx = ctx.add_u32_reg(lane_id, offset_reg);

                    // Determine sub-block (0-7)
                    let sub_block = ctx.div_u32(val_idx, 32);

                    // Load scale bytes
                    let four_64 = ctx.mov_u64_imm(4);
                    let scales_base = ctx.add_u64(sb_addr, four_64);

                    let sub_block_lt_4 = ctx.mov_u32_imm(4);
                    let is_simple = ctx.setp_lt_u32(sub_block, sub_block_lt_4);

                    let sub_block_64 = ctx.cvt_u64_u32(sub_block);
                    let scale_byte_addr = ctx.add_u64(scales_base, sub_block_64);
                    let scale_byte = ctx.ld_global_u8(scale_byte_addr);
                    let scale_byte_32 = ctx.cvt_u32_u8(scale_byte);

                    let four_reg = ctx.mov_u32_imm(4);
                    let sub_block_plus_4 = ctx.add_u32_reg(sub_block, four_reg);
                    let sub_block_plus_4_64 = ctx.cvt_u64_u32(sub_block_plus_4);
                    let min_byte_addr = ctx.add_u64(scales_base, sub_block_plus_4_64);
                    let min_byte = ctx.ld_global_u8(min_byte_addr);
                    let min_byte_32 = ctx.cvt_u32_u8(min_byte);

                    let mask_6bit = ctx.mov_u32_imm(0x3F);
                    let mask_4bit = ctx.mov_u32_imm(0x0F);
                    let six = ctx.mov_u32_imm(6);

                    let scale_simple = ctx.and_u32(scale_byte_32, mask_6bit);
                    let min_simple = ctx.and_u32(min_byte_32, mask_6bit);

                    // Complex path for blocks 4-7
                    // CORRECTNESS-001: Fixed scale/min extraction per GGML Q4_K spec
                    // CPU reference (extract_scale_min at realizar/quantize.rs:6589):
                    //   scale = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4)
                    //   min   = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4)
                    let eight_64 = ctx.mov_u64_imm(8);
                    let scales_8_base = ctx.add_u64(scales_base, eight_64);
                    // Safe subtraction: for sub_block < 4, use 0 to avoid underflow
                    // (the loaded value won't be used anyway due to selp)
                    let sub_block_minus_4_raw = ctx.sub_u32_reg(sub_block, four_reg);
                    let zero_safe = ctx.mov_u32_imm(0);
                    let sub_block_minus_4 = ctx.selp_u32(is_simple, zero_safe, sub_block_minus_4_raw);
                    let sub_block_minus_4_64 = ctx.cvt_u64_u32(sub_block_minus_4);
                    let scales_8_addr = ctx.add_u64(scales_8_base, sub_block_minus_4_64);
                    let s8_byte = ctx.ld_global_u8(scales_8_addr);
                    let s8_byte_32 = ctx.cvt_u32_u8(s8_byte);

                    // Load scales[sub_block - 4] for scale high bits (not scales[sub_block]!)
                    let scale_hi_src_addr = ctx.add_u64(scales_base, sub_block_minus_4_64);
                    let scale_hi_src_byte = ctx.ld_global_u8(scale_hi_src_addr);
                    let scale_hi_src_32 = ctx.cvt_u32_u8(scale_hi_src_byte);

                    // scale = (scales[sub_block + 4] & 0x0F) | ((scales[sub_block - 4] >> 6) << 4)
                    let s8_lo = ctx.and_u32(s8_byte_32, mask_4bit);
                    let s0_hi = ctx.shr_u32(scale_hi_src_32, six);
                    let s0_hi_shifted = ctx.shl_u32(s0_hi, four_reg);
                    let scale_complex = ctx.or_u32(s8_lo, s0_hi_shifted);

                    // min = (scales[sub_block + 4] >> 4) | ((scales[sub_block] >> 6) << 4)
                    // Note: use scale_byte_32 (scales[sub_block]) NOT min_byte_32 (scales[sub_block + 4])
                    let s8_hi = ctx.shr_u32(s8_byte_32, four_reg);
                    let s4_hi = ctx.shr_u32(scale_byte_32, six);
                    let s4_hi_shifted = ctx.shl_u32(s4_hi, four_reg);
                    let min_complex = ctx.or_u32(s8_hi, s4_hi_shifted);

                    let scale = ctx.selp_u32(is_simple, scale_simple, scale_complex);
                    let min = ctx.selp_u32(is_simple, min_simple, min_complex);

                    let scale_f = ctx.cvt_f32_u32(scale);
                    let min_f = ctx.cvt_f32_u32(min);
                    let ds = ctx.mul_f32(d, scale_f);
                    let dm = ctx.mul_f32(dmin, min_f);

                    // Load quantized value
                    let sixteen_64 = ctx.mov_u64_imm(16);
                    let qs_base = ctx.add_u64(sb_addr, sixteen_64);

                    let chunk_idx_inner = ctx.div_u32(val_idx, 64);
                    let val_in_chunk = ctx.rem_u32(val_idx, 64);
                    let byte_in_chunk = ctx.rem_u32(val_in_chunk, 32);

                    let chunk_offset_inner = ctx.mul_u32(chunk_idx_inner, 32);
                    let qs_byte_offset = ctx.add_u32_reg(chunk_offset_inner, byte_in_chunk);
                    let qs_byte_offset_64 = ctx.cvt_u64_u32(qs_byte_offset);
                    let qs_addr = ctx.add_u64(qs_base, qs_byte_offset_64);
                    let packed = ctx.ld_global_u8(qs_addr);
                    let packed_32 = ctx.cvt_u32_u8(packed);

                    let val_in_chunk_div_32 = ctx.div_u32(val_in_chunk, 32);
                    let shift_amount = ctx.mul_u32_reg(val_in_chunk_div_32, four_reg);
                    let shifted = ctx.shr_u32(packed_32, shift_amount);
                    let quant = ctx.and_u32(shifted, mask_4bit);

                    // Dequantize
                    let quant_f32 = ctx.cvt_f32_u32(quant);
                    let scaled = ctx.mul_f32(ds, quant_f32);
                    let dequant = ctx.sub_f32(scaled, dm);

                    // Load activation from SHARED MEMORY
                    // Local index = (sb_idx * 256 + val_idx) - chunk_start
                    let sb_k_base = ctx.mul_u32(sb_idx, Q4K_SUPER_BLOCK_SIZE);
                    let global_x_idx = ctx.add_u32_reg(sb_k_base, val_idx);
                    let local_x_idx = ctx.sub_u32_reg(global_x_idx, chunk_start);
                    // FIX: Use u32 offset for .shared state space (smem is < 48KB)
                    let x_smem_offset = ctx.mul_u32_reg(local_x_idx, four);
                    let x_cached = ctx.ld_shared_f32(x_smem_offset);

                    ctx.fma_f32_inplace(thread_partial, x_cached, dequant);
                }

                ctx.add_f32_inplace(global_acc, thread_partial);
                ctx.add_u32_inplace(sb_idx, 1);
                ctx.branch("sb_loop");

                ctx.label("sb_loop_end");

                // Barrier before next chunk load
                ctx.bar_sync(1);

                ctx.add_u32_inplace(chunk_idx, 1);
                ctx.branch("chunk_loop");

                ctx.label("chunk_loop_end");

                // ================================================================
                // PHASE 3: Final warp reduction and store
                // ================================================================
                let shfl16 = ctx.shfl_down_f32(global_acc, 16, 0xFFFF_FFFF);
                ctx.add_f32_inplace(global_acc, shfl16);
                let shfl8 = ctx.shfl_down_f32(global_acc, 8, 0xFFFF_FFFF);
                ctx.add_f32_inplace(global_acc, shfl8);
                let shfl4 = ctx.shfl_down_f32(global_acc, 4, 0xFFFF_FFFF);
                ctx.add_f32_inplace(global_acc, shfl4);
                let shfl2 = ctx.shfl_down_f32(global_acc, 2, 0xFFFF_FFFF);
                ctx.add_f32_inplace(global_acc, shfl2);
                let shfl1 = ctx.shfl_down_f32(global_acc, 1, 0xFFFF_FFFF);
                ctx.add_f32_inplace(global_acc, shfl1);

                // Only lane 0 of each warp writes
                let one = ctx.mov_u32_imm(1);
                let is_lane0 = ctx.setp_lt_u32(lane_id, one);
                ctx.branch_if_not(is_lane0, "exit");

                // Store y[output_idx]
                let y_offset = ctx.mul_wide_u32_reg(output_idx, four);
                let y_addr = ctx.add_u64(y_ptr, y_offset);
                ctx.st_global_f32(y_addr, global_acc);

                ctx.label("exit");
                ctx.ret();
            })
    }
}