mod common;
#[test]
fn emit_full_vector_add() {
let ptx = common::build_vector_add_ptx();
assert!(ptx.starts_with(".version 8.7\n"));
assert!(ptx.contains(".target sm_"));
assert!(ptx.contains(".address_size 64\n"));
assert!(ptx.contains(".visible .entry vector_add("));
assert!(ptx.contains(".param .u64 a_ptr,"));
assert!(ptx.contains(".param .u64 b_ptr,"));
assert!(ptx.contains(".param .u64 c_ptr,"));
assert!(ptx.contains(".param .u32 n"));
assert!(ptx.contains(".reg .b32 %r<"));
assert!(ptx.contains(".reg .b64 %rd<"));
assert!(ptx.contains(".reg .f32 %f<"));
assert!(ptx.contains(".reg .pred %p<"));
assert!(ptx.contains("ld.param.u64 %rd0, [a_ptr];"));
assert!(ptx.contains("ld.param.u32 %r0, [n];"));
assert!(ptx.contains("mov.u32 %r1, %ctaid.x;"));
assert!(ptx.contains("mov.u32 %r2, %ntid.x;"));
assert!(ptx.contains("mov.u32 %r3, %tid.x;"));
assert!(ptx.contains("mad.lo.s32 %r4, %r1, %r2, %r3;"));
assert!(ptx.contains("setp.ge.u32 %p0, %r4, %r0;"));
assert!(ptx.contains("@%p0 bra EXIT;"));
assert!(ptx.contains("cvta.to.global.u64"));
assert!(ptx.contains("mul.wide.u32"));
assert!(ptx.contains("add.s64"));
assert!(ptx.contains("ld.global.f32"));
assert!(ptx.contains("add.f32"));
assert!(ptx.contains("st.global.f32"));
assert!(ptx.contains("EXIT:"));
assert!(ptx.contains("ret;"));
assert!(ptx.trim_end().ends_with('}'));
eprintln!("=== KAIO vector_add PTX ===\n{ptx}");
}
#[test]
fn emit_shared_mem_kernel() {
let ptx = common::build_shared_mem_ptx();
assert!(ptx.contains(".shared .align 4 .b8 sdata[1024];"));
assert!(ptx.contains("st.shared.f32"));
assert!(ptx.contains("ld.shared.f32"));
assert!(ptx.contains("bar.sync 0;"));
assert!(ptx.contains("shfl.sync.down.b32"));
assert!(ptx.contains("0xFFFFFFFF"));
assert!(ptx.contains(".visible .entry shared_mem_test()"));
assert!(ptx.trim_end().ends_with('}'));
eprintln!("=== KAIO shared_mem_test PTX ===\n{ptx}");
}