use super::*;
use proptest::prelude::*;
use std::path::Path;
const CONTRACT_PATH: &str =
concat!(env!("CARGO_MANIFEST_DIR"), "/../../aprender/contracts/kernel-fusion-v1.yaml");
fn read_contract() -> String {
let path = Path::new(CONTRACT_PATH);
assert!(
path.exists(),
"kernel-fusion-v1.yaml contract not found at {CONTRACT_PATH}. \
Ensure aprender is checked out as a sibling of trueno."
);
std::fs::read_to_string(path).unwrap_or_else(|e| panic!("Failed to read {CONTRACT_PATH}: {e}"))
}
fn all_fused_kernels() -> Vec<(&'static str, String)> {
vec![
("FusedSwigluKernel", FusedSwigluKernel::new(4096).name().to_string()),
("BatchedSwigluKernel", BatchedSwigluKernel::new(4096, 4).name().to_string()),
("FusedQKVKernel", FusedQKVKernel::new(3584, 512).name().to_string()),
("FusedGateUpKernel", FusedGateUpKernel::new(3584, 18944).name().to_string()),
(
"FusedGemmBiasGeluKernel",
FusedGemmBiasGeluKernel::new(512, 2048, 512).name().to_string(),
),
(
"FusedRmsNormQ4KGemvKernel",
FusedRmsNormQ4KGemvKernel::new(3584, 3584).name().to_string(),
),
("FusedGateUpQ4KGemvKernel", FusedGateUpQ4KGemvKernel::new(3584, 18944).name().to_string()),
(
"FusedRmsNormGateUpSwigluQ4KKernel",
FusedRmsNormGateUpSwigluQ4KKernel::new(3584, 18944).name().to_string(),
),
]
}
#[test]
fn falsify_fusion_001_every_fused_kernel_has_yaml_entry() {
let yaml = read_contract();
for (struct_name, kernel_name) in all_fused_kernels() {
let found = yaml.contains(struct_name) || yaml.contains(&kernel_name);
assert!(
found,
"Fused kernel '{struct_name}' (name='{kernel_name}') is NOT documented \
in kernel-fusion-v1.yaml. Every fused kernel MUST have a contract entry. \
Add a fusion_decisions entry for this kernel."
);
}
}
fn is_fusion_entry_key(line: &str, trimmed: &str) -> bool {
!trimmed.starts_with('#')
&& !trimmed.is_empty()
&& line.starts_with(" ")
&& !line.starts_with(" ")
&& trimmed.ends_with(':')
&& !trimmed.starts_with('-')
}
fn has_valid_tok_s_value(trimmed: &str, prefix: &str) -> bool {
if !trimmed.starts_with(prefix) {
return false;
}
let value = trimmed.trim_start_matches(prefix).trim();
value != "null" && !value.is_empty()
}
#[test]
fn falsify_fusion_002_blocked_entries_have_benchmarks() {
let yaml = read_contract();
let mut in_blocked_section = false;
let mut current_entry_name = String::new();
let mut found_unfused = false;
let mut found_fused = false;
let mut checked_entries = 0;
for line in yaml.lines() {
let trimmed = line.trim();
if is_fusion_entry_key(line, trimmed) {
if in_blocked_section {
assert!(
found_unfused && found_fused,
"BLOCKED entry '{current_entry_name}' is missing benchmark data. \
unfused_tok_s present: {found_unfused}, fused_tok_s present: {found_fused}. \
BLOCKED fusions MUST have measured tok/s for both fused and unfused paths."
);
checked_entries += 1;
}
current_entry_name = trimmed.trim_end_matches(':').to_string();
in_blocked_section = false;
found_unfused = false;
found_fused = false;
}
if trimmed.contains("status:") && trimmed.contains("BLOCKED") {
in_blocked_section = true;
}
if in_blocked_section {
if has_valid_tok_s_value(trimmed, "unfused_tok_s:") {
found_unfused = true;
}
if has_valid_tok_s_value(trimmed, "fused_tok_s:") {
found_fused = true;
}
}
}
if in_blocked_section {
assert!(
found_unfused && found_fused,
"BLOCKED entry '{current_entry_name}' is missing benchmark data. \
unfused_tok_s present: {found_unfused}, fused_tok_s present: {found_fused}."
);
checked_entries += 1;
}
assert!(
checked_entries > 0,
"No BLOCKED entries found in contract — expected at least FUSION-003 \
(rmsnorm_gate_up_swiglu_fused_q4k). Is the contract format changed?"
);
}
fn is_valid_call_site(trimmed: &str) -> Option<bool> {
if !trimmed.starts_with("call_site:") {
return None;
}
let value = trimmed.trim_start_matches("call_site:").trim().trim_matches('"');
Some(!value.contains("NOT WIRED") && !value.is_empty())
}
#[test]
fn falsify_fusion_003_active_entries_have_call_site() {
let yaml = read_contract();
let mut in_active_section = false;
let mut current_entry_name = String::new();
let mut found_call_site = false;
let mut call_site_is_wired = false;
let mut checked_entries = 0;
for line in yaml.lines() {
let trimmed = line.trim();
if is_fusion_entry_key(line, trimmed) {
if in_active_section {
assert!(
found_call_site && call_site_is_wired,
"ACTIVE entry '{current_entry_name}' has no valid call_site. \
found_call_site: {found_call_site}, is_wired: {call_site_is_wired}. \
ACTIVE fusions MUST have a call_site that actually dispatches the kernel."
);
checked_entries += 1;
}
current_entry_name = trimmed.trim_end_matches(':').to_string();
in_active_section = false;
found_call_site = false;
call_site_is_wired = false;
}
if trimmed.contains("status:") && trimmed.contains("ACTIVE") {
in_active_section = true;
}
if in_active_section {
if let Some(is_wired) = is_valid_call_site(trimmed) {
found_call_site = true;
call_site_is_wired = is_wired;
}
}
}
if in_active_section {
assert!(
found_call_site && call_site_is_wired,
"ACTIVE entry '{current_entry_name}' has no valid call_site. \
found_call_site: {found_call_site}, is_wired: {call_site_is_wired}."
);
checked_entries += 1;
}
assert!(
checked_entries >= 5,
"Only found {checked_entries} ACTIVE entries with valid call_sites — \
expected at least 5. Has the contract format changed?"
);
}
#[test]
fn falsify_fusion_004_no_comment_only_decisions() {
let kernels_dir = concat!(env!("CARGO_MANIFEST_DIR"), "/src/kernels");
let kernels_path = Path::new(kernels_dir);
assert!(kernels_path.is_dir(), "Kernels directory not found at {kernels_dir}");
let suspect_patterns: &[&str] = &[
"fused.*blocked",
"fused.*disabled",
"fused.*slower",
"fusion.*blocked",
"fusion.*disabled",
"don't use.*fused",
"do not use.*fused",
];
let contract_references: &[&str] =
&["kernel-fusion-v1", "FUSION-0", "fusion_decisions", "F-FUSION-001"];
let self_file = Path::new(file!());
let self_filename = self_file.file_name().unwrap_or_default();
let mut violations = Vec::new();
scan_directory_for_comment_violations(
kernels_path,
suspect_patterns,
contract_references,
self_filename.to_str().unwrap_or("fusion_contract_falsify.rs"),
&mut violations,
);
assert!(
violations.is_empty(),
"Found fusion decisions in code comments WITHOUT contract references:\n{}",
violations.join("\n")
);
}
fn matches_suspect_pattern(lower: &str, pattern: &str) -> bool {
let parts: Vec<&str> = pattern.split('*').collect();
let mut pos = 0;
parts.iter().all(|part| {
if let Some(found) = lower[pos..].find(part) {
pos += found + part.len();
true
} else {
false
}
})
}
fn check_file_for_comment_violations(
path: &Path,
suspect_patterns: &[&str],
contract_references: &[&str],
violations: &mut Vec<String>,
) {
let Ok(content) = std::fs::read_to_string(path) else {
return;
};
for (line_num, line) in content.lines().enumerate() {
let trimmed = line.trim();
if !trimmed.starts_with("//") && !trimmed.starts_with("///") {
continue;
}
let lower = trimmed.to_lowercase();
let is_suspect = suspect_patterns.iter().any(|pat| matches_suspect_pattern(&lower, pat));
if is_suspect {
let has_contract_ref = contract_references.iter().any(|r| trimmed.contains(r));
if !has_contract_ref {
violations.push(format!(" {}:{}: {}", path.display(), line_num + 1, trimmed));
}
}
}
}
fn scan_directory_for_comment_violations(
dir: &Path,
suspect_patterns: &[&str],
contract_references: &[&str],
skip_filename: &str,
violations: &mut Vec<String>,
) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
scan_directory_for_comment_violations(
&path,
suspect_patterns,
contract_references,
skip_filename,
violations,
);
} else if path.extension() == Some(std::ffi::OsStr::new("rs")) {
if path.file_name() == Some(std::ffi::OsStr::new(skip_filename)) {
continue;
}
check_file_for_comment_violations(
&path,
suspect_patterns,
contract_references,
violations,
);
}
}
}
#[test]
fn falsify_fusion_005_orphan_detection() {
let yaml = read_contract();
let mut yaml_kernel_names: Vec<String> = Vec::new();
for line in yaml.lines() {
let trimmed = line.trim();
if trimmed.starts_with("fused:") {
let value = trimmed.trim_start_matches("fused:").trim().trim_matches('"');
if let Some(struct_name) = value.split_whitespace().next() {
yaml_kernel_names.push(struct_name.to_string());
}
}
}
assert!(
!yaml_kernel_names.is_empty(),
"No kernel names found in YAML `fused:` fields — contract parsing may be broken."
);
let known_kernel_structs: Vec<&str> =
all_fused_kernels().iter().map(|(name, _)| *name).collect();
let mut orphans = Vec::new();
for yaml_name in &yaml_kernel_names {
if !known_kernel_structs.contains(&yaml_name.as_str()) {
orphans.push(yaml_name.clone());
}
}
assert!(
orphans.is_empty(),
"Orphaned contract entries — these kernel struct names appear in the YAML \
but cannot be instantiated in trueno-gpu: {:?}. \
Either the kernel was deleted (remove the contract entry) or renamed \
(update the contract entry).",
orphans
);
}
proptest! {
#[test]
fn falsify_fusion_001_prop(random_name in "[a-z]{5,20}") {
let yaml = read_contract();
let known_names: Vec<String> = all_fused_kernels()
.into_iter()
.map(|(_, name)| name)
.collect();
let is_known = known_names.iter().any(|k| k == &random_name);
prop_assert!(
!is_known,
"Random name '{random_name}' collided with a real kernel name — \
kernel naming scheme may be too generic."
);
let as_key = format!(" {}:", random_name);
prop_assert!(
!yaml.contains(&as_key),
"Random name '{random_name}' matched a YAML fusion_decisions key — \
contract keys should use specific, descriptive names."
);
}
#[test]
fn falsify_fusion_002_prop(
unfused_present in proptest::bool::ANY,
fused_present in proptest::bool::ANY,
unfused_val in 1.0f64..500.0,
fused_val in 1.0f64..500.0,
) {
let unfused_line = if unfused_present {
format!(" unfused_tok_s: {unfused_val:.1}")
} else {
" unfused_tok_s: null".to_string()
};
let fused_line = if fused_present {
format!(" fused_tok_s: {fused_val:.1}")
} else {
" fused_tok_s: null".to_string()
};
let synthetic_entry = format!(
" test_blocked_entry:\n\
\x20 status: \"BLOCKED\"\n\
\x20 benchmark:\n\
{unfused_line}\n\
{fused_line}\n"
);
let has_unfused = unfused_present;
let has_fused = fused_present;
let is_valid = has_unfused && has_fused;
let mut parsed_unfused = false;
let mut parsed_fused = false;
for line in synthetic_entry.lines() {
let trimmed = line.trim();
if trimmed.starts_with("unfused_tok_s:") {
let val = trimmed.trim_start_matches("unfused_tok_s:").trim();
if val != "null" && !val.is_empty() {
parsed_unfused = true;
}
}
if trimmed.starts_with("fused_tok_s:") {
let val = trimmed.trim_start_matches("fused_tok_s:").trim();
if val != "null" && !val.is_empty() {
parsed_fused = true;
}
}
}
let parsed_valid = parsed_unfused && parsed_fused;
prop_assert_eq!(
is_valid,
parsed_valid,
"Benchmark validation mismatch: expected valid={}, \
parsed valid={} (unfused={}, fused={})",
is_valid,
parsed_valid,
has_unfused,
has_fused
);
}
#[test]
fn falsify_fusion_003_prop(
typo_segment in "(src|lib|mod|main|test|bench)",
extension in "(rs|py|toml|yaml)",
line_num in 1u32..9999,
) {
let typo_path = format!("realizar/{typo_segment}/nonexistent.{extension}:{line_num}");
let is_not_wired = typo_path.contains("NOT WIRED");
prop_assert!(
!is_not_wired,
"Generated path should never contain 'NOT WIRED'"
);
let has_colon_line = typo_path.contains(':');
prop_assert!(
has_colon_line,
"Call site should have file:line format, got: {typo_path}"
);
let not_wired_value = "NOT WIRED -- see PAR-077";
prop_assert!(
not_wired_value.contains("NOT WIRED"),
"'NOT WIRED' sentinel must be detectable in call_site values"
);
let empty_value = "";
prop_assert!(
empty_value.is_empty(),
"Empty call_site must be detectable"
);
}
}