Skip to main content

trueno_ptx_debug/parser/
ast.rs

1//! PTX Abstract Syntax Tree definitions
2
3use super::types::{AddressSpace, Modifier, Opcode, PtxType, SmTarget};
4
5/// PTX Module - top-level AST node
6#[derive(Debug, Clone, Default)]
7pub struct PtxModule {
8    /// PTX version (major, minor)
9    pub version: (u8, u8),
10    /// Target SM architecture
11    pub target: SmTarget,
12    /// Address size (32 or 64)
13    pub address_size: u8,
14    /// Global declarations
15    pub globals: Vec<GlobalDecl>,
16    /// Kernel definitions
17    pub kernels: Vec<KernelDef>,
18    /// Function definitions
19    pub functions: Vec<FunctionDef>,
20}
21
22/// Kernel definition
23#[derive(Debug, Clone)]
24pub struct KernelDef {
25    /// Kernel name
26    pub name: String,
27    /// Is this an entry point (.entry vs .func)
28    pub is_entry: bool,
29    /// Parameters
30    pub params: Vec<Param>,
31    /// Register declarations
32    pub registers: Vec<RegisterDecl>,
33    /// Shared memory declarations
34    pub shared_mem: Vec<SharedMemDecl>,
35    /// Body statements
36    pub body: Vec<Statement>,
37}
38
39/// Function definition (non-entry)
40#[derive(Debug, Clone)]
41pub struct FunctionDef {
42    /// Function name
43    pub name: String,
44    /// Return type
45    pub return_type: Option<PtxType>,
46    /// Parameters
47    pub params: Vec<Param>,
48    /// Register declarations
49    pub registers: Vec<RegisterDecl>,
50    /// Body statements
51    pub body: Vec<Statement>,
52}
53
54/// Global declaration
55#[derive(Debug, Clone)]
56pub struct GlobalDecl {
57    /// Name
58    pub name: String,
59    /// Address space
60    pub space: AddressSpace,
61    /// Type
62    pub ty: PtxType,
63    /// Size in bytes
64    pub size: usize,
65    /// Initial value (if any)
66    pub init: Option<Vec<u8>>,
67}
68
69/// Parameter declaration
70#[derive(Debug, Clone)]
71pub struct Param {
72    /// Parameter name
73    pub name: String,
74    /// Parameter type
75    pub ty: PtxType,
76}
77
78/// Register declaration
79#[derive(Debug, Clone)]
80pub struct RegisterDecl {
81    /// Register name (e.g., "%r0" or "%r<10>")
82    pub name: String,
83    /// Register type
84    pub ty: PtxType,
85}
86
87/// Shared memory declaration
88#[derive(Debug, Clone)]
89pub struct SharedMemDecl {
90    /// Name
91    pub name: String,
92    /// Size in bytes
93    pub size: usize,
94    /// Element type
95    pub ty: PtxType,
96}
97
98/// Statement in a kernel/function body
99#[derive(Debug, Clone)]
100pub enum Statement {
101    /// Label definition
102    Label(String),
103    /// Instruction
104    Instruction(Instruction),
105    /// Directive within body
106    Directive(Directive),
107    /// Comment (preserved for debugging)
108    Comment(String),
109}
110
111/// PTX Instruction
112#[derive(Debug, Clone)]
113pub struct Instruction {
114    /// Opcode
115    pub opcode: Opcode,
116    /// Modifiers (.shared, .u32, etc.)
117    pub modifiers: Vec<Modifier>,
118    /// Operands
119    pub operands: Vec<Operand>,
120    /// Predicate (if any)
121    pub predicate: Option<Predicate>,
122    /// Source location
123    pub location: SourceLocation,
124}
125
126/// Directive within body
127#[derive(Debug, Clone)]
128pub enum Directive {
129    /// .loc directive for debug info
130    Loc { file: u32, line: u32, column: u32 },
131    /// .pragma directive
132    Pragma(String),
133    /// Other directive
134    Other(String),
135}
136
137/// Instruction operand
138#[derive(Debug, Clone)]
139pub enum Operand {
140    /// Register (%r0, %rd1, etc.)
141    Register(String),
142    /// Memory reference ([%r0], [%r0+4], etc.)
143    Memory(String),
144    /// Immediate value
145    Immediate(i64),
146    /// Floating point immediate
147    ImmediateFloat(f64),
148    /// Label reference (for branches)
149    Label(String),
150    /// Vector operand {%r0, %r1, %r2, %r3}
151    Vector(Vec<Operand>),
152}
153
154impl Operand {
155    /// Check if this operand is a register
156    pub fn is_register(&self) -> bool {
157        matches!(self, Operand::Register(_))
158    }
159
160    /// Check if this operand is a memory reference
161    pub fn is_memory(&self) -> bool {
162        matches!(self, Operand::Memory(_))
163    }
164
165    /// Get the register name if this is a register operand
166    pub fn as_register(&self) -> Option<&str> {
167        match self {
168            Operand::Register(name) => Some(name),
169            _ => None,
170        }
171    }
172}
173
174impl std::fmt::Display for Operand {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        match self {
177            Operand::Register(name) => write!(f, "{}", name),
178            Operand::Memory(addr) => write!(f, "{}", addr),
179            Operand::Immediate(val) => write!(f, "{}", val),
180            Operand::ImmediateFloat(val) => write!(f, "{}", val),
181            Operand::Label(name) => write!(f, "{}", name),
182            Operand::Vector(ops) => {
183                write!(f, "{{")?;
184                for (i, op) in ops.iter().enumerate() {
185                    if i > 0 {
186                        write!(f, ", ")?;
187                    }
188                    write!(f, "{}", op)?;
189                }
190                write!(f, "}}")
191            }
192        }
193    }
194}
195
196/// Predicate for conditional execution
197#[derive(Debug, Clone)]
198pub struct Predicate {
199    /// Predicate register name
200    pub register: String,
201    /// Is negated (@!p)
202    pub negated: bool,
203}
204
205/// Source location for error reporting
206#[derive(Debug, Clone, Default)]
207pub struct SourceLocation {
208    /// Line number (1-based)
209    pub line: usize,
210    /// Column number (1-based)
211    pub column: usize,
212    /// Source file (if known)
213    pub file: Option<String>,
214}
215
216impl std::fmt::Display for SourceLocation {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        if let Some(file) = &self.file {
219            write!(f, "{}:{}:{}", file, self.line, self.column)
220        } else {
221            write!(f, "{}:{}", self.line, self.column)
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_operand_display() {
232        assert_eq!(format!("{}", Operand::Register("%r0".into())), "%r0");
233        assert_eq!(format!("{}", Operand::Memory("[%r0]".into())), "[%r0]");
234        assert_eq!(format!("{}", Operand::Immediate(42)), "42");
235    }
236
237    #[test]
238    fn test_source_location_display() {
239        let loc = SourceLocation {
240            line: 10,
241            column: 5,
242            file: Some("test.ptx".into()),
243        };
244        assert_eq!(format!("{}", loc), "test.ptx:10:5");
245
246        let loc = SourceLocation {
247            line: 10,
248            column: 5,
249            file: None,
250        };
251        assert_eq!(format!("{}", loc), "10:5");
252    }
253
254    #[test]
255    fn test_operand_type_checks() {
256        let reg = Operand::Register("%r0".into());
257        assert!(reg.is_register());
258        assert!(!reg.is_memory());
259
260        let mem = Operand::Memory("[%r0]".into());
261        assert!(!mem.is_register());
262        assert!(mem.is_memory());
263    }
264}