mod activations;
mod gemm;
mod layer_norm;
mod rms_norm;
mod softmax;
pub use activations::{GeluBackwardKernel, ReluBackwardKernel, SiluBackwardKernel};
pub use gemm::{GemmBackwardAKernel, GemmBackwardBKernel};
pub use layer_norm::LayerNormBackwardKernel;
pub use rms_norm::RmsNormBackwardKernel;
pub use softmax::SoftmaxBackwardKernel;
#[cfg(test)]
mod tests {
use super::*;
use crate::kernels::Kernel;
#[test]
fn test_all_backward_kernels_compile() {
let relu = ReluBackwardKernel::new(1024);
let gelu = GeluBackwardKernel::new(1024);
let silu = SiluBackwardKernel::new(1024);
assert!(relu.emit_ptx().contains(".entry"));
assert!(gelu.emit_ptx().contains(".entry"));
assert!(silu.emit_ptx().contains(".entry"));
let softmax = SoftmaxBackwardKernel::new(64, 32);
assert!(softmax.emit_ptx().contains(".entry"));
let rms_norm = RmsNormBackwardKernel::new(64, 32, 1e-6);
assert!(rms_norm.emit_ptx().contains(".entry"));
let layer_norm = LayerNormBackwardKernel::new(64, 32);
assert!(layer_norm.emit_ptx().contains(".entry"));
let gemm_a = GemmBackwardAKernel::new(64, 64, 64);
assert!(gemm_a.emit_ptx().contains(".entry"));
let gemm_b = GemmBackwardBKernel::new(64, 64, 64);
assert!(gemm_b.emit_ptx().contains(".entry"));
}
#[test]
fn test_all_backward_kernels_barrier_safe() {
let kernels: Vec<Box<dyn Kernel>> = vec![
Box::new(ReluBackwardKernel::new(1024)),
Box::new(GeluBackwardKernel::new(1024)),
Box::new(SiluBackwardKernel::new(1024)),
Box::new(SoftmaxBackwardKernel::new(64, 32)),
Box::new(RmsNormBackwardKernel::new(64, 32, 1e-6)),
Box::new(LayerNormBackwardKernel::new(64, 32)),
Box::new(GemmBackwardAKernel::new(64, 64, 64)),
Box::new(GemmBackwardBKernel::new(64, 64, 64)),
];
for kernel in kernels {
let result = kernel.analyze_barrier_safety();
let name = kernel.name();
assert!(result.is_safe, "Kernel {name} should be barrier-safe");
assert!(result.violations.is_empty());
}
}
#[test]
fn test_reexported_kernel_names() {
assert_eq!(ReluBackwardKernel::new(128).name(), "relu_backward");
assert_eq!(GeluBackwardKernel::new(128).name(), "gelu_backward");
assert_eq!(SiluBackwardKernel::new(128).name(), "silu_backward");
assert_eq!(
SoftmaxBackwardKernel::new(16, 16).name(),
"softmax_backward"
);
assert_eq!(
RmsNormBackwardKernel::new(16, 16, 1e-5).name(),
"rms_norm_backward"
);
assert_eq!(
LayerNormBackwardKernel::new(16, 16).name(),
"layer_norm_backward"
);
assert_eq!(
GemmBackwardAKernel::new(32, 32, 32).name(),
"gemm_backward_a"
);
assert_eq!(
GemmBackwardBKernel::new(32, 32, 32).name(),
"gemm_backward_b"
);
}
#[test]
fn test_activation_kernels_clone_and_debug() {
let relu = ReluBackwardKernel::new(512);
let relu_clone = relu.clone();
assert_eq!(relu.n, relu_clone.n);
let gelu = GeluBackwardKernel::new(512);
let gelu_clone = gelu.clone();
assert_eq!(gelu.n, gelu_clone.n);
let silu = SiluBackwardKernel::new(512);
let silu_clone = silu.clone();
assert_eq!(silu.n, silu_clone.n);
let debug_str = format!("{relu:?}");
assert!(debug_str.contains("ReluBackwardKernel"));
assert!(debug_str.contains("512"));
}
#[test]
fn test_softmax_kernel_clone_and_debug() {
let kernel = SoftmaxBackwardKernel::new(128, 32);
let cloned = kernel.clone();
assert_eq!(kernel.num_rows, cloned.num_rows);
assert_eq!(kernel.row_size, cloned.row_size);
let debug_str = format!("{kernel:?}");
assert!(debug_str.contains("SoftmaxBackwardKernel"));
assert!(debug_str.contains("128"));
assert!(debug_str.contains("32"));
}
#[test]
fn test_rms_norm_kernel_clone_and_debug() {
let kernel = RmsNormBackwardKernel::new(64, 16, 1e-6);
let cloned = kernel.clone();
assert_eq!(kernel.num_rows, cloned.num_rows);
assert_eq!(kernel.hidden_dim, cloned.hidden_dim);
assert!((kernel.eps - cloned.eps).abs() < 1e-10);
let debug_str = format!("{kernel:?}");
assert!(debug_str.contains("RmsNormBackwardKernel"));
assert!(debug_str.contains("64"));
assert!(debug_str.contains("16"));
}
#[test]
fn test_layer_norm_kernel_clone_and_debug() {
let kernel = LayerNormBackwardKernel::new(256, 8);
let cloned = kernel.clone();
assert_eq!(kernel.num_rows, cloned.num_rows);
assert_eq!(kernel.hidden_dim, cloned.hidden_dim);
let debug_str = format!("{kernel:?}");
assert!(debug_str.contains("LayerNormBackwardKernel"));
assert!(debug_str.contains("256"));
assert!(debug_str.contains("8"));
}
#[test]
fn test_gemm_kernels_clone_and_debug() {
let gemm_a = GemmBackwardAKernel::new(128, 256, 64);
let gemm_a_clone = gemm_a.clone();
assert_eq!(gemm_a.m, gemm_a_clone.m);
assert_eq!(gemm_a.n, gemm_a_clone.n);
assert_eq!(gemm_a.k, gemm_a_clone.k);
let gemm_b = GemmBackwardBKernel::new(128, 256, 64);
let gemm_b_clone = gemm_b.clone();
assert_eq!(gemm_b.m, gemm_b_clone.m);
assert_eq!(gemm_b.n, gemm_b_clone.n);
assert_eq!(gemm_b.k, gemm_b_clone.k);
let debug_a = format!("{gemm_a:?}");
assert!(debug_a.contains("GemmBackwardAKernel"));
assert!(debug_a.contains("128"));
let debug_b = format!("{gemm_b:?}");
assert!(debug_b.contains("GemmBackwardBKernel"));
assert!(debug_b.contains("256"));
}
#[test]
fn test_ptx_contains_correct_params() {
let relu_ptx = ReluBackwardKernel::new(64).emit_ptx();
assert!(relu_ptx.contains("input_ptr"));
assert!(relu_ptx.contains("grad_output_ptr"));
assert!(relu_ptx.contains("grad_input_ptr"));
let softmax_ptx = SoftmaxBackwardKernel::new(32, 16).emit_ptx();
assert!(softmax_ptx.contains("output_ptr"));
assert!(softmax_ptx.contains("num_rows"));
assert!(softmax_ptx.contains("row_size"));
let rms_norm_ptx = RmsNormBackwardKernel::new(32, 16, 1e-5).emit_ptx();
assert!(rms_norm_ptx.contains("gamma_ptr"));
assert!(rms_norm_ptx.contains("rms_ptr"));
let layer_norm_ptx = LayerNormBackwardKernel::new(32, 16).emit_ptx();
assert!(layer_norm_ptx.contains("mean_ptr"));
assert!(layer_norm_ptx.contains("rstd_ptr"));
let gemm_a_ptx = GemmBackwardAKernel::new(16, 16, 16).emit_ptx();
assert!(gemm_a_ptx.contains("grad_c_ptr"));
assert!(gemm_a_ptx.contains("b_ptr"));
assert!(gemm_a_ptx.contains("grad_a_ptr"));
let gemm_b_ptx = GemmBackwardBKernel::new(16, 16, 16).emit_ptx();
assert!(gemm_b_ptx.contains("a_ptr"));
assert!(gemm_b_ptx.contains("grad_b_ptr"));
}
#[test]
fn test_backward_kernels_with_various_sizes() {
let relu_small = ReluBackwardKernel::new(1);
assert!(relu_small.emit_ptx().contains(".entry"));
let gelu_small = GeluBackwardKernel::new(1);
assert!(gelu_small.emit_ptx().contains(".entry"));
let silu_small = SiluBackwardKernel::new(1);
assert!(silu_small.emit_ptx().contains(".entry"));
let softmax_min = SoftmaxBackwardKernel::new(1, 1);
assert!(softmax_min.emit_ptx().contains(".entry"));
let rms_norm_min = RmsNormBackwardKernel::new(1, 1, 1e-8);
assert!(rms_norm_min.emit_ptx().contains(".entry"));
let layer_norm_min = LayerNormBackwardKernel::new(1, 1);
assert!(layer_norm_min.emit_ptx().contains(".entry"));
let gemm_min = GemmBackwardAKernel::new(1, 1, 1);
assert!(gemm_min.emit_ptx().contains(".entry"));
let gemm_b_min = GemmBackwardBKernel::new(1, 1, 1);
assert!(gemm_b_min.emit_ptx().contains(".entry"));
}
#[test]
fn test_backward_kernels_large_sizes() {
let relu_large = ReluBackwardKernel::new(1_000_000);
assert!(relu_large.emit_ptx().contains(".entry"));
assert_eq!(relu_large.n, 1_000_000);
let gelu_large = GeluBackwardKernel::new(1_000_000);
assert!(gelu_large.emit_ptx().contains(".entry"));
assert_eq!(gelu_large.n, 1_000_000);
let silu_large = SiluBackwardKernel::new(1_000_000);
assert!(silu_large.emit_ptx().contains(".entry"));
assert_eq!(silu_large.n, 1_000_000);
}
#[test]
fn test_backward_ptx_version_and_target() {
let kernels_ptx: Vec<String> = vec![
ReluBackwardKernel::new(64).emit_ptx(),
GeluBackwardKernel::new(64).emit_ptx(),
SiluBackwardKernel::new(64).emit_ptx(),
SoftmaxBackwardKernel::new(16, 16).emit_ptx(),
RmsNormBackwardKernel::new(16, 16, 1e-6).emit_ptx(),
LayerNormBackwardKernel::new(16, 16).emit_ptx(),
GemmBackwardAKernel::new(16, 16, 16).emit_ptx(),
GemmBackwardBKernel::new(16, 16, 16).emit_ptx(),
];
for ptx in kernels_ptx {
assert!(ptx.contains(".version"), "PTX should contain version");
assert!(ptx.contains(".target"), "PTX should contain target");
assert!(
ptx.contains("ret;"),
"PTX should contain return instruction"
);
}
}
}