Skip to main content

trueno_ptx_debug/analyzer/
data_flow.rs

1//! Data Flow Analyzer - value propagation and loaded value bug detection
2
3use crate::bugs::Severity;
4use crate::parser::types::{AddressSpace, Opcode};
5use crate::parser::{Instruction, KernelDef, Operand, PtxModule, SourceLocation, Statement};
6use std::collections::{HashMap, HashSet};
7
8/// Source of a value
9#[derive(Debug, Clone)]
10pub enum ValueSource {
11    /// Value came from a load instruction
12    Load {
13        /// Address space
14        space: AddressSpace,
15        /// Source location
16        location: SourceLocation,
17    },
18    /// Value came from a constant/immediate
19    Constant(i64),
20    /// Value came from computation
21    Computed {
22        /// Input registers
23        inputs: Vec<String>,
24    },
25    /// Value came from parameter
26    Parameter(String),
27    /// Unknown source
28    Unknown,
29}
30
31/// Use point of a value
32#[derive(Debug, Clone)]
33pub struct UsePoint {
34    /// Instruction where the value is used
35    pub instruction: Instruction,
36    /// Operand index (0-based)
37    pub operand_index: usize,
38    /// Source location
39    pub location: SourceLocation,
40    /// Is this the data operand of a store?
41    is_store_data: bool,
42    /// Is this the address operand of a store?
43    is_store_addr: bool,
44}
45
46impl UsePoint {
47    /// Is this use point a store data operand?
48    pub fn is_store_data_operand(&self) -> bool {
49        self.is_store_data
50    }
51
52    /// Is this use point a store address operand?
53    pub fn is_store_address_operand(&self) -> bool {
54        self.is_store_addr
55    }
56}
57
58/// Bug: Store using value derived from ld.shared
59#[derive(Debug, Clone)]
60pub struct LoadedValueBug {
61    /// Load location
62    pub load_location: SourceLocation,
63    /// Store location
64    pub store_location: SourceLocation,
65    /// Register containing loaded value
66    pub register: String,
67    /// Severity
68    pub severity: Severity,
69    /// Mitigation advice
70    pub mitigation: String,
71}
72
73/// Bug: Address computed from ld.shared value causes store crash
74#[derive(Debug, Clone)]
75pub struct ComputedAddrFromLoadedBug {
76    /// Load location
77    pub load_location: SourceLocation,
78    /// Address computation location
79    pub addr_computation_location: SourceLocation,
80    /// Tainted register
81    pub tainted_register: String,
82    /// Severity
83    pub severity: Severity,
84    /// Mitigation advice
85    pub mitigation: String,
86}
87
88/// Data Flow Analyzer
89pub struct DataFlowAnalyzer {
90    /// Def-use chains: register -> use points
91    def_use_chains: HashMap<String, Vec<UsePoint>>,
92    /// Value sources: register -> source
93    value_sources: HashMap<String, ValueSource>,
94}
95
96impl DataFlowAnalyzer {
97    /// Create a new data flow analyzer
98    pub fn new() -> Self {
99        Self {
100            def_use_chains: HashMap::new(),
101            value_sources: HashMap::new(),
102        }
103    }
104
105    /// Create from a PTX module (analyzes first kernel)
106    pub fn from_module(module: &PtxModule) -> Self {
107        let mut analyzer = Self::new();
108        if let Some(kernel) = module.kernels.first() {
109            analyzer.analyze_kernel(kernel);
110        }
111        analyzer
112    }
113
114    /// Analyze a kernel for data flow
115    pub fn analyze_kernel(&mut self, kernel: &KernelDef) {
116        self.def_use_chains.clear();
117        self.value_sources.clear();
118
119        for stmt in &kernel.body {
120            if let Statement::Instruction(instr) = stmt {
121                self.analyze_instruction(instr);
122            }
123        }
124    }
125
126    fn analyze_instruction(&mut self, instr: &Instruction) {
127        match instr.opcode {
128            Opcode::Ld => {
129                // Load defines a register with value from memory
130                if let Some(Operand::Register(dest)) = instr.operands.first() {
131                    let space = self.get_address_space(instr);
132                    self.value_sources.insert(
133                        dest.clone(),
134                        ValueSource::Load {
135                            space,
136                            location: instr.location.clone(),
137                        },
138                    );
139                }
140            }
141            Opcode::Mov => {
142                // Move copies value source
143                if let (Some(Operand::Register(dest)), Some(src)) =
144                    (instr.operands.first(), instr.operands.get(1))
145                {
146                    let source = match src {
147                        Operand::Register(src_reg) => self
148                            .value_sources
149                            .get(src_reg)
150                            .cloned()
151                            .unwrap_or(ValueSource::Unknown),
152                        Operand::Immediate(val) => ValueSource::Constant(*val),
153                        _ => ValueSource::Unknown,
154                    };
155                    self.value_sources.insert(dest.clone(), source);
156                }
157            }
158            Opcode::Add
159            | Opcode::Sub
160            | Opcode::Mul
161            | Opcode::And
162            | Opcode::Or
163            | Opcode::Shl
164            | Opcode::Shr => {
165                // Computation defines register with computed value
166                if let Some(Operand::Register(dest)) = instr.operands.first() {
167                    let inputs: Vec<String> = instr
168                        .operands
169                        .iter()
170                        .skip(1)
171                        .filter_map(|op| {
172                            if let Operand::Register(reg) = op {
173                                Some(reg.clone())
174                            } else {
175                                None
176                            }
177                        })
178                        .collect();
179
180                    self.value_sources
181                        .insert(dest.clone(), ValueSource::Computed { inputs });
182                }
183            }
184            Opcode::St => {
185                // Store uses values - track the use points
186                // For st.<type> [addr], value:
187                // - operand 0 is the memory address
188                // - operand 1 is the value to store
189                if let Some(Operand::Memory(addr_str)) = instr.operands.first() {
190                    // Extract register from memory operand like [%r0] or [%r0+offset]
191                    let addr_reg = self.extract_register_from_memory(addr_str);
192                    if let Some(reg) = addr_reg {
193                        self.def_use_chains
194                            .entry(reg.clone())
195                            .or_default()
196                            .push(UsePoint {
197                                instruction: instr.clone(),
198                                operand_index: 0,
199                                location: instr.location.clone(),
200                                is_store_data: false,
201                                is_store_addr: true,
202                            });
203                    }
204                }
205
206                if let Some(Operand::Register(val_reg)) = instr.operands.get(1) {
207                    self.def_use_chains
208                        .entry(val_reg.clone())
209                        .or_default()
210                        .push(UsePoint {
211                            instruction: instr.clone(),
212                            operand_index: 1,
213                            location: instr.location.clone(),
214                            is_store_data: true,
215                            is_store_addr: false,
216                        });
217                }
218            }
219            _ => {}
220        }
221    }
222
223    fn get_address_space(&self, instr: &Instruction) -> AddressSpace {
224        for modifier in &instr.modifiers {
225            if let Some(space) = modifier.as_address_space() {
226                return space;
227            }
228        }
229        AddressSpace::Generic
230    }
231
232    fn extract_register_from_memory(&self, addr_str: &str) -> Option<String> {
233        // Extract register name from patterns like [%r0], [%r0+4], etc.
234        let trimmed = addr_str.trim_matches(|c| c == '[' || c == ']');
235        if let Some(plus_pos) = trimmed.find('+') {
236            Some(trimmed[..plus_pos].trim().to_string())
237        } else {
238            Some(trimmed.trim().to_string())
239        }
240    }
241
242    /// Detect the "loaded value" bug pattern (F081)
243    ///
244    /// Pattern: ld.shared -> computation -> st.XXX crashes
245    pub fn detect_loaded_value_bug(&self) -> Vec<LoadedValueBug> {
246        let mut bugs = Vec::new();
247
248        for (reg, source) in &self.value_sources {
249            if let ValueSource::Load {
250                space: AddressSpace::Shared,
251                location,
252            } = source
253            {
254                // Find all stores that use this register as data operand
255                for use_point in self.def_use_chains.get(reg).unwrap_or(&vec![]) {
256                    if use_point.is_store_data_operand() {
257                        bugs.push(LoadedValueBug {
258                            load_location: location.clone(),
259                            store_location: use_point.location.clone(),
260                            register: reg.clone(),
261                            severity: Severity::Low,
262                            mitigation: "Hypothesis F081 falsified on sm_89. This pattern is safe."
263                                .into(),
264                        });
265                    }
266                }
267            }
268        }
269
270        bugs
271    }
272
273    /// Detect "computed address from loaded value" bug pattern (F082)
274    ///
275    /// Pattern: ld.shared %r_val -> add %r_addr, base, %r_val -> st.XXX [%r_addr]
276    /// Even storing a CONSTANT to an address computed from a loaded value crashes.
277    pub fn detect_computed_addr_from_loaded(&self) -> Vec<ComputedAddrFromLoadedBug> {
278        let mut bugs = Vec::new();
279
280        // Track which registers come from ld.shared
281        let mut shared_loaded_regs: HashSet<String> = HashSet::new();
282        for (reg, source) in &self.value_sources {
283            if matches!(
284                source,
285                ValueSource::Load {
286                    space: AddressSpace::Shared,
287                    ..
288                }
289            ) {
290                shared_loaded_regs.insert(reg.clone());
291            }
292        }
293
294        // Track registers computed from shared-loaded registers (taint propagation)
295        let mut tainted_regs: HashSet<String> = shared_loaded_regs.clone();
296        let mut changed = true;
297        while changed {
298            changed = false;
299            for (reg, source) in &self.value_sources {
300                if let ValueSource::Computed { inputs } = source {
301                    if !tainted_regs.contains(reg)
302                        && inputs.iter().any(|i| tainted_regs.contains(i))
303                    {
304                        tainted_regs.insert(reg.clone());
305                        changed = true;
306                    }
307                }
308            }
309        }
310
311        // Find stores where ADDRESS is computed from tainted register
312        for (reg, _source) in &self.value_sources {
313            if tainted_regs.contains(reg) {
314                for use_point in self.def_use_chains.get(reg).unwrap_or(&vec![]) {
315                    if use_point.is_store_address_operand() {
316                        let load_loc = self.find_original_load_location(reg, &shared_loaded_regs);
317                        bugs.push(ComputedAddrFromLoadedBug {
318                            load_location: load_loc.unwrap_or_default(),
319                            addr_computation_location: use_point.location.clone(),
320                            tainted_register: reg.clone(),
321                            severity: Severity::Critical,
322                            mitigation: "Use constant-only address computation, try membar.cta (partial), or use Kernel Fission (split kernel)".into(),
323                        });
324                    }
325                }
326            }
327        }
328
329        bugs
330    }
331
332    fn find_original_load_location(
333        &self,
334        reg: &str,
335        shared_loaded_regs: &HashSet<String>,
336    ) -> Option<SourceLocation> {
337        // If this register is directly from a load, return that location
338        if let Some(ValueSource::Load { location, .. }) = self.value_sources.get(reg) {
339            return Some(location.clone());
340        }
341
342        // Otherwise, trace back through computations
343        if let Some(ValueSource::Computed { inputs }) = self.value_sources.get(reg) {
344            for input in inputs {
345                if shared_loaded_regs.contains(input) {
346                    return self.find_original_load_location(input, shared_loaded_regs);
347                }
348            }
349        }
350
351        None
352    }
353}
354
355impl Default for DataFlowAnalyzer {
356    fn default() -> Self {
357        Self::new()
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::parser::Parser;
365
366    // F081: No loaded value store pattern
367    #[test]
368    fn f081_no_loaded_value_bug() {
369        let ptx = r#"
370            .version 8.0
371            .target sm_70
372            .address_size 64
373
374            .entry test()
375            {
376                .reg .u32 %r<10>;
377                mov.u32 %r0, 0;
378                st.shared.u32 [%r1], %r0;
379                ret;
380            }
381        "#;
382        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
383        let module = parser.parse().expect("parsing should succeed");
384
385        let analyzer = DataFlowAnalyzer::from_module(&module);
386        let bugs = analyzer.detect_loaded_value_bug();
387
388        assert!(
389            bugs.is_empty(),
390            "F081: Should have no loaded value bugs when using constant"
391        );
392    }
393
394    // F082: No computed-address-from-loaded pattern
395    #[test]
396    fn f082_no_computed_addr_from_loaded_bug() {
397        let ptx = r#"
398            .version 8.0
399            .target sm_70
400            .address_size 64
401
402            .entry test()
403            {
404                .reg .u32 %r<10>;
405                mov.u32 %r0, 100;
406                add.u32 %r1, %r2, %r0;
407                mov.u32 %r3, 0xCAFE;
408                st.shared.u32 [%r1], %r3;
409                ret;
410            }
411        "#;
412        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
413        let module = parser.parse().expect("parsing should succeed");
414
415        let analyzer = DataFlowAnalyzer::from_module(&module);
416        let bugs = analyzer.detect_computed_addr_from_loaded();
417
418        assert!(
419            bugs.is_empty(),
420            "F082: Should have no computed-addr bugs when using constant"
421        );
422    }
423
424    // F071: No use before def
425    #[test]
426    fn f071_no_use_before_def() {
427        let ptx = r#"
428            .version 8.0
429            .target sm_70
430            .address_size 64
431
432            .entry test()
433            {
434                .reg .u32 %r<10>;
435                mov.u32 %r0, 0;
436                add.u32 %r1, %r0, 1;
437                ret;
438            }
439        "#;
440        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
441        let module = parser.parse().expect("parsing should succeed");
442
443        let _analyzer = DataFlowAnalyzer::from_module(&module);
444        // The analyzer should track value sources
445        // Note: The exact register names depend on parser output
446        // This test verifies the analyzer runs without error
447        assert!(!module.kernels.is_empty());
448    }
449}