Skip to main content

kyu_parser/
lexer.rs

1use smol_str::SmolStr;
2
3use crate::span::{Span, Spanned};
4use crate::token::{lookup_keyword, Token};
5
6/// Lexical analysis error.
7#[derive(Debug, Clone)]
8pub struct LexError {
9    pub span: Span,
10    pub message: String,
11}
12
13/// Hand-written Cypher lexer producing a flat token stream.
14///
15/// Separating lex from parse gives precise byte positions for every
16/// error span and avoids running combinators on individual characters.
17pub struct Lexer<'src> {
18    source: &'src [u8],
19    pos: usize,
20    tokens: Vec<Spanned<Token>>,
21    errors: Vec<LexError>,
22}
23
24impl<'src> Lexer<'src> {
25    pub fn new(source: &'src str) -> Self {
26        Self {
27            source: source.as_bytes(),
28            pos: 0,
29            tokens: Vec::new(),
30            errors: Vec::new(),
31        }
32    }
33
34    pub fn lex(mut self) -> (Vec<Spanned<Token>>, Vec<LexError>) {
35        while !self.is_at_end() {
36            self.skip_whitespace_and_comments();
37            if self.is_at_end() {
38                break;
39            }
40
41            let start = self.pos;
42            let ch = self.advance();
43
44            match ch {
45                b'(' => self.push(Token::LeftParen, start),
46                b')' => self.push(Token::RightParen, start),
47                b'[' => self.push(Token::LeftBracket, start),
48                b']' => self.push(Token::RightBracket, start),
49                b'{' => self.push(Token::LeftBrace, start),
50                b'}' => self.push(Token::RightBrace, start),
51                b',' => self.push(Token::Comma, start),
52                b';' => self.push(Token::Semicolon, start),
53                b':' => self.push(Token::Colon, start),
54                b'|' => self.push(Token::Pipe, start),
55                b'*' => self.push(Token::Star, start),
56                b'%' => self.push(Token::Percent, start),
57                b'^' => self.push(Token::Caret, start),
58                b'&' => self.push(Token::Ampersand, start),
59                b'~' => self.push(Token::Tilde, start),
60                b'!' => self.push(Token::Exclaim, start),
61
62                b'.' => {
63                    if self.peek() == Some(b'.') {
64                        self.advance();
65                        self.push(Token::DoubleDot, start);
66                    } else {
67                        self.push(Token::Dot, start);
68                    }
69                }
70
71                b'+' => {
72                    if self.peek() == Some(b'=') {
73                        self.advance();
74                        self.push(Token::PlusEq, start);
75                    } else {
76                        self.push(Token::Plus, start);
77                    }
78                }
79
80                b'=' => {
81                    if self.peek() == Some(b'~') {
82                        self.advance();
83                        self.push(Token::RegexMatch, start);
84                    } else {
85                        self.push(Token::Eq, start);
86                    }
87                }
88
89                b'<' => {
90                    match self.peek() {
91                        Some(b'=') => {
92                            self.advance();
93                            self.push(Token::Le, start);
94                        }
95                        Some(b'>') => {
96                            self.advance();
97                            self.push(Token::Neq, start);
98                        }
99                        Some(b'<') => {
100                            self.advance();
101                            self.push(Token::ShiftLeft, start);
102                        }
103                        Some(b'-') => {
104                            self.advance();
105                            self.push(Token::LeftArrow, start);
106                        }
107                        _ => self.push(Token::Lt, start),
108                    }
109                }
110
111                b'>' => {
112                    match self.peek() {
113                        Some(b'=') => {
114                            self.advance();
115                            self.push(Token::Ge, start);
116                        }
117                        Some(b'>') => {
118                            self.advance();
119                            self.push(Token::ShiftRight, start);
120                        }
121                        _ => self.push(Token::Gt, start),
122                    }
123                }
124
125                b'-' => {
126                    if self.peek() == Some(b'>') {
127                        self.advance();
128                        self.push(Token::Arrow, start);
129                    } else {
130                        self.push(Token::Dash, start);
131                    }
132                }
133
134                b'/' => {
135                    // Forward slash as division — comments already handled in skip_whitespace
136                    self.push(Token::Slash, start);
137                }
138
139                b'\'' | b'"' => self.lex_string(ch, start),
140
141                b'`' => self.lex_escaped_ident(start),
142
143                b'$' => self.lex_parameter(start),
144
145                b'0'..=b'9' => self.lex_number(start),
146
147                b'a'..=b'z' | b'A'..=b'Z' | b'_' => self.lex_ident_or_keyword(start),
148
149                _ => {
150                    self.errors.push(LexError {
151                        span: start..self.pos,
152                        message: format!("unexpected character '{}'", ch as char),
153                    });
154                }
155            }
156        }
157
158        self.tokens.push((Token::Eof, self.pos..self.pos));
159        (self.tokens, self.errors)
160    }
161
162    fn is_at_end(&self) -> bool {
163        self.pos >= self.source.len()
164    }
165
166    fn peek(&self) -> Option<u8> {
167        self.source.get(self.pos).copied()
168    }
169
170    fn advance(&mut self) -> u8 {
171        let ch = self.source[self.pos];
172        self.pos += 1;
173        ch
174    }
175
176    fn push(&mut self, token: Token, start: usize) {
177        self.tokens.push((token, start..self.pos));
178    }
179
180    fn skip_whitespace_and_comments(&mut self) {
181        while !self.is_at_end() {
182            let ch = self.source[self.pos];
183            match ch {
184                b' ' | b'\t' | b'\n' | b'\r' => {
185                    self.pos += 1;
186                }
187                b'/' => {
188                    if self.pos + 1 < self.source.len() {
189                        match self.source[self.pos + 1] {
190                            b'/' => {
191                                // Line comment: skip to end of line
192                                self.pos += 2;
193                                while !self.is_at_end() && self.source[self.pos] != b'\n' {
194                                    self.pos += 1;
195                                }
196                            }
197                            b'*' => {
198                                // Block comment: skip to */
199                                let start = self.pos;
200                                self.pos += 2;
201                                let mut depth = 1;
202                                while !self.is_at_end() && depth > 0 {
203                                    if self.source[self.pos] == b'*'
204                                        && self.pos + 1 < self.source.len()
205                                        && self.source[self.pos + 1] == b'/'
206                                    {
207                                        depth -= 1;
208                                        self.pos += 2;
209                                    } else if self.source[self.pos] == b'/'
210                                        && self.pos + 1 < self.source.len()
211                                        && self.source[self.pos + 1] == b'*'
212                                    {
213                                        depth += 1;
214                                        self.pos += 2;
215                                    } else {
216                                        self.pos += 1;
217                                    }
218                                }
219                                if depth > 0 {
220                                    self.errors.push(LexError {
221                                        span: start..self.pos,
222                                        message: "unterminated block comment".to_string(),
223                                    });
224                                }
225                            }
226                            _ => break,
227                        }
228                    } else {
229                        break;
230                    }
231                }
232                _ => break,
233            }
234        }
235    }
236
237    fn lex_string(&mut self, quote: u8, start: usize) {
238        let mut value = String::new();
239        loop {
240            if self.is_at_end() {
241                self.errors.push(LexError {
242                    span: start..self.pos,
243                    message: "unterminated string literal".to_string(),
244                });
245                break;
246            }
247            let ch = self.advance();
248            if ch == quote {
249                // Check for escaped quote (doubled): '' or ""
250                if self.peek() == Some(quote) {
251                    self.advance();
252                    value.push(quote as char);
253                } else {
254                    break;
255                }
256            } else if ch == b'\\' {
257                // Backslash escape
258                if self.is_at_end() {
259                    self.errors.push(LexError {
260                        span: start..self.pos,
261                        message: "unterminated string escape".to_string(),
262                    });
263                    break;
264                }
265                let esc = self.advance();
266                match esc {
267                    b'n' => value.push('\n'),
268                    b't' => value.push('\t'),
269                    b'r' => value.push('\r'),
270                    b'\\' => value.push('\\'),
271                    b'\'' => value.push('\''),
272                    b'"' => value.push('"'),
273                    b'0' => value.push('\0'),
274                    _ => {
275                        value.push('\\');
276                        value.push(esc as char);
277                    }
278                }
279            } else {
280                value.push(ch as char);
281            }
282        }
283        self.push(Token::StringLiteral(SmolStr::new(&value)), start);
284    }
285
286    fn lex_escaped_ident(&mut self, start: usize) {
287        let mut value = String::new();
288        loop {
289            if self.is_at_end() {
290                self.errors.push(LexError {
291                    span: start..self.pos,
292                    message: "unterminated escaped identifier".to_string(),
293                });
294                break;
295            }
296            let ch = self.advance();
297            if ch == b'`' {
298                // Doubled backtick = literal backtick
299                if self.peek() == Some(b'`') {
300                    self.advance();
301                    value.push('`');
302                } else {
303                    break;
304                }
305            } else {
306                value.push(ch as char);
307            }
308        }
309        self.push(Token::EscapedIdent(SmolStr::new(&value)), start);
310    }
311
312    fn lex_parameter(&mut self, start: usize) {
313        let name_start = self.pos;
314        while !self.is_at_end() && is_ident_continue(self.source[self.pos]) {
315            self.pos += 1;
316        }
317        let name = std::str::from_utf8(&self.source[name_start..self.pos]).unwrap_or("");
318        if name.is_empty() {
319            self.errors.push(LexError {
320                span: start..self.pos,
321                message: "expected parameter name after '$'".to_string(),
322            });
323        } else {
324            self.push(Token::Parameter(SmolStr::new(name)), start);
325        }
326    }
327
328    fn lex_number(&mut self, start: usize) {
329        // Consume integer part
330        while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
331            self.pos += 1;
332        }
333
334        let mut is_float = false;
335
336        // Check for fractional part
337        if self.peek() == Some(b'.')
338            && self
339                .source
340                .get(self.pos + 1)
341                .is_some_and(|c| c.is_ascii_digit())
342        {
343            is_float = true;
344            self.pos += 1; // consume '.'
345            while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
346                self.pos += 1;
347            }
348        }
349
350        // Check for exponent
351        if self.peek() == Some(b'e') || self.peek() == Some(b'E') {
352            is_float = true;
353            self.pos += 1;
354            if self.peek() == Some(b'+') || self.peek() == Some(b'-') {
355                self.pos += 1;
356            }
357            while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
358                self.pos += 1;
359            }
360        }
361
362        let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap_or("0");
363
364        if is_float {
365            self.push(Token::Float(SmolStr::new(text)), start);
366        } else {
367            match text.parse::<i64>() {
368                Ok(n) => self.push(Token::Integer(n), start),
369                Err(_) => {
370                    self.errors.push(LexError {
371                        span: start..self.pos,
372                        message: format!("integer literal too large: {text}"),
373                    });
374                }
375            }
376        }
377    }
378
379    fn lex_ident_or_keyword(&mut self, start: usize) {
380        while !self.is_at_end() && is_ident_continue(self.source[self.pos]) {
381            self.pos += 1;
382        }
383
384        let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap_or("");
385
386        if let Some(kw) = lookup_keyword(text) {
387            self.push(kw, start);
388        } else {
389            self.push(Token::Ident(SmolStr::new(text)), start);
390        }
391    }
392}
393
394fn is_ident_continue(ch: u8) -> bool {
395    ch.is_ascii_alphanumeric() || ch == b'_'
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    fn lex(src: &str) -> Vec<Token> {
403        let (tokens, errors) = Lexer::new(src).lex();
404        assert!(errors.is_empty(), "unexpected lex errors: {errors:?}");
405        tokens.into_iter().map(|(tok, _)| tok).collect()
406    }
407
408    fn lex_with_errors(src: &str) -> (Vec<Token>, Vec<LexError>) {
409        let (tokens, errors) = Lexer::new(src).lex();
410        let toks = tokens.into_iter().map(|(tok, _)| tok).collect();
411        (toks, errors)
412    }
413
414    #[test]
415    fn empty_input() {
416        let tokens = lex("");
417        assert_eq!(tokens, vec![Token::Eof]);
418    }
419
420    #[test]
421    fn single_char_tokens() {
422        let tokens = lex("( ) [ ] { } , ; : | * % ^ & ~");
423        assert_eq!(
424            tokens,
425            vec![
426                Token::LeftParen,
427                Token::RightParen,
428                Token::LeftBracket,
429                Token::RightBracket,
430                Token::LeftBrace,
431                Token::RightBrace,
432                Token::Comma,
433                Token::Semicolon,
434                Token::Colon,
435                Token::Pipe,
436                Token::Star,
437                Token::Percent,
438                Token::Caret,
439                Token::Ampersand,
440                Token::Tilde,
441                Token::Eof,
442            ]
443        );
444    }
445
446    #[test]
447    fn multi_char_operators() {
448        let tokens = lex("-> <- .. << >> =~ += <= >= <>");
449        assert_eq!(
450            tokens,
451            vec![
452                Token::Arrow,
453                Token::LeftArrow,
454                Token::DoubleDot,
455                Token::ShiftLeft,
456                Token::ShiftRight,
457                Token::RegexMatch,
458                Token::PlusEq,
459                Token::Le,
460                Token::Ge,
461                Token::Neq,
462                Token::Eof,
463            ]
464        );
465    }
466
467    #[test]
468    fn integer_literals() {
469        let tokens = lex("0 42 123456789");
470        assert_eq!(
471            tokens,
472            vec![
473                Token::Integer(0),
474                Token::Integer(42),
475                Token::Integer(123456789),
476                Token::Eof,
477            ]
478        );
479    }
480
481    #[test]
482    fn float_literals() {
483        let tokens = lex("3.14 1.0e10 2.5E-3");
484        assert_eq!(
485            tokens,
486            vec![
487                Token::Float(SmolStr::new("3.14")),
488                Token::Float(SmolStr::new("1.0e10")),
489                Token::Float(SmolStr::new("2.5E-3")),
490                Token::Eof,
491            ]
492        );
493    }
494
495    #[test]
496    fn string_literals() {
497        let tokens = lex("'hello' \"world\"");
498        assert_eq!(
499            tokens,
500            vec![
501                Token::StringLiteral(SmolStr::new("hello")),
502                Token::StringLiteral(SmolStr::new("world")),
503                Token::Eof,
504            ]
505        );
506    }
507
508    #[test]
509    fn string_escape_sequences() {
510        let tokens = lex(r#"'he\'s' "tab\there""#);
511        assert_eq!(
512            tokens,
513            vec![
514                Token::StringLiteral(SmolStr::new("he's")),
515                Token::StringLiteral(SmolStr::new("tab\there")),
516                Token::Eof,
517            ]
518        );
519    }
520
521    #[test]
522    fn string_doubled_quotes() {
523        let tokens = lex("'it''s'");
524        assert_eq!(
525            tokens,
526            vec![Token::StringLiteral(SmolStr::new("it's")), Token::Eof]
527        );
528    }
529
530    #[test]
531    fn identifiers() {
532        let tokens = lex("foo _bar baz123");
533        assert_eq!(
534            tokens,
535            vec![
536                Token::Ident(SmolStr::new("foo")),
537                Token::Ident(SmolStr::new("_bar")),
538                Token::Ident(SmolStr::new("baz123")),
539                Token::Eof,
540            ]
541        );
542    }
543
544    #[test]
545    fn escaped_identifiers() {
546        let tokens = lex("`my column` `has``backtick`");
547        assert_eq!(
548            tokens,
549            vec![
550                Token::EscapedIdent(SmolStr::new("my column")),
551                Token::EscapedIdent(SmolStr::new("has`backtick")),
552                Token::Eof,
553            ]
554        );
555    }
556
557    #[test]
558    fn parameters() {
559        let tokens = lex("$param1 $since");
560        assert_eq!(
561            tokens,
562            vec![
563                Token::Parameter(SmolStr::new("param1")),
564                Token::Parameter(SmolStr::new("since")),
565                Token::Eof,
566            ]
567        );
568    }
569
570    #[test]
571    fn keywords_case_insensitive() {
572        let tokens = lex("MATCH Match match WHERE where");
573        assert_eq!(
574            tokens,
575            vec![
576                Token::Match,
577                Token::Match,
578                Token::Match,
579                Token::Where,
580                Token::Where,
581                Token::Eof,
582            ]
583        );
584    }
585
586    #[test]
587    fn boolean_and_null() {
588        let tokens = lex("TRUE false NULL");
589        assert_eq!(
590            tokens,
591            vec![Token::True, Token::False, Token::Null, Token::Eof]
592        );
593    }
594
595    #[test]
596    fn line_comments() {
597        let tokens = lex("MATCH // this is a comment\n(n)");
598        assert_eq!(
599            tokens,
600            vec![
601                Token::Match,
602                Token::LeftParen,
603                Token::Ident(SmolStr::new("n")),
604                Token::RightParen,
605                Token::Eof,
606            ]
607        );
608    }
609
610    #[test]
611    fn block_comments() {
612        let tokens = lex("MATCH /* comment */ (n)");
613        assert_eq!(
614            tokens,
615            vec![
616                Token::Match,
617                Token::LeftParen,
618                Token::Ident(SmolStr::new("n")),
619                Token::RightParen,
620                Token::Eof,
621            ]
622        );
623    }
624
625    #[test]
626    fn full_query() {
627        let tokens = lex("MATCH (n:Person) WHERE n.age > 30 RETURN n.name");
628        // Check that we get the right sequence of tokens
629        assert_eq!(tokens[0], Token::Match);
630        assert_eq!(tokens[1], Token::LeftParen);
631        assert_eq!(tokens[2], Token::Ident(SmolStr::new("n")));
632        assert_eq!(tokens[3], Token::Colon);
633        assert_eq!(tokens[4], Token::Ident(SmolStr::new("Person")));
634        assert_eq!(tokens[5], Token::RightParen);
635        assert_eq!(tokens[6], Token::Where);
636        assert_eq!(tokens[7], Token::Ident(SmolStr::new("n")));
637        assert_eq!(tokens[8], Token::Dot);
638        assert_eq!(tokens[9], Token::Ident(SmolStr::new("age")));
639        assert_eq!(tokens[10], Token::Gt);
640        assert_eq!(tokens[11], Token::Integer(30));
641        assert_eq!(tokens[12], Token::Return);
642        assert_eq!(tokens[13], Token::Ident(SmolStr::new("n")));
643        assert_eq!(tokens[14], Token::Dot);
644        assert_eq!(tokens[15], Token::Ident(SmolStr::new("name")));
645        assert_eq!(tokens[16], Token::Eof);
646    }
647
648    #[test]
649    fn relationship_arrows() {
650        let tokens = lex("(a)-[:KNOWS]->(b)<-[:LIKES]-(c)");
651        assert!(tokens.contains(&Token::Arrow));
652        assert!(tokens.contains(&Token::LeftArrow));
653        assert!(tokens.contains(&Token::Dash));
654    }
655
656    #[test]
657    fn unexpected_char_reports_error() {
658        let (tokens, errors) = lex_with_errors("MATCH @invalid");
659        assert!(!errors.is_empty());
660        assert!(errors[0].message.contains("unexpected character"));
661        // Lexing continues past the error
662        assert!(tokens.len() > 1);
663    }
664
665    #[test]
666    fn unterminated_string_reports_error() {
667        let (_tokens, errors) = lex_with_errors("'unterminated");
668        assert!(!errors.is_empty());
669        assert!(errors[0].message.contains("unterminated string"));
670    }
671
672    #[test]
673    fn spans_are_correct() {
674        let (tokens, _) = Lexer::new("MATCH (n)").lex();
675        assert_eq!(tokens[0].1, 0..5); // MATCH
676        assert_eq!(tokens[1].1, 6..7); // (
677        assert_eq!(tokens[2].1, 7..8); // n
678        assert_eq!(tokens[3].1, 8..9); // )
679    }
680
681    #[test]
682    fn dash_vs_negative_number() {
683        // In Cypher, `-` before a number is a dash (operator), not a negative sign.
684        // The parser handles unary minus. The lexer always emits Dash.
685        let tokens = lex("-42");
686        assert_eq!(tokens[0], Token::Dash);
687        assert_eq!(tokens[1], Token::Integer(42));
688    }
689
690    #[test]
691    fn dot_after_integer_not_float() {
692        // `n.age` — the dot is property access, not a decimal point
693        let tokens = lex("n.age");
694        assert_eq!(
695            tokens,
696            vec![
697                Token::Ident(SmolStr::new("n")),
698                Token::Dot,
699                Token::Ident(SmolStr::new("age")),
700                Token::Eof,
701            ]
702        );
703    }
704}