#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct BatchedScaleKernel {
pub n: u32,
}
impl BatchedScaleKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for BatchedScaleKernel {
fn name(&self) -> &str {
"batched_scale"
}
fn build_ptx(&self) -> PtxKernel {
let total = self.n;
PtxKernel::new("batched_scale")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::F32, "scale")
.param(PtxType::U32, "n")
.build(move |ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let total_reg = ctx.mov_u32_imm(total);
let in_bounds = ctx.setp_lt_u32(gid, total_reg);
ctx.branch_if_not(in_bounds, "exit");
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let scale = ctx.load_param_f32("scale");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let in_addr = ctx.add_u64(input_ptr, offset);
let out_addr = ctx.add_u64(output_ptr, offset);
let val = ctx.ld_global_f32(in_addr);
let scaled = ctx.mul_f32(val, scale);
ctx.st_global_f32(out_addr, scaled);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct BatchedSoftmaxKernel {
pub total_rows: u32,
pub row_size: u32,
}
impl BatchedSoftmaxKernel {
#[must_use]
pub const fn new(total_rows: u32, row_size: u32) -> Self {
Self { total_rows, row_size }
}
}
impl Kernel for BatchedSoftmaxKernel {
fn name(&self) -> &str {
"batched_softmax"
}
fn build_ptx(&self) -> PtxKernel {
let total_rows = self.total_rows;
let row_size = self.row_size;
PtxKernel::new("batched_softmax")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U32, "total_rows")
.param(PtxType::U32, "row_size")
.shared_memory(72)
.build(move |ctx| {
let row_idx = ctx.special_reg(PtxReg::CtaIdX);
let tid = ctx.special_reg(PtxReg::TidX);
let total_rows_reg = ctx.mov_u32_imm(total_rows);
let valid = ctx.setp_lt_u32(row_idx, total_rows_reg);
ctx.branch_if_not(valid, "exit");
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let row_size_reg = ctx.mov_u32_imm(row_size);
let row_offset = ctx.mul_wide_u32(row_idx, row_size * 4);
let row_input_ptr = ctx.add_u64(input_ptr, row_offset);
let row_output_ptr = ctx.add_u64(output_ptr, row_offset);
let four = ctx.mov_u32_imm(4);
let log2e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let local_max = ctx.mov_f32_imm(f32::NEG_INFINITY);
let i_max = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(i_max, tid);
ctx.label("max_loop");
let max_done = ctx.setp_ge_u32(i_max, row_size_reg);
ctx.branch_if(max_done, "max_done");
let offset = ctx.mul_wide_u32_reg(i_max, four);
let addr = ctx.add_u64(row_input_ptr, offset);
let val = ctx.ld_global_f32(addr);
ctx.max_f32_inplace(local_max, val);
ctx.add_u32_inplace(i_max, 32);
ctx.branch("max_loop");
ctx.label("max_done");
let tmp16 = ctx.shfl_down_f32(local_max, 16, 0xFFFF_FFFF);
ctx.max_f32_inplace(local_max, tmp16);
let tmp8 = ctx.shfl_down_f32(local_max, 8, 0xFFFF_FFFF);
ctx.max_f32_inplace(local_max, tmp8);
let tmp4 = ctx.shfl_down_f32(local_max, 4, 0xFFFF_FFFF);
ctx.max_f32_inplace(local_max, tmp4);
let tmp2 = ctx.shfl_down_f32(local_max, 2, 0xFFFF_FFFF);
ctx.max_f32_inplace(local_max, tmp2);
let tmp1 = ctx.shfl_down_f32(local_max, 1, 0xFFFF_FFFF);
ctx.max_f32_inplace(local_max, tmp1);
let row_max = ctx.shfl_idx_f32(local_max, 0, 0xFFFF_FFFF);
let local_sum = ctx.mov_f32_imm(0.0);
let i_sum = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(i_sum, tid);
ctx.label("sum_loop");
let sum_done = ctx.setp_ge_u32(i_sum, row_size_reg);
ctx.branch_if(sum_done, "sum_done");
let offset = ctx.mul_wide_u32_reg(i_sum, four);
let addr = ctx.add_u64(row_input_ptr, offset);
let val = ctx.ld_global_f32(addr);
let diff = ctx.sub_f32(val, row_max);
let exp_arg = ctx.mul_f32(diff, log2e);
let exp_val = ctx.ex2_f32(exp_arg);
ctx.add_f32_inplace(local_sum, exp_val);
ctx.add_u32_inplace(i_sum, 32);
ctx.branch("sum_loop");
ctx.label("sum_done");
let stmp16 = ctx.shfl_down_f32(local_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(local_sum, stmp16);
let stmp8 = ctx.shfl_down_f32(local_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(local_sum, stmp8);
let stmp4 = ctx.shfl_down_f32(local_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(local_sum, stmp4);
let stmp2 = ctx.shfl_down_f32(local_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(local_sum, stmp2);
let stmp1 = ctx.shfl_down_f32(local_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(local_sum, stmp1);
let row_sum = ctx.shfl_idx_f32(local_sum, 0, 0xFFFF_FFFF);
let i_write = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(i_write, tid);
ctx.label("write_loop");
let write_done = ctx.setp_ge_u32(i_write, row_size_reg);
ctx.branch_if(write_done, "exit");
let offset = ctx.mul_wide_u32_reg(i_write, four);
let in_addr = ctx.add_u64(row_input_ptr, offset);
let out_addr = ctx.add_u64(row_output_ptr, offset);
let val = ctx.ld_global_f32(in_addr);
let diff = ctx.sub_f32(val, row_max);
let exp_arg = ctx.mul_f32(diff, log2e);
let exp_val = ctx.ex2_f32(exp_arg);
let normalized = ctx.div_f32(exp_val, row_sum);
ctx.st_global_f32(out_addr, normalized);
ctx.add_u32_inplace(i_write, 32);
ctx.branch("write_loop");
ctx.label("exit");
ctx.ret();
})
}
}