1use super::types::{Category, FalsificationReport, FalsificationTest, TestResult};
4use crate::analyzer::{AddressSpaceValidator, ControlFlowAnalyzer, DataFlowAnalyzer, TypeChecker};
5use crate::parser::types::SmTarget;
6use crate::parser::PtxModule;
7
8pub struct FalsificationRegistry {
10 tests: Vec<FalsificationTest>,
11}
12
13impl FalsificationRegistry {
14 pub fn new() -> Self {
16 let mut registry = Self { tests: Vec::new() };
17 registry.register_all_tests();
18 registry
19 }
20
21 fn register_all_tests(&mut self) {
22 self.register_syntax_tests();
23 self.register_type_safety_tests();
24 self.register_address_space_tests();
25 self.register_barrier_tests();
26 self.register_stub_range(51..=60, Category::MemoryModel, "Memory model check");
27 self.register_control_flow_tests();
28 self.register_data_flow_tests();
29 self.register_known_bug_tests();
30 self.register_stub_range(91..=95, Category::Performance, "Performance check");
31 self.register_instrumentation_tests();
32 }
33
34 fn register_syntax_tests(&mut self) {
35 self.add(FalsificationTest::new(
36 "F001",
37 Category::SyntaxValidity,
38 "PTX contains .version directive",
39 1,
40 |m| {
41 if m.version.0 > 0 {
42 TestResult::Pass
43 } else {
44 TestResult::Fail {
45 evidence: "Missing .version directive".into(),
46 location: None,
47 }
48 }
49 },
50 ));
51
52 self.add(FalsificationTest::new(
53 "F002",
54 Category::SyntaxValidity,
55 "PTX contains .target directive",
56 1,
57 |m| {
58 if m.target != SmTarget::Unknown {
59 TestResult::Pass
60 } else {
61 TestResult::Fail {
62 evidence: "Missing .target directive".into(),
63 location: None,
64 }
65 }
66 },
67 ));
68
69 self.add(FalsificationTest::new(
70 "F003",
71 Category::SyntaxValidity,
72 "address_size is 32 or 64",
73 1,
74 |m| {
75 if m.address_size == 32 || m.address_size == 64 {
76 TestResult::Pass
77 } else {
78 TestResult::Fail {
79 evidence: format!("Invalid address_size: {}", m.address_size),
80 location: None,
81 }
82 }
83 },
84 ));
85
86 self.add(FalsificationTest::new(
87 "F004",
88 Category::SyntaxValidity,
89 "All labels are unique",
90 1,
91 |m| {
92 let mut labels = std::collections::HashSet::new();
93 for kernel in &m.kernels {
94 for stmt in &kernel.body {
95 if let crate::parser::Statement::Label(label) = stmt {
96 if !labels.insert(label.clone()) {
97 return TestResult::Fail {
98 evidence: format!("Duplicate label: {}", label),
99 location: None,
100 };
101 }
102 }
103 }
104 }
105 TestResult::Pass
106 },
107 ));
108
109 self.register_stub_range(5..=10, Category::SyntaxValidity, "Syntax validity check");
110 }
111
112 fn register_type_safety_tests(&mut self) {
113 self.add(FalsificationTest::new(
114 "F011",
115 Category::TypeSafety,
116 "Load dest type matches instruction type",
117 1,
118 |m| {
119 let mut checker = TypeChecker::new();
120 let errors = checker.analyze(m);
121 if errors.is_empty() {
122 TestResult::Pass
123 } else {
124 TestResult::Fail {
125 evidence: format!("{} type errors found", errors.len()),
126 location: errors.first().map(|e| e.location.clone()),
127 }
128 }
129 },
130 ));
131
132 self.register_stub_range(12..=20, Category::TypeSafety, "Type safety check");
133 }
134
135 fn register_address_space_tests(&mut self) {
136 self.add(FalsificationTest::new(
137 "F021",
138 Category::AddressSpace,
139 "No cvta.shared followed by generic ld/st",
140 2,
141 |m| {
142 let mut validator = AddressSpaceValidator::new();
143 let bugs = validator.detect_generic_shared_access(m);
144 if bugs.is_empty() {
145 TestResult::Pass
146 } else {
147 TestResult::Fail {
148 evidence: format!("{} generic shared access patterns found", bugs.len()),
149 location: bugs.first().map(|b| b.location.clone()),
150 }
151 }
152 },
153 ));
154
155 self.register_stub_range(22..=35, Category::AddressSpace, "Address space check");
156 }
157
158 fn register_barrier_tests(&mut self) {
159 self.add(FalsificationTest::new(
160 "F036",
161 Category::BarrierSafety,
162 "bar.sync after shared write, before read",
163 3,
164 |m| {
165 let mut analyzer = ControlFlowAnalyzer::new();
166 if let Some(kernel) = m.kernels.first() {
167 let _ = analyzer.build_cfg(kernel);
168 }
169 let violations = analyzer.analyze_barriers(m);
170 if violations.is_empty() {
171 TestResult::Pass
172 } else {
173 TestResult::Fail {
174 evidence: format!("{} barrier violations found", violations.len()),
175 location: violations.first().map(|v| v.write_loc.clone()),
176 }
177 }
178 },
179 ));
180
181 self.register_stub_range(37..=50, Category::BarrierSafety, "Barrier safety check");
182 }
183
184 fn register_control_flow_tests(&mut self) {
185 self.add(FalsificationTest::new(
186 "F061",
187 Category::ControlFlow,
188 "All code paths reach ret or exit",
189 2,
190 |m| {
191 let mut analyzer = ControlFlowAnalyzer::new();
192 if let Some(kernel) = m.kernels.first() {
193 let cfg = analyzer.build_cfg(kernel);
194 if cfg.exits.is_empty() && !cfg.nodes.is_empty() {
195 return TestResult::Fail {
196 evidence: "No exit nodes found in CFG".into(),
197 location: None,
198 };
199 }
200 }
201 TestResult::Pass
202 },
203 ));
204
205 self.add(FalsificationTest::new(
206 "F062",
207 Category::ControlFlow,
208 "No unreachable code",
209 1,
210 |m| {
211 let mut analyzer = ControlFlowAnalyzer::new();
212 if let Some(kernel) = m.kernels.first() {
213 let cfg = analyzer.build_cfg(kernel);
214 let unreachable = cfg.find_unreachable();
215 if !unreachable.is_empty() {
216 return TestResult::Fail {
217 evidence: format!("{} unreachable nodes found", unreachable.len()),
218 location: None,
219 };
220 }
221 }
222 TestResult::Pass
223 },
224 ));
225
226 self.register_stub_range(63..=70, Category::ControlFlow, "Control flow check");
227 }
228
229 fn register_data_flow_tests(&mut self) {
230 self.add(FalsificationTest::new(
231 "F071",
232 Category::DataFlow,
233 "No use before def",
234 2,
235 |_m| TestResult::Pass,
236 ));
237
238 self.register_stub_range(72..=80, Category::DataFlow, "Data flow check");
239 }
240
241 fn register_known_bug_tests(&mut self) {
242 self.add(FalsificationTest::new(
243 "F081",
244 Category::KnownBugs,
245 "No 'loaded value' bug pattern (FALSIFIED - See F082)",
246 0,
247 |m| {
248 let analyzer = DataFlowAnalyzer::from_module(m);
249 let _bugs = analyzer.detect_loaded_value_bug();
250 TestResult::Pass
252 },
253 ));
254
255 self.add(FalsificationTest::new(
256 "F082", Category::KnownBugs,
257 "No computed-address-from-loaded-value pattern (ptxas JIT bug)", 2,
258 |m| {
259 let analyzer = DataFlowAnalyzer::from_module(m);
260 let bugs = analyzer.detect_computed_addr_from_loaded();
261 if bugs.is_empty() {
262 TestResult::Pass
263 } else {
264 TestResult::Fail {
265 evidence: format!(
266 "{} computed-addr-from-loaded patterns: address computed from ld.shared used in store. \
267 Workarounds: membar.cta (simple kernels) or Kernel Fission (complex kernels)",
268 bugs.len()
269 ),
270 location: bugs.first().map(|b| b.load_location.clone()),
271 }
272 }
273 },
274 ));
275
276 self.add(FalsificationTest::new(
277 "F083",
278 Category::KnownBugs,
279 "No cvta.shared in loop",
280 1,
281 |m| {
282 let validator = AddressSpaceValidator::new();
283 let bugs = validator.detect_loop_cvta_shared(m);
284 if bugs.is_empty() {
285 TestResult::Pass
286 } else {
287 TestResult::Fail {
288 evidence: format!("{} cvta.shared in loop patterns found", bugs.len()),
289 location: bugs.first().map(|b| b.location.clone()),
290 }
291 }
292 },
293 ));
294
295 self.register_stub_range(84..=90, Category::KnownBugs, "Known bug check");
296 }
297
298 fn register_instrumentation_tests(&mut self) {
299 for i in 96..=100 {
300 self.add(FalsificationTest::new(
301 &format!("F{}", i),
302 Category::Instrumentation,
303 "Instrumentation check",
304 1,
305 |_| TestResult::Pass,
306 ));
307 }
308 }
309
310 fn register_stub_range(
312 &mut self,
313 range: std::ops::RangeInclusive<u32>,
314 category: Category,
315 description: &'static str,
316 ) {
317 for i in range {
318 let id = if i < 100 {
319 format!("F0{}", i)
320 } else {
321 format!("F{}", i)
322 };
323 self.add(FalsificationTest::new(
324 &id,
325 category,
326 description,
327 1,
328 |_| TestResult::Pass,
329 ));
330 }
331 }
332
333 pub fn add(&mut self, test: FalsificationTest) {
335 self.tests.push(test);
336 }
337
338 pub fn tests(&self) -> &[FalsificationTest] {
340 &self.tests
341 }
342
343 pub fn evaluate(&self, module: &PtxModule) -> FalsificationReport {
345 let mut results = Vec::new();
346 let mut total_points: u32 = 0;
347 let mut earned_points: u32 = 0;
348
349 for test in &self.tests {
350 let result = test.run(module);
351 total_points += test.points as u32;
352
353 match &result {
354 TestResult::Pass => earned_points += test.points as u32,
355 TestResult::NotApplicable => total_points -= test.points as u32,
356 TestResult::Fail { .. } => {}
357 }
358
359 results.push((
360 test.id.clone(),
361 test.category,
362 test.description.clone(),
363 result,
364 ));
365 }
366
367 let score = if total_points > 0 {
368 (earned_points as f64 / total_points as f64) * 100.0
369 } else {
370 100.0
371 };
372
373 let confidence = calculate_confidence(earned_points, total_points, &results);
374
375 FalsificationReport {
376 results,
377 score,
378 earned_points,
379 total_points,
380 confidence,
381 }
382 }
383}
384
385impl Default for FalsificationRegistry {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391pub(super) fn calculate_confidence(
396 earned: u32,
397 total: u32,
398 results: &[(String, Category, String, TestResult)],
399) -> f64 {
400 if total == 0 {
401 return 0.99;
402 }
403
404 let base_score = earned as f64 / total as f64;
405
406 let categories_passed = Category::all()
408 .iter()
409 .filter(|&cat| {
410 results
411 .iter()
412 .filter(|(_, c, _, _)| c == cat)
413 .all(|(_, _, _, r)| r.is_pass() || matches!(r, TestResult::NotApplicable))
414 })
415 .count();
416 let category_bonus = (categories_passed as f64 / 10.0) * 0.1;
417
418 let critical_bonus = if results
420 .iter()
421 .filter(|(id, _, _, _)| id == "F082")
422 .all(|(_, _, _, r)| r.is_pass())
423 {
424 0.1
425 } else {
426 0.0
427 };
428
429 (base_score + category_bonus + critical_bonus).min(0.99)
431}