use super::*;
use crate::arch::SmVersion;
use crate::builder::KernelBuilder;
fn build_with_body<F>(f: F) -> String
where
F: FnOnce(&mut BodyBuilder<'_>) + 'static,
{
KernelBuilder::new("test")
.target(SmVersion::Sm80)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(f)
.build()
.expect("build should succeed")
}
#[test]
fn global_thread_id_x_emits_mad() {
let ptx = build_with_body(|b| {
let _gid = b.global_thread_id_x();
b.ret();
});
assert!(ptx.contains("mov.u32"));
assert!(ptx.contains("%tid.x"));
assert!(ptx.contains("%ntid.x"));
assert!(ptx.contains("%ctaid.x"));
assert!(ptx.contains("mad.lo.u32"));
assert!(ptx.contains("ret;"));
}
#[test]
fn load_param_u32_emits_ld_param() {
let ptx = build_with_body(|b| {
let _n = b.load_param_u32("n");
b.ret();
});
assert!(ptx.contains("ld.param.u32"));
assert!(ptx.contains("[%param_n]"));
}
#[test]
fn if_lt_u32_emits_setp_and_branch() {
let ptx = build_with_body(|b| {
let tid = b.global_thread_id_x();
let n = b.load_param_u32("n");
b.if_lt_u32(tid, n, |b| {
b.comment("inside conditional");
});
b.ret();
});
assert!(ptx.contains("setp.lo.u32"));
assert!(ptx.contains("@!%p"));
assert!(ptx.contains("bra $L__skip_"));
assert!(ptx.contains("// inside conditional"));
}
#[test]
fn fma_f32_emits_correct_instruction() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("a");
let c = b.fma_f32(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("fma.rn.f32"));
}
#[test]
fn bar_sync_emits_correctly() {
let ptx = build_with_body(|b| {
b.bar_sync(0);
b.ret();
});
assert!(ptx.contains("bar.sync 0;"));
}
#[test]
fn cvt_u32_to_u64_emits_correctly() {
let ptx = build_with_body(|b| {
let n = b.load_param_u32("n");
let _n64 = b.cvt_u32_to_u64(n);
b.ret();
});
assert!(ptx.contains("cvt.u64.u32"));
}
#[test]
fn unroll_emits_multiple_iterations() {
let ptx = build_with_body(|b| {
b.unroll(3, |b, i| {
b.comment(&format!("body {i}"));
});
b.ret();
});
assert!(ptx.contains("unroll iteration 0/3"));
assert!(ptx.contains("unroll iteration 1/3"));
assert!(ptx.contains("unroll iteration 2/3"));
assert!(ptx.contains("// body 0"));
assert!(ptx.contains("// body 1"));
assert!(ptx.contains("// body 2"));
}
#[test]
fn fresh_label_generates_unique_names() {
let mut regs = RegisterAllocator::new();
let mut instructions = Vec::new();
let params = vec![];
let mut bb = BodyBuilder::new(&mut regs, &mut instructions, ¶ms, SmVersion::Sm80);
let l1 = bb.fresh_label("loop");
let l2 = bb.fresh_label("loop");
let l3 = bb.fresh_label("exit");
assert_eq!(l1, "L__loop_0");
assert_eq!(l2, "L__loop_1");
assert_eq!(l3, "L__exit_2");
}
#[test]
fn comment_and_raw_ptx() {
let ptx = build_with_body(|b| {
b.comment("test comment");
b.raw_ptx("custom.instruction;");
b.ret();
});
assert!(ptx.contains("// test comment"));
assert!(ptx.contains("custom.instruction;"));
}
#[test]
fn load_param_f32_param_name() {
let ptx = KernelBuilder::new("f32_kernel")
.target(SmVersion::Sm80)
.param("alpha", PtxType::F32)
.body(|b| {
let _alpha = b.load_param_f32("alpha");
b.ret();
})
.build()
.expect("build should succeed");
assert!(ptx.contains("ld.param.f32"));
assert!(ptx.contains("[%param_alpha]"));
assert!(ptx.contains(".param .f32 %param_alpha"));
}
#[test]
fn store_global_f32_emits_st() {
let ptx = build_with_body(|b| {
let ptr = b.load_param_u64("a");
let val = b.load_param_f32("n");
b.store_global_f32(ptr, val);
b.ret();
});
assert!(ptx.contains("st.global.f32"));
}
#[test]
fn byte_offset_addr_computation() {
let ptx = build_with_body(|b| {
let base = b.load_param_u64("a");
let idx = b.load_param_u32("n");
let _addr = b.f32_elem_addr(base, idx);
b.ret();
});
assert!(ptx.contains("cvt.u64.u32"));
assert!(ptx.contains("mul.lo.u64"));
assert!(ptx.contains("add.u64"));
}
#[test]
fn brev_b32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let val = b.load_param_u32("n");
let _rev = b.brev_b32(val);
b.ret();
});
assert!(ptx.contains("brev.b32"), "expected brev.b32 in:\n{ptx}");
}
#[test]
fn clz_b32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let val = b.load_param_u32("n");
let _lz = b.clz_b32(val);
b.ret();
});
assert!(ptx.contains("clz.b32"), "expected clz.b32 in:\n{ptx}");
}
#[test]
fn popc_b32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let val = b.load_param_u32("n");
let _pc = b.popc_b32(val);
b.ret();
});
assert!(ptx.contains("popc.b32"), "expected popc.b32 in:\n{ptx}");
}
#[test]
fn bfind_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let val = b.load_param_u32("n");
let _msb = b.bfind_u32(val);
b.ret();
});
assert!(ptx.contains("bfind.u32"), "expected bfind.u32 in:\n{ptx}");
}
#[test]
fn bfind_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let val = b.load_param_u32("n");
let _msb = b.bfind_s32(val);
b.ret();
});
assert!(ptx.contains("bfind.s32"), "expected bfind.s32 in:\n{ptx}");
}
#[test]
fn bfe_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let val = b.load_param_u32("n");
let start = b.load_param_u32("n");
let len = b.load_param_u32("n");
let _field = b.bfe_u32(val, start, len);
b.ret();
});
assert!(ptx.contains("bfe.u32"), "expected bfe.u32 in:\n{ptx}");
}
#[test]
fn bfe_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let val = b.load_param_u32("n");
let start = b.load_param_u32("n");
let len = b.load_param_u32("n");
let _field = b.bfe_s32(val, start, len);
b.ret();
});
assert!(ptx.contains("bfe.s32"), "expected bfe.s32 in:\n{ptx}");
}
#[test]
fn bfi_b32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let insert = b.load_param_u32("n");
let base = b.load_param_u32("n");
let start = b.load_param_u32("n");
let len = b.load_param_u32("n");
let _result = b.bfi_b32(insert, base, start, len);
b.ret();
});
assert!(ptx.contains("bfi.b32"), "expected bfi.b32 in:\n{ptx}");
}
#[test]
fn rcp_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("n");
let _r = b.rcp_f32(a);
b.ret();
});
assert!(ptx.contains("rcp.rn.f32"), "expected rcp.rn.f32 in:\n{ptx}");
}
#[test]
fn rcp_f64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f64("n");
let _r = b.rcp_f64(a);
b.ret();
});
assert!(ptx.contains("rcp.rn.f64"), "expected rcp.rn.f64 in:\n{ptx}");
}
#[test]
fn rcp_approx_f32_emits_no_rounding() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("n");
let _r = b.rcp_approx_f32(a);
b.ret();
});
assert!(ptx.contains("rcp.f32"), "expected rcp.f32 in:\n{ptx}");
assert!(
!ptx.contains("rcp.rn.f32"),
"should NOT have rounding mode in approx:\n{ptx}"
);
}
#[test]
fn rsqrt_approx_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("n");
let _r = b.rsqrt_approx_f32(a);
b.ret();
});
assert!(
ptx.contains("rsqrt.approx.f32"),
"expected rsqrt.approx.f32 in:\n{ptx}"
);
}
#[test]
fn sqrt_rn_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("n");
let _r = b.sqrt_rn_f32(a);
b.ret();
});
assert!(
ptx.contains("sqrt.rn.f32"),
"expected sqrt.rn.f32 in:\n{ptx}"
);
}
#[test]
fn sqrt_rn_f64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f64("n");
let _r = b.sqrt_rn_f64(a);
b.ret();
});
assert!(
ptx.contains("sqrt.rn.f64"),
"expected sqrt.rn.f64 in:\n{ptx}"
);
}
#[test]
fn ex2_approx_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("n");
let _r = b.ex2_approx_f32(a);
b.ret();
});
assert!(
ptx.contains("ex2.approx.f32"),
"expected ex2.approx.f32 in:\n{ptx}"
);
}
#[test]
fn lg2_approx_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("n");
let _r = b.lg2_approx_f32(a);
b.ret();
});
assert!(
ptx.contains("lg2.approx.f32"),
"expected lg2.approx.f32 in:\n{ptx}"
);
}
#[test]
fn sin_approx_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("n");
let _r = b.sin_approx_f32(a);
b.ret();
});
assert!(
ptx.contains("sin.approx.f32"),
"expected sin.approx.f32 in:\n{ptx}"
);
}
#[test]
fn cos_approx_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f32("n");
let _r = b.cos_approx_f32(a);
b.ret();
});
assert!(
ptx.contains("cos.approx.f32"),
"expected cos.approx.f32 in:\n{ptx}"
);
}
#[test]
fn rsqrt_approx_f64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_f64("n");
let _r = b.rsqrt_approx_f64(a);
b.ret();
});
assert!(
ptx.contains("rsqrt.approx.f64"),
"expected rsqrt.approx.f64 in:\n{ptx}"
);
}
#[test]
fn brev_b64_via_builder() {
let ptx = build_with_body(|b| {
let val = b.load_param_u64("a");
let _rev = b.brev_b64(val);
b.ret();
});
assert!(ptx.contains("brev.b64"), "expected brev.b64 in:\n{ptx}");
}
#[test]
fn popc_b64_via_builder() {
let ptx = build_with_body(|b| {
let val = b.load_param_u64("a");
let _pc = b.popc_b64(val);
b.ret();
});
assert!(ptx.contains("popc.b64"), "expected popc.b64 in:\n{ptx}");
}
#[test]
fn atom_global_add_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_f32("n");
let _old = b.atom_global_add_f32(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.global.add.f32"),
"expected atom.global.add.f32 in:\n{ptx}"
);
}
#[test]
fn atom_global_add_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_u32("n");
let _old = b.atom_global_add_u32(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.global.add.u32"),
"expected atom.global.add.u32 in:\n{ptx}"
);
}
#[test]
fn atom_global_add_u64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_u64("n");
let _old = b.atom_global_add_u64(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.global.add.u64"),
"expected atom.global.add.u64 in:\n{ptx}"
);
}
#[test]
fn atom_global_add_f64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_f64("n");
let _old = b.atom_global_add_f64(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.global.add.f64"),
"expected atom.global.add.f64 in:\n{ptx}"
);
}
#[test]
fn atom_global_cas_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let cmp_val = b.load_param_u32("n");
let new_val = b.load_param_u32("n");
let _old = b.atom_global_cas_u32(addr, cmp_val, new_val);
b.ret();
});
assert!(
ptx.contains("atom.global.cas.u32"),
"expected atom.global.cas.u32 in:\n{ptx}"
);
}
#[test]
fn atom_global_cas_u64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let cmp_val = b.load_param_u64("n");
let new_val = b.load_param_u64("n");
let _old = b.atom_global_cas_u64(addr, cmp_val, new_val);
b.ret();
});
assert!(
ptx.contains("atom.global.cas.u64"),
"expected atom.global.cas.u64 in:\n{ptx}"
);
}
#[test]
fn atom_global_exch_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_u32("n");
let _old = b.atom_global_exch_u32(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.global.exch.u32"),
"expected atom.global.exch.u32 in:\n{ptx}"
);
}
#[test]
fn atom_global_min_max_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_u32("n");
let _min = b.atom_global_min_u32(addr.clone(), val.clone());
let _max = b.atom_global_max_u32(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.global.min.u32"),
"expected atom.global.min.u32 in:\n{ptx}"
);
assert!(
ptx.contains("atom.global.max.u32"),
"expected atom.global.max.u32 in:\n{ptx}"
);
}
#[test]
fn atom_global_min_max_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_u32("n");
let _min = b.atom_global_min_s32(addr.clone(), val.clone());
let _max = b.atom_global_max_s32(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.global.min.s32"),
"expected atom.global.min.s32 in:\n{ptx}"
);
assert!(
ptx.contains("atom.global.max.s32"),
"expected atom.global.max.s32 in:\n{ptx}"
);
}
#[test]
fn atom_global_bitwise_ops_emit_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_u32("n");
let _and = b.atom_global_and_b32(addr.clone(), val.clone());
let _or = b.atom_global_or_b32(addr.clone(), val.clone());
let _xor = b.atom_global_xor_b32(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.global.and.b32"),
"expected atom.global.and.b32 in:\n{ptx}"
);
assert!(
ptx.contains("atom.global.or.b32"),
"expected atom.global.or.b32 in:\n{ptx}"
);
assert!(
ptx.contains("atom.global.xor.b32"),
"expected atom.global.xor.b32 in:\n{ptx}"
);
}
#[test]
fn atom_shared_add_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u32("n");
let val = b.load_param_f32("n");
let _old = b.atom_shared_add_f32(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.shared.add.f32"),
"expected atom.shared.add.f32 in:\n{ptx}"
);
}
#[test]
fn atom_shared_add_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u32("n");
let val = b.load_param_u32("n");
let _old = b.atom_shared_add_u32(addr, val);
b.ret();
});
assert!(
ptx.contains("atom.shared.add.u32"),
"expected atom.shared.add.u32 in:\n{ptx}"
);
}
#[test]
fn red_global_add_f32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_f32("n");
b.red_global_add_f32(addr, val);
b.ret();
});
assert!(
ptx.contains("red.global.add.f32"),
"expected red.global.add.f32 in:\n{ptx}"
);
}
#[test]
fn red_global_add_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let addr = b.load_param_u64("a");
let val = b.load_param_u32("n");
b.red_global_add_u32(addr, val);
b.ret();
});
assert!(
ptx.contains("red.global.add.u32"),
"expected red.global.add.u32 in:\n{ptx}"
);
}
#[test]
fn mad_lo_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let c = b.mad_lo_s32(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("mad.lo.s32"), "expected mad.lo.s32 in:\n{ptx}");
}
#[test]
fn mad_lo_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let c = b.mad_lo_u32(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("mad.lo.u32"), "expected mad.lo.u32 in:\n{ptx}");
}
#[test]
fn mad_lo_s64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u64("a");
let c = b.mad_lo_s64(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("mad.lo.s64"), "expected mad.lo.s64 in:\n{ptx}");
}
#[test]
fn mad_lo_u64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u64("a");
let c = b.mad_lo_u64(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("mad.lo.u64"), "expected mad.lo.u64 in:\n{ptx}");
}
#[test]
fn mad_hi_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let c = b.mad_hi_s32(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("mad.hi.s32"), "expected mad.hi.s32 in:\n{ptx}");
}
#[test]
fn mad_hi_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let c = b.mad_hi_u32(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("mad.hi.u32"), "expected mad.hi.u32 in:\n{ptx}");
}
#[test]
fn mad_hi_s64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u64("a");
let c = b.mad_hi_s64(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("mad.hi.s64"), "expected mad.hi.s64 in:\n{ptx}");
}
#[test]
fn mad_hi_u64_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u64("a");
let c = b.mad_hi_u64(a.clone(), a.clone(), a);
let _ = c;
b.ret();
});
assert!(ptx.contains("mad.hi.u64"), "expected mad.hi.u64 in:\n{ptx}");
}
#[test]
fn mad_wide_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let c = b.load_param_u64("a");
let r = b.mad_wide_s32(a.clone(), a, c);
let _ = r;
b.ret();
});
assert!(
ptx.contains("mad.wide.s32"),
"expected mad.wide.s32 in:\n{ptx}"
);
}
#[test]
fn mad_wide_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let c = b.load_param_u64("a");
let r = b.mad_wide_u32(a.clone(), a, c);
let _ = r;
b.ret();
});
assert!(
ptx.contains("mad.wide.u32"),
"expected mad.wide.u32 in:\n{ptx}"
);
}
#[test]
fn mad_wide_s16_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let r = b.mad_wide_s16(a.clone(), a.clone(), a);
let _ = r;
b.ret();
});
assert!(
ptx.contains("mad.wide.s16"),
"expected mad.wide.s16 in:\n{ptx}"
);
}
#[test]
fn mad_wide_u16_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let r = b.mad_wide_u16(a.clone(), a.clone(), a);
let _ = r;
b.ret();
});
assert!(
ptx.contains("mad.wide.u16"),
"expected mad.wide.u16 in:\n{ptx}"
);
}
#[test]
fn dp4a_u32_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let c = b.load_param_u32("n");
let _r = b.dp4a_u32_u32(a.clone(), a, c);
b.ret();
});
assert!(
ptx.contains("dp4a.u32.u32"),
"expected dp4a.u32.u32 in:\n{ptx}"
);
}
#[test]
fn dp4a_s32_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let _r = b.dp4a_s32_s32(a.clone(), a.clone(), a);
b.ret();
});
assert!(
ptx.contains("dp4a.s32.s32"),
"expected dp4a.s32.s32 in:\n{ptx}"
);
}
#[test]
fn dp4a_s32_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let _r = b.dp4a_s32_u32(a.clone(), a.clone(), a);
b.ret();
});
assert!(
ptx.contains("dp4a.s32.u32"),
"expected dp4a.s32.u32 in:\n{ptx}"
);
}
#[test]
fn dp4a_u32_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let _r = b.dp4a_u32_s32(a.clone(), a.clone(), a);
b.ret();
});
assert!(
ptx.contains("dp4a.u32.s32"),
"expected dp4a.u32.s32 in:\n{ptx}"
);
}
#[test]
fn dp4a_result_is_s32_register() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let r = b.dp4a_u32_u32(a.clone(), a.clone(), a);
assert_eq!(r.ty, PtxType::S32);
b.ret();
});
assert!(ptx.contains("dp4a"));
}
#[test]
fn dp4a_chained_accumulate() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let acc = b.dp4a_u32_u32(a.clone(), a.clone(), a.clone());
let _acc2 = b.dp4a_u32_u32(a.clone(), a, acc);
b.ret();
});
let count = ptx.matches("dp4a.u32.u32").count();
assert_eq!(
count, 2,
"expected 2 dp4a instructions, got {count} in:\n{ptx}"
);
}
#[test]
fn dp4a_with_immediate_not_supported_uses_registers() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let c = b.load_param_u32("n");
let r = b.dp4a_s32_s32(a.clone(), a, c);
let _ = r;
b.ret();
});
assert!(ptx.contains("dp4a.s32.s32"), "expected dp4a in:\n{ptx}");
assert!(
ptx.contains("%r"),
"expected register references in:\n{ptx}"
);
}
#[test]
fn dp4a_mixed_sign_all_variants() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let _r1 = b.dp4a_u32_u32(a.clone(), a.clone(), a.clone());
let _r2 = b.dp4a_s32_s32(a.clone(), a.clone(), a.clone());
let _r3 = b.dp4a_s32_u32(a.clone(), a.clone(), a.clone());
let _r4 = b.dp4a_u32_s32(a.clone(), a.clone(), a);
b.ret();
});
assert!(ptx.contains("dp4a.u32.u32"), "missing u32.u32 in:\n{ptx}");
assert!(ptx.contains("dp4a.s32.s32"), "missing s32.s32 in:\n{ptx}");
assert!(ptx.contains("dp4a.s32.u32"), "missing s32.u32 in:\n{ptx}");
assert!(ptx.contains("dp4a.u32.s32"), "missing u32.s32 in:\n{ptx}");
}
#[test]
fn dp2a_lo_u32_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let _r = b.dp2a_lo_u32_u32(a.clone(), a.clone(), a);
b.ret();
});
assert!(
ptx.contains("dp2a.lo.u32.u32"),
"expected dp2a.lo.u32.u32 in:\n{ptx}"
);
}
#[test]
fn dp2a_hi_u32_u32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let _r = b.dp2a_hi_u32_u32(a.clone(), a.clone(), a);
b.ret();
});
assert!(
ptx.contains("dp2a.hi.u32.u32"),
"expected dp2a.hi.u32.u32 in:\n{ptx}"
);
}
#[test]
fn dp2a_lo_s32_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let _r = b.dp2a_lo_s32_s32(a.clone(), a.clone(), a);
b.ret();
});
assert!(
ptx.contains("dp2a.lo.s32.s32"),
"expected dp2a.lo.s32.s32 in:\n{ptx}"
);
}
#[test]
fn dp2a_hi_s32_s32_emits_correct_ptx() {
let ptx = build_with_body(|b| {
let a = b.load_param_u32("n");
let _r = b.dp2a_hi_s32_s32(a.clone(), a.clone(), a);
b.ret();
});
assert!(
ptx.contains("dp2a.hi.s32.s32"),
"expected dp2a.hi.s32.s32 in:\n{ptx}"
);
}
#[test]
fn unroll_zero_emits_nothing() {
let ptx = build_with_body(|b| {
b.unroll(0, |_b, _i| {
panic!("should not be called");
});
b.ret();
});
assert!(
!ptx.contains("unroll iteration"),
"no iterations expected in:\n{ptx}"
);
}
#[test]
fn unroll_single_iteration() {
let ptx = build_with_body(|b| {
b.unroll(1, |b, i| {
b.comment(&format!("iter {i}"));
});
b.ret();
});
assert!(
ptx.contains("unroll iteration 0/1"),
"expected 0/1 in:\n{ptx}"
);
assert!(ptx.contains("// iter 0"), "expected iter 0 in:\n{ptx}");
assert!(
!ptx.contains("unroll iteration 1/1"),
"no 1/1 expected in:\n{ptx}"
);
}
#[test]
fn unroll_iteration_index_passed_correctly() {
let ptx = build_with_body(|b| {
b.unroll(4, |b, i| {
b.comment(&format!("idx={i}"));
});
b.ret();
});
for i in 0..4 {
assert!(
ptx.contains(&format!("// idx={i}")),
"expected idx={i} in:\n{ptx}"
);
}
}
#[test]
fn unroll_generates_distinct_registers() {
let ptx = build_with_body(|b| {
let base = b.load_param_u32("n");
b.unroll(3, |b, _i| {
let _r = b.add_u32(base.clone(), base.clone());
});
b.ret();
});
let add_count = ptx.matches("add.u32").count();
assert_eq!(
add_count, 3,
"expected 3 add.u32 instructions, got {add_count} in:\n{ptx}"
);
}
#[test]
fn pragma_unroll_with_factor_emits_directive() {
let ptx = build_with_body(|b| {
b.pragma_unroll(Some(8));
b.ret();
});
assert!(
ptx.contains(".pragma \"unroll 8\";"),
"expected .pragma \"unroll 8\"; in:\n{ptx}"
);
}
#[test]
fn pragma_unroll_none_emits_nounroll() {
let ptx = build_with_body(|b| {
b.pragma_unroll(None);
b.ret();
});
assert!(
ptx.contains(".pragma \"nounroll\";"),
"expected .pragma \"nounroll\"; in:\n{ptx}"
);
}
#[test]
fn pragma_unroll_factor_one() {
let ptx = build_with_body(|b| {
b.pragma_unroll(Some(1));
b.ret();
});
assert!(
ptx.contains(".pragma \"unroll 1\";"),
"expected .pragma \"unroll 1\"; in:\n{ptx}"
);
}
#[test]
fn pragma_unroll_before_loop_pattern() {
let ptx = build_with_body(|b| {
b.pragma_unroll(Some(4));
let start = b.fresh_label("loop");
let end = b.fresh_label("end");
b.label(&start);
b.comment("loop body");
b.branch(&end);
b.branch(&start);
b.label(&end);
b.ret();
});
assert!(
ptx.contains(".pragma \"unroll 4\";"),
"expected pragma before loop in:\n{ptx}"
);
assert!(
ptx.contains("bra $L__loop_0;"),
"expected loop branch in:\n{ptx}"
);
}
fn build_with_body_sm90<F>(f: F) -> String
where
F: FnOnce(&mut BodyBuilder<'_>) + 'static,
{
KernelBuilder::new("test")
.target(SmVersion::Sm90)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(f)
.build()
.expect("build should succeed")
}
#[test]
fn redux_add_emits_on_sm80() {
let ptx = build_with_body(|b| {
let n = b.load_param_u32("n");
let _result = b.redux_add_u32(&n.name);
b.ret();
});
assert!(
ptx.contains("redux.sync.add.u32"),
"expected redux instruction in:\n{ptx}"
);
}
#[test]
fn redux_fails_on_sm75() {
use crate::error::PtxGenError;
let mut regs = crate::ir::RegisterAllocator::new();
let mut instructions = Vec::new();
let params = vec!["a".to_string()];
let mut bb = BodyBuilder::new(&mut regs, &mut instructions, ¶ms, SmVersion::Sm75);
let result = bb.redux_add_u32("%r0");
assert!(result.is_err());
if let Err(PtxGenError::GenerationFailed(msg)) = result {
assert!(
msg.contains("SM >= 80"),
"error should mention SM >= 80: {msg}"
);
}
}
#[test]
fn stmatrix_emits_on_sm90() {
let ptx = build_with_body_sm90(|b| {
let _r = b.stmatrix_m8n8x4("%r0", "%r1");
b.ret();
});
assert!(
ptx.contains("stmatrix.sync.aligned.m8n8.x4"),
"expected stmatrix in:\n{ptx}"
);
}
#[test]
fn elect_sync_emits_on_sm90() {
let ptx = build_with_body_sm90(|b| {
let _pred = b.elect_sync();
b.ret();
});
assert!(ptx.contains("elect.sync"), "expected elect.sync in:\n{ptx}");
}
#[test]
fn setmaxnreg_emits_on_sm90() {
let ptx = build_with_body_sm90(|b| {
let _r = b.setmaxnreg_inc(128);
let _r2 = b.setmaxnreg_dec(64);
b.ret();
});
assert!(
ptx.contains("setmaxnreg.inc 128;"),
"expected setmaxnreg.inc in:\n{ptx}"
);
assert!(
ptx.contains("setmaxnreg.dec 64;"),
"expected setmaxnreg.dec in:\n{ptx}"
);
}
#[test]
fn griddepcontrol_emits_on_sm90() {
let ptx = build_with_body_sm90(|b| {
let _r = b.griddepcontrol_launch_dependents();
let _r2 = b.griddepcontrol_wait();
b.ret();
});
assert!(
ptx.contains("griddepcontrol.launch_dependents;"),
"expected griddepcontrol in:\n{ptx}"
);
assert!(
ptx.contains("griddepcontrol.wait;"),
"expected griddepcontrol.wait in:\n{ptx}"
);
}
#[test]
fn fence_proxy_emits() {
let ptx = build_with_body_sm90(|b| {
let _r = b.fence_proxy_async("gpu");
b.ret();
});
assert!(
ptx.contains("fence.proxy.async.gpu.shared;"),
"expected fence.proxy in:\n{ptx}"
);
}
#[test]
fn mbarrier_lifecycle_emits_on_sm90() {
let ptx = build_with_body_sm90(|b| {
let _r = b.mbarrier_init("%r0", "%r1");
let _r2 = b.mbarrier_arrive("%r0");
let _r3 = b.mbarrier_wait("%r0", "%r2");
b.ret();
});
assert!(
ptx.contains("mbarrier.init.shared.b64"),
"expected mbarrier.init in:\n{ptx}"
);
assert!(
ptx.contains("mbarrier.arrive.shared.b64"),
"expected mbarrier.arrive in:\n{ptx}"
);
assert!(
ptx.contains("mbarrier.try_wait.parity.shared.b64"),
"expected mbarrier.wait in:\n{ptx}"
);
}
#[test]
fn round_trip_kernel_with_ptx8_instructions() {
let ptx = build_with_body_sm90(|b| {
let n = b.load_param_u32("n");
let _redux_result = b.redux_add_u32(&n.name);
let _pred = b.elect_sync();
let _r = b.griddepcontrol_launch_dependents();
let _r2 = b.fence_proxy_async("cta");
let _r3 = b.setmaxnreg_inc(96);
b.ret();
});
assert!(ptx.contains(".version"), "should have PTX version header");
assert!(ptx.contains(".target"), "should have target directive");
assert!(ptx.contains("redux.sync.add.u32"), "should have redux");
assert!(ptx.contains("elect.sync"), "should have elect.sync");
assert!(
ptx.contains("griddepcontrol.launch_dependents"),
"should have griddepcontrol"
);
assert!(
ptx.contains("fence.proxy.async.cta.shared"),
"should have fence.proxy"
);
assert!(ptx.contains("setmaxnreg.inc 96"), "should have setmaxnreg");
assert!(ptx.contains("ret;"), "should have ret");
}
#[test]
fn cvt_f32_to_e4m3_emits_correct_type() {
let ptx = build_with_body(|b| {
let src = b.load_param_f32("a");
let _dst = b.cvt_f32_to_e4m3(src);
b.ret();
});
assert!(
ptx.contains(".e4m3"),
"expected '.e4m3' type in PTX:\n{ptx}"
);
assert!(
ptx.contains("cvt.rn"),
"expected round-to-nearest mode in FP8 cvt:\n{ptx}"
);
}
#[test]
fn cvt_e4m3_to_f32_emits_correct_instruction() {
let ptx = build_with_body(|b| {
let e4m3_reg = b.alloc_reg(PtxType::E4M3);
let _f32_reg = b.cvt_e4m3_to_f32(e4m3_reg);
b.ret();
});
assert!(
ptx.contains("cvt"),
"expected 'cvt' instruction in PTX:\n{ptx}"
);
assert!(
ptx.contains(".e4m3"),
"expected '.e4m3' source type in PTX:\n{ptx}"
);
assert!(
ptx.contains(".f32"),
"expected '.f32' destination type in PTX:\n{ptx}"
);
}
#[test]
fn cvt_f32_to_e5m2_emits_correct_type() {
let ptx = build_with_body(|b| {
let src = b.load_param_f32("a");
let _dst = b.cvt_f32_to_e5m2(src);
b.ret();
});
assert!(
ptx.contains(".e5m2"),
"expected '.e5m2' type in PTX:\n{ptx}"
);
assert!(
ptx.contains("cvt.rn"),
"expected round-to-nearest mode in FP8 cvt:\n{ptx}"
);
}
#[test]
fn cvt_e5m2_to_f32_emits_round_trip_types() {
let ptx = build_with_body(|b| {
let src_f32 = b.load_param_f32("a");
let e5m2_reg = b.cvt_f32_to_e5m2(src_f32);
let _back = b.cvt_e5m2_to_f32(e5m2_reg);
b.ret();
});
assert!(
ptx.contains(".e5m2"),
"expected '.e5m2' in round-trip PTX:\n{ptx}"
);
assert!(
ptx.contains(".f32"),
"expected '.f32' in round-trip PTX:\n{ptx}"
);
}
#[test]
fn cvt_bf16_to_f32_emits_correctly() {
let ptx = build_with_body(|b| {
let bf16_reg = b.alloc_reg(PtxType::BF16);
let _f32_reg = b.cvt_bf16_to_f32(bf16_reg);
b.ret();
});
assert!(
ptx.contains(".bf16"),
"expected '.bf16' source type in PTX:\n{ptx}"
);
assert!(
ptx.contains(".f32"),
"expected '.f32' destination in PTX:\n{ptx}"
);
}
#[test]
fn cvt_f32_to_bf16_emits_rn_mode() {
let ptx = build_with_body(|b| {
let src = b.load_param_f32("a");
let _dst = b.cvt_f32_to_bf16(src);
b.ret();
});
assert!(
ptx.contains("cvt.rn"),
"expected round-to-nearest in bf16 cvt:\n{ptx}"
);
assert!(
ptx.contains(".bf16"),
"expected '.bf16' destination type in PTX:\n{ptx}"
);
}
#[test]
fn wgmma_mma_async_emits_on_sm90() {
let ptx = build_with_body_sm90(|b| {
let _r = b.wgmma_mma_async_m64n128k16_f16("%desc_a", "%desc_b");
b.ret();
});
assert!(
ptx.contains("wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16"),
"expected wgmma instruction in PTX:\n{ptx}"
);
}
#[test]
fn wgmma_mma_async_rejected_below_sm90() {
let result = KernelBuilder::new("test_wgmma_sm80")
.target(SmVersion::Sm80)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let r = b.wgmma_mma_async_m64n128k16_f16("%desc_a", "%desc_b");
assert!(r.is_err(), "wgmma should error on sm_80");
b.ret();
})
.build();
let _ = result;
}
fn build_with_body_sm100<F>(f: F) -> String
where
F: FnOnce(&mut BodyBuilder<'_>) + 'static,
{
KernelBuilder::new("test_sm100")
.target(SmVersion::Sm100)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(f)
.build()
.expect("build should succeed on sm_100")
}
#[test]
fn test_fp4_e2m1_type() {
assert_eq!(PtxType::E2M1.bit_width(), 4);
assert!(PtxType::E2M1.is_float());
assert_eq!(format!("{}", PtxType::E2M1), "e2m1");
assert_eq!(PtxType::E2M1.as_ptx_str(), ".e2m1");
}
#[test]
fn test_tcgen05_requires_sm100() {
let caps_sm90 = SmVersion::Sm90.capabilities();
assert!(!caps_sm90.has_fp6_fp4, "SM 90 should not have FP4 support");
let caps_sm100 = SmVersion::Sm100.capabilities();
assert!(
caps_sm100.has_fp6_fp4,
"SM 100 should have FP4/tcgen05 support"
);
assert!(SmVersion::Sm100 >= SmVersion::Sm100, "SM version ordering");
}
#[test]
fn test_cvt_f32_to_e2m1_emits_on_sm100() {
let ptx = build_with_body_sm100(|b| {
let src = b.alloc_reg(PtxType::F32);
let _dst = b
.cvt_f32_to_e2m1(src)
.expect("cvt_f32_to_e2m1 should succeed on sm_100");
b.ret();
});
assert!(
ptx.contains("cvt.rn.e2m1.f32"),
"expected cvt.rn.e2m1.f32 in PTX:\n{ptx}"
);
}
#[test]
fn test_cvt_e2m1_to_f32_emits_on_sm100() {
let ptx = build_with_body_sm100(|b| {
let src = b.alloc_reg(PtxType::E2M1);
let _dst = b
.cvt_e2m1_to_f32(src)
.expect("cvt_e2m1_to_f32 should succeed on sm_100");
b.ret();
});
assert!(
ptx.contains("cvt.f32.e2m1"),
"expected cvt.f32.e2m1 in PTX:\n{ptx}"
);
}
#[test]
fn test_cvt_fp4_rejected_below_sm100() {
let result = KernelBuilder::new("test_fp4_sm90")
.target(SmVersion::Sm90)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let src = b.alloc_reg(PtxType::F32);
let r = b.cvt_f32_to_e2m1(src);
assert!(r.is_err(), "cvt_f32_to_e2m1 should fail on sm_90");
b.ret();
})
.build();
let _ = result;
}
#[test]
fn test_tcgen05_mma_emits_on_sm100() {
let ptx = build_with_body_sm100(|b| {
let a_desc = b.alloc_reg(PtxType::U64);
let b_desc = b.alloc_reg(PtxType::U64);
b.tcgen05_mma_m128n256k256_e2m1(a_desc, b_desc)
.expect("tcgen05.mma should succeed on sm_100");
b.ret();
});
assert!(
ptx.contains("tcgen05.mma.cta_group::1.kind::f32"),
"expected tcgen05.mma instruction in PTX:\n{ptx}"
);
}
#[test]
fn test_cluster_barrier_ptx_emission() {
let ptx = build_with_body_sm90(|b| {
b.barrier_cluster()
.expect("barrier_cluster should succeed on sm_90");
b.fence_cluster()
.expect("fence_cluster should succeed on sm_90");
b.ret();
});
assert!(
ptx.contains("barrier.cluster.arrive;"),
"expected barrier.cluster.arrive in PTX:\n{ptx}"
);
assert!(
ptx.contains("fence.mbarrier_init.release.cluster;"),
"expected fence.mbarrier_init.release.cluster in PTX:\n{ptx}"
);
}
#[test]
fn test_cluster_barrier_rejected_below_sm90() {
let result = KernelBuilder::new("test_barrier_cluster_sm80")
.target(SmVersion::Sm80)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let r = b.barrier_cluster();
assert!(r.is_err(), "barrier_cluster should fail on sm_80");
b.ret();
})
.build();
let _ = result;
}
#[test]
fn test_tma_cp_async_bulk_ptx() {
let ptx = build_with_body_sm90(|b| {
let dst = b.alloc_reg(PtxType::U64);
let src = b.alloc_reg(PtxType::U64);
let desc = b.alloc_reg(PtxType::U64);
b.cp_async_bulk_tensor_1d(dst, src, desc)
.expect("cp_async_bulk_tensor_1d should succeed on sm_90");
b.ret();
});
assert!(
ptx.contains("cp.async.bulk.tensor.1d.shared::cluster.global.tile.bulk_group"),
"expected cp.async.bulk.tensor.1d in PTX:\n{ptx}"
);
}
#[test]
fn test_tma_bulk_rejected_below_sm90() {
let result = KernelBuilder::new("test_cp_bulk_sm80")
.target(SmVersion::Sm80)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let dst = b.alloc_reg(PtxType::U64);
let src = b.alloc_reg(PtxType::U64);
let desc = b.alloc_reg(PtxType::U64);
let r = b.cp_async_bulk_tensor_1d(dst, src, desc);
assert!(r.is_err(), "cp_async_bulk should fail on sm_80");
b.ret();
})
.build();
let _ = result;
}
#[test]
fn test_griddepcontrol_emission() {
let ptx = build_with_body_sm90(|b| {
b.griddepcontrol_launch_dependents()
.expect("griddepcontrol_launch_dependents should succeed on sm_90");
b.griddepcontrol_wait()
.expect("griddepcontrol_wait should succeed on sm_90");
b.ret();
});
assert!(
ptx.contains("griddepcontrol.launch_dependents;"),
"expected griddepcontrol.launch_dependents in PTX:\n{ptx}"
);
assert!(
ptx.contains("griddepcontrol.wait;"),
"expected griddepcontrol.wait in PTX:\n{ptx}"
);
}
#[test]
fn test_cp_async_4byte_ptx() {
let ptx = build_with_body(|b| {
let dst = b.alloc_reg(PtxType::U64);
let src = b.alloc_reg(PtxType::U64);
b.cp_async_32bit(dst, src);
b.ret();
});
assert!(ptx.contains("cp.async"), "expected cp.async in PTX:\n{ptx}");
assert!(
ptx.contains(", 4;") || ptx.contains(", 4 ") || ptx.ends_with(", 4"),
"expected size 4 in cp.async PTX:\n{ptx}"
);
}
#[test]
fn test_cp_async_8byte_ptx() {
let ptx = build_with_body(|b| {
let dst = b.alloc_reg(PtxType::U64);
let src = b.alloc_reg(PtxType::U64);
b.cp_async_64bit(dst, src);
b.ret();
});
assert!(ptx.contains("cp.async"), "expected cp.async in PTX:\n{ptx}");
assert!(
ptx.contains(", 8;") || ptx.contains(", 8 ") || ptx.ends_with(", 8"),
"expected size 8 in cp.async PTX:\n{ptx}"
);
}
#[test]
fn test_cp_async_16byte_ptx() {
let ptx = build_with_body(|b| {
let dst = b.alloc_reg(PtxType::U64);
let src = b.alloc_reg(PtxType::U64);
b.cp_async_128bit(dst, src);
b.ret();
});
assert!(ptx.contains("cp.async"), "expected cp.async in PTX:\n{ptx}");
assert!(
ptx.contains(", 16;") || ptx.contains(", 16 ") || ptx.ends_with(", 16"),
"expected size 16 in cp.async PTX:\n{ptx}"
);
}
#[test]
fn test_cp_async_all_three_sizes_distinct_ptx() {
let ptx = build_with_body(|b| {
let dst = b.alloc_reg(PtxType::U64);
let src = b.alloc_reg(PtxType::U64);
b.cp_async_32bit(dst.clone(), src.clone());
b.cp_async_64bit(dst.clone(), src.clone());
b.cp_async_128bit(dst, src);
b.cp_async_commit();
b.cp_async_wait(0);
b.ret();
});
let count_4 = ptx.matches("cp.async.ca.shared.global").count();
assert_eq!(
count_4, 3,
"expected exactly 3 cp.async.ca.shared.global instructions, got {count_4}:\n{ptx}"
);
assert!(
ptx.contains("cp.async.commit_group;"),
"expected commit_group:\n{ptx}"
);
assert!(
ptx.contains("cp.async.wait_group 0;"),
"expected wait_group 0:\n{ptx}"
);
}
#[test]
fn test_ldmatrix_sm75_ptx() {
let ptx = KernelBuilder::new("test_ldmatrix_sm75")
.target(SmVersion::Sm75)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let addr = b.alloc_reg(PtxType::U64);
let _regs = b
.ldmatrix_x4(addr)
.expect("ldmatrix_x4 should succeed on sm_75");
b.ret();
})
.build()
.expect("kernel build should succeed");
assert!(
ptx.contains("ldmatrix.sync.aligned.m8n8.x4.shared.b16"),
"expected ldmatrix.sync.aligned.m8n8.x4.shared.b16 in PTX:\n{ptx}"
);
}
#[test]
fn test_ldmatrix_sm80_ptx() {
let ptx = build_with_body(|b| {
let addr = b.alloc_reg(PtxType::U64);
let _regs = b
.ldmatrix_x4(addr)
.expect("ldmatrix_x4 should succeed on sm_80");
b.ret();
});
assert!(
ptx.contains("ldmatrix.sync.aligned.m8n8.x4.shared.b16"),
"expected ldmatrix instruction in SM_80 PTX:\n{ptx}"
);
}
#[test]
fn test_ldmatrix_rejects_pre_sm75() {
let result = KernelBuilder::new("test_ldmatrix_sm75_succeeds")
.target(SmVersion::Sm75)
.param("a", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let addr = b.alloc_reg(PtxType::U64);
let r = b.ldmatrix_x4(addr);
assert!(r.is_ok(), "ldmatrix_x4 should succeed on SM 75");
b.ret();
})
.build();
assert!(
result.is_ok(),
"SM75 ldmatrix kernel should build successfully"
);
}
#[test]
fn test_ldmatrix_returns_four_registers() {
let ptx = build_with_body(|b| {
let addr = b.alloc_reg(PtxType::U64);
let regs = b
.ldmatrix_x4(addr)
.expect("ldmatrix_x4 should succeed on sm_80");
b.comment(&format!(
"ldmatrix dst: {}, {}, {}, {}",
regs[0].name, regs[1].name, regs[2].name, regs[3].name
));
b.ret();
});
assert!(
ptx.contains("ldmatrix"),
"expected ldmatrix instruction:\n{ptx}"
);
}