1pub(crate) mod locy_parser;
2mod locy_walker;
3mod walker;
4
5use crate::ast::{Expr, Query};
6use crate::locy_ast::LocyProgram;
7use pest::Parser;
8use pest_derive::Parser;
9
10#[derive(Debug, thiserror::Error)]
12#[error("{message}")]
13pub struct ParseError {
14 message: String,
15}
16
17impl ParseError {
18 pub fn new(message: String) -> Self {
19 Self { message }
20 }
21}
22
23#[derive(Parser)]
24#[grammar = "grammar/cypher.pest"]
25pub struct CypherParser;
26
27pub fn parse(input: &str) -> Result<Query, ParseError> {
28 let pairs = CypherParser::parse(Rule::query, input).map_err(|e| map_pest_error(input, e))?;
29
30 walker::build_query(pairs)
31}
32
33pub fn parse_expression(input: &str) -> Result<Expr, ParseError> {
34 let pairs =
35 CypherParser::parse(Rule::expression, input).map_err(|e| map_pest_error(input, e))?;
36
37 walker::build_expression(pairs.into_iter().next().unwrap())
38}
39
40pub fn parse_locy(input: &str) -> Result<LocyProgram, ParseError> {
41 use locy_parser::LocyParser;
42 use locy_parser::Rule as LocyRule;
43
44 let pairs = LocyParser::parse(LocyRule::locy_query, input)
45 .map_err(|e| map_locy_pest_error(input, e))?;
46
47 locy_walker::build_program(pairs.into_iter().next().unwrap())
48}
49
50fn expects_identifier(e: &pest::error::Error<Rule>) -> bool {
54 use pest::error::ErrorVariant;
55 match &e.variant {
56 ErrorVariant::ParsingError { positives, .. } => positives
57 .iter()
58 .any(|r| matches!(r, Rule::identifier | Rule::identifier_or_keyword)),
59 _ => false,
60 }
61}
62
63fn error_position<R: pest::RuleType>(e: &pest::error::Error<R>) -> usize {
64 match e.location {
65 pest::error::InputLocation::Pos(p) => p,
66 pest::error::InputLocation::Span((s, _)) => s,
67 }
68}
69
70fn extract_token_span_at(input: &str, pos: usize) -> Option<(usize, usize)> {
71 let bytes = input.as_bytes();
72 if bytes.is_empty() {
73 return None;
74 }
75
76 let mut p = pos.min(bytes.len() - 1);
77
78 let is_token_char =
79 |b: u8| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b'#' | b'$');
80
81 if !is_token_char(bytes[p]) {
82 if p == 0 || !is_token_char(bytes[p - 1]) {
83 return None;
84 }
85 p -= 1;
86 }
87
88 let mut start = p;
89 while start > 0 && is_token_char(bytes[start - 1]) {
90 start -= 1;
91 }
92
93 let mut end = p;
94 while end < bytes.len() && is_token_char(bytes[end]) {
95 end += 1;
96 }
97
98 Some((start, end))
99}
100
101fn is_map_key_like_context(input: &str, start: usize, end: usize) -> bool {
102 let bytes = input.as_bytes();
103 if bytes.is_empty() || start >= bytes.len() || end > bytes.len() {
104 return false;
105 }
106
107 let mut colon_pos = end;
108 while colon_pos < bytes.len() && bytes[colon_pos].is_ascii_whitespace() {
109 colon_pos += 1;
110 }
111 if colon_pos >= bytes.len() || bytes[colon_pos] != b':' {
112 return false;
113 }
114
115 let mut prev_pos = start;
116 while prev_pos > 0 && bytes[prev_pos - 1].is_ascii_whitespace() {
117 prev_pos -= 1;
118 }
119 if prev_pos == 0 {
120 return false;
121 }
122
123 matches!(bytes[prev_pos - 1], b'{' | b',')
124}
125
126fn relationship_bracket_segment(input: &str, pos: usize) -> Option<&str> {
127 let pos = pos.min(input.len());
128 let before = &input[..pos];
129 let start = before.rfind('[')?;
130
131 let prefix = &input[..start];
133 if !prefix.trim_end().ends_with('-') {
134 return None;
135 }
136
137 let after = &input[start..];
138 let end = after.find(']').map(|i| start + i + 1).unwrap_or(pos);
139 Some(&input[start..end])
140}
141
142fn is_invalid_relationship_pattern(input: &str, pos: usize) -> bool {
143 let Some(segment) = relationship_bracket_segment(input, pos) else {
144 return false;
145 };
146 (segment.contains("..") && !segment.contains('*')) || segment.contains("*-")
148}
149
150fn is_invalid_number_literal(input: &str, pos: usize) -> bool {
151 let Some((start, end)) = extract_token_span_at(input, pos) else {
152 return false;
153 };
154 if is_map_key_like_context(input, start, end) {
155 return false;
156 }
157 let token = &input[start..end];
158
159 let t = token.strip_prefix('-').unwrap_or(token);
160 if !t.as_bytes().first().is_some_and(|b| b.is_ascii_digit()) {
161 return false;
162 }
163
164 let has_only = |digits: &str, valid: fn(&char) -> bool| {
165 digits.is_empty() || !digits.chars().all(|c| valid(&c) || c == '_')
166 };
167
168 if let Some(digits) = t.strip_prefix("0x").or_else(|| t.strip_prefix("0X")) {
169 return has_only(digits, char::is_ascii_hexdigit);
170 }
171 if let Some(digits) = t.strip_prefix("0o").or_else(|| t.strip_prefix("0O")) {
172 return has_only(digits, |c| matches!(c, '0'..='7'));
173 }
174
175 t.chars().any(|c| c.is_ascii_alphabetic())
177}
178
179fn invalid_unicode_character(input: &str, pos: usize) -> Option<char> {
180 let ch = input.get(pos..)?.chars().next()?;
181 matches!(ch, '—' | '–' | '−').then_some(ch)
182}
183
184const CYPHER_RESERVED_KEYWORDS: &[&str] = &[
187 "match",
188 "optional",
189 "where",
190 "create",
191 "merge",
192 "set",
193 "remove",
194 "delete",
195 "detach",
196 "return",
197 "with",
198 "unwind",
199 "union",
200 "call",
201 "yield",
202 "distinct",
203 "order",
204 "by",
205 "asc",
206 "desc",
207 "skip",
208 "limit",
209 "as",
210 "and",
211 "or",
212 "xor",
213 "not",
214 "in",
215 "contains",
216 "starts",
217 "ends",
218 "is",
219 "null",
220 "true",
221 "false",
222 "case",
223 "when",
224 "then",
225 "else",
226 "if",
227 "from",
228 "to",
229 "on",
230 "drop",
231 "alter",
232 "show",
233 "over",
234 "partition",
235 "explain",
236 "recursive",
237 "valid_at",
238 "each",
239];
240
241const LOCY_RESERVED_KEYWORDS: &[&str] = &[
243 "rule", "along", "prev", "fold", "best", "derive", "assume", "abduce", "query",
244];
245
246fn reserved_keyword_at(input: &str, pos: usize, extra_keywords: &[&str]) -> Option<String> {
248 let (start, end) = extract_token_span_at(input, pos)?;
249 let token = &input[start..end];
250 let lower = token.to_lowercase();
251 if CYPHER_RESERVED_KEYWORDS.contains(&lower.as_str())
252 || extra_keywords.contains(&lower.as_str())
253 {
254 Some(token.to_string())
255 } else {
256 None
257 }
258}
259
260fn locy_context_category(input: &str, pos: usize) -> Option<&'static str> {
262 let before = input[..pos].trim_end();
263 let before_upper = before.to_uppercase();
264 if before_upper.ends_with("BEST BY") {
266 return Some("InvalidBestByClause");
267 }
268 if before_upper.ends_with("ALONG") {
269 return Some("InvalidAlongClause");
270 }
271 if before_upper.ends_with("FOLD") {
272 return Some("InvalidFoldClause");
273 }
274 if before_upper.ends_with("ASSUME") {
275 return Some("InvalidAssumeBlock");
276 }
277 if before_upper.ends_with("DERIVE") {
278 return Some("InvalidDeriveCommand");
279 }
280 if before_upper.contains("CREATE RULE") {
282 return Some("InvalidRuleDefinition");
283 }
284 if before_upper.ends_with("QUERY") && !before_upper.contains("CREATE RULE") {
286 return Some("InvalidGoalQuery");
287 }
288 None
289}
290
291fn map_locy_pest_error(input: &str, e: pest::error::Error<locy_parser::Rule>) -> ParseError {
292 let pos = error_position(&e);
293
294 if is_invalid_relationship_pattern(input, pos) {
296 return ParseError::new(format!("LocySyntaxError: InvalidRelationshipPattern - {e}"));
297 }
298 if is_invalid_number_literal(input, pos) {
299 return ParseError::new(format!("LocySyntaxError: InvalidNumberLiteral - {e}"));
300 }
301 if let Some(ch) = invalid_unicode_character(input, pos) {
302 return ParseError::new(format!(
303 "LocySyntaxError: InvalidUnicodeCharacter - Invalid character '{ch}'"
304 ));
305 }
306 if let Some(kw) = reserved_keyword_at(input, pos, LOCY_RESERVED_KEYWORDS) {
307 return ParseError::new(format!(
308 "LocySyntaxError: ReservedKeyword - \"{kw}\" is a reserved keyword \
309 and cannot be used as a variable name. Use backtick-quoting: `{kw}`\n{e}"
310 ));
311 }
312
313 if let Some(category) = locy_context_category(input, pos) {
315 return ParseError::new(format!("LocySyntaxError: {category} - {e}"));
316 }
317
318 ParseError::new(format!("LocySyntaxError: {e}"))
319}
320
321fn map_pest_error(input: &str, e: pest::error::Error<Rule>) -> ParseError {
322 let pos = error_position(&e);
323 if is_invalid_relationship_pattern(input, pos) {
324 return ParseError::new(format!("SyntaxError: InvalidRelationshipPattern - {e}"));
325 }
326 if is_invalid_number_literal(input, pos) {
327 return ParseError::new(format!("SyntaxError: InvalidNumberLiteral - {e}"));
328 }
329 if let Some(ch) = invalid_unicode_character(input, pos) {
330 return ParseError::new(format!(
331 "SyntaxError: InvalidUnicodeCharacter - Invalid character '{ch}'"
332 ));
333 }
334 if let Some(kw) = expects_identifier(&e)
335 .then(|| reserved_keyword_at(input, pos, &[]))
336 .flatten()
337 {
338 return ParseError::new(format!(
339 "SyntaxError: ReservedKeyword - \"{kw}\" is a reserved keyword \
340 and cannot be used as a variable name. Use backtick-quoting: `{kw}`\n{e}"
341 ));
342 }
343
344 ParseError::new(format!("UnexpectedSyntax: {e}"))
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_expression_parsing() {
353 let cases = [
354 ("1", Rule::integer),
355 ("3.14", Rule::float),
356 ("'hello'", Rule::string),
357 ("n.name", Rule::expression),
358 ("1 + 2", Rule::expression),
359 ("a AND b OR c", Rule::expression),
360 ];
361
362 for (input, rule) in cases {
363 let result = CypherParser::parse(rule, input);
364 assert!(
365 result.is_ok(),
366 "Failed to parse '{}' as {:?}: {:?}",
367 input,
368 rule,
369 result.err()
370 );
371 }
372 }
373
374 #[test]
375 fn test_list_expressions() {
376 assert!(parse_expression("[]").is_ok());
378
379 assert!(parse_expression("[1, 2, 3]").is_ok());
381
382 assert!(parse_expression("[x IN range(1,10) | x * 2]").is_ok());
384 assert!(parse_expression("[x IN list WHERE x > 5 | x]").is_ok());
385
386 assert!(parse_expression("[(n)-[:KNOWS]->(m) | m.name]").is_ok());
388 assert!(parse_expression("[p = (n)-->(m) WHERE m.age > 30 | p]").is_ok());
389 }
390
391 #[test]
392 fn test_ambiguous_cases() {
393 assert!(parse_expression("[n]").is_ok()); assert!(parse_expression("[n.name]").is_ok()); assert!(parse_expression("[n IN list]").is_ok()); assert!(parse_expression("[(n)]").is_ok()); }
418
419 fn parse_err_msg(input: &str) -> String {
420 parse(input).unwrap_err().to_string()
421 }
422
423 #[test]
424 fn test_invalid_relationship_pattern_missing_star_error_code() {
425 let msg = parse_err_msg("MATCH (a:A)\nMATCH (a)-[:LIKES..]->(c)\nRETURN c.name");
426 assert!(
427 msg.contains("InvalidRelationshipPattern"),
428 "expected InvalidRelationshipPattern, got: {msg}"
429 );
430 }
431
432 #[test]
433 fn test_invalid_number_literal_error_code_decimal_alpha() {
434 let msg = parse_err_msg("RETURN 9223372h54775808 AS literal");
435 assert!(
436 msg.contains("InvalidNumberLiteral"),
437 "expected InvalidNumberLiteral, got: {msg}"
438 );
439 }
440
441 #[test]
442 fn test_invalid_number_literal_error_code_hex_prefix_only() {
443 let msg = parse_err_msg("RETURN 0x AS literal");
444 assert!(
445 msg.contains("InvalidNumberLiteral"),
446 "expected InvalidNumberLiteral, got: {msg}"
447 );
448 }
449
450 #[test]
451 fn test_invalid_unicode_character_error_code() {
452 let msg = parse_err_msg("RETURN 42 — 41");
453 assert!(
454 msg.contains("InvalidUnicodeCharacter"),
455 "expected InvalidUnicodeCharacter, got: {msg}"
456 );
457 }
458
459 #[test]
460 fn test_symbol_in_number_stays_unexpected_syntax() {
461 let msg = parse_err_msg("RETURN 9223372#54775808 AS literal");
462 assert!(
463 msg.contains("UnexpectedSyntax"),
464 "expected UnexpectedSyntax, got: {msg}"
465 );
466 }
467
468 #[test]
469 fn test_map_key_starting_with_number_stays_unexpected_syntax() {
470 let msg = parse_err_msg("RETURN {1B2c3e67:1} AS literal");
471 assert!(
472 msg.contains("UnexpectedSyntax"),
473 "expected UnexpectedSyntax, got: {msg}"
474 );
475 }
476
477 #[test]
478 fn test_unary_minus_double() {
479 use crate::ast::{CypherLiteral, Expr};
480 let expr = parse_expression("--5").expect("--5 should parse");
482 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(5)));
483 }
484
485 #[test]
486 fn test_unary_minus_single() {
487 use crate::ast::{CypherLiteral, Expr};
488 let expr = parse_expression("-5").expect("-5 should parse");
490 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
491 }
492
493 #[test]
494 fn test_unary_minus_triple() {
495 use crate::ast::{CypherLiteral, Expr};
496 let expr = parse_expression("---5").expect("---5 should parse");
498 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
499 }
500
501 #[test]
502 fn test_unary_plus_identity() {
503 use crate::ast::{CypherLiteral, Expr};
504 let expr = parse_expression("+5").expect("+5 should parse");
506 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(5)));
507 }
508
509 #[test]
510 fn test_unary_plus_minus() {
511 use crate::ast::{CypherLiteral, Expr};
512 let expr = parse_expression("+-5").expect("+-5 should parse");
514 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
515 }
516
517 #[test]
518 fn test_unary_minus_plus() {
519 use crate::ast::{CypherLiteral, Expr};
520 let expr = parse_expression("-+5").expect("-+5 should parse");
522 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(-5)));
523 }
524
525 #[test]
526 fn test_unary_double_minus_overflow() {
527 let result = parse_expression("--9223372036854775808");
529 assert!(
530 result.is_err(),
531 "expected overflow error, got: {:?}",
532 result
533 );
534 let msg = result.unwrap_err().to_string();
535 assert!(
536 msg.contains("IntegerOverflow"),
537 "expected IntegerOverflow, got: {msg}"
538 );
539 }
540
541 #[test]
542 fn test_unary_minus_i64_min() {
543 use crate::ast::{CypherLiteral, Expr};
544 let expr = parse_expression("-9223372036854775808").expect("-i64::MIN should parse");
546 assert_eq!(expr, Expr::Literal(CypherLiteral::Integer(i64::MIN)));
547 }
548
549 #[test]
550 fn test_stacked_predicates_is_null_is_not_null() {
551 let result = parse("RETURN x IS NULL IS NOT NULL");
553 assert!(
554 result.is_err(),
555 "expected parse error for stacked IS NULL IS NOT NULL"
556 );
557 let msg = result.unwrap_err().to_string();
558 assert!(
559 msg.contains("InvalidPredicateChain"),
560 "expected InvalidPredicateChain, got: {msg}"
561 );
562 }
563
564 #[test]
565 fn test_stacked_predicates_starts_with() {
566 let result = parse("RETURN x STARTS WITH 'a' STARTS WITH 'b'");
568 assert!(
569 result.is_err(),
570 "expected parse error for stacked STARTS WITH"
571 );
572 let msg = result.unwrap_err().to_string();
573 assert!(
574 msg.contains("InvalidPredicateChain"),
575 "expected InvalidPredicateChain, got: {msg}"
576 );
577 }
578
579 #[test]
580 fn test_stacked_predicates_in() {
581 let result = parse("RETURN x IN [1] IN [true]");
583 assert!(result.is_err(), "expected parse error for stacked IN");
584 let msg = result.unwrap_err().to_string();
585 assert!(
586 msg.contains("InvalidPredicateChain"),
587 "expected InvalidPredicateChain, got: {msg}"
588 );
589 }
590
591 #[test]
592 fn test_stacked_predicates_contains_ends_with() {
593 let result = parse("RETURN x CONTAINS 'a' ENDS WITH 'b'");
595 assert!(
596 result.is_err(),
597 "expected parse error for stacked CONTAINS/ENDS WITH"
598 );
599 let msg = result.unwrap_err().to_string();
600 assert!(
601 msg.contains("InvalidPredicateChain"),
602 "expected InvalidPredicateChain, got: {msg}"
603 );
604 }
605
606 #[test]
607 fn test_label_stacking_allowed() {
608 assert!(
611 parse("MATCH (x) WHERE x:Person:Employee RETURN x").is_ok(),
612 "label stacking should be allowed"
613 );
614 }
615
616 #[test]
617 fn test_range_chaining_allowed() {
618 assert!(
620 parse("MATCH (n) WHERE 1 < n.num < 3 RETURN n").is_ok(),
621 "range chaining 1 < n.num < 3 should be allowed"
622 );
623 }
624
625 #[test]
626 fn test_reserved_keyword_as_variable_name() {
627 let msg = parse_err_msg("MATCH (match:N) RETURN match");
628 assert!(
629 msg.contains("ReservedKeyword"),
630 "expected ReservedKeyword, got: {msg}"
631 );
632 assert!(
633 msg.contains("backtick-quoting"),
634 "expected backtick hint, got: {msg}"
635 );
636 }
637
638 #[test]
639 fn test_reserved_keyword_return_as_variable() {
640 let msg = parse_err_msg("MATCH (return:N) RETURN return");
641 assert!(
642 msg.contains("ReservedKeyword"),
643 "expected ReservedKeyword, got: {msg}"
644 );
645 }
646
647 #[test]
648 fn test_non_reserved_keyword_allowed() {
649 assert!(
651 parse("MATCH (end:N) RETURN end").is_ok(),
652 "non-reserved keyword 'end' should be allowed as variable name"
653 );
654 }
655
656 #[test]
657 fn test_backtick_escaped_reserved_keyword() {
658 assert!(
659 parse("MATCH (`match`:N) RETURN `match`").is_ok(),
660 "backtick-escaped reserved keyword should be allowed"
661 );
662 }
663}