use super::*;
#[test]
fn test_ld_generic_u32() {
let kernel = PtxKernel::new("test_generic_u32")
.param(PtxType::U64, "ptr")
.build(|ctx| {
let ptr = ctx.load_param_u64("ptr");
let val = ctx.ld_generic_u32(ptr);
ctx.st_generic_u32(ptr, val);
ctx.ret();
});
let ptx = kernel.emit();
assert!(ptx.contains("ld.u32"), "Expected ld.u32 in: {}", ptx);
assert!(ptx.contains("st.u32"), "Expected st.u32 in: {}", ptx);
}
#[test]
fn test_ld_generic_u64() {
let kernel = PtxKernel::new("test_generic_u64")
.param(PtxType::U64, "ptr")
.build(|ctx| {
let ptr = ctx.load_param_u64("ptr");
let val = ctx.ld_generic_u64(ptr);
ctx.st_generic_u64(ptr, val);
ctx.ret();
});
let ptx = kernel.emit();
assert!(
ptx.contains("ld.u64") || ptx.contains(".u64"),
"Expected u64 in: {}",
ptx
);
}
#[test]
fn test_ld_generic_u8() {
let kernel = PtxKernel::new("test_generic_u8")
.param(PtxType::U64, "ptr")
.build(|ctx| {
let ptr = ctx.load_param_u64("ptr");
let val = ctx.ld_generic_u8(ptr);
ctx.st_generic_u8(ptr, val);
ctx.ret();
});
let ptx = kernel.emit();
assert!(
ptx.contains(".u8") || ptx.contains("u8"),
"Expected u8 ops in: {}",
ptx
);
}
#[test]
fn test_ld_generic_u16() {
let kernel = PtxKernel::new("test_generic_u16")
.param(PtxType::U64, "ptr")
.build(|ctx| {
let ptr = ctx.load_param_u64("ptr");
let val = ctx.ld_generic_u16(ptr);
ctx.st_generic_u16(ptr, val);
ctx.ret();
});
let ptx = kernel.emit();
assert!(
ptx.contains(".u16") || ptx.contains("u16"),
"Expected u16 ops in: {}",
ptx
);
}
#[test]
fn test_ld_generic_f32() {
let kernel = PtxKernel::new("test_generic_f32")
.param(PtxType::U64, "ptr")
.build(|ctx| {
let ptr = ctx.load_param_u64("ptr");
let val = ctx.ld_generic_f32(ptr);
ctx.st_generic_f32(ptr, val);
ctx.ret();
});
let ptx = kernel.emit();
assert!(
ptx.contains(".f32") || ptx.contains("f32"),
"Expected f32 ops in: {}",
ptx
);
}
#[test]
fn test_ld_generic_u32_into() {
let kernel = PtxKernel::new("test_generic_into")
.param(PtxType::U64, "ptr")
.build(|ctx| {
let ptr = ctx.load_param_u64("ptr");
let dest = ctx.mov_u32_imm(0);
ctx.ld_generic_u32_into(ptr, dest);
ctx.ret();
});
let ptx = kernel.emit();
assert!(ptx.contains("ld"), "Expected load in: {}", ptx);
}
#[test]
fn test_shared_base_addr() {
let kernel = PtxKernel::new("test_shared_base")
.shared_memory(4096)
.build(|ctx| {
let smem = ctx.shared_base_addr();
let _val = ctx.ld_generic_f32(smem);
ctx.ret();
});
let ptx = kernel.emit();
assert!(
ptx.contains("cvta"),
"Expected cvta instruction for shared base addr in: {}",
ptx
);
assert!(
ptx.contains("smem"),
"Expected smem label reference in: {}",
ptx
);
}
#[test]
fn test_ld_global_f16_to_f32_predicated() {
let kernel = PtxKernel::new("test_f16_pred_load")
.param(PtxType::U64, "ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let ptr = ctx.load_param_u64("ptr");
let n = ctx.load_param_u32("n");
let idx = ctx.mov_u32_imm(3);
let valid = ctx.setp_lt_u32(idx, n);
let val = ctx.ld_global_f16_to_f32_predicated(ptr, valid);
ctx.st_global_f32(ptr, val);
ctx.ret();
});
let ptx = kernel.emit();
assert!(
ptx.contains("mov.f32"),
"Expected mov.f32 for default initialization in: {}",
ptx
);
assert!(
ptx.contains("0F00000000"),
"Expected 0.0 float literal (0F00000000) for default in: {}",
ptx
);
assert!(
ptx.contains("ld.global"),
"Expected ld.global for F16 load in: {}",
ptx
);
assert!(
ptx.contains(".b16"),
"Expected .b16 type for F16 load in: {}",
ptx
);
assert!(
ptx.contains("cvt"),
"Expected cvt instruction for F16->F32 conversion in: {}",
ptx
);
assert!(
ptx.contains("@%p"),
"Expected predicate guard @%p in: {}",
ptx
);
}
#[test]
fn test_ld_global_f16_to_f32_predicated_with_store() {
let kernel = PtxKernel::new("test_f16_pred_accum")
.param(PtxType::U64, "kv_ptr")
.param(PtxType::U32, "head_dim")
.build(|ctx| {
let kv_ptr = ctx.load_param_u64("kv_ptr");
let head_dim = ctx.load_param_u32("head_dim");
let tid = ctx.special_reg(crate::ptx::registers::PtxReg::TidX);
let valid = ctx.setp_lt_u32(tid, head_dim);
let offset = ctx.mul_wide_u32(tid, 2);
let addr = ctx.add_u64(kv_ptr, offset);
let k_val = ctx.ld_global_f16_to_f32_predicated(addr, valid);
let scale = ctx.mov_f32_imm(0.125);
let scaled = ctx.mul_f32(k_val, scale);
let out_offset = ctx.mul_wide_u32(tid, 4);
let out_addr = ctx.add_u64(kv_ptr, out_offset);
ctx.st_global_f32(out_addr, scaled);
ctx.ret();
});
let ptx = kernel.emit();
assert!(ptx.contains("@%p"), "Expected predicate guard in: {}", ptx);
assert!(
ptx.contains("ld.global"),
"Expected global load in: {}",
ptx
);
assert!(
ptx.contains("cvt"),
"Expected F16->F32 conversion in: {}",
ptx
);
assert!(
ptx.contains("mul"),
"Expected multiply for scale in: {}",
ptx
);
assert!(
ptx.contains("st.global"),
"Expected global store in: {}",
ptx
);
}
#[test]
fn test_ld_global_f16_to_f32_predicated_instruction_count() {
let kernel = PtxKernel::new("test_f16_instr_count")
.param(PtxType::U64, "ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let ptr = ctx.load_param_u64("ptr");
let n = ctx.load_param_u32("n");
let idx = ctx.mov_u32_imm(0);
let valid = ctx.setp_lt_u32(idx, n);
let _val = ctx.ld_global_f16_to_f32_predicated(ptr, valid);
ctx.ret();
});
let ptx = kernel.emit();
let predicated_count = ptx.lines().filter(|l| l.contains("@%p")).count();
assert_eq!(
predicated_count, 2,
"Expected exactly 2 predicated instructions (ld + cvt), got {} in:\n{}",
predicated_count, ptx
);
}