roan_ast/lexer/
mod.rs

1use crate::{
2    lexer::{
3        identifier::Identifier,
4        number::NumberLiteral,
5        string::StringLiteral,
6        token::{Token, TokenKind},
7    },
8    source::Source,
9};
10use anyhow::Result;
11use roan_error::{error::RoanError::InvalidToken, position::Position, span::TextSpan};
12
13mod identifier;
14mod number;
15mod string;
16pub mod token;
17
18/// The lexer is responsible for converting the source code into a list of tokens.
19pub struct Lexer {
20    pub source: Source,
21    pub tokens: Vec<Token>,
22    pub position: Position,
23}
24
25impl Lexer {
26    /// Create a new lexer from a source string.
27    ///
28    /// # Arguments
29    /// - `source` - An instance of `Source` containing the source code.
30    ///
31    /// # Example
32    /// ```rust
33    /// use roan_ast::{Lexer, TokenKind};
34    /// use roan_ast::source::Source;
35    /// let source = Source::from_string("let x = 10;".to_string());
36    /// let mut lexer = Lexer::new(source);
37    /// let tokens = lexer.lex().expect("Failed to lex source code");
38    ///
39    /// assert_eq!(tokens.first().unwrap().kind, TokenKind::Let);
40    /// ```
41    pub fn new(source: Source) -> Self {
42        Self {
43            source,
44            tokens: vec![],
45            position: Position::new(1, 0, 0),
46        }
47    }
48}
49
50impl Lexer {
51    /// Lex the source code and return a list of tokens.
52    ///
53    /// During the lexing process, the lexer will consume the source code character by character
54    /// and convert it into a list of tokens. The lexer will skip whitespace and comments.
55    ///
56    /// When EOF is reached, the lexer will return the list of tokens.
57    pub fn lex(&mut self, lex_comments: bool) -> Result<Vec<Token>> {
58        loop {
59            let token = self.next_token()?;
60
61            if let Some(token) = token {
62                if (token.kind == TokenKind::Comment && !lex_comments) || token.kind == TokenKind::Whitespace {
63                    continue;
64                }
65
66                if token.kind == TokenKind::EOF {
67                    break;
68                }
69
70                self.tokens.push(token);
71            } else {
72                break;
73            }
74        }
75
76        Ok(self.tokens.clone())
77    }
78
79    /// Check if the lexer has reached the end of the source code.
80    pub fn is_eof(&self) -> bool {
81        self.position.index >= self.source.len()
82    }
83
84    /// Get the current character in the source code.
85    pub fn current(&mut self) -> Option<char> {
86        self.source.chars().nth(self.position.index)
87    }
88
89    /// Consume the current character and move to the next one.
90    pub fn consume(&mut self) -> Option<char> {
91        if self.position.index >= self.source.len() {
92            return None;
93        }
94        let c = self.current();
95
96        self.update_position(c?);
97
98        c
99    }
100
101    /// Update the position of the lexer.
102    ///
103    /// The position is updated based on the current character.
104    /// The position includes the line, column, and index of the character.
105    ///
106    /// If the character is a newline, the line is incremented and the column is reset to 0.
107    fn update_position(&mut self, c: char) {
108        if c == '\n' {
109            self.position.line += 1;
110            self.position.column = 0;
111        } else {
112            self.position.column += 1;
113        }
114        self.position.index += 1;
115    }
116
117    /// Check if the character is a valid identifier start character.
118    pub fn is_identifier_start(&self, c: char) -> bool {
119        c.is_alphanumeric() || c == '_'
120    }
121
122    /// Check if the character is a valid number start character.
123    pub fn is_number_start(&self, c: char) -> bool {
124        c.is_digit(10)
125    }
126
127    /// Peek at the next character in the source code.
128    pub fn peek(&self) -> Option<char> {
129        if self.position.index + 1 >= self.source.len() {
130            None
131        } else {
132            self.source.chars().nth(self.position.index + 1)
133        }
134    }
135
136    /// Get the next token in the source code.
137    pub fn next_token(&mut self) -> Result<Option<Token>> {
138        let start = self.position;
139        let Some(c) = self.current() else {
140            return Ok(None);
141        };
142
143        let kind = match c {
144            _ if c.is_whitespace() => {
145                while let Some(c) = self.current() {
146                    if !c.is_whitespace() {
147                        break;
148                    }
149                    self.consume();
150                }
151                TokenKind::Whitespace
152            }
153
154            _ if c == '"' => StringLiteral::lex_string(self)?,
155            _ if c.is_ascii_digit() => NumberLiteral::lex_number(self, c)?,
156            _ if c == '\'' => TokenKind::Char(self.parse_char()?),
157
158            _ if Identifier::is_identifier_start(c) => Identifier::lex_identifier(self)?,
159
160            _ => {
161                let kind = match c {
162                    '(' => TokenKind::LeftParen,
163                    ')' => TokenKind::RightParen,
164                    '{' => TokenKind::LeftBrace,
165                    '}' => TokenKind::RightBrace,
166                    '[' => TokenKind::LeftBracket,
167                    ']' => TokenKind::RightBracket,
168                    ',' => TokenKind::Comma,
169                    '.' => self.lex_potential_triple(
170                        '.',
171                        TokenKind::Dot,
172                        TokenKind::DoubleDot,
173                        TokenKind::TripleDot,
174                    ),
175                    ':' => self.lex_potential_double(':', TokenKind::Colon, TokenKind::DoubleColon),
176                    ';' => TokenKind::Semicolon,
177                    '/' => {
178                        if self.match_next('/') {
179                            while let Some(c) = self.current() {
180                                if c == '\n' {
181                                    break;
182                                }
183                                self.consume();
184                            }
185                            TokenKind::Comment
186                        } else {
187                            self.lex_potential_double(
188                                '=',
189                                TokenKind::Slash,
190                                TokenKind::DivideEquals,
191                            )
192                        }
193                    }
194                    '+' => {
195                        if self.match_next('+') {
196                            self.consume();
197                            TokenKind::Increment
198                        } else if self.match_next('=') {
199                            self.consume();
200                            TokenKind::PlusEquals
201                        } else {
202                            TokenKind::Plus
203                        }
204                    }
205                    '-' => {
206                        if self.match_next('-') {
207                            self.consume();
208                            TokenKind::Decrement
209                        } else if self.match_next('=') {
210                            self.consume();
211                            TokenKind::MinusEquals
212                        } else if self.match_next('>') {
213                            self.consume();
214                            TokenKind::Arrow
215                        } else {
216                            TokenKind::Minus
217                        }
218                    }
219                    '*' => {
220                        if self.match_next('*') {
221                            self.consume();
222                            TokenKind::DoubleAsterisk
223                        } else if self.match_next('=') {
224                            self.consume();
225                            TokenKind::MultiplyEquals
226                        } else {
227                            TokenKind::Asterisk
228                        }
229                    }
230                    '%' => TokenKind::Percent,
231                    '^' => TokenKind::Caret,
232                    '!' => self.lex_potential_double('=', TokenKind::Bang, TokenKind::BangEquals),
233                    '=' => {
234                        self.lex_potential_double('=', TokenKind::Equals, TokenKind::EqualsEquals)
235                    }
236                    '~' => TokenKind::Tilde,
237                    '<' => {
238                        if self.match_next('<') {
239                            self.consume();
240                            TokenKind::DoubleLessThan
241                        } else {
242                            self.lex_potential_double(
243                                '=',
244                                TokenKind::LessThan,
245                                TokenKind::LessThanEquals,
246                            )
247                        }
248                    }
249                    '>' => {
250                        if self.match_next('>') {
251                            self.consume();
252                            TokenKind::DoubleGreaterThan
253                        } else {
254                            self.lex_potential_double(
255                                '=',
256                                TokenKind::GreaterThan,
257                                TokenKind::GreaterThanEquals,
258                            )
259                        }
260                    }
261                    '?' => TokenKind::QuestionMark,
262                    '&' => self.lex_potential_double('&', TokenKind::Ampersand, TokenKind::And),
263                    '|' => self.lex_potential_double('|', TokenKind::Pipe, TokenKind::Or),
264                    _ => {
265                        return Err(InvalidToken(
266                            c.to_string(),
267                            TextSpan::new(self.position, self.position, c.to_string()),
268                        )
269                        .into())
270                    }
271                };
272
273                self.consume();
274
275                kind
276            }
277        };
278
279        let end_pos = self.position;
280        let literal = self.source.get_between(start.index, end_pos.index);
281        Ok(Some(Token::new(
282            kind,
283            TextSpan::new(start, end_pos, literal),
284        )))
285    }
286
287    /// Consume a number.
288    ///
289    /// Can be either an integer or a float.
290    pub fn consume_number(&mut self) -> (NumberType, String) {
291        let mut number = String::new();
292        let mut number_type = NumberType::Integer;
293
294        while let Some(c) = self.current() {
295            if c.is_digit(10) {
296                number.push(c);
297            } else if c == '.' {
298                number.push(c);
299                number_type = NumberType::Float;
300            } else {
301                break;
302            }
303            self.consume();
304        }
305
306        (number_type, number)
307    }
308
309    /// Parses a character literal. Throws an error if more than one character is found.
310    pub fn parse_char(&mut self) -> Result<char> {
311        self.consume();
312        let c = self.consume();
313        if self.consume() == Some('\'') {
314            Ok(c.unwrap())
315        } else {
316            Err(InvalidToken(
317                c.unwrap().to_string(),
318                TextSpan::new(self.position, self.position, c.unwrap().to_string()),
319            )
320            .into())
321        }
322    }
323
324    pub fn lex_potential_double(
325        &mut self,
326        expected: char,
327        one_char: TokenKind,
328        double_char: TokenKind,
329    ) -> TokenKind {
330        if let Some(next) = self.peek() {
331            if next == expected {
332                self.consume();
333                double_char
334            } else {
335                one_char
336            }
337        } else {
338            one_char
339        }
340    }
341
342    pub fn lex_potential_triple(
343        &mut self,
344        expected: char,
345        one_char: TokenKind,
346        double_char: TokenKind,
347        triple_char: TokenKind,
348    ) -> TokenKind {
349        match self.peek() {
350            Some(next) if next == expected => {
351                self.consume();
352                match self.peek() {
353                    Some(next) if next == expected => {
354                        self.consume();
355                        triple_char
356                    }
357                    _ => double_char,
358                }
359            }
360            _ => one_char,
361        }
362    }
363
364    /// Check if the next character matches the given character.
365    pub fn match_next(&mut self, ch: char) -> bool {
366        if let Some(c) = self.peek() {
367            if c == ch {
368                return true;
369            }
370        }
371        false
372    }
373}
374
375/// The type of number.
376#[derive(Debug)]
377enum NumberType {
378    Integer,
379    Float,
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::source::Source;
386
387    macro_rules! test_tokens {
388        ($source:expr, $expected:expr) => {{
389            let source = Source::from_string($source.to_string());
390            let mut lexer = Lexer::new(source);
391            let tokens = lexer.lex().expect("Lexing failed");
392            let expected_kinds = $expected;
393            let actual_kinds: Vec<TokenKind> = tokens.iter().map(|t| t.kind.clone()).collect();
394            assert_eq!(
395                actual_kinds, expected_kinds,
396                "Source: {}\nExpected: {:?}\nActual: {:?}",
397                $source, expected_kinds, actual_kinds
398            );
399        }};
400    }
401
402    #[test]
403    fn test_lexer_tokens() {
404        let test_cases = vec![
405            // String Literal
406            (
407                r#""Hello, world!""#,
408                vec![TokenKind::String("Hello, world!".to_string())],
409            ),
410            // Integer Literal
411            ("123", vec![TokenKind::Integer(123)]),
412            // Float Literal
413            ("123.45", vec![TokenKind::Float(123.45)]),
414            // Identifier
415            ("foo", vec![TokenKind::Identifier]),
416            // Boolean Literals
417            (
418                "true; false",
419                vec![TokenKind::True, TokenKind::Semicolon, TokenKind::False],
420            ),
421            // Arrow
422            ("->", vec![TokenKind::Arrow]),
423            // Single Dot
424            (
425                "arr.len();",
426                vec![
427                    TokenKind::Identifier, // arr
428                    TokenKind::Dot,        // .
429                    TokenKind::Identifier, // len
430                    TokenKind::LeftParen,  // (
431                    TokenKind::RightParen, // )
432                    TokenKind::Semicolon,  // ;
433                ],
434            ),
435            // Double Dot
436            ("..", vec![TokenKind::DoubleDot]),
437            // Triple Dot
438            ("...", vec![TokenKind::TripleDot]),
439            // Double Colon
440            ("::", vec![TokenKind::DoubleColon]),
441            // Comment
442            (
443                "// This is a comment\nlet x = 10;",
444                vec![
445                    TokenKind::Let,
446                    TokenKind::Identifier, // x
447                    TokenKind::Equals,
448                    TokenKind::Integer(10),
449                    TokenKind::Semicolon,
450                ],
451            ),
452            // Escape Sequences
453            (
454                r#""\n\r\t\\"#,
455                vec![TokenKind::String("\n\r\t\\".to_string())],
456            ),
457            // Mixed Tokens
458            (
459                r#"let x = 42; if (x > 10) { return x; }"#,
460                vec![
461                    TokenKind::Let,
462                    TokenKind::Identifier, // x
463                    TokenKind::Equals,
464                    TokenKind::Integer(42),
465                    TokenKind::Semicolon,
466                    TokenKind::If,
467                    TokenKind::LeftParen,
468                    TokenKind::Identifier, // x
469                    TokenKind::GreaterThan,
470                    TokenKind::Integer(10),
471                    TokenKind::RightParen,
472                    TokenKind::LeftBrace,
473                    TokenKind::Return,
474                    TokenKind::Identifier, // x
475                    TokenKind::Semicolon,
476                    TokenKind::RightBrace,
477                ],
478            ),
479            // All Single-Character Punctuations
480            (
481                "(){},.;",
482                vec![
483                    TokenKind::LeftParen,
484                    TokenKind::RightParen,
485                    TokenKind::LeftBrace,
486                    TokenKind::RightBrace,
487                    TokenKind::Comma,
488                    TokenKind::Dot,
489                    TokenKind::Semicolon,
490                ],
491            ),
492            // All Multi-Character Operators
493            (
494                "== != <= >= ++ -- += -= *= /= && || ::",
495                vec![
496                    TokenKind::EqualsEquals,
497                    TokenKind::BangEquals,
498                    TokenKind::LessThanEquals,
499                    TokenKind::GreaterThanEquals,
500                    TokenKind::Increment,
501                    TokenKind::Decrement,
502                    TokenKind::PlusEquals,
503                    TokenKind::MinusEquals,
504                    TokenKind::MultiplyEquals,
505                    TokenKind::DivideEquals,
506                    TokenKind::And,
507                    TokenKind::Or,
508                    TokenKind::DoubleColon,
509                ],
510            ),
511            // Unicode Identifiers
512            // (
513            //     "变量 = 100;",
514            //     vec![
515            //         TokenKind::Identifier, // 变量
516            //         TokenKind::Equals,
517            //         TokenKind::Integer(100),
518            //         TokenKind::Semicolon,
519            //     ],
520            // ),
521            (
522                "_privateVar = true;",
523                vec![
524                    TokenKind::Identifier, // _privateVar
525                    TokenKind::Equals,
526                    TokenKind::True,
527                    TokenKind::Semicolon,
528                ],
529            ),
530            // Number Edge Cases
531            ("007", vec![TokenKind::Integer(7)]),
532            ("123.", vec![TokenKind::Float(123.0)]),
533            ("123.45", vec![TokenKind::Float(123.45)]),
534            ("1", vec![TokenKind::Integer(1)]),
535            ("123.45", vec![TokenKind::Float(123.45)]),
536            ("0b1010", vec![TokenKind::Integer(10)]),
537            ("0o755", vec![TokenKind::Integer(493)]),
538            ("0xdeadbeef", vec![TokenKind::Integer(0xdeadbeef)]),
539            // Complex Expressions
540            (
541                "fn add(a, b) -> a + b;",
542                vec![
543                    TokenKind::Fn,
544                    TokenKind::Identifier, // add
545                    TokenKind::LeftParen,
546                    TokenKind::Identifier, // a
547                    TokenKind::Comma,
548                    TokenKind::Identifier, // b
549                    TokenKind::RightParen,
550                    TokenKind::Arrow,
551                    TokenKind::Identifier, // a
552                    TokenKind::Plus,
553                    TokenKind::Identifier, // b
554                    TokenKind::Semicolon,
555                ],
556            ),
557            // Whitespace Variations
558            (
559                "   \n\t let    x   =   5   ;  ",
560                vec![
561                    TokenKind::Let,
562                    TokenKind::Identifier, // x
563                    TokenKind::Equals,
564                    TokenKind::Integer(5),
565                    TokenKind::Semicolon,
566                ],
567            ),
568            (
569                "let x = 10",
570                vec![
571                    TokenKind::Let,
572                    TokenKind::Identifier, // x
573                    TokenKind::Equals,
574                    TokenKind::Integer(10),
575                ],
576            ),
577            (
578                "2 << 3",
579                vec![
580                    TokenKind::Integer(2),
581                    TokenKind::DoubleLessThan,
582                    TokenKind::Integer(3),
583                ],
584            ),
585            (
586                "2 >> 3",
587                vec![
588                    TokenKind::Integer(2),
589                    TokenKind::DoubleGreaterThan,
590                    TokenKind::Integer(3),
591                ],
592            ),
593        ];
594
595        for (source, expected) in test_cases {
596            test_tokens!(source, expected);
597        }
598    }
599
600    #[test]
601    fn test_invalid_escape_sequence() {
602        let source = Source::from_string(r#""\z""#.to_string());
603        let mut lexer = Lexer::new(source);
604        let result = lexer.lex();
605        assert!(
606            result.is_err(),
607            "Expected an error for invalid escape sequence"
608        );
609    }
610
611    #[test]
612    fn test_invalid_token() {
613        let source = Source::from_string(r#"@@"#.to_string());
614        let mut lexer = Lexer::new(source);
615        let result = lexer.lex();
616        assert!(result.is_err(), "Expected an error for invalid tokens");
617    }
618}