Skip to main content

trueno_ptx_debug/analyzer/
address_space.rs

1//! Address Space Validator - validates correct address space usage
2
3use crate::bugs::Severity;
4use crate::parser::types::{Modifier, Opcode};
5use crate::parser::{Instruction, Operand, PtxModule, SourceLocation, Statement};
6use std::collections::HashSet;
7
8/// Bug: Generic addressing of shared memory
9#[derive(Debug, Clone)]
10pub struct GenericSharedBug {
11    /// Source location
12    pub location: SourceLocation,
13    /// Instruction that triggered the bug
14    pub instruction: Instruction,
15    /// Severity
16    pub severity: Severity,
17    /// Fix suggestion
18    pub fix: String,
19}
20
21/// Address Space Validator
22pub struct AddressSpaceValidator {
23    /// Registers holding cvta.shared results (generic shared addresses)
24    shared_base_regs: HashSet<String>,
25}
26
27impl AddressSpaceValidator {
28    /// Create a new address space validator
29    pub fn new() -> Self {
30        Self {
31            shared_base_regs: HashSet::new(),
32        }
33    }
34
35    /// Detect generic addressing of shared memory (F021)
36    ///
37    /// WRONG: cvta.shared.u64 %rd, smem; ld.u32 [%rd]
38    /// RIGHT: ld.shared.u32 [smem_offset]
39    pub fn detect_generic_shared_access(&mut self, module: &PtxModule) -> Vec<GenericSharedBug> {
40        let mut bugs = Vec::new();
41
42        for kernel in &module.kernels {
43            self.shared_base_regs.clear();
44
45            for stmt in &kernel.body {
46                if let Statement::Instruction(instr) = stmt {
47                    // Track cvta.shared destinations
48                    if instr.opcode == Opcode::Cvta && self.has_shared_modifier(instr) {
49                        if let Some(Operand::Register(dest)) = instr.operands.first() {
50                            self.shared_base_regs.insert(dest.clone());
51                        }
52                    }
53
54                    // Detect generic ld/st using tracked registers
55                    if (instr.opcode == Opcode::Ld || instr.opcode == Opcode::St)
56                        && !self.has_space_modifier(instr)
57                    {
58                        // Check if address operand uses a generic shared register
59                        let addr_operand = if instr.opcode == Opcode::Ld {
60                            instr.operands.get(1)
61                        } else {
62                            instr.operands.first()
63                        };
64
65                        if let Some(operand) = addr_operand {
66                            if self.uses_generic_shared_reg(operand) {
67                                bugs.push(GenericSharedBug {
68                                    location: instr.location.clone(),
69                                    instruction: instr.clone(),
70                                    severity: Severity::Critical,
71                                    fix: "Use ld.shared with 32-bit offset instead".into(),
72                                });
73                            }
74                        }
75                    }
76                }
77            }
78        }
79
80        bugs
81    }
82
83    /// Detect shared memory using 64-bit addresses (F022)
84    pub fn detect_shared_mem_u64(&self, module: &PtxModule) -> Vec<GenericSharedBug> {
85        let mut bugs = Vec::new();
86
87        for kernel in &module.kernels {
88            for stmt in &kernel.body {
89                if let Statement::Instruction(instr) = stmt {
90                    // Check for ld.shared.u64 or st.shared.u64 with 64-bit address
91                    if self.has_shared_modifier(instr) && self.has_u64_modifier(instr) {
92                        bugs.push(GenericSharedBug {
93                            location: instr.location.clone(),
94                            instruction: instr.clone(),
95                            severity: Severity::High,
96                            fix: "Use 32-bit offset for shared memory addressing".into(),
97                        });
98                    }
99                }
100            }
101        }
102
103        bugs
104    }
105
106    /// Detect cvta.shared inside loops (F083)
107    pub fn detect_loop_cvta_shared(&self, module: &PtxModule) -> Vec<GenericSharedBug> {
108        let mut bugs = Vec::new();
109
110        for kernel in &module.kernels {
111            let mut in_loop = false;
112            let mut loop_start_labels = HashSet::new();
113
114            for stmt in &kernel.body {
115                match stmt {
116                    Statement::Label(label) => {
117                        // Simple heuristic: label containing "loop" starts a loop
118                        if label.to_lowercase().contains("loop") {
119                            in_loop = true;
120                            loop_start_labels.insert(label.clone());
121                        }
122                    }
123                    Statement::Instruction(instr) => {
124                        // Check for backward branch (loop end)
125                        if instr.opcode == Opcode::Bra {
126                            for operand in &instr.operands {
127                                if let Operand::Label(target) = operand {
128                                    if loop_start_labels.contains(target) {
129                                        in_loop = false;
130                                    }
131                                }
132                            }
133                        }
134
135                        // Detect cvta.shared inside loop
136                        if in_loop
137                            && instr.opcode == Opcode::Cvta
138                            && self.has_shared_modifier(instr)
139                        {
140                            bugs.push(GenericSharedBug {
141                                location: instr.location.clone(),
142                                instruction: instr.clone(),
143                                severity: Severity::High,
144                                fix: "Move cvta.shared outside loop to reduce register pressure"
145                                    .into(),
146                            });
147                        }
148                    }
149                    _ => {}
150                }
151            }
152        }
153
154        bugs
155    }
156
157    fn has_shared_modifier(&self, instr: &Instruction) -> bool {
158        instr
159            .modifiers
160            .iter()
161            .any(|m| matches!(m, Modifier::Shared))
162    }
163
164    fn has_space_modifier(&self, instr: &Instruction) -> bool {
165        instr
166            .modifiers
167            .iter()
168            .any(|m| m.as_address_space().is_some())
169    }
170
171    fn has_u64_modifier(&self, instr: &Instruction) -> bool {
172        instr
173            .modifiers
174            .iter()
175            .any(|m| matches!(m, Modifier::U64 | Modifier::B64))
176    }
177
178    fn uses_generic_shared_reg(&self, operand: &Operand) -> bool {
179        match operand {
180            Operand::Register(reg) => self.shared_base_regs.contains(reg),
181            Operand::Memory(addr) => {
182                // Check if the memory address contains a generic shared register
183                self.shared_base_regs.iter().any(|reg| addr.contains(reg))
184            }
185            _ => false,
186        }
187    }
188}
189
190impl Default for AddressSpaceValidator {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::parser::Parser;
200
201    // F021: No cvta.shared followed by generic ld/st
202    #[test]
203    fn f021_no_generic_shared_access() {
204        let ptx = r#"
205            .version 8.0
206            .target sm_70
207            .address_size 64
208
209            .entry test()
210            {
211                .reg .u32 %r<10>;
212                ld.shared.u32 %r0, [%r1];
213                ret;
214            }
215        "#;
216        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
217        let module = parser.parse().expect("parsing should succeed");
218
219        let mut validator = AddressSpaceValidator::new();
220        let bugs = validator.detect_generic_shared_access(&module);
221
222        assert!(
223            bugs.is_empty(),
224            "F021: Should have no generic shared access bugs"
225        );
226    }
227
228    // F023: Direct .shared addressing preferred
229    #[test]
230    fn f023_direct_shared_addressing() {
231        let ptx = r#"
232            .version 8.0
233            .target sm_70
234            .address_size 64
235
236            .entry test()
237            {
238                .reg .u32 %r<10>;
239                ld.shared.u32 %r0, [%r1];
240                st.shared.u32 [%r2], %r0;
241                ret;
242            }
243        "#;
244        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
245        let module = parser.parse().expect("parsing should succeed");
246
247        let mut validator = AddressSpaceValidator::new();
248        let bugs = validator.detect_generic_shared_access(&module);
249
250        assert!(
251            bugs.is_empty(),
252            "F023: Direct shared addressing should not trigger bugs"
253        );
254    }
255}