#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct SquaredSumKernel {
pub n: u32,
}
impl SquaredSumKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
#[must_use]
pub const fn block_size(&self) -> u32 {
256
}
#[must_use]
pub fn num_blocks(&self) -> u32 {
let elements_per_block = self.block_size() * 4;
let needed = (self.n + elements_per_block - 1) / elements_per_block;
needed.min(256)
}
}
impl Kernel for SquaredSumKernel {
fn name(&self) -> &str {
"squared_sum_reduce"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("squared_sum_reduce")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U32, "n")
.shared_memory(8 * 4) .build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let bid = ctx.special_reg(PtxReg::CtaIdX);
let block_dim = ctx.special_reg(PtxReg::NtidX);
let num_blocks = ctx.special_reg(PtxReg::NctaIdX);
let warp_id = ctx.div_u32(tid, 32);
let lane_id = ctx.rem_u32(tid, 32);
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let n = ctx.load_param_u32("n");
let total_threads = ctx.mul_lo_u32(num_blocks, block_dim);
let global_id = ctx.mad_lo_u32(bid, block_dim, tid);
let four = ctx.mov_u32_imm(4);
let zero_u32 = ctx.mov_u32_imm(0);
let eight = ctx.mov_u32_imm(8);
let sq_sum = ctx.mov_f32_imm(0.0);
let i = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(i, global_id);
ctx.label("acc_loop");
let in_bounds = ctx.setp_lt_u32(i, n);
ctx.branch_if_not(in_bounds, "acc_done");
let byte_offset = ctx.mul_wide_u32_reg(i, four);
let addr = ctx.add_u64(input_ptr, byte_offset);
let val = ctx.ld_global_f32(addr);
ctx.fma_f32_inplace(sq_sum, val, val);
ctx.add_u32_reg_inplace(i, total_threads);
ctx.branch("acc_loop");
ctx.label("acc_done");
let shfl16 = ctx.shfl_down_f32(sq_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl16);
let shfl8 = ctx.shfl_down_f32(sq_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl8);
let shfl4 = ctx.shfl_down_f32(sq_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl4);
let shfl2 = ctx.shfl_down_f32(sq_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl2);
let shfl1 = ctx.shfl_down_f32(sq_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl1);
let lane_zero = ctx.setp_eq_u32(lane_id, zero_u32);
let warp_smem_off = ctx.mul_u32(warp_id, 4);
ctx.branch_if_not(lane_zero, "skip_smem_write");
ctx.st_shared_f32(warp_smem_off, sq_sum);
ctx.label("skip_smem_write");
ctx.bar_sync(0);
let thirty_two = ctx.mov_u32_imm(32);
let is_first_warp = ctx.setp_lt_u32(tid, thirty_two);
ctx.branch_if_not(is_first_warp, "skip_final_reduce");
let lane_valid = ctx.setp_lt_u32(lane_id, eight);
let lane_smem_off = ctx.mul_u32(lane_id, 4);
let warp_partial = ctx.mov_f32_imm(0.0);
ctx.branch_if_not(lane_valid, "skip_warp_load");
let loaded_val = ctx.ld_shared_f32(lane_smem_off);
ctx.add_f32_inplace(warp_partial, loaded_val);
ctx.label("skip_warp_load");
let red4 = ctx.shfl_down_f32(warp_partial, 4, 0xFFFF_FFFF);
let partial = ctx.add_f32(warp_partial, red4);
let red2 = ctx.shfl_down_f32(partial, 2, 0xFFFF_FFFF);
let partial = ctx.add_f32(partial, red2);
let red1 = ctx.shfl_down_f32(partial, 1, 0xFFFF_FFFF);
let final_sum = ctx.add_f32(partial, red1);
let is_tid0 = ctx.setp_eq_u32(tid, zero_u32);
ctx.branch_if_not(is_tid0, "skip_final_reduce");
let bid_offset = ctx.mul_wide_u32_reg(bid, four);
let out_addr = ctx.add_u64(output_ptr, bid_offset);
ctx.st_global_f32(out_addr, final_sum);
ctx.label("skip_final_reduce");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_squared_sum_kernel_name() {
let kernel = SquaredSumKernel::new(1024);
assert_eq!(kernel.name(), "squared_sum_reduce");
}
#[test]
fn test_squared_sum_num_blocks() {
assert_eq!(SquaredSumKernel::new(1024).num_blocks(), 1);
assert_eq!(SquaredSumKernel::new(100_000).num_blocks(), 98);
assert_eq!(SquaredSumKernel::new(32_000_000).num_blocks(), 256);
}
#[test]
fn test_squared_sum_ptx_generation() {
let kernel = SquaredSumKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry squared_sum_reduce"));
assert!(ptx.contains(".param .u64 input_ptr"));
assert!(ptx.contains(".param .u64 output_ptr"));
assert!(ptx.contains(".param .u32 n"));
assert!(ptx.contains("fma.rn.f32"));
assert!(ptx.contains("shfl.sync.down"));
assert!(ptx.contains(".shared"));
}
#[test]
fn test_squared_sum_barrier_safety() {
let kernel = SquaredSumKernel::new(1024);
let _ptx = kernel.emit_ptx_validated();
}
}