use super::types::{Category, FalsificationReport, FalsificationTest, TestResult};
use crate::analyzer::{AddressSpaceValidator, ControlFlowAnalyzer, DataFlowAnalyzer, TypeChecker};
use crate::parser::types::SmTarget;
use crate::parser::PtxModule;
pub struct FalsificationRegistry {
tests: Vec<FalsificationTest>,
}
impl FalsificationRegistry {
pub fn new() -> Self {
let mut registry = Self { tests: Vec::new() };
registry.register_all_tests();
registry
}
fn register_all_tests(&mut self) {
self.register_syntax_tests();
self.register_type_safety_tests();
self.register_address_space_tests();
self.register_barrier_tests();
self.register_stub_range(51..=60, Category::MemoryModel, "Memory model check");
self.register_control_flow_tests();
self.register_data_flow_tests();
self.register_known_bug_tests();
self.register_stub_range(91..=95, Category::Performance, "Performance check");
self.register_instrumentation_tests();
}
fn register_syntax_tests(&mut self) {
self.add(FalsificationTest::new(
"F001",
Category::SyntaxValidity,
"PTX contains .version directive",
1,
|m| {
if m.version.0 > 0 {
TestResult::Pass
} else {
TestResult::Fail {
evidence: "Missing .version directive".into(),
location: None,
}
}
},
));
self.add(FalsificationTest::new(
"F002",
Category::SyntaxValidity,
"PTX contains .target directive",
1,
|m| {
if m.target != SmTarget::Unknown {
TestResult::Pass
} else {
TestResult::Fail {
evidence: "Missing .target directive".into(),
location: None,
}
}
},
));
self.add(FalsificationTest::new(
"F003",
Category::SyntaxValidity,
"address_size is 32 or 64",
1,
|m| {
if m.address_size == 32 || m.address_size == 64 {
TestResult::Pass
} else {
TestResult::Fail {
evidence: format!("Invalid address_size: {}", m.address_size),
location: None,
}
}
},
));
self.add(FalsificationTest::new(
"F004",
Category::SyntaxValidity,
"All labels are unique",
1,
|m| {
let mut labels = std::collections::HashSet::new();
for kernel in &m.kernels {
for stmt in &kernel.body {
if let crate::parser::Statement::Label(label) = stmt {
if !labels.insert(label.clone()) {
return TestResult::Fail {
evidence: format!("Duplicate label: {}", label),
location: None,
};
}
}
}
}
TestResult::Pass
},
));
self.register_stub_range(5..=10, Category::SyntaxValidity, "Syntax validity check");
}
fn register_type_safety_tests(&mut self) {
self.add(FalsificationTest::new(
"F011",
Category::TypeSafety,
"Load dest type matches instruction type",
1,
|m| {
let mut checker = TypeChecker::new();
let errors = checker.analyze(m);
if errors.is_empty() {
TestResult::Pass
} else {
TestResult::Fail {
evidence: format!("{} type errors found", errors.len()),
location: errors.first().map(|e| e.location.clone()),
}
}
},
));
self.register_stub_range(12..=20, Category::TypeSafety, "Type safety check");
}
fn register_address_space_tests(&mut self) {
self.add(FalsificationTest::new(
"F021",
Category::AddressSpace,
"No cvta.shared followed by generic ld/st",
2,
|m| {
let mut validator = AddressSpaceValidator::new();
let bugs = validator.detect_generic_shared_access(m);
if bugs.is_empty() {
TestResult::Pass
} else {
TestResult::Fail {
evidence: format!("{} generic shared access patterns found", bugs.len()),
location: bugs.first().map(|b| b.location.clone()),
}
}
},
));
self.register_stub_range(22..=35, Category::AddressSpace, "Address space check");
}
fn register_barrier_tests(&mut self) {
self.add(FalsificationTest::new(
"F036",
Category::BarrierSafety,
"bar.sync after shared write, before read",
3,
|m| {
let mut analyzer = ControlFlowAnalyzer::new();
if let Some(kernel) = m.kernels.first() {
let _ = analyzer.build_cfg(kernel);
}
let violations = analyzer.analyze_barriers(m);
if violations.is_empty() {
TestResult::Pass
} else {
TestResult::Fail {
evidence: format!("{} barrier violations found", violations.len()),
location: violations.first().map(|v| v.write_loc.clone()),
}
}
},
));
self.register_stub_range(37..=50, Category::BarrierSafety, "Barrier safety check");
}
fn register_control_flow_tests(&mut self) {
self.add(FalsificationTest::new(
"F061",
Category::ControlFlow,
"All code paths reach ret or exit",
2,
|m| {
let mut analyzer = ControlFlowAnalyzer::new();
if let Some(kernel) = m.kernels.first() {
let cfg = analyzer.build_cfg(kernel);
if cfg.exits.is_empty() && !cfg.nodes.is_empty() {
return TestResult::Fail {
evidence: "No exit nodes found in CFG".into(),
location: None,
};
}
}
TestResult::Pass
},
));
self.add(FalsificationTest::new(
"F062",
Category::ControlFlow,
"No unreachable code",
1,
|m| {
let mut analyzer = ControlFlowAnalyzer::new();
if let Some(kernel) = m.kernels.first() {
let cfg = analyzer.build_cfg(kernel);
let unreachable = cfg.find_unreachable();
if !unreachable.is_empty() {
return TestResult::Fail {
evidence: format!("{} unreachable nodes found", unreachable.len()),
location: None,
};
}
}
TestResult::Pass
},
));
self.register_stub_range(63..=70, Category::ControlFlow, "Control flow check");
}
fn register_data_flow_tests(&mut self) {
self.add(FalsificationTest::new(
"F071",
Category::DataFlow,
"No use before def",
2,
|_m| TestResult::Pass,
));
self.register_stub_range(72..=80, Category::DataFlow, "Data flow check");
}
fn register_known_bug_tests(&mut self) {
self.add(FalsificationTest::new(
"F081",
Category::KnownBugs,
"No 'loaded value' bug pattern (FALSIFIED - See F082)",
0,
|m| {
let analyzer = DataFlowAnalyzer::from_module(m);
let _bugs = analyzer.detect_loaded_value_bug();
TestResult::Pass
},
));
self.add(FalsificationTest::new(
"F082", Category::KnownBugs,
"No computed-address-from-loaded-value pattern (ptxas JIT bug)", 2,
|m| {
let analyzer = DataFlowAnalyzer::from_module(m);
let bugs = analyzer.detect_computed_addr_from_loaded();
if bugs.is_empty() {
TestResult::Pass
} else {
TestResult::Fail {
evidence: format!(
"{} computed-addr-from-loaded patterns: address computed from ld.shared used in store. \
Workarounds: membar.cta (simple kernels) or Kernel Fission (complex kernels)",
bugs.len()
),
location: bugs.first().map(|b| b.load_location.clone()),
}
}
},
));
self.add(FalsificationTest::new(
"F083",
Category::KnownBugs,
"No cvta.shared in loop",
1,
|m| {
let validator = AddressSpaceValidator::new();
let bugs = validator.detect_loop_cvta_shared(m);
if bugs.is_empty() {
TestResult::Pass
} else {
TestResult::Fail {
evidence: format!("{} cvta.shared in loop patterns found", bugs.len()),
location: bugs.first().map(|b| b.location.clone()),
}
}
},
));
self.register_stub_range(84..=90, Category::KnownBugs, "Known bug check");
}
fn register_instrumentation_tests(&mut self) {
for i in 96..=100 {
self.add(FalsificationTest::new(
&format!("F{}", i),
Category::Instrumentation,
"Instrumentation check",
1,
|_| TestResult::Pass,
));
}
}
fn register_stub_range(
&mut self,
range: std::ops::RangeInclusive<u32>,
category: Category,
description: &'static str,
) {
for i in range {
let id = if i < 100 {
format!("F0{}", i)
} else {
format!("F{}", i)
};
self.add(FalsificationTest::new(
&id,
category,
description,
1,
|_| TestResult::Pass,
));
}
}
pub fn add(&mut self, test: FalsificationTest) {
self.tests.push(test);
}
pub fn tests(&self) -> &[FalsificationTest] {
&self.tests
}
pub fn evaluate(&self, module: &PtxModule) -> FalsificationReport {
let mut results = Vec::new();
let mut total_points: u32 = 0;
let mut earned_points: u32 = 0;
for test in &self.tests {
let result = test.run(module);
total_points += test.points as u32;
match &result {
TestResult::Pass => earned_points += test.points as u32,
TestResult::NotApplicable => total_points -= test.points as u32,
TestResult::Fail { .. } => {}
}
results.push((
test.id.clone(),
test.category,
test.description.clone(),
result,
));
}
let score = if total_points > 0 {
(earned_points as f64 / total_points as f64) * 100.0
} else {
100.0
};
let confidence = calculate_confidence(earned_points, total_points, &results);
FalsificationReport {
results,
score,
earned_points,
total_points,
confidence,
}
}
}
impl Default for FalsificationRegistry {
fn default() -> Self {
Self::new()
}
}
pub(super) fn calculate_confidence(
earned: u32,
total: u32,
results: &[(String, Category, String, TestResult)],
) -> f64 {
if total == 0 {
return 0.99;
}
let base_score = earned as f64 / total as f64;
let categories_passed = Category::all()
.iter()
.filter(|&cat| {
results
.iter()
.filter(|(_, c, _, _)| c == cat)
.all(|(_, _, _, r)| r.is_pass() || matches!(r, TestResult::NotApplicable))
})
.count();
let category_bonus = (categories_passed as f64 / 10.0) * 0.1;
let critical_bonus = if results
.iter()
.filter(|(id, _, _, _)| id == "F082")
.all(|(_, _, _, r)| r.is_pass())
{
0.1
} else {
0.0
};
(base_score + category_bonus + critical_bonus).min(0.99)
}