#![cfg(feature = "cuda")]
use trueno_gpu::kernels::{
BatchedResidualAddKernel, BatchedScaleKernel, BatchedSwigluKernel, BatchedTransposeKernel,
ElementwiseMulKernel, FusedResidualRmsNormKernel, FusedSwigluKernel, GeluKernel, Kernel,
ResidualAddKernel, RopeKernel, ScaleKernel, SiluKernel, TransposeKernel,
};
#[test]
fn golden_residual_add_kernel() {
let kernel = ResidualAddKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry residual_add"),
"GOLDEN FAIL: Missing residual_add entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("add.f32"),
"GOLDEN FAIL: Missing f32 add in residual_add\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("ld.global"),
"GOLDEN FAIL: Missing global load in residual_add\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("st.global"),
"GOLDEN FAIL: Missing global store in residual_add\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_batched_residual_add_kernel() {
let kernel = BatchedResidualAddKernel::new(768, 8);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry batched_residual_add"),
"GOLDEN FAIL: Missing batched_residual_add entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("%ctaid.y"),
"GOLDEN FAIL: Missing batch index (ctaid.y) in batched_residual_add\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_silu_kernel() {
let kernel = SiluKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry silu"),
"GOLDEN FAIL: Missing silu entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("ex2"),
"GOLDEN FAIL: Missing exp (ex2) in silu\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("div"),
"GOLDEN FAIL: Missing division in silu\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_gelu_kernel() {
let kernel = GeluKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry gelu"),
"GOLDEN FAIL: Missing gelu entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("mul.f32"),
"GOLDEN FAIL: Missing mul.f32 in gelu\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("ex2"),
"GOLDEN FAIL: Missing exp (ex2) in gelu\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_elementwise_mul_kernel() {
let kernel = ElementwiseMulKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry elementwise_mul"),
"GOLDEN FAIL: Missing elementwise_mul entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("mul.f32"),
"GOLDEN FAIL: Missing mul.f32 in elementwise_mul\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_scale_kernel() {
let kernel = ScaleKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry scale"),
"GOLDEN FAIL: Missing scale entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("mul.f32"),
"GOLDEN FAIL: Missing mul.f32 in scale\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_fused_residual_rmsnorm_kernel() {
let kernel = FusedResidualRmsNormKernel::new(768);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry fused_residual_rmsnorm"),
"GOLDEN FAIL: Missing fused_residual_rmsnorm entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("rsqrt"),
"GOLDEN FAIL: Missing rsqrt in fused_residual_rmsnorm\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("shfl"),
"GOLDEN FAIL: Missing warp shuffle in fused_residual_rmsnorm\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_fused_residual_rmsnorm_custom_epsilon() {
let kernel = FusedResidualRmsNormKernel::new(768).with_epsilon(1e-6);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry fused_residual_rmsnorm"),
"Custom epsilon should still produce valid kernel"
);
}
#[test]
fn golden_fused_swiglu_kernel() {
let kernel = FusedSwigluKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry fused_swiglu"),
"GOLDEN FAIL: Missing fused_swiglu entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("ex2"),
"GOLDEN FAIL: Missing exp (ex2) in fused_swiglu\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("mul.f32"),
"GOLDEN FAIL: Missing mul.f32 in fused_swiglu\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_batched_swiglu_kernel() {
let kernel = BatchedSwigluKernel::new(2048, 8);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry batched_swiglu"),
"GOLDEN FAIL: Missing batched_swiglu entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("%ctaid"),
"GOLDEN FAIL: Missing block index in batched_swiglu\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_rope_kernel() {
let kernel = RopeKernel::new(8, 64, 10000.0); let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry rope"),
"GOLDEN FAIL: Missing rope entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("sin") || ptx.contains("ex2"),
"GOLDEN FAIL: Missing trig functions in rope\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_transpose_kernel() {
let kernel = TransposeKernel::new(64, 64);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry transpose"),
"GOLDEN FAIL: Missing transpose entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("ld.global") && ptx.contains("st.global"),
"GOLDEN FAIL: Missing global memory ops in transpose\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_batched_transpose_kernel() {
let kernel = BatchedTransposeKernel::new(8, 64, 64); let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry batched_transpose"),
"GOLDEN FAIL: Missing batched_transpose entry\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_batched_scale_kernel() {
let kernel = BatchedScaleKernel::new(1024);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry batched_scale"),
"GOLDEN FAIL: Missing batched_scale entry\nPTX:\n{}",
ptx
);
assert!(
ptx.contains("mul.f32"),
"GOLDEN FAIL: Missing mul.f32 in batched_scale\nPTX:\n{}",
ptx
);
}
#[test]
fn golden_silu_various_sizes() {
for n in [256, 512, 1024, 2048, 4096] {
let kernel = SiluKernel::new(n);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry silu"),
"SiLU n={} should generate valid kernel",
n
);
}
}
#[test]
fn golden_gelu_various_sizes() {
for n in [256, 512, 1024, 2048] {
let kernel = GeluKernel::new(n);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry gelu"),
"GELU n={} should generate valid kernel",
n
);
}
}
#[test]
fn golden_fused_residual_rmsnorm_various_hidden() {
for hidden in [256, 512, 768, 1024, 2048, 4096] {
let kernel = FusedResidualRmsNormKernel::new(hidden);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry fused_residual_rmsnorm"),
"FusedResidualRmsNorm hidden={} should generate valid kernel",
hidden
);
}
}
#[test]
fn golden_rope_various_head_dims() {
for head_dim in [32, 64, 128] {
let kernel = RopeKernel::new(8, head_dim, 10000.0); let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".entry rope"),
"RoPE head_dim={} should generate valid kernel",
head_dim
);
}
}
#[test]
fn golden_elementwise_kernel_names() {
assert_eq!(ResidualAddKernel::new(1024).name(), "residual_add");
assert_eq!(
BatchedResidualAddKernel::new(768, 8).name(),
"batched_residual_add"
);
assert_eq!(SiluKernel::new(1024).name(), "silu");
assert_eq!(GeluKernel::new(1024).name(), "gelu");
assert_eq!(ElementwiseMulKernel::new(1024).name(), "elementwise_mul");
assert_eq!(ScaleKernel::new(1024).name(), "scale");
assert_eq!(
FusedResidualRmsNormKernel::new(768).name(),
"fused_residual_rmsnorm"
);
assert_eq!(FusedSwigluKernel::new(1024).name(), "fused_swiglu");
assert_eq!(BatchedSwigluKernel::new(1024, 8).name(), "batched_swiglu");
assert_eq!(RopeKernel::new(8, 64, 10000.0).name(), "rope");
assert_eq!(TransposeKernel::new(64, 64).name(), "transpose");
}