#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct ClipScaleReduceKernel;
impl Kernel for ClipScaleReduceKernel {
fn name(&self) -> &str {
"clip_scale_reduce"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("clip_scale_reduce")
.param(PtxType::U64, "partials_ptr")
.param(PtxType::U32, "total_n")
.param(PtxType::F32, "max_norm")
.param(PtxType::U64, "output_ptr")
.build(|ctx| {
let partials_ptr = ctx.load_param_u64("partials_ptr");
let total_n = ctx.load_param_u32("total_n");
let max_norm = ctx.load_param_f32("max_norm");
let output_ptr = ctx.load_param_u64("output_ptr");
let four = ctx.mov_u32_imm(4);
let total_sq = ctx.mov_f32_imm(0.0);
let i = ctx.mov_u32_imm(0);
ctx.label("sum_loop");
let in_bounds = ctx.setp_lt_u32(i, total_n);
ctx.branch_if_not(in_bounds, "sum_done");
let byte_offset = ctx.mul_wide_u32_reg(i, four);
let addr = ctx.add_u64(partials_ptr, byte_offset);
let val = ctx.ld_global_f32(addr);
ctx.add_f32_inplace(total_sq, val);
let one = ctx.mov_u32_imm(1);
ctx.add_u32_reg_inplace(i, one);
ctx.branch("sum_loop");
ctx.label("sum_done");
let norm = ctx.sqrt_f32(total_sq);
let raw_scale = ctx.div_f32(max_norm, norm);
let one_f32 = ctx.mov_f32_imm(1.0);
let scale = ctx.min_f32(raw_scale, one_f32);
ctx.st_global_f32(output_ptr, scale);
let four_u64 = ctx.mov_u64_imm(4);
let norm_addr = ctx.add_u64(output_ptr, four_u64);
ctx.st_global_f32(norm_addr, norm);
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct GradientClipGpuScaleKernel {
pub n: u32,
}
impl GradientClipGpuScaleKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for GradientClipGpuScaleKernel {
fn name(&self) -> &str {
"gradient_clip_gpu_scale"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("gradient_clip_gpu_scale")
.param(PtxType::U64, "grads_ptr")
.param(PtxType::U64, "scale_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let scale_ptr = ctx.load_param_u64("scale_ptr");
let scale = ctx.ld_global_f32(scale_ptr);
let one = ctx.mov_f32_imm(1.0);
let diff = ctx.sub_f32(scale, one);
let abs_diff = ctx.abs_f32(diff);
let threshold = ctx.mov_f32_imm(1e-7);
let no_clip = ctx.setp_lt_f32(abs_diff, threshold);
ctx.branch_if(no_clip, "exit");
let grads_ptr = ctx.load_param_u64("grads_ptr");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let grad_addr = ctx.add_u64(grads_ptr, offset);
let grad = ctx.ld_global_f32(grad_addr);
let grad_new = ctx.mul_f32(grad, scale);
ctx.st_global_f32(grad_addr, grad_new);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clip_scale_reduce_kernel_name() {
let kernel = ClipScaleReduceKernel;
assert_eq!(kernel.name(), "clip_scale_reduce");
}
#[test]
fn test_clip_scale_reduce_ptx_generation() {
let kernel = ClipScaleReduceKernel;
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry clip_scale_reduce"));
assert!(ptx.contains(".param .u64 partials_ptr"));
assert!(ptx.contains(".param .u32 total_n"));
assert!(ptx.contains(".param .f32 max_norm"));
assert!(ptx.contains(".param .u64 output_ptr"));
assert!(ptx.contains("sqrt"));
assert!(ptx.contains("min"));
assert!(ptx.contains("div"));
}
#[test]
fn test_gradient_clip_gpu_scale_kernel_name() {
let kernel = GradientClipGpuScaleKernel::new(1024);
assert_eq!(kernel.name(), "gradient_clip_gpu_scale");
}
#[test]
fn test_gradient_clip_gpu_scale_ptx_generation() {
let kernel = GradientClipGpuScaleKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry gradient_clip_gpu_scale"));
assert!(ptx.contains(".param .u64 grads_ptr"));
assert!(ptx.contains(".param .u64 scale_ptr"));
assert!(ptx.contains(".param .u32 n"));
assert!(!ptx.contains(".param .f32 scale"));
assert!(ptx.contains("mul.f32"));
}
#[test]
fn test_clip_scale_reduce_barrier_safety() {
let kernel = ClipScaleReduceKernel;
let _ptx = kernel.emit_ptx_validated();
}
#[test]
fn test_gradient_clip_gpu_scale_barrier_safety() {
let kernel = GradientClipGpuScaleKernel::new(1024);
let _ptx = kernel.emit_ptx_validated();
}
}