Skip to main content

trueno_ptx_debug/parser/lexer/
mod.rs

1//! PTX Lexer - Tokenization of PTX source code
2
3use super::ast::SourceLocation;
4use super::error::ParseError;
5
6/// Token kinds for PTX lexing
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum TokenKind {
9    /// End of file
10    Eof,
11    /// `.version`, `.target`, `.address_size`, etc.
12    Directive,
13    /// .entry keyword
14    Entry,
15    /// .func keyword
16    Func,
17    /// .reg declaration
18    Reg,
19    /// .shared declaration
20    Shared,
21    /// .local declaration
22    Local,
23    /// .global declaration
24    Global,
25    /// .param declaration
26    Param,
27    /// Identifier (register name, label, etc.)
28    Identifier,
29    /// Integer literal
30    Integer,
31    /// Float literal
32    Float,
33    /// Instruction (ld, st, mov, add, etc.)
34    Instruction,
35    /// Label (name:)
36    Label,
37    /// Comment (// or /* */)
38    Comment,
39    /// Opening brace {
40    LBrace,
41    /// Closing brace }
42    RBrace,
43    /// Opening parenthesis (
44    LParen,
45    /// Closing parenthesis )
46    RParen,
47    /// Opening bracket [
48    LBracket,
49    /// Closing bracket ]
50    RBracket,
51    /// Comma ,
52    Comma,
53    /// Semicolon ;
54    Semicolon,
55    /// Colon :
56    Colon,
57    /// Unknown token
58    Unknown,
59}
60
61/// A token in the PTX source
62#[derive(Debug, Clone)]
63pub struct Token {
64    /// The kind of token
65    pub kind: TokenKind,
66    /// The text of the token
67    pub text: String,
68    /// Source location
69    pub location: SourceLocation,
70}
71
72impl Default for Token {
73    fn default() -> Self {
74        Self {
75            kind: TokenKind::Eof,
76            text: String::new(),
77            location: SourceLocation::default(),
78        }
79    }
80}
81
82/// PTX Lexer
83pub struct Lexer<'a> {
84    source: &'a str,
85    pos: usize,
86    line: usize,
87    column: usize,
88}
89
90impl<'a> Lexer<'a> {
91    /// Create a new lexer for the given source
92    pub fn new(source: &'a str) -> Self {
93        Self {
94            source,
95            pos: 0,
96            line: 1,
97            column: 1,
98        }
99    }
100
101    fn peek(&self) -> Option<char> {
102        self.source[self.pos..].chars().next()
103    }
104
105    fn peek_at(&self, offset: usize) -> Option<char> {
106        self.source[self.pos..].chars().nth(offset)
107    }
108
109    fn advance(&mut self) -> Option<char> {
110        let c = self.peek()?;
111        self.pos += c.len_utf8();
112        if c == '\n' {
113            self.line += 1;
114            self.column = 1;
115        } else {
116            self.column += 1;
117        }
118        Some(c)
119    }
120
121    fn skip_whitespace(&mut self) {
122        while let Some(c) = self.peek() {
123            if c.is_whitespace() {
124                self.advance();
125            } else if c == '/' {
126                if !self.skip_comment() {
127                    break;
128                }
129            } else {
130                break;
131            }
132        }
133    }
134
135    /// Skip a comment starting at current position. Returns true if a comment was skipped.
136    fn skip_comment(&mut self) -> bool {
137        if self.peek_at(1) == Some('/') {
138            self.skip_line_comment();
139            true
140        } else if self.peek_at(1) == Some('*') {
141            self.skip_block_comment();
142            true
143        } else {
144            false
145        }
146    }
147
148    fn skip_line_comment(&mut self) {
149        while let Some(c) = self.peek() {
150            self.advance();
151            if c == '\n' {
152                break;
153            }
154        }
155    }
156
157    fn skip_block_comment(&mut self) {
158        self.advance(); // skip /
159        self.advance(); // skip *
160        while let Some(c) = self.peek() {
161            self.advance();
162            if c == '*' && self.peek() == Some('/') {
163                self.advance();
164                break;
165            }
166        }
167    }
168
169    /// Get the next token from the source
170    pub fn next_token(&mut self) -> Result<Token, ParseError> {
171        self.skip_whitespace();
172
173        let location = SourceLocation {
174            line: self.line,
175            column: self.column,
176            file: None,
177        };
178
179        let Some(c) = self.peek() else {
180            return Ok(Token {
181                kind: TokenKind::Eof,
182                text: String::new(),
183                location,
184            });
185        };
186
187        match c {
188            '{' => {
189                self.advance();
190                Ok(Token {
191                    kind: TokenKind::LBrace,
192                    text: "{".into(),
193                    location,
194                })
195            }
196            '}' => {
197                self.advance();
198                Ok(Token {
199                    kind: TokenKind::RBrace,
200                    text: "}".into(),
201                    location,
202                })
203            }
204            '(' => {
205                self.advance();
206                Ok(Token {
207                    kind: TokenKind::LParen,
208                    text: "(".into(),
209                    location,
210                })
211            }
212            ')' => {
213                self.advance();
214                Ok(Token {
215                    kind: TokenKind::RParen,
216                    text: ")".into(),
217                    location,
218                })
219            }
220            '[' => {
221                self.advance();
222                Ok(Token {
223                    kind: TokenKind::LBracket,
224                    text: "[".into(),
225                    location,
226                })
227            }
228            ']' => {
229                self.advance();
230                Ok(Token {
231                    kind: TokenKind::RBracket,
232                    text: "]".into(),
233                    location,
234                })
235            }
236            ',' => {
237                self.advance();
238                Ok(Token {
239                    kind: TokenKind::Comma,
240                    text: ",".into(),
241                    location,
242                })
243            }
244            ';' => {
245                self.advance();
246                Ok(Token {
247                    kind: TokenKind::Semicolon,
248                    text: ";".into(),
249                    location,
250                })
251            }
252            '.' => self.read_directive(location),
253            '%' => self.read_register(location),
254            '@' => self.read_predicate(location),
255            '0'..='9' | '-' => self.read_number(location),
256            _ if c.is_alphabetic() || c == '_' => self.read_identifier_or_instruction(location),
257            _ => {
258                self.advance();
259                Ok(Token {
260                    kind: TokenKind::Unknown,
261                    text: c.to_string(),
262                    location,
263                })
264            }
265        }
266    }
267
268    fn read_directive(&mut self, location: SourceLocation) -> Result<Token, ParseError> {
269        let start = self.pos;
270        self.advance(); // skip '.'
271
272        // Read directive name
273        while let Some(c) = self.peek() {
274            if c.is_alphanumeric() || c == '_' {
275                self.advance();
276            } else {
277                break;
278            }
279        }
280
281        let directive_name = &self.source[start..self.pos];
282
283        // For version, target, address_size - read the value too
284        let text = if directive_name.starts_with(".version")
285            || directive_name.starts_with(".target")
286            || directive_name.starts_with(".address_size")
287        {
288            self.skip_whitespace();
289            let value_start = self.pos;
290            while let Some(c) = self.peek() {
291                if c == '\n' || c == ';' || c == '{' || c == '(' {
292                    break;
293                }
294                self.advance();
295            }
296            format!(
297                "{} {}",
298                directive_name,
299                self.source[value_start..self.pos].trim()
300            )
301        } else {
302            directive_name.to_string()
303        };
304
305        let kind = self.classify_directive(&text);
306        Ok(Token {
307            kind,
308            text,
309            location,
310        })
311    }
312
313    fn classify_directive(&self, text: &str) -> TokenKind {
314        const DIRECTIVE_MAP: &[(&str, TokenKind)] = &[
315            (".entry", TokenKind::Entry),
316            (".func", TokenKind::Func),
317            (".reg", TokenKind::Reg),
318            (".shared", TokenKind::Shared),
319            (".local", TokenKind::Local),
320            (".global", TokenKind::Global),
321            (".param", TokenKind::Param),
322        ];
323
324        DIRECTIVE_MAP
325            .iter()
326            .find(|(prefix, _)| text.starts_with(prefix))
327            .map_or(TokenKind::Directive, |(_, kind)| kind.clone())
328    }
329
330    fn read_register(&mut self, location: SourceLocation) -> Result<Token, ParseError> {
331        let start = self.pos;
332        self.advance(); // skip '%'
333
334        while let Some(c) = self.peek() {
335            if c.is_alphanumeric() || c == '_' {
336                self.advance();
337            } else {
338                break;
339            }
340        }
341
342        Ok(Token {
343            kind: TokenKind::Identifier,
344            text: self.source[start..self.pos].to_string(),
345            location,
346        })
347    }
348
349    fn read_predicate(&mut self, location: SourceLocation) -> Result<Token, ParseError> {
350        let start = self.pos;
351        self.advance(); // skip '@'
352
353        // May have '!' for negation
354        if self.peek() == Some('!') {
355            self.advance();
356        }
357
358        // Read predicate register name
359        while let Some(c) = self.peek() {
360            if c.is_alphanumeric() || c == '_' || c == '%' {
361                self.advance();
362            } else {
363                break;
364            }
365        }
366
367        Ok(Token {
368            kind: TokenKind::Identifier,
369            text: self.source[start..self.pos].to_string(),
370            location,
371        })
372    }
373
374    fn read_number(&mut self, location: SourceLocation) -> Result<Token, ParseError> {
375        let start = self.pos;
376
377        if self.peek() == Some('-') {
378            self.advance();
379        }
380
381        if let Some(tok) = self.try_read_hex(start, &location) {
382            return Ok(tok);
383        }
384
385        self.advance_while(|c| c.is_ascii_digit());
386
387        let is_float = self.read_fractional_part() | self.read_exponent_part();
388
389        Ok(Token {
390            kind: if is_float {
391                TokenKind::Float
392            } else {
393                TokenKind::Integer
394            },
395            text: self.source[start..self.pos].to_string(),
396            location,
397        })
398    }
399
400    fn try_read_hex(&mut self, start: usize, location: &SourceLocation) -> Option<Token> {
401        if self.peek() != Some('0') {
402            return None;
403        }
404        self.advance();
405        if !matches!(self.peek(), Some('x' | 'X')) {
406            return None;
407        }
408        self.advance();
409        self.advance_while(|c| c.is_ascii_hexdigit());
410        Some(Token {
411            kind: TokenKind::Integer,
412            text: self.source[start..self.pos].to_string(),
413            location: location.clone(),
414        })
415    }
416
417    fn read_fractional_part(&mut self) -> bool {
418        if self.peek() != Some('.') {
419            return false;
420        }
421        self.advance();
422        self.advance_while(|c| c.is_ascii_digit());
423        true
424    }
425
426    fn read_exponent_part(&mut self) -> bool {
427        if !matches!(self.peek(), Some('e' | 'E')) {
428            return false;
429        }
430        self.advance();
431        if matches!(self.peek(), Some('+' | '-')) {
432            self.advance();
433        }
434        self.advance_while(|c| c.is_ascii_digit());
435        true
436    }
437
438    fn advance_while(&mut self, predicate: impl Fn(char) -> bool) {
439        while let Some(c) = self.peek() {
440            if predicate(c) {
441                self.advance();
442            } else {
443                break;
444            }
445        }
446    }
447
448    fn read_identifier_or_instruction(
449        &mut self,
450        location: SourceLocation,
451    ) -> Result<Token, ParseError> {
452        let start = self.pos;
453
454        while let Some(c) = self.peek() {
455            if c.is_alphanumeric() || c == '_' {
456                self.advance();
457            } else {
458                break;
459            }
460        }
461
462        let text = &self.source[start..self.pos];
463
464        // Check if it's a label (followed by :)
465        if self.peek() == Some(':') {
466            self.advance();
467            return Ok(Token {
468                kind: TokenKind::Label,
469                text: self.source[start..self.pos].to_string(),
470                location,
471            });
472        }
473
474        // Check if it's an instruction
475        if self.is_instruction(text) {
476            // For instructions, read modifiers and operands
477            let instr_end = self.pos;
478            self.skip_whitespace();
479
480            // Read the rest of the line for operands
481            let operand_start = self.pos;
482            while let Some(c) = self.peek() {
483                if c == '\n' || c == ';' || c == '{' || c == '}' {
484                    break;
485                }
486                self.advance();
487            }
488
489            let full_text = if operand_start < self.pos {
490                format!(
491                    "{} {}",
492                    &self.source[start..instr_end],
493                    self.source[operand_start..self.pos].trim()
494                )
495            } else {
496                self.source[start..instr_end].to_string()
497            };
498
499            return Ok(Token {
500                kind: TokenKind::Instruction,
501                text: full_text,
502                location,
503            });
504        }
505
506        Ok(Token {
507            kind: TokenKind::Identifier,
508            text: text.to_string(),
509            location,
510        })
511    }
512
513    fn is_instruction(&self, text: &str) -> bool {
514        matches!(
515            text,
516            "ld" | "st"
517                | "mov"
518                | "add"
519                | "sub"
520                | "mul"
521                | "div"
522                | "rem"
523                | "mad"
524                | "fma"
525                | "neg"
526                | "abs"
527                | "min"
528                | "max"
529                | "and"
530                | "or"
531                | "xor"
532                | "not"
533                | "shl"
534                | "shr"
535                | "setp"
536                | "selp"
537                | "cvt"
538                | "cvta"
539                | "bra"
540                | "call"
541                | "ret"
542                | "exit"
543                | "bar"
544                | "membar"
545                | "atom"
546                | "red"
547                | "tex"
548                | "tld4"
549                | "suld"
550                | "sust"
551                | "shfl"
552                | "vote"
553                | "match"
554                | "mma"
555                | "wmma"
556                | "ldmatrix"
557                | "cp"
558                | "prefetch"
559                | "prefetchu"
560        )
561    }
562}
563
564#[cfg(test)]
565mod tests;