1use crate::error::ParseError;
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Match,
7 Optional,
8 Create,
9 Merge,
10 Delete,
11 Detach,
12 Set,
13 Remove,
14 Unwind,
15 Union,
16 All,
17 Where,
18 Return,
19 With,
20 As,
21 And,
22 Or,
23 Not,
24 Is,
25 Null,
26 True,
27 False,
28 Order,
29 By,
30 Limit,
31 Skip,
32 Asc,
33 Desc,
34 Starts,
35 Ends,
36 Contains,
37 In,
38 Case,
39 When,
40 Then,
41 Else,
42 End,
43 Count,
45 Sum,
46 Avg,
47 Min,
48 Max,
49 Collect,
50
51 Ident(String),
53 Integer(i64),
54 Float(f64),
55 Str(String),
56
57 Eq, Ne, Lt, Gt, Le, Ge, Dot, Comma, Star, Plus, Slash, Dollar, LParen, RParen, LBrack, RBrack, LBrace, RBrace, Colon, Pipe, Arrow, LArrow, Dash, Eof,
85}
86
87#[derive(Debug, Clone)]
89pub struct Spanned {
90 pub token: Token,
91 pub offset: usize,
92 pub line: u32,
93 pub col: u32,
94}
95
96pub fn tokenize(input: &str) -> Result<Vec<Spanned>, ParseError> {
97 let mut tokens = Vec::new();
98 let bytes = input.as_bytes();
99 let mut i = 0;
100 let mut line: u32 = 1;
101 let mut col: u32 = 1;
102
103 macro_rules! advance {
105 () => {{
106 if bytes[i] == b'\n' {
107 line += 1;
108 col = 1;
109 } else {
110 col += 1;
111 }
112 i += 1;
113 }};
114 }
115
116 while i < bytes.len() {
117 if bytes[i].is_ascii_whitespace() {
119 advance!();
120 continue;
121 }
122
123 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
125 while i < bytes.len() && bytes[i] != b'\n' {
126 advance!();
127 }
128 continue;
129 }
130
131 let start = i;
132 let tok_line = line;
133 let tok_col = col;
134
135 if bytes[i] == b'"' || bytes[i] == b'\'' {
137 let quote = bytes[i];
138 advance!();
139 let mut s = String::new();
140 while i < bytes.len() && bytes[i] != quote {
141 if bytes[i] == b'\\' && i + 1 < bytes.len() {
142 advance!(); match bytes[i] {
144 b'n' => s.push('\n'),
145 b't' => s.push('\t'),
146 b'r' => s.push('\r'),
147 b'\\' => s.push('\\'),
148 b'\'' => s.push('\''),
149 b'"' => s.push('"'),
150 c => {
151 s.push('\\');
152 s.push(c as char);
153 }
154 }
155 } else {
156 s.push(bytes[i] as char);
157 }
158 advance!();
159 }
160 if i >= bytes.len() {
161 return Err(ParseError::new(
162 "unterminated string literal",
163 tok_line,
164 tok_col,
165 ));
166 }
167 advance!(); tokens.push(Spanned {
169 token: Token::Str(s),
170 offset: start,
171 line: tok_line,
172 col: tok_col,
173 });
174 continue;
175 }
176
177 if bytes[i].is_ascii_digit()
179 || (bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit())
180 {
181 let neg = bytes[i] == b'-';
182 if neg {
183 col += 1;
184 i += 1;
185 }
186 let num_start = i;
187 while i < bytes.len() && bytes[i].is_ascii_digit() {
188 col += 1;
189 i += 1;
190 }
191 if i < bytes.len()
192 && bytes[i] == b'.'
193 && i + 1 < bytes.len()
194 && bytes[i + 1].is_ascii_digit()
195 {
196 col += 1;
197 i += 1;
198 while i < bytes.len() && bytes[i].is_ascii_digit() {
199 col += 1;
200 i += 1;
201 }
202 let s = &input[if neg { start } else { num_start }..i];
203 let v: f64 = s.parse().map_err(|_| {
204 ParseError::new(format!("invalid float: {s}"), tok_line, tok_col)
205 })?;
206 tokens.push(Spanned {
207 token: Token::Float(v),
208 offset: start,
209 line: tok_line,
210 col: tok_col,
211 });
212 } else {
213 let s = &input[num_start..i];
214 let v: i64 = s.parse().map_err(|_| {
215 ParseError::new(format!("invalid integer: {s}"), tok_line, tok_col)
216 })?;
217 tokens.push(Spanned {
218 token: Token::Integer(if neg { -v } else { v }),
219 offset: start,
220 line: tok_line,
221 col: tok_col,
222 });
223 }
224 continue;
225 }
226
227 if bytes[i].is_ascii_alphabetic() || bytes[i] == b'_' {
229 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
230 col += 1;
231 i += 1;
232 }
233 let word = &input[start..i];
234 let tok = keyword(word).unwrap_or_else(|| Token::Ident(word.to_string()));
235 tokens.push(Spanned {
236 token: tok,
237 offset: start,
238 line: tok_line,
239 col: tok_col,
240 });
241 continue;
242 }
243
244 if i + 1 < bytes.len() {
246 let two = &bytes[i..i + 2];
247 let maybe = match two {
248 b"<>" => Some(Token::Ne),
249 b"<=" => Some(Token::Le),
250 b">=" => Some(Token::Ge),
251 b"->" => Some(Token::Arrow),
252 b"<-" => Some(Token::LArrow),
253 _ => None,
254 };
255 if let Some(tok) = maybe {
256 tokens.push(Spanned {
257 token: tok,
258 offset: start,
259 line: tok_line,
260 col: tok_col,
261 });
262 col += 2;
263 i += 2;
264 continue;
265 }
266 }
267
268 let tok = match bytes[i] {
270 b'=' => Token::Eq,
271 b'<' => Token::Lt,
272 b'>' => Token::Gt,
273 b'.' => Token::Dot,
274 b',' => Token::Comma,
275 b'*' => Token::Star,
276 b'+' => Token::Plus,
277 b'/' => Token::Slash,
278 b'$' => Token::Dollar,
279 b'(' => Token::LParen,
280 b')' => Token::RParen,
281 b'[' => Token::LBrack,
282 b']' => Token::RBrack,
283 b'{' => Token::LBrace,
284 b'}' => Token::RBrace,
285 b':' => Token::Colon,
286 b'|' => Token::Pipe,
287 b'-' => Token::Dash,
288 c => {
289 return Err(ParseError::new(
290 format!("unexpected character: '{}'", c as char),
291 tok_line,
292 tok_col,
293 ))
294 }
295 };
296 tokens.push(Spanned {
297 token: tok,
298 offset: start,
299 line: tok_line,
300 col: tok_col,
301 });
302 advance!();
303 }
304
305 tokens.push(Spanned {
306 token: Token::Eof,
307 offset: input.len(),
308 line,
309 col,
310 });
311 Ok(tokens)
312}
313
314fn keyword(s: &str) -> Option<Token> {
315 match s.to_ascii_uppercase().as_str() {
316 "MATCH" => Some(Token::Match),
317 "OPTIONAL" => Some(Token::Optional),
318 "CREATE" => Some(Token::Create),
319 "MERGE" => Some(Token::Merge),
320 "DELETE" => Some(Token::Delete),
321 "DETACH" => Some(Token::Detach),
322 "SET" => Some(Token::Set),
323 "REMOVE" => Some(Token::Remove),
324 "UNWIND" => Some(Token::Unwind),
325 "UNION" => Some(Token::Union),
326 "ALL" => Some(Token::All),
327 "WHERE" => Some(Token::Where),
328 "RETURN" => Some(Token::Return),
329 "WITH" => Some(Token::With),
330 "AS" => Some(Token::As),
331 "AND" => Some(Token::And),
332 "OR" => Some(Token::Or),
333 "NOT" => Some(Token::Not),
334 "IS" => Some(Token::Is),
335 "NULL" => Some(Token::Null),
336 "TRUE" => Some(Token::True),
337 "FALSE" => Some(Token::False),
338 "ORDER" => Some(Token::Order),
339 "BY" => Some(Token::By),
340 "LIMIT" => Some(Token::Limit),
341 "SKIP" => Some(Token::Skip),
342 "ASC" => Some(Token::Asc),
343 "DESC" => Some(Token::Desc),
344 "STARTS" => Some(Token::Starts),
345 "ENDS" => Some(Token::Ends),
346 "CONTAINS" => Some(Token::Contains),
347 "IN" => Some(Token::In),
348 "CASE" => Some(Token::Case),
349 "WHEN" => Some(Token::When),
350 "THEN" => Some(Token::Then),
351 "ELSE" => Some(Token::Else),
352 "END" => Some(Token::End),
353 "COUNT" => Some(Token::Count),
354 "SUM" => Some(Token::Sum),
355 "AVG" => Some(Token::Avg),
356 "MIN" => Some(Token::Min),
357 "MAX" => Some(Token::Max),
358 "COLLECT" => Some(Token::Collect),
359 _ => None,
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 fn toks(input: &str) -> Vec<Token> {
368 tokenize(input)
369 .unwrap()
370 .into_iter()
371 .map(|s| s.token)
372 .collect()
373 }
374
375 #[test]
376 fn keywords_case_insensitive() {
377 assert_eq!(
378 toks("MATCH match Match OPTIONAL optional CREATE create MERGE merge DELETE delete DETACH detach SET set REMOVE remove UNWIND unwind UNION union ALL all CASE case WHEN when THEN then ELSE else END end"),
379 vec![
380 Token::Match,
381 Token::Match,
382 Token::Match,
383 Token::Optional,
384 Token::Optional,
385 Token::Create,
386 Token::Create,
387 Token::Merge,
388 Token::Merge,
389 Token::Delete,
390 Token::Delete,
391 Token::Detach,
392 Token::Detach,
393 Token::Set,
394 Token::Set,
395 Token::Remove,
396 Token::Remove,
397 Token::Unwind,
398 Token::Unwind,
399 Token::Union,
400 Token::Union,
401 Token::All,
402 Token::All,
403 Token::Case,
404 Token::Case,
405 Token::When,
406 Token::When,
407 Token::Then,
408 Token::Then,
409 Token::Else,
410 Token::Else,
411 Token::End,
412 Token::End,
413 Token::Eof
414 ]
415 );
416 }
417
418 #[test]
419 fn identifiers() {
420 assert_eq!(
421 toks("n r_1 _foo"),
422 vec![
423 Token::Ident("n".into()),
424 Token::Ident("r_1".into()),
425 Token::Ident("_foo".into()),
426 Token::Eof,
427 ]
428 );
429 }
430
431 #[test]
432 fn integer_literals() {
433 assert_eq!(
434 toks("42 0"),
435 vec![Token::Integer(42), Token::Integer(0), Token::Eof]
436 );
437 }
438
439 #[test]
440 fn float_literals() {
441 let t = toks("3.14");
442 assert!(matches!(t[0], Token::Float(f) if (f - (314.0_f64 / 100.0)).abs() < 1e-9));
443 }
444
445 #[test]
446 fn string_literals() {
447 assert_eq!(
448 toks("\"hello\" 'world'"),
449 vec![
450 Token::Str("hello".into()),
451 Token::Str("world".into()),
452 Token::Eof,
453 ]
454 );
455 }
456
457 #[test]
458 fn string_escape() {
459 assert_eq!(
460 toks(r#""a\nb""#),
461 vec![Token::Str("a\nb".into()), Token::Eof]
462 );
463 }
464
465 #[test]
466 fn operators() {
467 assert_eq!(
468 toks("= <> < > <= >= + / $"),
469 vec![
470 Token::Eq,
471 Token::Ne,
472 Token::Lt,
473 Token::Gt,
474 Token::Le,
475 Token::Ge,
476 Token::Plus,
477 Token::Slash,
478 Token::Dollar,
479 Token::Eof,
480 ]
481 );
482 }
483
484 #[test]
485 fn arrows() {
486 assert_eq!(toks("-> <-"), vec![Token::Arrow, Token::LArrow, Token::Eof]);
487 }
488
489 #[test]
490 fn punctuation() {
491 assert_eq!(
492 toks("( ) [ ] : , . * - + / $"),
493 vec![
494 Token::LParen,
495 Token::RParen,
496 Token::LBrack,
497 Token::RBrack,
498 Token::Colon,
499 Token::Comma,
500 Token::Dot,
501 Token::Star,
502 Token::Dash,
503 Token::Plus,
504 Token::Slash,
505 Token::Dollar,
506 Token::Eof,
507 ]
508 );
509 }
510
511 #[test]
512 fn line_comment_skipped() {
513 assert_eq!(
514 toks("MATCH // this is a comment\nRETURN"),
515 vec![Token::Match, Token::Return, Token::Eof,]
516 );
517 }
518
519 #[test]
520 fn unterminated_string_error() {
521 let err = tokenize("\"oops").unwrap_err();
522 assert_eq!(err.line, 1);
523 assert_eq!(err.col, 1);
524 }
525
526 #[test]
527 fn unknown_char_error() {
528 let err = tokenize("@").unwrap_err();
529 assert_eq!(err.line, 1);
530 assert_eq!(err.col, 1);
531 }
532
533 #[test]
534 fn span_first_token() {
535 let tokens = tokenize("MATCH").unwrap();
536 assert_eq!(tokens[0].line, 1);
537 assert_eq!(tokens[0].col, 1);
538 }
539
540 #[test]
541 fn span_second_line() {
542 let tokens = tokenize("MATCH\nRETURN").unwrap();
543 assert_eq!(tokens[0].line, 1);
545 assert_eq!(tokens[0].col, 1);
546 assert_eq!(tokens[1].line, 2);
547 assert_eq!(tokens[1].col, 1);
548 }
549
550 #[test]
551 fn span_column_offset() {
552 let tokens = tokenize("MATCH (n)").unwrap();
553 assert_eq!(tokens[1].line, 1);
555 assert_eq!(tokens[1].col, 7);
556 }
557
558 #[test]
559 fn unknown_char_mid_line_error() {
560 let err = tokenize("MATCH @").unwrap_err();
561 assert_eq!(err.line, 1);
562 assert_eq!(err.col, 7);
563 }
564
565 #[test]
566 fn bool_literals() {
567 assert_eq!(
568 toks("true false TRUE FALSE"),
569 vec![
570 Token::True,
571 Token::False,
572 Token::True,
573 Token::False,
574 Token::Eof,
575 ]
576 );
577 }
578}