ptx_parser/parser/
mod.rs

1use crate::lexer::{PtxToken, tokenize};
2use thiserror::Error;
3
4pub(crate) mod common;
5pub(crate) mod function;
6pub(crate) mod instruction;
7pub(crate) mod module;
8pub(crate) mod variable;
9
10pub type Span = std::ops::Range<usize>;
11
12/// Kinds of parse errors that can occur during PTX parsing.
13#[derive(Debug, Clone, PartialEq, Eq, Error)]
14pub enum ParseErrorKind {
15    #[error("unexpected token: expected one of {expected:?}, found {found}")]
16    UnexpectedToken {
17        expected: Vec<String>,
18        found: String,
19    },
20    #[error("unexpected end of input")]
21    UnexpectedEof,
22    #[error("invalid literal: {0}")]
23    InvalidLiteral(String),
24}
25
26/// PTX parsing error with location information.
27#[derive(Debug, Clone, PartialEq, Eq, Error)]
28#[error("parsing error at {span:?}: {kind}")]
29pub struct PtxParseError {
30    pub kind: ParseErrorKind,
31    pub span: Span,
32}
33
34/// Represents a position in the token stream, including both token index and character offset within a token
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct StreamPosition {
37    pub index: usize,
38    pub char_offset: usize,
39}
40
41/// Token stream wrapper for parsing PTX tokens.
42///
43/// This struct provides methods for consuming and inspecting tokens during parsing.
44pub struct PtxTokenStream<'a> {
45    tokens: &'a [(PtxToken, Span)],
46    /// Current position (index) in the tokens list
47    index: usize,
48    /// Position within the current token's string content (for parsing multi-char identifiers/numbers)
49    pub(crate) char_offset: usize,
50}
51
52impl<'a> PtxTokenStream<'a> {
53    pub fn new(tokens: &'a [(PtxToken, Span)]) -> Self {
54        Self {
55            tokens,
56            index: 0,
57            char_offset: 0,
58        }
59    }
60
61    /// Peek at the next token without consuming it.
62    pub fn peek(&self) -> Result<&'a (PtxToken, Span), PtxParseError> {
63        self.tokens.get(self.index).ok_or_else(|| {
64            // If the stream is empty, return an EOF error
65            let span = self.tokens.last().map_or(0..0, |(_, s)| s.clone());
66            PtxParseError {
67                kind: ParseErrorKind::UnexpectedEof,
68                span,
69            }
70        })
71    }
72
73    /// Consume and return the next token.
74    pub fn consume(&mut self) -> Result<&'a (PtxToken, Span), PtxParseError> {
75        let token = self.peek()?;
76        self.index += 1;
77        Ok(token)
78    }
79
80    /// Check if the next token is the expected type, and if so, consume it.
81    /// Otherwise, return an error and do NOT consume the token.
82    pub fn expect(&mut self, expected: &PtxToken) -> Result<&'a (PtxToken, Span), PtxParseError> {
83        let token_pair = self.peek()?;
84        let (token, span) = token_pair;
85        if std::mem::discriminant(token) == std::mem::discriminant(expected) {
86            self.index += 1;
87            Ok(token_pair)
88        } else {
89            Err(PtxParseError {
90                kind: ParseErrorKind::UnexpectedToken {
91                    expected: vec![format!("{:?}", expected)],
92                    found: format!("{:?}", token),
93                },
94                span: span.clone(),
95            })
96        }
97    }
98
99    /// Generic helper to extract a String value from a token variant.
100    /// Returns the extracted string and span if the pattern matches, otherwise returns an error.
101    fn expect_token_with_string<F>(
102        &mut self,
103        expected_name: &str,
104        extractor: F,
105    ) -> Result<(String, Span), PtxParseError>
106    where
107        F: FnOnce(&PtxToken) -> Option<String>,
108    {
109        let (token, span) = self.peek()?;
110        if let Some(value) = extractor(token) {
111            let span = span.clone();
112            self.index += 1;
113            Ok((value, span))
114        } else {
115            Err(PtxParseError {
116                kind: ParseErrorKind::UnexpectedToken {
117                    expected: vec![expected_name.to_string()],
118                    found: format!("{:?}", token),
119                },
120                span: span.clone(),
121            })
122        }
123    }
124
125    /// Check if the next token is an identifier, and if so, consume it and return the String.
126    pub fn expect_identifier(&mut self) -> Result<(String, Span), PtxParseError> {
127        self.expect_token_with_string("Identifier", |token| {
128            if let PtxToken::Identifier(name) = token {
129                Some(name.clone())
130            } else {
131                None
132            }
133        })
134    }
135
136    /// Check if the next token is a register, and if so, consume it and return the String.
137    pub fn expect_register(&mut self) -> Result<(String, Span), PtxParseError> {
138        self.expect_token_with_string("Register", |token| {
139            if let PtxToken::Register(name) = token {
140                Some(name.clone())
141            } else {
142                None
143            }
144        })
145    }
146
147    /// Check if the next token is a directive (Dot + Identifier), and if so, consume them and return the String.
148    pub fn expect_directive(&mut self) -> Result<(String, Span), PtxParseError> {
149        let (_, dot_span) = self.expect(&PtxToken::Dot)?;
150        let (name, id_span) = self.expect_identifier()?;
151        let span = dot_span.start..id_span.end;
152        Ok((name, span))
153    }
154
155    /// Check if the next token is a directive that represents a modifier (type, state space, etc.).
156    /// This is an alias for expect_directive for semantic clarity when parsing modifiers.
157    pub fn expect_modifier(&mut self) -> Result<(String, Span), PtxParseError> {
158        self.expect_directive()
159    }
160
161    /// Expect and consume a double colon (::) token sequence.
162    pub fn expect_double_colon(&mut self) -> Result<(), PtxParseError> {
163        self.expect(&PtxToken::Colon)?;
164        self.expect(&PtxToken::Colon)?;
165        Ok(())
166    }
167
168    /// Try to match and consume a sequence of tokens that matches one of the candidate strings.
169    /// Returns the index of the matched candidate.
170    ///
171    /// This is used for parsing modifiers that may contain :: sequences like ".to::cluster"
172    /// The candidates should include the leading dot (e.g., [".to::cluster", ".to::cta"])
173    pub fn expect_strings(&mut self, candidates: &[&str]) -> Result<usize, PtxParseError> {
174        let start_pos = self.position();
175
176        for (idx, candidate) in candidates.iter().enumerate() {
177            self.set_position(start_pos);
178
179            // Try to match this candidate
180            if self.try_match_string(candidate) {
181                return Ok(idx);
182            }
183        }
184
185        // None matched, create error
186        let (token, span) = self.peek()?;
187        Err(PtxParseError {
188            kind: ParseErrorKind::UnexpectedToken {
189                expected: candidates.iter().map(|s| s.to_string()).collect(),
190                found: format!("{:?}", token),
191            },
192            span: span.clone(),
193        })
194    }
195
196    pub fn expect_string(&mut self, expected: &str) -> Result<(), PtxParseError> {
197        if self.try_match_string(expected) {
198            Ok(())
199        } else {
200            let (token, span) = self.peek()?;
201            Err(PtxParseError {
202                kind: ParseErrorKind::UnexpectedToken {
203                    expected: vec![expected.to_string()],
204                    found: format!("{:?}", token),
205                },
206                span: span.clone(),
207            })
208        }
209    }
210
211    /// Try to match a string pattern by consuming characters from the stream.
212    ///
213    /// # Behavior
214    /// Matches the pattern character-by-character against the token stream.
215    /// Tokens are converted to their string representation and matched from char_offset.
216    /// If all characters match, the stream is advanced and returns true.
217    /// If any character fails to match, the stream is reset and returns false.
218    ///
219    /// # Returns
220    /// - `true` if the entire pattern was successfully matched (chars consumed)
221    /// - `false` if matching failed at any point (stream position restored)
222    pub fn try_match_string(&mut self, pattern: &str) -> bool {
223        let start_pos = self.position();
224
225        // Tokenize the pattern to get expected tokens
226        let expected_tokens = match tokenize(pattern) {
227            Ok(tokens) => tokens,
228            Err(_) => {
229                // If pattern can't be tokenized, it can't match
230                return false;
231            }
232        };
233
234        // Try to match each expected token
235        for (expected_token, _) in expected_tokens {
236            match self.peek() {
237                Ok((actual_token, _)) => {
238                    // Check if we can do a partial match for Identifier tokens
239                    // This handles cases like matching ".b3210" as ".b" + "3" + "2" + "1" + "0"
240                    if let (PtxToken::Identifier(actual_id), expected_str) =
241                        (actual_token, expected_token.as_str())
242                    {
243                        // Check if the expected string matches from the current char_offset
244                        let remaining = &actual_id[self.char_offset..];
245                        if remaining.starts_with(expected_str) {
246                            let new_offset = self.char_offset + expected_str.len();
247                            if new_offset == actual_id.len() {
248                                // Exactly consumed the entire identifier - advance to next token
249                                self.index += 1;
250                                self.char_offset = 0;
251                            } else {
252                                // Partial match! Advance char_offset but DON'T advance index
253                                self.char_offset = new_offset;
254                            }
255                            continue;
256                        }
257                    }
258
259                    // Normal exact match
260                    if actual_token != &expected_token {
261                        self.set_position(start_pos);
262                        return false;
263                    }
264                    // Token matches, consume it
265                    self.index += 1;
266                    self.char_offset = 0;
267                }
268                Err(_) => {
269                    // Unexpected EOF
270                    self.set_position(start_pos);
271                    return false;
272                }
273            }
274        }
275
276        // Successfully matched all tokens
277        true
278    }
279
280    /// Check if the next token matches a specific pattern.
281    pub fn check<F>(&self, predicate: F) -> bool
282    where
283        F: FnOnce(&PtxToken) -> bool,
284    {
285        self.tokens
286            .get(self.index)
287            .map_or(false, |(token, _)| predicate(token))
288    }
289
290    /// Expect that we've consumed a complete token (not stopped in the middle).
291    /// This should be called at the end of each struct parser to verify that
292    /// character-level parsing has consumed all characters from the current token.
293    ///
294    /// # Returns
295    /// - `Ok(())` if `char_offset == 0` (no partial token consumption)
296    /// - `Err(PtxParseError)` if `char_offset > 0` (stopped in middle of token)
297    pub fn expect_complete(&self) -> Result<(), PtxParseError> {
298        if self.char_offset > 0 {
299            // We're in the middle of a token - this is an error
300            let span = self
301                .peek()
302                .map(|(_, s)| s.clone())
303                .unwrap_or(Span { start: 0, end: 0 });
304            Err(unexpected_value(
305                span,
306                &["complete token"],
307                format!("partial token at char offset {}", self.char_offset),
308            ))
309        } else {
310            Ok(())
311        }
312    }
313
314    /// Consume the next token if it matches the predicate.
315    pub fn consume_if<F>(&mut self, predicate: F) -> Option<&'a (PtxToken, Span)>
316    where
317        F: FnOnce(&PtxToken) -> bool,
318    {
319        if self.check(predicate) {
320            self.index += 1;
321            self.tokens.get(self.index - 1)
322        } else {
323            None
324        }
325    }
326
327    /// Get the current position in the stream, for backtracking.
328    pub fn position(&self) -> StreamPosition {
329        StreamPosition {
330            index: self.index,
331            char_offset: self.char_offset,
332        }
333    }
334
335    /// Reset the stream to an old position, for backtracking.
336    pub fn set_position(&mut self, pos: StreamPosition) {
337        self.index = pos.index;
338        self.char_offset = pos.char_offset;
339    }
340
341    /// Check if we've reached the end of the token stream.
342    pub fn is_at_end(&self) -> bool {
343        self.index >= self.tokens.len()
344    }
345
346    /// Get the remaining tokens.
347    pub fn remaining(&self) -> &'a [(PtxToken, Span)] {
348        &self.tokens[self.index..]
349    }
350
351    /// Peek at the character at the current char_offset within the current token's string.
352    /// Returns None if we're at the end of the current token's string or if the token has no string content.
353    pub fn peek_char_in_token(&self) -> Option<char> {
354        if self.index >= self.tokens.len() {
355            return None;
356        }
357
358        let (token, _) = &self.tokens[self.index];
359        let string = match token {
360            PtxToken::Identifier(s)
361            | PtxToken::DecimalInteger(s)
362            | PtxToken::HexInteger(s)
363            | PtxToken::BinaryInteger(s)
364            | PtxToken::OctalInteger(s) => s,
365            _ => return None,
366        };
367
368        string.chars().nth(self.char_offset)
369    }
370
371    /// Consume one character from the current token by advancing char_offset.
372    /// If we reach the end of the token's string, advance to the next token and reset char_offset.
373    /// Returns the consumed character.
374    pub fn consume_char_in_token(&mut self) -> Option<char> {
375        let ch = self.peek_char_in_token()?;
376        self.char_offset += 1;
377
378        // Check if we've consumed the entire string of this token
379        if self.index < self.tokens.len() {
380            let (token, _) = &self.tokens[self.index];
381            let string = match token {
382                PtxToken::Identifier(s)
383                | PtxToken::DecimalInteger(s)
384                | PtxToken::HexInteger(s)
385                | PtxToken::BinaryInteger(s)
386                | PtxToken::OctalInteger(s) => s,
387                _ => "",
388            };
389
390            if self.char_offset >= string.len() {
391                // Move to next token and reset char_offset
392                self.index += 1;
393                self.char_offset = 0;
394            }
395        }
396
397        Some(ch)
398    }
399
400    /// Match a specific character at the current position within the token.
401    /// Consumes the character if it matches.
402    pub fn expect_char_in_token(&mut self, expected: char) -> Result<char, PtxParseError> {
403        match self.peek_char_in_token() {
404            Some(ch) if ch == expected => {
405                self.consume_char_in_token();
406                Ok(ch)
407            }
408            Some(ch) => {
409                let span = if self.index < self.tokens.len() {
410                    self.tokens[self.index].1.clone()
411                } else {
412                    0..0
413                };
414                Err(PtxParseError {
415                    kind: ParseErrorKind::UnexpectedToken {
416                        expected: vec![format!("'{}'", expected)],
417                        found: format!("'{}'", ch),
418                    },
419                    span,
420                })
421            }
422            None => {
423                let span = if self.index < self.tokens.len() {
424                    self.tokens[self.index].1.clone()
425                } else {
426                    0..0
427                };
428                Err(PtxParseError {
429                    kind: ParseErrorKind::UnexpectedEof,
430                    span,
431                })
432            }
433        }
434    }
435}
436
437/// Trait for types that can be parsed from a PTX token stream.
438///
439/// This trait is implemented for all PTX AST node types to enable
440/// recursive descent parsing.
441pub trait PtxParser
442where
443    Self: Sized,
444{
445    /// Parse an instance of `Self` from the token stream.
446    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError>;
447}
448
449/// Parse PTX source code into a structured Module representation.
450///
451/// This is the main entry point for parsing PTX code. It performs lexical
452/// analysis followed by syntactic parsing.
453///
454/// # Arguments
455///
456/// * `source` - The PTX source code as a string slice
457///
458/// # Returns
459///
460/// Returns a parsed `Module` AST node, or a `PtxParseError` if parsing fails.
461///
462/// # Example
463///
464/// ```no_run
465/// use ptx_parser::parse_ptx;
466///
467/// let source = r#"
468///     .version 8.5
469///     .target sm_90
470///     .address_size 64
471///     
472///     .entry kernel() {
473///         ret;
474///     }
475/// "#;
476///
477/// let module = parse_ptx(source).expect("Failed to parse PTX");
478/// println!("Parsed {} directives", module.directives.len());
479/// ```
480pub fn parse_ptx(source: &str) -> Result<crate::r#type::module::Module, PtxParseError> {
481    let tokens = crate::lexer::tokenize(source).map_err(|err| PtxParseError {
482        kind: ParseErrorKind::InvalidLiteral("lexical error".into()),
483        span: err.span,
484    })?;
485    let mut stream = PtxTokenStream::new(&tokens);
486    let module = crate::r#type::module::Module::parse(&mut stream)?;
487    if !stream.is_at_end() {
488        let (token, span) = stream.peek()?;
489        return Err(unexpected_value(
490            span.clone(),
491            &["end of input"],
492            format!("{token:?}"),
493        ));
494    }
495    Ok(module)
496}
497
498pub fn unexpected_value(span: Span, expected: &[&str], found: impl Into<String>) -> PtxParseError {
499    PtxParseError {
500        kind: ParseErrorKind::UnexpectedToken {
501            expected: expected.iter().map(|s| s.to_string()).collect(),
502            found: found.into(),
503        },
504        span,
505    }
506}
507
508pub(crate) fn invalid_literal(span: Span, message: impl Into<String>) -> PtxParseError {
509    PtxParseError {
510        kind: ParseErrorKind::InvalidLiteral(message.into()),
511        span,
512    }
513}
514
515pub(crate) fn expect_directive_value(
516    stream: &mut PtxTokenStream,
517    expected: &str,
518) -> Result<(), PtxParseError> {
519    let (value, span) = stream.expect_directive()?;
520    if value == expected {
521        Ok(())
522    } else {
523        Err(unexpected_value(
524            span,
525            &[&format!(".{expected}")],
526            format!(".{value}"),
527        ))
528    }
529}
530
531pub(crate) fn peek_directive(
532    stream: &mut PtxTokenStream,
533) -> Result<Option<(String, Span)>, PtxParseError> {
534    // Check if we have Dot followed by Identifier
535    if let Some((PtxToken::Dot, dot_span)) = stream.tokens.get(stream.index) {
536        if let Some((PtxToken::Identifier(value), id_span)) = stream.tokens.get(stream.index + 1) {
537            let span = dot_span.start..id_span.end;
538            return Ok(Some((value.clone(), span)));
539        }
540    }
541    Ok(None)
542}