use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct AdamWStepKernel {
pub n: u32,
}
impl AdamWStepKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for AdamWStepKernel {
fn name(&self) -> &str {
"adamw_step"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("adamw_step")
.param(PtxType::U64, "params_ptr")
.param(PtxType::U64, "grads_ptr")
.param(PtxType::U64, "m_ptr")
.param(PtxType::U64, "v_ptr")
.param(PtxType::F32, "lr")
.param(PtxType::F32, "beta1")
.param(PtxType::F32, "beta2")
.param(PtxType::F32, "eps")
.param(PtxType::F32, "weight_decay")
.param(PtxType::F32, "bias_correction1") .param(PtxType::F32, "bias_correction2") .param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let params_ptr = ctx.load_param_u64("params_ptr");
let grads_ptr = ctx.load_param_u64("grads_ptr");
let m_ptr = ctx.load_param_u64("m_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let lr = ctx.load_param_f32("lr");
let beta1 = ctx.load_param_f32("beta1");
let beta2 = ctx.load_param_f32("beta2");
let eps = ctx.load_param_f32("eps");
let weight_decay = ctx.load_param_f32("weight_decay");
let bias_corr1 = ctx.load_param_f32("bias_correction1");
let bias_corr2 = ctx.load_param_f32("bias_correction2");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let param_addr = ctx.add_u64(params_ptr, offset);
let grad_addr = ctx.add_u64(grads_ptr, offset);
let m_addr = ctx.add_u64(m_ptr, offset);
let v_addr = ctx.add_u64(v_ptr, offset);
let param = ctx.ld_global_f32(param_addr);
let grad = ctx.ld_global_f32(grad_addr);
let m = ctx.ld_global_f32(m_addr);
let v = ctx.ld_global_f32(v_addr);
let one = ctx.mov_f32_imm(1.0);
let one_minus_beta1 = ctx.sub_f32(one, beta1);
let one_minus_beta2 = ctx.sub_f32(one, beta2);
let m_scaled = ctx.mul_f32(beta1, m);
let grad_scaled = ctx.mul_f32(one_minus_beta1, grad);
let m_new = ctx.add_f32(m_scaled, grad_scaled);
let v_scaled = ctx.mul_f32(beta2, v);
let grad_sq = ctx.mul_f32(grad, grad);
let grad_sq_scaled = ctx.mul_f32(one_minus_beta2, grad_sq);
let v_new = ctx.add_f32(v_scaled, grad_sq_scaled);
let m_hat = ctx.mul_f32(m_new, bias_corr1);
let v_hat = ctx.mul_f32(v_new, bias_corr2);
let sqrt_v = ctx.sqrt_f32(v_hat);
let denom = ctx.add_f32(sqrt_v, eps);
let adam_update = ctx.div_f32(m_hat, denom);
let decay_term = ctx.mul_f32(weight_decay, param);
let total_update = ctx.add_f32(adam_update, decay_term);
let lr_update = ctx.mul_f32(lr, total_update);
let param_new = ctx.sub_f32(param, lr_update);
ctx.st_global_f32(param_addr, param_new);
ctx.st_global_f32(m_addr, m_new);
ctx.st_global_f32(v_addr, v_new);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct AdamStepKernel {
pub n: u32,
}
impl AdamStepKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for AdamStepKernel {
fn name(&self) -> &str {
"adam_step"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("adam_step")
.param(PtxType::U64, "params_ptr")
.param(PtxType::U64, "grads_ptr")
.param(PtxType::U64, "m_ptr")
.param(PtxType::U64, "v_ptr")
.param(PtxType::F32, "lr")
.param(PtxType::F32, "beta1")
.param(PtxType::F32, "beta2")
.param(PtxType::F32, "eps")
.param(PtxType::F32, "bias_correction1")
.param(PtxType::F32, "bias_correction2")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let params_ptr = ctx.load_param_u64("params_ptr");
let grads_ptr = ctx.load_param_u64("grads_ptr");
let m_ptr = ctx.load_param_u64("m_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let lr = ctx.load_param_f32("lr");
let beta1 = ctx.load_param_f32("beta1");
let beta2 = ctx.load_param_f32("beta2");
let eps = ctx.load_param_f32("eps");
let bias_corr1 = ctx.load_param_f32("bias_correction1");
let bias_corr2 = ctx.load_param_f32("bias_correction2");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let param_addr = ctx.add_u64(params_ptr, offset);
let grad_addr = ctx.add_u64(grads_ptr, offset);
let m_addr = ctx.add_u64(m_ptr, offset);
let v_addr = ctx.add_u64(v_ptr, offset);
let param = ctx.ld_global_f32(param_addr);
let grad = ctx.ld_global_f32(grad_addr);
let m = ctx.ld_global_f32(m_addr);
let v = ctx.ld_global_f32(v_addr);
let one = ctx.mov_f32_imm(1.0);
let one_minus_beta1 = ctx.sub_f32(one, beta1);
let one_minus_beta2 = ctx.sub_f32(one, beta2);
let m_scaled = ctx.mul_f32(beta1, m);
let grad_scaled = ctx.mul_f32(one_minus_beta1, grad);
let m_new = ctx.add_f32(m_scaled, grad_scaled);
let v_scaled = ctx.mul_f32(beta2, v);
let grad_sq = ctx.mul_f32(grad, grad);
let grad_sq_scaled = ctx.mul_f32(one_minus_beta2, grad_sq);
let v_new = ctx.add_f32(v_scaled, grad_sq_scaled);
let m_hat = ctx.mul_f32(m_new, bias_corr1);
let v_hat = ctx.mul_f32(v_new, bias_corr2);
let sqrt_v = ctx.sqrt_f32(v_hat);
let denom = ctx.add_f32(sqrt_v, eps);
let update = ctx.div_f32(m_hat, denom);
let lr_update = ctx.mul_f32(lr, update);
let param_new = ctx.sub_f32(param, lr_update);
ctx.st_global_f32(param_addr, param_new);
ctx.st_global_f32(m_addr, m_new);
ctx.st_global_f32(v_addr, v_new);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adamw_step_kernel_name() {
let kernel = AdamWStepKernel::new(1024);
assert_eq!(kernel.name(), "adamw_step");
}
#[test]
fn test_adamw_ptx_generation() {
let kernel = AdamWStepKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry adamw_step"));
assert!(ptx.contains(".param .u64 params_ptr"));
assert!(ptx.contains(".param .u64 grads_ptr"));
assert!(ptx.contains(".param .f32 lr"));
assert!(ptx.contains(".param .f32 weight_decay"));
assert!(ptx.contains("sqrt.rn.f32")); assert!(ptx.contains("div.rn.f32")); }
#[test]
fn test_adam_step_kernel_name() {
let kernel = AdamStepKernel::new(1024);
assert_eq!(kernel.name(), "adam_step");
}
#[test]
fn test_adam_ptx_generation() {
let kernel = AdamStepKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry adam_step"));
assert!(!ptx.contains("weight_decay"));
assert!(ptx.contains("sqrt.rn.f32"));
assert!(ptx.contains("div.rn.f32"));
}
}