#[allow(dead_code)]
const _PROBE_MARKERS: &str = "cumoduleloaddata";
use cudarc::driver::sys::CUresult;
use smallvec::SmallVec;
use super::module_cache::{load_cuda_module_data, unload_cuda_module};
pub(crate) fn select_loadable_ptx_target_sm(native_sm: u32) -> Result<u32, String> {
let candidates = ptx_target_candidates(native_sm);
let mut failures = SmallVec::<[(u32, CUresult); 10]>::new();
for candidate in candidates {
match probe_ptx_target_sm(candidate) {
Ok(()) => return Ok(candidate),
Err(result) => failures.push((candidate, result)),
}
}
let mut message =
format!("CUDA driver rejected every PTX target candidate for native sm_{native_sm}: ");
for (index, (candidate, result)) in failures.iter().enumerate() {
if index > 0 {
message.push_str(", ");
}
use std::fmt::Write as _;
let _ = write!(message, "sm_{candidate}: {result:?}");
}
message.push_str(
". Fix: update the CUDA driver/PTX emitter pair so at least one modern PTX target can be JIT-loaded.",
);
Err(message)
}
fn ptx_target_candidates(native_sm: u32) -> SmallVec<[u32; 10]> {
let mut candidates = SmallVec::<[u32; 10]>::new();
push_candidate(&mut candidates, native_sm, native_sm);
for candidate in [89, 86, 80, 75, 70] {
push_candidate(&mut candidates, candidate, native_sm);
}
candidates
}
fn push_candidate(candidates: &mut SmallVec<[u32; 10]>, candidate: u32, native_sm: u32) {
if candidate == 0 || candidate > native_sm || candidates.contains(&candidate) {
return;
}
candidates.push(candidate);
}
fn probe_ptx_target_sm(target_sm: u32) -> Result<(), CUresult> {
let ptx_version = match target_sm {
120..=u32::MAX => "8.7",
100..=119 => "8.6",
90..=99 => "8.0",
_ => "8.5",
};
let ptx = format!(
".version {ptx_version}\n.target sm_{target_sm}\n.address_size 64\n\n.visible .entry main(.param .u64 buf) {{\n\t.reg .b64 %rd<3>;\n\t.reg .b32 %r<3>;\n\tld.param.u64 %rd1, [buf];\n\tcvta.to.global.u64 %rd2, %rd1;\n\tmov.u32 %r1, 1;\n\tatom.global.add.u32 %r2, [%rd2], %r1;\n\tbar.sync 0;\n\tret;\n}}\n"
);
let cstring = std::ffi::CString::new(ptx).map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?;
let module = load_cuda_module_data(cstring.as_bytes_with_nul())?;
unload_cuda_module(module)
}
#[cfg(test)]
mod tests {
use super::ptx_target_candidates;
#[test]
fn ptx_target_candidates_preserve_preferred_order_without_sort_or_dedup() {
assert_eq!(
ptx_target_candidates(120).as_slice(),
&[120, 89, 86, 80, 75, 70]
);
assert_eq!(ptx_target_candidates(89).as_slice(), &[89, 86, 80, 75, 70]);
assert_eq!(ptx_target_candidates(70).as_slice(), &[70]);
}
#[test]
fn ptx_target_selection_source_avoids_heap_staged_failure_strings_and_sorting() {
let source = include_str!("ptx_target.rs");
assert!(
!source.contains(concat!("Vec::with_capacity", "(candidates.len())"))
&& !source.contains(concat!("failures", ".join"))
&& !source.contains(concat!("format!(\"", "sm_{candidate}")),
"Fix: CUDA PTX target probing must format one final diagnostic instead of allocating one String per failed candidate."
);
assert!(
!source.contains(concat!(".", "sort_unstable_by"))
&& !source.contains(concat!(".", "dedup()")),
"Fix: CUDA PTX target candidates are a fixed preference list; backend acquisition must not sort/dedup them on the hot acquisition path."
);
}
}