Skip to main content

hiko_syntax/
lexer.rs

1use crate::span::Span;
2use crate::token::{Token, TokenKind};
3
4#[derive(Debug, Clone)]
5pub struct LexError {
6    pub message: String,
7    pub span: Span,
8}
9
10fn hex_digit(b: u8) -> Option<u8> {
11    match b {
12        b'0'..=b'9' => Some(b - b'0'),
13        b'a'..=b'f' => Some(b - b'a' + 10),
14        b'A'..=b'F' => Some(b - b'A' + 10),
15        _ => None,
16    }
17}
18
19pub struct Lexer<'src> {
20    source: &'src str,
21    bytes: &'src [u8],
22    pos: usize,
23    file_id: u32,
24}
25
26impl<'src> Lexer<'src> {
27    pub fn new(source: &'src str, file_id: u32) -> Self {
28        Self {
29            source,
30            bytes: source.as_bytes(),
31            pos: 0,
32            file_id,
33        }
34    }
35
36    pub fn tokenize(&mut self) -> Result<Vec<Token>, LexError> {
37        let mut tokens = Vec::new();
38        loop {
39            let tok = self.next_token()?;
40            let is_eof = tok.kind == TokenKind::Eof;
41            tokens.push(tok);
42            if is_eof {
43                break;
44            }
45        }
46        Ok(tokens)
47    }
48
49    fn next_token(&mut self) -> Result<Token, LexError> {
50        self.skip_whitespace_and_comments()?;
51
52        if self.pos >= self.bytes.len() {
53            return Ok(self.make_token(TokenKind::Eof, self.pos, self.pos));
54        }
55
56        let start = self.pos;
57        let ch = self.bytes[start];
58
59        // String literal
60        if ch == b'"' {
61            return self.lex_string();
62        }
63
64        // Char literal: #"c"
65        if ch == b'#' && self.peek_at(1) == Some(b'"') {
66            return self.lex_char();
67        }
68
69        // Number (int or float)
70        if ch.is_ascii_digit() {
71            return self.lex_number();
72        }
73
74        // Type variable: 'a, 'b, etc.
75        if ch == b'\''
76            && self
77                .peek_at(1)
78                .is_some_and(|c| c.is_ascii_alphabetic() || c == b'_')
79        {
80            return self.lex_tyvar();
81        }
82
83        // Identifier or keyword
84        if ch.is_ascii_alphabetic() || ch == b'_' {
85            return self.lex_ident();
86        }
87
88        // Operators and delimiters
89        self.lex_operator_or_delimiter()
90    }
91
92    fn lex_string(&mut self) -> Result<Token, LexError> {
93        let start = self.pos;
94        self.pos += 1; // skip opening "
95        let mut value = String::new();
96
97        loop {
98            if self.pos >= self.bytes.len() {
99                return Err(self.err("unterminated string literal", start));
100            }
101            match self.bytes[self.pos] {
102                b'"' => {
103                    self.pos += 1;
104                    return Ok(self.make_token(TokenKind::StringLit(value), start, self.pos));
105                }
106                b'\\' => {
107                    value.push(self.parse_escape(start)?);
108                }
109                _ => {
110                    let rest = &self.source[self.pos..];
111                    if let Some(c) = rest.chars().next() {
112                        value.push(c);
113                        self.pos += c.len_utf8();
114                    }
115                }
116            }
117        }
118    }
119
120    fn lex_char(&mut self) -> Result<Token, LexError> {
121        let start = self.pos;
122        self.pos += 2; // skip #"
123
124        if self.pos >= self.bytes.len() {
125            return Err(self.err("unterminated character literal", start));
126        }
127
128        let c = if self.bytes[self.pos] == b'\\' {
129            self.parse_escape(start)?
130        } else {
131            let rest = &self.source[self.pos..];
132            let c = rest.chars().next().unwrap();
133            self.pos += c.len_utf8();
134            c
135        };
136
137        if self.pos >= self.bytes.len() || self.bytes[self.pos] != b'"' {
138            return Err(self.err("unterminated character literal, expected closing \"", start));
139        }
140        self.pos += 1; // skip closing "
141
142        Ok(self.make_token(TokenKind::CharLit(c), start, self.pos))
143    }
144
145    fn parse_escape(&mut self, literal_start: usize) -> Result<char, LexError> {
146        self.pos += 1; // skip backslash
147        if self.pos >= self.bytes.len() {
148            return Err(self.err("unterminated escape sequence", literal_start));
149        }
150        let c = match self.bytes[self.pos] {
151            b'n' => '\n',
152            b't' => '\t',
153            b'r' => '\r',
154            b'0' => '\0',
155            b'\\' => '\\',
156            b'"' => '"',
157            b'x' => {
158                // \xHH: two hex digits
159                if self.pos + 2 >= self.bytes.len() {
160                    return Err(self.err("incomplete \\x escape", literal_start));
161                }
162                let hi = self.bytes[self.pos + 1];
163                let lo = self.bytes[self.pos + 2];
164                let val = hex_digit(hi)
165                    .and_then(|h| hex_digit(lo).map(|l| h * 16 + l))
166                    .ok_or_else(|| self.err("invalid hex digit in \\x escape", literal_start))?;
167                self.pos += 2; // skip the two hex digits (main +1 below)
168                val as char
169            }
170            other => {
171                return Err(LexError {
172                    message: format!("unknown escape sequence: \\{}", other as char),
173                    span: self.span(self.pos - 1, self.pos + 1),
174                });
175            }
176        };
177        self.pos += 1;
178        Ok(c)
179    }
180
181    fn lex_number(&mut self) -> Result<Token, LexError> {
182        let start = self.pos;
183        self.consume_digits();
184
185        let mut is_float = false;
186
187        // Check for decimal point followed by digit
188        if self.pos < self.bytes.len()
189            && self.bytes[self.pos] == b'.'
190            && self
191                .bytes
192                .get(self.pos + 1)
193                .is_some_and(|c| c.is_ascii_digit())
194        {
195            is_float = true;
196            self.pos += 1; // skip '.'
197            self.consume_digits();
198        }
199
200        // Check for exponent
201        if self.pos < self.bytes.len() && matches!(self.bytes[self.pos], b'e' | b'E') {
202            is_float = true;
203            self.pos += 1;
204            if self.pos < self.bytes.len() && matches!(self.bytes[self.pos], b'+' | b'-') {
205                self.pos += 1;
206            }
207            if self.pos >= self.bytes.len() || !self.bytes[self.pos].is_ascii_digit() {
208                return Err(self.err("expected digits after exponent", start));
209            }
210            self.consume_digits();
211        }
212
213        let text = &self.source[start..self.pos];
214        if is_float {
215            let value: f64 = text
216                .parse()
217                .map_err(|_| self.err(&format!("invalid float literal: {text}"), start))?;
218            Ok(self.make_token(TokenKind::FloatLit(value), start, self.pos))
219        } else {
220            let value: i64 = text
221                .parse()
222                .map_err(|_| self.err(&format!("invalid integer literal: {text}"), start))?;
223            Ok(self.make_token(TokenKind::IntLit(value), start, self.pos))
224        }
225    }
226
227    fn lex_tyvar(&mut self) -> Result<Token, LexError> {
228        let start = self.pos;
229        self.pos += 1; // skip '
230        self.consume_ident_chars();
231        let name = self.source[start..self.pos].to_string();
232        Ok(self.make_token(TokenKind::TyVar(name), start, self.pos))
233    }
234
235    fn lex_ident(&mut self) -> Result<Token, LexError> {
236        let start = self.pos;
237        self.consume_ident_chars();
238        let text = &self.source[start..self.pos];
239
240        if text == "_" {
241            return Ok(self.make_token(TokenKind::Underscore, start, self.pos));
242        }
243
244        if let Some(kw) = TokenKind::keyword_from_str(text) {
245            return Ok(self.make_token(kw, start, self.pos));
246        }
247
248        let first = text.as_bytes()[0];
249        if first.is_ascii_uppercase() {
250            Ok(self.make_token(TokenKind::UpperIdent(text.to_string()), start, self.pos))
251        } else {
252            Ok(self.make_token(TokenKind::Ident(text.to_string()), start, self.pos))
253        }
254    }
255
256    fn lex_operator_or_delimiter(&mut self) -> Result<Token, LexError> {
257        let start = self.pos;
258        let ch = self.bytes[start];
259
260        let kind = match ch {
261            b'(' => {
262                self.pos += 1;
263                TokenKind::LParen
264            }
265            b')' => {
266                self.pos += 1;
267                TokenKind::RParen
268            }
269            b'[' => {
270                self.pos += 1;
271                TokenKind::LBracket
272            }
273            b']' => {
274                self.pos += 1;
275                TokenKind::RBracket
276            }
277            b',' => {
278                self.pos += 1;
279                TokenKind::Comma
280            }
281            b';' => {
282                self.pos += 1;
283                TokenKind::Semicolon
284            }
285            b'~' => {
286                self.pos += 1;
287                TokenKind::Tilde
288            }
289            b'#' => {
290                self.pos += 1;
291                TokenKind::Hash
292            }
293            b'^' => {
294                self.pos += 1;
295                TokenKind::Caret
296            }
297            b'|' => {
298                self.pos += 1;
299                TokenKind::Bar
300            }
301
302            b':' => {
303                self.pos += 1;
304                if self.peek() == Some(b':') {
305                    self.pos += 1;
306                    TokenKind::ColonColon
307                } else {
308                    TokenKind::Colon
309                }
310            }
311
312            b'=' => {
313                self.pos += 1;
314                if self.peek() == Some(b'>') {
315                    self.pos += 1;
316                    TokenKind::Arrow
317                } else {
318                    TokenKind::Eq
319                }
320            }
321
322            b'-' => {
323                self.pos += 1;
324                if self.peek() == Some(b'>') {
325                    self.pos += 1;
326                    TokenKind::ThinArrow
327                } else if self.peek() == Some(b'.') {
328                    self.pos += 1;
329                    TokenKind::MinusDot
330                } else {
331                    TokenKind::Minus
332                }
333            }
334
335            b'+' => {
336                self.pos += 1;
337                if self.peek() == Some(b'.') {
338                    self.pos += 1;
339                    TokenKind::PlusDot
340                } else {
341                    TokenKind::Plus
342                }
343            }
344
345            b'*' => {
346                self.pos += 1;
347                if self.peek() == Some(b'.') {
348                    self.pos += 1;
349                    TokenKind::StarDot
350                } else {
351                    TokenKind::Star
352                }
353            }
354
355            b'/' => {
356                self.pos += 1;
357                if self.peek() == Some(b'.') {
358                    self.pos += 1;
359                    TokenKind::SlashDot
360                } else {
361                    TokenKind::Slash
362                }
363            }
364
365            b'<' => {
366                self.pos += 1;
367                match self.peek() {
368                    Some(b'>') => {
369                        self.pos += 1;
370                        TokenKind::Ne
371                    }
372                    Some(b'=') => {
373                        self.pos += 1;
374                        if self.peek() == Some(b'.') {
375                            self.pos += 1;
376                            TokenKind::LeDot
377                        } else {
378                            TokenKind::Le
379                        }
380                    }
381                    Some(b'.') => {
382                        self.pos += 1;
383                        TokenKind::LtDot
384                    }
385                    _ => TokenKind::Lt,
386                }
387            }
388
389            b'>' => {
390                self.pos += 1;
391                match self.peek() {
392                    Some(b'=') => {
393                        self.pos += 1;
394                        if self.peek() == Some(b'.') {
395                            self.pos += 1;
396                            TokenKind::GeDot
397                        } else {
398                            TokenKind::Ge
399                        }
400                    }
401                    Some(b'.') => {
402                        self.pos += 1;
403                        TokenKind::GtDot
404                    }
405                    _ => TokenKind::Gt,
406                }
407            }
408
409            _ => {
410                self.pos += 1;
411                return Err(LexError {
412                    message: format!("unexpected character: '{}'", ch as char),
413                    span: self.span(start, self.pos),
414                });
415            }
416        };
417
418        Ok(self.make_token(kind, start, self.pos))
419    }
420
421    fn skip_whitespace_and_comments(&mut self) -> Result<(), LexError> {
422        loop {
423            // Skip whitespace
424            while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_whitespace() {
425                self.pos += 1;
426            }
427
428            // Skip nested comments (* ... *)
429            if self.pos + 1 < self.bytes.len()
430                && self.bytes[self.pos] == b'('
431                && self.bytes[self.pos + 1] == b'*'
432            {
433                self.skip_comment()?;
434            } else {
435                break;
436            }
437        }
438        Ok(())
439    }
440
441    fn skip_comment(&mut self) -> Result<(), LexError> {
442        let start = self.pos;
443        self.pos += 2; // skip (*
444        let mut depth = 1u32;
445
446        while self.pos < self.bytes.len() && depth > 0 {
447            if self.pos + 1 < self.bytes.len()
448                && self.bytes[self.pos] == b'('
449                && self.bytes[self.pos + 1] == b'*'
450            {
451                depth += 1;
452                self.pos += 2;
453            } else if self.pos + 1 < self.bytes.len()
454                && self.bytes[self.pos] == b'*'
455                && self.bytes[self.pos + 1] == b')'
456            {
457                depth -= 1;
458                self.pos += 2;
459            } else {
460                self.pos += 1;
461            }
462        }
463
464        if depth > 0 {
465            return Err(self.err("unterminated comment", start));
466        }
467        Ok(())
468    }
469
470    fn consume_digits(&mut self) {
471        while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_digit() {
472            self.pos += 1;
473        }
474    }
475
476    fn consume_ident_chars(&mut self) {
477        while self.pos < self.bytes.len()
478            && (self.bytes[self.pos].is_ascii_alphanumeric() || self.bytes[self.pos] == b'_')
479        {
480            self.pos += 1;
481        }
482    }
483
484    fn peek(&self) -> Option<u8> {
485        self.bytes.get(self.pos).copied()
486    }
487
488    fn peek_at(&self, offset: usize) -> Option<u8> {
489        self.bytes.get(self.pos + offset).copied()
490    }
491
492    fn span(&self, start: usize, end: usize) -> Span {
493        Span::new(self.file_id, start as u32, end as u32)
494    }
495
496    fn err(&self, message: &str, start: usize) -> LexError {
497        LexError {
498            message: message.to_string(),
499            span: self.span(start, self.pos),
500        }
501    }
502
503    fn make_token(&self, kind: TokenKind, start: usize, end: usize) -> Token {
504        Token {
505            kind,
506            span: self.span(start, end),
507        }
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    fn lex(input: &str) -> Vec<TokenKind> {
516        let mut lexer = Lexer::new(input, 0);
517        lexer
518            .tokenize()
519            .unwrap()
520            .into_iter()
521            .map(|t| t.kind)
522            .collect()
523    }
524
525    fn lex_err(input: &str) -> String {
526        let mut lexer = Lexer::new(input, 0);
527        lexer.tokenize().unwrap_err().message
528    }
529
530    #[test]
531    fn test_int_literals() {
532        assert_eq!(lex("42"), vec![TokenKind::IntLit(42), TokenKind::Eof]);
533        assert_eq!(lex("0"), vec![TokenKind::IntLit(0), TokenKind::Eof]);
534        assert_eq!(lex("12345"), vec![TokenKind::IntLit(12345), TokenKind::Eof]);
535    }
536
537    #[test]
538    fn test_float_literals() {
539        assert_eq!(lex("3.14"), vec![TokenKind::FloatLit(3.14), TokenKind::Eof]);
540        assert_eq!(
541            lex("1.0e10"),
542            vec![TokenKind::FloatLit(1.0e10), TokenKind::Eof]
543        );
544        assert_eq!(
545            lex("2.5E-3"),
546            vec![TokenKind::FloatLit(2.5e-3), TokenKind::Eof]
547        );
548        assert_eq!(lex("0.0"), vec![TokenKind::FloatLit(0.0), TokenKind::Eof]);
549    }
550
551    #[test]
552    fn test_string_literals() {
553        assert_eq!(
554            lex(r#""hello""#),
555            vec![TokenKind::StringLit("hello".to_string()), TokenKind::Eof]
556        );
557        assert_eq!(
558            lex(r#""a\nb""#),
559            vec![TokenKind::StringLit("a\nb".to_string()), TokenKind::Eof]
560        );
561        assert_eq!(
562            lex(r#""a\\b""#),
563            vec![TokenKind::StringLit("a\\b".to_string()), TokenKind::Eof]
564        );
565    }
566
567    #[test]
568    fn test_char_literals() {
569        assert_eq!(
570            lex(r#"#"a""#),
571            vec![TokenKind::CharLit('a'), TokenKind::Eof]
572        );
573        assert_eq!(
574            lex(r#"#"\n""#),
575            vec![TokenKind::CharLit('\n'), TokenKind::Eof]
576        );
577    }
578
579    #[test]
580    fn test_keywords() {
581        assert_eq!(lex("val"), vec![TokenKind::Val, TokenKind::Eof]);
582        assert_eq!(lex("fun"), vec![TokenKind::Fun, TokenKind::Eof]);
583        assert_eq!(lex("fn"), vec![TokenKind::Fn, TokenKind::Eof]);
584        assert_eq!(lex("let"), vec![TokenKind::Let, TokenKind::Eof]);
585        assert_eq!(lex("case"), vec![TokenKind::Case, TokenKind::Eof]);
586        assert_eq!(lex("datatype"), vec![TokenKind::Datatype, TokenKind::Eof]);
587        assert_eq!(lex("andalso"), vec![TokenKind::Andalso, TokenKind::Eof]);
588        assert_eq!(lex("orelse"), vec![TokenKind::Orelse, TokenKind::Eof]);
589    }
590
591    #[test]
592    fn test_identifiers() {
593        assert_eq!(
594            lex("foo"),
595            vec![TokenKind::Ident("foo".to_string()), TokenKind::Eof]
596        );
597        assert_eq!(
598            lex("x1"),
599            vec![TokenKind::Ident("x1".to_string()), TokenKind::Eof]
600        );
601        assert_eq!(
602            lex("_bar"),
603            vec![TokenKind::Ident("_bar".to_string()), TokenKind::Eof]
604        );
605    }
606
607    #[test]
608    fn test_upper_idents() {
609        assert_eq!(
610            lex("Some"),
611            vec![TokenKind::UpperIdent("Some".to_string()), TokenKind::Eof]
612        );
613        assert_eq!(
614            lex("None"),
615            vec![TokenKind::UpperIdent("None".to_string()), TokenKind::Eof]
616        );
617    }
618
619    #[test]
620    fn test_tyvars() {
621        assert_eq!(
622            lex("'a"),
623            vec![TokenKind::TyVar("'a".to_string()), TokenKind::Eof]
624        );
625        assert_eq!(
626            lex("'abc"),
627            vec![TokenKind::TyVar("'abc".to_string()), TokenKind::Eof]
628        );
629    }
630
631    #[test]
632    fn test_operators() {
633        assert_eq!(
634            lex("+ +. - -. * *. / /."),
635            vec![
636                TokenKind::Plus,
637                TokenKind::PlusDot,
638                TokenKind::Minus,
639                TokenKind::MinusDot,
640                TokenKind::Star,
641                TokenKind::StarDot,
642                TokenKind::Slash,
643                TokenKind::SlashDot,
644                TokenKind::Eof,
645            ]
646        );
647    }
648
649    #[test]
650    fn test_comparison_operators() {
651        assert_eq!(
652            lex("< <. <= <=. > >. >= >=."),
653            vec![
654                TokenKind::Lt,
655                TokenKind::LtDot,
656                TokenKind::Le,
657                TokenKind::LeDot,
658                TokenKind::Gt,
659                TokenKind::GtDot,
660                TokenKind::Ge,
661                TokenKind::GeDot,
662                TokenKind::Eof,
663            ]
664        );
665    }
666
667    #[test]
668    fn test_equality_operators() {
669        assert_eq!(
670            lex("= <>"),
671            vec![TokenKind::Eq, TokenKind::Ne, TokenKind::Eof]
672        );
673    }
674
675    #[test]
676    fn test_arrows_and_cons() {
677        assert_eq!(
678            lex("=> -> ::"),
679            vec![
680                TokenKind::Arrow,
681                TokenKind::ThinArrow,
682                TokenKind::ColonColon,
683                TokenKind::Eof,
684            ]
685        );
686    }
687
688    #[test]
689    fn test_delimiters() {
690        assert_eq!(
691            lex("( ) [ ] , : ; | _ #"),
692            vec![
693                TokenKind::LParen,
694                TokenKind::RParen,
695                TokenKind::LBracket,
696                TokenKind::RBracket,
697                TokenKind::Comma,
698                TokenKind::Colon,
699                TokenKind::Semicolon,
700                TokenKind::Bar,
701                TokenKind::Underscore,
702                TokenKind::Hash,
703                TokenKind::Eof,
704            ]
705        );
706    }
707
708    #[test]
709    fn test_comments() {
710        assert_eq!(
711            lex("(* comment *) 42"),
712            vec![TokenKind::IntLit(42), TokenKind::Eof]
713        );
714    }
715
716    #[test]
717    fn test_nested_comments() {
718        assert_eq!(
719            lex("(* outer (* inner *) still outer *) 1"),
720            vec![TokenKind::IntLit(1), TokenKind::Eof],
721        );
722    }
723
724    #[test]
725    fn test_unterminated_comment() {
726        assert_eq!(lex_err("(* oops"), "unterminated comment");
727    }
728
729    #[test]
730    fn test_unterminated_string() {
731        assert_eq!(lex_err(r#""oops"#), "unterminated string literal");
732    }
733
734    #[test]
735    fn test_unexpected_char() {
736        assert_eq!(lex_err("@"), "unexpected character: '@'");
737    }
738
739    #[test]
740    fn test_negation_is_operator() {
741        assert_eq!(
742            lex("~42"),
743            vec![TokenKind::Tilde, TokenKind::IntLit(42), TokenKind::Eof,]
744        );
745    }
746
747    #[test]
748    fn test_full_expression() {
749        let tokens = lex("fun fib n = if n < 2 then n else fib (n - 1) + fib (n - 2)");
750        assert_eq!(tokens[0], TokenKind::Fun);
751        assert_eq!(tokens[1], TokenKind::Ident("fib".to_string()));
752        assert_eq!(tokens[2], TokenKind::Ident("n".to_string()));
753        assert_eq!(tokens[3], TokenKind::Eq);
754        assert_eq!(tokens[4], TokenKind::If);
755    }
756
757    #[test]
758    fn test_empty_input() {
759        assert_eq!(lex(""), vec![TokenKind::Eof]);
760    }
761
762    #[test]
763    fn test_comments_only() {
764        assert_eq!(lex("(* just a comment *)"), vec![TokenKind::Eof]);
765    }
766
767    #[test]
768    fn test_datatype_declaration() {
769        let tokens = lex("datatype 'a option = None | Some of 'a");
770        assert_eq!(tokens[0], TokenKind::Datatype);
771        assert_eq!(tokens[1], TokenKind::TyVar("'a".to_string()));
772        assert_eq!(tokens[2], TokenKind::Ident("option".to_string()));
773        assert_eq!(tokens[3], TokenKind::Eq);
774        assert_eq!(tokens[4], TokenKind::UpperIdent("None".to_string()));
775        assert_eq!(tokens[5], TokenKind::Bar);
776        assert_eq!(tokens[6], TokenKind::UpperIdent("Some".to_string()));
777        assert_eq!(tokens[7], TokenKind::Of);
778        assert_eq!(tokens[8], TokenKind::TyVar("'a".to_string()));
779    }
780}