1use smol_str::SmolStr;
2
3use crate::span::{Span, Spanned};
4use crate::token::{lookup_keyword, Token};
5
6#[derive(Debug, Clone)]
8pub struct LexError {
9 pub span: Span,
10 pub message: String,
11}
12
13pub struct Lexer<'src> {
18 source: &'src [u8],
19 pos: usize,
20 tokens: Vec<Spanned<Token>>,
21 errors: Vec<LexError>,
22}
23
24impl<'src> Lexer<'src> {
25 pub fn new(source: &'src str) -> Self {
26 Self {
27 source: source.as_bytes(),
28 pos: 0,
29 tokens: Vec::new(),
30 errors: Vec::new(),
31 }
32 }
33
34 pub fn lex(mut self) -> (Vec<Spanned<Token>>, Vec<LexError>) {
35 while !self.is_at_end() {
36 self.skip_whitespace_and_comments();
37 if self.is_at_end() {
38 break;
39 }
40
41 let start = self.pos;
42 let ch = self.advance();
43
44 match ch {
45 b'(' => self.push(Token::LeftParen, start),
46 b')' => self.push(Token::RightParen, start),
47 b'[' => self.push(Token::LeftBracket, start),
48 b']' => self.push(Token::RightBracket, start),
49 b'{' => self.push(Token::LeftBrace, start),
50 b'}' => self.push(Token::RightBrace, start),
51 b',' => self.push(Token::Comma, start),
52 b';' => self.push(Token::Semicolon, start),
53 b':' => self.push(Token::Colon, start),
54 b'|' => self.push(Token::Pipe, start),
55 b'*' => self.push(Token::Star, start),
56 b'%' => self.push(Token::Percent, start),
57 b'^' => self.push(Token::Caret, start),
58 b'&' => self.push(Token::Ampersand, start),
59 b'~' => self.push(Token::Tilde, start),
60 b'!' => self.push(Token::Exclaim, start),
61
62 b'.' => {
63 if self.peek() == Some(b'.') {
64 self.advance();
65 self.push(Token::DoubleDot, start);
66 } else {
67 self.push(Token::Dot, start);
68 }
69 }
70
71 b'+' => {
72 if self.peek() == Some(b'=') {
73 self.advance();
74 self.push(Token::PlusEq, start);
75 } else {
76 self.push(Token::Plus, start);
77 }
78 }
79
80 b'=' => {
81 if self.peek() == Some(b'~') {
82 self.advance();
83 self.push(Token::RegexMatch, start);
84 } else {
85 self.push(Token::Eq, start);
86 }
87 }
88
89 b'<' => {
90 match self.peek() {
91 Some(b'=') => {
92 self.advance();
93 self.push(Token::Le, start);
94 }
95 Some(b'>') => {
96 self.advance();
97 self.push(Token::Neq, start);
98 }
99 Some(b'<') => {
100 self.advance();
101 self.push(Token::ShiftLeft, start);
102 }
103 Some(b'-') => {
104 self.advance();
105 self.push(Token::LeftArrow, start);
106 }
107 _ => self.push(Token::Lt, start),
108 }
109 }
110
111 b'>' => {
112 match self.peek() {
113 Some(b'=') => {
114 self.advance();
115 self.push(Token::Ge, start);
116 }
117 Some(b'>') => {
118 self.advance();
119 self.push(Token::ShiftRight, start);
120 }
121 _ => self.push(Token::Gt, start),
122 }
123 }
124
125 b'-' => {
126 if self.peek() == Some(b'>') {
127 self.advance();
128 self.push(Token::Arrow, start);
129 } else {
130 self.push(Token::Dash, start);
131 }
132 }
133
134 b'/' => {
135 self.push(Token::Slash, start);
137 }
138
139 b'\'' | b'"' => self.lex_string(ch, start),
140
141 b'`' => self.lex_escaped_ident(start),
142
143 b'$' => self.lex_parameter(start),
144
145 b'0'..=b'9' => self.lex_number(start),
146
147 b'a'..=b'z' | b'A'..=b'Z' | b'_' => self.lex_ident_or_keyword(start),
148
149 _ => {
150 self.errors.push(LexError {
151 span: start..self.pos,
152 message: format!("unexpected character '{}'", ch as char),
153 });
154 }
155 }
156 }
157
158 self.tokens.push((Token::Eof, self.pos..self.pos));
159 (self.tokens, self.errors)
160 }
161
162 fn is_at_end(&self) -> bool {
163 self.pos >= self.source.len()
164 }
165
166 fn peek(&self) -> Option<u8> {
167 self.source.get(self.pos).copied()
168 }
169
170 fn advance(&mut self) -> u8 {
171 let ch = self.source[self.pos];
172 self.pos += 1;
173 ch
174 }
175
176 fn push(&mut self, token: Token, start: usize) {
177 self.tokens.push((token, start..self.pos));
178 }
179
180 fn skip_whitespace_and_comments(&mut self) {
181 while !self.is_at_end() {
182 let ch = self.source[self.pos];
183 match ch {
184 b' ' | b'\t' | b'\n' | b'\r' => {
185 self.pos += 1;
186 }
187 b'/' => {
188 if self.pos + 1 < self.source.len() {
189 match self.source[self.pos + 1] {
190 b'/' => {
191 self.pos += 2;
193 while !self.is_at_end() && self.source[self.pos] != b'\n' {
194 self.pos += 1;
195 }
196 }
197 b'*' => {
198 let start = self.pos;
200 self.pos += 2;
201 let mut depth = 1;
202 while !self.is_at_end() && depth > 0 {
203 if self.source[self.pos] == b'*'
204 && self.pos + 1 < self.source.len()
205 && self.source[self.pos + 1] == b'/'
206 {
207 depth -= 1;
208 self.pos += 2;
209 } else if self.source[self.pos] == b'/'
210 && self.pos + 1 < self.source.len()
211 && self.source[self.pos + 1] == b'*'
212 {
213 depth += 1;
214 self.pos += 2;
215 } else {
216 self.pos += 1;
217 }
218 }
219 if depth > 0 {
220 self.errors.push(LexError {
221 span: start..self.pos,
222 message: "unterminated block comment".to_string(),
223 });
224 }
225 }
226 _ => break,
227 }
228 } else {
229 break;
230 }
231 }
232 _ => break,
233 }
234 }
235 }
236
237 fn lex_string(&mut self, quote: u8, start: usize) {
238 let mut value = String::new();
239 loop {
240 if self.is_at_end() {
241 self.errors.push(LexError {
242 span: start..self.pos,
243 message: "unterminated string literal".to_string(),
244 });
245 break;
246 }
247 let ch = self.advance();
248 if ch == quote {
249 if self.peek() == Some(quote) {
251 self.advance();
252 value.push(quote as char);
253 } else {
254 break;
255 }
256 } else if ch == b'\\' {
257 if self.is_at_end() {
259 self.errors.push(LexError {
260 span: start..self.pos,
261 message: "unterminated string escape".to_string(),
262 });
263 break;
264 }
265 let esc = self.advance();
266 match esc {
267 b'n' => value.push('\n'),
268 b't' => value.push('\t'),
269 b'r' => value.push('\r'),
270 b'\\' => value.push('\\'),
271 b'\'' => value.push('\''),
272 b'"' => value.push('"'),
273 b'0' => value.push('\0'),
274 _ => {
275 value.push('\\');
276 value.push(esc as char);
277 }
278 }
279 } else {
280 value.push(ch as char);
281 }
282 }
283 self.push(Token::StringLiteral(SmolStr::new(&value)), start);
284 }
285
286 fn lex_escaped_ident(&mut self, start: usize) {
287 let mut value = String::new();
288 loop {
289 if self.is_at_end() {
290 self.errors.push(LexError {
291 span: start..self.pos,
292 message: "unterminated escaped identifier".to_string(),
293 });
294 break;
295 }
296 let ch = self.advance();
297 if ch == b'`' {
298 if self.peek() == Some(b'`') {
300 self.advance();
301 value.push('`');
302 } else {
303 break;
304 }
305 } else {
306 value.push(ch as char);
307 }
308 }
309 self.push(Token::EscapedIdent(SmolStr::new(&value)), start);
310 }
311
312 fn lex_parameter(&mut self, start: usize) {
313 let name_start = self.pos;
314 while !self.is_at_end() && is_ident_continue(self.source[self.pos]) {
315 self.pos += 1;
316 }
317 let name = std::str::from_utf8(&self.source[name_start..self.pos]).unwrap_or("");
318 if name.is_empty() {
319 self.errors.push(LexError {
320 span: start..self.pos,
321 message: "expected parameter name after '$'".to_string(),
322 });
323 } else {
324 self.push(Token::Parameter(SmolStr::new(name)), start);
325 }
326 }
327
328 fn lex_number(&mut self, start: usize) {
329 while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
331 self.pos += 1;
332 }
333
334 let mut is_float = false;
335
336 if self.peek() == Some(b'.')
338 && self
339 .source
340 .get(self.pos + 1)
341 .is_some_and(|c| c.is_ascii_digit())
342 {
343 is_float = true;
344 self.pos += 1; while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
346 self.pos += 1;
347 }
348 }
349
350 if self.peek() == Some(b'e') || self.peek() == Some(b'E') {
352 is_float = true;
353 self.pos += 1;
354 if self.peek() == Some(b'+') || self.peek() == Some(b'-') {
355 self.pos += 1;
356 }
357 while !self.is_at_end() && self.source[self.pos].is_ascii_digit() {
358 self.pos += 1;
359 }
360 }
361
362 let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap_or("0");
363
364 if is_float {
365 self.push(Token::Float(SmolStr::new(text)), start);
366 } else {
367 match text.parse::<i64>() {
368 Ok(n) => self.push(Token::Integer(n), start),
369 Err(_) => {
370 self.errors.push(LexError {
371 span: start..self.pos,
372 message: format!("integer literal too large: {text}"),
373 });
374 }
375 }
376 }
377 }
378
379 fn lex_ident_or_keyword(&mut self, start: usize) {
380 while !self.is_at_end() && is_ident_continue(self.source[self.pos]) {
381 self.pos += 1;
382 }
383
384 let text = std::str::from_utf8(&self.source[start..self.pos]).unwrap_or("");
385
386 if let Some(kw) = lookup_keyword(text) {
387 self.push(kw, start);
388 } else {
389 self.push(Token::Ident(SmolStr::new(text)), start);
390 }
391 }
392}
393
394fn is_ident_continue(ch: u8) -> bool {
395 ch.is_ascii_alphanumeric() || ch == b'_'
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 fn lex(src: &str) -> Vec<Token> {
403 let (tokens, errors) = Lexer::new(src).lex();
404 assert!(errors.is_empty(), "unexpected lex errors: {errors:?}");
405 tokens.into_iter().map(|(tok, _)| tok).collect()
406 }
407
408 fn lex_with_errors(src: &str) -> (Vec<Token>, Vec<LexError>) {
409 let (tokens, errors) = Lexer::new(src).lex();
410 let toks = tokens.into_iter().map(|(tok, _)| tok).collect();
411 (toks, errors)
412 }
413
414 #[test]
415 fn empty_input() {
416 let tokens = lex("");
417 assert_eq!(tokens, vec![Token::Eof]);
418 }
419
420 #[test]
421 fn single_char_tokens() {
422 let tokens = lex("( ) [ ] { } , ; : | * % ^ & ~");
423 assert_eq!(
424 tokens,
425 vec![
426 Token::LeftParen,
427 Token::RightParen,
428 Token::LeftBracket,
429 Token::RightBracket,
430 Token::LeftBrace,
431 Token::RightBrace,
432 Token::Comma,
433 Token::Semicolon,
434 Token::Colon,
435 Token::Pipe,
436 Token::Star,
437 Token::Percent,
438 Token::Caret,
439 Token::Ampersand,
440 Token::Tilde,
441 Token::Eof,
442 ]
443 );
444 }
445
446 #[test]
447 fn multi_char_operators() {
448 let tokens = lex("-> <- .. << >> =~ += <= >= <>");
449 assert_eq!(
450 tokens,
451 vec![
452 Token::Arrow,
453 Token::LeftArrow,
454 Token::DoubleDot,
455 Token::ShiftLeft,
456 Token::ShiftRight,
457 Token::RegexMatch,
458 Token::PlusEq,
459 Token::Le,
460 Token::Ge,
461 Token::Neq,
462 Token::Eof,
463 ]
464 );
465 }
466
467 #[test]
468 fn integer_literals() {
469 let tokens = lex("0 42 123456789");
470 assert_eq!(
471 tokens,
472 vec![
473 Token::Integer(0),
474 Token::Integer(42),
475 Token::Integer(123456789),
476 Token::Eof,
477 ]
478 );
479 }
480
481 #[test]
482 fn float_literals() {
483 let tokens = lex("3.14 1.0e10 2.5E-3");
484 assert_eq!(
485 tokens,
486 vec![
487 Token::Float(SmolStr::new("3.14")),
488 Token::Float(SmolStr::new("1.0e10")),
489 Token::Float(SmolStr::new("2.5E-3")),
490 Token::Eof,
491 ]
492 );
493 }
494
495 #[test]
496 fn string_literals() {
497 let tokens = lex("'hello' \"world\"");
498 assert_eq!(
499 tokens,
500 vec![
501 Token::StringLiteral(SmolStr::new("hello")),
502 Token::StringLiteral(SmolStr::new("world")),
503 Token::Eof,
504 ]
505 );
506 }
507
508 #[test]
509 fn string_escape_sequences() {
510 let tokens = lex(r#"'he\'s' "tab\there""#);
511 assert_eq!(
512 tokens,
513 vec![
514 Token::StringLiteral(SmolStr::new("he's")),
515 Token::StringLiteral(SmolStr::new("tab\there")),
516 Token::Eof,
517 ]
518 );
519 }
520
521 #[test]
522 fn string_doubled_quotes() {
523 let tokens = lex("'it''s'");
524 assert_eq!(
525 tokens,
526 vec![Token::StringLiteral(SmolStr::new("it's")), Token::Eof]
527 );
528 }
529
530 #[test]
531 fn identifiers() {
532 let tokens = lex("foo _bar baz123");
533 assert_eq!(
534 tokens,
535 vec![
536 Token::Ident(SmolStr::new("foo")),
537 Token::Ident(SmolStr::new("_bar")),
538 Token::Ident(SmolStr::new("baz123")),
539 Token::Eof,
540 ]
541 );
542 }
543
544 #[test]
545 fn escaped_identifiers() {
546 let tokens = lex("`my column` `has``backtick`");
547 assert_eq!(
548 tokens,
549 vec![
550 Token::EscapedIdent(SmolStr::new("my column")),
551 Token::EscapedIdent(SmolStr::new("has`backtick")),
552 Token::Eof,
553 ]
554 );
555 }
556
557 #[test]
558 fn parameters() {
559 let tokens = lex("$param1 $since");
560 assert_eq!(
561 tokens,
562 vec![
563 Token::Parameter(SmolStr::new("param1")),
564 Token::Parameter(SmolStr::new("since")),
565 Token::Eof,
566 ]
567 );
568 }
569
570 #[test]
571 fn keywords_case_insensitive() {
572 let tokens = lex("MATCH Match match WHERE where");
573 assert_eq!(
574 tokens,
575 vec![
576 Token::Match,
577 Token::Match,
578 Token::Match,
579 Token::Where,
580 Token::Where,
581 Token::Eof,
582 ]
583 );
584 }
585
586 #[test]
587 fn boolean_and_null() {
588 let tokens = lex("TRUE false NULL");
589 assert_eq!(
590 tokens,
591 vec![Token::True, Token::False, Token::Null, Token::Eof]
592 );
593 }
594
595 #[test]
596 fn line_comments() {
597 let tokens = lex("MATCH // this is a comment\n(n)");
598 assert_eq!(
599 tokens,
600 vec![
601 Token::Match,
602 Token::LeftParen,
603 Token::Ident(SmolStr::new("n")),
604 Token::RightParen,
605 Token::Eof,
606 ]
607 );
608 }
609
610 #[test]
611 fn block_comments() {
612 let tokens = lex("MATCH /* comment */ (n)");
613 assert_eq!(
614 tokens,
615 vec![
616 Token::Match,
617 Token::LeftParen,
618 Token::Ident(SmolStr::new("n")),
619 Token::RightParen,
620 Token::Eof,
621 ]
622 );
623 }
624
625 #[test]
626 fn full_query() {
627 let tokens = lex("MATCH (n:Person) WHERE n.age > 30 RETURN n.name");
628 assert_eq!(tokens[0], Token::Match);
630 assert_eq!(tokens[1], Token::LeftParen);
631 assert_eq!(tokens[2], Token::Ident(SmolStr::new("n")));
632 assert_eq!(tokens[3], Token::Colon);
633 assert_eq!(tokens[4], Token::Ident(SmolStr::new("Person")));
634 assert_eq!(tokens[5], Token::RightParen);
635 assert_eq!(tokens[6], Token::Where);
636 assert_eq!(tokens[7], Token::Ident(SmolStr::new("n")));
637 assert_eq!(tokens[8], Token::Dot);
638 assert_eq!(tokens[9], Token::Ident(SmolStr::new("age")));
639 assert_eq!(tokens[10], Token::Gt);
640 assert_eq!(tokens[11], Token::Integer(30));
641 assert_eq!(tokens[12], Token::Return);
642 assert_eq!(tokens[13], Token::Ident(SmolStr::new("n")));
643 assert_eq!(tokens[14], Token::Dot);
644 assert_eq!(tokens[15], Token::Ident(SmolStr::new("name")));
645 assert_eq!(tokens[16], Token::Eof);
646 }
647
648 #[test]
649 fn relationship_arrows() {
650 let tokens = lex("(a)-[:KNOWS]->(b)<-[:LIKES]-(c)");
651 assert!(tokens.contains(&Token::Arrow));
652 assert!(tokens.contains(&Token::LeftArrow));
653 assert!(tokens.contains(&Token::Dash));
654 }
655
656 #[test]
657 fn unexpected_char_reports_error() {
658 let (tokens, errors) = lex_with_errors("MATCH @invalid");
659 assert!(!errors.is_empty());
660 assert!(errors[0].message.contains("unexpected character"));
661 assert!(tokens.len() > 1);
663 }
664
665 #[test]
666 fn unterminated_string_reports_error() {
667 let (_tokens, errors) = lex_with_errors("'unterminated");
668 assert!(!errors.is_empty());
669 assert!(errors[0].message.contains("unterminated string"));
670 }
671
672 #[test]
673 fn spans_are_correct() {
674 let (tokens, _) = Lexer::new("MATCH (n)").lex();
675 assert_eq!(tokens[0].1, 0..5); assert_eq!(tokens[1].1, 6..7); assert_eq!(tokens[2].1, 7..8); assert_eq!(tokens[3].1, 8..9); }
680
681 #[test]
682 fn dash_vs_negative_number() {
683 let tokens = lex("-42");
686 assert_eq!(tokens[0], Token::Dash);
687 assert_eq!(tokens[1], Token::Integer(42));
688 }
689
690 #[test]
691 fn dot_after_integer_not_float() {
692 let tokens = lex("n.age");
694 assert_eq!(
695 tokens,
696 vec![
697 Token::Ident(SmolStr::new("n")),
698 Token::Dot,
699 Token::Ident(SmolStr::new("age")),
700 Token::Eof,
701 ]
702 );
703 }
704}