pub const AC_CUDASAFETY_MIN_TRANSFORM_STAGES: usize = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CudaSafetyVerdict {
Pass,
Fail,
}
#[must_use]
#[allow(clippy::fn_params_excessive_bools)]
pub fn verdict_from_kernel_ffi(
has_global_qualifier: bool,
output_has_extern_c: bool,
output_has_matching_fn_name: bool,
pointer_params_are_raw_mut: bool,
) -> CudaSafetyVerdict {
if has_global_qualifier
&& output_has_extern_c
&& output_has_matching_fn_name
&& pointer_params_are_raw_mut
{
CudaSafetyVerdict::Pass
} else {
CudaSafetyVerdict::Fail
}
}
#[must_use]
pub fn verdict_from_host_transpilation(
has_global_or_device: bool,
output_is_extern_c: bool,
output_is_regular_fn: bool,
) -> CudaSafetyVerdict {
if !has_global_or_device && !output_is_extern_c && output_is_regular_fn {
CudaSafetyVerdict::Pass
} else {
CudaSafetyVerdict::Fail
}
}
#[must_use]
pub fn verdict_from_qualifier_preservation(
input_has_qualifier: bool,
qualifier_at_stage: &[bool],
) -> CudaSafetyVerdict {
if !input_has_qualifier {
return CudaSafetyVerdict::Fail;
}
if qualifier_at_stage.len() < AC_CUDASAFETY_MIN_TRANSFORM_STAGES {
return CudaSafetyVerdict::Fail;
}
if qualifier_at_stage.iter().all(|&b| b) {
CudaSafetyVerdict::Pass
} else {
CudaSafetyVerdict::Fail
}
}
#[must_use]
pub fn verdict_from_keyword_detection(
source_has_global_token: bool,
parser_in_cpp_mode: bool,
empty_macros_injected: bool,
) -> CudaSafetyVerdict {
if source_has_global_token && parser_in_cpp_mode && empty_macros_injected {
CudaSafetyVerdict::Pass
} else {
CudaSafetyVerdict::Fail
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn provenance_min_transform_stages_is_5() {
assert_eq!(AC_CUDASAFETY_MIN_TRANSFORM_STAGES, 5);
}
#[test]
fn fcuda001_pass_kernel_with_extern_c() {
let v = verdict_from_kernel_ffi(true, true, true, true);
assert_eq!(v, CudaSafetyVerdict::Pass);
}
#[test]
fn fcuda001_fail_no_global_qualifier() {
let v = verdict_from_kernel_ffi(false, true, true, true);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda001_fail_missing_extern_c() {
let v = verdict_from_kernel_ffi(true, false, true, true);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda001_fail_fn_name_mismatch() {
let v = verdict_from_kernel_ffi(true, true, false, true);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda001_fail_safe_pointers_not_raw() {
let v = verdict_from_kernel_ffi(true, true, true, false);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda002_pass_host_function_normal_rust() {
let v = verdict_from_host_transpilation(false, false, true);
assert_eq!(v, CudaSafetyVerdict::Pass);
}
#[test]
fn fcuda002_fail_global_treated_as_host() {
let v = verdict_from_host_transpilation(true, false, true);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda002_fail_host_wrapped_extern_c() {
let v = verdict_from_host_transpilation(false, true, false);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda002_fail_no_regular_fn() {
let v = verdict_from_host_transpilation(false, false, false);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda003_pass_qualifier_survives_5_stages() {
let v = verdict_from_qualifier_preservation(true, &[true; 5]);
assert_eq!(v, CudaSafetyVerdict::Pass);
}
#[test]
fn fcuda003_pass_qualifier_survives_more_stages() {
let v = verdict_from_qualifier_preservation(true, &[true; 8]);
assert_eq!(v, CudaSafetyVerdict::Pass);
}
#[test]
fn fcuda003_fail_no_input_qualifier() {
let v = verdict_from_qualifier_preservation(false, &[true; 5]);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda003_fail_qualifier_dropped_mid_chain() {
let v = verdict_from_qualifier_preservation(true, &[true, true, false, true, true]);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda003_fail_too_few_stages() {
let v = verdict_from_qualifier_preservation(true, &[true; 3]);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda003_fail_zero_stages() {
let v = verdict_from_qualifier_preservation(true, &[]);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda004_pass_global_detected_cpp_mode_macros() {
let v = verdict_from_keyword_detection(true, true, true);
assert_eq!(v, CudaSafetyVerdict::Pass);
}
#[test]
fn fcuda004_fail_no_global_token() {
let v = verdict_from_keyword_detection(false, true, true);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda004_fail_parser_not_cpp_mode() {
let v = verdict_from_keyword_detection(true, false, true);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn fcuda004_fail_empty_macros_not_injected() {
let v = verdict_from_keyword_detection(true, true, false);
assert_eq!(v, CudaSafetyVerdict::Fail);
}
#[test]
fn mutation_survey_001_only_when_all_4_inputs_true() {
for mask in 0_u8..16 {
let q = mask & 1 != 0;
let e = mask & 2 != 0;
let f = mask & 4 != 0;
let p = mask & 8 != 0;
let v = verdict_from_kernel_ffi(q, e, f, p);
let expected = if q && e && f && p {
CudaSafetyVerdict::Pass
} else {
CudaSafetyVerdict::Fail
};
assert_eq!(v, expected, "mask={mask:04b}");
}
}
#[test]
fn mutation_survey_002_pass_only_when_no_qual_no_extern_yes_fn() {
for mask in 0_u8..8 {
let g = mask & 1 != 0;
let e = mask & 2 != 0;
let f = mask & 4 != 0;
let v = verdict_from_host_transpilation(g, e, f);
let expected = if !g && !e && f {
CudaSafetyVerdict::Pass
} else {
CudaSafetyVerdict::Fail
};
assert_eq!(v, expected, "mask={mask:03b}");
}
}
#[test]
fn realistic_healthy_decy_transpilation_passes_all_4() {
let v1 = verdict_from_kernel_ffi(true, true, true, true);
let v2 = verdict_from_host_transpilation(false, false, true);
let v3 = verdict_from_qualifier_preservation(true, &[true; 5]);
let v4 = verdict_from_keyword_detection(true, true, true);
assert_eq!(v1, CudaSafetyVerdict::Pass);
assert_eq!(v2, CudaSafetyVerdict::Pass);
assert_eq!(v3, CudaSafetyVerdict::Pass);
assert_eq!(v4, CudaSafetyVerdict::Pass);
}
#[test]
fn realistic_pre_fix_all_4_failures() {
let v1 = verdict_from_kernel_ffi(true, false, true, true);
let v2 = verdict_from_host_transpilation(true, false, true);
let v3 = verdict_from_qualifier_preservation(true, &[true, true, false, true, true]);
let v4 = verdict_from_keyword_detection(true, false, true);
assert_eq!(v1, CudaSafetyVerdict::Fail);
assert_eq!(v2, CudaSafetyVerdict::Fail);
assert_eq!(v3, CudaSafetyVerdict::Fail);
assert_eq!(v4, CudaSafetyVerdict::Fail);
}
}