1use smol_str::SmolStr;
2
3use crate::span::{Span, Spanned};
4use crate::token::{Token, lookup_keyword};
5
6#[derive(Debug, Clone)]
8pub struct LexError {
9 pub span: Span,
10 pub message: String,
11}
12
13pub struct Lexer<'src> {
18 source: &'src [u8],
19 pos: usize,
20 tokens: Vec<Spanned<Token>>,
21 errors: Vec<LexError>,
22}
23
24impl<'src> Lexer<'src> {
25 pub fn new(source: &'src str) -> Self {
26 Self {
27 source: source.as_bytes(),
28 pos: 0,
29 tokens: Vec::new(),
30 errors: Vec::new(),
31 }
32 }
33
34 pub fn lex(mut self) -> (Vec<Spanned<Token>>, Vec<LexError>) {
35 while !self.is_at_end() {
36 self.skip_whitespace_and_comments();
37 if self.is_at_end() {
38 break;
39 }
40
41 let start = self.pos;
42 let ch = self.advance();
43
44 match ch {
45 b'(' => self.push(Token::LeftParen, start),
46 b')' => self.push(Token::RightParen, start),
47 b'[' => self.push(Token::LeftBracket, start),
48 b']' => self.push(Token::RightBracket, start),
49 b'{' => self.push(Token::LeftBrace, start),
50 b'}' => self.push(Token::RightBrace, start),
51 b',' => self.push(Token::Comma, start),
52 b';' => self.push(Token::Semicolon, start),
53 b':' => self.push(Token::Colon, start),
54 b'|' => self.push(Token::Pipe, start),
55 b'*' => self.push(Token::Star, start),
56 b'%' => self.push(Token::Percent, start),
57 b'^' => self.push(Token::Caret, start),
58 b'&' => self.push(Token::Ampersand, start),
59 b'~' => self.push(Token::Tilde, start),
60 b'!' => self.push(Token::Exclaim, start),
61
62 b'.' => {
63 if self.peek() == Some(b'.') {
64 self.advance();
65 self.push(Token::DoubleDot, start);
66 } else {
67 self.push(Token::Dot, start);
68 }
69 }
70
71 b'+' => {
72 if self.peek() == Some(b'=') {
73 self.advance();
74 self.push(Token::PlusEq, start);
75 } else {
76 self.push(Token::Plus, start);
77 }
78 }
79
80 b'=' => {
81 if self.peek() == Some(b'~') {
82 self.advance();
83 self.push(Token::RegexMatch, start);
84 } else {
85 self.push(Token::Eq, start);
86 }
87 }
88
89 b'<' => match self.peek() {
90 Some(b'=') => {
91 self.advance();
92 self.push(Token::Le, start);
93 }
94 Some(b'>') => {
95 self.advance();
96 self.push(Token::Neq, start);
97 }
98 Some(b'<') => {
99 self.advance();
100 self.push(Token::ShiftLeft, start);
101 }
102 Some(b'-') => {
103 self.advance();
104 self.push(Token::LeftArrow, start);
105 }
106 _ => self.push(Token::Lt, start),
107 },
108
109 b'>' => match self.peek() {
110 Some(b'=') => {
111 self.advance();
112 self.push(Token::Ge, start);
113 }
114 Some(b'>') => {
115 self.advance();
116 self.push(Token::ShiftRight, start);
117 }
118 _ => self.push(Token::Gt, start),
119 },
120
121 b'-' => {
122 if self.peek() == Some(b'>') {
123 self.advance();
124 self.push(Token::Arrow, start);
125 } else {
126 self.push(Token::Dash, start);
127 }
128 }
129
130 b'/' => {
131 self.push(Token::Slash, start);
133 }
134
135 b'\'' | b'"' => self.lex_string(ch, start),
136
137 b'`' => self.lex_escaped_ident(start),
138
139 b'$' => self.lex_parameter(start),
140
141 b'0'..=b'9' => self.lex_number(start),
142
143 b'a'..=b'z' | b'A'..=b'Z' | b'_' => self.lex_ident_or_keyword(start),
144
145 _ => {
146 self.errors.push(LexError {
147 span: start..self.pos,
148 message: format!("unexpected character '{}'", ch as char),
149 });
150 }
151 }
152 }
153
154 self.tokens.push((Token::Eof, self.pos..self.pos));
155 (self.tokens, self.errors)
156 }
157
158 fn is_at_end(&self) -> bool {
159 self.pos >= self.source.len()
160 }
161
162 fn peek(&self) -> Option<u8> {
163 self.source.get(self.pos).copied()
164 }
165
166 fn advance(&mut self) -> u8 {
167 let ch = self.source[self.pos];
168 self.pos += 1;
169 ch
170 }
171
172 fn push(&mut self, token: Token, start: usize) {
173 self.tokens.push((token, start..self.pos));
174 }
175
176 fn skip_whitespace_and_comments(&mut self) {
177 while !self.is_at_end() {
178 let ch = self.source[self.pos];
179 match ch {
180 b' ' | b'\t' | b'\n' | b'\r' => {
181 self.pos += 1;
182 }
183 b'/' => {
184 if self.pos + 1 < self.source.len() {
185 match self.source[self.pos + 1] {
186 b'/' => {
187 self.pos += 2;
189 while !self.is_at_end() && self.source[self.pos] != b'\n' {
190 self.pos += 1;
191 }
192 }
193 b'*' => {
194 let start = self.pos;
196 self.pos += 2;
197 let mut depth = 1;
198 while !self.is_at_end() && depth > 0 {
199 if self.source[self.pos] == b'*'
200 && self.pos + 1 < self.source.len()
201 && self.source[self.pos + 1] == b'/'
202 {
203 depth -= 1;
204 self.pos += 2;
205 } else if self.source[self.pos] == b'/'
206 && self.pos + 1 < self.source.len()
207 && self.source[self.pos + 1] == b'*'
208 {
209 depth += 1;
210 self.pos += 2;
211 } else {
212 self.pos += 1;
213 }
214 }
215 if depth > 0 {
216 self.errors.push(LexError {
217 span: start..self.pos,
218 message: "unterminated block comment".to_string(),
219 });
220 }
221 }
222 _ => break,
223 }
224 } else {
225 break;
226 }
227 }
228 _ => break,
229 }
230 }
231 }
232
233 fn lex_string(&mut self, quote: u8, start: usize) {
234 let mut value = String::new();
235 loop {
236 if self.is_at_end() {
237 self.errors.push(LexError {
238 span: start..self.pos,
239 message: "unterminated string literal".to_string(),
240 });
241 break;
242 }
243 let ch = self.advance();
244 if ch == quote {
245 if self.peek() == Some(quote) {
247 self.advance();
248 value.push(quote as char);
249 } else {
250 break;
251 }
252 } else if ch == b'\\' {
253 if self.is_at_end() {
255 self.errors.push(LexError {
256 span: start..self.pos,
257 message: "unterminated string escape".to_string(),
258 });
259 break;
260 }
261 let esc = self.advance();
262 match esc {
263 b'n' => value.push('\n'),
264 b't' => value.push('\t'),
265 b'r' => value.push('\r'),
266 b'\\' => value.push('\\'),
267 b'\'' => value.push('\''),
268 b'"' => value.push('"'),
269 b'0' => value.push('\0'),
270 _ => {
271 value.push('\\');
272 value.push(esc as char);
273 }
274 }
275 } else {
276 value.push(ch as char);
277 }
278 }
279 self.push(Token::StringLiteral(SmolStr::new(&value)), start);
280 }
281
282 fn lex_escaped_ident(&mut self, start: usize) {
283 let mut value = String::new();
284 loop {
285 if self.is_at_end() {
286 self.errors.push(LexError {
287 span: start..self.pos,
288 message: "unterminated escaped identifier".to_string(),
289 });
290 break;
291 }
292 let ch = self.advance();
293 if ch == b'`' {
294 if self.peek() == Some(b'`') {
296 self.advance();
297 value.push('`');
298 } else {
299 break;
300 }
301 } else {
302 value.push(ch as char);
303 }
304 }
305 self.push(Token::EscapedIdent(SmolStr::new(&value)), start);
306 }
307
308 fn lex_parameter(&mut self, start: usize) {
309 let name_start = self.pos;
310 while !self.is_at_end() && is_ident_continue(self.source[self.pos]) {
311 self.pos += 1;
312 }
313 let name = std::str::from_utf8(&self.source[name_start..self.pos]).unwrap_or("");
314 if name.is_empty() {
315 self.errors.push(LexError {
316 span: start..self.pos,
317 message: "expected parameter name after '$'".to_string(),
318 });
319 } else {
320 self.push(Token::Parameter(SmolStr::new(name)), start);
321 }
322 }
323
324 fn lex_number(&mut self, start: usize) {
325 while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
327 self.pos += 1;
328 }
329
330 let mut is_float = false;
331
332 if self.peek() == Some(b'.')
334 && self
335 .source
336 .get(self.pos + 1)
337 .is_some_and(|c| c.is_ascii_digit())
338 {
339 is_float = true;
340 self.pos += 1; while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
342 self.pos += 1;
343 }
344 }
345
346 if self.peek() == Some(b'e') || self.peek() == Some(b'E') {
348 is_float = true;
349 self.pos += 1;
350 if self.peek() == Some(b'+') || self.peek() == Some(b'-') {
351 self.pos += 1;
352 }
353 while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
354 self.pos += 1;
355 }
356 }
357
358 let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap_or("0");
359
360 if is_float {
361 self.push(Token::Float(SmolStr::new(text)), start);
362 } else {
363 match text.parse::<i64>() {
364 Ok(n) => self.push(Token::Integer(n), start),
365 Err(_) => {
366 self.errors.push(LexError {
367 span: start..self.pos,
368 message: format!("integer literal too large: {text}"),
369 });
370 }
371 }
372 }
373 }
374
375 fn lex_ident_or_keyword(&mut self, start: usize) {
376 while !self.is_at_end() && is_ident_continue(self.source[self.pos]) {
377 self.pos += 1;
378 }
379
380 let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap_or("");
381
382 if let Some(kw) = lookup_keyword(text) {
383 self.push(kw, start);
384 } else {
385 self.push(Token::Ident(SmolStr::new(text)), start);
386 }
387 }
388}
389
390fn is_ident_continue(ch: u8) -> bool {
391 ch.is_ascii_alphanumeric() || ch == b'_'
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 fn lex(src: &str) -> Vec<Token> {
399 let (tokens, errors) = Lexer::new(src).lex();
400 assert!(errors.is_empty(), "unexpected lex errors: {errors:?}");
401 tokens.into_iter().map(|(tok, _)| tok).collect()
402 }
403
404 fn lex_with_errors(src: &str) -> (Vec<Token>, Vec<LexError>) {
405 let (tokens, errors) = Lexer::new(src).lex();
406 let toks = tokens.into_iter().map(|(tok, _)| tok).collect();
407 (toks, errors)
408 }
409
410 #[test]
411 fn empty_input() {
412 let tokens = lex("");
413 assert_eq!(tokens, vec![Token::Eof]);
414 }
415
416 #[test]
417 fn single_char_tokens() {
418 let tokens = lex("( ) [ ] { } , ; : | * % ^ & ~");
419 assert_eq!(
420 tokens,
421 vec![
422 Token::LeftParen,
423 Token::RightParen,
424 Token::LeftBracket,
425 Token::RightBracket,
426 Token::LeftBrace,
427 Token::RightBrace,
428 Token::Comma,
429 Token::Semicolon,
430 Token::Colon,
431 Token::Pipe,
432 Token::Star,
433 Token::Percent,
434 Token::Caret,
435 Token::Ampersand,
436 Token::Tilde,
437 Token::Eof,
438 ]
439 );
440 }
441
442 #[test]
443 fn multi_char_operators() {
444 let tokens = lex("-> <- .. << >> =~ += <= >= <>");
445 assert_eq!(
446 tokens,
447 vec![
448 Token::Arrow,
449 Token::LeftArrow,
450 Token::DoubleDot,
451 Token::ShiftLeft,
452 Token::ShiftRight,
453 Token::RegexMatch,
454 Token::PlusEq,
455 Token::Le,
456 Token::Ge,
457 Token::Neq,
458 Token::Eof,
459 ]
460 );
461 }
462
463 #[test]
464 fn integer_literals() {
465 let tokens = lex("0 42 123456789");
466 assert_eq!(
467 tokens,
468 vec![
469 Token::Integer(0),
470 Token::Integer(42),
471 Token::Integer(123456789),
472 Token::Eof,
473 ]
474 );
475 }
476
477 #[test]
478 fn float_literals() {
479 let tokens = lex("3.14 1.0e10 2.5E-3");
480 assert_eq!(
481 tokens,
482 vec![
483 Token::Float(SmolStr::new("3.14")),
484 Token::Float(SmolStr::new("1.0e10")),
485 Token::Float(SmolStr::new("2.5E-3")),
486 Token::Eof,
487 ]
488 );
489 }
490
491 #[test]
492 fn string_literals() {
493 let tokens = lex("'hello' \"world\"");
494 assert_eq!(
495 tokens,
496 vec![
497 Token::StringLiteral(SmolStr::new("hello")),
498 Token::StringLiteral(SmolStr::new("world")),
499 Token::Eof,
500 ]
501 );
502 }
503
504 #[test]
505 fn string_escape_sequences() {
506 let tokens = lex(r#"'he\'s' "tab\there""#);
507 assert_eq!(
508 tokens,
509 vec![
510 Token::StringLiteral(SmolStr::new("he's")),
511 Token::StringLiteral(SmolStr::new("tab\there")),
512 Token::Eof,
513 ]
514 );
515 }
516
517 #[test]
518 fn string_doubled_quotes() {
519 let tokens = lex("'it''s'");
520 assert_eq!(
521 tokens,
522 vec![Token::StringLiteral(SmolStr::new("it's")), Token::Eof]
523 );
524 }
525
526 #[test]
527 fn identifiers() {
528 let tokens = lex("foo _bar baz123");
529 assert_eq!(
530 tokens,
531 vec![
532 Token::Ident(SmolStr::new("foo")),
533 Token::Ident(SmolStr::new("_bar")),
534 Token::Ident(SmolStr::new("baz123")),
535 Token::Eof,
536 ]
537 );
538 }
539
540 #[test]
541 fn escaped_identifiers() {
542 let tokens = lex("`my column` `has``backtick`");
543 assert_eq!(
544 tokens,
545 vec![
546 Token::EscapedIdent(SmolStr::new("my column")),
547 Token::EscapedIdent(SmolStr::new("has`backtick")),
548 Token::Eof,
549 ]
550 );
551 }
552
553 #[test]
554 fn parameters() {
555 let tokens = lex("$param1 $since");
556 assert_eq!(
557 tokens,
558 vec![
559 Token::Parameter(SmolStr::new("param1")),
560 Token::Parameter(SmolStr::new("since")),
561 Token::Eof,
562 ]
563 );
564 }
565
566 #[test]
567 fn keywords_case_insensitive() {
568 let tokens = lex("MATCH Match match WHERE where");
569 assert_eq!(
570 tokens,
571 vec![
572 Token::Match,
573 Token::Match,
574 Token::Match,
575 Token::Where,
576 Token::Where,
577 Token::Eof,
578 ]
579 );
580 }
581
582 #[test]
583 fn boolean_and_null() {
584 let tokens = lex("TRUE false NULL");
585 assert_eq!(
586 tokens,
587 vec![Token::True, Token::False, Token::Null, Token::Eof]
588 );
589 }
590
591 #[test]
592 fn line_comments() {
593 let tokens = lex("MATCH // this is a comment\n(n)");
594 assert_eq!(
595 tokens,
596 vec![
597 Token::Match,
598 Token::LeftParen,
599 Token::Ident(SmolStr::new("n")),
600 Token::RightParen,
601 Token::Eof,
602 ]
603 );
604 }
605
606 #[test]
607 fn block_comments() {
608 let tokens = lex("MATCH /* comment */ (n)");
609 assert_eq!(
610 tokens,
611 vec![
612 Token::Match,
613 Token::LeftParen,
614 Token::Ident(SmolStr::new("n")),
615 Token::RightParen,
616 Token::Eof,
617 ]
618 );
619 }
620
621 #[test]
622 fn full_query() {
623 let tokens = lex("MATCH (n:Person) WHERE n.age > 30 RETURN n.name");
624 assert_eq!(tokens[0], Token::Match);
626 assert_eq!(tokens[1], Token::LeftParen);
627 assert_eq!(tokens[2], Token::Ident(SmolStr::new("n")));
628 assert_eq!(tokens[3], Token::Colon);
629 assert_eq!(tokens[4], Token::Ident(SmolStr::new("Person")));
630 assert_eq!(tokens[5], Token::RightParen);
631 assert_eq!(tokens[6], Token::Where);
632 assert_eq!(tokens[7], Token::Ident(SmolStr::new("n")));
633 assert_eq!(tokens[8], Token::Dot);
634 assert_eq!(tokens[9], Token::Ident(SmolStr::new("age")));
635 assert_eq!(tokens[10], Token::Gt);
636 assert_eq!(tokens[11], Token::Integer(30));
637 assert_eq!(tokens[12], Token::Return);
638 assert_eq!(tokens[13], Token::Ident(SmolStr::new("n")));
639 assert_eq!(tokens[14], Token::Dot);
640 assert_eq!(tokens[15], Token::Ident(SmolStr::new("name")));
641 assert_eq!(tokens[16], Token::Eof);
642 }
643
644 #[test]
645 fn relationship_arrows() {
646 let tokens = lex("(a)-[:KNOWS]->(b)<-[:LIKES]-(c)");
647 assert!(tokens.contains(&Token::Arrow));
648 assert!(tokens.contains(&Token::LeftArrow));
649 assert!(tokens.contains(&Token::Dash));
650 }
651
652 #[test]
653 fn unexpected_char_reports_error() {
654 let (tokens, errors) = lex_with_errors("MATCH @invalid");
655 assert!(!errors.is_empty());
656 assert!(errors[0].message.contains("unexpected character"));
657 assert!(tokens.len() > 1);
659 }
660
661 #[test]
662 fn unterminated_string_reports_error() {
663 let (_tokens, errors) = lex_with_errors("'unterminated");
664 assert!(!errors.is_empty());
665 assert!(errors[0].message.contains("unterminated string"));
666 }
667
668 #[test]
669 fn spans_are_correct() {
670 let (tokens, _) = Lexer::new("MATCH (n)").lex();
671 assert_eq!(tokens[0].1, 0..5); assert_eq!(tokens[1].1, 6..7); assert_eq!(tokens[2].1, 7..8); assert_eq!(tokens[3].1, 8..9); }
676
677 #[test]
678 fn dash_vs_negative_number() {
679 let tokens = lex("-42");
682 assert_eq!(tokens[0], Token::Dash);
683 assert_eq!(tokens[1], Token::Integer(42));
684 }
685
686 #[test]
687 fn dot_after_integer_not_float() {
688 let tokens = lex("n.age");
690 assert_eq!(
691 tokens,
692 vec![
693 Token::Ident(SmolStr::new("n")),
694 Token::Dot,
695 Token::Ident(SmolStr::new("age")),
696 Token::Eof,
697 ]
698 );
699 }
700}