Skip to main content

kyu_parser/parser/
pattern.rs

1//! Node and relationship pattern parsers.
2
3use chumsky::prelude::*;
4use smol_str::SmolStr;
5
6use crate::ast::*;
7use crate::span::Spanned;
8use crate::token::Token;
9
10use super::expression::expression_parser;
11
12type ParserError = Simple<Token>;
13
14/// Parse an identifier (unquoted, escaped, or keyword used in identifier position).
15///
16/// Cypher allows most keywords to be used as identifiers in non-ambiguous contexts
17/// (variable names, labels, property names, map keys, relationship types).
18///
19/// Uses `filter_map` instead of `select!` to avoid monomorphization bloat from 70+ arms,
20/// which otherwise causes linker failures from excessively long symbol names.
21pub fn ident() -> impl Parser<Token, SmolStr, Error = ParserError> + Clone {
22    filter_map(|span, token: Token| {
23        match token_to_ident(&token) {
24            Some(name) => Ok(name),
25            None => Err(Simple::expected_input_found(span, [], Some(token))),
26        }
27    })
28}
29
30/// Convert a token to an identifier name if it can be used in identifier position.
31fn token_to_ident(token: &Token) -> Option<SmolStr> {
32    match token {
33        Token::Ident(name) | Token::EscapedIdent(name) => Some(name.clone()),
34        // DDL keywords
35        Token::Node => Some(SmolStr::new("NODE")),
36        Token::Rel => Some(SmolStr::new("REL")),
37        Token::Table => Some(SmolStr::new("TABLE")),
38        Token::Group => Some(SmolStr::new("GROUP")),
39        Token::Rdf => Some(SmolStr::new("RDF")),
40        Token::Graph => Some(SmolStr::new("GRAPH")),
41        Token::From => Some(SmolStr::new("FROM")),
42        Token::To => Some(SmolStr::new("TO")),
43        Token::Primary => Some(SmolStr::new("PRIMARY")),
44        Token::Key => Some(SmolStr::new("KEY")),
45        Token::Add => Some(SmolStr::new("ADD")),
46        Token::Column => Some(SmolStr::new("COLUMN")),
47        Token::Rename => Some(SmolStr::new("RENAME")),
48        Token::Comment => Some(SmolStr::new("COMMENT")),
49        Token::Default => Some(SmolStr::new("DEFAULT")),
50        Token::Copy => Some(SmolStr::new("COPY")),
51        Token::Load => Some(SmolStr::new("LOAD")),
52        Token::Attach => Some(SmolStr::new("ATTACH")),
53        Token::Use => Some(SmolStr::new("USE")),
54        Token::Database => Some(SmolStr::new("DATABASE")),
55        Token::Export => Some(SmolStr::new("EXPORT")),
56        Token::Import => Some(SmolStr::new("IMPORT")),
57        Token::Install => Some(SmolStr::new("INSTALL")),
58        Token::Extension => Some(SmolStr::new("EXTENSION")),
59        // Transaction keywords
60        Token::Begin => Some(SmolStr::new("BEGIN")),
61        Token::Commit => Some(SmolStr::new("COMMIT")),
62        Token::Rollback => Some(SmolStr::new("ROLLBACK")),
63        Token::Transaction => Some(SmolStr::new("TRANSACTION")),
64        Token::Read => Some(SmolStr::new("READ")),
65        Token::Write => Some(SmolStr::new("WRITE")),
66        Token::Only => Some(SmolStr::new("ONLY")),
67        // Type keywords
68        Token::ListType => Some(SmolStr::new("LIST")),
69        Token::MapType => Some(SmolStr::new("MAP")),
70        Token::StructType => Some(SmolStr::new("STRUCT")),
71        Token::UnionType => Some(SmolStr::new("UNION")),
72        Token::BoolType => Some(SmolStr::new("BOOL")),
73        Token::StringType => Some(SmolStr::new("STRING")),
74        Token::DateType => Some(SmolStr::new("DATE")),
75        Token::TimestampType => Some(SmolStr::new("TIMESTAMP")),
76        Token::IntervalType => Some(SmolStr::new("INTERVAL")),
77        Token::BlobType => Some(SmolStr::new("BLOB")),
78        Token::UuidType => Some(SmolStr::new("UUID")),
79        Token::SerialType => Some(SmolStr::new("SERIAL")),
80        Token::FloatType => Some(SmolStr::new("FLOAT")),
81        Token::DoubleType => Some(SmolStr::new("DOUBLE")),
82        Token::Int8Type => Some(SmolStr::new("INT8")),
83        Token::Int16Type => Some(SmolStr::new("INT16")),
84        Token::Int32Type => Some(SmolStr::new("INT32")),
85        Token::Int64Type => Some(SmolStr::new("INT64")),
86        Token::Int128Type => Some(SmolStr::new("INT128")),
87        Token::UInt8Type => Some(SmolStr::new("UINT8")),
88        Token::UInt16Type => Some(SmolStr::new("UINT16")),
89        Token::UInt32Type => Some(SmolStr::new("UINT32")),
90        Token::UInt64Type => Some(SmolStr::new("UINT64")),
91        // Cypher keywords usable as identifiers
92        Token::Count => Some(SmolStr::new("count")),
93        Token::Exists => Some(SmolStr::new("exists")),
94        Token::All => Some(SmolStr::new("all")),
95        Token::Any => Some(SmolStr::new("any")),
96        Token::Single => Some(SmolStr::new("single")),
97        Token::None => Some(SmolStr::new("none")),
98        Token::On => Some(SmolStr::new("ON")),
99        Token::Yield => Some(SmolStr::new("YIELD")),
100        Token::End => Some(SmolStr::new("END")),
101        Token::Call => Some(SmolStr::new("CALL")),
102        Token::If => Some(SmolStr::new("IF")),
103        Token::Macro => Some(SmolStr::new("MACRO")),
104        Token::Shortest => Some(SmolStr::new("SHORTEST")),
105        Token::Asc => Some(SmolStr::new("ASC")),
106        Token::Desc => Some(SmolStr::new("DESC")),
107        Token::In => Some(SmolStr::new("IN")),
108        Token::Is => Some(SmolStr::new("IS")),
109        Token::Contains => Some(SmolStr::new("CONTAINS")),
110        Token::Starts => Some(SmolStr::new("STARTS")),
111        Token::Ends => Some(SmolStr::new("ENDS")),
112        Token::Union => Some(SmolStr::new("UNION")),
113        Token::Drop => Some(SmolStr::new("DROP")),
114        Token::Alter => Some(SmolStr::new("ALTER")),
115        Token::Remove => Some(SmolStr::new("REMOVE")),
116        Token::Profile => Some(SmolStr::new("PROFILE")),
117        Token::Explain => Some(SmolStr::new("EXPLAIN")),
118        _ => Option::None,
119    }
120}
121
122/// Parse one or more node labels: `:Label1:Label2`
123fn node_labels() -> impl Parser<Token, Vec<Spanned<SmolStr>>, Error = ParserError> + Clone {
124    just(Token::Colon)
125        .ignore_then(ident().map_with_span(|n, s| (n, s)))
126        .repeated()
127        .at_least(1)
128}
129
130/// Parse map properties: `{key: expr, key: expr, ...}`
131pub fn map_properties(
132) -> impl Parser<Token, Vec<(Spanned<SmolStr>, Spanned<Expression>)>, Error = ParserError> + Clone
133{
134    let entry = ident()
135        .map_with_span(|n, s| (n, s))
136        .then_ignore(just(Token::Colon))
137        .then(expression_parser());
138
139    entry
140        .separated_by(just(Token::Comma))
141        .allow_trailing()
142        .delimited_by(just(Token::LeftBrace), just(Token::RightBrace))
143}
144
145/// Parse a node pattern: `(variable:Label {props})`
146pub fn node_pattern() -> impl Parser<Token, NodePattern, Error = ParserError> + Clone {
147    let variable = ident().map_with_span(|n, s| (n, s)).or_not();
148    let labels = node_labels().or_not().map(|l| l.unwrap_or_default());
149    let props = map_properties().or_not();
150
151    variable
152        .then(labels)
153        .then(props)
154        .delimited_by(just(Token::LeftParen), just(Token::RightParen))
155        .map_with_span(|((variable, labels), properties), span| NodePattern {
156            variable,
157            labels,
158            properties,
159            span,
160        })
161        .labelled("node pattern")
162}
163
164/// Parse a relationship pattern including direction arrows.
165///
166/// Handles:
167/// - `-[r:TYPE]->` (right)
168/// - `<-[r:TYPE]-` (left)
169/// - `-[r:TYPE]-`  (both)
170type RelDetail = (
171    Option<Spanned<SmolStr>>,
172    Vec<Spanned<SmolStr>>,
173    Option<(Option<u32>, Option<u32>)>,
174    Option<Vec<(Spanned<SmolStr>, Spanned<Expression>)>>,
175);
176
177fn relationship_detail() -> impl Parser<Token, RelDetail, Error = ParserError> + Clone {
178    let variable = ident().map_with_span(|n, s| (n, s)).or_not();
179
180    // Relationship types: `:TYPE1|TYPE2` or `:TYPE1|:TYPE2` (TCK allows colon after pipe)
181    let rel_types = just(Token::Colon)
182        .ignore_then(
183            ident()
184                .map_with_span(|n, s| (n, s))
185                .then(
186                    just(Token::Pipe)
187                        .ignore_then(just(Token::Colon).or_not())
188                        .ignore_then(ident().map_with_span(|n, s| (n, s)))
189                        .repeated(),
190                )
191                .map(|(first, rest)| {
192                    let mut types = vec![first];
193                    types.extend(rest);
194                    types
195                }),
196        )
197        .or_not()
198        .map(|t| t.unwrap_or_default());
199
200    // Variable-length: *min..max
201    let range = just(Token::Star)
202        .ignore_then(
203            select! { Token::Integer(n) => n as u32 }
204                .or_not()
205                .then(
206                    just(Token::DoubleDot)
207                        .ignore_then(select! { Token::Integer(n) => n as u32 }.or_not())
208                        .or_not(),
209                ),
210        )
211        .map(|(min, max_opt)| match max_opt {
212            Some(max) => (min, max),
213            None => (min, min),
214        })
215        .or_not();
216
217    let props = map_properties().or_not();
218
219    variable
220        .then(rel_types)
221        .then(range)
222        .then(props)
223        .delimited_by(just(Token::LeftBracket), just(Token::RightBracket))
224        .map(|(((variable, rel_types), range), properties)| {
225            (variable, rel_types, range, properties)
226        })
227}
228
229/// Parse a complete relationship segment with arrows/dashes.
230fn relationship_pattern() -> impl Parser<Token, RelationshipPattern, Error = ParserError> + Clone {
231    // Case 1: <-[...]-  (left arrow)
232    let left = just(Token::LeftArrow)
233        .ignore_then(relationship_detail())
234        .then_ignore(just(Token::Dash))
235        .map_with_span(|detail: RelDetail, span| {
236            let (variable, rel_types, range, properties) = detail;
237            RelationshipPattern { variable, rel_types, direction: Direction::Left, range, properties, span }
238        });
239
240    // Case 2: -[...]->  (right arrow)
241    let right = just(Token::Dash)
242        .ignore_then(relationship_detail())
243        .then_ignore(just(Token::Arrow))
244        .map_with_span(|detail: RelDetail, span| {
245            let (variable, rel_types, range, properties) = detail;
246            RelationshipPattern { variable, rel_types, direction: Direction::Right, range, properties, span }
247        });
248
249    // Case 3: -[...]-  (both directions / undirected)
250    let both = just(Token::Dash)
251        .ignore_then(relationship_detail())
252        .then_ignore(just(Token::Dash))
253        .map_with_span(|detail: RelDetail, span| {
254            let (variable, rel_types, range, properties) = detail;
255            RelationshipPattern { variable, rel_types, direction: Direction::Both, range, properties, span }
256        });
257
258    // Case 4: simple dashes without brackets: --> or <-- or --
259    let simple_right = just(Token::Dash)
260        .then_ignore(just(Token::Arrow))
261        .map_with_span(|_, span| RelationshipPattern {
262            variable: None,
263            rel_types: vec![],
264            direction: Direction::Right,
265            range: None,
266            properties: None,
267            span,
268        });
269
270    let simple_left = just(Token::LeftArrow)
271        .then_ignore(just(Token::Dash))
272        .map_with_span(|_, span| RelationshipPattern {
273            variable: None,
274            rel_types: vec![],
275            direction: Direction::Left,
276            range: None,
277            properties: None,
278            span,
279        });
280
281    // Case 6: simple undirected `--` without brackets
282    let simple_both = just(Token::Dash)
283        .then_ignore(just(Token::Dash))
284        .map_with_span(|_, span| RelationshipPattern {
285            variable: None,
286            rel_types: vec![],
287            direction: Direction::Both,
288            range: None,
289            properties: None,
290            span,
291        });
292
293    // Case 7: `<-->` shorthand (left-arrow followed by right-arrow)
294    let simple_both_arrow = just(Token::LeftArrow)
295        .then_ignore(just(Token::Arrow))
296        .map_with_span(|_, span| RelationshipPattern {
297            variable: None,
298            rel_types: vec![],
299            direction: Direction::Both,
300            range: None,
301            properties: None,
302            span,
303        });
304
305    choice((left, right, both, simple_right, simple_left, simple_both_arrow, simple_both))
306        .labelled("relationship pattern")
307}
308
309/// Parse a chain of node-rel-node-rel-... pattern elements.
310pub fn pattern_element_chain() -> impl Parser<Token, Vec<PatternElement>, Error = ParserError> + Clone
311{
312    node_pattern()
313        .map(PatternElement::Node)
314        .then(
315            relationship_pattern()
316                .map(PatternElement::Relationship)
317                .then(node_pattern().map(PatternElement::Node))
318                .repeated(),
319        )
320        .map(|(first, rest)| {
321            let mut elements = vec![first];
322            for (rel, node) in rest {
323                elements.push(rel);
324                elements.push(node);
325            }
326            elements
327        })
328}
329
330/// Parse a full pattern with optional variable assignment: `p = (a)-[:R]->(b)`
331pub fn pattern_parser() -> impl Parser<Token, Pattern, Error = ParserError> + Clone {
332    let variable_assignment = ident()
333        .map_with_span(|n, s| (n, s))
334        .then_ignore(just(Token::Eq))
335        .or_not();
336
337    variable_assignment
338        .then(pattern_element_chain())
339        .map(|(variable, elements)| Pattern { variable, elements })
340        .labelled("pattern")
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::lexer::Lexer;
347
348    fn parse_pattern_from(src: &str) -> Option<Pattern> {
349        let (tokens, lex_errors) = Lexer::new(src).lex();
350        assert!(lex_errors.is_empty(), "lex errors: {lex_errors:?}");
351
352        let len = src.len();
353        let stream = chumsky::Stream::from_iter(
354            len..len + 1,
355            tokens
356                .into_iter()
357                .filter(|(tok, _)| !matches!(tok, Token::Eof)),
358        );
359
360        let (result, errors) = pattern_parser().then_ignore(end()).parse_recovery(stream);
361        if !errors.is_empty() {
362            eprintln!("parse errors: {errors:?}");
363        }
364        result
365    }
366
367    #[test]
368    fn simple_node() {
369        let p = parse_pattern_from("(n)").unwrap();
370        assert_eq!(p.elements.len(), 1);
371        if let PatternElement::Node(node) = &p.elements[0] {
372            assert_eq!(node.variable.as_ref().unwrap().0.as_str(), "n");
373            assert!(node.labels.is_empty());
374        } else {
375            panic!("expected node");
376        }
377    }
378
379    #[test]
380    fn node_with_label() {
381        let p = parse_pattern_from("(n:Person)").unwrap();
382        if let PatternElement::Node(node) = &p.elements[0] {
383            assert_eq!(node.labels.len(), 1);
384            assert_eq!(node.labels[0].0.as_str(), "Person");
385        } else {
386            panic!("expected node");
387        }
388    }
389
390    #[test]
391    fn node_with_multiple_labels() {
392        let p = parse_pattern_from("(n:Person:Employee)").unwrap();
393        if let PatternElement::Node(node) = &p.elements[0] {
394            assert_eq!(node.labels.len(), 2);
395        } else {
396            panic!("expected node");
397        }
398    }
399
400    #[test]
401    fn node_with_properties() {
402        let p = parse_pattern_from("(n:Person {name: 'Alice', age: 30})").unwrap();
403        if let PatternElement::Node(node) = &p.elements[0] {
404            assert!(node.properties.is_some());
405            assert_eq!(node.properties.as_ref().unwrap().len(), 2);
406        } else {
407            panic!("expected node");
408        }
409    }
410
411    #[test]
412    fn right_relationship() {
413        let p = parse_pattern_from("(a)-[:KNOWS]->(b)").unwrap();
414        assert_eq!(p.elements.len(), 3);
415        if let PatternElement::Relationship(rel) = &p.elements[1] {
416            assert_eq!(rel.direction, Direction::Right);
417            assert_eq!(rel.rel_types.len(), 1);
418            assert_eq!(rel.rel_types[0].0.as_str(), "KNOWS");
419        } else {
420            panic!("expected relationship");
421        }
422    }
423
424    #[test]
425    fn left_relationship() {
426        let p = parse_pattern_from("(a)<-[:LIKES]-(b)").unwrap();
427        if let PatternElement::Relationship(rel) = &p.elements[1] {
428            assert_eq!(rel.direction, Direction::Left);
429        } else {
430            panic!("expected relationship");
431        }
432    }
433
434    #[test]
435    fn undirected_relationship() {
436        let p = parse_pattern_from("(a)-[:FRIENDS]-(b)").unwrap();
437        if let PatternElement::Relationship(rel) = &p.elements[1] {
438            assert_eq!(rel.direction, Direction::Both);
439        } else {
440            panic!("expected relationship");
441        }
442    }
443
444    #[test]
445    fn variable_on_relationship() {
446        let p = parse_pattern_from("(a)-[r:KNOWS]->(b)").unwrap();
447        if let PatternElement::Relationship(rel) = &p.elements[1] {
448            assert_eq!(rel.variable.as_ref().unwrap().0.as_str(), "r");
449        } else {
450            panic!("expected relationship");
451        }
452    }
453
454    #[test]
455    fn multi_hop() {
456        let p = parse_pattern_from("(a)-[:R1]->(b)-[:R2]->(c)").unwrap();
457        assert_eq!(p.elements.len(), 5); // node-rel-node-rel-node
458    }
459
460    #[test]
461    fn pattern_with_variable_assignment() {
462        let p = parse_pattern_from("p = (a)-[:R]->(b)").unwrap();
463        assert_eq!(p.variable.as_ref().unwrap().0.as_str(), "p");
464    }
465
466    #[test]
467    fn anonymous_node() {
468        let p = parse_pattern_from("()").unwrap();
469        if let PatternElement::Node(node) = &p.elements[0] {
470            assert!(node.variable.is_none());
471        } else {
472            panic!("expected node");
473        }
474    }
475}