#[cfg(test)]
mod tests {
use crate::arch::SmVersion;
use crate::builder::KernelBuilder;
use crate::ir::PtxType;
fn generate_vector_add_ptx_f32_raw(sm: SmVersion) -> Result<String, crate::error::PtxGenError> {
KernelBuilder::new("vector_add")
.target(sm)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let gid = b.global_thread_id_x();
let gid_name = gid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(gid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {gid_name};\n \
mul.lo.u64 %rd_off, %rd_off, 4;\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;\n \
ld.global.f32 %f_a, [%rd_a];\n \
ld.global.f32 %f_b, [%rd_b];\n \
add.f32 %f_c, %f_a, %f_b;\n \
st.global.f32 [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_vector_add_ptx_f16(sm: SmVersion) -> Result<String, crate::error::PtxGenError> {
KernelBuilder::new("vector_add")
.target(sm)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let gid = b.global_thread_id_x();
let gid_name = gid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(gid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {gid_name};\n \
mul.lo.u64 %rd_off, %rd_off, 2;\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;\n \
ld.global.f16 %f_a, [%rd_a];\n \
ld.global.f16 %f_b, [%rd_b];\n \
add.f16 %f_c, %f_a, %f_b;\n \
st.global.f16 [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_vector_add_ptx_f64(sm: SmVersion) -> Result<String, crate::error::PtxGenError> {
KernelBuilder::new("vector_add")
.target(sm)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(|b| {
let gid = b.global_thread_id_x();
let gid_name = gid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(gid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {gid_name};\n \
mul.lo.u64 %rd_off, %rd_off, 8;\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;\n \
ld.global.f64 %f_a, [%rd_a];\n \
ld.global.f64 %f_b, [%rd_b];\n \
add.f64 %f_c, %f_a, %f_b;\n \
st.global.f64 [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_simple_kernel_ptx(sm: SmVersion) -> Result<String, crate::error::PtxGenError> {
KernelBuilder::new("simple_kernel")
.target(sm)
.param("n", PtxType::U32)
.body(|b| {
b.ret();
})
.build()
}
#[test]
fn p1_vector_add_f32_contains_version_header() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains(".version"),
"PTX must contain .version directive; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_contains_correct_target() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_80"),
"PTX must contain '.target sm_80'; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_contains_entry_name() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains("vector_add"),
"PTX must contain kernel name 'vector_add'; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_contains_parameter_loads() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains("ld.param"),
"PTX must contain parameter loading instructions; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_contains_global_thread_id_computation() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains("mad.lo.u32"),
"PTX must contain global thread ID computation (mad.lo.u32); got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_contains_global_load() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains("ld.global"),
"PTX must contain global memory load; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_contains_add_operation() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains("add.f32") || ptx.contains("fma.rn.f32"),
"PTX must contain add.f32 or fma.rn.f32 operation; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_contains_global_store() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains("st.global"),
"PTX must contain global memory store; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_contains_bounds_check() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains("setp") || ptx.contains("bra"),
"PTX must contain bounds check (setp/bra); got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_well_formed_structure() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains(".visible .entry vector_add"),
"PTX must contain .visible .entry vector_add; got:\n{ptx}"
);
assert!(ptx.contains("ret;"), "PTX must contain ret; got:\n{ptx}");
}
#[test]
fn p1_vector_add_f16_contains_required_elements() {
let ptx =
generate_vector_add_ptx_f16(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains(".version"),
"f16 PTX must contain .version; got:\n{ptx}"
);
assert!(
ptx.contains("vector_add"),
"f16 PTX must contain kernel name; got:\n{ptx}"
);
assert!(
ptx.contains("ld.global"),
"f16 PTX must contain global load; got:\n{ptx}"
);
assert!(
ptx.contains("st.global"),
"f16 PTX must contain global store; got:\n{ptx}"
);
assert!(
ptx.contains("add.f16") || ptx.contains("fma.rn.f16"),
"f16 PTX must contain f16 add; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f64_contains_required_elements() {
let ptx =
generate_vector_add_ptx_f64(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains(".version"),
"f64 PTX must contain .version; got:\n{ptx}"
);
assert!(
ptx.contains("vector_add"),
"f64 PTX must contain kernel name; got:\n{ptx}"
);
assert!(
ptx.contains("ld.global"),
"f64 PTX must contain global load; got:\n{ptx}"
);
assert!(
ptx.contains("st.global"),
"f64 PTX must contain global store; got:\n{ptx}"
);
assert!(
ptx.contains("add.f64") || ptx.contains("fma.rn.f64"),
"f64 PTX must contain f64 add; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_address_size_64() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains(".address_size 64"),
"PTX must contain .address_size 64; got:\n{ptx}"
);
}
#[test]
fn p1_vector_add_f32_has_u64_params() {
let ptx =
generate_vector_add_ptx_f32_raw(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains(".param .u64"),
"PTX must contain .param .u64 for pointer args; got:\n{ptx}"
);
assert!(
ptx.contains(".param .u32"),
"PTX must contain .param .u32 for n; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_target_sm75() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm75).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_75"),
"Expected .target sm_75 in PTX; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_target_sm80() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm80).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_80"),
"Expected .target sm_80 in PTX; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_target_sm86() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm86).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_86"),
"Expected .target sm_86 in PTX; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_target_sm89() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm89).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_89"),
"Expected .target sm_89 in PTX; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_target_sm90() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm90).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_90"),
"Expected .target sm_90 in PTX; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_target_sm90a() {
let ptx =
generate_simple_kernel_ptx(SmVersion::Sm90a).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_90a"),
"Expected .target sm_90a in PTX; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_target_sm100() {
let ptx =
generate_simple_kernel_ptx(SmVersion::Sm100).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_100"),
"Expected .target sm_100 in PTX; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_target_sm120() {
let ptx =
generate_simple_kernel_ptx(SmVersion::Sm120).expect("PTX generation must succeed");
assert!(
ptx.contains(".target sm_120"),
"Expected .target sm_120 in PTX; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_matches_sm75() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm75).expect("PTX generation must succeed");
let expected_version = SmVersion::Sm75.ptx_version();
assert!(
ptx.contains(&format!(".version {expected_version}")),
"Expected .version {expected_version} for sm_75; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_matches_sm80() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm80).expect("PTX generation must succeed");
let expected_version = SmVersion::Sm80.ptx_version();
assert!(
ptx.contains(&format!(".version {expected_version}")),
"Expected .version {expected_version} for sm_80; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_matches_sm86() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm86).expect("PTX generation must succeed");
let expected_version = SmVersion::Sm86.ptx_version();
assert!(
ptx.contains(&format!(".version {expected_version}")),
"Expected .version {expected_version} for sm_86; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_matches_sm89() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm89).expect("PTX generation must succeed");
let expected_version = SmVersion::Sm89.ptx_version();
assert!(
ptx.contains(&format!(".version {expected_version}")),
"Expected .version {expected_version} for sm_89; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_matches_sm90() {
let ptx = generate_simple_kernel_ptx(SmVersion::Sm90).expect("PTX generation must succeed");
let expected_version = SmVersion::Sm90.ptx_version();
assert!(
ptx.contains(&format!(".version {expected_version}")),
"Expected .version {expected_version} for sm_90; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_matches_sm90a() {
let ptx =
generate_simple_kernel_ptx(SmVersion::Sm90a).expect("PTX generation must succeed");
let expected_version = SmVersion::Sm90a.ptx_version();
assert!(
ptx.contains(&format!(".version {expected_version}")),
"Expected .version {expected_version} for sm_90a; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_matches_sm100() {
let ptx =
generate_simple_kernel_ptx(SmVersion::Sm100).expect("PTX generation must succeed");
let expected_version = SmVersion::Sm100.ptx_version();
assert!(
ptx.contains(&format!(".version {expected_version}")),
"Expected .version {expected_version} for sm_100; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_matches_sm120() {
let ptx =
generate_simple_kernel_ptx(SmVersion::Sm120).expect("PTX generation must succeed");
let expected_version = SmVersion::Sm120.ptx_version();
assert!(
ptx.contains(&format!(".version {expected_version}")),
"Expected .version {expected_version} for sm_120; got:\n{ptx}"
);
}
#[test]
fn p8_ptx_version_ordering_respects_sm_ordering() {
let versions_in_order = [
SmVersion::Sm75,
SmVersion::Sm80,
SmVersion::Sm86,
SmVersion::Sm89,
SmVersion::Sm90,
SmVersion::Sm90a,
SmVersion::Sm100,
SmVersion::Sm120,
];
let isa_versions: Vec<(u32, u32)> = versions_in_order
.iter()
.map(|sm| sm.ptx_isa_version())
.collect();
for window in isa_versions.windows(2) {
let (prev_major, prev_minor) = window[0];
let (curr_major, curr_minor) = window[1];
assert!(
(curr_major, curr_minor) >= (prev_major, prev_minor),
"PTX ISA version must be non-decreasing across SM versions: \
{prev_major}.{prev_minor} -> {curr_major}.{curr_minor}"
);
}
}
#[test]
fn p8_fp8_types_require_sm89_or_higher() {
let no_fp8 = [SmVersion::Sm75, SmVersion::Sm80, SmVersion::Sm86];
for sm in no_fp8 {
assert!(
!sm.capabilities().has_fp8,
"{sm} should not have FP8 support"
);
}
let has_fp8 = [
SmVersion::Sm89,
SmVersion::Sm90,
SmVersion::Sm90a,
SmVersion::Sm100,
SmVersion::Sm120,
];
for sm in has_fp8 {
assert!(sm.capabilities().has_fp8, "{sm} should have FP8 support");
}
}
#[test]
fn p8_fp8_param_type_in_ptx_sm89() {
let ptx = KernelBuilder::new("fp8_kernel")
.target(SmVersion::Sm89)
.param("scale", PtxType::E4M3)
.body(|b| {
b.ret();
})
.build()
.expect("PTX generation must succeed");
assert!(
ptx.contains(".param .b8"),
"SM 89 kernel with E4M3 param must have .param .b8; got:\n{ptx}"
);
assert!(
ptx.contains(".target sm_89"),
"Must target sm_89; got:\n{ptx}"
);
}
#[test]
fn p8_sm80_has_no_fp8_capability() {
let caps = SmVersion::Sm80.capabilities();
assert!(!caps.has_fp8, "SM 80 must not advertise FP8 capability");
assert!(
!caps.has_fp6_fp4,
"SM 80 must not advertise FP6/FP4 capability"
);
}
#[test]
fn p8_all_sm_produce_address_size_64() {
let all_sm = [
SmVersion::Sm75,
SmVersion::Sm80,
SmVersion::Sm86,
SmVersion::Sm89,
SmVersion::Sm90,
SmVersion::Sm90a,
SmVersion::Sm100,
SmVersion::Sm120,
];
for sm in all_sm {
let ptx = generate_simple_kernel_ptx(sm)
.unwrap_or_else(|e| panic!("PTX generation failed for {sm}: {e}"));
assert!(
ptx.contains(".address_size 64"),
"{sm} PTX must contain .address_size 64; got:\n{ptx}"
);
}
}
#[test]
fn p8_all_sm_produce_non_empty_ptx() {
let all_sm = [
SmVersion::Sm75,
SmVersion::Sm80,
SmVersion::Sm86,
SmVersion::Sm89,
SmVersion::Sm90,
SmVersion::Sm90a,
SmVersion::Sm100,
SmVersion::Sm120,
];
for sm in all_sm {
let ptx = generate_simple_kernel_ptx(sm)
.unwrap_or_else(|e| panic!("PTX generation failed for {sm}: {e}"));
assert!(!ptx.is_empty(), "PTX for {sm} must not be empty");
assert!(
ptx.len() > 50,
"PTX for {sm} is suspiciously short: {ptx:?}"
);
}
}
}