1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
//! Tiled Q4_K GEMV kernel with shared memory input caching
//!
//! Input vector is cached in shared memory and shared by multiple outputs.
//! Addresses inefficiency in `Q4KGemvKernel` where each warp loads entire
//! input vector from global memory.
use crate::kernels::quantize::{Kernel, Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
/// Tiled Q4_K GEMV kernel with shared memory input caching
///
/// Addresses inefficiency in `Q4KGemvKernel`:
/// - Original: Each warp loads entire input vector from global memory
/// - Tiled: Input vector cached in shared memory, shared by multiple outputs
pub struct TiledQ4KGemvKernel {
/// K dimension (input dimension, must be multiple of 256)
pub k: u32,
/// N dimension (output dimension)
pub n: u32,
/// Number of outputs per block (default: 4)
pub outputs_per_block: u32,
}
impl TiledQ4KGemvKernel {
/// Create a new tiled Q4_K GEMV kernel
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self {
k,
n,
outputs_per_block: 4, // Default: 4 outputs per block (128 threads = 4 warps)
}
}
/// Set number of outputs computed per block
#[must_use]
pub const fn with_outputs_per_block(mut self, outputs_per_block: u32) -> Self {
self.outputs_per_block = outputs_per_block;
self
}
}
impl Kernel for TiledQ4KGemvKernel {
fn name(&self) -> &str {
"tiled_q4k_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let k = self.k;
let outputs_per_block = self.outputs_per_block;
// Shared memory for input vector: K floats
let smem_size = (k * 4) as usize;
PtxKernel::new("tiled_q4k_gemv")
.param(PtxType::U64, "y_ptr") // Output vector (N)
.param(PtxType::U64, "w_ptr") // Q4_K weights (N × K/256 super-blocks)
.param(PtxType::U64, "x_ptr") // Input vector (K)
.param(PtxType::U32, "k_dim") // K dimension
.param(PtxType::U32, "n_dim") // N dimension
.shared_memory(smem_size)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
// Load parameters
let n_dim = ctx.load_param_u32("n_dim");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
// Constants
let four = ctx.mov_u32_imm(4);
let outputs_per_block_reg = ctx.mov_u32_imm(outputs_per_block);
// ================================================================
// PHASE 1: Cooperatively load input vector into shared memory
// ================================================================
let idx = ctx.mov_u32_imm(0);
ctx.label("load_loop");
let loop_idx = ctx.add_u32_reg(idx, thread_id);
let in_bounds = ctx.setp_lt_u32(loop_idx, k_dim);
ctx.branch_if_not(in_bounds, "load_loop_end");
// Load x[loop_idx] from global memory
let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
let x_addr = ctx.add_u64(x_ptr, elem_offset);
let x_val = ctx.ld_global_f32(x_addr);
// GH-37 FIX: Use direct .shared addressing (u32 offset) instead of
// generic addressing (cvta.shared.u64 + ld/st). Direct .shared is
// more efficient: fewer registers, no cvta instruction, hardware
// knows the address space.
let smem_offset = ctx.mul_u32_reg(loop_idx, four);
ctx.st_shared_f32(smem_offset, x_val);
ctx.add_u32_inplace(idx, 32 * outputs_per_block); // stride by block size
ctx.branch("load_loop");
ctx.label("load_loop_end");
// Synchronize: ensure input is fully loaded
ctx.bar_sync(0);
// ================================================================
// PHASE 2: Compute multiple outputs using cached input
// ================================================================
// Each warp computes one output element
// With 8 warps per block, we compute up to 8 outputs per block
let warp_id = ctx.div_u32(thread_id, 32);
let lane_id = ctx.rem_u32(thread_id, 32);
// Calculate which output this warp is computing
let base_output = ctx.mul_u32_reg(block_id, outputs_per_block_reg);
let output_idx = ctx.add_u32_reg(base_output, warp_id);
// Check if this warp has work to do
let warp_oob = ctx.setp_ge_u32(output_idx, n_dim);
ctx.branch_if(warp_oob, "exit");
// Also check if warp_id < outputs_per_block
let warp_beyond_block = ctx.setp_ge_u32(warp_id, outputs_per_block_reg);
ctx.branch_if(warp_beyond_block, "exit");
// Initialize accumulator
let acc = ctx.mov_f32_imm(0.0);
// Calculate number of super-blocks: ceil(K / 256) for GGUF
let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_super_blocks = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
// Calculate base address for this row's weights
let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_super_blocks, sb_bytes);
let row_offset = ctx.mul_wide_u32_reg(output_idx, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_offset);
// Super-block loop
let sb_idx = ctx.mov_u32_imm(0);
ctx.label("sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_super_blocks);
ctx.branch_if(sb_done, "sb_loop_end");
let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_offset);
// Load d and dmin
let d_f16 = ctx.ld_global_f16(sb_addr);
let d = ctx.cvt_f32_f16(d_f16);
let two = ctx.mov_u64_imm(2);
let dmin_addr = ctx.add_u64(sb_addr, two);
let dmin_f16 = ctx.ld_global_f16(dmin_addr);
let dmin = ctx.cvt_f32_f16(dmin_f16);
// Each thread in warp processes 8 elements (256 per super-block / 32 threads)
let thread_partial = ctx.mov_f32_imm(0.0);
for offset in [0u32, 32, 64, 96, 128, 160, 192, 224] {
let offset_reg = ctx.mov_u32_imm(offset);
let val_idx = ctx.add_u32_reg(lane_id, offset_reg);
// Determine sub-block (0-7)
let sub_block = ctx.div_u32(val_idx, 32);
// Load scale bytes (simplified - could be optimized further)
let four_64 = ctx.mov_u64_imm(4);
let scales_base = ctx.add_u64(sb_addr, four_64);
// Simple scale/min extraction for sub-blocks 0-3
let sub_block_lt_4 = ctx.mov_u32_imm(4);
let is_simple = ctx.setp_lt_u32(sub_block, sub_block_lt_4);
let sub_block_64 = ctx.cvt_u64_u32(sub_block);
let scale_byte_addr = ctx.add_u64(scales_base, sub_block_64);
let scale_byte = ctx.ld_global_u8(scale_byte_addr);
let scale_byte_32 = ctx.cvt_u32_u8(scale_byte);
let four_reg = ctx.mov_u32_imm(4);
let sub_block_plus_4 = ctx.add_u32_reg(sub_block, four_reg);
let sub_block_plus_4_64 = ctx.cvt_u64_u32(sub_block_plus_4);
let min_byte_addr = ctx.add_u64(scales_base, sub_block_plus_4_64);
let min_byte = ctx.ld_global_u8(min_byte_addr);
let min_byte_32 = ctx.cvt_u32_u8(min_byte);
let mask_6bit = ctx.mov_u32_imm(0x3F);
let mask_4bit = ctx.mov_u32_imm(0x0F);
let six = ctx.mov_u32_imm(6);
let scale_simple = ctx.and_u32(scale_byte_32, mask_6bit);
let min_simple = ctx.and_u32(min_byte_32, mask_6bit);
// Complex path for blocks 4-7
// CORRECTNESS-001: Fixed scale/min extraction per GGML Q4_K spec
// CPU reference (extract_scale_min at realizar/quantize.rs:6589):
// scale = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4)
// min = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4)
let eight_64 = ctx.mov_u64_imm(8);
let scales_8_base = ctx.add_u64(scales_base, eight_64);
// Safe subtraction: for sub_block < 4, use 0 to avoid underflow
// (the loaded value won't be used anyway due to selp)
let sub_block_minus_4_raw = ctx.sub_u32_reg(sub_block, four_reg);
let zero_safe = ctx.mov_u32_imm(0);
let sub_block_minus_4 =
ctx.selp_u32(is_simple, zero_safe, sub_block_minus_4_raw);
let sub_block_minus_4_64 = ctx.cvt_u64_u32(sub_block_minus_4);
let scales_8_addr = ctx.add_u64(scales_8_base, sub_block_minus_4_64);
let s8_byte = ctx.ld_global_u8(scales_8_addr);
let s8_byte_32 = ctx.cvt_u32_u8(s8_byte);
// Load scales[sub_block - 4] for scale high bits (not scales[sub_block]!)
let scale_hi_src_addr = ctx.add_u64(scales_base, sub_block_minus_4_64);
let scale_hi_src_byte = ctx.ld_global_u8(scale_hi_src_addr);
let scale_hi_src_32 = ctx.cvt_u32_u8(scale_hi_src_byte);
// scale = (scales[sub_block + 4] & 0x0F) | ((scales[sub_block - 4] >> 6) << 4)
let s8_lo = ctx.and_u32(s8_byte_32, mask_4bit);
let s0_hi = ctx.shr_u32(scale_hi_src_32, six);
let s0_hi_shifted = ctx.shl_u32(s0_hi, four_reg);
let scale_complex = ctx.or_u32(s8_lo, s0_hi_shifted);
// min = (scales[sub_block + 4] >> 4) | ((scales[sub_block] >> 6) << 4)
// Note: use scale_byte_32 (scales[sub_block]) NOT min_byte_32 (scales[sub_block + 4])
let s8_hi = ctx.shr_u32(s8_byte_32, four_reg);
let s4_hi = ctx.shr_u32(scale_byte_32, six);
let s4_hi_shifted = ctx.shl_u32(s4_hi, four_reg);
let min_complex = ctx.or_u32(s8_hi, s4_hi_shifted);
let scale = ctx.selp_u32(is_simple, scale_simple, scale_complex);
let min = ctx.selp_u32(is_simple, min_simple, min_complex);
let scale_f = ctx.cvt_f32_u32(scale);
let min_f = ctx.cvt_f32_u32(min);
let ds = ctx.mul_f32(d, scale_f);
let dm = ctx.mul_f32(dmin, min_f);
// Load quantized value
let sixteen_64 = ctx.mov_u64_imm(16);
let qs_base = ctx.add_u64(sb_addr, sixteen_64);
let chunk_idx = ctx.div_u32(val_idx, 64);
let val_in_chunk = ctx.rem_u32(val_idx, 64);
let byte_in_chunk = ctx.rem_u32(val_in_chunk, 32);
let chunk_offset = ctx.mul_u32(chunk_idx, 32);
let qs_byte_offset = ctx.add_u32_reg(chunk_offset, byte_in_chunk);
let qs_byte_offset_64 = ctx.cvt_u64_u32(qs_byte_offset);
let qs_addr = ctx.add_u64(qs_base, qs_byte_offset_64);
let packed = ctx.ld_global_u8(qs_addr);
let packed_32 = ctx.cvt_u32_u8(packed);
let val_in_chunk_div_32 = ctx.div_u32(val_in_chunk, 32);
let shift_amount = ctx.mul_u32_reg(val_in_chunk_div_32, four_reg);
let shifted = ctx.shr_u32(packed_32, shift_amount);
let quant = ctx.and_u32(shifted, mask_4bit);
// Dequantize
let quant_f32 = ctx.cvt_f32_u32(quant);
let scaled = ctx.mul_f32(ds, quant_f32);
let dequant = ctx.sub_f32(scaled, dm);
// Load activation from SHARED MEMORY (the key optimization!)
// GH-37 FIX: Use direct .shared addressing (u32 offset, ld.shared.f32)
let sb_k_base = ctx.mul_u32(sb_idx, Q4K_SUPER_BLOCK_SIZE);
let x_idx = ctx.add_u32_reg(sb_k_base, val_idx);
let x_smem_offset = ctx.mul_u32_reg(x_idx, four);
let x_cached = ctx.ld_shared_f32(x_smem_offset);
ctx.fma_f32_inplace(thread_partial, x_cached, dequant);
}
ctx.add_f32_inplace(acc, thread_partial);
ctx.add_u32_inplace(sb_idx, 1);
ctx.branch("sb_loop");
ctx.label("sb_loop_end");
// Warp shuffle reduction
let shfl16 = ctx.shfl_down_f32(acc, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl16);
let shfl8 = ctx.shfl_down_f32(acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl8);
let shfl4 = ctx.shfl_down_f32(acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl4);
let shfl2 = ctx.shfl_down_f32(acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl2);
let shfl1 = ctx.shfl_down_f32(acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl1);
// Only lane 0 of each warp writes
let one = ctx.mov_u32_imm(1);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
ctx.branch_if_not(is_lane0, "exit");
// Store y[output_idx]
let y_offset = ctx.mul_wide_u32_reg(output_idx, four);
let y_addr = ctx.add_u64(y_ptr, y_offset);
ctx.st_global_f32(y_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
}