use super::*;
#[test]
fn test_softmax_kernel_name() {
let kernel = SoftmaxKernel::new(4096);
assert_eq!(kernel.name(), "softmax_warp_shuffle");
let kernel_shared = SoftmaxKernel::new(4096).without_warp_shuffle();
assert_eq!(kernel_shared.name(), "softmax_shared");
}
#[test]
fn test_long_row_softmax_ptx_generation() {
let kernel = LongRowSoftmaxKernel::new(1500);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("softmax_long_row"), "Missing kernel name");
assert!(ptx.contains(".param .u64 input_ptr"), "Missing input_ptr param");
assert!(ptx.contains(".param .u64 output_ptr"), "Missing output_ptr param");
assert!(ptx.contains(".param .u32 row_size"), "Missing row_size param");
assert!(ptx.contains("max_loop:"), "Missing max_loop label");
assert!(ptx.contains("max_loop_done:"), "Missing max_loop_done label");
assert!(ptx.contains("sum_loop:"), "Missing sum_loop label");
assert!(ptx.contains("write_loop:"), "Missing write_loop label");
assert!(ptx.contains("bar.sync"), "Missing barrier sync");
assert!(
ptx.contains("shfl") || ptx.contains("shfl.down") || ptx.contains("shfl.sync.down"),
"Missing warp shuffle"
);
for (i, line) in ptx.lines().enumerate().take(300) {
println!("{:4}: {}", i + 1, line);
}
}
#[test]
fn test_softmax_ptx_generation() {
let kernel = SoftmaxKernel::new(4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".param .u64 input_ptr"));
assert!(ptx.contains(".param .u64 output_ptr"));
assert!(ptx.contains(".param .u32 length"));
}
#[test]
fn test_softmax_shared_memory() {
let kernel = SoftmaxKernel::new(4096).without_warp_shuffle();
let ptx_kernel = kernel.build_ptx();
assert!(ptx_kernel.shared_memory_bytes() > 0);
}
#[test]
fn test_softmax_warp_shuffle_ptx() {
let kernel = SoftmaxKernel::new(32);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("shfl") || ptx.contains("shfl.down"));
assert!(ptx.contains("max.f32"));
assert!(ptx.contains("ex2.f32") || ptx.contains("ex2"));
assert!(ptx.contains("div.rn.f32"));
assert!(ptx.contains("ld.global.f32"));
assert!(ptx.contains("st.global.f32"));
}
#[test]
fn test_softmax_shared_memory_ptx() {
let kernel = SoftmaxKernel::new(256).without_warp_shuffle();
let ptx = kernel.emit_ptx();
assert!(ptx.contains("ld.shared.f32") || ptx.contains("ld.f32"));
assert!(ptx.contains("st.shared.f32") || ptx.contains("st.f32"));
assert!(ptx.contains("bar"));
assert!(ptx.contains("ex2.f32") || ptx.contains("ex2"));
assert!(ptx.contains("div.rn.f32")); }
#[test]
fn test_softmax_kernel_variants() {
let warp_kernel = SoftmaxKernel::new(32);
let shared_kernel = SoftmaxKernel::new(256).without_warp_shuffle();
let warp_ptx = warp_kernel.emit_ptx();
let shared_ptx = shared_kernel.emit_ptx();
assert!(!warp_ptx.is_empty());
assert!(!shared_ptx.is_empty());
assert!(warp_ptx.contains("softmax_warp_shuffle"));
assert!(shared_ptx.contains("softmax_shared"));
}
#[test]
fn test_softmax_numerical_stability() {
let kernel = SoftmaxKernel::new(32);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("sub.f32"));
assert!(ptx.contains("mul.f32"));
}
#[test]
fn test_shared_max_reduce_loop_iterates() {
let kernel = SoftmaxKernel::new(256).without_warp_shuffle();
let ptx = kernel.emit_ptx();
let has_loop_back = ptx.contains("bra max_reduce_loop") || ptx.contains("bra\tmax_reduce_loop");
assert!(
has_loop_back,
"FALSIFIED: Max-reduce loop does not branch back to loop start. \
Found 'bra max_reduce_done' instead of 'bra max_reduce_loop'. \
This means max reduction only runs once, producing wrong max."
);
}
#[test]
fn test_shared_max_reduce_stride_halves() {
let kernel = SoftmaxKernel::new(256).without_warp_shuffle();
let ptx = kernel.emit_ptx();
let has_stride_update =
ptx.contains("shr.b32") || ptx.contains("shr.u32") || ptx.contains("div.u32");
assert!(
has_stride_update,
"FALSIFIED: Max-reduce stride is not halved. \
Expected shr.b32, shr.u32 or div.u32 for stride = stride / 2. \
Without this, tree reduction cannot work correctly."
);
}
#[test]
fn test_shared_sum_reduce_implemented() {
let kernel = SoftmaxKernel::new(256).without_warp_shuffle();
let ptx = kernel.emit_ptx();
let has_sum_loop = ptx.contains("sum_reduce_loop");
let has_sum_done = ptx.contains("sum_reduce_done");
let has_loop_back = ptx.contains("bra sum_reduce_loop") || ptx.contains("bra\tsum_reduce_loop");
assert!(
has_sum_loop && has_sum_done && has_loop_back,
"FALSIFIED: Sum reduction loop structure is incomplete. \
has_sum_loop={}, has_sum_done={}, has_loop_back={}. \
A proper tree reduction needs a complete loop structure.",
has_sum_loop,
has_sum_done,
has_loop_back
);
}