Skip to main content

trueno_ptx_debug/parser/
mod.rs

1//! PTX Parser Module
2//!
3//! Provides lexing and parsing of PTX source code into a typed AST.
4
5mod ast;
6mod error;
7mod lexer;
8pub mod types;
9
10pub use ast::{
11    Directive, FunctionDef, GlobalDecl, Instruction, KernelDef, Operand, Param, Predicate,
12    PtxModule, RegisterDecl, SharedMemDecl, SourceLocation, Statement,
13};
14pub use error::ParseError;
15pub use lexer::{Lexer, Token, TokenKind};
16pub use types::{AddressSpace, Modifier, Opcode, PtxType, SmTarget};
17
18/// Match a text string against `contains` patterns and return the first matching value.
19///
20/// Each arm is `[pattern1, pattern2, ...] => value` where patterns are string literals
21/// checked via `str::contains`. Any matching pattern in the bracket group triggers the
22/// arm (logical OR). The final `default` argument is returned when no arm matches.
23///
24/// Single-pattern arms use `["pattern"] => value`.
25///
26/// Returns the value directly (not wrapped in `Result`).
27macro_rules! match_contains {
28    ($text:expr, $default:expr, $( [ $( $pattern:expr ),+ ] => $value:expr ),+ $(,)?) => {{
29        let __text = $text;
30        $(
31            if $( __text.contains($pattern) )||+ {
32                $value
33            } else
34        )+
35        { $default }
36    }};
37}
38
39/// Match an exact string against a lookup table and return the corresponding value.
40///
41/// Each arm is `literal => value`. The final `default` argument is returned when
42/// no literal matches.
43macro_rules! match_str_lookup {
44    ($text:expr, $default:expr, $( $lit:expr => $value:expr ),+ $(,)?) => {
45        match $text {
46            $( $lit => $value, )+
47            _ => $default,
48        }
49    };
50}
51
52/// PTX Parser - constructs AST from token stream
53pub struct Parser<'a> {
54    lexer: Lexer<'a>,
55    current: Token,
56    peek: Token,
57}
58
59impl<'a> Parser<'a> {
60    /// Create a new parser from PTX source
61    pub fn new(source: &'a str) -> Result<Self, ParseError> {
62        let mut lexer = Lexer::new(source);
63        let current = lexer.next_token()?;
64        let peek = lexer.next_token()?;
65        Ok(Self {
66            lexer,
67            current,
68            peek,
69        })
70    }
71
72    /// Parse the PTX source into a module
73    pub fn parse(&mut self) -> Result<PtxModule, ParseError> {
74        let mut module = PtxModule::default();
75
76        while self.current.kind != TokenKind::Eof {
77            match self.current.kind {
78                TokenKind::Directive => {
79                    self.parse_directive(&mut module)?;
80                }
81                TokenKind::Entry | TokenKind::Func => {
82                    let kernel = self.parse_kernel()?;
83                    module.kernels.push(kernel);
84                }
85                _ => {
86                    self.advance()?;
87                }
88            }
89        }
90
91        Ok(module)
92    }
93
94    fn advance(&mut self) -> Result<(), ParseError> {
95        self.current = std::mem::replace(&mut self.peek, self.lexer.next_token()?);
96        Ok(())
97    }
98
99    fn parse_directive(&mut self, module: &mut PtxModule) -> Result<(), ParseError> {
100        let directive_text = self.current.text.clone();
101
102        if directive_text.starts_with(".version") {
103            module.version = self.parse_version_directive()?;
104        } else if directive_text.starts_with(".target") {
105            module.target = self.parse_target_directive()?;
106        } else if directive_text.starts_with(".address_size") {
107            module.address_size = self.parse_address_size_directive()?;
108        }
109
110        self.advance()?;
111        Ok(())
112    }
113
114    fn parse_version_directive(&self) -> Result<(u8, u8), ParseError> {
115        // Parse ".version X.Y"
116        let text = &self.current.text;
117        if let Some(rest) = text.strip_prefix(".version") {
118            let version_str = rest.trim();
119            let parts: Vec<&str> = version_str.split('.').collect();
120            if parts.len() >= 2 {
121                let major = parts[0].parse().unwrap_or(0);
122                let minor = parts[1].parse().unwrap_or(0);
123                return Ok((major, minor));
124            }
125        }
126        Ok((0, 0))
127    }
128
129    fn parse_target_directive(&self) -> Result<SmTarget, ParseError> {
130        let text = &self.current.text;
131        Ok(match_contains!(text, SmTarget::Unknown,
132            ["sm_70"] => SmTarget::Sm70,
133            ["sm_75"] => SmTarget::Sm75,
134            ["sm_80"] => SmTarget::Sm80,
135            ["sm_86"] => SmTarget::Sm86,
136            ["sm_89"] => SmTarget::Sm89,
137            ["sm_90"] => SmTarget::Sm90,
138        ))
139    }
140
141    fn parse_address_size_directive(&self) -> Result<u8, ParseError> {
142        let text = &self.current.text;
143        Ok(match_contains!(text, 0,
144            ["64"] => 64,
145            ["32"] => 32,
146        ))
147    }
148
149    fn parse_kernel(&mut self) -> Result<KernelDef, ParseError> {
150        let is_entry = self.current.kind == TokenKind::Entry;
151        self.advance()?;
152
153        // Parse kernel name
154        let name = if self.current.kind == TokenKind::Identifier {
155            let n = self.current.text.clone();
156            self.advance()?;
157            n
158        } else {
159            return Err(ParseError::UnexpectedToken {
160                expected: "kernel name".into(),
161                found: format!("{:?}", self.current.kind),
162                location: self.current.location.clone(),
163            });
164        };
165
166        let mut kernel = KernelDef {
167            name,
168            is_entry,
169            params: Vec::new(),
170            registers: Vec::new(),
171            shared_mem: Vec::new(),
172            body: Vec::new(),
173        };
174
175        // Parse parameter list and body (simplified)
176        while self.current.kind != TokenKind::Eof {
177            if self.current.kind == TokenKind::Entry || self.current.kind == TokenKind::Func {
178                break;
179            }
180
181            match self.current.kind {
182                TokenKind::Reg => {
183                    let reg = self.parse_register_decl()?;
184                    kernel.registers.push(reg);
185                }
186                TokenKind::Shared => {
187                    let shared = self.parse_shared_mem_decl()?;
188                    kernel.shared_mem.push(shared);
189                }
190                TokenKind::Instruction => {
191                    let instr = self.parse_instruction()?;
192                    kernel.body.push(Statement::Instruction(instr));
193                }
194                TokenKind::Label => {
195                    let label = self.current.text.trim_end_matches(':').to_string();
196                    kernel.body.push(Statement::Label(label));
197                    self.advance()?;
198                }
199                _ => {
200                    self.advance()?;
201                }
202            }
203        }
204
205        Ok(kernel)
206    }
207
208    fn parse_register_decl(&mut self) -> Result<RegisterDecl, ParseError> {
209        // Parse ".reg .TYPE %name"
210        let text = self.current.text.clone();
211        self.advance()?;
212
213        // Extract type and name from the declaration
214        let (ty, name) = self.extract_reg_type_and_name(&text);
215
216        Ok(RegisterDecl { name, ty })
217    }
218
219    fn extract_reg_type_and_name(&self, text: &str) -> (PtxType, String) {
220        let ty = match_contains!(text, PtxType::B32,
221            [".b64", ".u64"] => PtxType::B64,
222            [".b32", ".u32"] => PtxType::U32,
223            [".f32"]         => PtxType::F32,
224            [".f64"]         => PtxType::F64,
225            [".pred"]        => PtxType::Pred,
226        );
227
228        // Extract register name (starts with %)
229        let name = text
230            .split_whitespace()
231            .find(|s| s.starts_with('%'))
232            .map(|s| s.trim_end_matches(|c: char| !c.is_alphanumeric() && c != '%' && c != '_'))
233            .unwrap_or("%unknown")
234            .to_string();
235
236        (ty, name)
237    }
238
239    fn parse_shared_mem_decl(&mut self) -> Result<SharedMemDecl, ParseError> {
240        let text = self.current.text.clone();
241        self.advance()?;
242
243        // Parse shared memory declaration
244        let name = text
245            .split_whitespace()
246            .find(|s| !s.starts_with('.'))
247            .unwrap_or("unknown")
248            .to_string();
249
250        let size = text
251            .split('[')
252            .nth(1)
253            .and_then(|s| s.split(']').next())
254            .and_then(|s| s.parse().ok())
255            .unwrap_or(0);
256
257        Ok(SharedMemDecl {
258            name,
259            size,
260            ty: PtxType::B8,
261        })
262    }
263
264    fn parse_instruction(&mut self) -> Result<Instruction, ParseError> {
265        let text = self.current.text.clone();
266        let location = self.current.location.clone();
267        self.advance()?;
268
269        let (opcode, modifiers) = self.parse_opcode(&text);
270        let operands = self.parse_operands(&text);
271
272        Ok(Instruction {
273            opcode,
274            modifiers,
275            operands,
276            predicate: None,
277            location,
278        })
279    }
280
281    fn parse_opcode(&self, text: &str) -> (Opcode, Vec<Modifier>) {
282        let parts: Vec<&str> = text
283            .split_whitespace()
284            .next()
285            .unwrap_or("")
286            .split('.')
287            .collect();
288
289        let opcode = match parts.first().copied() {
290            Some(s) => match_str_lookup!(s, Opcode::Unknown,
291                "ld"     => Opcode::Ld,
292                "st"     => Opcode::St,
293                "mov"    => Opcode::Mov,
294                "add"    => Opcode::Add,
295                "sub"    => Opcode::Sub,
296                "mul"    => Opcode::Mul,
297                "mad"    => Opcode::Mad,
298                "fma"    => Opcode::Fma,
299                "cvta"   => Opcode::Cvta,
300                "cvt"    => Opcode::Cvt,
301                "setp"   => Opcode::Setp,
302                "bra"    => Opcode::Bra,
303                "bar"    => Opcode::Bar,
304                "atom"   => Opcode::Atom,
305                "ret"    => Opcode::Ret,
306                "exit"   => Opcode::Exit,
307                "and"    => Opcode::And,
308                "or"     => Opcode::Or,
309                "xor"    => Opcode::Xor,
310                "shl"    => Opcode::Shl,
311                "shr"    => Opcode::Shr,
312                "membar" => Opcode::MemBar,
313            ),
314            None => Opcode::Unknown,
315        };
316
317        let modifiers = parts
318            .iter()
319            .skip(1)
320            .map(|&m| self.parse_modifier(m))
321            .collect();
322
323        (opcode, modifiers)
324    }
325
326    fn parse_modifier(&self, s: &str) -> Modifier {
327        match_str_lookup!(s, Modifier::Other(s.to_string()),
328            "shared" => Modifier::Shared,
329            "global" => Modifier::Global,
330            "local"  => Modifier::Local,
331            "const"  => Modifier::Const,
332            "param"  => Modifier::Param,
333            "u32"    => Modifier::U32,
334            "u64"    => Modifier::U64,
335            "s32"    => Modifier::S32,
336            "s64"    => Modifier::S64,
337            "f32"    => Modifier::F32,
338            "f64"    => Modifier::F64,
339            "b32"    => Modifier::B32,
340            "b64"    => Modifier::B64,
341            "sync"   => Modifier::Sync,
342            "cta"    => Modifier::Cta,
343            "gl"     => Modifier::Gl,
344            "add"    => Modifier::AtomicAdd,
345            "cas"    => Modifier::AtomicCas,
346        )
347    }
348
349    fn parse_operands(&self, text: &str) -> Vec<Operand> {
350        // Simple operand parsing - extract tokens after the opcode
351        let mut operands = Vec::new();
352
353        // Skip the opcode part
354        let operand_part = text
355            .split_whitespace()
356            .skip(1)
357            .collect::<Vec<_>>()
358            .join(" ");
359
360        for part in operand_part.split(',') {
361            let trimmed = part.trim();
362            if trimmed.is_empty() {
363                continue;
364            }
365
366            let operand = if trimmed.starts_with('%') {
367                Operand::Register(trimmed.to_string())
368            } else if trimmed.starts_with('[') {
369                Operand::Memory(trimmed.to_string())
370            } else if trimmed.parse::<i64>().is_ok() {
371                Operand::Immediate(trimmed.parse().unwrap_or(0))
372            } else {
373                Operand::Label(trimmed.to_string())
374            };
375            operands.push(operand);
376        }
377
378        operands
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    // F001: PTX contains .version directive
387    #[test]
388    fn f001_version_directive_present() {
389        let ptx = r#"
390            .version 8.0
391            .target sm_70
392            .address_size 64
393        "#;
394        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
395        let module = parser.parse().expect("parsing should succeed");
396        assert!(module.version.0 > 0, "F001: Missing .version directive");
397    }
398
399    // F001: Negative test
400    #[test]
401    fn f001_version_directive_missing() {
402        let ptx = r#"
403            .target sm_70
404            .address_size 64
405        "#;
406        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
407        let module = parser.parse().expect("parsing should succeed");
408        assert_eq!(module.version, (0, 0), "Should detect missing version");
409    }
410
411    // F002: PTX contains .target directive
412    #[test]
413    fn f002_target_directive_present() {
414        let ptx = r#"
415            .version 8.0
416            .target sm_70
417            .address_size 64
418        "#;
419        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
420        let module = parser.parse().expect("parsing should succeed");
421        assert_ne!(
422            module.target,
423            SmTarget::Unknown,
424            "F002: Missing .target directive"
425        );
426    }
427
428    // F003: address_size is 32 or 64
429    #[test]
430    fn f003_address_size_valid() {
431        let ptx = r#"
432            .version 8.0
433            .target sm_70
434            .address_size 64
435        "#;
436        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
437        let module = parser.parse().expect("parsing should succeed");
438        assert!(
439            module.address_size == 32 || module.address_size == 64,
440            "F003: address_size must be 32 or 64"
441        );
442    }
443
444    // Test parsing a simple kernel
445    #[test]
446    fn parse_simple_kernel() {
447        let ptx = r#"
448            .version 8.0
449            .target sm_70
450            .address_size 64
451
452            .entry test_kernel(
453                .param .u64 param0
454            )
455            {
456                .reg .u32 %r<10>;
457                .reg .u64 %rd<10>;
458
459                mov.u32 %r0, 0;
460                ret;
461            }
462        "#;
463        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
464        let module = parser.parse().expect("parsing should succeed");
465
466        assert_eq!(module.version, (8, 0));
467        assert_eq!(module.target, SmTarget::Sm70);
468        assert_eq!(module.address_size, 64);
469        assert!(
470            !module.kernels.is_empty(),
471            "Should have at least one kernel"
472        );
473    }
474
475    // Test parsing instructions
476    #[test]
477    fn parse_ld_st_instructions() {
478        let ptx = r#"
479            .version 8.0
480            .target sm_70
481            .address_size 64
482
483            .entry test(
484            )
485            {
486                .reg .u32 %r<10>;
487
488                ld.shared.u32 %r0, [%r1];
489                st.global.u32 [%r2], %r0;
490                ret;
491            }
492        "#;
493        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
494        let module = parser.parse().expect("parsing should succeed");
495
496        let kernel = &module.kernels[0];
497        let instructions: Vec<_> = kernel
498            .body
499            .iter()
500            .filter_map(|s| match s {
501                Statement::Instruction(i) => Some(i),
502                _ => None,
503            })
504            .collect();
505
506        assert!(instructions.iter().any(|i| matches!(i.opcode, Opcode::Ld)));
507        assert!(instructions.iter().any(|i| matches!(i.opcode, Opcode::St)));
508    }
509}