Skip to main content

ion_core/
lexer.rs

1use crate::error::IonError;
2use crate::token::{SpannedToken, Token};
3
4pub struct Lexer<'a> {
5    source: &'a [u8],
6    pos: usize,
7    line: usize,
8    col: usize,
9}
10
11impl<'a> Lexer<'a> {
12    pub fn new(source: &'a str) -> Self {
13        Self {
14            source: source.as_bytes(),
15            pos: 0,
16            line: 1,
17            col: 1,
18        }
19    }
20
21    pub fn tokenize(&mut self) -> Result<Vec<SpannedToken>, IonError> {
22        let mut tokens = Vec::new();
23        loop {
24            let tok = self.next_token()?;
25            let is_eof = tok.token == Token::Eof;
26            tokens.push(tok);
27            if is_eof {
28                break;
29            }
30        }
31        Ok(tokens)
32    }
33
34    fn peek(&self) -> u8 {
35        if self.pos < self.source.len() {
36            self.source[self.pos]
37        } else {
38            0
39        }
40    }
41
42    fn peek_at(&self, offset: usize) -> u8 {
43        let idx = self.pos + offset;
44        if idx < self.source.len() {
45            self.source[idx]
46        } else {
47            0
48        }
49    }
50
51    fn advance(&mut self) -> u8 {
52        let ch = self.peek();
53        if ch == b'\n' {
54            self.line += 1;
55            self.col = 1;
56        } else {
57            self.col += 1;
58        }
59        self.pos += 1;
60        ch
61    }
62
63    /// Consume a full UTF-8 character from the source and push it to the string.
64    fn push_utf8_char(&mut self, s: &mut String) {
65        let b = self.advance();
66        let width = if b < 0x80 {
67            s.push(b as char);
68            return;
69        } else if b < 0xE0 {
70            2
71        } else if b < 0xF0 {
72            3
73        } else {
74            4
75        };
76        let mut bytes = vec![b];
77        for _ in 1..width {
78            if self.pos < self.source.len() {
79                bytes.push(self.advance());
80            }
81        }
82        if let Ok(ch) = std::str::from_utf8(&bytes) {
83            s.push_str(ch);
84        } else {
85            // Invalid UTF-8: push replacement character
86            s.push(char::REPLACEMENT_CHARACTER);
87        }
88    }
89
90    /// Parse `\u{XXXX}` Unicode escape (the `u` has already been consumed).
91    fn parse_unicode_escape(&mut self, s: &mut String) -> Result<(), IonError> {
92        if self.pos >= self.source.len() || self.peek() != b'{' {
93            return Err(IonError::lex(
94                ion_str!("expected '{' after \\u"),
95                self.line,
96                self.col,
97            ));
98        }
99        self.advance(); // consume '{'
100        let mut hex = String::new();
101        while self.pos < self.source.len() && self.peek() != b'}' {
102            hex.push(self.advance() as char);
103            if hex.len() > 6 {
104                return Err(IonError::lex(
105                    ion_str!("unicode escape too long (max 6 hex digits)"),
106                    self.line,
107                    self.col,
108                ));
109            }
110        }
111        if self.pos >= self.source.len() {
112            return Err(IonError::lex(
113                ion_str!("unterminated unicode escape"),
114                self.line,
115                self.col,
116            ));
117        }
118        self.advance(); // consume '}'
119        if hex.is_empty() {
120            return Err(IonError::lex(
121                ion_str!("empty unicode escape"),
122                self.line,
123                self.col,
124            ));
125        }
126        let code_point = u32::from_str_radix(&hex, 16)
127            .map_err(|_| IonError::lex(ion_str!("invalid unicode escape"), self.line, self.col))?;
128        let ch = char::from_u32(code_point).ok_or_else(|| {
129            IonError::lex(ion_str!("invalid unicode code point"), self.line, self.col)
130        })?;
131        s.push(ch);
132        Ok(())
133    }
134
135    fn skip_whitespace_and_comments(&mut self) {
136        loop {
137            while self.pos < self.source.len() && self.peek().is_ascii_whitespace() {
138                self.advance();
139            }
140            if self.peek() == b'/' && self.peek_at(1) == b'/' {
141                while self.pos < self.source.len() && self.peek() != b'\n' {
142                    self.advance();
143                }
144            } else {
145                break;
146            }
147        }
148    }
149
150    fn spanned(&self, token: Token, line: usize, col: usize) -> SpannedToken {
151        SpannedToken { token, line, col }
152    }
153
154    fn next_token(&mut self) -> Result<SpannedToken, IonError> {
155        self.skip_whitespace_and_comments();
156
157        let line = self.line;
158        let col = self.col;
159
160        if self.pos >= self.source.len() {
161            return Ok(self.spanned(Token::Eof, line, col));
162        }
163
164        let ch = self.peek();
165
166        // Numbers
167        if ch.is_ascii_digit() {
168            return self.lex_number(line, col);
169        }
170
171        // Triple-quoted strings (must check before single-quoted)
172        // ch == peek() since pos hasn't advanced yet, so check peek_at(1) and peek_at(2)
173        if ch == b'"' && self.peek_at(1) == b'"' && self.peek_at(2) == b'"' {
174            return self.lex_triple_string(line, col, false);
175        }
176
177        // Triple-quoted f-strings
178        if ch == b'f'
179            && self.peek_at(1) == b'"'
180            && self.peek_at(2) == b'"'
181            && self.peek_at(3) == b'"'
182        {
183            self.advance(); // consume 'f'
184            return self.lex_triple_string(line, col, true);
185        }
186
187        // Strings
188        if ch == b'"' {
189            return self.lex_string(line, col, false);
190        }
191
192        // f-strings
193        if ch == b'f' && self.peek_at(1) == b'"' {
194            self.advance(); // consume 'f'
195            return self.lex_string(line, col, true);
196        }
197
198        // byte strings
199        if ch == b'b' && self.peek_at(1) == b'"' {
200            self.advance(); // consume 'b'
201            return self.lex_bytes(line, col);
202        }
203
204        // Identifiers and keywords
205        if ch.is_ascii_alphabetic() || ch == b'_' {
206            return self.lex_ident(line, col);
207        }
208
209        // Loop labels: 'name
210        if ch == b'\'' {
211            return self.lex_label(line, col);
212        }
213
214        // Operators and punctuation
215        self.advance();
216        match ch {
217            b'(' => Ok(self.spanned(Token::LParen, line, col)),
218            b')' => Ok(self.spanned(Token::RParen, line, col)),
219            b'{' => Ok(self.spanned(Token::LBrace, line, col)),
220            b'}' => Ok(self.spanned(Token::RBrace, line, col)),
221            b'[' => Ok(self.spanned(Token::LBracket, line, col)),
222            b']' => Ok(self.spanned(Token::RBracket, line, col)),
223            b',' => Ok(self.spanned(Token::Comma, line, col)),
224            b';' => Ok(self.spanned(Token::Semicolon, line, col)),
225            b'?' => Ok(self.spanned(Token::Question, line, col)),
226            b'#' => {
227                if self.peek() == b'{' {
228                    self.advance();
229                    Ok(self.spanned(Token::HashBrace, line, col))
230                } else {
231                    Err(IonError::lex(
232                        format!("{}{}", ion_str!("unexpected character: "), '#'),
233                        line,
234                        col,
235                    ))
236                }
237            }
238            b'+' => {
239                if self.peek() == b'=' {
240                    self.advance();
241                    Ok(self.spanned(Token::PlusEq, line, col))
242                } else {
243                    Ok(self.spanned(Token::Plus, line, col))
244                }
245            }
246            b'-' => {
247                if self.peek() == b'=' {
248                    self.advance();
249                    Ok(self.spanned(Token::MinusEq, line, col))
250                } else {
251                    Ok(self.spanned(Token::Minus, line, col))
252                }
253            }
254            b'*' => {
255                if self.peek() == b'=' {
256                    self.advance();
257                    Ok(self.spanned(Token::StarEq, line, col))
258                } else {
259                    Ok(self.spanned(Token::Star, line, col))
260                }
261            }
262            b'/' => {
263                if self.peek() == b'=' {
264                    self.advance();
265                    Ok(self.spanned(Token::SlashEq, line, col))
266                } else {
267                    Ok(self.spanned(Token::Slash, line, col))
268                }
269            }
270            b'%' => Ok(self.spanned(Token::Percent, line, col)),
271            b'=' => {
272                if self.peek() == b'=' {
273                    self.advance();
274                    Ok(self.spanned(Token::EqEq, line, col))
275                } else if self.peek() == b'>' {
276                    self.advance();
277                    Ok(self.spanned(Token::Arrow, line, col))
278                } else {
279                    Ok(self.spanned(Token::Eq, line, col))
280                }
281            }
282            b'!' => {
283                if self.peek() == b'=' {
284                    self.advance();
285                    Ok(self.spanned(Token::BangEq, line, col))
286                } else {
287                    Ok(self.spanned(Token::Bang, line, col))
288                }
289            }
290            b'<' => {
291                if self.peek() == b'=' {
292                    self.advance();
293                    Ok(self.spanned(Token::LtEq, line, col))
294                } else if self.peek() == b'<' {
295                    self.advance();
296                    Ok(self.spanned(Token::Shl, line, col))
297                } else {
298                    Ok(self.spanned(Token::Lt, line, col))
299                }
300            }
301            b'>' => {
302                if self.peek() == b'=' {
303                    self.advance();
304                    Ok(self.spanned(Token::GtEq, line, col))
305                } else if self.peek() == b'>' {
306                    self.advance();
307                    Ok(self.spanned(Token::Shr, line, col))
308                } else {
309                    Ok(self.spanned(Token::Gt, line, col))
310                }
311            }
312            b'&' => {
313                if self.peek() == b'&' {
314                    self.advance();
315                    Ok(self.spanned(Token::And, line, col))
316                } else {
317                    Ok(self.spanned(Token::Ampersand, line, col))
318                }
319            }
320            b'^' => Ok(self.spanned(Token::Caret, line, col)),
321            b'|' => {
322                if self.peek() == b'|' {
323                    self.advance();
324                    Ok(self.spanned(Token::Or, line, col))
325                } else if self.peek() == b'>' {
326                    self.advance();
327                    Ok(self.spanned(Token::Pipe, line, col))
328                } else {
329                    Ok(self.spanned(Token::PipeSym, line, col))
330                }
331            }
332            b'.' => {
333                if self.peek() == b'.' {
334                    self.advance();
335                    if self.peek() == b'.' {
336                        self.advance();
337                        Ok(self.spanned(Token::DotDotDot, line, col))
338                    } else if self.peek() == b'=' {
339                        self.advance();
340                        Ok(self.spanned(Token::DotDotEq, line, col))
341                    } else {
342                        Ok(self.spanned(Token::DotDot, line, col))
343                    }
344                } else {
345                    Ok(self.spanned(Token::Dot, line, col))
346                }
347            }
348            b':' => {
349                if self.peek() == b':' {
350                    self.advance();
351                    Ok(self.spanned(Token::ColonColon, line, col))
352                } else {
353                    Ok(self.spanned(Token::Colon, line, col))
354                }
355            }
356            _ => Err(IonError::lex(
357                format!("{}{}", ion_str!("unexpected character: "), ch as char),
358                line,
359                col,
360            )),
361        }
362    }
363
364    fn lex_number(&mut self, line: usize, col: usize) -> Result<SpannedToken, IonError> {
365        let start = self.pos;
366        let mut is_float = false;
367
368        while self.peek().is_ascii_digit() || self.peek() == b'_' {
369            self.advance();
370        }
371        if self.peek() == b'.' && self.peek_at(1) != b'.' {
372            is_float = true;
373            self.advance();
374            while self.peek().is_ascii_digit() || self.peek() == b'_' {
375                self.advance();
376            }
377        }
378
379        let text: String = self.source[start..self.pos]
380            .iter()
381            .filter(|&&b| b != b'_')
382            .map(|&b| b as char)
383            .collect();
384
385        if is_float {
386            let val: f64 = text
387                .parse()
388                .map_err(|_| IonError::lex(ion_str!("invalid float literal"), line, col))?;
389            Ok(self.spanned(Token::Float(val), line, col))
390        } else {
391            let val: i64 = text
392                .parse()
393                .map_err(|_| IonError::lex(ion_str!("invalid integer literal"), line, col))?;
394            Ok(self.spanned(Token::Int(val), line, col))
395        }
396    }
397
398    fn lex_string(
399        &mut self,
400        line: usize,
401        col: usize,
402        is_fstr: bool,
403    ) -> Result<SpannedToken, IonError> {
404        self.advance(); // consume opening "
405        let mut s = String::new();
406        let mut interp_depth = 0u32; // track {} nesting in f-string interpolation
407
408        while self.pos < self.source.len() {
409            let ch = self.peek();
410
411            // In f-string interpolation mode, track braces and allow quotes
412            if is_fstr && interp_depth > 0 {
413                if ch == b'"' {
414                    // Skip over nested string inside interpolation
415                    self.advance();
416                    s.push('"');
417                    while self.pos < self.source.len() && self.peek() != b'"' {
418                        if self.peek() == b'\\' {
419                            self.advance();
420                            s.push('\\');
421                            if self.pos < self.source.len() {
422                                self.push_utf8_char(&mut s);
423                            }
424                        } else {
425                            self.push_utf8_char(&mut s);
426                        }
427                    }
428                    if self.pos < self.source.len() {
429                        self.advance(); // closing "
430                        s.push('"');
431                    }
432                    continue;
433                } else if ch == b'{' {
434                    interp_depth += 1;
435                    self.advance();
436                    s.push('{');
437                    continue;
438                } else if ch == b'}' {
439                    interp_depth -= 1;
440                    self.advance();
441                    s.push('}');
442                    continue;
443                } else {
444                    self.push_utf8_char(&mut s);
445                    continue;
446                }
447            }
448
449            // Outside interpolation: " terminates the string
450            if ch == b'"' {
451                break;
452            }
453
454            if ch == b'\\' {
455                self.advance();
456                match self.peek() {
457                    b'n' => {
458                        self.advance();
459                        s.push('\n');
460                    }
461                    b't' => {
462                        self.advance();
463                        s.push('\t');
464                    }
465                    b'r' => {
466                        self.advance();
467                        s.push('\r');
468                    }
469                    b'\\' => {
470                        self.advance();
471                        s.push('\\');
472                    }
473                    b'"' => {
474                        self.advance();
475                        s.push('"');
476                    }
477                    b'{' => {
478                        self.advance();
479                        s.push('{');
480                    }
481                    b'}' => {
482                        self.advance();
483                        s.push('}');
484                    }
485                    b'u' => {
486                        self.advance(); // consume 'u'
487                        self.parse_unicode_escape(&mut s)?;
488                    }
489                    _ => {
490                        return Err(IonError::lex(
491                            ion_str!("invalid escape sequence"),
492                            self.line,
493                            self.col,
494                        ));
495                    }
496                }
497            } else if is_fstr && ch == b'{' {
498                // Enter interpolation mode
499                interp_depth = 1;
500                self.advance();
501                s.push('{');
502            } else {
503                self.push_utf8_char(&mut s);
504            }
505        }
506
507        if self.pos >= self.source.len() {
508            return Err(IonError::lex(ion_str!("unterminated string"), line, col));
509        }
510        self.advance(); // consume closing "
511
512        if is_fstr {
513            Ok(self.spanned(Token::FStr(s), line, col))
514        } else {
515            Ok(self.spanned(Token::Str(s), line, col))
516        }
517    }
518
519    fn lex_triple_string(
520        &mut self,
521        line: usize,
522        col: usize,
523        is_fstr: bool,
524    ) -> Result<SpannedToken, IonError> {
525        // consume opening """
526        self.advance(); // first "
527        self.advance(); // second "
528        self.advance(); // third "
529                        // Skip a leading newline immediately after """
530        if self.pos < self.source.len() && self.peek() == b'\n' {
531            self.advance();
532        }
533        let mut s = String::new();
534
535        while self.pos < self.source.len() {
536            if self.peek() == b'"'
537                && self.pos + 1 < self.source.len()
538                && self.source[self.pos + 1] == b'"'
539                && self.pos + 2 < self.source.len()
540                && self.source[self.pos + 2] == b'"'
541            {
542                // consume closing """
543                self.advance();
544                self.advance();
545                self.advance();
546                if is_fstr {
547                    return Ok(self.spanned(Token::FStr(s), line, col));
548                } else {
549                    return Ok(self.spanned(Token::Str(s), line, col));
550                }
551            }
552            let ch = self.peek();
553            if ch == b'\n' {
554                self.advance();
555                s.push('\n');
556            } else if ch == b'\\' {
557                self.advance();
558                match self.peek() {
559                    b'n' => {
560                        self.advance();
561                        s.push('\n');
562                    }
563                    b't' => {
564                        self.advance();
565                        s.push('\t');
566                    }
567                    b'r' => {
568                        self.advance();
569                        s.push('\r');
570                    }
571                    b'\\' => {
572                        self.advance();
573                        s.push('\\');
574                    }
575                    b'"' => {
576                        self.advance();
577                        s.push('"');
578                    }
579                    b'{' => {
580                        self.advance();
581                        s.push('{');
582                    }
583                    b'}' => {
584                        self.advance();
585                        s.push('}');
586                    }
587                    b'u' => {
588                        self.advance(); // consume 'u'
589                        self.parse_unicode_escape(&mut s)?;
590                    }
591                    _ => {
592                        return Err(IonError::lex(
593                            ion_str!("invalid escape sequence"),
594                            self.line,
595                            self.col,
596                        ));
597                    }
598                }
599            } else {
600                self.push_utf8_char(&mut s);
601            }
602        }
603
604        Err(IonError::lex(
605            ion_str!("unterminated triple-quoted string"),
606            line,
607            col,
608        ))
609    }
610
611    fn lex_bytes(&mut self, line: usize, col: usize) -> Result<SpannedToken, IonError> {
612        self.advance(); // consume opening "
613        let mut bytes = Vec::new();
614
615        while self.pos < self.source.len() && self.peek() != b'"' {
616            let ch = self.peek();
617            if ch == b'\\' {
618                self.advance();
619                match self.peek() {
620                    b'n' => {
621                        self.advance();
622                        bytes.push(b'\n');
623                    }
624                    b't' => {
625                        self.advance();
626                        bytes.push(b'\t');
627                    }
628                    b'r' => {
629                        self.advance();
630                        bytes.push(b'\r');
631                    }
632                    b'\\' => {
633                        self.advance();
634                        bytes.push(b'\\');
635                    }
636                    b'"' => {
637                        self.advance();
638                        bytes.push(b'"');
639                    }
640                    b'0' => {
641                        self.advance();
642                        bytes.push(0);
643                    }
644                    b'x' => {
645                        self.advance(); // consume 'x'
646                        let hi = self.advance();
647                        let lo = self.advance();
648                        let val = hex_digit(hi).ok_or_else(|| {
649                            IonError::lex(
650                                ion_str!("invalid hex escape in byte string"),
651                                self.line,
652                                self.col,
653                            )
654                        })? << 4
655                            | hex_digit(lo).ok_or_else(|| {
656                                IonError::lex(
657                                    ion_str!("invalid hex escape in byte string"),
658                                    self.line,
659                                    self.col,
660                                )
661                            })?;
662                        bytes.push(val);
663                    }
664                    _ => {
665                        return Err(IonError::lex(
666                            ion_str!("invalid escape sequence in byte string"),
667                            self.line,
668                            self.col,
669                        ));
670                    }
671                }
672            } else {
673                self.advance();
674                bytes.push(ch);
675            }
676        }
677
678        if self.pos >= self.source.len() {
679            return Err(IonError::lex(
680                ion_str!("unterminated byte string"),
681                line,
682                col,
683            ));
684        }
685        self.advance(); // consume closing "
686
687        Ok(self.spanned(Token::Bytes(bytes), line, col))
688    }
689
690    fn lex_label(&mut self, line: usize, col: usize) -> Result<SpannedToken, IonError> {
691        self.advance(); // consume '
692        let start = self.pos;
693        if !(self.peek().is_ascii_alphabetic() || self.peek() == b'_') {
694            return Err(IonError::lex(
695                ion_str!("expected label name after '"),
696                line,
697                col,
698            ));
699        }
700        while self.peek().is_ascii_alphanumeric() || self.peek() == b'_' {
701            self.advance();
702        }
703        let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap();
704        Ok(self.spanned(Token::Label(text.to_string()), line, col))
705    }
706
707    fn lex_ident(&mut self, line: usize, col: usize) -> Result<SpannedToken, IonError> {
708        let start = self.pos;
709        while self.peek().is_ascii_alphanumeric() || self.peek() == b'_' {
710            self.advance();
711        }
712        let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap();
713        let token = match text {
714            "let" => Token::Let,
715            "mut" => Token::Mut,
716            "fn" => Token::Fn,
717            "match" => Token::Match,
718            "if" => Token::If,
719            "else" => Token::Else,
720            "for" => Token::For,
721            "while" => Token::While,
722            "loop" => Token::Loop,
723            "break" => Token::Break,
724            "continue" => Token::Continue,
725            "return" => Token::Return,
726            "in" => Token::In,
727            "as" => Token::As,
728            "true" => Token::True,
729            "false" => Token::False,
730            "None" => Token::None,
731            "Some" => Token::Some,
732            "Ok" => Token::Ok,
733            "Err" => Token::Err,
734            "async" => Token::Async,
735            "spawn" => Token::Spawn,
736            "await" => Token::Await,
737            "select" => Token::Select,
738            "try" => Token::Try,
739            "catch" => Token::Catch,
740            "use" => Token::Use,
741            _ => Token::Ident(text.to_string()),
742        };
743        Ok(self.spanned(token, line, col))
744    }
745}
746
747fn hex_digit(ch: u8) -> Option<u8> {
748    match ch {
749        b'0'..=b'9' => Some(ch - b'0'),
750        b'a'..=b'f' => Some(ch - b'a' + 10),
751        b'A'..=b'F' => Some(ch - b'A' + 10),
752        _ => None,
753    }
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759
760    fn lex(src: &str) -> Vec<Token> {
761        Lexer::new(src)
762            .tokenize()
763            .unwrap()
764            .into_iter()
765            .map(|t| t.token)
766            .collect()
767    }
768
769    #[test]
770    fn test_basic_tokens() {
771        let tokens = lex("let x = 42;");
772        assert_eq!(
773            tokens,
774            vec![
775                Token::Let,
776                Token::Ident("x".into()),
777                Token::Eq,
778                Token::Int(42),
779                Token::Semicolon,
780                Token::Eof,
781            ]
782        );
783    }
784
785    #[test]
786    fn test_string() {
787        let tokens = lex(r#""hello world""#);
788        assert_eq!(tokens[0], Token::Str("hello world".into()));
789    }
790
791    #[test]
792    fn test_fstring() {
793        let tokens = lex(r#"f"hi {name}""#);
794        assert_eq!(tokens[0], Token::FStr("hi {name}".into()));
795    }
796
797    #[test]
798    fn test_hash_brace() {
799        let tokens = lex("#{ }");
800        assert_eq!(tokens[0], Token::HashBrace);
801    }
802
803    #[test]
804    fn test_operators() {
805        let tokens = lex("|> .. ... => :: ? !=");
806        assert_eq!(
807            tokens,
808            vec![
809                Token::Pipe,
810                Token::DotDot,
811                Token::DotDotDot,
812                Token::Arrow,
813                Token::ColonColon,
814                Token::Question,
815                Token::BangEq,
816                Token::Eof,
817            ]
818        );
819    }
820
821    #[test]
822    fn test_float() {
823        let tokens = lex("3.5");
824        assert_eq!(tokens[0], Token::Float(3.5));
825    }
826
827    #[test]
828    fn test_comments() {
829        let tokens = lex("let x = 1; // comment\nlet y = 2;");
830        assert_eq!(
831            tokens,
832            vec![
833                Token::Let,
834                Token::Ident("x".into()),
835                Token::Eq,
836                Token::Int(1),
837                Token::Semicolon,
838                Token::Let,
839                Token::Ident("y".into()),
840                Token::Eq,
841                Token::Int(2),
842                Token::Semicolon,
843                Token::Eof,
844            ]
845        );
846    }
847}