#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct GemmBackwardAKernel {
pub m: u32,
pub n: u32,
pub k: u32,
}
impl GemmBackwardAKernel {
#[must_use]
pub const fn new(m: u32, n: u32, k: u32) -> Self {
Self { m, n, k }
}
}
impl Kernel for GemmBackwardAKernel {
fn name(&self) -> &str {
"gemm_backward_a"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("gemm_backward_a")
.param(PtxType::U64, "grad_c_ptr")
.param(PtxType::U64, "b_ptr")
.param(PtxType::U64, "grad_a_ptr")
.param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.build(|ctx| {
let tid_x = ctx.special_reg(PtxReg::TidX);
let tid_y = ctx.special_reg(PtxReg::TidY);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
let ntid_x = ctx.special_reg(PtxReg::NtidX);
let ntid_y = ctx.special_reg(PtxReg::NtidY);
let row = ctx.mad_lo_u32(ctaid_y, ntid_y, tid_y);
let col = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let m = ctx.load_param_u32("m");
let n = ctx.load_param_u32("n");
let k = ctx.load_param_u32("k");
let grad_c_ptr = ctx.load_param_u64("grad_c_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let grad_a_ptr = ctx.load_param_u64("grad_a_ptr");
let valid_row = ctx.setp_lt_u32(row, m);
ctx.branch_if_not(valid_row, "exit");
let valid_col = ctx.setp_lt_u32(col, k);
ctx.branch_if_not(valid_col, "exit");
let acc = ctx.mov_f32_imm(0.0);
let four = ctx.mov_u32_imm(4);
let i = ctx.mov_u32_imm(0);
ctx.label("loop_start");
let loop_cond = ctx.setp_lt_u32(i, n);
ctx.branch_if_not(loop_cond, "loop_end");
let grad_c_row_offset = ctx.mul_lo_u32(row, n);
let grad_c_elem_idx = ctx.add_u32_reg(grad_c_row_offset, i);
let grad_c_byte_offset = ctx.mul_wide_u32_reg(grad_c_elem_idx, four);
let grad_c_addr = ctx.add_u64(grad_c_ptr, grad_c_byte_offset);
let grad_c_val = ctx.ld_global_f32(grad_c_addr);
let b_row_offset = ctx.mul_lo_u32(col, n);
let b_elem_idx = ctx.add_u32_reg(b_row_offset, i);
let b_byte_offset = ctx.mul_wide_u32_reg(b_elem_idx, four);
let b_addr = ctx.add_u64(b_ptr, b_byte_offset);
let b_val = ctx.ld_global_f32(b_addr);
let prod = ctx.mul_f32(grad_c_val, b_val);
ctx.add_f32_inplace(acc, prod);
ctx.add_u32_inplace(i, 1);
ctx.branch("loop_start");
ctx.label("loop_end");
let grad_a_row_offset = ctx.mul_lo_u32(row, k);
let grad_a_elem_idx = ctx.add_u32_reg(grad_a_row_offset, col);
let grad_a_byte_offset = ctx.mul_wide_u32_reg(grad_a_elem_idx, four);
let grad_a_addr = ctx.add_u64(grad_a_ptr, grad_a_byte_offset);
ctx.st_global_f32(grad_a_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct GemmBackwardBKernel {
pub m: u32,
pub n: u32,
pub k: u32,
}
impl GemmBackwardBKernel {
#[must_use]
pub const fn new(m: u32, n: u32, k: u32) -> Self {
Self { m, n, k }
}
}
impl Kernel for GemmBackwardBKernel {
fn name(&self) -> &str {
"gemm_backward_b"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("gemm_backward_b")
.param(PtxType::U64, "a_ptr")
.param(PtxType::U64, "grad_c_ptr")
.param(PtxType::U64, "grad_b_ptr")
.param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.build(|ctx| {
let tid_x = ctx.special_reg(PtxReg::TidX);
let tid_y = ctx.special_reg(PtxReg::TidY);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
let ntid_x = ctx.special_reg(PtxReg::NtidX);
let ntid_y = ctx.special_reg(PtxReg::NtidY);
let row = ctx.mad_lo_u32(ctaid_y, ntid_y, tid_y);
let col = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let m = ctx.load_param_u32("m");
let n = ctx.load_param_u32("n");
let k = ctx.load_param_u32("k");
let a_ptr = ctx.load_param_u64("a_ptr");
let grad_c_ptr = ctx.load_param_u64("grad_c_ptr");
let grad_b_ptr = ctx.load_param_u64("grad_b_ptr");
let valid_row = ctx.setp_lt_u32(row, k);
ctx.branch_if_not(valid_row, "exit");
let valid_col = ctx.setp_lt_u32(col, n);
ctx.branch_if_not(valid_col, "exit");
let acc = ctx.mov_f32_imm(0.0);
let four = ctx.mov_u32_imm(4);
let i = ctx.mov_u32_imm(0);
ctx.label("loop_start");
let loop_cond = ctx.setp_lt_u32(i, m);
ctx.branch_if_not(loop_cond, "loop_end");
let a_row_offset = ctx.mul_lo_u32(i, k);
let a_elem_idx = ctx.add_u32_reg(a_row_offset, row);
let a_byte_offset = ctx.mul_wide_u32_reg(a_elem_idx, four);
let a_addr = ctx.add_u64(a_ptr, a_byte_offset);
let a_val = ctx.ld_global_f32(a_addr);
let grad_c_row_offset = ctx.mul_lo_u32(i, n);
let grad_c_elem_idx = ctx.add_u32_reg(grad_c_row_offset, col);
let grad_c_byte_offset = ctx.mul_wide_u32_reg(grad_c_elem_idx, four);
let grad_c_addr = ctx.add_u64(grad_c_ptr, grad_c_byte_offset);
let grad_c_val = ctx.ld_global_f32(grad_c_addr);
let prod = ctx.mul_f32(a_val, grad_c_val);
ctx.add_f32_inplace(acc, prod);
ctx.add_u32_inplace(i, 1);
ctx.branch("loop_start");
ctx.label("loop_end");
let grad_b_row_offset = ctx.mul_lo_u32(row, n);
let grad_b_elem_idx = ctx.add_u32_reg(grad_b_row_offset, col);
let grad_b_byte_offset = ctx.mul_wide_u32_reg(grad_b_elem_idx, four);
let grad_b_addr = ctx.add_u64(grad_b_ptr, grad_b_byte_offset);
ctx.st_global_f32(grad_b_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gemm_backward_a_name() {
let kernel = GemmBackwardAKernel::new(64, 64, 64);
assert_eq!(kernel.name(), "gemm_backward_a");
}
#[test]
fn test_gemm_backward_a_ptx_generation() {
let kernel = GemmBackwardAKernel::new(64, 64, 64);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry gemm_backward_a"));
assert!(ptx.contains(".param .u64 grad_c_ptr"));
assert!(ptx.contains(".param .u64 b_ptr"));
assert!(ptx.contains(".param .u64 grad_a_ptr"));
assert!(ptx.contains("loop_start"));
assert!(ptx.contains("loop_end"));
}
#[test]
fn test_gemm_backward_a_barrier_safety() {
let kernel = GemmBackwardAKernel::new(32, 32, 32);
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"GEMM backward A should be barrier-safe: {:?}",
result.violations
);
}
#[test]
fn test_gemm_backward_b_name() {
let kernel = GemmBackwardBKernel::new(64, 64, 64);
assert_eq!(kernel.name(), "gemm_backward_b");
}
#[test]
fn test_gemm_backward_b_ptx_generation() {
let kernel = GemmBackwardBKernel::new(64, 64, 64);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry gemm_backward_b"));
assert!(ptx.contains(".param .u64 a_ptr"));
assert!(ptx.contains(".param .u64 grad_c_ptr"));
assert!(ptx.contains(".param .u64 grad_b_ptr"));
assert!(ptx.contains("loop_start"));
}
#[test]
fn test_gemm_backward_b_barrier_safety() {
let kernel = GemmBackwardBKernel::new(32, 32, 32);
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"GEMM backward B should be barrier-safe: {:?}",
result.violations
);
}
}