use super::{Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct FusedRmsNormQ4KGemvKernel {
pub k: u32,
pub n: u32,
pub epsilon: f32,
}
impl FusedRmsNormQ4KGemvKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self {
k,
n,
epsilon: 1e-5,
}
}
#[must_use]
pub const fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
}
impl Kernel for FusedRmsNormQ4KGemvKernel {
fn name(&self) -> &str {
"fused_rmsnorm_q4k_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let k = self.k;
let epsilon = self.epsilon;
let smem_size = (k * 4 + 32) as usize;
PtxKernel::new("fused_rmsnorm_q4k_gemv")
.param(PtxType::U64, "y_ptr") .param(PtxType::U64, "w_ptr") .param(PtxType::U64, "x_ptr") .param(PtxType::U64, "gamma_ptr") .param(PtxType::U32, "k_dim") .param(PtxType::U32, "n_dim") .shared_memory(smem_size)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let n_dim = ctx.load_param_u32("n_dim");
let oob = ctx.setp_ge_u32(block_id, n_dim);
ctx.branch_if(oob, "exit");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let gamma_ptr = ctx.load_param_u64("gamma_ptr");
let four = ctx.mov_u32_imm(4);
let one = ctx.mov_u32_imm(1);
let sq_sum = ctx.mov_f32_imm(0.0);
let idx = ctx.mov_u32_imm(0);
ctx.label("load_loop");
let loop_idx = ctx.add_u32_reg(idx, thread_id);
let in_bounds = ctx.setp_lt_u32(loop_idx, k_dim);
ctx.branch_if_not(in_bounds, "load_loop_end");
let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
let x_addr = ctx.add_u64(x_ptr, elem_offset);
let x_val = ctx.ld_global_f32(x_addr);
ctx.st_shared_f32(elem_offset, x_val);
ctx.fma_f32_inplace(sq_sum, x_val, x_val);
ctx.add_u32_inplace(idx, 256);
ctx.branch("load_loop");
ctx.label("load_loop_end");
let shfl16 = ctx.shfl_down_f32(sq_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl16);
let shfl8 = ctx.shfl_down_f32(sq_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl8);
let shfl4 = ctx.shfl_down_f32(sq_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl4);
let shfl2 = ctx.shfl_down_f32(sq_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl2);
let shfl1 = ctx.shfl_down_f32(sq_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl1);
let lane_id = ctx.rem_u32(thread_id, 32);
let warp_id = ctx.div_u32(thread_id, 32);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
let k_bytes = ctx.mul_u32_reg(k_dim, four);
let warp_sum_offset = ctx.mul_wide_u32_reg(warp_id, four);
let k_bytes_64 = ctx.cvt_u64_u32(k_bytes);
let warp_sum_addr = ctx.add_u64(k_bytes_64, warp_sum_offset);
ctx.branch_if_not(is_lane0, "skip_warp_write");
ctx.st_shared_f32(warp_sum_addr, sq_sum);
ctx.label("skip_warp_write");
ctx.bar_sync(0);
let is_thread0 = ctx.setp_lt_u32(thread_id, one);
let total_sq_sum = ctx.mov_f32_imm(0.0);
ctx.branch_if_not(is_thread0, "skip_final_reduce");
for warp in 0..8u32 {
let warp_offset = ctx.mov_u64_imm((warp * 4) as u64);
let addr = ctx.add_u64(k_bytes_64, warp_offset);
let warp_sum = ctx.ld_shared_f32(addr);
ctx.add_f32_inplace(total_sq_sum, warp_sum);
}
let k_f32 = ctx.cvt_f32_u32(k_dim);
let mean_sq = ctx.div_f32(total_sq_sum, k_f32);
let eps = ctx.mov_f32_imm(epsilon);
let mean_sq_eps = ctx.add_f32(mean_sq, eps);
let rms_inv = ctx.rsqrt_f32(mean_sq_eps);
let rms_inv_offset = ctx.mov_u64_imm(0); ctx.st_shared_f32(rms_inv_offset, rms_inv);
ctx.label("skip_final_reduce");
ctx.bar_sync(1);
let rms_inv_broadcast_offset = ctx.mov_u64_imm(0);
let rms_inv_val = ctx.ld_shared_f32(rms_inv_broadcast_offset);
let idx2 = ctx.mov_u32_imm(0);
ctx.label("norm_loop");
let loop_idx2 = ctx.add_u32_reg(idx2, thread_id);
let in_bounds2 = ctx.setp_lt_u32(loop_idx2, k_dim);
ctx.branch_if_not(in_bounds2, "norm_loop_end");
let elem_offset2 = ctx.mul_wide_u32_reg(loop_idx2, four);
let x_smem = ctx.ld_shared_f32(elem_offset2);
let gamma_addr = ctx.add_u64(gamma_ptr, elem_offset2);
let gamma = ctx.ld_global_f32(gamma_addr);
let normalized = ctx.mul_f32(x_smem, rms_inv_val);
let scaled = ctx.mul_f32(normalized, gamma);
ctx.st_shared_f32(elem_offset2, scaled);
ctx.add_u32_inplace(idx2, 256);
ctx.branch("norm_loop");
ctx.label("norm_loop_end");
ctx.bar_sync(2);
let acc = ctx.mov_f32_imm(0.0);
let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_super_blocks = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_super_blocks, sb_bytes);
let row_offset = ctx.mul_wide_u32_reg(block_id, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_offset);
let sb_idx = ctx.mov_u32_imm(0);
ctx.label("sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_super_blocks);
ctx.branch_if(sb_done, "sb_loop_end");
let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_offset);
let d_f16 = ctx.ld_global_f16(sb_addr);
let d = ctx.cvt_f32_f16(d_f16);
let two = ctx.mov_u64_imm(2);
let dmin_addr = ctx.add_u64(sb_addr, two);
let dmin_f16 = ctx.ld_global_f16(dmin_addr);
let dmin = ctx.cvt_f32_f16(dmin_f16);
let four_64 = ctx.mov_u64_imm(4);
let scales_base = ctx.add_u64(sb_addr, four_64);
let sub_block = ctx.div_u32(thread_id, 32);
let four_cmp = ctx.mov_u32_imm(4);
let sub_block_lt_4 = ctx.setp_lt_u32(sub_block, four_cmp);
let sub_block_64 = ctx.cvt_u64_u32(sub_block);
let scale_byte_addr = ctx.add_u64(scales_base, sub_block_64);
let scale_byte = ctx.ld_global_u8(scale_byte_addr);
let scale_byte_32 = ctx.cvt_u32_u8(scale_byte);
let four_reg = ctx.mov_u32_imm(4);
let sub_block_plus_4 = ctx.add_u32_reg(sub_block, four_reg);
let sub_block_plus_4_64 = ctx.cvt_u64_u32(sub_block_plus_4);
let min_byte_addr = ctx.add_u64(scales_base, sub_block_plus_4_64);
let min_byte = ctx.ld_global_u8(min_byte_addr);
let min_byte_32 = ctx.cvt_u32_u8(min_byte);
let mask_6bit = ctx.mov_u32_imm(0x3F);
let mask_4bit = ctx.mov_u32_imm(0x0F);
let six = ctx.mov_u32_imm(6);
let scale_simple = ctx.and_u32(scale_byte_32, mask_6bit);
let min_simple = ctx.and_u32(min_byte_32, mask_6bit);
let eight_64 = ctx.mov_u64_imm(8);
let scales_8_base = ctx.add_u64(scales_base, eight_64);
let sub_block_minus_4_raw = ctx.sub_u32_reg(sub_block, four_reg);
let zero_safe_fused = ctx.mov_u32_imm(0);
let sub_block_minus_4 =
ctx.selp_u32(sub_block_lt_4, zero_safe_fused, sub_block_minus_4_raw);
let sub_block_minus_4_64 = ctx.cvt_u64_u32(sub_block_minus_4);
let scales_8_addr = ctx.add_u64(scales_8_base, sub_block_minus_4_64);
let s8_byte = ctx.ld_global_u8(scales_8_addr);
let s8_byte_32 = ctx.cvt_u32_u8(s8_byte);
let scale_hi_src_addr = ctx.add_u64(scales_base, sub_block_minus_4_64);
let scale_hi_src_byte = ctx.ld_global_u8(scale_hi_src_addr);
let scale_hi_src_32 = ctx.cvt_u32_u8(scale_hi_src_byte);
let s8_lo = ctx.and_u32(s8_byte_32, mask_4bit);
let s0_hi = ctx.shr_u32(scale_hi_src_32, six);
let s0_hi_shifted = ctx.shl_u32(s0_hi, four_reg);
let scale_complex = ctx.or_u32(s8_lo, s0_hi_shifted);
let s8_hi = ctx.shr_u32(s8_byte_32, four_reg);
let s4_hi = ctx.shr_u32(scale_byte_32, six);
let s4_hi_shifted = ctx.shl_u32(s4_hi, four_reg);
let min_complex = ctx.or_u32(s8_hi, s4_hi_shifted);
let scale = ctx.selp_u32(sub_block_lt_4, scale_simple, scale_complex);
let min = ctx.selp_u32(sub_block_lt_4, min_simple, min_complex);
let scale_f = ctx.cvt_f32_u32(scale);
let min_f = ctx.cvt_f32_u32(min);
let ds = ctx.mul_f32(d, scale_f);
let dm = ctx.mul_f32(dmin, min_f);
let sixteen_64 = ctx.mov_u64_imm(16);
let qs_base = ctx.add_u64(sb_addr, sixteen_64);
let chunk_idx = ctx.div_u32(thread_id, 64);
let val_in_chunk = ctx.rem_u32(thread_id, 64);
let byte_in_chunk = ctx.rem_u32(val_in_chunk, 32);
let chunk_offset = ctx.mul_u32(chunk_idx, 32);
let qs_byte_offset = ctx.add_u32_reg(chunk_offset, byte_in_chunk);
let qs_byte_offset_64 = ctx.cvt_u64_u32(qs_byte_offset);
let qs_addr = ctx.add_u64(qs_base, qs_byte_offset_64);
let packed = ctx.ld_global_u8(qs_addr);
let packed_32 = ctx.cvt_u32_u8(packed);
let val_in_chunk_div_32 = ctx.div_u32(val_in_chunk, 32);
let shift_amount = ctx.mul_u32_reg(val_in_chunk_div_32, four_reg);
let shifted = ctx.shr_u32(packed_32, shift_amount);
let quant = ctx.and_u32(shifted, mask_4bit);
let quant_f32 = ctx.cvt_f32_u32(quant);
let scaled_q = ctx.mul_f32(ds, quant_f32);
let dequant = ctx.sub_f32(scaled_q, dm);
let sb_k_base = ctx.mul_u32(sb_idx, Q4K_SUPER_BLOCK_SIZE);
let x_idx = ctx.add_u32_reg(sb_k_base, thread_id);
let x_smem_offset = ctx.mul_wide_u32_reg(x_idx, four);
let x_norm_val = ctx.ld_shared_f32(x_smem_offset);
ctx.fma_f32_inplace(acc, x_norm_val, dequant);
ctx.add_u32_inplace(sb_idx, 1);
ctx.branch("sb_loop");
ctx.label("sb_loop_end");
let shfl16_acc = ctx.shfl_down_f32(acc, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl16_acc);
let shfl8_acc = ctx.shfl_down_f32(acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl8_acc);
let shfl4_acc = ctx.shfl_down_f32(acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl4_acc);
let shfl2_acc = ctx.shfl_down_f32(acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl2_acc);
let shfl1_acc = ctx.shfl_down_f32(acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, shfl1_acc);
let warp_acc_offset = ctx.mul_wide_u32_reg(warp_id, four);
let warp_acc_addr = ctx.add_u64(k_bytes_64, warp_acc_offset);
ctx.branch_if_not(is_lane0, "skip_warp_acc_write");
ctx.st_shared_f32(warp_acc_addr, acc);
ctx.label("skip_warp_acc_write");
ctx.bar_sync(3);
ctx.branch_if_not(is_thread0, "exit");
let final_acc = ctx.mov_f32_imm(0.0);
for warp in 0..8u32 {
let warp_offset = ctx.mov_u64_imm((warp * 4) as u64);
let addr = ctx.add_u64(k_bytes_64, warp_offset);
let warp_acc = ctx.ld_shared_f32(addr);
ctx.add_f32_inplace(final_acc, warp_acc);
}
let y_offset = ctx.mul_wide_u32(block_id, 4);
let y_addr = ctx.add_u64(y_ptr, y_offset);
ctx.st_global_f32(y_addr, final_acc);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct FusedGateUpQ4KGemvKernel {
pub k: u32,
pub n: u32,
}
impl FusedGateUpQ4KGemvKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self { k, n }
}
}
impl Kernel for FusedGateUpQ4KGemvKernel {
fn name(&self) -> &str {
"fused_gate_up_q4k_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let k = self.k;
let smem_size = (k * 4 + 64) as usize;
PtxKernel::new("fused_gate_up_q4k_gemv")
.param(PtxType::U64, "gate_out_ptr") .param(PtxType::U64, "up_out_ptr") .param(PtxType::U64, "wg_ptr") .param(PtxType::U64, "wu_ptr") .param(PtxType::U64, "x_ptr") .param(PtxType::U32, "k_dim") .param(PtxType::U32, "n_dim") .shared_memory(smem_size)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let n_dim = ctx.load_param_u32("n_dim");
let oob = ctx.setp_ge_u32(block_id, n_dim);
ctx.branch_if(oob, "exit");
let k_dim = ctx.load_param_u32("k_dim");
let gate_out_ptr = ctx.load_param_u64("gate_out_ptr");
let up_out_ptr = ctx.load_param_u64("up_out_ptr");
let wg_ptr = ctx.load_param_u64("wg_ptr");
let wu_ptr = ctx.load_param_u64("wu_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let four = ctx.mov_u32_imm(4);
let one = ctx.mov_u32_imm(1);
let lane_id = ctx.rem_u32(thread_id, 32);
let warp_id = ctx.div_u32(thread_id, 32);
let idx = ctx.mov_u32_imm(0);
ctx.label("load_loop");
let loop_idx = ctx.add_u32_reg(idx, thread_id);
let in_bounds = ctx.setp_lt_u32(loop_idx, k_dim);
ctx.branch_if_not(in_bounds, "load_loop_end");
let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
let x_addr = ctx.add_u64(x_ptr, elem_offset);
let x_val = ctx.ld_global_f32(x_addr);
ctx.st_shared_f32(elem_offset, x_val);
ctx.add_u32_inplace(idx, 256);
ctx.branch("load_loop");
ctx.label("load_loop_end");
ctx.bar_sync(0);
let k_rounded = ctx.add_u32(k_dim, 255);
let num_sb = ctx.div_u32(k_rounded, 256);
let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes);
let row_offset = ctx.mul_wide_u32_reg(block_id, row_bytes);
let wg_row_base = ctx.add_u64(wg_ptr, row_offset);
let wu_row_base = ctx.add_u64(wu_ptr, row_offset);
let acc_gate = ctx.mov_f32_imm(0.0);
let acc_up = ctx.mov_f32_imm(0.0);
let sb_idx = ctx.mov_u32_imm(0);
ctx.label("sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_sb);
ctx.branch_if(sb_done, "sb_loop_end");
let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let wg_sb_addr = ctx.add_u64(wg_row_base, sb_offset);
let wu_sb_addr = ctx.add_u64(wu_row_base, sb_offset);
let d_gate_f16 = ctx.ld_global_f16(wg_sb_addr);
let d_gate = ctx.cvt_f32_f16(d_gate_f16);
let two = ctx.mov_u64_imm(2);
let dmin_gate_addr = ctx.add_u64(wg_sb_addr, two);
let dmin_gate_f16 = ctx.ld_global_f16(dmin_gate_addr);
let dmin_gate = ctx.cvt_f32_f16(dmin_gate_f16);
let d_up_f16 = ctx.ld_global_f16(wu_sb_addr);
let d_up = ctx.cvt_f32_f16(d_up_f16);
let dmin_up_addr = ctx.add_u64(wu_sb_addr, two);
let dmin_up_f16 = ctx.ld_global_f16(dmin_up_addr);
let dmin_up = ctx.cvt_f32_f16(dmin_up_f16);
let four_64 = ctx.mov_u64_imm(4);
let scales_gate_base = ctx.add_u64(wg_sb_addr, four_64);
let scales_up_base = ctx.add_u64(wu_sb_addr, four_64);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
let scales_gate_0_3 = ctx.mov_u32_imm(0);
let scales_gate_4_7 = ctx.mov_u32_imm(0);
let scales_gate_8_11 = ctx.mov_u32_imm(0);
let scales_up_0_3 = ctx.mov_u32_imm(0);
let scales_up_4_7 = ctx.mov_u32_imm(0);
let scales_up_8_11 = ctx.mov_u32_imm(0);
ctx.branch_if_not(is_lane0, "skip_scale_load");
ctx.ld_global_u32_into(scales_gate_0_3, scales_gate_base);
let four_64b = ctx.mov_u64_imm(4);
let scales_gate_4_addr = ctx.add_u64(scales_gate_base, four_64b);
ctx.ld_global_u32_into(scales_gate_4_7, scales_gate_4_addr);
let eight_64 = ctx.mov_u64_imm(8);
let scales_gate_8_addr = ctx.add_u64(scales_gate_base, eight_64);
ctx.ld_global_u32_into(scales_gate_8_11, scales_gate_8_addr);
ctx.ld_global_u32_into(scales_up_0_3, scales_up_base);
let scales_up_4_addr = ctx.add_u64(scales_up_base, four_64b);
ctx.ld_global_u32_into(scales_up_4_7, scales_up_4_addr);
let scales_up_8_addr = ctx.add_u64(scales_up_base, eight_64);
ctx.ld_global_u32_into(scales_up_8_11, scales_up_8_addr);
ctx.label("skip_scale_load");
let _scales_gate_0_3_bcast = ctx.shfl_idx_u32(scales_gate_0_3, 0, 0xFFFF_FFFF);
let _scales_gate_4_7_bcast = ctx.shfl_idx_u32(scales_gate_4_7, 0, 0xFFFF_FFFF);
let _scales_gate_8_11_bcast = ctx.shfl_idx_u32(scales_gate_8_11, 0, 0xFFFF_FFFF);
let _scales_up_0_3_bcast = ctx.shfl_idx_u32(scales_up_0_3, 0, 0xFFFF_FFFF);
let _scales_up_4_7_bcast = ctx.shfl_idx_u32(scales_up_4_7, 0, 0xFFFF_FFFF);
let _scales_up_8_11_bcast = ctx.shfl_idx_u32(scales_up_8_11, 0, 0xFFFF_FFFF);
let quant_offset = ctx.mov_u64_imm(16);
let wg_quant_base = ctx.add_u64(wg_sb_addr, quant_offset);
let wu_quant_base = ctx.add_u64(wu_sb_addr, quant_offset);
let two_const = ctx.mov_u32_imm(2);
let _block_idx = ctx.shr_u32(lane_id, two_const);
let mask_4bit = ctx.mov_u32_imm(0x0F);
let _mask_8bit = ctx.mov_u32_imm(0xFF);
let eight_shift = ctx.mov_u32_imm(8);
let sixteen_shift = ctx.mov_u32_imm(16);
let twenty_four = ctx.mov_u32_imm(24);
let eff_scale_gate = d_gate;
let eff_min_gate = dmin_gate;
let eff_scale_up = d_up;
let eff_min_up = dmin_up;
let quant_byte_offset = ctx.mul_wide_u32_reg(lane_id, four);
let wg_quant_addr = ctx.add_u64(wg_quant_base, quant_byte_offset);
let wu_quant_addr = ctx.add_u64(wu_quant_base, quant_byte_offset);
let wg_packed = ctx.ld_global_u32(wg_quant_addr);
let wu_packed = ctx.ld_global_u32(wu_quant_addr);
let sb_base_u32 = ctx.mov_u32_imm(256);
let sb_base = ctx.mul_u32_reg(sb_idx, sb_base_u32);
let eight_const = ctx.mov_u32_imm(8);
let lane_base = ctx.mul_u32_reg(lane_id, eight_const);
let input_base_idx = ctx.add_u32_reg(sb_base, lane_base);
let nib0_g = ctx.and_u32(wg_packed, mask_4bit);
let nib0_u = ctx.and_u32(wu_packed, mask_4bit);
let shift4 = ctx.mov_u32_imm(4);
let tmp1_g = ctx.shr_u32(wg_packed, shift4);
let nib1_g = ctx.and_u32(tmp1_g, mask_4bit);
let tmp1_u = ctx.shr_u32(wu_packed, shift4);
let nib1_u = ctx.and_u32(tmp1_u, mask_4bit);
let tmp2_g = ctx.shr_u32(wg_packed, eight_shift);
let nib2_g = ctx.and_u32(tmp2_g, mask_4bit);
let tmp2_u = ctx.shr_u32(wu_packed, eight_shift);
let nib2_u = ctx.and_u32(tmp2_u, mask_4bit);
let shift12 = ctx.mov_u32_imm(12);
let tmp3_g = ctx.shr_u32(wg_packed, shift12);
let nib3_g = ctx.and_u32(tmp3_g, mask_4bit);
let tmp3_u = ctx.shr_u32(wu_packed, shift12);
let nib3_u = ctx.and_u32(tmp3_u, mask_4bit);
let tmp4_g = ctx.shr_u32(wg_packed, sixteen_shift);
let nib4_g = ctx.and_u32(tmp4_g, mask_4bit);
let tmp4_u = ctx.shr_u32(wu_packed, sixteen_shift);
let nib4_u = ctx.and_u32(tmp4_u, mask_4bit);
let shift20 = ctx.mov_u32_imm(20);
let tmp5_g = ctx.shr_u32(wg_packed, shift20);
let nib5_g = ctx.and_u32(tmp5_g, mask_4bit);
let tmp5_u = ctx.shr_u32(wu_packed, shift20);
let nib5_u = ctx.and_u32(tmp5_u, mask_4bit);
let tmp6_g = ctx.shr_u32(wg_packed, twenty_four);
let nib6_g = ctx.and_u32(tmp6_g, mask_4bit);
let tmp6_u = ctx.shr_u32(wu_packed, twenty_four);
let nib6_u = ctx.and_u32(tmp6_u, mask_4bit);
let shift28 = ctx.mov_u32_imm(28);
let nib7_g = ctx.shr_u32(wg_packed, shift28);
let nib7_u = ctx.shr_u32(wu_packed, shift28);
let nib0_g_f = ctx.cvt_f32_u32(nib0_g);
let nib0_u_f = ctx.cvt_f32_u32(nib0_u);
let nib1_g_f = ctx.cvt_f32_u32(nib1_g);
let nib1_u_f = ctx.cvt_f32_u32(nib1_u);
let nib2_g_f = ctx.cvt_f32_u32(nib2_g);
let nib2_u_f = ctx.cvt_f32_u32(nib2_u);
let nib3_g_f = ctx.cvt_f32_u32(nib3_g);
let nib3_u_f = ctx.cvt_f32_u32(nib3_u);
let nib4_g_f = ctx.cvt_f32_u32(nib4_g);
let nib4_u_f = ctx.cvt_f32_u32(nib4_u);
let nib5_g_f = ctx.cvt_f32_u32(nib5_g);
let nib5_u_f = ctx.cvt_f32_u32(nib5_u);
let nib6_g_f = ctx.cvt_f32_u32(nib6_g);
let nib6_u_f = ctx.cvt_f32_u32(nib6_u);
let nib7_g_f = ctx.cvt_f32_u32(nib7_g);
let nib7_u_f = ctx.cvt_f32_u32(nib7_u);
let neg_min_g = ctx.neg_f32(eff_min_gate);
let neg_min_u = ctx.neg_f32(eff_min_up);
let dq0_g = ctx.fma_f32(eff_scale_gate, nib0_g_f, neg_min_g);
let dq0_u = ctx.fma_f32(eff_scale_up, nib0_u_f, neg_min_u);
let dq1_g = ctx.fma_f32(eff_scale_gate, nib1_g_f, neg_min_g);
let dq1_u = ctx.fma_f32(eff_scale_up, nib1_u_f, neg_min_u);
let dq2_g = ctx.fma_f32(eff_scale_gate, nib2_g_f, neg_min_g);
let dq2_u = ctx.fma_f32(eff_scale_up, nib2_u_f, neg_min_u);
let dq3_g = ctx.fma_f32(eff_scale_gate, nib3_g_f, neg_min_g);
let dq3_u = ctx.fma_f32(eff_scale_up, nib3_u_f, neg_min_u);
let dq4_g = ctx.fma_f32(eff_scale_gate, nib4_g_f, neg_min_g);
let dq4_u = ctx.fma_f32(eff_scale_up, nib4_u_f, neg_min_u);
let dq5_g = ctx.fma_f32(eff_scale_gate, nib5_g_f, neg_min_g);
let dq5_u = ctx.fma_f32(eff_scale_up, nib5_u_f, neg_min_u);
let dq6_g = ctx.fma_f32(eff_scale_gate, nib6_g_f, neg_min_g);
let dq6_u = ctx.fma_f32(eff_scale_up, nib6_u_f, neg_min_u);
let dq7_g = ctx.fma_f32(eff_scale_gate, nib7_g_f, neg_min_g);
let dq7_u = ctx.fma_f32(eff_scale_up, nib7_u_f, neg_min_u);
let zero_imm = ctx.mov_u32_imm(0);
let one_imm = ctx.mov_u32_imm(1);
let two_imm = ctx.mov_u32_imm(2);
let three_imm = ctx.mov_u32_imm(3);
let four_imm = ctx.mov_u32_imm(4);
let five_imm = ctx.mov_u32_imm(5);
let six_imm = ctx.mov_u32_imm(6);
let seven_imm = ctx.mov_u32_imm(7);
let idx0 = ctx.add_u32_reg(input_base_idx, zero_imm);
let off0 = ctx.mul_wide_u32_reg(idx0, four);
let x0 = ctx.ld_shared_f32(off0);
ctx.fma_f32_inplace(acc_gate, dq0_g, x0);
ctx.fma_f32_inplace(acc_up, dq0_u, x0);
let idx1 = ctx.add_u32_reg(input_base_idx, one_imm);
let off1 = ctx.mul_wide_u32_reg(idx1, four);
let x1 = ctx.ld_shared_f32(off1);
ctx.fma_f32_inplace(acc_gate, dq1_g, x1);
ctx.fma_f32_inplace(acc_up, dq1_u, x1);
let idx2 = ctx.add_u32_reg(input_base_idx, two_imm);
let off2 = ctx.mul_wide_u32_reg(idx2, four);
let x2 = ctx.ld_shared_f32(off2);
ctx.fma_f32_inplace(acc_gate, dq2_g, x2);
ctx.fma_f32_inplace(acc_up, dq2_u, x2);
let idx3 = ctx.add_u32_reg(input_base_idx, three_imm);
let off3 = ctx.mul_wide_u32_reg(idx3, four);
let x3 = ctx.ld_shared_f32(off3);
ctx.fma_f32_inplace(acc_gate, dq3_g, x3);
ctx.fma_f32_inplace(acc_up, dq3_u, x3);
let idx4 = ctx.add_u32_reg(input_base_idx, four_imm);
let off4 = ctx.mul_wide_u32_reg(idx4, four);
let x4 = ctx.ld_shared_f32(off4);
ctx.fma_f32_inplace(acc_gate, dq4_g, x4);
ctx.fma_f32_inplace(acc_up, dq4_u, x4);
let idx5 = ctx.add_u32_reg(input_base_idx, five_imm);
let off5 = ctx.mul_wide_u32_reg(idx5, four);
let x5 = ctx.ld_shared_f32(off5);
ctx.fma_f32_inplace(acc_gate, dq5_g, x5);
ctx.fma_f32_inplace(acc_up, dq5_u, x5);
let idx6 = ctx.add_u32_reg(input_base_idx, six_imm);
let off6 = ctx.mul_wide_u32_reg(idx6, four);
let x6 = ctx.ld_shared_f32(off6);
ctx.fma_f32_inplace(acc_gate, dq6_g, x6);
ctx.fma_f32_inplace(acc_up, dq6_u, x6);
let idx7 = ctx.add_u32_reg(input_base_idx, seven_imm);
let off7 = ctx.mul_wide_u32_reg(idx7, four);
let x7 = ctx.ld_shared_f32(off7);
ctx.fma_f32_inplace(acc_gate, dq7_g, x7);
ctx.fma_f32_inplace(acc_up, dq7_u, x7);
ctx.add_u32_inplace(sb_idx, 1);
ctx.branch("sb_loop");
ctx.label("sb_loop_end");
let shfl16_gate = ctx.shfl_down_f32(acc_gate, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl16_gate);
let shfl8_gate = ctx.shfl_down_f32(acc_gate, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl8_gate);
let shfl4_gate = ctx.shfl_down_f32(acc_gate, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl4_gate);
let shfl2_gate = ctx.shfl_down_f32(acc_gate, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl2_gate);
let shfl1_gate = ctx.shfl_down_f32(acc_gate, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl1_gate);
let shfl16_up = ctx.shfl_down_f32(acc_up, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl16_up);
let shfl8_up = ctx.shfl_down_f32(acc_up, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl8_up);
let shfl4_up = ctx.shfl_down_f32(acc_up, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl4_up);
let shfl2_up = ctx.shfl_down_f32(acc_up, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl2_up);
let shfl1_up = ctx.shfl_down_f32(acc_up, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl1_up);
let k_bytes = ctx.mul_u32_reg(k_dim, four);
let k_bytes_64 = ctx.cvt_u64_u32(k_bytes);
let warp_gate_offset = ctx.mul_wide_u32_reg(warp_id, four);
let warp_gate_addr = ctx.add_u64(k_bytes_64, warp_gate_offset);
let thirty_two = ctx.mov_u64_imm(32);
let warp_up_addr_base = ctx.add_u64(k_bytes_64, thirty_two);
let warp_up_addr = ctx.add_u64(warp_up_addr_base, warp_gate_offset);
ctx.branch_if_not(is_lane0, "skip_warp_write");
ctx.st_shared_f32(warp_gate_addr, acc_gate);
ctx.st_shared_f32(warp_up_addr, acc_up);
ctx.label("skip_warp_write");
ctx.bar_sync(1);
let is_thread0 = ctx.setp_lt_u32(thread_id, one);
ctx.branch_if_not(is_thread0, "exit");
let final_gate = ctx.mov_f32_imm(0.0);
let final_up = ctx.mov_f32_imm(0.0);
for warp in 0..8u32 {
let warp_offset = ctx.mov_u64_imm((warp * 4) as u64);
let gate_addr = ctx.add_u64(k_bytes_64, warp_offset);
let up_addr_base = ctx.add_u64(k_bytes_64, thirty_two);
let up_addr = ctx.add_u64(up_addr_base, warp_offset);
let warp_gate_sum = ctx.ld_shared_f32(gate_addr);
let warp_up_sum = ctx.ld_shared_f32(up_addr);
ctx.add_f32_inplace(final_gate, warp_gate_sum);
ctx.add_f32_inplace(final_up, warp_up_sum);
}
let out_offset = ctx.mul_wide_u32(block_id, 4);
let gate_addr = ctx.add_u64(gate_out_ptr, out_offset);
let up_addr = ctx.add_u64(up_out_ptr, out_offset);
ctx.st_global_f32(gate_addr, final_gate);
ctx.st_global_f32(up_addr, final_up);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct FusedRmsNormGateUpSwigluQ4KKernel {
pub k: u32,
pub n: u32,
pub epsilon: f32,
}
impl FusedRmsNormGateUpSwigluQ4KKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self {
k,
n,
epsilon: 1e-6, }
}
#[must_use]
pub const fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
}
impl Kernel for FusedRmsNormGateUpSwigluQ4KKernel {
fn name(&self) -> &str {
"fused_rmsnorm_gate_up_swiglu_q4k"
}
fn build_ptx(&self) -> PtxKernel {
let k = self.k;
let epsilon = self.epsilon;
let smem_size = (k * 4 + 96) as usize;
PtxKernel::new("fused_rmsnorm_gate_up_swiglu_q4k")
.param(PtxType::U64, "out_ptr") .param(PtxType::U64, "wg_ptr") .param(PtxType::U64, "wu_ptr") .param(PtxType::U64, "x_ptr") .param(PtxType::U64, "gamma_ptr") .param(PtxType::U32, "k_dim") .param(PtxType::U32, "n_dim") .shared_memory(smem_size)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let n_dim = ctx.load_param_u32("n_dim");
let oob = ctx.setp_ge_u32(block_id, n_dim);
ctx.branch_if(oob, "exit");
let k_dim = ctx.load_param_u32("k_dim");
let out_ptr = ctx.load_param_u64("out_ptr");
let wg_ptr = ctx.load_param_u64("wg_ptr");
let wu_ptr = ctx.load_param_u64("wu_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let gamma_ptr = ctx.load_param_u64("gamma_ptr");
let four = ctx.mov_u32_imm(4);
let one = ctx.mov_u32_imm(1);
let lane_id = ctx.rem_u32(thread_id, 32);
let warp_id = ctx.div_u32(thread_id, 32);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
let is_thread0 = ctx.setp_lt_u32(thread_id, one);
let sq_sum = ctx.mov_f32_imm(0.0);
let idx = ctx.mov_u32_imm(0);
ctx.label("load_loop");
let loop_idx = ctx.add_u32_reg(idx, thread_id);
let in_bounds = ctx.setp_lt_u32(loop_idx, k_dim);
ctx.branch_if_not(in_bounds, "load_loop_end");
let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
let x_addr = ctx.add_u64(x_ptr, elem_offset);
let x_val = ctx.ld_global_f32(x_addr);
ctx.st_shared_f32(elem_offset, x_val);
ctx.fma_f32_inplace(sq_sum, x_val, x_val);
ctx.add_u32_inplace(idx, 256);
ctx.branch("load_loop");
ctx.label("load_loop_end");
let shfl16 = ctx.shfl_down_f32(sq_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl16);
let shfl8 = ctx.shfl_down_f32(sq_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl8);
let shfl4 = ctx.shfl_down_f32(sq_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl4);
let shfl2 = ctx.shfl_down_f32(sq_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl2);
let shfl1 = ctx.shfl_down_f32(sq_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl1);
let k_bytes = ctx.mul_u32_reg(k_dim, four);
let k_bytes_64 = ctx.cvt_u64_u32(k_bytes);
let warp_sum_offset = ctx.mul_wide_u32_reg(warp_id, four);
let warp_sum_addr = ctx.add_u64(k_bytes_64, warp_sum_offset);
ctx.branch_if_not(is_lane0, "skip_warp_write");
ctx.st_shared_f32(warp_sum_addr, sq_sum);
ctx.label("skip_warp_write");
ctx.bar_sync(0);
let total_sq_sum = ctx.mov_f32_imm(0.0);
ctx.branch_if_not(is_thread0, "skip_final_reduce");
for warp in 0..8u32 {
let warp_offset = ctx.mov_u64_imm((warp * 4) as u64);
let addr = ctx.add_u64(k_bytes_64, warp_offset);
let warp_sum = ctx.ld_shared_f32(addr);
ctx.add_f32_inplace(total_sq_sum, warp_sum);
}
let k_f32 = ctx.cvt_f32_u32(k_dim);
let mean_sq = ctx.div_f32(total_sq_sum, k_f32);
let eps = ctx.mov_f32_imm(epsilon);
let mean_sq_eps = ctx.add_f32(mean_sq, eps);
let rms_inv = ctx.rsqrt_f32(mean_sq_eps);
let rms_inv_offset = ctx.mov_u64_imm(0);
ctx.st_shared_f32(rms_inv_offset, rms_inv);
ctx.label("skip_final_reduce");
ctx.bar_sync(1);
let rms_inv_broadcast_offset = ctx.mov_u64_imm(0);
let rms_inv_val = ctx.ld_shared_f32(rms_inv_broadcast_offset);
let idx2 = ctx.mov_u32_imm(0);
ctx.label("norm_loop");
let loop_idx2 = ctx.add_u32_reg(idx2, thread_id);
let in_bounds2 = ctx.setp_lt_u32(loop_idx2, k_dim);
ctx.branch_if_not(in_bounds2, "norm_loop_end");
let elem_offset2 = ctx.mul_wide_u32_reg(loop_idx2, four);
let x_smem = ctx.ld_shared_f32(elem_offset2);
let gamma_addr = ctx.add_u64(gamma_ptr, elem_offset2);
let gamma = ctx.ld_global_f32(gamma_addr);
let normalized = ctx.mul_f32(x_smem, rms_inv_val);
let scaled = ctx.mul_f32(normalized, gamma);
ctx.st_shared_f32(elem_offset2, scaled);
ctx.add_u32_inplace(idx2, 256);
ctx.branch("norm_loop");
ctx.label("norm_loop_end");
ctx.bar_sync(2);
let acc_gate = ctx.mov_f32_imm(0.0);
let acc_up = ctx.mov_f32_imm(0.0);
let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_sb = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes);
let row_offset = ctx.mul_wide_u32_reg(block_id, row_bytes);
let wg_row_base = ctx.add_u64(wg_ptr, row_offset);
let wu_row_base = ctx.add_u64(wu_ptr, row_offset);
let sb_idx = ctx.mov_u32_imm(0);
ctx.label("sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_sb);
ctx.branch_if(sb_done, "sb_loop_end");
let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let wg_sb_addr = ctx.add_u64(wg_row_base, sb_offset);
let wu_sb_addr = ctx.add_u64(wu_row_base, sb_offset);
let d_gate_f16 = ctx.ld_global_f16(wg_sb_addr);
let d_gate = ctx.cvt_f32_f16(d_gate_f16);
let two_64 = ctx.mov_u64_imm(2);
let dmin_gate_addr = ctx.add_u64(wg_sb_addr, two_64);
let dmin_gate_f16 = ctx.ld_global_f16(dmin_gate_addr);
let dmin_gate = ctx.cvt_f32_f16(dmin_gate_f16);
let d_up_f16 = ctx.ld_global_f16(wu_sb_addr);
let d_up = ctx.cvt_f32_f16(d_up_f16);
let dmin_up_addr = ctx.add_u64(wu_sb_addr, two_64);
let dmin_up_f16 = ctx.ld_global_f16(dmin_up_addr);
let dmin_up = ctx.cvt_f32_f16(dmin_up_f16);
let eff_scale_gate = d_gate;
let eff_min_gate = dmin_gate;
let eff_scale_up = d_up;
let eff_min_up = dmin_up;
let quant_offset = ctx.mov_u64_imm(16);
let wg_quant_base = ctx.add_u64(wg_sb_addr, quant_offset);
let wu_quant_base = ctx.add_u64(wu_sb_addr, quant_offset);
let mask_4bit = ctx.mov_u32_imm(0x0F);
let eight_shift = ctx.mov_u32_imm(8);
let sixteen_shift = ctx.mov_u32_imm(16);
let twenty_four = ctx.mov_u32_imm(24);
let quant_byte_offset = ctx.mul_wide_u32_reg(lane_id, four);
let wg_quant_addr = ctx.add_u64(wg_quant_base, quant_byte_offset);
let wu_quant_addr = ctx.add_u64(wu_quant_base, quant_byte_offset);
let wg_packed = ctx.ld_global_u32(wg_quant_addr);
let wu_packed = ctx.ld_global_u32(wu_quant_addr);
let sb_base_u32 = ctx.mov_u32_imm(256);
let sb_base = ctx.mul_u32_reg(sb_idx, sb_base_u32);
let eight_const = ctx.mov_u32_imm(8);
let lane_base = ctx.mul_u32_reg(lane_id, eight_const);
let input_base_idx = ctx.add_u32_reg(sb_base, lane_base);
let nib0_g = ctx.and_u32(wg_packed, mask_4bit);
let nib0_u = ctx.and_u32(wu_packed, mask_4bit);
let shift4 = ctx.mov_u32_imm(4);
let tmp1_g = ctx.shr_u32(wg_packed, shift4);
let nib1_g = ctx.and_u32(tmp1_g, mask_4bit);
let tmp1_u = ctx.shr_u32(wu_packed, shift4);
let nib1_u = ctx.and_u32(tmp1_u, mask_4bit);
let tmp2_g = ctx.shr_u32(wg_packed, eight_shift);
let nib2_g = ctx.and_u32(tmp2_g, mask_4bit);
let tmp2_u = ctx.shr_u32(wu_packed, eight_shift);
let nib2_u = ctx.and_u32(tmp2_u, mask_4bit);
let shift12 = ctx.mov_u32_imm(12);
let tmp3_g = ctx.shr_u32(wg_packed, shift12);
let nib3_g = ctx.and_u32(tmp3_g, mask_4bit);
let tmp3_u = ctx.shr_u32(wu_packed, shift12);
let nib3_u = ctx.and_u32(tmp3_u, mask_4bit);
let tmp4_g = ctx.shr_u32(wg_packed, sixteen_shift);
let nib4_g = ctx.and_u32(tmp4_g, mask_4bit);
let tmp4_u = ctx.shr_u32(wu_packed, sixteen_shift);
let nib4_u = ctx.and_u32(tmp4_u, mask_4bit);
let shift20 = ctx.mov_u32_imm(20);
let tmp5_g = ctx.shr_u32(wg_packed, shift20);
let nib5_g = ctx.and_u32(tmp5_g, mask_4bit);
let tmp5_u = ctx.shr_u32(wu_packed, shift20);
let nib5_u = ctx.and_u32(tmp5_u, mask_4bit);
let tmp6_g = ctx.shr_u32(wg_packed, twenty_four);
let nib6_g = ctx.and_u32(tmp6_g, mask_4bit);
let tmp6_u = ctx.shr_u32(wu_packed, twenty_four);
let nib6_u = ctx.and_u32(tmp6_u, mask_4bit);
let shift28 = ctx.mov_u32_imm(28);
let nib7_g = ctx.shr_u32(wg_packed, shift28);
let nib7_u = ctx.shr_u32(wu_packed, shift28);
let nib0_g_f = ctx.cvt_f32_u32(nib0_g);
let nib0_u_f = ctx.cvt_f32_u32(nib0_u);
let nib1_g_f = ctx.cvt_f32_u32(nib1_g);
let nib1_u_f = ctx.cvt_f32_u32(nib1_u);
let nib2_g_f = ctx.cvt_f32_u32(nib2_g);
let nib2_u_f = ctx.cvt_f32_u32(nib2_u);
let nib3_g_f = ctx.cvt_f32_u32(nib3_g);
let nib3_u_f = ctx.cvt_f32_u32(nib3_u);
let nib4_g_f = ctx.cvt_f32_u32(nib4_g);
let nib4_u_f = ctx.cvt_f32_u32(nib4_u);
let nib5_g_f = ctx.cvt_f32_u32(nib5_g);
let nib5_u_f = ctx.cvt_f32_u32(nib5_u);
let nib6_g_f = ctx.cvt_f32_u32(nib6_g);
let nib6_u_f = ctx.cvt_f32_u32(nib6_u);
let nib7_g_f = ctx.cvt_f32_u32(nib7_g);
let nib7_u_f = ctx.cvt_f32_u32(nib7_u);
let neg_min_g = ctx.neg_f32(eff_min_gate);
let neg_min_u = ctx.neg_f32(eff_min_up);
let dq0_g = ctx.fma_f32(eff_scale_gate, nib0_g_f, neg_min_g);
let dq0_u = ctx.fma_f32(eff_scale_up, nib0_u_f, neg_min_u);
let dq1_g = ctx.fma_f32(eff_scale_gate, nib1_g_f, neg_min_g);
let dq1_u = ctx.fma_f32(eff_scale_up, nib1_u_f, neg_min_u);
let dq2_g = ctx.fma_f32(eff_scale_gate, nib2_g_f, neg_min_g);
let dq2_u = ctx.fma_f32(eff_scale_up, nib2_u_f, neg_min_u);
let dq3_g = ctx.fma_f32(eff_scale_gate, nib3_g_f, neg_min_g);
let dq3_u = ctx.fma_f32(eff_scale_up, nib3_u_f, neg_min_u);
let dq4_g = ctx.fma_f32(eff_scale_gate, nib4_g_f, neg_min_g);
let dq4_u = ctx.fma_f32(eff_scale_up, nib4_u_f, neg_min_u);
let dq5_g = ctx.fma_f32(eff_scale_gate, nib5_g_f, neg_min_g);
let dq5_u = ctx.fma_f32(eff_scale_up, nib5_u_f, neg_min_u);
let dq6_g = ctx.fma_f32(eff_scale_gate, nib6_g_f, neg_min_g);
let dq6_u = ctx.fma_f32(eff_scale_up, nib6_u_f, neg_min_u);
let dq7_g = ctx.fma_f32(eff_scale_gate, nib7_g_f, neg_min_g);
let dq7_u = ctx.fma_f32(eff_scale_up, nib7_u_f, neg_min_u);
let zero_imm = ctx.mov_u32_imm(0);
let one_imm = ctx.mov_u32_imm(1);
let two_imm = ctx.mov_u32_imm(2);
let three_imm = ctx.mov_u32_imm(3);
let four_imm = ctx.mov_u32_imm(4);
let five_imm = ctx.mov_u32_imm(5);
let six_imm = ctx.mov_u32_imm(6);
let seven_imm = ctx.mov_u32_imm(7);
let idx0 = ctx.add_u32_reg(input_base_idx, zero_imm);
let off0 = ctx.mul_wide_u32_reg(idx0, four);
let x0 = ctx.ld_shared_f32(off0);
ctx.fma_f32_inplace(acc_gate, dq0_g, x0);
ctx.fma_f32_inplace(acc_up, dq0_u, x0);
let idx1 = ctx.add_u32_reg(input_base_idx, one_imm);
let off1 = ctx.mul_wide_u32_reg(idx1, four);
let x1 = ctx.ld_shared_f32(off1);
ctx.fma_f32_inplace(acc_gate, dq1_g, x1);
ctx.fma_f32_inplace(acc_up, dq1_u, x1);
let idx2 = ctx.add_u32_reg(input_base_idx, two_imm);
let off2 = ctx.mul_wide_u32_reg(idx2, four);
let x2 = ctx.ld_shared_f32(off2);
ctx.fma_f32_inplace(acc_gate, dq2_g, x2);
ctx.fma_f32_inplace(acc_up, dq2_u, x2);
let idx3 = ctx.add_u32_reg(input_base_idx, three_imm);
let off3 = ctx.mul_wide_u32_reg(idx3, four);
let x3 = ctx.ld_shared_f32(off3);
ctx.fma_f32_inplace(acc_gate, dq3_g, x3);
ctx.fma_f32_inplace(acc_up, dq3_u, x3);
let idx4 = ctx.add_u32_reg(input_base_idx, four_imm);
let off4 = ctx.mul_wide_u32_reg(idx4, four);
let x4 = ctx.ld_shared_f32(off4);
ctx.fma_f32_inplace(acc_gate, dq4_g, x4);
ctx.fma_f32_inplace(acc_up, dq4_u, x4);
let idx5 = ctx.add_u32_reg(input_base_idx, five_imm);
let off5 = ctx.mul_wide_u32_reg(idx5, four);
let x5 = ctx.ld_shared_f32(off5);
ctx.fma_f32_inplace(acc_gate, dq5_g, x5);
ctx.fma_f32_inplace(acc_up, dq5_u, x5);
let idx6 = ctx.add_u32_reg(input_base_idx, six_imm);
let off6 = ctx.mul_wide_u32_reg(idx6, four);
let x6 = ctx.ld_shared_f32(off6);
ctx.fma_f32_inplace(acc_gate, dq6_g, x6);
ctx.fma_f32_inplace(acc_up, dq6_u, x6);
let idx7 = ctx.add_u32_reg(input_base_idx, seven_imm);
let off7 = ctx.mul_wide_u32_reg(idx7, four);
let x7 = ctx.ld_shared_f32(off7);
ctx.fma_f32_inplace(acc_gate, dq7_g, x7);
ctx.fma_f32_inplace(acc_up, dq7_u, x7);
ctx.add_u32_inplace(sb_idx, 1);
ctx.branch("sb_loop");
ctx.label("sb_loop_end");
let shfl16_gate = ctx.shfl_down_f32(acc_gate, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl16_gate);
let shfl8_gate = ctx.shfl_down_f32(acc_gate, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl8_gate);
let shfl4_gate = ctx.shfl_down_f32(acc_gate, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl4_gate);
let shfl2_gate = ctx.shfl_down_f32(acc_gate, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl2_gate);
let shfl1_gate = ctx.shfl_down_f32(acc_gate, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_gate, shfl1_gate);
let shfl16_up = ctx.shfl_down_f32(acc_up, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl16_up);
let shfl8_up = ctx.shfl_down_f32(acc_up, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl8_up);
let shfl4_up = ctx.shfl_down_f32(acc_up, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl4_up);
let shfl2_up = ctx.shfl_down_f32(acc_up, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl2_up);
let shfl1_up = ctx.shfl_down_f32(acc_up, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_up, shfl1_up);
let thirty_two_64 = ctx.mov_u64_imm(32);
let sixty_four_64 = ctx.mov_u64_imm(64);
let warp_gate_offset = ctx.mul_wide_u32_reg(warp_id, four);
let warp_gate_base = ctx.add_u64(k_bytes_64, thirty_two_64);
let warp_gate_addr = ctx.add_u64(warp_gate_base, warp_gate_offset);
let warp_up_base = ctx.add_u64(k_bytes_64, sixty_four_64);
let warp_up_addr = ctx.add_u64(warp_up_base, warp_gate_offset);
ctx.branch_if_not(is_lane0, "skip_warp_acc_write");
ctx.st_shared_f32(warp_gate_addr, acc_gate);
ctx.st_shared_f32(warp_up_addr, acc_up);
ctx.label("skip_warp_acc_write");
ctx.bar_sync(3);
ctx.branch_if_not(is_thread0, "exit");
let final_gate = ctx.mov_f32_imm(0.0);
let final_up = ctx.mov_f32_imm(0.0);
for warp in 0..8u32 {
let warp_offset = ctx.mov_u64_imm((warp * 4) as u64);
let gate_addr = ctx.add_u64(warp_gate_base, warp_offset);
let up_addr = ctx.add_u64(warp_up_base, warp_offset);
let warp_gate_sum = ctx.ld_shared_f32(gate_addr);
let warp_up_sum = ctx.ld_shared_f32(up_addr);
ctx.add_f32_inplace(final_gate, warp_gate_sum);
ctx.add_f32_inplace(final_up, warp_up_sum);
}
let neg_gate = ctx.neg_f32(final_gate);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scaled_exp = ctx.mul_f32(neg_gate, log2_e);
let exp_val = ctx.ex2_f32(scaled_exp);
let one_f = ctx.mov_f32_imm(1.0);
let one_plus_exp = ctx.add_f32(one_f, exp_val);
let sigmoid = ctx.rcp_f32(one_plus_exp);
let silu = ctx.mul_f32(final_gate, sigmoid);
let output = ctx.mul_f32(silu, final_up);
let out_offset = ctx.mul_wide_u32(block_id, 4);
let out_addr = ctx.add_u64(out_ptr, out_offset);
ctx.st_global_f32(out_addr, output);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests_3way_fusion {
use super::*;
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_kernel_builds() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(3584, 18944);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_rmsnorm_gate_up_swiglu_q4k"));
assert!(ptx.contains(".entry"));
}
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_kernel_name() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(1024, 4096);
assert_eq!(kernel.name(), "fused_rmsnorm_gate_up_swiglu_q4k");
}
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_kernel_clone() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(2048, 8192);
let cloned = kernel.clone();
assert_eq!(cloned.k, kernel.k);
assert_eq!(cloned.n, kernel.n);
assert_eq!(cloned.epsilon, kernel.epsilon);
}
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_with_epsilon() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(2048, 8192).with_epsilon(1e-5);
assert_eq!(kernel.epsilon, 1e-5);
}
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_kernel_debug() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(2048, 8192);
let debug = format!("{:?}", kernel);
assert!(debug.contains("FusedRmsNormGateUpSwigluQ4KKernel"));
assert!(debug.contains("2048"));
assert!(debug.contains("8192"));
}
}