trueno_ptx_debug/analyzer/
data_flow.rs1use crate::bugs::Severity;
4use crate::parser::types::{AddressSpace, Opcode};
5use crate::parser::{Instruction, KernelDef, Operand, PtxModule, SourceLocation, Statement};
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug, Clone)]
10pub enum ValueSource {
11 Load {
13 space: AddressSpace,
15 location: SourceLocation,
17 },
18 Constant(i64),
20 Computed {
22 inputs: Vec<String>,
24 },
25 Parameter(String),
27 Unknown,
29}
30
31#[derive(Debug, Clone)]
33pub struct UsePoint {
34 pub instruction: Instruction,
36 pub operand_index: usize,
38 pub location: SourceLocation,
40 is_store_data: bool,
42 is_store_addr: bool,
44}
45
46impl UsePoint {
47 pub fn is_store_data_operand(&self) -> bool {
49 self.is_store_data
50 }
51
52 pub fn is_store_address_operand(&self) -> bool {
54 self.is_store_addr
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct LoadedValueBug {
61 pub load_location: SourceLocation,
63 pub store_location: SourceLocation,
65 pub register: String,
67 pub severity: Severity,
69 pub mitigation: String,
71}
72
73#[derive(Debug, Clone)]
75pub struct ComputedAddrFromLoadedBug {
76 pub load_location: SourceLocation,
78 pub addr_computation_location: SourceLocation,
80 pub tainted_register: String,
82 pub severity: Severity,
84 pub mitigation: String,
86}
87
88pub struct DataFlowAnalyzer {
90 def_use_chains: HashMap<String, Vec<UsePoint>>,
92 value_sources: HashMap<String, ValueSource>,
94}
95
96impl DataFlowAnalyzer {
97 pub fn new() -> Self {
99 Self {
100 def_use_chains: HashMap::new(),
101 value_sources: HashMap::new(),
102 }
103 }
104
105 pub fn from_module(module: &PtxModule) -> Self {
107 let mut analyzer = Self::new();
108 if let Some(kernel) = module.kernels.first() {
109 analyzer.analyze_kernel(kernel);
110 }
111 analyzer
112 }
113
114 pub fn analyze_kernel(&mut self, kernel: &KernelDef) {
116 self.def_use_chains.clear();
117 self.value_sources.clear();
118
119 for stmt in &kernel.body {
120 if let Statement::Instruction(instr) = stmt {
121 self.analyze_instruction(instr);
122 }
123 }
124 }
125
126 fn analyze_instruction(&mut self, instr: &Instruction) {
127 match instr.opcode {
128 Opcode::Ld => {
129 if let Some(Operand::Register(dest)) = instr.operands.first() {
131 let space = self.get_address_space(instr);
132 self.value_sources.insert(
133 dest.clone(),
134 ValueSource::Load {
135 space,
136 location: instr.location.clone(),
137 },
138 );
139 }
140 }
141 Opcode::Mov => {
142 if let (Some(Operand::Register(dest)), Some(src)) =
144 (instr.operands.first(), instr.operands.get(1))
145 {
146 let source = match src {
147 Operand::Register(src_reg) => self
148 .value_sources
149 .get(src_reg)
150 .cloned()
151 .unwrap_or(ValueSource::Unknown),
152 Operand::Immediate(val) => ValueSource::Constant(*val),
153 _ => ValueSource::Unknown,
154 };
155 self.value_sources.insert(dest.clone(), source);
156 }
157 }
158 Opcode::Add
159 | Opcode::Sub
160 | Opcode::Mul
161 | Opcode::And
162 | Opcode::Or
163 | Opcode::Shl
164 | Opcode::Shr => {
165 if let Some(Operand::Register(dest)) = instr.operands.first() {
167 let inputs: Vec<String> = instr
168 .operands
169 .iter()
170 .skip(1)
171 .filter_map(|op| {
172 if let Operand::Register(reg) = op {
173 Some(reg.clone())
174 } else {
175 None
176 }
177 })
178 .collect();
179
180 self.value_sources
181 .insert(dest.clone(), ValueSource::Computed { inputs });
182 }
183 }
184 Opcode::St => {
185 if let Some(Operand::Memory(addr_str)) = instr.operands.first() {
190 let addr_reg = self.extract_register_from_memory(addr_str);
192 if let Some(reg) = addr_reg {
193 self.def_use_chains
194 .entry(reg.clone())
195 .or_default()
196 .push(UsePoint {
197 instruction: instr.clone(),
198 operand_index: 0,
199 location: instr.location.clone(),
200 is_store_data: false,
201 is_store_addr: true,
202 });
203 }
204 }
205
206 if let Some(Operand::Register(val_reg)) = instr.operands.get(1) {
207 self.def_use_chains
208 .entry(val_reg.clone())
209 .or_default()
210 .push(UsePoint {
211 instruction: instr.clone(),
212 operand_index: 1,
213 location: instr.location.clone(),
214 is_store_data: true,
215 is_store_addr: false,
216 });
217 }
218 }
219 _ => {}
220 }
221 }
222
223 fn get_address_space(&self, instr: &Instruction) -> AddressSpace {
224 for modifier in &instr.modifiers {
225 if let Some(space) = modifier.as_address_space() {
226 return space;
227 }
228 }
229 AddressSpace::Generic
230 }
231
232 fn extract_register_from_memory(&self, addr_str: &str) -> Option<String> {
233 let trimmed = addr_str.trim_matches(|c| c == '[' || c == ']');
235 if let Some(plus_pos) = trimmed.find('+') {
236 Some(trimmed[..plus_pos].trim().to_string())
237 } else {
238 Some(trimmed.trim().to_string())
239 }
240 }
241
242 pub fn detect_loaded_value_bug(&self) -> Vec<LoadedValueBug> {
246 let mut bugs = Vec::new();
247
248 for (reg, source) in &self.value_sources {
249 if let ValueSource::Load {
250 space: AddressSpace::Shared,
251 location,
252 } = source
253 {
254 for use_point in self.def_use_chains.get(reg).unwrap_or(&vec![]) {
256 if use_point.is_store_data_operand() {
257 bugs.push(LoadedValueBug {
258 load_location: location.clone(),
259 store_location: use_point.location.clone(),
260 register: reg.clone(),
261 severity: Severity::Low,
262 mitigation: "Hypothesis F081 falsified on sm_89. This pattern is safe."
263 .into(),
264 });
265 }
266 }
267 }
268 }
269
270 bugs
271 }
272
273 pub fn detect_computed_addr_from_loaded(&self) -> Vec<ComputedAddrFromLoadedBug> {
278 let mut bugs = Vec::new();
279
280 let mut shared_loaded_regs: HashSet<String> = HashSet::new();
282 for (reg, source) in &self.value_sources {
283 if matches!(
284 source,
285 ValueSource::Load {
286 space: AddressSpace::Shared,
287 ..
288 }
289 ) {
290 shared_loaded_regs.insert(reg.clone());
291 }
292 }
293
294 let mut tainted_regs: HashSet<String> = shared_loaded_regs.clone();
296 let mut changed = true;
297 while changed {
298 changed = false;
299 for (reg, source) in &self.value_sources {
300 if let ValueSource::Computed { inputs } = source {
301 if !tainted_regs.contains(reg)
302 && inputs.iter().any(|i| tainted_regs.contains(i))
303 {
304 tainted_regs.insert(reg.clone());
305 changed = true;
306 }
307 }
308 }
309 }
310
311 for (reg, _source) in &self.value_sources {
313 if tainted_regs.contains(reg) {
314 for use_point in self.def_use_chains.get(reg).unwrap_or(&vec![]) {
315 if use_point.is_store_address_operand() {
316 let load_loc = self.find_original_load_location(reg, &shared_loaded_regs);
317 bugs.push(ComputedAddrFromLoadedBug {
318 load_location: load_loc.unwrap_or_default(),
319 addr_computation_location: use_point.location.clone(),
320 tainted_register: reg.clone(),
321 severity: Severity::Critical,
322 mitigation: "Use constant-only address computation, try membar.cta (partial), or use Kernel Fission (split kernel)".into(),
323 });
324 }
325 }
326 }
327 }
328
329 bugs
330 }
331
332 fn find_original_load_location(
333 &self,
334 reg: &str,
335 shared_loaded_regs: &HashSet<String>,
336 ) -> Option<SourceLocation> {
337 if let Some(ValueSource::Load { location, .. }) = self.value_sources.get(reg) {
339 return Some(location.clone());
340 }
341
342 if let Some(ValueSource::Computed { inputs }) = self.value_sources.get(reg) {
344 for input in inputs {
345 if shared_loaded_regs.contains(input) {
346 return self.find_original_load_location(input, shared_loaded_regs);
347 }
348 }
349 }
350
351 None
352 }
353}
354
355impl Default for DataFlowAnalyzer {
356 fn default() -> Self {
357 Self::new()
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::parser::Parser;
365
366 #[test]
368 fn f081_no_loaded_value_bug() {
369 let ptx = r#"
370 .version 8.0
371 .target sm_70
372 .address_size 64
373
374 .entry test()
375 {
376 .reg .u32 %r<10>;
377 mov.u32 %r0, 0;
378 st.shared.u32 [%r1], %r0;
379 ret;
380 }
381 "#;
382 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
383 let module = parser.parse().expect("parsing should succeed");
384
385 let analyzer = DataFlowAnalyzer::from_module(&module);
386 let bugs = analyzer.detect_loaded_value_bug();
387
388 assert!(
389 bugs.is_empty(),
390 "F081: Should have no loaded value bugs when using constant"
391 );
392 }
393
394 #[test]
396 fn f082_no_computed_addr_from_loaded_bug() {
397 let ptx = r#"
398 .version 8.0
399 .target sm_70
400 .address_size 64
401
402 .entry test()
403 {
404 .reg .u32 %r<10>;
405 mov.u32 %r0, 100;
406 add.u32 %r1, %r2, %r0;
407 mov.u32 %r3, 0xCAFE;
408 st.shared.u32 [%r1], %r3;
409 ret;
410 }
411 "#;
412 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
413 let module = parser.parse().expect("parsing should succeed");
414
415 let analyzer = DataFlowAnalyzer::from_module(&module);
416 let bugs = analyzer.detect_computed_addr_from_loaded();
417
418 assert!(
419 bugs.is_empty(),
420 "F082: Should have no computed-addr bugs when using constant"
421 );
422 }
423
424 #[test]
426 fn f071_no_use_before_def() {
427 let ptx = r#"
428 .version 8.0
429 .target sm_70
430 .address_size 64
431
432 .entry test()
433 {
434 .reg .u32 %r<10>;
435 mov.u32 %r0, 0;
436 add.u32 %r1, %r0, 1;
437 ret;
438 }
439 "#;
440 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
441 let module = parser.parse().expect("parsing should succeed");
442
443 let _analyzer = DataFlowAnalyzer::from_module(&module);
444 assert!(!module.kernels.is_empty());
448 }
449}