ptx_parser/parser/
mod.rs

1use crate::lexer::{PtxToken, tokenize};
2use thiserror::Error;
3
4mod common;
5mod function;
6pub mod instruction;
7mod module;
8mod variable;
9
10pub type Span = std::ops::Range<usize>;
11
12#[derive(Debug, Clone, PartialEq, Eq, Error)]
13pub enum ParseErrorKind {
14    #[error("unexpected token: expected one of {expected:?}, found {found}")]
15    UnexpectedToken {
16        expected: Vec<String>,
17        found: String,
18    },
19    #[error("unexpected end of input")]
20    UnexpectedEof,
21    #[error("invalid literal: {0}")]
22    InvalidLiteral(String),
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, Error)]
26#[error("parsing error at {span:?}: {kind}")]
27pub struct PtxParseError {
28    pub kind: ParseErrorKind,
29    pub span: Span,
30}
31
32/// Represents a position in the token stream, including both token index and character offset within a token
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct StreamPosition {
35    pub index: usize,
36    pub char_offset: usize,
37}
38
39pub struct PtxTokenStream<'a> {
40    tokens: &'a [(PtxToken, Span)],
41    /// Current position (index) in the tokens list
42    index: usize,
43    /// Position within the current token's string content (for parsing multi-char identifiers/numbers)
44    pub(crate) char_offset: usize,
45}
46
47impl<'a> PtxTokenStream<'a> {
48    pub fn new(tokens: &'a [(PtxToken, Span)]) -> Self {
49        Self { tokens, index: 0, char_offset: 0 }
50    }
51
52    /// Peek at the next token without consuming it.
53    pub fn peek(&self) -> Result<&'a (PtxToken, Span), PtxParseError> {
54        self.tokens.get(self.index).ok_or_else(|| {
55            // If the stream is empty, return an EOF error
56            let span = self.tokens.last().map_or(0..0, |(_, s)| s.clone());
57            PtxParseError {
58                kind: ParseErrorKind::UnexpectedEof,
59                span,
60            }
61        })
62    }
63
64    /// Consume and return the next token.
65    pub fn consume(&mut self) -> Result<&'a (PtxToken, Span), PtxParseError> {
66        let token = self.peek()?;
67        self.index += 1;
68        Ok(token)
69    }
70
71    /// Check if the next token is the expected type, and if so, consume it.
72    /// Otherwise, return an error and do NOT consume the token.
73    pub fn expect(&mut self, expected: &PtxToken) -> Result<&'a (PtxToken, Span), PtxParseError> {
74        let token_pair = self.peek()?;
75        let (token, span) = token_pair;
76        if std::mem::discriminant(token) == std::mem::discriminant(expected) {
77            self.index += 1;
78            Ok(token_pair)
79        } else {
80            Err(PtxParseError {
81                kind: ParseErrorKind::UnexpectedToken {
82                    expected: vec![format!("{:?}", expected)],
83                    found: format!("{:?}", token),
84                },
85                span: span.clone(),
86            })
87        }
88    }
89
90    /// Generic helper to extract a String value from a token variant.
91    /// Returns the extracted string and span if the pattern matches, otherwise returns an error.
92    fn expect_token_with_string<F>(
93        &mut self,
94        expected_name: &str,
95        extractor: F,
96    ) -> Result<(String, Span), PtxParseError>
97    where
98        F: FnOnce(&PtxToken) -> Option<String>,
99    {
100        let (token, span) = self.peek()?;
101        if let Some(value) = extractor(token) {
102            let span = span.clone();
103            self.index += 1;
104            Ok((value, span))
105        } else {
106            Err(PtxParseError {
107                kind: ParseErrorKind::UnexpectedToken {
108                    expected: vec![expected_name.to_string()],
109                    found: format!("{:?}", token),
110                },
111                span: span.clone(),
112            })
113        }
114    }
115
116    /// Check if the next token is an identifier, and if so, consume it and return the String.
117    pub fn expect_identifier(&mut self) -> Result<(String, Span), PtxParseError> {
118        self.expect_token_with_string("Identifier", |token| {
119            if let PtxToken::Identifier(name) = token {
120                Some(name.clone())
121            } else {
122                None
123            }
124        })
125    }
126
127    /// Check if the next token is a register, and if so, consume it and return the String.
128    pub fn expect_register(&mut self) -> Result<(String, Span), PtxParseError> {
129        self.expect_token_with_string("Register", |token| {
130            if let PtxToken::Register(name) = token {
131                Some(name.clone())
132            } else {
133                None
134            }
135        })
136    }
137
138    /// Check if the next token is a directive (Dot + Identifier), and if so, consume them and return the String.
139    pub fn expect_directive(&mut self) -> Result<(String, Span), PtxParseError> {
140        let (_, dot_span) = self.expect(&PtxToken::Dot)?;
141        let (name, id_span) = self.expect_identifier()?;
142        let span = dot_span.start..id_span.end;
143        Ok((name, span))
144    }
145
146    /// Check if the next token is a directive that represents a modifier (type, state space, etc.).
147    /// This is an alias for expect_directive for semantic clarity when parsing modifiers.
148    pub fn expect_modifier(&mut self) -> Result<(String, Span), PtxParseError> {
149        self.expect_directive()
150    }
151
152    /// Expect and consume a double colon (::) token sequence.
153    pub fn expect_double_colon(&mut self) -> Result<(), PtxParseError> {
154        self.expect(&PtxToken::Colon)?;
155        self.expect(&PtxToken::Colon)?;
156        Ok(())
157    }
158
159    /// Try to match and consume a sequence of tokens that matches one of the candidate strings.
160    /// Returns the index of the matched candidate.
161    ///
162    /// This is used for parsing modifiers that may contain :: sequences like ".to::cluster"
163    /// The candidates should include the leading dot (e.g., [".to::cluster", ".to::cta"])
164    pub fn expect_strings(&mut self, candidates: &[&str]) -> Result<usize, PtxParseError> {
165        let start_pos = self.position();
166
167        for (idx, candidate) in candidates.iter().enumerate() {
168            self.set_position(start_pos);
169
170            // Try to match this candidate
171            if self.try_match_string(candidate) {
172                return Ok(idx);
173            }
174        }
175
176        // None matched, create error
177        let (token, span) = self.peek()?;
178        Err(PtxParseError {
179            kind: ParseErrorKind::UnexpectedToken {
180                expected: candidates.iter().map(|s| s.to_string()).collect(),
181                found: format!("{:?}", token),
182            },
183            span: span.clone(),
184        })
185    }
186
187    pub fn expect_string(&mut self, expected: &str) -> Result<(), PtxParseError> {
188        if self.try_match_string(expected) {
189            Ok(())
190        } else {
191            let (token, span) = self.peek()?;
192            Err(PtxParseError {
193                kind: ParseErrorKind::UnexpectedToken {
194                    expected: vec![expected.to_string()],
195                    found: format!("{:?}", token),
196                },
197                span: span.clone(),
198            })
199        }
200    }
201
202    /// Try to match a string pattern by consuming characters from the stream.
203    ///
204    /// # Behavior
205    /// Matches the pattern character-by-character against the token stream.
206    /// Tokens are converted to their string representation and matched from char_offset.
207    /// If all characters match, the stream is advanced and returns true.
208    /// If any character fails to match, the stream is reset and returns false.
209    ///
210    /// # Returns
211    /// - `true` if the entire pattern was successfully matched (chars consumed)
212    /// - `false` if matching failed at any point (stream position restored)
213    pub fn try_match_string(&mut self, pattern: &str) -> bool {
214        let start_pos = self.position();
215
216        // Tokenize the pattern to get expected tokens
217        let expected_tokens = match tokenize(pattern) {
218            Ok(tokens) => tokens,
219            Err(_) => {
220                // If pattern can't be tokenized, it can't match
221                return false;
222            }
223        };
224
225        // Try to match each expected token
226        for (expected_token, _) in expected_tokens {
227            match self.peek() {
228                Ok((actual_token, _)) => {
229                    // Check if we can do a partial match for Identifier tokens
230                    // This handles cases like matching ".b3210" as ".b" + "3" + "2" + "1" + "0"
231                    if let (PtxToken::Identifier(actual_id), expected_str) = (actual_token, expected_token.as_str()) {
232                        // Check if the expected string matches from the current char_offset
233                        let remaining = &actual_id[self.char_offset..];
234                        if remaining.starts_with(expected_str) {
235                            let new_offset = self.char_offset + expected_str.len();
236                            if new_offset == actual_id.len() {
237                                // Exactly consumed the entire identifier - advance to next token
238                                self.index += 1;
239                                self.char_offset = 0;
240                            } else {
241                                // Partial match! Advance char_offset but DON'T advance index
242                                self.char_offset = new_offset;
243                            }
244                            continue;
245                        }
246                    }
247                    
248                    // Normal exact match
249                    if actual_token != &expected_token {
250                        self.set_position(start_pos);
251                        return false;
252                    }
253                    // Token matches, consume it
254                    self.index += 1;
255                    self.char_offset = 0;
256                }
257                Err(_) => {
258                    // Unexpected EOF
259                    self.set_position(start_pos);
260                    return false;
261                }
262            }
263        }
264
265        // Successfully matched all tokens
266        true
267    }
268
269    /// Check if the next token matches a specific pattern.
270    pub fn check<F>(&self, predicate: F) -> bool
271    where
272        F: FnOnce(&PtxToken) -> bool,
273    {
274        self.tokens
275            .get(self.index)
276            .map_or(false, |(token, _)| predicate(token))
277    }
278
279    /// Expect that we've consumed a complete token (not stopped in the middle).
280    /// This should be called at the end of each struct parser to verify that
281    /// character-level parsing has consumed all characters from the current token.
282    ///
283    /// # Returns
284    /// - `Ok(())` if `char_offset == 0` (no partial token consumption)
285    /// - `Err(PtxParseError)` if `char_offset > 0` (stopped in middle of token)
286    pub fn expect_complete(&self) -> Result<(), PtxParseError> {
287        if self.char_offset > 0 {
288            // We're in the middle of a token - this is an error
289            let span = self.peek().map(|(_, s)| s.clone()).unwrap_or(Span { start: 0, end: 0 });
290            Err(unexpected_value(
291                span,
292                &["complete token"],
293                format!("partial token at char offset {}", self.char_offset)
294            ))
295        } else {
296            Ok(())
297        }
298    }
299
300    /// Consume the next token if it matches the predicate.
301    pub fn consume_if<F>(&mut self, predicate: F) -> Option<&'a (PtxToken, Span)>
302    where
303        F: FnOnce(&PtxToken) -> bool,
304    {
305        if self.check(predicate) {
306            self.index += 1;
307            self.tokens.get(self.index - 1)
308        } else {
309            None
310        }
311    }
312
313    /// Get the current position in the stream, for backtracking.
314    pub fn position(&self) -> StreamPosition {
315        StreamPosition {
316            index: self.index,
317            char_offset: self.char_offset,
318        }
319    }
320
321    /// Reset the stream to an old position, for backtracking.
322    pub fn set_position(&mut self, pos: StreamPosition) {
323        self.index = pos.index;
324        self.char_offset = pos.char_offset;
325    }
326
327    /// Check if we've reached the end of the token stream.
328    pub fn is_at_end(&self) -> bool {
329        self.index >= self.tokens.len()
330    }
331
332    /// Get the remaining tokens.
333    pub fn remaining(&self) -> &'a [(PtxToken, Span)] {
334        &self.tokens[self.index..]
335    }
336
337    /// Peek at the character at the current char_offset within the current token's string.
338    /// Returns None if we're at the end of the current token's string or if the token has no string content.
339    pub fn peek_char_in_token(&self) -> Option<char> {
340        if self.index >= self.tokens.len() {
341            return None;
342        }
343        
344        let (token, _) = &self.tokens[self.index];
345        let string = match token {
346            PtxToken::Identifier(s) |
347            PtxToken::DecimalInteger(s) |
348            PtxToken::HexInteger(s) |
349            PtxToken::BinaryInteger(s) |
350            PtxToken::OctalInteger(s) => s,
351            _ => return None,
352        };
353        
354        string.chars().nth(self.char_offset)
355    }
356
357    /// Consume one character from the current token by advancing char_offset.
358    /// If we reach the end of the token's string, advance to the next token and reset char_offset.
359    /// Returns the consumed character.
360    pub fn consume_char_in_token(&mut self) -> Option<char> {
361        let ch = self.peek_char_in_token()?;
362        self.char_offset += 1;
363        
364        // Check if we've consumed the entire string of this token
365        if self.index < self.tokens.len() {
366            let (token, _) = &self.tokens[self.index];
367            let string = match token {
368                PtxToken::Identifier(s) |
369                PtxToken::DecimalInteger(s) |
370                PtxToken::HexInteger(s) |
371                PtxToken::BinaryInteger(s) |
372                PtxToken::OctalInteger(s) => s,
373                _ => "",
374            };
375            
376            if self.char_offset >= string.len() {
377                // Move to next token and reset char_offset
378                self.index += 1;
379                self.char_offset = 0;
380            }
381        }
382        
383        Some(ch)
384    }
385
386    /// Match a specific character at the current position within the token.
387    /// Consumes the character if it matches.
388    pub fn expect_char_in_token(&mut self, expected: char) -> Result<char, PtxParseError> {
389        match self.peek_char_in_token() {
390            Some(ch) if ch == expected => {
391                self.consume_char_in_token();
392                Ok(ch)
393            }
394            Some(ch) => {
395                let span = if self.index < self.tokens.len() {
396                    self.tokens[self.index].1.clone()
397                } else {
398                    0..0
399                };
400                Err(PtxParseError {
401                    kind: ParseErrorKind::UnexpectedToken {
402                        expected: vec![format!("'{}'", expected)],
403                        found: format!("'{}'", ch),
404                    },
405                    span,
406                })
407            }
408            None => {
409                let span = if self.index < self.tokens.len() {
410                    self.tokens[self.index].1.clone()
411                } else {
412                    0..0
413                };
414                Err(PtxParseError {
415                    kind: ParseErrorKind::UnexpectedEof,
416                    span,
417                })
418            }
419        }
420    }
421}
422
423pub trait PtxParser
424where
425    Self: Sized,
426{
427    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError>;
428}
429
430pub fn parse_ptx(source: &str) -> Result<crate::r#type::module::Module, PtxParseError> {
431    let tokens = crate::lexer::tokenize(source).map_err(|err| PtxParseError {
432        kind: ParseErrorKind::InvalidLiteral("lexical error".into()),
433        span: err.span,
434    })?;
435    let mut stream = PtxTokenStream::new(&tokens);
436    let module = crate::r#type::module::Module::parse(&mut stream)?;
437    if !stream.is_at_end() {
438        let (token, span) = stream.peek()?;
439        return Err(unexpected_value(
440            span.clone(),
441            &["end of input"],
442            format!("{token:?}"),
443        ));
444    }
445    Ok(module)
446}
447
448pub fn unexpected_value(span: Span, expected: &[&str], found: impl Into<String>) -> PtxParseError {
449    PtxParseError {
450        kind: ParseErrorKind::UnexpectedToken {
451            expected: expected.iter().map(|s| s.to_string()).collect(),
452            found: found.into(),
453        },
454        span,
455    }
456}
457
458pub(crate) fn invalid_literal(span: Span, message: impl Into<String>) -> PtxParseError {
459    PtxParseError {
460        kind: ParseErrorKind::InvalidLiteral(message.into()),
461        span,
462    }
463}
464
465pub(crate) fn expect_directive_value(
466    stream: &mut PtxTokenStream,
467    expected: &str,
468) -> Result<(), PtxParseError> {
469    let (value, span) = stream.expect_directive()?;
470    if value == expected {
471        Ok(())
472    } else {
473        Err(unexpected_value(
474            span,
475            &[&format!(".{expected}")],
476            format!(".{value}"),
477        ))
478    }
479}
480
481pub(crate) fn peek_directive(
482    stream: &mut PtxTokenStream,
483) -> Result<Option<(String, Span)>, PtxParseError> {
484    // Check if we have Dot followed by Identifier
485    if let Some((PtxToken::Dot, dot_span)) = stream.tokens.get(stream.index) {
486        if let Some((PtxToken::Identifier(value), id_span)) = stream.tokens.get(stream.index + 1) {
487            let span = dot_span.start..id_span.end;
488            return Ok(Some((value.clone(), span)));
489        }
490    }
491    Ok(None)
492}