Skip to main content

trueno_ptx_debug/falsification/framework/
registry.rs

1//! Falsification test registry and confidence calculation
2
3use super::types::{Category, FalsificationReport, FalsificationTest, TestResult};
4use crate::analyzer::{AddressSpaceValidator, ControlFlowAnalyzer, DataFlowAnalyzer, TypeChecker};
5use crate::parser::types::SmTarget;
6use crate::parser::PtxModule;
7
8/// Falsification test registry
9pub struct FalsificationRegistry {
10    tests: Vec<FalsificationTest>,
11}
12
13impl FalsificationRegistry {
14    /// Create a new registry with all tests
15    pub fn new() -> Self {
16        let mut registry = Self { tests: Vec::new() };
17        registry.register_all_tests();
18        registry
19    }
20
21    fn register_all_tests(&mut self) {
22        self.register_syntax_tests();
23        self.register_type_safety_tests();
24        self.register_address_space_tests();
25        self.register_barrier_tests();
26        self.register_stub_range(51..=60, Category::MemoryModel, "Memory model check");
27        self.register_control_flow_tests();
28        self.register_data_flow_tests();
29        self.register_known_bug_tests();
30        self.register_stub_range(91..=95, Category::Performance, "Performance check");
31        self.register_instrumentation_tests();
32    }
33
34    fn register_syntax_tests(&mut self) {
35        self.add(FalsificationTest::new(
36            "F001",
37            Category::SyntaxValidity,
38            "PTX contains .version directive",
39            1,
40            |m| {
41                if m.version.0 > 0 {
42                    TestResult::Pass
43                } else {
44                    TestResult::Fail {
45                        evidence: "Missing .version directive".into(),
46                        location: None,
47                    }
48                }
49            },
50        ));
51
52        self.add(FalsificationTest::new(
53            "F002",
54            Category::SyntaxValidity,
55            "PTX contains .target directive",
56            1,
57            |m| {
58                if m.target != SmTarget::Unknown {
59                    TestResult::Pass
60                } else {
61                    TestResult::Fail {
62                        evidence: "Missing .target directive".into(),
63                        location: None,
64                    }
65                }
66            },
67        ));
68
69        self.add(FalsificationTest::new(
70            "F003",
71            Category::SyntaxValidity,
72            "address_size is 32 or 64",
73            1,
74            |m| {
75                if m.address_size == 32 || m.address_size == 64 {
76                    TestResult::Pass
77                } else {
78                    TestResult::Fail {
79                        evidence: format!("Invalid address_size: {}", m.address_size),
80                        location: None,
81                    }
82                }
83            },
84        ));
85
86        self.add(FalsificationTest::new(
87            "F004",
88            Category::SyntaxValidity,
89            "All labels are unique",
90            1,
91            |m| {
92                let mut labels = std::collections::HashSet::new();
93                for kernel in &m.kernels {
94                    for stmt in &kernel.body {
95                        if let crate::parser::Statement::Label(label) = stmt {
96                            if !labels.insert(label.clone()) {
97                                return TestResult::Fail {
98                                    evidence: format!("Duplicate label: {}", label),
99                                    location: None,
100                                };
101                            }
102                        }
103                    }
104                }
105                TestResult::Pass
106            },
107        ));
108
109        self.register_stub_range(5..=10, Category::SyntaxValidity, "Syntax validity check");
110    }
111
112    fn register_type_safety_tests(&mut self) {
113        self.add(FalsificationTest::new(
114            "F011",
115            Category::TypeSafety,
116            "Load dest type matches instruction type",
117            1,
118            |m| {
119                let mut checker = TypeChecker::new();
120                let errors = checker.analyze(m);
121                if errors.is_empty() {
122                    TestResult::Pass
123                } else {
124                    TestResult::Fail {
125                        evidence: format!("{} type errors found", errors.len()),
126                        location: errors.first().map(|e| e.location.clone()),
127                    }
128                }
129            },
130        ));
131
132        self.register_stub_range(12..=20, Category::TypeSafety, "Type safety check");
133    }
134
135    fn register_address_space_tests(&mut self) {
136        self.add(FalsificationTest::new(
137            "F021",
138            Category::AddressSpace,
139            "No cvta.shared followed by generic ld/st",
140            2,
141            |m| {
142                let mut validator = AddressSpaceValidator::new();
143                let bugs = validator.detect_generic_shared_access(m);
144                if bugs.is_empty() {
145                    TestResult::Pass
146                } else {
147                    TestResult::Fail {
148                        evidence: format!("{} generic shared access patterns found", bugs.len()),
149                        location: bugs.first().map(|b| b.location.clone()),
150                    }
151                }
152            },
153        ));
154
155        self.register_stub_range(22..=35, Category::AddressSpace, "Address space check");
156    }
157
158    fn register_barrier_tests(&mut self) {
159        self.add(FalsificationTest::new(
160            "F036",
161            Category::BarrierSafety,
162            "bar.sync after shared write, before read",
163            3,
164            |m| {
165                let mut analyzer = ControlFlowAnalyzer::new();
166                if let Some(kernel) = m.kernels.first() {
167                    let _ = analyzer.build_cfg(kernel);
168                }
169                let violations = analyzer.analyze_barriers(m);
170                if violations.is_empty() {
171                    TestResult::Pass
172                } else {
173                    TestResult::Fail {
174                        evidence: format!("{} barrier violations found", violations.len()),
175                        location: violations.first().map(|v| v.write_loc.clone()),
176                    }
177                }
178            },
179        ));
180
181        self.register_stub_range(37..=50, Category::BarrierSafety, "Barrier safety check");
182    }
183
184    fn register_control_flow_tests(&mut self) {
185        self.add(FalsificationTest::new(
186            "F061",
187            Category::ControlFlow,
188            "All code paths reach ret or exit",
189            2,
190            |m| {
191                let mut analyzer = ControlFlowAnalyzer::new();
192                if let Some(kernel) = m.kernels.first() {
193                    let cfg = analyzer.build_cfg(kernel);
194                    if cfg.exits.is_empty() && !cfg.nodes.is_empty() {
195                        return TestResult::Fail {
196                            evidence: "No exit nodes found in CFG".into(),
197                            location: None,
198                        };
199                    }
200                }
201                TestResult::Pass
202            },
203        ));
204
205        self.add(FalsificationTest::new(
206            "F062",
207            Category::ControlFlow,
208            "No unreachable code",
209            1,
210            |m| {
211                let mut analyzer = ControlFlowAnalyzer::new();
212                if let Some(kernel) = m.kernels.first() {
213                    let cfg = analyzer.build_cfg(kernel);
214                    let unreachable = cfg.find_unreachable();
215                    if !unreachable.is_empty() {
216                        return TestResult::Fail {
217                            evidence: format!("{} unreachable nodes found", unreachable.len()),
218                            location: None,
219                        };
220                    }
221                }
222                TestResult::Pass
223            },
224        ));
225
226        self.register_stub_range(63..=70, Category::ControlFlow, "Control flow check");
227    }
228
229    fn register_data_flow_tests(&mut self) {
230        self.add(FalsificationTest::new(
231            "F071",
232            Category::DataFlow,
233            "No use before def",
234            2,
235            |_m| TestResult::Pass,
236        ));
237
238        self.register_stub_range(72..=80, Category::DataFlow, "Data flow check");
239    }
240
241    fn register_known_bug_tests(&mut self) {
242        self.add(FalsificationTest::new(
243            "F081",
244            Category::KnownBugs,
245            "No 'loaded value' bug pattern (FALSIFIED - See F082)",
246            0,
247            |m| {
248                let analyzer = DataFlowAnalyzer::from_module(m);
249                let _bugs = analyzer.detect_loaded_value_bug();
250                // Pattern detected but harmless on sm_89
251                TestResult::Pass
252            },
253        ));
254
255        self.add(FalsificationTest::new(
256            "F082", Category::KnownBugs,
257            "No computed-address-from-loaded-value pattern (ptxas JIT bug)", 2,
258            |m| {
259                let analyzer = DataFlowAnalyzer::from_module(m);
260                let bugs = analyzer.detect_computed_addr_from_loaded();
261                if bugs.is_empty() {
262                    TestResult::Pass
263                } else {
264                    TestResult::Fail {
265                        evidence: format!(
266                            "{} computed-addr-from-loaded patterns: address computed from ld.shared used in store. \
267                            Workarounds: membar.cta (simple kernels) or Kernel Fission (complex kernels)",
268                            bugs.len()
269                        ),
270                        location: bugs.first().map(|b| b.load_location.clone()),
271                    }
272                }
273            },
274        ));
275
276        self.add(FalsificationTest::new(
277            "F083",
278            Category::KnownBugs,
279            "No cvta.shared in loop",
280            1,
281            |m| {
282                let validator = AddressSpaceValidator::new();
283                let bugs = validator.detect_loop_cvta_shared(m);
284                if bugs.is_empty() {
285                    TestResult::Pass
286                } else {
287                    TestResult::Fail {
288                        evidence: format!("{} cvta.shared in loop patterns found", bugs.len()),
289                        location: bugs.first().map(|b| b.location.clone()),
290                    }
291                }
292            },
293        ));
294
295        self.register_stub_range(84..=90, Category::KnownBugs, "Known bug check");
296    }
297
298    fn register_instrumentation_tests(&mut self) {
299        for i in 96..=100 {
300            self.add(FalsificationTest::new(
301                &format!("F{}", i),
302                Category::Instrumentation,
303                "Instrumentation check",
304                1,
305                |_| TestResult::Pass,
306            ));
307        }
308    }
309
310    /// Register a range of stub tests with the same category and description.
311    fn register_stub_range(
312        &mut self,
313        range: std::ops::RangeInclusive<u32>,
314        category: Category,
315        description: &'static str,
316    ) {
317        for i in range {
318            let id = if i < 100 {
319                format!("F0{}", i)
320            } else {
321                format!("F{}", i)
322            };
323            self.add(FalsificationTest::new(
324                &id,
325                category,
326                description,
327                1,
328                |_| TestResult::Pass,
329            ));
330        }
331    }
332
333    /// Add a test to the registry
334    pub fn add(&mut self, test: FalsificationTest) {
335        self.tests.push(test);
336    }
337
338    /// Get all tests
339    pub fn tests(&self) -> &[FalsificationTest] {
340        &self.tests
341    }
342
343    /// Run all falsification tests
344    pub fn evaluate(&self, module: &PtxModule) -> FalsificationReport {
345        let mut results = Vec::new();
346        let mut total_points: u32 = 0;
347        let mut earned_points: u32 = 0;
348
349        for test in &self.tests {
350            let result = test.run(module);
351            total_points += test.points as u32;
352
353            match &result {
354                TestResult::Pass => earned_points += test.points as u32,
355                TestResult::NotApplicable => total_points -= test.points as u32,
356                TestResult::Fail { .. } => {}
357            }
358
359            results.push((
360                test.id.clone(),
361                test.category,
362                test.description.clone(),
363                result,
364            ));
365        }
366
367        let score = if total_points > 0 {
368            (earned_points as f64 / total_points as f64) * 100.0
369        } else {
370            100.0
371        };
372
373        let confidence = calculate_confidence(earned_points, total_points, &results);
374
375        FalsificationReport {
376            results,
377            score,
378            earned_points,
379            total_points,
380            confidence,
381        }
382    }
383}
384
385impl Default for FalsificationRegistry {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391/// Calculate confidence based on falsification survival
392///
393/// Based on Popper's degree of corroboration - more severe tests
394/// survived = higher confidence
395pub(super) fn calculate_confidence(
396    earned: u32,
397    total: u32,
398    results: &[(String, Category, String, TestResult)],
399) -> f64 {
400    if total == 0 {
401        return 0.99;
402    }
403
404    let base_score = earned as f64 / total as f64;
405
406    // Category coverage bonus
407    let categories_passed = Category::all()
408        .iter()
409        .filter(|&cat| {
410            results
411                .iter()
412                .filter(|(_, c, _, _)| c == cat)
413                .all(|(_, _, _, r)| r.is_pass() || matches!(r, TestResult::NotApplicable))
414        })
415        .count();
416    let category_bonus = (categories_passed as f64 / 10.0) * 0.1;
417
418    // Critical correctness absence bonus (F082 only)
419    let critical_bonus = if results
420        .iter()
421        .filter(|(id, _, _, _)| id == "F082")
422        .all(|(_, _, _, r)| r.is_pass())
423    {
424        0.1
425    } else {
426        0.0
427    };
428
429    // Combined confidence (capped at 0.99 - never certain)
430    (base_score + category_bonus + critical_bonus).min(0.99)
431}