1use crate::span::Span;
2use crate::token::{Token, TokenKind};
3
4#[derive(Debug, Clone)]
5pub struct LexError {
6 pub message: String,
7 pub span: Span,
8}
9
10fn hex_digit(b: u8) -> Option<u8> {
11 match b {
12 b'0'..=b'9' => Some(b - b'0'),
13 b'a'..=b'f' => Some(b - b'a' + 10),
14 b'A'..=b'F' => Some(b - b'A' + 10),
15 _ => None,
16 }
17}
18
19pub struct Lexer<'src> {
20 source: &'src str,
21 bytes: &'src [u8],
22 pos: usize,
23 file_id: u32,
24}
25
26impl<'src> Lexer<'src> {
27 pub fn new(source: &'src str, file_id: u32) -> Self {
28 Self {
29 source,
30 bytes: source.as_bytes(),
31 pos: 0,
32 file_id,
33 }
34 }
35
36 pub fn tokenize(&mut self) -> Result<Vec<Token>, LexError> {
37 let mut tokens = Vec::new();
38 loop {
39 let tok = self.next_token()?;
40 let is_eof = tok.kind == TokenKind::Eof;
41 tokens.push(tok);
42 if is_eof {
43 break;
44 }
45 }
46 Ok(tokens)
47 }
48
49 fn next_token(&mut self) -> Result<Token, LexError> {
50 self.skip_whitespace_and_comments()?;
51
52 if self.pos >= self.bytes.len() {
53 return Ok(self.make_token(TokenKind::Eof, self.pos, self.pos));
54 }
55
56 let start = self.pos;
57 let ch = self.bytes[start];
58
59 if ch == b'"' {
61 return self.lex_string();
62 }
63
64 if ch == b'#' && self.peek_at(1) == Some(b'"') {
66 return self.lex_char();
67 }
68
69 if ch.is_ascii_digit() {
71 return self.lex_number();
72 }
73
74 if ch == b'\''
76 && self
77 .peek_at(1)
78 .is_some_and(|c| c.is_ascii_alphabetic() || c == b'_')
79 {
80 return self.lex_tyvar();
81 }
82
83 if ch.is_ascii_alphabetic() || ch == b'_' {
85 return self.lex_ident();
86 }
87
88 self.lex_operator_or_delimiter()
90 }
91
92 fn lex_string(&mut self) -> Result<Token, LexError> {
93 let start = self.pos;
94 self.pos += 1; let mut value = String::new();
96
97 loop {
98 if self.pos >= self.bytes.len() {
99 return Err(self.err("unterminated string literal", start));
100 }
101 match self.bytes[self.pos] {
102 b'"' => {
103 self.pos += 1;
104 return Ok(self.make_token(TokenKind::StringLit(value), start, self.pos));
105 }
106 b'\\' => {
107 value.push(self.parse_escape(start)?);
108 }
109 _ => {
110 let rest = &self.source[self.pos..];
111 if let Some(c) = rest.chars().next() {
112 value.push(c);
113 self.pos += c.len_utf8();
114 }
115 }
116 }
117 }
118 }
119
120 fn lex_char(&mut self) -> Result<Token, LexError> {
121 let start = self.pos;
122 self.pos += 2; if self.pos >= self.bytes.len() {
125 return Err(self.err("unterminated character literal", start));
126 }
127
128 let c = if self.bytes[self.pos] == b'\\' {
129 self.parse_escape(start)?
130 } else {
131 let rest = &self.source[self.pos..];
132 let c = rest.chars().next().unwrap();
133 self.pos += c.len_utf8();
134 c
135 };
136
137 if self.pos >= self.bytes.len() || self.bytes[self.pos] != b'"' {
138 return Err(self.err("unterminated character literal, expected closing \"", start));
139 }
140 self.pos += 1; Ok(self.make_token(TokenKind::CharLit(c), start, self.pos))
143 }
144
145 fn parse_escape(&mut self, literal_start: usize) -> Result<char, LexError> {
146 self.pos += 1; if self.pos >= self.bytes.len() {
148 return Err(self.err("unterminated escape sequence", literal_start));
149 }
150 let c = match self.bytes[self.pos] {
151 b'n' => '\n',
152 b't' => '\t',
153 b'r' => '\r',
154 b'0' => '\0',
155 b'\\' => '\\',
156 b'"' => '"',
157 b'x' => {
158 if self.pos + 2 >= self.bytes.len() {
160 return Err(self.err("incomplete \\x escape", literal_start));
161 }
162 let hi = self.bytes[self.pos + 1];
163 let lo = self.bytes[self.pos + 2];
164 let val = hex_digit(hi)
165 .and_then(|h| hex_digit(lo).map(|l| h * 16 + l))
166 .ok_or_else(|| self.err("invalid hex digit in \\x escape", literal_start))?;
167 self.pos += 2; val as char
169 }
170 other => {
171 return Err(LexError {
172 message: format!("unknown escape sequence: \\{}", other as char),
173 span: self.span(self.pos - 1, self.pos + 1),
174 });
175 }
176 };
177 self.pos += 1;
178 Ok(c)
179 }
180
181 fn lex_number(&mut self) -> Result<Token, LexError> {
182 let start = self.pos;
183 self.consume_digits();
184
185 let mut is_float = false;
186
187 if self.pos < self.bytes.len()
189 && self.bytes[self.pos] == b'.'
190 && self
191 .bytes
192 .get(self.pos + 1)
193 .is_some_and(|c| c.is_ascii_digit())
194 {
195 is_float = true;
196 self.pos += 1; self.consume_digits();
198 }
199
200 if self.pos < self.bytes.len() && matches!(self.bytes[self.pos], b'e' | b'E') {
202 is_float = true;
203 self.pos += 1;
204 if self.pos < self.bytes.len() && matches!(self.bytes[self.pos], b'+' | b'-') {
205 self.pos += 1;
206 }
207 if self.pos >= self.bytes.len() || !self.bytes[self.pos].is_ascii_digit() {
208 return Err(self.err("expected digits after exponent", start));
209 }
210 self.consume_digits();
211 }
212
213 let text = &self.source[start..self.pos];
214 if is_float {
215 let value: f64 = text
216 .parse()
217 .map_err(|_| self.err(&format!("invalid float literal: {text}"), start))?;
218 Ok(self.make_token(TokenKind::FloatLit(value), start, self.pos))
219 } else {
220 let value: i64 = text
221 .parse()
222 .map_err(|_| self.err(&format!("invalid integer literal: {text}"), start))?;
223 Ok(self.make_token(TokenKind::IntLit(value), start, self.pos))
224 }
225 }
226
227 fn lex_tyvar(&mut self) -> Result<Token, LexError> {
228 let start = self.pos;
229 self.pos += 1; self.consume_ident_chars();
231 let name = self.source[start..self.pos].to_string();
232 Ok(self.make_token(TokenKind::TyVar(name), start, self.pos))
233 }
234
235 fn lex_ident(&mut self) -> Result<Token, LexError> {
236 let start = self.pos;
237 self.consume_ident_chars();
238 let text = &self.source[start..self.pos];
239
240 if text == "_" {
241 return Ok(self.make_token(TokenKind::Underscore, start, self.pos));
242 }
243
244 if let Some(kw) = TokenKind::keyword_from_str(text) {
245 return Ok(self.make_token(kw, start, self.pos));
246 }
247
248 let first = text.as_bytes()[0];
249 if first.is_ascii_uppercase() {
250 Ok(self.make_token(TokenKind::UpperIdent(text.to_string()), start, self.pos))
251 } else {
252 Ok(self.make_token(TokenKind::Ident(text.to_string()), start, self.pos))
253 }
254 }
255
256 fn lex_operator_or_delimiter(&mut self) -> Result<Token, LexError> {
257 let start = self.pos;
258 let ch = self.bytes[start];
259
260 let kind = match ch {
261 b'(' => {
262 self.pos += 1;
263 TokenKind::LParen
264 }
265 b')' => {
266 self.pos += 1;
267 TokenKind::RParen
268 }
269 b'[' => {
270 self.pos += 1;
271 TokenKind::LBracket
272 }
273 b']' => {
274 self.pos += 1;
275 TokenKind::RBracket
276 }
277 b',' => {
278 self.pos += 1;
279 TokenKind::Comma
280 }
281 b';' => {
282 self.pos += 1;
283 TokenKind::Semicolon
284 }
285 b'~' => {
286 self.pos += 1;
287 TokenKind::Tilde
288 }
289 b'#' => {
290 self.pos += 1;
291 TokenKind::Hash
292 }
293 b'^' => {
294 self.pos += 1;
295 TokenKind::Caret
296 }
297 b'|' => {
298 self.pos += 1;
299 TokenKind::Bar
300 }
301
302 b':' => {
303 self.pos += 1;
304 if self.peek() == Some(b':') {
305 self.pos += 1;
306 TokenKind::ColonColon
307 } else {
308 TokenKind::Colon
309 }
310 }
311
312 b'=' => {
313 self.pos += 1;
314 if self.peek() == Some(b'>') {
315 self.pos += 1;
316 TokenKind::Arrow
317 } else {
318 TokenKind::Eq
319 }
320 }
321
322 b'-' => {
323 self.pos += 1;
324 if self.peek() == Some(b'>') {
325 self.pos += 1;
326 TokenKind::ThinArrow
327 } else if self.peek() == Some(b'.') {
328 self.pos += 1;
329 TokenKind::MinusDot
330 } else {
331 TokenKind::Minus
332 }
333 }
334
335 b'+' => {
336 self.pos += 1;
337 if self.peek() == Some(b'.') {
338 self.pos += 1;
339 TokenKind::PlusDot
340 } else {
341 TokenKind::Plus
342 }
343 }
344
345 b'*' => {
346 self.pos += 1;
347 if self.peek() == Some(b'.') {
348 self.pos += 1;
349 TokenKind::StarDot
350 } else {
351 TokenKind::Star
352 }
353 }
354
355 b'/' => {
356 self.pos += 1;
357 if self.peek() == Some(b'.') {
358 self.pos += 1;
359 TokenKind::SlashDot
360 } else {
361 TokenKind::Slash
362 }
363 }
364
365 b'<' => {
366 self.pos += 1;
367 match self.peek() {
368 Some(b'>') => {
369 self.pos += 1;
370 TokenKind::Ne
371 }
372 Some(b'=') => {
373 self.pos += 1;
374 if self.peek() == Some(b'.') {
375 self.pos += 1;
376 TokenKind::LeDot
377 } else {
378 TokenKind::Le
379 }
380 }
381 Some(b'.') => {
382 self.pos += 1;
383 TokenKind::LtDot
384 }
385 _ => TokenKind::Lt,
386 }
387 }
388
389 b'>' => {
390 self.pos += 1;
391 match self.peek() {
392 Some(b'=') => {
393 self.pos += 1;
394 if self.peek() == Some(b'.') {
395 self.pos += 1;
396 TokenKind::GeDot
397 } else {
398 TokenKind::Ge
399 }
400 }
401 Some(b'.') => {
402 self.pos += 1;
403 TokenKind::GtDot
404 }
405 _ => TokenKind::Gt,
406 }
407 }
408
409 _ => {
410 self.pos += 1;
411 return Err(LexError {
412 message: format!("unexpected character: '{}'", ch as char),
413 span: self.span(start, self.pos),
414 });
415 }
416 };
417
418 Ok(self.make_token(kind, start, self.pos))
419 }
420
421 fn skip_whitespace_and_comments(&mut self) -> Result<(), LexError> {
422 loop {
423 while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_whitespace() {
425 self.pos += 1;
426 }
427
428 if self.pos + 1 < self.bytes.len()
430 && self.bytes[self.pos] == b'('
431 && self.bytes[self.pos + 1] == b'*'
432 {
433 self.skip_comment()?;
434 } else {
435 break;
436 }
437 }
438 Ok(())
439 }
440
441 fn skip_comment(&mut self) -> Result<(), LexError> {
442 let start = self.pos;
443 self.pos += 2; let mut depth = 1u32;
445
446 while self.pos < self.bytes.len() && depth > 0 {
447 if self.pos + 1 < self.bytes.len()
448 && self.bytes[self.pos] == b'('
449 && self.bytes[self.pos + 1] == b'*'
450 {
451 depth += 1;
452 self.pos += 2;
453 } else if self.pos + 1 < self.bytes.len()
454 && self.bytes[self.pos] == b'*'
455 && self.bytes[self.pos + 1] == b')'
456 {
457 depth -= 1;
458 self.pos += 2;
459 } else {
460 self.pos += 1;
461 }
462 }
463
464 if depth > 0 {
465 return Err(self.err("unterminated comment", start));
466 }
467 Ok(())
468 }
469
470 fn consume_digits(&mut self) {
471 while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_digit() {
472 self.pos += 1;
473 }
474 }
475
476 fn consume_ident_chars(&mut self) {
477 while self.pos < self.bytes.len()
478 && (self.bytes[self.pos].is_ascii_alphanumeric() || self.bytes[self.pos] == b'_')
479 {
480 self.pos += 1;
481 }
482 }
483
484 fn peek(&self) -> Option<u8> {
485 self.bytes.get(self.pos).copied()
486 }
487
488 fn peek_at(&self, offset: usize) -> Option<u8> {
489 self.bytes.get(self.pos + offset).copied()
490 }
491
492 fn span(&self, start: usize, end: usize) -> Span {
493 Span::new(self.file_id, start as u32, end as u32)
494 }
495
496 fn err(&self, message: &str, start: usize) -> LexError {
497 LexError {
498 message: message.to_string(),
499 span: self.span(start, self.pos),
500 }
501 }
502
503 fn make_token(&self, kind: TokenKind, start: usize, end: usize) -> Token {
504 Token {
505 kind,
506 span: self.span(start, end),
507 }
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 fn lex(input: &str) -> Vec<TokenKind> {
516 let mut lexer = Lexer::new(input, 0);
517 lexer
518 .tokenize()
519 .unwrap()
520 .into_iter()
521 .map(|t| t.kind)
522 .collect()
523 }
524
525 fn lex_err(input: &str) -> String {
526 let mut lexer = Lexer::new(input, 0);
527 lexer.tokenize().unwrap_err().message
528 }
529
530 #[test]
531 fn test_int_literals() {
532 assert_eq!(lex("42"), vec![TokenKind::IntLit(42), TokenKind::Eof]);
533 assert_eq!(lex("0"), vec![TokenKind::IntLit(0), TokenKind::Eof]);
534 assert_eq!(lex("12345"), vec![TokenKind::IntLit(12345), TokenKind::Eof]);
535 }
536
537 #[test]
538 fn test_float_literals() {
539 assert_eq!(lex("3.14"), vec![TokenKind::FloatLit(3.14), TokenKind::Eof]);
540 assert_eq!(
541 lex("1.0e10"),
542 vec![TokenKind::FloatLit(1.0e10), TokenKind::Eof]
543 );
544 assert_eq!(
545 lex("2.5E-3"),
546 vec![TokenKind::FloatLit(2.5e-3), TokenKind::Eof]
547 );
548 assert_eq!(lex("0.0"), vec![TokenKind::FloatLit(0.0), TokenKind::Eof]);
549 }
550
551 #[test]
552 fn test_string_literals() {
553 assert_eq!(
554 lex(r#""hello""#),
555 vec![TokenKind::StringLit("hello".to_string()), TokenKind::Eof]
556 );
557 assert_eq!(
558 lex(r#""a\nb""#),
559 vec![TokenKind::StringLit("a\nb".to_string()), TokenKind::Eof]
560 );
561 assert_eq!(
562 lex(r#""a\\b""#),
563 vec![TokenKind::StringLit("a\\b".to_string()), TokenKind::Eof]
564 );
565 }
566
567 #[test]
568 fn test_char_literals() {
569 assert_eq!(
570 lex(r#"#"a""#),
571 vec![TokenKind::CharLit('a'), TokenKind::Eof]
572 );
573 assert_eq!(
574 lex(r#"#"\n""#),
575 vec![TokenKind::CharLit('\n'), TokenKind::Eof]
576 );
577 }
578
579 #[test]
580 fn test_keywords() {
581 assert_eq!(lex("val"), vec![TokenKind::Val, TokenKind::Eof]);
582 assert_eq!(lex("fun"), vec![TokenKind::Fun, TokenKind::Eof]);
583 assert_eq!(lex("fn"), vec![TokenKind::Fn, TokenKind::Eof]);
584 assert_eq!(lex("let"), vec![TokenKind::Let, TokenKind::Eof]);
585 assert_eq!(lex("case"), vec![TokenKind::Case, TokenKind::Eof]);
586 assert_eq!(lex("datatype"), vec![TokenKind::Datatype, TokenKind::Eof]);
587 assert_eq!(lex("andalso"), vec![TokenKind::Andalso, TokenKind::Eof]);
588 assert_eq!(lex("orelse"), vec![TokenKind::Orelse, TokenKind::Eof]);
589 }
590
591 #[test]
592 fn test_identifiers() {
593 assert_eq!(
594 lex("foo"),
595 vec![TokenKind::Ident("foo".to_string()), TokenKind::Eof]
596 );
597 assert_eq!(
598 lex("x1"),
599 vec![TokenKind::Ident("x1".to_string()), TokenKind::Eof]
600 );
601 assert_eq!(
602 lex("_bar"),
603 vec![TokenKind::Ident("_bar".to_string()), TokenKind::Eof]
604 );
605 }
606
607 #[test]
608 fn test_upper_idents() {
609 assert_eq!(
610 lex("Some"),
611 vec![TokenKind::UpperIdent("Some".to_string()), TokenKind::Eof]
612 );
613 assert_eq!(
614 lex("None"),
615 vec![TokenKind::UpperIdent("None".to_string()), TokenKind::Eof]
616 );
617 }
618
619 #[test]
620 fn test_tyvars() {
621 assert_eq!(
622 lex("'a"),
623 vec![TokenKind::TyVar("'a".to_string()), TokenKind::Eof]
624 );
625 assert_eq!(
626 lex("'abc"),
627 vec![TokenKind::TyVar("'abc".to_string()), TokenKind::Eof]
628 );
629 }
630
631 #[test]
632 fn test_operators() {
633 assert_eq!(
634 lex("+ +. - -. * *. / /."),
635 vec![
636 TokenKind::Plus,
637 TokenKind::PlusDot,
638 TokenKind::Minus,
639 TokenKind::MinusDot,
640 TokenKind::Star,
641 TokenKind::StarDot,
642 TokenKind::Slash,
643 TokenKind::SlashDot,
644 TokenKind::Eof,
645 ]
646 );
647 }
648
649 #[test]
650 fn test_comparison_operators() {
651 assert_eq!(
652 lex("< <. <= <=. > >. >= >=."),
653 vec![
654 TokenKind::Lt,
655 TokenKind::LtDot,
656 TokenKind::Le,
657 TokenKind::LeDot,
658 TokenKind::Gt,
659 TokenKind::GtDot,
660 TokenKind::Ge,
661 TokenKind::GeDot,
662 TokenKind::Eof,
663 ]
664 );
665 }
666
667 #[test]
668 fn test_equality_operators() {
669 assert_eq!(
670 lex("= <>"),
671 vec![TokenKind::Eq, TokenKind::Ne, TokenKind::Eof]
672 );
673 }
674
675 #[test]
676 fn test_arrows_and_cons() {
677 assert_eq!(
678 lex("=> -> ::"),
679 vec![
680 TokenKind::Arrow,
681 TokenKind::ThinArrow,
682 TokenKind::ColonColon,
683 TokenKind::Eof,
684 ]
685 );
686 }
687
688 #[test]
689 fn test_delimiters() {
690 assert_eq!(
691 lex("( ) [ ] , : ; | _ #"),
692 vec![
693 TokenKind::LParen,
694 TokenKind::RParen,
695 TokenKind::LBracket,
696 TokenKind::RBracket,
697 TokenKind::Comma,
698 TokenKind::Colon,
699 TokenKind::Semicolon,
700 TokenKind::Bar,
701 TokenKind::Underscore,
702 TokenKind::Hash,
703 TokenKind::Eof,
704 ]
705 );
706 }
707
708 #[test]
709 fn test_comments() {
710 assert_eq!(
711 lex("(* comment *) 42"),
712 vec![TokenKind::IntLit(42), TokenKind::Eof]
713 );
714 }
715
716 #[test]
717 fn test_nested_comments() {
718 assert_eq!(
719 lex("(* outer (* inner *) still outer *) 1"),
720 vec![TokenKind::IntLit(1), TokenKind::Eof],
721 );
722 }
723
724 #[test]
725 fn test_unterminated_comment() {
726 assert_eq!(lex_err("(* oops"), "unterminated comment");
727 }
728
729 #[test]
730 fn test_unterminated_string() {
731 assert_eq!(lex_err(r#""oops"#), "unterminated string literal");
732 }
733
734 #[test]
735 fn test_unexpected_char() {
736 assert_eq!(lex_err("@"), "unexpected character: '@'");
737 }
738
739 #[test]
740 fn test_negation_is_operator() {
741 assert_eq!(
742 lex("~42"),
743 vec![TokenKind::Tilde, TokenKind::IntLit(42), TokenKind::Eof,]
744 );
745 }
746
747 #[test]
748 fn test_full_expression() {
749 let tokens = lex("fun fib n = if n < 2 then n else fib (n - 1) + fib (n - 2)");
750 assert_eq!(tokens[0], TokenKind::Fun);
751 assert_eq!(tokens[1], TokenKind::Ident("fib".to_string()));
752 assert_eq!(tokens[2], TokenKind::Ident("n".to_string()));
753 assert_eq!(tokens[3], TokenKind::Eq);
754 assert_eq!(tokens[4], TokenKind::If);
755 }
756
757 #[test]
758 fn test_empty_input() {
759 assert_eq!(lex(""), vec![TokenKind::Eof]);
760 }
761
762 #[test]
763 fn test_comments_only() {
764 assert_eq!(lex("(* just a comment *)"), vec![TokenKind::Eof]);
765 }
766
767 #[test]
768 fn test_datatype_declaration() {
769 let tokens = lex("datatype 'a option = None | Some of 'a");
770 assert_eq!(tokens[0], TokenKind::Datatype);
771 assert_eq!(tokens[1], TokenKind::TyVar("'a".to_string()));
772 assert_eq!(tokens[2], TokenKind::Ident("option".to_string()));
773 assert_eq!(tokens[3], TokenKind::Eq);
774 assert_eq!(tokens[4], TokenKind::UpperIdent("None".to_string()));
775 assert_eq!(tokens[5], TokenKind::Bar);
776 assert_eq!(tokens[6], TokenKind::UpperIdent("Some".to_string()));
777 assert_eq!(tokens[7], TokenKind::Of);
778 assert_eq!(tokens[8], TokenKind::TyVar("'a".to_string()));
779 }
780}