#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxType};
#[derive(Debug, Clone)]
pub struct BatchedGemmConfig {
pub batch: u32,
pub m: u32,
pub n: u32,
pub k: u32,
pub tile_size: u32,
}
impl Default for BatchedGemmConfig {
fn default() -> Self {
Self {
batch: 1,
m: 1024,
n: 1024,
k: 1024,
tile_size: 16,
}
}
}
#[derive(Debug, Clone)]
pub struct BatchedGemmKernel {
pub config: BatchedGemmConfig,
variant: BatchedGemmVariant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BatchedGemmVariant {
Naive,
Tiled,
TiledUnrolled,
WmmaFp16,
}
impl BatchedGemmKernel {
#[must_use]
pub fn naive(batch: u32, m: u32, n: u32, k: u32) -> Self {
Self {
config: BatchedGemmConfig {
batch,
m,
n,
k,
..Default::default()
},
variant: BatchedGemmVariant::Naive,
}
}
#[must_use]
pub fn tiled(batch: u32, m: u32, n: u32, k: u32, tile_size: u32) -> Self {
Self {
config: BatchedGemmConfig {
batch,
m,
n,
k,
tile_size,
},
variant: BatchedGemmVariant::Tiled,
}
}
#[must_use]
pub fn tiled_unrolled(batch: u32, m: u32, n: u32, k: u32, tile_size: u32) -> Self {
Self {
config: BatchedGemmConfig {
batch,
m,
n,
k,
tile_size,
},
variant: BatchedGemmVariant::TiledUnrolled,
}
}
#[must_use]
pub fn wmma_fp16(batch: u32, m: u32, n: u32, k: u32) -> Self {
Self {
config: BatchedGemmConfig {
batch,
m,
n,
k,
tile_size: 16, },
variant: BatchedGemmVariant::WmmaFp16,
}
}
fn build_naive(&self) -> PtxKernel {
let m_val = self.config.m;
let n_val = self.config.n;
let k_val = self.config.k;
PtxKernel::new("batched_gemm_naive")
.param(PtxType::U64, "a_ptr")
.param(PtxType::U64, "b_ptr")
.param(PtxType::U64, "c_ptr")
.param(PtxType::U32, "batch")
.param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.build(|ctx| {
let batch_idx = ctx.special_reg(crate::ptx::PtxReg::CtaIdZ);
let ctaid_y = ctx.special_reg(crate::ptx::PtxReg::CtaIdY);
let ntid_y = ctx.special_reg(crate::ptx::PtxReg::NtidY);
let tid_y = ctx.special_reg(crate::ptx::PtxReg::TidY);
let ctaid_x = ctx.special_reg(crate::ptx::PtxReg::CtaIdX);
let ntid_x = ctx.special_reg(crate::ptx::PtxReg::NtidX);
let tid_x = ctx.special_reg(crate::ptx::PtxReg::TidX);
let row = ctx.mad_lo_u32(ctaid_y, ntid_y, tid_y);
let col = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let batch_param = ctx.load_param_u32("batch");
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let pred_batch = ctx.setp_ge_u32(batch_idx, batch_param);
ctx.branch_if(pred_batch, "exit");
let pred_m = ctx.setp_ge_u32(row, m_param);
ctx.branch_if(pred_m, "exit");
let pred_n = ctx.setp_ge_u32(col, n_param);
ctx.branch_if(pred_n, "exit");
let a_ptr = ctx.load_param_u64("a_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let a_batch_offset = ctx.mul_wide_u32(batch_idx, m_val * k_val * 4);
let b_batch_offset = ctx.mul_wide_u32(batch_idx, k_val * n_val * 4);
let c_batch_offset = ctx.mul_wide_u32(batch_idx, m_val * n_val * 4);
let a_batch_ptr = ctx.add_u64(a_ptr, a_batch_offset);
let b_batch_ptr = ctx.add_u64(b_ptr, b_batch_offset);
let c_batch_ptr = ctx.add_u64(c_ptr, c_batch_offset);
let acc = ctx.mov_f32_imm(0.0);
let row_offset = ctx.mul_wide_u32(row, k_val * 4);
let a_row_ptr = ctx.add_u64(a_batch_ptr, row_offset);
let col_offset = ctx.mul_wide_u32(col, 4);
let b_col_base = ctx.add_u64(b_batch_ptr, col_offset);
let i = ctx.mov_u32_imm(0);
ctx.label("loop_k");
let pred_k = ctx.setp_ge_u32(i, k_param);
ctx.branch_if(pred_k, "loop_end");
let i_offset = ctx.mul_wide_u32(i, 4);
let a_addr = ctx.add_u64(a_row_ptr, i_offset);
let a_val = ctx.ld_global_f32(a_addr);
let b_row_offset = ctx.mul_wide_u32(i, n_val * 4);
let b_addr = ctx.add_u64(b_col_base, b_row_offset);
let b_val = ctx.ld_global_f32(b_addr);
ctx.fma_f32_inplace(acc, a_val, b_val);
ctx.add_u32_inplace(i, 1);
ctx.branch("loop_k");
ctx.label("loop_end");
let c_row_offset = ctx.mul_wide_u32(row, n_val * 4);
let c_row_ptr = ctx.add_u64(c_batch_ptr, c_row_offset);
let c_col_offset = ctx.mul_wide_u32(col, 4);
let c_addr = ctx.add_u64(c_row_ptr, c_col_offset);
ctx.st_global_f32(c_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
fn build_tiled(&self) -> PtxKernel {
let tile_size = self.config.tile_size;
let smem_size = tile_size * tile_size * 4 * 2; let n_tiles = (self.config.k + tile_size - 1) / tile_size;
let m_val = self.config.m;
let n_val = self.config.n;
let k_val = self.config.k;
PtxKernel::new("batched_gemm_tiled")
.param(PtxType::U64, "a_ptr")
.param(PtxType::U64, "b_ptr")
.param(PtxType::U64, "c_ptr")
.param(PtxType::U32, "batch")
.param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.shared_memory(smem_size as usize)
.build(|ctx| {
let batch_idx = ctx.special_reg(crate::ptx::PtxReg::CtaIdZ);
let tid_x = ctx.special_reg(crate::ptx::PtxReg::TidX);
let tid_y = ctx.special_reg(crate::ptx::PtxReg::TidY);
let ctaid_x = ctx.special_reg(crate::ptx::PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(crate::ptx::PtxReg::CtaIdY);
let tile_size_reg = ctx.mov_u32_imm(tile_size);
let row = ctx.mad_lo_u32(ctaid_y, tile_size_reg, tid_y);
let col = ctx.mad_lo_u32(ctaid_x, tile_size_reg, tid_x);
let batch_param = ctx.load_param_u32("batch");
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let batch_valid = ctx.setp_lt_u32(batch_idx, batch_param);
let row_valid = ctx.setp_lt_u32(row, m_param);
let col_valid = ctx.setp_lt_u32(col, n_param);
let a_ptr = ctx.load_param_u64("a_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let a_batch_offset = ctx.mul_wide_u32(batch_idx, m_val * k_val * 4);
let b_batch_offset = ctx.mul_wide_u32(batch_idx, k_val * n_val * 4);
let c_batch_offset = ctx.mul_wide_u32(batch_idx, m_val * n_val * 4);
let a_batch_ptr = ctx.add_u64(a_ptr, a_batch_offset);
let b_batch_ptr = ctx.add_u64(b_ptr, b_batch_offset);
let c_batch_ptr = ctx.add_u64(c_ptr, c_batch_offset);
let acc = ctx.mov_f32_imm(0.0);
let tile_idx = ctx.mov_u32_imm(0);
let n_tiles_reg = ctx.mov_u32_imm(n_tiles);
ctx.label("tile_loop");
let tile_done = ctx.setp_ge_u32(tile_idx, n_tiles_reg);
ctx.branch_if(tile_done, "tile_loop_end");
let smem_idx = ctx.mad_lo_u32(tid_y, tile_size_reg, tid_x);
let smem_a_offset = ctx.mul_u32(smem_idx, 4);
let smem_b_base = ctx.mov_u32_imm(tile_size * tile_size * 4);
let smem_b_offset = ctx.add_u32_reg(smem_b_base, smem_a_offset);
let tile_k_offset = ctx.mul_u32(tile_idx, tile_size);
let a_col = ctx.add_u32_reg(tile_k_offset, tid_x);
let a_col_valid = ctx.setp_lt_u32(a_col, k_param);
let zero_a = ctx.mov_f32_imm(0.0);
ctx.st_shared_f32(smem_a_offset, zero_a);
ctx.branch_if_not(batch_valid, "skip_a_load");
ctx.branch_if_not(row_valid, "skip_a_load");
ctx.branch_if_not(a_col_valid, "skip_a_load");
let row_offset_a = ctx.mul_wide_u32(row, k_val * 4);
let col_offset_a = ctx.mul_wide_u32(a_col, 4);
let a_row_base = ctx.add_u64(a_batch_ptr, row_offset_a);
let a_addr = ctx.add_u64(a_row_base, col_offset_a);
let a_val = ctx.ld_global_f32(a_addr);
ctx.st_shared_f32(smem_a_offset, a_val);
ctx.label("skip_a_load");
let b_row = ctx.add_u32_reg(tile_k_offset, tid_y);
let b_row_valid = ctx.setp_lt_u32(b_row, k_param);
let zero_b = ctx.mov_f32_imm(0.0);
ctx.st_shared_f32(smem_b_offset, zero_b);
ctx.branch_if_not(batch_valid, "skip_b_load");
ctx.branch_if_not(b_row_valid, "skip_b_load");
ctx.branch_if_not(col_valid, "skip_b_load");
let row_offset_b = ctx.mul_wide_u32(b_row, n_val * 4);
let col_offset_b = ctx.mul_wide_u32(col, 4);
let b_row_base = ctx.add_u64(b_batch_ptr, row_offset_b);
let b_addr = ctx.add_u64(b_row_base, col_offset_b);
let b_val = ctx.ld_global_f32(b_addr);
ctx.st_shared_f32(smem_b_offset, b_val);
ctx.label("skip_b_load");
ctx.bar_sync(0);
let inner_k = ctx.mov_u32_imm(0);
ctx.label("inner_k_loop");
let inner_done = ctx.setp_ge_u32(inner_k, tile_size_reg);
ctx.branch_if(inner_done, "inner_k_end");
let as_idx = ctx.mad_lo_u32(tid_y, tile_size_reg, inner_k);
let as_addr = ctx.mul_u32(as_idx, 4);
let a_shared = ctx.ld_shared_f32(as_addr);
let bs_idx = ctx.mad_lo_u32(inner_k, tile_size_reg, tid_x);
let bs_idx_bytes = ctx.mul_u32(bs_idx, 4);
let bs_addr = ctx.add_u32_reg(smem_b_base, bs_idx_bytes);
let b_shared = ctx.ld_shared_f32(bs_addr);
ctx.fma_f32_inplace(acc, a_shared, b_shared);
ctx.add_u32_inplace(inner_k, 1);
ctx.branch("inner_k_loop");
ctx.label("inner_k_end");
ctx.bar_sync(1);
ctx.add_u32_inplace(tile_idx, 1);
ctx.branch("tile_loop");
ctx.label("tile_loop_end");
ctx.branch_if_not(batch_valid, "exit");
ctx.branch_if_not(row_valid, "exit");
ctx.branch_if_not(col_valid, "exit");
let c_row_offset = ctx.mul_wide_u32(row, n_val * 4);
let c_col_offset = ctx.mul_wide_u32(col, 4);
let c_row_base = ctx.add_u64(c_batch_ptr, c_row_offset);
let c_addr = ctx.add_u64(c_row_base, c_col_offset);
ctx.st_global_f32(c_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
#[allow(clippy::too_many_lines)]
fn build_tiled_unrolled(&self) -> PtxKernel {
let tile_size = self.config.tile_size;
let smem_size = tile_size * tile_size * 4 * 2;
let n_tiles = (self.config.k + tile_size - 1) / tile_size;
let m_val = self.config.m;
let n_val = self.config.n;
let k_val = self.config.k;
let batch_stride_a = m_val * k_val;
let batch_stride_b = k_val * n_val;
let batch_stride_c = m_val * n_val;
let unroll_factor = 4u32;
let unrolled_iters = tile_size / unroll_factor;
PtxKernel::new("batched_gemm_tiled_unrolled")
.param(PtxType::U64, "a_ptr")
.param(PtxType::U64, "b_ptr")
.param(PtxType::U64, "c_ptr")
.param(PtxType::U32, "batch")
.param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.shared_memory(smem_size as usize)
.build(|ctx| {
let batch_idx = ctx.special_reg(crate::ptx::PtxReg::CtaIdZ);
let tid_x = ctx.special_reg(crate::ptx::PtxReg::TidX);
let tid_y = ctx.special_reg(crate::ptx::PtxReg::TidY);
let ctaid_x = ctx.special_reg(crate::ptx::PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(crate::ptx::PtxReg::CtaIdY);
let tile_size_reg = ctx.mov_u32_imm(tile_size);
let row = ctx.mad_lo_u32(ctaid_y, tile_size_reg, tid_y);
let col = ctx.mad_lo_u32(ctaid_x, tile_size_reg, tid_x);
let batch_param = ctx.load_param_u32("batch");
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let batch_valid = ctx.setp_lt_u32(batch_idx, batch_param);
let row_valid = ctx.setp_lt_u32(row, m_param);
let col_valid = ctx.setp_lt_u32(col, n_param);
let a_ptr = ctx.load_param_u64("a_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let batch_offset_a = ctx.mul_wide_u32(batch_idx, batch_stride_a * 4);
let batch_offset_b = ctx.mul_wide_u32(batch_idx, batch_stride_b * 4);
let batch_offset_c = ctx.mul_wide_u32(batch_idx, batch_stride_c * 4);
let a_batch_ptr = ctx.add_u64(a_ptr, batch_offset_a);
let b_batch_ptr = ctx.add_u64(b_ptr, batch_offset_b);
let c_batch_ptr = ctx.add_u64(c_ptr, batch_offset_c);
let acc = ctx.mov_f32_imm(0.0);
let tile_idx = ctx.mov_u32_imm(0);
let n_tiles_reg = ctx.mov_u32_imm(n_tiles);
ctx.label("tile_loop");
let tile_done = ctx.setp_ge_u32(tile_idx, n_tiles_reg);
ctx.branch_if(tile_done, "tile_loop_end");
let smem_idx = ctx.mad_lo_u32(tid_y, tile_size_reg, tid_x);
let smem_a_offset = ctx.mul_u32(smem_idx, 4);
let smem_b_base = ctx.mov_u32_imm(tile_size * tile_size * 4);
let smem_b_offset = ctx.add_u32_reg(smem_b_base, smem_a_offset);
let tile_k_offset = ctx.mul_u32(tile_idx, tile_size);
let a_col = ctx.add_u32_reg(tile_k_offset, tid_x);
let a_col_valid = ctx.setp_lt_u32(a_col, k_param);
let zero_a = ctx.mov_f32_imm(0.0);
ctx.st_shared_f32(smem_a_offset, zero_a);
ctx.branch_if_not(batch_valid, "skip_a_load");
ctx.branch_if_not(row_valid, "skip_a_load");
ctx.branch_if_not(a_col_valid, "skip_a_load");
let row_offset_a = ctx.mul_wide_u32(row, k_val * 4);
let col_offset_a = ctx.mul_wide_u32(a_col, 4);
let a_row_base = ctx.add_u64(a_batch_ptr, row_offset_a);
let a_addr = ctx.add_u64(a_row_base, col_offset_a);
let a_val = ctx.ld_global_f32(a_addr);
ctx.st_shared_f32(smem_a_offset, a_val);
ctx.label("skip_a_load");
let b_row = ctx.add_u32_reg(tile_k_offset, tid_y);
let b_row_valid = ctx.setp_lt_u32(b_row, k_param);
let zero_b = ctx.mov_f32_imm(0.0);
ctx.st_shared_f32(smem_b_offset, zero_b);
ctx.branch_if_not(batch_valid, "skip_b_load");
ctx.branch_if_not(b_row_valid, "skip_b_load");
ctx.branch_if_not(col_valid, "skip_b_load");
let row_offset_b = ctx.mul_wide_u32(b_row, n_val * 4);
let col_offset_b = ctx.mul_wide_u32(col, 4);
let b_row_base = ctx.add_u64(b_batch_ptr, row_offset_b);
let b_addr = ctx.add_u64(b_row_base, col_offset_b);
let b_val = ctx.ld_global_f32(b_addr);
ctx.st_shared_f32(smem_b_offset, b_val);
ctx.label("skip_b_load");
ctx.bar_sync(0);
let inner_k = ctx.mov_u32_imm(0);
let unrolled_iters_reg = ctx.mov_u32_imm(unrolled_iters);
ctx.label("inner_k_loop");
let inner_done = ctx.setp_ge_u32(inner_k, unrolled_iters_reg);
ctx.branch_if(inner_done, "inner_k_end");
let k_base = ctx.mul_u32(inner_k, unroll_factor);
let k0 = k_base;
let as_idx0 = ctx.mad_lo_u32(tid_y, tile_size_reg, k0);
let as_addr0 = ctx.mul_u32(as_idx0, 4);
let a_shared0 = ctx.ld_shared_f32(as_addr0);
let bs_idx0 = ctx.mad_lo_u32(k0, tile_size_reg, tid_x);
let bs_idx_bytes0 = ctx.mul_u32(bs_idx0, 4);
let bs_addr0 = ctx.add_u32_reg(smem_b_base, bs_idx_bytes0);
let b_shared0 = ctx.ld_shared_f32(bs_addr0);
ctx.fma_f32_inplace(acc, a_shared0, b_shared0);
let k1 = ctx.add_u32(k_base, 1);
let as_idx1 = ctx.mad_lo_u32(tid_y, tile_size_reg, k1);
let as_addr1 = ctx.mul_u32(as_idx1, 4);
let a_shared1 = ctx.ld_shared_f32(as_addr1);
let bs_idx1 = ctx.mad_lo_u32(k1, tile_size_reg, tid_x);
let bs_idx_bytes1 = ctx.mul_u32(bs_idx1, 4);
let bs_addr1 = ctx.add_u32_reg(smem_b_base, bs_idx_bytes1);
let b_shared1 = ctx.ld_shared_f32(bs_addr1);
ctx.fma_f32_inplace(acc, a_shared1, b_shared1);
let k2 = ctx.add_u32(k_base, 2);
let as_idx2 = ctx.mad_lo_u32(tid_y, tile_size_reg, k2);
let as_addr2 = ctx.mul_u32(as_idx2, 4);
let a_shared2 = ctx.ld_shared_f32(as_addr2);
let bs_idx2 = ctx.mad_lo_u32(k2, tile_size_reg, tid_x);
let bs_idx_bytes2 = ctx.mul_u32(bs_idx2, 4);
let bs_addr2 = ctx.add_u32_reg(smem_b_base, bs_idx_bytes2);
let b_shared2 = ctx.ld_shared_f32(bs_addr2);
ctx.fma_f32_inplace(acc, a_shared2, b_shared2);
let k3 = ctx.add_u32(k_base, 3);
let as_idx3 = ctx.mad_lo_u32(tid_y, tile_size_reg, k3);
let as_addr3 = ctx.mul_u32(as_idx3, 4);
let a_shared3 = ctx.ld_shared_f32(as_addr3);
let bs_idx3 = ctx.mad_lo_u32(k3, tile_size_reg, tid_x);
let bs_idx_bytes3 = ctx.mul_u32(bs_idx3, 4);
let bs_addr3 = ctx.add_u32_reg(smem_b_base, bs_idx_bytes3);
let b_shared3 = ctx.ld_shared_f32(bs_addr3);
ctx.fma_f32_inplace(acc, a_shared3, b_shared3);
ctx.add_u32_inplace(inner_k, 1);
ctx.branch("inner_k_loop");
ctx.label("inner_k_end");
ctx.bar_sync(1);
ctx.add_u32_inplace(tile_idx, 1);
ctx.branch("tile_loop");
ctx.label("tile_loop_end");
ctx.branch_if_not(batch_valid, "exit");
ctx.branch_if_not(row_valid, "exit");
ctx.branch_if_not(col_valid, "exit");
let c_row_offset = ctx.mul_wide_u32(row, n_val * 4);
let c_col_offset = ctx.mul_wide_u32(col, 4);
let c_row_base = ctx.add_u64(c_batch_ptr, c_row_offset);
let c_addr = ctx.add_u64(c_row_base, c_col_offset);
ctx.st_global_f32(c_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
#[allow(clippy::too_many_lines)]
fn build_wmma_fp16(&self) -> PtxKernel {
use crate::ptx::WmmaLayout;
let tile_size = 16_u32;
let smem_size = tile_size * tile_size * 2 * 2; let n_k_tiles = (self.config.k + tile_size - 1) / tile_size;
let m_val = self.config.m;
let n_val = self.config.n;
let k_val = self.config.k;
PtxKernel::new("batched_gemm_wmma_fp16")
.param(PtxType::U64, "a_ptr")
.param(PtxType::U64, "b_ptr")
.param(PtxType::U64, "c_ptr")
.param(PtxType::U32, "batch")
.param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.shared_memory(smem_size as usize)
.build(|ctx| {
let tid_x = ctx.special_reg(crate::ptx::PtxReg::TidX);
let ctaid_x = ctx.special_reg(crate::ptx::PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(crate::ptx::PtxReg::CtaIdY);
let batch_idx = ctx.special_reg(crate::ptx::PtxReg::CtaIdZ);
let tile_size_reg = ctx.mov_u32_imm(tile_size);
let tile_row = ctx.mul_u32(ctaid_y, tile_size);
let tile_col = ctx.mul_u32(ctaid_x, tile_size);
let batch_param = ctx.load_param_u32("batch");
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let batch_valid = ctx.setp_lt_u32(batch_idx, batch_param);
let tile_row_valid = ctx.setp_lt_u32(tile_row, m_param);
let tile_col_valid = ctx.setp_lt_u32(tile_col, n_param);
let a_ptr = ctx.load_param_u64("a_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let a_batch_offset = ctx.mul_wide_u32(batch_idx, m_val * k_val * 4);
let b_batch_offset = ctx.mul_wide_u32(batch_idx, k_val * n_val * 4);
let c_batch_offset = ctx.mul_wide_u32(batch_idx, m_val * n_val * 4);
let a_batch_ptr = ctx.add_u64(a_ptr, a_batch_offset);
let b_batch_ptr = ctx.add_u64(b_ptr, b_batch_offset);
let c_batch_ptr = ctx.add_u64(c_ptr, c_batch_offset);
let smem_a_base = ctx.mov_u32_imm(0);
let smem_b_base = ctx.mov_u32_imm(tile_size * tile_size * 2);
let frag_c = ctx.wmma_init_c_zero();
let k_tile_idx = ctx.mov_u32_imm(0);
let n_k_tiles_reg = ctx.mov_u32_imm(n_k_tiles);
ctx.label("k_tile_loop");
let k_done = ctx.setp_ge_u32(k_tile_idx, n_k_tiles_reg);
ctx.branch_if(k_done, "k_tile_end");
let k_offset = ctx.mul_u32_reg(k_tile_idx, tile_size_reg);
let elements_per_thread = ctx.mov_u32_imm(8);
let my_start = ctx.mul_u32_reg(tid_x, elements_per_thread);
let load_idx = ctx.mov_u32_imm(0);
ctx.label("load_a_loop_batched");
let load_done = ctx.setp_ge_u32(load_idx, elements_per_thread);
ctx.branch_if(load_done, "load_a_end_batched");
let elem_idx = ctx.add_u32_reg(my_start, load_idx);
let row_in_tile = ctx.div_u32(elem_idx, 16);
let col_in_tile = ctx.rem_u32(elem_idx, 16);
let smem_a_offset = ctx.mul_u32(elem_idx, 2);
let smem_a_addr = ctx.add_u32_reg(smem_a_base, smem_a_offset);
let zero_f32 = ctx.mov_f32_imm(0.0);
let zero_f16 = ctx.cvt_f16_f32(zero_f32);
ctx.st_shared_f16(smem_a_addr, zero_f16);
let a_row = ctx.add_u32_reg(tile_row, row_in_tile);
let a_col = ctx.add_u32_reg(k_offset, col_in_tile);
let a_row_valid = ctx.setp_lt_u32(a_row, m_param);
let a_col_valid = ctx.setp_lt_u32(a_col, k_param);
ctx.branch_if_not(a_row_valid, "skip_a_load_batched");
ctx.branch_if_not(a_col_valid, "skip_a_load_batched");
ctx.branch_if_not(batch_valid, "skip_a_load_batched");
let k_reg = ctx.mov_u32_imm(k_val);
let a_idx = ctx.mad_lo_u32(a_row, k_reg, a_col);
let a_byte_offset = ctx.mul_wide_u32(a_idx, 4);
let a_addr = ctx.add_u64(a_batch_ptr, a_byte_offset);
let a_val_f32 = ctx.ld_global_f32(a_addr);
let a_val_f16 = ctx.cvt_f16_f32(a_val_f32);
ctx.st_shared_f16(smem_a_addr, a_val_f16);
ctx.label("skip_a_load_batched");
ctx.add_u32_inplace(load_idx, 1);
ctx.branch("load_a_loop_batched");
ctx.label("load_a_end_batched");
let load_idx_b = ctx.mov_u32_imm(0);
ctx.label("load_b_loop_batched");
let load_b_done = ctx.setp_ge_u32(load_idx_b, elements_per_thread);
ctx.branch_if(load_b_done, "load_b_end_batched");
let elem_idx_b = ctx.add_u32_reg(my_start, load_idx_b);
let row_in_tile_b = ctx.div_u32(elem_idx_b, 16);
let col_in_tile_b = ctx.rem_u32(elem_idx_b, 16);
let smem_b_offset = ctx.mul_u32(elem_idx_b, 2);
let smem_b_addr = ctx.add_u32_reg(smem_b_base, smem_b_offset);
let zero_b_f32 = ctx.mov_f32_imm(0.0);
let zero_b_f16 = ctx.cvt_f16_f32(zero_b_f32);
ctx.st_shared_f16(smem_b_addr, zero_b_f16);
let b_row = ctx.add_u32_reg(k_offset, row_in_tile_b);
let b_col = ctx.add_u32_reg(tile_col, col_in_tile_b);
let b_row_valid = ctx.setp_lt_u32(b_row, k_param);
let b_col_valid = ctx.setp_lt_u32(b_col, n_param);
ctx.branch_if_not(b_row_valid, "skip_b_load_batched");
ctx.branch_if_not(b_col_valid, "skip_b_load_batched");
ctx.branch_if_not(batch_valid, "skip_b_load_batched");
let n_reg = ctx.mov_u32_imm(n_val);
let b_idx = ctx.mad_lo_u32(b_row, n_reg, b_col);
let b_byte_offset = ctx.mul_wide_u32(b_idx, 4);
let b_addr = ctx.add_u64(b_batch_ptr, b_byte_offset);
let b_val_f32 = ctx.ld_global_f32(b_addr);
let b_val_f16 = ctx.cvt_f16_f32(b_val_f32);
ctx.st_shared_f16(smem_b_addr, b_val_f16);
ctx.label("skip_b_load_batched");
ctx.add_u32_inplace(load_idx_b, 1);
ctx.branch("load_b_loop_batched");
ctx.label("load_b_end_batched");
ctx.bar_sync(0);
let smem_generic_base = ctx.shared_base_addr();
let frag_a = ctx.wmma_load_a_f16(smem_generic_base, 16, WmmaLayout::RowMajor);
let smem_b_offset_u64 = ctx.cvt_u64_u32(smem_b_base);
let smem_b_ptr = ctx.add_u64(smem_generic_base, smem_b_offset_u64);
let frag_b = ctx.wmma_load_b_f16(smem_b_ptr, 16, WmmaLayout::RowMajor);
let frag_d = ctx.wmma_mma_f16_f32(&frag_a, &frag_b, &frag_c);
for (c_reg, d_reg) in frag_c.iter().zip(frag_d.iter()) {
ctx.mov_f32_reg(*c_reg, *d_reg);
}
ctx.bar_sync(1);
ctx.add_u32_inplace(k_tile_idx, 1);
ctx.branch("k_tile_loop");
ctx.label("k_tile_end");
ctx.branch_if_not(batch_valid, "exit_batched");
ctx.branch_if_not(tile_row_valid, "exit_batched");
ctx.branch_if_not(tile_col_valid, "exit_batched");
let c_tile_row_offset = ctx.mul_wide_u32(tile_row, n_val * 4);
let c_tile_col_offset = ctx.mul_wide_u32(tile_col, 4);
let c_tile_base = ctx.add_u64(c_batch_ptr, c_tile_row_offset);
let c_tile_addr = ctx.add_u64(c_tile_base, c_tile_col_offset);
ctx.wmma_store_d_f32(c_tile_addr, &frag_c, n_val, WmmaLayout::RowMajor);
ctx.label("exit_batched");
ctx.ret();
})
}
}
impl Kernel for BatchedGemmKernel {
fn name(&self) -> &str {
match self.variant {
BatchedGemmVariant::Naive => "batched_gemm_naive",
BatchedGemmVariant::Tiled => "batched_gemm_tiled",
BatchedGemmVariant::TiledUnrolled => "batched_gemm_tiled_unrolled",
BatchedGemmVariant::WmmaFp16 => "batched_gemm_wmma_fp16",
}
}
fn build_ptx(&self) -> PtxKernel {
match self.variant {
BatchedGemmVariant::Naive => self.build_naive(),
BatchedGemmVariant::Tiled => self.build_tiled(),
BatchedGemmVariant::TiledUnrolled => self.build_tiled_unrolled(),
BatchedGemmVariant::WmmaFp16 => self.build_wmma_fp16(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batched_gemm_config_default() {
let config = BatchedGemmConfig::default();
assert_eq!(config.batch, 1);
assert_eq!(config.m, 1024);
assert_eq!(config.n, 1024);
assert_eq!(config.k, 1024);
assert_eq!(config.tile_size, 16);
}
#[test]
fn test_batched_gemm_naive_ptx_gen() {
let kernel = BatchedGemmKernel::naive(4, 32, 32, 32);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry batched_gemm_naive"));
assert!(ptx.contains(".param .u64 a_ptr"));
assert!(ptx.contains(".param .u64 b_ptr"));
assert!(ptx.contains(".param .u64 c_ptr"));
assert!(ptx.contains(".param .u32 batch"));
assert!(ptx.contains(".param .u32 m"));
assert!(ptx.contains(".param .u32 n"));
assert!(ptx.contains(".param .u32 k"));
}
#[test]
fn test_batched_gemm_tiled_ptx_gen() {
let kernel = BatchedGemmKernel::tiled(4, 64, 64, 64, 16);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry batched_gemm_tiled"));
assert!(ptx.contains(".shared")); assert!(ptx.contains("bar.sync")); }
#[test]
fn test_batched_gemm_tiled_unrolled_ptx_gen() {
let kernel = BatchedGemmKernel::tiled_unrolled(4, 64, 64, 64, 16);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry batched_gemm_tiled_unrolled"));
assert!(ptx.contains(".shared"));
assert!(ptx.contains("fma"));
}
#[test]
fn test_batched_gemm_wmma_fp16_ptx_gen() {
let kernel = BatchedGemmKernel::wmma_fp16(4, 64, 64, 64);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry batched_gemm_wmma_fp16"));
assert!(ptx.contains("wmma.load")); assert!(ptx.contains("wmma.mma")); assert!(ptx.contains("wmma.store")); }
#[test]
fn test_batched_gemm_kernel_names() {
assert_eq!(
BatchedGemmKernel::naive(1, 32, 32, 32).name(),
"batched_gemm_naive"
);
assert_eq!(
BatchedGemmKernel::tiled(1, 32, 32, 32, 16).name(),
"batched_gemm_tiled"
);
assert_eq!(
BatchedGemmKernel::tiled_unrolled(1, 32, 32, 32, 16).name(),
"batched_gemm_tiled_unrolled"
);
assert_eq!(
BatchedGemmKernel::wmma_fp16(1, 32, 32, 32).name(),
"batched_gemm_wmma_fp16"
);
}
#[test]
fn test_batched_gemm_config_clone() {
let config = BatchedGemmConfig {
batch: 8,
m: 256,
n: 128,
k: 64,
tile_size: 32,
};
let cloned = config.clone();
assert_eq!(cloned.batch, 8);
assert_eq!(cloned.m, 256);
assert_eq!(cloned.n, 128);
assert_eq!(cloned.k, 64);
assert_eq!(cloned.tile_size, 32);
}
#[test]
fn test_batched_gemm_kernel_clone() {
let kernel = BatchedGemmKernel::naive(2, 16, 16, 16);
let cloned = kernel.clone();
assert_eq!(cloned.name(), "batched_gemm_naive");
assert_eq!(cloned.config.batch, 2);
}
#[test]
fn test_batched_gemm_debug_format() {
let config = BatchedGemmConfig::default();
let debug = format!("{:?}", config);
assert!(debug.contains("BatchedGemmConfig"));
let kernel = BatchedGemmKernel::tiled(4, 64, 64, 64, 16);
let kernel_debug = format!("{:?}", kernel);
assert!(debug.contains("BatchedGemmConfig") || kernel_debug.contains("BatchedGemmKernel"));
}
#[test]
fn test_batched_gemm_small_dimensions() {
let kernel = BatchedGemmKernel::naive(1, 1, 1, 1);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry batched_gemm_naive"));
}
#[test]
fn test_batched_gemm_large_batch() {
let kernel = BatchedGemmKernel::naive(128, 16, 16, 16);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry batched_gemm_naive"));
}
#[test]
fn test_batched_gemm_non_square_dims() {
let kernel = BatchedGemmKernel::tiled(4, 128, 64, 32, 16);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry batched_gemm_tiled"));
assert_eq!(kernel.config.m, 128);
assert_eq!(kernel.config.n, 64);
assert_eq!(kernel.config.k, 32);
}
}