use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct GradientClipKernel {
pub n: u32,
}
impl GradientClipKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for GradientClipKernel {
fn name(&self) -> &str {
"gradient_clip"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("gradient_clip")
.param(PtxType::U64, "grads_ptr")
.param(PtxType::F32, "scale")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let grads_ptr = ctx.load_param_u64("grads_ptr");
let scale = ctx.load_param_f32("scale");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let grad_addr = ctx.add_u64(grads_ptr, offset);
let grad = ctx.ld_global_f32(grad_addr);
let grad_new = ctx.mul_f32(grad, scale);
ctx.st_global_f32(grad_addr, grad_new);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gradient_clip_kernel_name() {
let kernel = GradientClipKernel::new(1024);
assert_eq!(kernel.name(), "gradient_clip");
}
#[test]
fn test_gradient_clip_ptx_generation() {
let kernel = GradientClipKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry gradient_clip"));
assert!(ptx.contains(".param .u64 grads_ptr"));
assert!(ptx.contains(".param .f32 scale"));
assert!(ptx.contains("mul.f32"));
}
}