1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
//! Q5_0 GEMV Kernel
//!
//! 5-bit quantization with high bits: (nibble | (high_bit << 4)) - 16.
use super::{Q5_0_BLOCK_BYTES, Q5_0_BLOCK_SIZE};
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
/// Q5_0 GEMV kernel - handles Qwen 0.5B and similar models
///
/// Q5_0 format (per block of 32 elements):
/// - d: fp16 scale (2 bytes, offset 0)
/// - qh: u32 with 32 high bits (4 bytes, offset 2)
/// - qs: packed 4-bit nibbles (16 bytes, offset 6)
///
/// Dequantization: val = d * ((nibble | (high_bit << 4)) - 16)
#[derive(Debug, Clone)]
pub struct Q5_0GemvKernel {
/// K dimension (input dimension)
pub k: u32,
/// N dimension (output dimension)
pub n: u32,
}
impl Q5_0GemvKernel {
/// Create a new Q5_0 GEMV kernel
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self { k, n }
}
/// Get number of blocks per row (ceiling division)
#[must_use]
pub const fn num_blocks_per_row(&self) -> u32 {
(self.k + Q5_0_BLOCK_SIZE - 1) / Q5_0_BLOCK_SIZE
}
}
impl Kernel for Q5_0GemvKernel {
fn name(&self) -> &str {
"q5_0_gemv_warp_reduce"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("q5_0_gemv_warp_reduce")
.param(PtxType::U64, "y_ptr") // Output vector (N)
.param(PtxType::U64, "w_ptr") // Q5_0 weights (N x K/32 blocks)
.param(PtxType::U64, "x_ptr") // Input vector (K)
.param(PtxType::U32, "k_dim") // K dimension
.param(PtxType::U32, "n_dim") // N dimension
.build(|ctx| {
// Block = 32 threads (one warp), grid = N blocks
// Each block computes one output element y[block_id]
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
// Bounds check
let n_dim = ctx.load_param_u32("n_dim");
let oob = ctx.setp_ge_u32(block_id, n_dim);
ctx.branch_if(oob, "exit");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let acc = ctx.mov_f32_imm(0.0);
// Number of blocks per row: ceil(K / 32)
let k_rounded = ctx.add_u32(k_dim, Q5_0_BLOCK_SIZE - 1);
let num_blocks = ctx.div_u32(k_rounded, Q5_0_BLOCK_SIZE);
// Row base address: w_ptr + block_id * num_blocks * 22
let block_bytes = ctx.mov_u32_imm(Q5_0_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_blocks, block_bytes);
let row_offset = ctx.mul_wide_u32_reg(block_id, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_offset);
// Loop over blocks (each thread handles one value per block)
let blk_idx = ctx.mov_u32_imm(0);
ctx.label("blk_loop");
let blk_done = ctx.setp_ge_u32(blk_idx, num_blocks);
ctx.branch_if(blk_done, "blk_loop_end");
// Block address = row_base + blk_idx * 22
let blk_offset = ctx.mul_wide_u32(blk_idx, Q5_0_BLOCK_BYTES);
let blk_addr = ctx.add_u64(row_base, blk_offset);
// Load scale d (fp16 at offset 0)
let d_f16 = ctx.ld_global_f16(blk_addr);
let d = ctx.cvt_f32_f16(d_f16);
// Load qh (u32 at offset 2) - contains high bits for all 32 values
// PAR-061-FIX: Use byte loads to avoid misaligned u32 access
// Q5_0 blocks are 22 bytes, so offset 2 is not guaranteed 4-byte aligned
let two_64 = ctx.mov_u64_imm(2);
let qh_addr = ctx.add_u64(blk_addr, two_64);
let qh_b0 = ctx.ld_global_u8(qh_addr);
let three_64 = ctx.mov_u64_imm(3);
let qh_addr1 = ctx.add_u64(blk_addr, three_64);
let qh_b1 = ctx.ld_global_u8(qh_addr1);
let four_64 = ctx.mov_u64_imm(4);
let qh_addr2 = ctx.add_u64(blk_addr, four_64);
let qh_b2 = ctx.ld_global_u8(qh_addr2);
let five_64 = ctx.mov_u64_imm(5);
let qh_addr3 = ctx.add_u64(blk_addr, five_64);
let qh_b3 = ctx.ld_global_u8(qh_addr3);
// Combine bytes: qh = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)
let qh_b0_u32 = ctx.cvt_u32_u8(qh_b0);
let qh_b1_u32 = ctx.cvt_u32_u8(qh_b1);
let qh_b2_u32 = ctx.cvt_u32_u8(qh_b2);
let qh_b3_u32 = ctx.cvt_u32_u8(qh_b3);
let qh_b1_shifted = ctx.shl_u32_imm(qh_b1_u32, 8);
let qh_b2_shifted = ctx.shl_u32_imm(qh_b2_u32, 16);
let qh_b3_shifted = ctx.shl_u32_imm(qh_b3_u32, 24);
let qh_01 = ctx.or_u32(qh_b0_u32, qh_b1_shifted);
let qh_012 = ctx.or_u32(qh_01, qh_b2_shifted);
let qh = ctx.or_u32(qh_012, qh_b3_shifted);
// Extract high bit for this thread: (qh >> thread_id) & 1
let high_bit = ctx.shr_u32(qh, thread_id);
let one_u32 = ctx.mov_u32_imm(1);
let high_bit_masked = ctx.and_u32(high_bit, one_u32);
// Load nibble for this thread from qs (offset 6)
// qs layout: 32 4-bit values packed into 16 bytes
// Nibble index = thread_id, byte index = thread_id / 2
// Low/high nibble = thread_id % 2
let six_64 = ctx.mov_u64_imm(6);
let qs_base = ctx.add_u64(blk_addr, six_64);
// byte_idx = thread_id / 2
let byte_idx = ctx.div_u32(thread_id, 2);
let byte_idx_64 = ctx.cvt_u64_u32(byte_idx);
let qs_addr = ctx.add_u64(qs_base, byte_idx_64);
// Load the byte containing our nibble
let qs_byte = ctx.ld_global_u8(qs_addr);
let qs_byte_u32 = ctx.cvt_u32_u8(qs_byte);
// Extract nibble: if thread_id is odd, use high nibble (>> 4)
// nibble_select = (thread_id % 2) * 4 = (thread_id & 1) << 2
let nibble_select = ctx.and_u32(thread_id, one_u32);
let shift_amount = ctx.mul_u32(nibble_select, 4);
let shifted = ctx.shr_u32(qs_byte_u32, shift_amount);
let fifteen_u32 = ctx.mov_u32_imm(15);
let nibble = ctx.and_u32(shifted, fifteen_u32);
// Combine nibble with high bit: q = nibble | (high_bit << 4)
let high_shifted = ctx.shl_u32_imm(high_bit_masked, 4);
let q_5bit = ctx.or_u32(nibble, high_shifted);
// Center: q_centered = q - 16 (result may be negative, -16 to +15)
let sixteen_u32 = ctx.mov_u32_imm(16);
let q_centered = ctx.sub_u32_reg(q_5bit, sixteen_u32);
// Convert to float and dequantize
// cvt_f32_s32 interprets the bits as signed, so negative values work correctly
let q_f32 = ctx.cvt_f32_s32(q_centered);
let dequant = ctx.mul_f32(d, q_f32);
// Load activation x[blk_idx * 32 + thread_id]
let blk_k_base = ctx.mul_u32(blk_idx, Q5_0_BLOCK_SIZE);
let x_idx = ctx.add_u32_reg(blk_k_base, thread_id);
// Bounds check for last block (K may not be multiple of 32)
let x_oob = ctx.setp_ge_u32(x_idx, k_dim);
ctx.branch_if(x_oob, "skip_mul");
let x_idx_64 = ctx.cvt_u64_u32(x_idx);
let x_bytes = ctx.mul_u64(x_idx_64, 4);
let x_addr = ctx.add_u64(x_ptr, x_bytes);
let x_val = ctx.ld_global_f32(x_addr);
ctx.fma_f32_inplace(acc, x_val, dequant);
ctx.label("skip_mul");
ctx.add_u32_inplace(blk_idx, 1);
ctx.branch("blk_loop");
ctx.label("blk_loop_end");
// Warp reduce
let tmp16 = ctx.shfl_down_f32(acc, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, tmp16);
let tmp8 = ctx.shfl_down_f32(acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, tmp8);
let tmp4 = ctx.shfl_down_f32(acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, tmp4);
let tmp2 = ctx.shfl_down_f32(acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, tmp2);
let tmp1 = ctx.shfl_down_f32(acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, tmp1);
// Thread 0 writes result
let is_thread0 = ctx.setp_lt_u32(thread_id, one_u32);
ctx.branch_if_not(is_thread0, "exit");
let y_offset = ctx.mul_wide_u32(block_id, 4);
let y_addr = ctx.add_u64(y_ptr, y_offset);
ctx.st_global_f32(y_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
}