Skip to main content

kyu_parser/parser/
clause.rs

1//! MATCH, RETURN, WITH, WHERE, CREATE, SET, DELETE and other clause parsers.
2
3use chumsky::prelude::*;
4use smol_str::SmolStr;
5
6use crate::ast::*;
7use crate::span::Spanned;
8use crate::token::Token;
9
10use super::expression::expression_parser;
11use super::pattern::{ident, pattern_parser};
12
13type ParserError = Simple<Token>;
14
15// =============================================================================
16// Reading Clauses
17// =============================================================================
18
19/// Parse a MATCH clause: `[OPTIONAL] MATCH pattern [WHERE expr]`
20pub fn match_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
21    let optional = just(Token::Optional).or_not().map(|o| o.is_some());
22
23    optional
24        .then_ignore(just(Token::Match))
25        .then(
26            pattern_parser()
27                .separated_by(just(Token::Comma))
28                .at_least(1)
29                .labelled("match pattern"),
30        )
31        .then(where_clause().or_not())
32        .map(|((is_optional, patterns), where_clause)| {
33            ReadingClause::Match(MatchClause {
34                is_optional,
35                patterns,
36                where_clause,
37            })
38        })
39        .labelled("match clause")
40}
41
42/// Parse a WHERE clause: `WHERE expr`
43fn where_clause() -> impl Parser<Token, Spanned<Expression>, Error = ParserError> + Clone {
44    just(Token::Where)
45        .ignore_then(expression_parser())
46        .labelled("where clause")
47}
48
49/// Parse an UNWIND clause: `UNWIND expr AS alias`
50pub fn unwind_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
51    just(Token::Unwind)
52        .ignore_then(expression_parser())
53        .then_ignore(just(Token::As))
54        .then(ident().map_with_span(|n, s| (n, s)))
55        .map(|(expression, alias)| ReadingClause::Unwind(UnwindClause { expression, alias }))
56        .labelled("unwind clause")
57}
58
59// =============================================================================
60// Updating Clauses
61// =============================================================================
62
63/// Parse a CREATE clause: `CREATE pattern, pattern, ...`
64pub fn create_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
65    just(Token::Create)
66        .ignore_then(
67            pattern_parser()
68                .separated_by(just(Token::Comma))
69                .at_least(1),
70        )
71        .map(UpdatingClause::Create)
72        .labelled("create clause")
73}
74
75/// Parse a SET clause: `SET item, item, ...`
76pub fn set_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
77    just(Token::Set)
78        .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
79        .map(UpdatingClause::Set)
80        .labelled("set clause")
81}
82
83fn set_item() -> impl Parser<Token, SetItem, Error = ParserError> + Clone {
84    // Parse left side as a property chain: n.prop or n.a.b
85    // We can't use expression_parser() here because it would consume `=` as comparison.
86    let property_chain = ident()
87        .map_with_span(|name, span| {
88            (Expression::Variable(name), span)
89        })
90        .then(
91            just(Token::Dot)
92                .ignore_then(ident().map_with_span(|n, s| (n, s)))
93                .repeated()
94                .at_least(1),
95        )
96        .foldl(|base, key| {
97            let span = base.1.start..key.1.end;
98            (
99                Expression::Property {
100                    object: Box::new(base),
101                    key,
102                },
103                span,
104            )
105        });
106
107    let set_property = property_chain
108        .then_ignore(just(Token::Eq))
109        .then(expression_parser())
110        .map(|(entity, value)| SetItem::Property { entity, value });
111
112    // SET n:Label1:Label2 — setting labels on a node
113    let set_labels = ident()
114        .map_with_span(|n, s| (n, s))
115        .then(
116            just(Token::Colon)
117                .ignore_then(ident().map_with_span(|n, s| (n, s)))
118                .repeated()
119                .at_least(1),
120        )
121        .map(|(entity, labels)| SetItem::Labels { entity, labels });
122
123    set_property.or(set_labels)
124}
125
126/// Parse a DELETE clause: `[DETACH] DELETE expr, expr, ...`
127pub fn delete_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
128    let detach = just(Token::Detach).or_not().map(|d| d.is_some());
129
130    detach
131        .then_ignore(just(Token::Delete))
132        .then(
133            expression_parser()
134                .separated_by(just(Token::Comma))
135                .at_least(1),
136        )
137        .map(|(detach, expressions)| {
138            UpdatingClause::Delete(DeleteClause {
139                detach,
140                expressions,
141            })
142        })
143        .labelled("delete clause")
144}
145
146/// Parse a MERGE clause: `MERGE pattern [ON MATCH SET ...] [ON CREATE SET ...]`
147pub fn merge_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
148    let on_match = just(Token::On)
149        .then_ignore(just(Token::Match))
150        .then_ignore(just(Token::Set))
151        .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
152        .or_not()
153        .map(|o| o.unwrap_or_default());
154
155    let on_create = just(Token::On)
156        .then_ignore(just(Token::Create))
157        .then_ignore(just(Token::Set))
158        .ignore_then(set_item().separated_by(just(Token::Comma)).at_least(1))
159        .or_not()
160        .map(|o| o.unwrap_or_default());
161
162    just(Token::Merge)
163        .ignore_then(pattern_parser())
164        .then(on_match)
165        .then(on_create)
166        .map(|((pattern, on_match), on_create)| {
167            UpdatingClause::Merge(MergeClause {
168                pattern,
169                on_match,
170                on_create,
171            })
172        })
173        .labelled("merge clause")
174}
175
176// =============================================================================
177// Projection: RETURN and WITH
178// =============================================================================
179
180/// Parse a projection body (shared by RETURN and WITH).
181pub fn projection_body() -> impl Parser<Token, ProjectionBody, Error = ParserError> + Clone {
182    let distinct = just(Token::Distinct).or_not().map(|d| d.is_some());
183
184    let alias = just(Token::As)
185        .ignore_then(ident().map_with_span(|n, s| (n, s)))
186        .or_not();
187
188    let item = expression_parser().then(alias);
189
190    let items = just(Token::Star)
191        .to(ProjectionItems::All)
192        .or(item
193            .separated_by(just(Token::Comma))
194            .at_least(1)
195            .map(ProjectionItems::Expressions));
196
197    let sort_order = choice((
198        just(Token::Asc).to(SortOrder::Ascending),
199        just(Token::Desc).to(SortOrder::Descending),
200    ))
201    .or_not()
202    .map(|o| o.unwrap_or(SortOrder::Ascending));
203
204    let order_by = just(Token::Order)
205        .ignore_then(just(Token::By))
206        .ignore_then(
207            expression_parser()
208                .then(sort_order)
209                .separated_by(just(Token::Comma))
210                .at_least(1),
211        )
212        .or_not()
213        .map(|o| o.unwrap_or_default());
214
215    let skip = just(Token::Skip)
216        .ignore_then(expression_parser())
217        .or_not();
218
219    let limit = just(Token::Limit)
220        .ignore_then(expression_parser())
221        .or_not();
222
223    distinct
224        .then(items)
225        .then(order_by)
226        .then(skip)
227        .then(limit)
228        .map(|((((distinct, items), order_by), skip), limit)| ProjectionBody {
229            distinct,
230            items,
231            order_by,
232            skip,
233            limit,
234        })
235}
236
237/// Parse RETURN clause.
238pub fn return_clause() -> impl Parser<Token, ProjectionBody, Error = ParserError> + Clone {
239    just(Token::Return)
240        .ignore_then(projection_body())
241        .labelled("return clause")
242}
243
244/// Parse WITH clause.
245pub fn with_clause() -> impl Parser<Token, (ProjectionBody, Option<Spanned<Expression>>), Error = ParserError> + Clone
246{
247    just(Token::With)
248        .ignore_then(projection_body())
249        .then(where_clause().or_not())
250        .labelled("with clause")
251}
252
253// =============================================================================
254// Reading clause dispatcher
255// =============================================================================
256
257/// Parse any reading clause.
258pub fn reading_clause() -> impl Parser<Token, ReadingClause, Error = ParserError> + Clone {
259    choice((match_clause(), unwind_clause()))
260}
261
262/// Parse any updating clause.
263pub fn updating_clause() -> impl Parser<Token, UpdatingClause, Error = ParserError> + Clone {
264    choice((
265        create_clause(),
266        merge_clause(),
267        set_clause(),
268        delete_clause(),
269    ))
270}
271
272// =============================================================================
273// Standalone CALL
274// =============================================================================
275
276/// Parse a standalone CALL statement: `CALL procedure(args...)` or `CALL db.schema`
277pub fn standalone_call() -> impl Parser<Token, StandaloneCall, Error = ParserError> + Clone {
278    // Dotted procedure name: `db.schema` or just `table_info`
279    let procedure_name = ident()
280        .map_with_span(|n, s| (n, s))
281        .then(
282            just(Token::Dot)
283                .ignore_then(ident().map_with_span(|n, s| (n, s)))
284                .repeated(),
285        )
286        .map(|(first, rest)| {
287            let mut full_name = first.0.to_string();
288            let start = first.1.start;
289            let mut end = first.1.end;
290            for part in &rest {
291                full_name.push('.');
292                full_name.push_str(&part.0);
293                end = part.1.end;
294            }
295            (SmolStr::new(&full_name), start..end)
296        });
297
298    let args = expression_parser()
299        .separated_by(just(Token::Comma))
300        .delimited_by(just(Token::LeftParen), just(Token::RightParen))
301        .or_not()
302        .map(|a| a.unwrap_or_default());
303
304    just(Token::Call)
305        .ignore_then(procedure_name)
306        .then(args)
307        .map(|(procedure, args)| StandaloneCall { procedure, args })
308        .labelled("call statement")
309}
310
311// =============================================================================
312// Transaction
313// =============================================================================
314
315/// Parse transaction statements: BEGIN [READ ONLY | READ WRITE], COMMIT, ROLLBACK
316pub fn transaction_statement(
317) -> impl Parser<Token, TransactionStatement, Error = ParserError> + Clone {
318    let mode = choice((
319        just(Token::Read)
320            .then_ignore(just(Token::Only))
321            .to(TransactionMode::ReadOnly),
322        just(Token::Read)
323            .then_ignore(just(Token::Write))
324            .to(TransactionMode::ReadWrite),
325    ))
326    .or_not()
327    .map(|m| m.unwrap_or(TransactionMode::ReadWrite));
328
329    let begin = just(Token::Begin)
330        .ignore_then(just(Token::Transaction).or_not())
331        .ignore_then(mode)
332        .map(TransactionStatement::Begin);
333
334    let commit = just(Token::Commit).to(TransactionStatement::Commit);
335    let rollback = just(Token::Rollback).to(TransactionStatement::Rollback);
336
337    choice((begin, commit, rollback)).labelled("transaction statement")
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use crate::lexer::Lexer;
344
345    fn tokens(src: &str) -> Vec<Spanned<Token>> {
346        let (tokens, errors) = Lexer::new(src).lex();
347        assert!(errors.is_empty());
348        tokens
349    }
350
351    fn parse_with<T>(
352        parser: impl Parser<Token, T, Error = ParserError>,
353        src: &str,
354    ) -> Option<T> {
355        let toks = tokens(src);
356        let len = src.len();
357        let stream = chumsky::Stream::from_iter(
358            len..len + 1,
359            toks.into_iter()
360                .filter(|(tok, _)| !matches!(tok, Token::Eof)),
361        );
362        let (result, errors) = parser.then_ignore(end()).parse_recovery(stream);
363        if !errors.is_empty() {
364            eprintln!("parse errors: {errors:?}");
365        }
366        result
367    }
368
369    #[test]
370    fn simple_match() {
371        let clause = parse_with(match_clause(), "MATCH (n:Person)").unwrap();
372        if let ReadingClause::Match(m) = clause {
373            assert!(!m.is_optional);
374            assert_eq!(m.patterns.len(), 1);
375        } else {
376            panic!("expected match clause");
377        }
378    }
379
380    #[test]
381    fn optional_match() {
382        let clause = parse_with(match_clause(), "OPTIONAL MATCH (n)").unwrap();
383        if let ReadingClause::Match(m) = clause {
384            assert!(m.is_optional);
385        } else {
386            panic!("expected match clause");
387        }
388    }
389
390    #[test]
391    fn match_with_where() {
392        let clause = parse_with(match_clause(), "MATCH (n:Person) WHERE n.age > 30").unwrap();
393        if let ReadingClause::Match(m) = clause {
394            assert!(m.where_clause.is_some());
395        } else {
396            panic!("expected match clause");
397        }
398    }
399
400    #[test]
401    fn return_star() {
402        let proj = parse_with(return_clause(), "RETURN *").unwrap();
403        assert!(matches!(proj.items, ProjectionItems::All));
404    }
405
406    #[test]
407    fn return_with_alias() {
408        let proj = parse_with(return_clause(), "RETURN n.name AS name").unwrap();
409        if let ProjectionItems::Expressions(items) = &proj.items {
410            assert_eq!(items.len(), 1);
411            assert!(items[0].1.is_some());
412        } else {
413            panic!("expected expressions");
414        }
415    }
416
417    #[test]
418    fn return_with_order_by() {
419        let proj = parse_with(return_clause(), "RETURN n ORDER BY n.age DESC").unwrap();
420        assert_eq!(proj.order_by.len(), 1);
421        assert_eq!(proj.order_by[0].1, SortOrder::Descending);
422    }
423
424    #[test]
425    fn return_with_limit_skip() {
426        let proj = parse_with(return_clause(), "RETURN n SKIP 10 LIMIT 5").unwrap();
427        assert!(proj.skip.is_some());
428        assert!(proj.limit.is_some());
429    }
430
431    #[test]
432    fn create_node() {
433        let clause = parse_with(
434            create_clause(),
435            "CREATE (n:Person {name: 'Alice', age: 30})",
436        )
437        .unwrap();
438        if let UpdatingClause::Create(patterns) = clause {
439            assert_eq!(patterns.len(), 1);
440        } else {
441            panic!("expected create clause");
442        }
443    }
444
445    #[test]
446    fn delete_detach() {
447        let clause = parse_with(delete_clause(), "DETACH DELETE n").unwrap();
448        if let UpdatingClause::Delete(d) = clause {
449            assert!(d.detach);
450        } else {
451            panic!("expected delete clause");
452        }
453    }
454
455    #[test]
456    fn set_property() {
457        let clause = parse_with(set_clause(), "SET n.age = 31").unwrap();
458        assert!(matches!(clause, UpdatingClause::Set(_)));
459    }
460
461    #[test]
462    fn unwind() {
463        let clause = parse_with(unwind_clause(), "UNWIND [1, 2, 3] AS x").unwrap();
464        assert!(matches!(clause, ReadingClause::Unwind(_)));
465    }
466
467    #[test]
468    fn transaction_begin() {
469        let stmt = parse_with(transaction_statement(), "BEGIN TRANSACTION READ ONLY").unwrap();
470        assert!(matches!(
471            stmt,
472            TransactionStatement::Begin(TransactionMode::ReadOnly)
473        ));
474    }
475
476    #[test]
477    fn transaction_commit() {
478        let stmt = parse_with(transaction_statement(), "COMMIT").unwrap();
479        assert!(matches!(stmt, TransactionStatement::Commit));
480    }
481}