#![allow(clippy::similar_names)]
use super::BatchedGemmKernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxType};
impl BatchedGemmKernel {
pub(super) fn build_naive(&self) -> PtxKernel {
let m_val = self.config.m;
let n_val = self.config.n;
let k_val = self.config.k;
PtxKernel::new("batched_gemm_naive")
.param(PtxType::U64, "a_ptr")
.param(PtxType::U64, "b_ptr")
.param(PtxType::U64, "c_ptr")
.param(PtxType::U32, "batch")
.param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.build(|ctx| {
let batch_idx = ctx.special_reg(crate::ptx::PtxReg::CtaIdZ);
let ctaid_y = ctx.special_reg(crate::ptx::PtxReg::CtaIdY);
let ntid_y = ctx.special_reg(crate::ptx::PtxReg::NtidY);
let tid_y = ctx.special_reg(crate::ptx::PtxReg::TidY);
let ctaid_x = ctx.special_reg(crate::ptx::PtxReg::CtaIdX);
let ntid_x = ctx.special_reg(crate::ptx::PtxReg::NtidX);
let tid_x = ctx.special_reg(crate::ptx::PtxReg::TidX);
let row = ctx.mad_lo_u32(ctaid_y, ntid_y, tid_y);
let col = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let batch_param = ctx.load_param_u32("batch");
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let pred_batch = ctx.setp_ge_u32(batch_idx, batch_param);
ctx.branch_if(pred_batch, "exit");
let pred_m = ctx.setp_ge_u32(row, m_param);
ctx.branch_if(pred_m, "exit");
let pred_n = ctx.setp_ge_u32(col, n_param);
ctx.branch_if(pred_n, "exit");
let a_ptr = ctx.load_param_u64("a_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let a_batch_offset = ctx.mul_wide_u32(batch_idx, m_val * k_val * 4);
let b_batch_offset = ctx.mul_wide_u32(batch_idx, k_val * n_val * 4);
let c_batch_offset = ctx.mul_wide_u32(batch_idx, m_val * n_val * 4);
let a_batch_ptr = ctx.add_u64(a_ptr, a_batch_offset);
let b_batch_ptr = ctx.add_u64(b_ptr, b_batch_offset);
let c_batch_ptr = ctx.add_u64(c_ptr, c_batch_offset);
let acc = ctx.mov_f32_imm(0.0);
let row_offset = ctx.mul_wide_u32(row, k_val * 4);
let a_row_ptr = ctx.add_u64(a_batch_ptr, row_offset);
let col_offset = ctx.mul_wide_u32(col, 4);
let b_col_base = ctx.add_u64(b_batch_ptr, col_offset);
let i = ctx.mov_u32_imm(0);
ctx.label("loop_k");
let pred_k = ctx.setp_ge_u32(i, k_param);
ctx.branch_if(pred_k, "loop_end");
let i_offset = ctx.mul_wide_u32(i, 4);
let a_addr = ctx.add_u64(a_row_ptr, i_offset);
let a_val = ctx.ld_global_f32(a_addr);
let b_row_offset = ctx.mul_wide_u32(i, n_val * 4);
let b_addr = ctx.add_u64(b_col_base, b_row_offset);
let b_val = ctx.ld_global_f32(b_addr);
ctx.fma_f32_inplace(acc, a_val, b_val);
ctx.add_u32_inplace(i, 1);
ctx.branch("loop_k");
ctx.label("loop_end");
let c_row_offset = ctx.mul_wide_u32(row, n_val * 4);
let c_row_ptr = ctx.add_u64(c_batch_ptr, c_row_offset);
let c_col_offset = ctx.mul_wide_u32(col, 4);
let c_addr = ctx.add_u64(c_row_ptr, c_col_offset);
ctx.st_global_f32(c_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
}