Skip to main content

kyu_parser/parser/
clause.rs

1//! MATCH, RETURN, WITH, WHERE, CREATE, SET, DELETE and other clause parsers.
2
3use chumsky::prelude::*;
4use smol_str::SmolStr;
5
6use crate::ast::*;
7use crate::span::Spanned;
8use crate::token::Token;
9
10use super::expression::expression_parser;
11use super::pattern::{ident, pattern_parser};
12
13type ParserError = Simple<Token>;
14
15// =============================================================================
16// Reading Clauses
17// =============================================================================
18
19/// Parse a MATCH clause: `[OPTIONAL] MATCH pattern [WHERE expr]`
20pub fn match_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
21    let optional = just(Token::Optional).or_not().map(|o| o.is_some());
22
23    optional
24        .then_ignore(just(Token::Match))
25        .then(
26            pattern_parser()
27                .separated_by(just(Token::Comma))
28                .at_least(1)
29                .labelled("match pattern"),
30        )
31        .then(where_clause().or_not())
32        .map(|((is_optional, patterns), where_clause)| {
33            ReadingClause::Match(MatchClause {
34                is_optional,
35                patterns,
36                where_clause,
37            })
38        })
39        .labelled("match clause")
40}
41
42/// Parse a WHERE clause: `WHERE expr`
43fn where_clause() -> impl Parser<Token, Spanned<Expression>, Error = ParserError> + Clone {
44    just(Token::Where)
45        .ignore_then(expression_parser())
46        .labelled("where clause")
47}
48
49/// Parse an UNWIND clause: `UNWIND expr AS alias`
50pub fn unwind_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
51    just(Token::Unwind)
52        .ignore_then(expression_parser())
53        .then_ignore(just(Token::As))
54        .then(ident().map_with_span(|n, s| (n, s)))
55        .map(|(expression, alias)| ReadingClause::Unwind(UnwindClause { expression, alias }))
56        .labelled("unwind clause")
57}
58
59// =============================================================================
60// Updating Clauses
61// =============================================================================
62
63/// Parse a CREATE clause: `CREATE pattern, pattern, ...`
64pub fn create_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
65    just(Token::Create)
66        .ignore_then(
67            pattern_parser()
68                .separated_by(just(Token::Comma))
69                .at_least(1),
70        )
71        .map(UpdatingClause::Create)
72        .labelled("create clause")
73}
74
75/// Parse a SET clause: `SET item, item, ...`
76pub fn set_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
77    just(Token::Set)
78        .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
79        .map(UpdatingClause::Set)
80        .labelled("set clause")
81}
82
83fn set_item() -> impl Parser<Token, SetItem, Error = ParserError> + Clone {
84    // Parse left side as a property chain: n.prop or n.a.b
85    // We can't use expression_parser() here because it would consume `=` as comparison.
86    let property_chain = ident()
87        .map_with_span(|name, span| (Expression::Variable(name), span))
88        .then(
89            just(Token::Dot)
90                .ignore_then(ident().map_with_span(|n, s| (n, s)))
91                .repeated()
92                .at_least(1),
93        )
94        .foldl(|base, key| {
95            let span = base.1.start..key.1.end;
96            (
97                Expression::Property {
98                    object: Box::new(base),
99                    key,
100                },
101                span,
102            )
103        });
104
105    let set_property = property_chain
106        .then_ignore(just(Token::Eq))
107        .then(expression_parser())
108        .map(|(entity, value)| SetItem::Property { entity, value });
109
110    // SET n:Label1:Label2 — setting labels on a node
111    let set_labels = ident()
112        .map_with_span(|n, s| (n, s))
113        .then(
114            just(Token::Colon)
115                .ignore_then(ident().map_with_span(|n, s| (n, s)))
116                .repeated()
117                .at_least(1),
118        )
119        .map(|(entity, labels)| SetItem::Labels { entity, labels });
120
121    set_property.or(set_labels)
122}
123
124/// Parse a DELETE clause: `[DETACH] DELETE expr, expr, ...`
125pub fn delete_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
126    let detach = just(Token::Detach).or_not().map(|d| d.is_some());
127
128    detach
129        .then_ignore(just(Token::Delete))
130        .then(
131            expression_parser()
132                .separated_by(just(Token::Comma))
133                .at_least(1),
134        )
135        .map(|(detach, expressions)| {
136            UpdatingClause::Delete(DeleteClause {
137                detach,
138                expressions,
139            })
140        })
141        .labelled("delete clause")
142}
143
144/// Parse a MERGE clause: `MERGE pattern [ON MATCH SET ...] [ON CREATE SET ...]`
145pub fn merge_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
146    let on_match = just(Token::On)
147        .then_ignore(just(Token::Match))
148        .then_ignore(just(Token::Set))
149        .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
150        .or_not()
151        .map(|o| o.unwrap_or_default());
152
153    let on_create = just(Token::On)
154        .then_ignore(just(Token::Create))
155        .then_ignore(just(Token::Set))
156        .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
157        .or_not()
158        .map(|o| o.unwrap_or_default());
159
160    just(Token::Merge)
161        .ignore_then(pattern_parser())
162        .then(on_match)
163        .then(on_create)
164        .map(|((pattern, on_match), on_create)| {
165            UpdatingClause::Merge(MergeClause {
166                pattern,
167                on_match,
168                on_create,
169            })
170        })
171        .labelled("merge clause")
172}
173
174// =============================================================================
175// Projection: RETURN and WITH
176// =============================================================================
177
178/// Parse a projection body (shared by RETURN and WITH).
179pub fn projection_body() -> impl Parser<Token, ProjectionBody, Error = ParserError> + Clone {
180    let distinct = just(Token::Distinct).or_not().map(|d| d.is_some());
181
182    let alias = just(Token::As)
183        .ignore_then(ident().map_with_span(|n, s| (n, s)))
184        .or_not();
185
186    let item = expression_parser().then(alias);
187
188    let items = just(Token::Star).to(ProjectionItems::All).or(item
189        .separated_by(just(Token::Comma))
190        .at_least(1)
191        .map(ProjectionItems::Expressions));
192
193    let sort_order = choice((
194        just(Token::Asc).to(SortOrder::Ascending),
195        just(Token::Desc).to(SortOrder::Descending),
196    ))
197    .or_not()
198    .map(|o| o.unwrap_or(SortOrder::Ascending));
199
200    let order_by = just(Token::Order)
201        .ignore_then(just(Token::By))
202        .ignore_then(
203            expression_parser()
204                .then(sort_order)
205                .separated_by(just(Token::Comma))
206                .at_least(1),
207        )
208        .or_not()
209        .map(|o| o.unwrap_or_default());
210
211    let skip = just(Token::Skip).ignore_then(expression_parser()).or_not();
212
213    let limit = just(Token::Limit).ignore_then(expression_parser()).or_not();
214
215    distinct
216        .then(items)
217        .then(order_by)
218        .then(skip)
219        .then(limit)
220        .map(
221            |((((distinct, items), order_by), skip), limit)| ProjectionBody {
222                distinct,
223                items,
224                order_by,
225                skip,
226                limit,
227            },
228        )
229}
230
231/// Parse RETURN clause.
232pub fn return_clause() -> impl Parser<Token, ProjectionBody, Error = ParserError> + Clone {
233    just(Token::Return)
234        .ignore_then(projection_body())
235        .labelled("return clause")
236}
237
238/// Parse WITH clause.
239pub fn with_clause()
240-> impl Parser<Token, (ProjectionBody, Option<Spanned<Expression>>), Error = ParserError> + Clone {
241    just(Token::With)
242        .ignore_then(projection_body())
243        .then(where_clause().or_not())
244        .labelled("with clause")
245}
246
247// =============================================================================
248// Reading clause dispatcher
249// =============================================================================
250
251/// Parse any reading clause.
252pub fn reading_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
253    choice((match_clause(), unwind_clause()))
254}
255
256/// Parse any updating clause.
257pub fn updating_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
258    choice((
259        create_clause(),
260        merge_clause(),
261        set_clause(),
262        delete_clause(),
263    ))
264}
265
266// =============================================================================
267// Standalone CALL
268// =============================================================================
269
270/// Parse a standalone CALL statement: `CALL procedure(args...)` or `CALL db.schema`
271pub fn standalone_call() -> impl Parser<Token, StandaloneCall, Error = ParserError> + Clone {
272    // Dotted procedure name: `db.schema` or just `table_info`
273    let procedure_name = ident()
274        .map_with_span(|n, s| (n, s))
275        .then(
276            just(Token::Dot)
277                .ignore_then(ident().map_with_span(|n, s| (n, s)))
278                .repeated(),
279        )
280        .map(|(first, rest)| {
281            let mut full_name = first.0.to_string();
282            let start = first.1.start;
283            let mut end = first.1.end;
284            for part in &rest {
285                full_name.push('.');
286                full_name.push_str(&part.0);
287                end = part.1.end;
288            }
289            (SmolStr::new(&full_name), start..end)
290        });
291
292    let args = expression_parser()
293        .separated_by(just(Token::Comma))
294        .delimited_by(just(Token::LeftParen), just(Token::RightParen))
295        .or_not()
296        .map(|a| a.unwrap_or_default());
297
298    just(Token::Call)
299        .ignore_then(procedure_name)
300        .then(args)
301        .map(|(procedure, args)| StandaloneCall { procedure, args })
302        .labelled("call statement")
303}
304
305// =============================================================================
306// Transaction
307// =============================================================================
308
309/// Parse transaction statements: BEGIN [READ ONLY | READ WRITE], COMMIT, ROLLBACK
310pub fn transaction_statement()
311-> impl Parser<Token, TransactionStatement, Error = ParserError> + Clone {
312    let mode = choice((
313        just(Token::Read)
314            .then_ignore(just(Token::Only))
315            .to(TransactionMode::ReadOnly),
316        just(Token::Read)
317            .then_ignore(just(Token::Write))
318            .to(TransactionMode::ReadWrite),
319    ))
320    .or_not()
321    .map(|m| m.unwrap_or(TransactionMode::ReadWrite));
322
323    let begin = just(Token::Begin)
324        .ignore_then(just(Token::Transaction).or_not())
325        .ignore_then(mode)
326        .map(TransactionStatement::Begin);
327
328    let commit = just(Token::Commit).to(TransactionStatement::Commit);
329    let rollback = just(Token::Rollback).to(TransactionStatement::Rollback);
330
331    choice((begin, commit, rollback)).labelled("transaction statement")
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use crate::lexer::Lexer;
338
339    fn tokens(src: &str) -> Vec<Spanned<Token>> {
340        let (tokens, errors) = Lexer::new(src).lex();
341        assert!(errors.is_empty());
342        tokens
343    }
344
345    fn parse_with<T>(parser: impl Parser<Token, T, Error = ParserError>, src: &str) -> Option<T> {
346        let toks = tokens(src);
347        let len = src.len();
348        let stream = chumsky::Stream::from_iter(
349            len..len + 1,
350            toks.into_iter()
351                .filter(|(tok, _)| !matches!(tok, Token::Eof)),
352        );
353        let (result, errors) = parser.then_ignore(end()).parse_recovery(stream);
354        if !errors.is_empty() {
355            eprintln!("parse errors: {errors:?}");
356        }
357        result
358    }
359
360    #[test]
361    fn simple_match() {
362        let clause = parse_with(match_clause(), "MATCH (n:Person)").unwrap();
363        if let ReadingClause::Match(m) = clause {
364            assert!(!m.is_optional);
365            assert_eq!(m.patterns.len(), 1);
366        } else {
367            panic!("expected match clause");
368        }
369    }
370
371    #[test]
372    fn optional_match() {
373        let clause = parse_with(match_clause(), "OPTIONAL MATCH (n)").unwrap();
374        if let ReadingClause::Match(m) = clause {
375            assert!(m.is_optional);
376        } else {
377            panic!("expected match clause");
378        }
379    }
380
381    #[test]
382    fn match_with_where() {
383        let clause = parse_with(match_clause(), "MATCH (n:Person) WHERE n.age > 30").unwrap();
384        if let ReadingClause::Match(m) = clause {
385            assert!(m.where_clause.is_some());
386        } else {
387            panic!("expected match clause");
388        }
389    }
390
391    #[test]
392    fn return_star() {
393        let proj = parse_with(return_clause(), "RETURN *").unwrap();
394        assert!(matches!(proj.items, ProjectionItems::All));
395    }
396
397    #[test]
398    fn return_with_alias() {
399        let proj = parse_with(return_clause(), "RETURN n.name AS name").unwrap();
400        if let ProjectionItems::Expressions(items) = &proj.items {
401            assert_eq!(items.len(), 1);
402            assert!(items[0].1.is_some());
403        } else {
404            panic!("expected expressions");
405        }
406    }
407
408    #[test]
409    fn return_with_order_by() {
410        let proj = parse_with(return_clause(), "RETURN n ORDER BY n.age DESC").unwrap();
411        assert_eq!(proj.order_by.len(), 1);
412        assert_eq!(proj.order_by[0].1, SortOrder::Descending);
413    }
414
415    #[test]
416    fn return_with_limit_skip() {
417        let proj = parse_with(return_clause(), "RETURN n SKIP 10 LIMIT 5").unwrap();
418        assert!(proj.skip.is_some());
419        assert!(proj.limit.is_some());
420    }
421
422    #[test]
423    fn create_node() {
424        let clause = parse_with(
425            create_clause(),
426            "CREATE (n:Person {name: 'Alice', age: 30})",
427        )
428        .unwrap();
429        if let UpdatingClause::Create(patterns) = clause {
430            assert_eq!(patterns.len(), 1);
431        } else {
432            panic!("expected create clause");
433        }
434    }
435
436    #[test]
437    fn delete_detach() {
438        let clause = parse_with(delete_clause(), "DETACH DELETE n").unwrap();
439        if let UpdatingClause::Delete(d) = clause {
440            assert!(d.detach);
441        } else {
442            panic!("expected delete clause");
443        }
444    }
445
446    #[test]
447    fn set_property() {
448        let clause = parse_with(set_clause(), "SET n.age = 31").unwrap();
449        assert!(matches!(clause, UpdatingClause::Set(_)));
450    }
451
452    #[test]
453    fn unwind() {
454        let clause = parse_with(unwind_clause(), "UNWIND [1, 2, 3] AS x").unwrap();
455        assert!(matches!(clause, ReadingClause::Unwind(_)));
456    }
457
458    #[test]
459    fn transaction_begin() {
460        let stmt = parse_with(transaction_statement(), "BEGIN TRANSACTION READ ONLY").unwrap();
461        assert!(matches!(
462            stmt,
463            TransactionStatement::Begin(TransactionMode::ReadOnly)
464        ));
465    }
466
467    #[test]
468    fn transaction_commit() {
469        let stmt = parse_with(transaction_statement(), "COMMIT").unwrap();
470        assert!(matches!(stmt, TransactionStatement::Commit));
471    }
472}