1use chumsky::prelude::*;
4use smol_str::SmolStr;
5
6use crate::ast::*;
7use crate::span::Spanned;
8use crate::token::Token;
9
10use super::expression::expression_parser;
11use super::pattern::{ident, pattern_parser};
12
13type ParserError = Simple<Token>;
14
15pub fn match_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
21 let optional = just(Token::Optional).or_not().map(|o| o.is_some());
22
23 optional
24 .then_ignore(just(Token::Match))
25 .then(
26 pattern_parser()
27 .separated_by(just(Token::Comma))
28 .at_least(1)
29 .labelled("match pattern"),
30 )
31 .then(where_clause().or_not())
32 .map(|((is_optional, patterns), where_clause)| {
33 ReadingClause::Match(MatchClause {
34 is_optional,
35 patterns,
36 where_clause,
37 })
38 })
39 .labelled("match clause")
40}
41
42fn where_clause() -> impl Parser<Token, Spanned<Expression>, Error = ParserError> + Clone {
44 just(Token::Where)
45 .ignore_then(expression_parser())
46 .labelled("where clause")
47}
48
49pub fn unwind_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
51 just(Token::Unwind)
52 .ignore_then(expression_parser())
53 .then_ignore(just(Token::As))
54 .then(ident().map_with_span(|n, s| (n, s)))
55 .map(|(expression, alias)| ReadingClause::Unwind(UnwindClause { expression, alias }))
56 .labelled("unwind clause")
57}
58
59pub fn create_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
65 just(Token::Create)
66 .ignore_then(
67 pattern_parser()
68 .separated_by(just(Token::Comma))
69 .at_least(1),
70 )
71 .map(UpdatingClause::Create)
72 .labelled("create clause")
73}
74
75pub fn set_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
77 just(Token::Set)
78 .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
79 .map(UpdatingClause::Set)
80 .labelled("set clause")
81}
82
83fn set_item() -> impl Parser<Token, SetItem, Error = ParserError> + Clone {
84 let property_chain = ident()
87 .map_with_span(|name, span| {
88 (Expression::Variable(name), span)
89 })
90 .then(
91 just(Token::Dot)
92 .ignore_then(ident().map_with_span(|n, s| (n, s)))
93 .repeated()
94 .at_least(1),
95 )
96 .foldl(|base, key| {
97 let span = base.1.start..key.1.end;
98 (
99 Expression::Property {
100 object: Box::new(base),
101 key,
102 },
103 span,
104 )
105 });
106
107 let set_property = property_chain
108 .then_ignore(just(Token::Eq))
109 .then(expression_parser())
110 .map(|(entity, value)| SetItem::Property { entity, value });
111
112 let set_labels = ident()
114 .map_with_span(|n, s| (n, s))
115 .then(
116 just(Token::Colon)
117 .ignore_then(ident().map_with_span(|n, s| (n, s)))
118 .repeated()
119 .at_least(1),
120 )
121 .map(|(entity, labels)| SetItem::Labels { entity, labels });
122
123 set_property.or(set_labels)
124}
125
126pub fn delete_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
128 let detach = just(Token::Detach).or_not().map(|d| d.is_some());
129
130 detach
131 .then_ignore(just(Token::Delete))
132 .then(
133 expression_parser()
134 .separated_by(just(Token::Comma))
135 .at_least(1),
136 )
137 .map(|(detach, expressions)| {
138 UpdatingClause::Delete(DeleteClause {
139 detach,
140 expressions,
141 })
142 })
143 .labelled("delete clause")
144}
145
146pub fn merge_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
148 let on_match = just(Token::On)
149 .then_ignore(just(Token::Match))
150 .then_ignore(just(Token::Set))
151 .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
152 .or_not()
153 .map(|o| o.unwrap_or_default());
154
155 let on_create = just(Token::On)
156 .then_ignore(just(Token::Create))
157 .then_ignore(just(Token::Set))
158 .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
159 .or_not()
160 .map(|o| o.unwrap_or_default());
161
162 just(Token::Merge)
163 .ignore_then(pattern_parser())
164 .then(on_match)
165 .then(on_create)
166 .map(|((pattern, on_match), on_create)| {
167 UpdatingClause::Merge(MergeClause {
168 pattern,
169 on_match,
170 on_create,
171 })
172 })
173 .labelled("merge clause")
174}
175
176pub fn projection_body() -> impl Parser<Token, ProjectionBody, Error = ParserError> + Clone {
182 let distinct = just(Token::Distinct).or_not().map(|d| d.is_some());
183
184 let alias = just(Token::As)
185 .ignore_then(ident().map_with_span(|n, s| (n, s)))
186 .or_not();
187
188 let item = expression_parser().then(alias);
189
190 let items = just(Token::Star)
191 .to(ProjectionItems::All)
192 .or(item
193 .separated_by(just(Token::Comma))
194 .at_least(1)
195 .map(ProjectionItems::Expressions));
196
197 let sort_order = choice((
198 just(Token::Asc).to(SortOrder::Ascending),
199 just(Token::Desc).to(SortOrder::Descending),
200 ))
201 .or_not()
202 .map(|o| o.unwrap_or(SortOrder::Ascending));
203
204 let order_by = just(Token::Order)
205 .ignore_then(just(Token::By))
206 .ignore_then(
207 expression_parser()
208 .then(sort_order)
209 .separated_by(just(Token::Comma))
210 .at_least(1),
211 )
212 .or_not()
213 .map(|o| o.unwrap_or_default());
214
215 let skip = just(Token::Skip)
216 .ignore_then(expression_parser())
217 .or_not();
218
219 let limit = just(Token::Limit)
220 .ignore_then(expression_parser())
221 .or_not();
222
223 distinct
224 .then(items)
225 .then(order_by)
226 .then(skip)
227 .then(limit)
228 .map(|((((distinct, items), order_by), skip), limit)| ProjectionBody {
229 distinct,
230 items,
231 order_by,
232 skip,
233 limit,
234 })
235}
236
237pub fn return_clause() -> impl Parser<Token, ProjectionBody, Error = ParserError> + Clone {
239 just(Token::Return)
240 .ignore_then(projection_body())
241 .labelled("return clause")
242}
243
244pub fn with_clause() -> impl Parser<Token, (ProjectionBody, Option<Spanned<Expression>>), Error = ParserError> + Clone
246{
247 just(Token::With)
248 .ignore_then(projection_body())
249 .then(where_clause().or_not())
250 .labelled("with clause")
251}
252
253pub fn reading_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
259 choice((match_clause(), unwind_clause()))
260}
261
262pub fn updating_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
264 choice((
265 create_clause(),
266 merge_clause(),
267 set_clause(),
268 delete_clause(),
269 ))
270}
271
272pub fn standalone_call() -> impl Parser<Token, StandaloneCall, Error = ParserError> + Clone {
278 let procedure_name = ident()
280 .map_with_span(|n, s| (n, s))
281 .then(
282 just(Token::Dot)
283 .ignore_then(ident().map_with_span(|n, s| (n, s)))
284 .repeated(),
285 )
286 .map(|(first, rest)| {
287 let mut full_name = first.0.to_string();
288 let start = first.1.start;
289 let mut end = first.1.end;
290 for part in &rest {
291 full_name.push('.');
292 full_name.push_str(&part.0);
293 end = part.1.end;
294 }
295 (SmolStr::new(&full_name), start..end)
296 });
297
298 let args = expression_parser()
299 .separated_by(just(Token::Comma))
300 .delimited_by(just(Token::LeftParen), just(Token::RightParen))
301 .or_not()
302 .map(|a| a.unwrap_or_default());
303
304 just(Token::Call)
305 .ignore_then(procedure_name)
306 .then(args)
307 .map(|(procedure, args)| StandaloneCall { procedure, args })
308 .labelled("call statement")
309}
310
311pub fn transaction_statement(
317) -> impl Parser<Token, TransactionStatement, Error = ParserError> + Clone {
318 let mode = choice((
319 just(Token::Read)
320 .then_ignore(just(Token::Only))
321 .to(TransactionMode::ReadOnly),
322 just(Token::Read)
323 .then_ignore(just(Token::Write))
324 .to(TransactionMode::ReadWrite),
325 ))
326 .or_not()
327 .map(|m| m.unwrap_or(TransactionMode::ReadWrite));
328
329 let begin = just(Token::Begin)
330 .ignore_then(just(Token::Transaction).or_not())
331 .ignore_then(mode)
332 .map(TransactionStatement::Begin);
333
334 let commit = just(Token::Commit).to(TransactionStatement::Commit);
335 let rollback = just(Token::Rollback).to(TransactionStatement::Rollback);
336
337 choice((begin, commit, rollback)).labelled("transaction statement")
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::lexer::Lexer;
344
345 fn tokens(src: &str) -> Vec<Spanned<Token>> {
346 let (tokens, errors) = Lexer::new(src).lex();
347 assert!(errors.is_empty());
348 tokens
349 }
350
351 fn parse_with<T>(
352 parser: impl Parser<Token, T, Error = ParserError>,
353 src: &str,
354 ) -> Option<T> {
355 let toks = tokens(src);
356 let len = src.len();
357 let stream = chumsky::Stream::from_iter(
358 len..len + 1,
359 toks.into_iter()
360 .filter(|(tok, _)| !matches!(tok, Token::Eof)),
361 );
362 let (result, errors) = parser.then_ignore(end()).parse_recovery(stream);
363 if !errors.is_empty() {
364 eprintln!("parse errors: {errors:?}");
365 }
366 result
367 }
368
369 #[test]
370 fn simple_match() {
371 let clause = parse_with(match_clause(), "MATCH (n:Person)").unwrap();
372 if let ReadingClause::Match(m) = clause {
373 assert!(!m.is_optional);
374 assert_eq!(m.patterns.len(), 1);
375 } else {
376 panic!("expected match clause");
377 }
378 }
379
380 #[test]
381 fn optional_match() {
382 let clause = parse_with(match_clause(), "OPTIONAL MATCH (n)").unwrap();
383 if let ReadingClause::Match(m) = clause {
384 assert!(m.is_optional);
385 } else {
386 panic!("expected match clause");
387 }
388 }
389
390 #[test]
391 fn match_with_where() {
392 let clause = parse_with(match_clause(), "MATCH (n:Person) WHERE n.age > 30").unwrap();
393 if let ReadingClause::Match(m) = clause {
394 assert!(m.where_clause.is_some());
395 } else {
396 panic!("expected match clause");
397 }
398 }
399
400 #[test]
401 fn return_star() {
402 let proj = parse_with(return_clause(), "RETURN *").unwrap();
403 assert!(matches!(proj.items, ProjectionItems::All));
404 }
405
406 #[test]
407 fn return_with_alias() {
408 let proj = parse_with(return_clause(), "RETURN n.name AS name").unwrap();
409 if let ProjectionItems::Expressions(items) = &proj.items {
410 assert_eq!(items.len(), 1);
411 assert!(items[0].1.is_some());
412 } else {
413 panic!("expected expressions");
414 }
415 }
416
417 #[test]
418 fn return_with_order_by() {
419 let proj = parse_with(return_clause(), "RETURN n ORDER BY n.age DESC").unwrap();
420 assert_eq!(proj.order_by.len(), 1);
421 assert_eq!(proj.order_by[0].1, SortOrder::Descending);
422 }
423
424 #[test]
425 fn return_with_limit_skip() {
426 let proj = parse_with(return_clause(), "RETURN n SKIP 10 LIMIT 5").unwrap();
427 assert!(proj.skip.is_some());
428 assert!(proj.limit.is_some());
429 }
430
431 #[test]
432 fn create_node() {
433 let clause = parse_with(
434 create_clause(),
435 "CREATE (n:Person {name: 'Alice', age: 30})",
436 )
437 .unwrap();
438 if let UpdatingClause::Create(patterns) = clause {
439 assert_eq!(patterns.len(), 1);
440 } else {
441 panic!("expected create clause");
442 }
443 }
444
445 #[test]
446 fn delete_detach() {
447 let clause = parse_with(delete_clause(), "DETACH DELETE n").unwrap();
448 if let UpdatingClause::Delete(d) = clause {
449 assert!(d.detach);
450 } else {
451 panic!("expected delete clause");
452 }
453 }
454
455 #[test]
456 fn set_property() {
457 let clause = parse_with(set_clause(), "SET n.age = 31").unwrap();
458 assert!(matches!(clause, UpdatingClause::Set(_)));
459 }
460
461 #[test]
462 fn unwind() {
463 let clause = parse_with(unwind_clause(), "UNWIND [1, 2, 3] AS x").unwrap();
464 assert!(matches!(clause, ReadingClause::Unwind(_)));
465 }
466
467 #[test]
468 fn transaction_begin() {
469 let stmt = parse_with(transaction_statement(), "BEGIN TRANSACTION READ ONLY").unwrap();
470 assert!(matches!(
471 stmt,
472 TransactionStatement::Begin(TransactionMode::ReadOnly)
473 ));
474 }
475
476 #[test]
477 fn transaction_commit() {
478 let stmt = parse_with(transaction_statement(), "COMMIT").unwrap();
479 assert!(matches!(stmt, TransactionStatement::Commit));
480 }
481}