#![cfg(feature = "cuda")]
use trueno_gpu::kernels::{
AttentionKernel, Batched4DGemmKernel, BatchedGemmKernel, BatchedIncrementalAttentionKernel,
BatchedQ4KGemvKernel, BatchedQ6KGemvKernel, ChunkedTiledQ4KGemvKernel, CoalescedGemvKernel,
CoalescedQ4KGemvKernel, CoalescedQ6KGemvKernel, Dp4aQ4KGemvKernel, FlashDecodingChunkKernel,
FlashDecodingReduceKernel, Fp16Q4KGemvKernel, FusedGateUpQ4KGemvKernel,
FusedRmsNormQ4KGemvKernel, GemmKernel, GemvKernel, IncrementalAttentionKernel, Kernel,
LongRowSoftmaxKernel, Lz4WarpCompressKernel, MultiWarpIncrementalAttentionKernel,
PackedDp4aQ4KQ8Kernel, Q4KGemvKernel, Q4KQ8DotKernel, Q5KGemvKernel, Q5KKernel, Q6KGemvKernel,
Q6KKernel, TiledQ4KGemvKernel, VectorizedQ4KGemvKernel,
};
fn assert_valid_ptx(ptx: &str, kernel_name: &str) {
assert!(
ptx.contains(".version"),
"{}: Missing .version directive\nPTX (first 200):\n{}",
kernel_name,
&ptx[..ptx.len().min(200)]
);
assert!(
ptx.contains(".entry"),
"{}: Missing .entry directive",
kernel_name
);
assert!(
ptx.contains(".target"),
"{}: Missing .target directive",
kernel_name
);
}
#[test]
fn test_lz4_warp_compress_build_ptx() {
let kernel = Lz4WarpCompressKernel::new(4);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "lz4_warp_compress");
assert!(ptx.contains("ld.global"), "Missing global loads");
assert!(ptx.contains("st.global"), "Missing global stores");
}
#[test]
fn test_coalesced_q4k_gemv_build_ptx() {
let kernel = CoalescedQ4KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "coalesced_q4k_gemv");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_vectorized_q4k_gemv_build_ptx() {
let kernel = VectorizedQ4KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "vectorized_q4k_gemv");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_incremental_attention_build_ptx() {
let kernel = IncrementalAttentionKernel::new(2048, 128, 32);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "incremental_attention");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_multi_warp_incremental_attention_build_ptx() {
let kernel = MultiWarpIncrementalAttentionKernel::new(2048, 128, 32, 8, 4);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "multi_warp_incremental_attention");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_batched_incremental_attention_build_ptx() {
let kernel = BatchedIncrementalAttentionKernel::new(2048, 128, 32, 8, 4);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "batched_incremental_attention");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_flash_decoding_chunk_build_ptx() {
let kernel = FlashDecodingChunkKernel::new(2048, 128, 32, 8, 4);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "flash_decoding_chunk");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_flash_decoding_reduce_build_ptx() {
let kernel = FlashDecodingReduceKernel::new(128, 32, 4);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "flash_decoding_reduce");
}
#[test]
fn test_fused_rmsnorm_q4k_gemv_build_ptx() {
let kernel = FusedRmsNormQ4KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "fused_rmsnorm_q4k_gemv");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_fused_gate_up_q4k_gemv_build_ptx() {
let kernel = FusedGateUpQ4KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "fused_gate_up_q4k_gemv");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_q6k_gemv_build_ptx() {
let kernel = Q6KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "q6k_gemv");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_coalesced_q6k_gemv_build_ptx() {
let kernel = CoalescedQ6KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "coalesced_q6k_gemv");
}
#[test]
fn test_batched_q6k_gemv_build_ptx() {
let kernel = BatchedQ6KGemvKernel::new(4, 1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "batched_q6k_gemv");
}
#[test]
fn test_q6k_kernel_build_ptx() {
let kernel = Q6KKernel::new(64, 64, 64);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "q6k_kernel");
}
#[test]
fn test_gemm_tensor_core_build_ptx() {
let kernel = GemmKernel::tensor_core(64, 64, 64);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "gemm_tensor_core");
assert!(ptx.contains("fma"), "Missing fma instructions");
}
#[test]
fn test_gemm_tiled_unrolled_build_ptx() {
let kernel = GemmKernel::tiled_unrolled(64, 64, 64, 16);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "gemm_tiled_unrolled");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_batched_q4k_gemv_build_ptx() {
let kernel = BatchedQ4KGemvKernel::new(4, 1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "batched_q4k_gemv");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_fp16_q4k_gemv_build_ptx() {
let kernel = Fp16Q4KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "fp16_q4k_gemv");
}
#[test]
fn test_q4k_gemv_build_ptx() {
let kernel = Q4KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "q4k_gemv");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_dp4a_q4k_gemv_build_ptx() {
let kernel = Dp4aQ4KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "dp4a_q4k_gemv");
}
#[test]
fn test_tiled_q4k_gemv_build_ptx() {
let kernel = TiledQ4KGemvKernel::new(1536, 4);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "tiled_q4k_gemv");
}
#[test]
fn test_chunked_tiled_q4k_gemv_build_ptx() {
let kernel = ChunkedTiledQ4KGemvKernel::new(1536, 4);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "chunked_tiled_q4k_gemv");
}
#[test]
fn test_flash_attention_build_ptx() {
let kernel = AttentionKernel::new(128, 64);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "flash_attention");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_flash_attention_causal_build_ptx() {
let mut kernel = AttentionKernel::new(128, 64);
kernel.causal = true;
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "flash_attention_causal");
}
#[test]
fn test_flash_attention_tensor_core_build_ptx() {
let mut kernel = AttentionKernel::new(128, 64);
kernel.use_tensor_cores = true;
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "flash_attention_tensor_core");
}
#[test]
fn test_q4k_q8_dot_build_ptx() {
let kernel = Q4KQ8DotKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "q4k_q8_dot");
}
#[test]
fn test_packed_dp4a_q4k_q8_build_ptx() {
let kernel = PackedDp4aQ4KQ8Kernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "packed_dp4a_q4k_q8");
}
#[test]
fn test_q5k_kernel_build_ptx() {
let kernel = Q5KKernel::new(64, 64, 64);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "q5k_kernel");
}
#[test]
fn test_q5k_gemv_build_ptx() {
let kernel = Q5KGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "q5k_gemv");
}
#[test]
fn test_long_row_softmax_build_ptx() {
let kernel = LongRowSoftmaxKernel::new(32768);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "long_row_softmax");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_coalesced_gemv_build_ptx() {
let kernel = CoalescedGemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "coalesced_gemv");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_batched_gemm_tiled_unrolled_build_ptx() {
let kernel = BatchedGemmKernel::tiled_unrolled(4, 64, 64, 64, 16);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "batched_gemm_tiled_unrolled");
}
#[test]
fn test_batched_4d_gemm_build_ptx() {
let kernel = Batched4DGemmKernel::new(2, 8, 64, 64, 64);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "batched_4d_gemm");
assert!(ptx.contains("ld.global"), "Missing global loads");
}
#[test]
fn test_gemv_basic_build_ptx() {
let kernel = GemvKernel::new(1536, 8960);
let ptx = kernel.emit_ptx();
assert_valid_ptx(&ptx, "gemv_basic");
}