Skip to main content

kyu_parser/
lexer.rs

1use smol_str::SmolStr;
2
3use crate::span::{Span, Spanned};
4use crate::token::{Token, lookup_keyword};
5
6/// Lexical analysis error.
7#[derive(Debug, Clone)]
8pub struct LexError {
9    pub span: Span,
10    pub message: String,
11}
12
13/// Hand-written Cypher lexer producing a flat token stream.
14///
15/// Separating lex from parse gives precise byte positions for every
16/// error span and avoids running combinators on individual characters.
17pub struct Lexer<'src> {
18    source: &'src [u8],
19    pos: usize,
20    tokens: Vec<Spanned<Token>>,
21    errors: Vec<LexError>,
22}
23
24impl<'src> Lexer<'src> {
25    pub fn new(source: &'src str) -> Self {
26        Self {
27            source: source.as_bytes(),
28            pos: 0,
29            tokens: Vec::new(),
30            errors: Vec::new(),
31        }
32    }
33
34    pub fn lex(mut self) -> (Vec<Spanned<Token>>, Vec<LexError>) {
35        while !self.is_at_end() {
36            self.skip_whitespace_and_comments();
37            if self.is_at_end() {
38                break;
39            }
40
41            let start = self.pos;
42            let ch = self.advance();
43
44            match ch {
45                b'(' => self.push(Token::LeftParen, start),
46                b')' => self.push(Token::RightParen, start),
47                b'[' => self.push(Token::LeftBracket, start),
48                b']' => self.push(Token::RightBracket, start),
49                b'{' => self.push(Token::LeftBrace, start),
50                b'}' => self.push(Token::RightBrace, start),
51                b',' => self.push(Token::Comma, start),
52                b';' => self.push(Token::Semicolon, start),
53                b':' => self.push(Token::Colon, start),
54                b'|' => self.push(Token::Pipe, start),
55                b'*' => self.push(Token::Star, start),
56                b'%' => self.push(Token::Percent, start),
57                b'^' => self.push(Token::Caret, start),
58                b'&' => self.push(Token::Ampersand, start),
59                b'~' => self.push(Token::Tilde, start),
60                b'!' => self.push(Token::Exclaim, start),
61
62                b'.' => {
63                    if self.peek() == Some(b'.') {
64                        self.advance();
65                        self.push(Token::DoubleDot, start);
66                    } else {
67                        self.push(Token::Dot, start);
68                    }
69                }
70
71                b'+' => {
72                    if self.peek() == Some(b'=') {
73                        self.advance();
74                        self.push(Token::PlusEq, start);
75                    } else {
76                        self.push(Token::Plus, start);
77                    }
78                }
79
80                b'=' => {
81                    if self.peek() == Some(b'~') {
82                        self.advance();
83                        self.push(Token::RegexMatch, start);
84                    } else {
85                        self.push(Token::Eq, start);
86                    }
87                }
88
89                b'<' => match self.peek() {
90                    Some(b'=') => {
91                        self.advance();
92                        self.push(Token::Le, start);
93                    }
94                    Some(b'>') => {
95                        self.advance();
96                        self.push(Token::Neq, start);
97                    }
98                    Some(b'<') => {
99                        self.advance();
100                        self.push(Token::ShiftLeft, start);
101                    }
102                    Some(b'-') => {
103                        self.advance();
104                        self.push(Token::LeftArrow, start);
105                    }
106                    _ => self.push(Token::Lt, start),
107                },
108
109                b'>' => match self.peek() {
110                    Some(b'=') => {
111                        self.advance();
112                        self.push(Token::Ge, start);
113                    }
114                    Some(b'>') => {
115                        self.advance();
116                        self.push(Token::ShiftRight, start);
117                    }
118                    _ => self.push(Token::Gt, start),
119                },
120
121                b'-' => {
122                    if self.peek() == Some(b'>') {
123                        self.advance();
124                        self.push(Token::Arrow, start);
125                    } else {
126                        self.push(Token::Dash, start);
127                    }
128                }
129
130                b'/' => {
131                    // Forward slash as division — comments already handled in skip_whitespace
132                    self.push(Token::Slash, start);
133                }
134
135                b'\'' | b'"' => self.lex_string(ch, start),
136
137                b'`' => self.lex_escaped_ident(start),
138
139                b'$' => self.lex_parameter(start),
140
141                b'0'..=b'9' => self.lex_number(start),
142
143                b'a'..=b'z' | b'A'..=b'Z' | b'_' => self.lex_ident_or_keyword(start),
144
145                _ => {
146                    self.errors.push(LexError {
147                        span: start..self.pos,
148                        message: format!("unexpected character '{}'", ch as char),
149                    });
150                }
151            }
152        }
153
154        self.tokens.push((Token::Eof, self.pos..self.pos));
155        (self.tokens, self.errors)
156    }
157
158    fn is_at_end(&self) -> bool {
159        self.pos >= self.source.len()
160    }
161
162    fn peek(&self) -> Option<u8> {
163        self.source.get(self.pos).copied()
164    }
165
166    fn advance(&mut self) -> u8 {
167        let ch = self.source[self.pos];
168        self.pos += 1;
169        ch
170    }
171
172    fn push(&mut self, token: Token, start: usize) {
173        self.tokens.push((token, start..self.pos));
174    }
175
176    fn skip_whitespace_and_comments(&mut self) {
177        while !self.is_at_end() {
178            let ch = self.source[self.pos];
179            match ch {
180                b' ' | b'\t' | b'\n' | b'\r' => {
181                    self.pos += 1;
182                }
183                b'/' => {
184                    if self.pos + 1 < self.source.len() {
185                        match self.source[self.pos + 1] {
186                            b'/' => {
187                                // Line comment: skip to end of line
188                                self.pos += 2;
189                                while !self.is_at_end() && self.source[self.pos] != b'\n' {
190                                    self.pos += 1;
191                                }
192                            }
193                            b'*' => {
194                                // Block comment: skip to */
195                                let start = self.pos;
196                                self.pos += 2;
197                                let mut depth = 1;
198                                while !self.is_at_end() && depth > 0 {
199                                    if self.source[self.pos] == b'*'
200                                        && self.pos + 1 < self.source.len()
201                                        && self.source[self.pos + 1] == b'/'
202                                    {
203                                        depth -= 1;
204                                        self.pos += 2;
205                                    } else if self.source[self.pos] == b'/'
206                                        && self.pos + 1 < self.source.len()
207                                        && self.source[self.pos + 1] == b'*'
208                                    {
209                                        depth += 1;
210                                        self.pos += 2;
211                                    } else {
212                                        self.pos += 1;
213                                    }
214                                }
215                                if depth > 0 {
216                                    self.errors.push(LexError {
217                                        span: start..self.pos,
218                                        message: "unterminated block comment".to_string(),
219                                    });
220                                }
221                            }
222                            _ => break,
223                        }
224                    } else {
225                        break;
226                    }
227                }
228                _ => break,
229            }
230        }
231    }
232
233    fn lex_string(&mut self, quote: u8, start: usize) {
234        let mut value = String::new();
235        loop {
236            if self.is_at_end() {
237                self.errors.push(LexError {
238                    span: start..self.pos,
239                    message: "unterminated string literal".to_string(),
240                });
241                break;
242            }
243            let ch = self.advance();
244            if ch == quote {
245                // Check for escaped quote (doubled): '' or ""
246                if self.peek() == Some(quote) {
247                    self.advance();
248                    value.push(quote as char);
249                } else {
250                    break;
251                }
252            } else if ch == b'\\' {
253                // Backslash escape
254                if self.is_at_end() {
255                    self.errors.push(LexError {
256                        span: start..self.pos,
257                        message: "unterminated string escape".to_string(),
258                    });
259                    break;
260                }
261                let esc = self.advance();
262                match esc {
263                    b'n' => value.push('\n'),
264                    b't' => value.push('\t'),
265                    b'r' => value.push('\r'),
266                    b'\\' => value.push('\\'),
267                    b'\'' => value.push('\''),
268                    b'"' => value.push('"'),
269                    b'0' => value.push('\0'),
270                    _ => {
271                        value.push('\\');
272                        value.push(esc as char);
273                    }
274                }
275            } else {
276                value.push(ch as char);
277            }
278        }
279        self.push(Token::StringLiteral(SmolStr::new(&value)), start);
280    }
281
282    fn lex_escaped_ident(&mut self, start: usize) {
283        let mut value = String::new();
284        loop {
285            if self.is_at_end() {
286                self.errors.push(LexError {
287                    span: start..self.pos,
288                    message: "unterminated escaped identifier".to_string(),
289                });
290                break;
291            }
292            let ch = self.advance();
293            if ch == b'`' {
294                // Doubled backtick = literal backtick
295                if self.peek() == Some(b'`') {
296                    self.advance();
297                    value.push('`');
298                } else {
299                    break;
300                }
301            } else {
302                value.push(ch as char);
303            }
304        }
305        self.push(Token::EscapedIdent(SmolStr::new(&value)), start);
306    }
307
308    fn lex_parameter(&mut self, start: usize) {
309        let name_start = self.pos;
310        while !self.is_at_end() && is_ident_continue(self.source[self.pos]) {
311            self.pos += 1;
312        }
313        let name = std::str::from_utf8(&self.source[name_start..self.pos]).unwrap_or("");
314        if name.is_empty() {
315            self.errors.push(LexError {
316                span: start..self.pos,
317                message: "expected parameter name after '$'".to_string(),
318            });
319        } else {
320            self.push(Token::Parameter(SmolStr::new(name)), start);
321        }
322    }
323
324    fn lex_number(&mut self, start: usize) {
325        // Consume integer part
326        while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
327            self.pos += 1;
328        }
329
330        let mut is_float = false;
331
332        // Check for fractional part
333        if self.peek() == Some(b'.')
334            && self
335                .source
336                .get(self.pos + 1)
337                .is_some_and(|c| c.is_ascii_digit())
338        {
339            is_float = true;
340            self.pos += 1; // consume '.'
341            while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
342                self.pos += 1;
343            }
344        }
345
346        // Check for exponent
347        if self.peek() == Some(b'e') || self.peek() == Some(b'E') {
348            is_float = true;
349            self.pos += 1;
350            if self.peek() == Some(b'+') || self.peek() == Some(b'-') {
351                self.pos += 1;
352            }
353            while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
354                self.pos += 1;
355            }
356        }
357
358        let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap_or("0");
359
360        if is_float {
361            self.push(Token::Float(SmolStr::new(text)), start);
362        } else {
363            match text.parse::<i64>() {
364                Ok(n) => self.push(Token::Integer(n), start),
365                Err(_) => {
366                    self.errors.push(LexError {
367                        span: start..self.pos,
368                        message: format!("integer literal too large: {text}"),
369                    });
370                }
371            }
372        }
373    }
374
375    fn lex_ident_or_keyword(&mut self, start: usize) {
376        while !self.is_at_end() && is_ident_continue(self.source[self.pos]) {
377            self.pos += 1;
378        }
379
380        let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap_or("");
381
382        if let Some(kw) = lookup_keyword(text) {
383            self.push(kw, start);
384        } else {
385            self.push(Token::Ident(SmolStr::new(text)), start);
386        }
387    }
388}
389
390fn is_ident_continue(ch: u8) -> bool {
391    ch.is_ascii_alphanumeric() || ch == b'_'
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    fn lex(src: &str) -> Vec<Token> {
399        let (tokens, errors) = Lexer::new(src).lex();
400        assert!(errors.is_empty(), "unexpected lex errors: {errors:?}");
401        tokens.into_iter().map(|(tok, _)| tok).collect()
402    }
403
404    fn lex_with_errors(src: &str) -> (Vec<Token>, Vec<LexError>) {
405        let (tokens, errors) = Lexer::new(src).lex();
406        let toks = tokens.into_iter().map(|(tok, _)| tok).collect();
407        (toks, errors)
408    }
409
410    #[test]
411    fn empty_input() {
412        let tokens = lex("");
413        assert_eq!(tokens, vec![Token::Eof]);
414    }
415
416    #[test]
417    fn single_char_tokens() {
418        let tokens = lex("( ) [ ] { } , ; : | * % ^ & ~");
419        assert_eq!(
420            tokens,
421            vec![
422                Token::LeftParen,
423                Token::RightParen,
424                Token::LeftBracket,
425                Token::RightBracket,
426                Token::LeftBrace,
427                Token::RightBrace,
428                Token::Comma,
429                Token::Semicolon,
430                Token::Colon,
431                Token::Pipe,
432                Token::Star,
433                Token::Percent,
434                Token::Caret,
435                Token::Ampersand,
436                Token::Tilde,
437                Token::Eof,
438            ]
439        );
440    }
441
442    #[test]
443    fn multi_char_operators() {
444        let tokens = lex("-> <- .. << >> =~ += <= >= <>");
445        assert_eq!(
446            tokens,
447            vec![
448                Token::Arrow,
449                Token::LeftArrow,
450                Token::DoubleDot,
451                Token::ShiftLeft,
452                Token::ShiftRight,
453                Token::RegexMatch,
454                Token::PlusEq,
455                Token::Le,
456                Token::Ge,
457                Token::Neq,
458                Token::Eof,
459            ]
460        );
461    }
462
463    #[test]
464    fn integer_literals() {
465        let tokens = lex("0 42 123456789");
466        assert_eq!(
467            tokens,
468            vec![
469                Token::Integer(0),
470                Token::Integer(42),
471                Token::Integer(123456789),
472                Token::Eof,
473            ]
474        );
475    }
476
477    #[test]
478    fn float_literals() {
479        let tokens = lex("3.14 1.0e10 2.5E-3");
480        assert_eq!(
481            tokens,
482            vec![
483                Token::Float(SmolStr::new("3.14")),
484                Token::Float(SmolStr::new("1.0e10")),
485                Token::Float(SmolStr::new("2.5E-3")),
486                Token::Eof,
487            ]
488        );
489    }
490
491    #[test]
492    fn string_literals() {
493        let tokens = lex("'hello' \"world\"");
494        assert_eq!(
495            tokens,
496            vec![
497                Token::StringLiteral(SmolStr::new("hello")),
498                Token::StringLiteral(SmolStr::new("world")),
499                Token::Eof,
500            ]
501        );
502    }
503
504    #[test]
505    fn string_escape_sequences() {
506        let tokens = lex(r#"'he\'s' "tab\there""#);
507        assert_eq!(
508            tokens,
509            vec![
510                Token::StringLiteral(SmolStr::new("he's")),
511                Token::StringLiteral(SmolStr::new("tab\there")),
512                Token::Eof,
513            ]
514        );
515    }
516
517    #[test]
518    fn string_doubled_quotes() {
519        let tokens = lex("'it''s'");
520        assert_eq!(
521            tokens,
522            vec![Token::StringLiteral(SmolStr::new("it's")), Token::Eof]
523        );
524    }
525
526    #[test]
527    fn identifiers() {
528        let tokens = lex("foo _bar baz123");
529        assert_eq!(
530            tokens,
531            vec![
532                Token::Ident(SmolStr::new("foo")),
533                Token::Ident(SmolStr::new("_bar")),
534                Token::Ident(SmolStr::new("baz123")),
535                Token::Eof,
536            ]
537        );
538    }
539
540    #[test]
541    fn escaped_identifiers() {
542        let tokens = lex("`my column` `has``backtick`");
543        assert_eq!(
544            tokens,
545            vec![
546                Token::EscapedIdent(SmolStr::new("my column")),
547                Token::EscapedIdent(SmolStr::new("has`backtick")),
548                Token::Eof,
549            ]
550        );
551    }
552
553    #[test]
554    fn parameters() {
555        let tokens = lex("$param1 $since");
556        assert_eq!(
557            tokens,
558            vec![
559                Token::Parameter(SmolStr::new("param1")),
560                Token::Parameter(SmolStr::new("since")),
561                Token::Eof,
562            ]
563        );
564    }
565
566    #[test]
567    fn keywords_case_insensitive() {
568        let tokens = lex("MATCH Match match WHERE where");
569        assert_eq!(
570            tokens,
571            vec![
572                Token::Match,
573                Token::Match,
574                Token::Match,
575                Token::Where,
576                Token::Where,
577                Token::Eof,
578            ]
579        );
580    }
581
582    #[test]
583    fn boolean_and_null() {
584        let tokens = lex("TRUE false NULL");
585        assert_eq!(
586            tokens,
587            vec![Token::True, Token::False, Token::Null, Token::Eof]
588        );
589    }
590
591    #[test]
592    fn line_comments() {
593        let tokens = lex("MATCH // this is a comment\n(n)");
594        assert_eq!(
595            tokens,
596            vec![
597                Token::Match,
598                Token::LeftParen,
599                Token::Ident(SmolStr::new("n")),
600                Token::RightParen,
601                Token::Eof,
602            ]
603        );
604    }
605
606    #[test]
607    fn block_comments() {
608        let tokens = lex("MATCH /* comment */ (n)");
609        assert_eq!(
610            tokens,
611            vec![
612                Token::Match,
613                Token::LeftParen,
614                Token::Ident(SmolStr::new("n")),
615                Token::RightParen,
616                Token::Eof,
617            ]
618        );
619    }
620
621    #[test]
622    fn full_query() {
623        let tokens = lex("MATCH (n:Person) WHERE n.age > 30 RETURN n.name");
624        // Check that we get the right sequence of tokens
625        assert_eq!(tokens[0], Token::Match);
626        assert_eq!(tokens[1], Token::LeftParen);
627        assert_eq!(tokens[2], Token::Ident(SmolStr::new("n")));
628        assert_eq!(tokens[3], Token::Colon);
629        assert_eq!(tokens[4], Token::Ident(SmolStr::new("Person")));
630        assert_eq!(tokens[5], Token::RightParen);
631        assert_eq!(tokens[6], Token::Where);
632        assert_eq!(tokens[7], Token::Ident(SmolStr::new("n")));
633        assert_eq!(tokens[8], Token::Dot);
634        assert_eq!(tokens[9], Token::Ident(SmolStr::new("age")));
635        assert_eq!(tokens[10], Token::Gt);
636        assert_eq!(tokens[11], Token::Integer(30));
637        assert_eq!(tokens[12], Token::Return);
638        assert_eq!(tokens[13], Token::Ident(SmolStr::new("n")));
639        assert_eq!(tokens[14], Token::Dot);
640        assert_eq!(tokens[15], Token::Ident(SmolStr::new("name")));
641        assert_eq!(tokens[16], Token::Eof);
642    }
643
644    #[test]
645    fn relationship_arrows() {
646        let tokens = lex("(a)-[:KNOWS]->(b)<-[:LIKES]-(c)");
647        assert!(tokens.contains(&Token::Arrow));
648        assert!(tokens.contains(&Token::LeftArrow));
649        assert!(tokens.contains(&Token::Dash));
650    }
651
652    #[test]
653    fn unexpected_char_reports_error() {
654        let (tokens, errors) = lex_with_errors("MATCH @invalid");
655        assert!(!errors.is_empty());
656        assert!(errors[0].message.contains("unexpected character"));
657        // Lexing continues past the error
658        assert!(tokens.len() > 1);
659    }
660
661    #[test]
662    fn unterminated_string_reports_error() {
663        let (_tokens, errors) = lex_with_errors("'unterminated");
664        assert!(!errors.is_empty());
665        assert!(errors[0].message.contains("unterminated string"));
666    }
667
668    #[test]
669    fn spans_are_correct() {
670        let (tokens, _) = Lexer::new("MATCH (n)").lex();
671        assert_eq!(tokens[0].1, 0..5); // MATCH
672        assert_eq!(tokens[1].1, 6..7); // (
673        assert_eq!(tokens[2].1, 7..8); // n
674        assert_eq!(tokens[3].1, 8..9); // )
675    }
676
677    #[test]
678    fn dash_vs_negative_number() {
679        // In Cypher, `-` before a number is a dash (operator), not a negative sign.
680        // The parser handles unary minus. The lexer always emits Dash.
681        let tokens = lex("-42");
682        assert_eq!(tokens[0], Token::Dash);
683        assert_eq!(tokens[1], Token::Integer(42));
684    }
685
686    #[test]
687    fn dot_after_integer_not_float() {
688        // `n.age` — the dot is property access, not a decimal point
689        let tokens = lex("n.age");
690        assert_eq!(
691            tokens,
692            vec![
693                Token::Ident(SmolStr::new("n")),
694                Token::Dot,
695                Token::Ident(SmolStr::new("age")),
696                Token::Eof,
697            ]
698        );
699    }
700}