drasi_query_cypher/
lib.rs

1#![allow(clippy::redundant_closure_call)]
2// Copyright 2024 The Drasi Authors.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use drasi_query_ast::{
17    api::{QueryParseError, QueryParser},
18    ast::{self, Expression, ParentExpression, ProjectionClause},
19};
20use peg::{error::ParseError, str::LineCol};
21use std::{collections::HashSet, sync::Arc};
22
23#[cfg(test)]
24mod tests;
25
26peg::parser! {
27    grammar cypher() for str {
28        use drasi_query_ast::ast::*;
29
30        rule kw_match()     = ("MATCH" / "match")
31        rule kw_create()    = ("CREATE" / "create")
32        rule kw_set()       = ("SET" / "set")
33        rule kw_delete()    = ("DELETE" / "delete")
34        rule kw_where()     = ("WHERE" / "where")
35        rule kw_return()    = ("RETURN" / "return")
36        rule kw_true()      = ("TRUE" / "true")
37        rule kw_false()     = ("FALSE" / "false")
38        rule kw_null()      = ("NULL" / "null")
39        rule kw_and()       = ("AND" / "and")
40        rule kw_or()        = ("OR" / "or")
41        rule kw_not()       = ("NOT" / "not")
42        rule kw_is()        = ("IS" / "is")
43        rule kw_id()        = ("ID" / "id")
44        rule kw_label()     = ("LABEL" / "label")
45        rule kw_as()        = ("AS" / "as")
46        rule kw_case()      = ("CASE" / "case")
47        rule kw_when()      = ("WHEN" / "when")
48        rule kw_then()      = ("THEN" / "then")
49        rule kw_else()      = ("ELSE" / "else")
50        rule kw_end()       = ("END" / "end")
51        rule kw_with()      = ("WITH" / "with")
52        rule kw_in()        = ("IN" / "in")
53        rule kw_exists()    = ("EXISTS" / "exists")
54
55        rule _()
56            = quiet!{[' ']}
57
58        rule __()
59            = quiet!{[' ' | '\n' | '\t']}
60            / comment()
61
62        rule comment()
63            = quiet!{ "//" (!"\n" [_])* ("\n" / ![_])}
64
65        rule alpha()
66            = ['a'..='z' | 'A'..='Z']
67
68        rule num()
69            = quiet! {
70                ['0'..='9']
71            }
72            / expected!("a number")
73
74        rule alpha_num()
75            = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_']
76
77
78        // e.g. '42', '-1'
79        rule integer() -> i64
80            = integer:$("-"?num()+) {? integer.parse().or(Err("invalid integer")) }
81
82        // e.g. '-0.53', '34346.245', '236.0'
83        rule real() -> f64
84            = real:$("-"? num()+ "." num()+) {? real.parse().or(Err("invalid real"))}
85
86        // e.g. 'TRUE', 'FALSE'
87        rule boolean() -> bool
88            = kw_true() { true } / kw_false() { false }
89
90        // e.g. 'hello world'
91        rule text() -> Arc<str>
92            = quiet! {
93                "'" text:$((date_for_date_time() "T" time_format() timezone())) "'"{ Arc::from(text) }
94            }
95            / quiet! {
96                "'" text:$((date_for_date_time() "T" time_format())) "'"{ Arc::from(text) }
97            }
98            / quiet! {
99                "'" text:$(date_format()) "'" { Arc::from(text) }
100            }
101            / quiet! {
102                "'" text:$(time_format() timezone()) "'" { Arc::from(text) }
103            }
104            / quiet! {
105                "'" text:$(time_format()) "'" { Arc::from(text) }
106            }
107            /quiet!{
108                "'" text:$(duration()) "'" { Arc::from(text) }
109            }
110            /quiet! {
111                "'" text:$([^ '\'' | '\n' | '\r']*) "'" { Arc::from(text) }
112            }
113            / expected!("a quoted string")
114
115        // e.g. 'TRUE', '42', 'hello world'
116        rule literal() -> Literal
117            = r:real() { Literal::Real(r) }
118            / i:integer() { Literal::Integer(i) }
119            / b:boolean() { Literal::Boolean(b) }
120            / t:text() { Literal::Text(t) }
121            / kw_null() { Literal::Null }
122
123        rule year() -> Arc<str>
124            = year:$(['0'..='9']*<4>) { Arc::from(year) }
125
126        rule month() -> Arc<str>
127            = month:$(['0'..='1']['0'..='9']) { Arc::from(month) }
128
129        rule day() -> Arc<str>
130            = day:$(['0'..='9']*<2>) { Arc::from(day) }
131
132        rule week() -> Arc<str>
133            = week:$("W" ['0'..='9']*<0,2>) { Arc::from(week) }
134
135        rule quarter() -> Arc<str>
136            = quarter:$("Q" ['1'..='4']) { Arc::from(quarter)}
137
138        rule date_format() -> Arc<str>
139            = date_format:$(year() "-"? month() "-"? day()) { Arc::from(date_format) }
140            / date_format:$(year() "-"? week() "-"? ['0'..='7']) { Arc::from(date_format) }
141            / date_format:$(year() "-"? week()) { Arc::from(date_format) }
142            / date_format:$(year() "-"? quarter() "-"? day()) { Arc::from(date_format) }
143            / date_format:$(year() "-"? month()) { Arc::from(date_format) }
144            / date_format:$(year() "-"? ['0'..='9']*<3>) { Arc::from(date_format) }
145            / date_format:$(year()) { Arc::from(date_format) }
146
147
148        rule hour() -> Arc<str>
149            = hour:$(['0'..='2']['0'..='9']) { Arc::from(hour) }
150
151        rule minute() -> Arc<str>
152            = minute:$(['0'..='5']['0'..='9']) { Arc::from(minute) }
153
154        rule second() -> Arc<str>
155            = second:$(['0'..='5']['0'..='9']) { Arc::from(second) }
156
157        rule time_fraction() -> Arc<str>
158            = time_fraction:$("." ['0'..='9']*<0,9>) { Arc::from(time_fraction) }
159
160        rule time_format() ->  Arc<str>
161            = time_format:$(hour() ":"? minute() ":"? second()? time_fraction()?) { Arc::from(time_format) }
162            / time_format:$(hour() ":"? minute() ":"? second()) { Arc::from(time_format) }
163            / time_format:$(hour() ":"? minute()) { Arc::from(time_format) }
164            / time_format:$(hour()) { Arc::from(time_format) }
165
166        rule date_for_date_time() -> Arc<str>
167            = date_for_date_time:$(year() "-"? month() "-"? day()) { Arc::from(date_for_date_time) }
168            / date_for_date_time:$(year() "-"? week() "-"? ['0'..='7']) { Arc::from(date_for_date_time) }
169            / date_for_date_time:$(year() "-"? quarter() "-"? day() ) { Arc::from(date_for_date_time) }
170            / date_for_date_time:$(year() "-"? ['0'..='9']*<3>) { Arc::from(date_for_date_time) }
171
172
173
174        rule timezone() -> Arc<str>
175            = timezone:$("Z") { Arc::from(timezone) }
176            / timezone:$("[" ['a'..='z' | 'A'..='Z' | '_']+ "/" ['a'..='z' | 'A'..='Z' | '_']+  "]") { Arc::from(timezone) }
177            / timezone:$("+" hour() ":"? minute()?) { Arc::from(timezone) }
178            / timezone:$("-" hour() ":"? minute()?) { Arc::from(timezone) }
179            / timezone:$("+" hour() "[" ['a'..='z' | 'A'..='Z' | '_']+ "/" ['a'..='z' | 'A'..='Z' | '_']+ "]" ) { Arc::from(timezone) }
180            / timezone:$("-" hour() "[" ['a'..='z' | 'A'..='Z' | '_']+ "/" ['a'..='z' | 'A'..='Z' | '_']+  "]" ) { Arc::from(timezone) }
181            //specific timezone (IANA timezone database)
182
183
184        rule space() = [' ']*
185
186        rule duration() -> Arc<str>
187            = duration:$("P" date_for_date_time()? "T" time_format()?) "\"" ")" { Arc::from(duration) } //"P2012-02-02T14:37:21.545"
188            / duration:$("P" (['0'..='9']*<0,19> "Y")? ( ['0'..='9']*<0,19> "M")? ( ['0'..='9']*<0,19> "W")? ( ['0'..='9']*<0,19>  time_fraction()?"D")? ("T" ( ['0'..='9']*<0,19> "H")? ( ['0'..='9']*<0,19> "M")? ( ['0'..='9']*<0,19> time_fraction()?  "S")?)?) { Arc::from(duration) }
189            / duration:$("P" (['0'..='9']*<0,19> "Y")? ( ['0'..='9']*<0,19> "M")? ( ['0'..='9']*<0,19> "W")? ( ['0'..='9']*<0,19>  time_fraction()?"D")? ){ Arc::from(duration) }
190            / duration:$("P" ("T" ( ['0'..='9']*<0,19> "H")? ( ['0'..='9']*<0,19>"M")? ( ['0'..='9']*<0,19>  time_fraction()? "S")?)?){ Arc::from(duration) }
191
192
193
194        rule projection_expression() -> Expression
195            = z:expression() _* kw_as() _* a:ident() { UnaryExpression::alias(z, a) }
196            / expression()
197
198        rule when_expression() -> (Expression, Expression)
199            = kw_when() __+ when:expression() __+ kw_then() __+ then:expression() __+ { (when, then) }
200
201        rule else_expression() -> Expression
202            = kw_else() __+ else_:expression() __+ { else_ }
203
204            #[cache_left_rec]
205        pub rule expression() -> Expression
206            = precedence!{
207                a:(@) __* kw_and() __* b:@ { BinaryExpression::and(a, b) }
208                a:(@) __* kw_or() __* b:@ { BinaryExpression::or(a, b) }
209                --
210                kw_not() __* c:(@) { UnaryExpression::not(c) }
211                --
212                it:iterator() { it }
213                --
214                a:(@) __* kw_in() __* b:@ { BinaryExpression::in_(a, b) }
215                a:(@) __* "="  __* b:@ { BinaryExpression::eq(a, b) }
216                a:(@) __* ("<>" / "!=") __* b:@ { BinaryExpression::ne(a, b) }
217                a:(@) __* "<"  __* b:@ { BinaryExpression::lt(a, b) }
218                a:(@) __* "<=" __* b:@ { BinaryExpression::le(a, b) }
219                a:(@) __* ">"  __* b:@ { BinaryExpression::gt(a, b) }
220                a:(@) __* ">=" __* b:@ { BinaryExpression::ge(a, b) }
221                --
222                a:(@) __* "+" __* b:@ { BinaryExpression::add(a, b) }
223                a:(@) __* "-" __* b:@ { BinaryExpression::subtract(a, b) }
224                --
225                a:(@) __* "*" __* b:@ { BinaryExpression::multiply(a, b) }
226                a:(@) __* "/" __* b:@ { BinaryExpression::divide(a, b) }
227                --
228                a:(@) __* "%" __* b:@ { BinaryExpression::modulo(a, b) }
229                a:(@) __* "^" __* b:@ { BinaryExpression::exponent(a, b) }
230                --
231                list:expression() "[" index:expression() "]" { BinaryExpression::index(list, index)}
232                e:(@) __+ kw_is() _+ kw_null() { UnaryExpression::is_null(e) }
233                e:(@) __+ kw_is() _+ kw_not() _+ kw_null() { UnaryExpression::is_not_null(e) }
234                kw_case() __* mtch:expression()? __* when:when_expression()+ __* else_:else_expression()? __* kw_end() { CaseExpression::case(mtch, when, else_) }
235                kw_case() __* when:when_expression()+ __* else_:else_expression()? __* kw_end() { CaseExpression::case(None, when, else_) }
236                pos: position!() func:function_name() _* "(" __* params:(expression() ** (__* "," __*))? __* ")" "." key:ident() {
237                    let params = params.unwrap_or_else(Vec::new);
238                    UnaryExpression::expression_property(FunctionExpression::function(func, params, pos ), key)
239                }
240                pos: position!() func:function_name() _* "(" __* params:(expression() ** (__* "," __*))? __* ")" {
241                    let params = params.unwrap_or_else(Vec::new);
242                    FunctionExpression::function(func, params, pos )
243                }
244                p:property() { UnaryExpression::property(p.0, p.1) }
245                "$" name:ident() { UnaryExpression::parameter(name) }
246                start:expression()? ".." end:expression()? { UnaryExpression::list_range(start, end) }
247                l:literal() { UnaryExpression::literal(l) }
248                i:ident() { UnaryExpression::ident(&i) } // UnaryExpression::ident(i)
249
250                --
251                "(" __* c:expression() __* ")" { c }
252                c: component() { ObjectExpression::object_from_vec(c)  } //ObjectExpression
253                "[" __* c:expression() ** (__* "," __*) __* "]" { ListExpression::list(c) }
254            }
255
256            #[cache_left_rec]
257        rule iterator() -> Expression
258            = "[" __* item:ident() __* kw_in() __* list:expression() __* kw_where() __* filter:expression() __* "|" __* map:expression()__* "]"
259                { IteratorExpression::map_with_filter(item, list, map, filter) }
260
261            / "[" __* item:ident() __* kw_in() __* list:expression() __*  "|" __* map:expression()__* "]"
262                { IteratorExpression::map(item, list, map) }
263
264            / "[" __* item:ident() __* kw_in() __* list:expression() __* kw_where() __* filter:expression()__* "]"
265                { IteratorExpression::iterator_with_filter(item, list, filter) }
266
267            / "[" __* item:ident() __* kw_in() __* list:expression()__* "]"
268                { IteratorExpression::iterator(item, list) }
269
270            / item:ident() __* kw_in() __* list:expression() __* kw_where() __* filter:expression() __* "|" __* map:expression()
271                { IteratorExpression::map_with_filter(item, list, map, filter) }
272
273            / item:ident() __* kw_in() __* list:expression() __*  "|" __* map:expression()
274                { IteratorExpression::map(item, list, map) }
275
276            / item:ident() __* kw_in() __* list:expression() __* kw_where() __* filter:expression()
277                { IteratorExpression::iterator_with_filter(item, list, filter) }
278
279            / item:ident() __* kw_in() __* list:expression()
280                { IteratorExpression::iterator(item, list) }
281
282
283        // e.g. 'hello_world', 'Rust', 'HAS_PROPERTY'
284        rule ident() -> Arc<str>
285            = quiet!{ident:$(alpha()alpha_num()*) { Arc::from(ident) }}
286            / expected!("an identifier")
287
288        // e.g. 'sign', 'duration_between'
289        rule function_name()  -> Arc<str>
290            = quiet!{func:$(alpha()alpha_num()* ("." alpha_num()+)?) { Arc::from(func) }}
291            / expected!("function name")
292
293        rule component() -> Vec<(Arc<str>, Expression)>
294            = "{" __* entries:( (k:ident() __* ":" __* v:expression() { (k, v) }) ++ (__* "," __*) ) __* "}" { entries }
295
296        // e.g. 'a', 'a : PERSON', ': KNOWS'
297        rule annotation() -> Annotation
298            = name:ident()? { Annotation { name } }
299
300
301        // e.g. '{answer: 42, book: 'Hitchhikers Guide'}'
302        rule property_map() -> Vec<(Arc<str>, Expression)>
303            = "{" __* entries:( (k:ident() __* ":" __* v:expression() { (k, v) }) ++ (__* "," __*) ) __* "}" { entries }
304
305        rule property_map_predicate() -> Vec<Expression>
306            = "{" __* entries:( (k:ident() __* ":" __* v:expression() { BinaryExpression::eq(UnaryExpression::property("".into(), k), v) }) ++ (__* "," __*) ) __* "}" { entries }
307
308        rule element_match() -> (Annotation, Vec<Arc<str>>, Vec<Expression>)
309            = a:annotation() labels:(":" label:ident() ** "|" { label })? _* p:(pm:property_map_predicate() { pm } / ( w:(where_clause() ** (__+) )? { w.unwrap_or_else(Vec::new) } ) ) {
310                (a, labels.unwrap_or_else(Vec::new), p)
311            }
312
313        // e.g. '()', '( a:PERSON )', '(b)', '(a : OTHER_THING)'
314        rule node() -> NodeMatch
315            = "(" _* element:element_match() _* ")" {
316                NodeMatch::new(element.0, element.1, element.2)
317              }
318            / expected!("node match pattern, e.g. '()', '( a:PERSON )', '(b)', '(a : OTHER_THING)'")
319
320        // e.g. '-', '<-', '-[ name:KIND ]-', '<-[name]-'
321        rule relation() -> RelationMatch
322            =  "-[" _* element:element_match() _* vl:variable_length()? _* "]->" {
323                RelationMatch::right(element.0, element.1, element.2, vl)
324            }
325            /  "-[" _* element:element_match() _* vl:variable_length()? _* "]-"  {
326                RelationMatch::either(element.0, element.1, element.2, vl)
327            }
328            / "<-[" _* element:element_match() _* vl:variable_length()? _* "]-"  {
329                RelationMatch::left(element.0, element.1, element.2, vl)
330            }
331            / "<-" { RelationMatch::left(Annotation::empty(), Vec::new(), Vec::new(), None) }
332            / "->" { RelationMatch::right(Annotation::empty(), Vec::new(), Vec::new(), None) }
333            / "-" { RelationMatch::either(Annotation::empty(), Vec::new(), Vec::new(), None) }
334            / expected!("relation match pattern, e.g. '-', '<-', '-[ name:KIND ]-', '<-[name]-'")
335
336
337        rule variable_length() -> VariableLengthMatch
338            = quiet!{
339                "*" min_hops:integer()? max_hops:(".." r:integer() {r})? { VariableLengthMatch{ min_hops, max_hops } }
340            }
341            / expected!("variable length match pattern, e.g. '*', '*2', '*..3'")
342
343        rule property() -> (Arc<str>, Arc<str>)
344            = name:ident() "." key:ident() { (name, key) }
345            / "$" name:ident() "." key:ident() { (name, key) }
346
347        // e.g. 'MATCH (a)', 'MATCH (a) -> (b) <- (c)', ...
348        rule match_clause() -> Vec<MatchClause>
349            = kw_match() __+ items:( (start:node()
350                    path:( (__* e:relation() __* n:node() { (e, n) }) ** "" ) {
351                    MatchClause { start, path }
352                }) ++ (__* "," __*) ) { items }
353
354        // e.g. 'WHERE a.name <> b.name', 'WHERE a.age > b.age AND a.age <= 42'
355        rule where_clause() -> Expression
356            = kw_where() __+ c:expression() { c }
357
358        // e.g. 'SET a.name = 'Peter Parker''
359        rule set_clause() -> SetClause
360            = kw_set() __+ p:property() _* "=" _* e:expression() {
361                SetClause { name: p.0, key: p.1, value: e }
362            }
363
364        // e.g. 'DELETE a'
365        rule delete_clause() -> Arc<str>
366            = kw_delete() __+ name:ident() { name }
367
368        // e.g. 'RETURN a, b'
369        rule return_clause() -> Vec<Expression>
370            = kw_return() __+ items:( projection_expression() ++ (__* "," __*) ) { items }
371
372        rule with_clause() -> Vec<Expression>
373            = kw_with() __+ items:( projection_expression() ++ (__* "," __*) ) { items }
374
375        rule with_or_return() -> Vec<Expression>
376            = w:with_clause() { w }
377            / r:return_clause() { r }
378
379        rule part(config: &dyn CypherConfiguration) -> QueryPart
380            = match_clauses:( __* m:(match_clause() ** (__+) )? { m.unwrap_or_else(Vec::new).into_iter().flatten().collect() } )
381                where_clauses:( __* w:(where_clause() ** (__+) )? { w.unwrap_or_else(Vec::new) } )
382                //create_clauses:( __* c:(create_clause() ** (__+) )? { c.unwrap_or_else(Vec::new) } )
383                set_clauses:( __* s:(set_clause() ** (__+) )? { s.unwrap_or_else(Vec::new) } )
384                delete_clauses:( __* d:(delete_clause() ** (__+) )? { d.unwrap_or_else(Vec::new) } )
385                return_clause:( with_or_return() )
386                {
387                    QueryPart {
388                        match_clauses,
389                        where_clauses,
390                        return_clause: return_clause.into_projection_clause(config),
391                    }
392                }
393
394        pub rule query(config: &dyn CypherConfiguration) -> Query
395            = __*
396              parts:(w:( part(config)+ ) { w } )
397              __* {
398                Query {
399                    parts,
400                }
401            }
402    }
403}
404
405pub fn parse(
406    input: &str,
407    config: &dyn CypherConfiguration,
408) -> Result<ast::Query, ParseError<LineCol>> {
409    cypher::query(input, config)
410}
411
412pub fn parse_expression(input: &str) -> Result<ast::Expression, ParseError<LineCol>> {
413    cypher::expression(input)
414}
415
416pub trait CypherConfiguration: Send + Sync {
417    fn get_aggregating_function_names(&self) -> HashSet<String>;
418}
419
420pub trait IntoProjectionClause {
421    fn into_projection_clause(self, config: &dyn CypherConfiguration) -> ProjectionClause;
422}
423
424impl IntoProjectionClause for Vec<Expression> {
425    fn into_projection_clause(self, config: &dyn CypherConfiguration) -> ProjectionClause {
426        let mut keys = Vec::new();
427        let mut aggs = Vec::new();
428
429        for expr in self {
430            if contains_aggregating_function(&expr, config) {
431                aggs.push(expr);
432            } else {
433                keys.push(expr);
434            }
435        }
436
437        if aggs.is_empty() {
438            ProjectionClause::Item(keys)
439        } else {
440            ProjectionClause::GroupBy {
441                grouping: keys,
442                aggregates: aggs,
443            }
444        }
445    }
446}
447
448pub fn contains_aggregating_function(
449    expression: &Expression,
450    config: &dyn CypherConfiguration,
451) -> bool {
452    let stack = &mut vec![expression];
453    let aggr_funcs = config.get_aggregating_function_names();
454
455    while let Some(expr) = stack.pop() {
456        if let Expression::FunctionExpression(ref function) = expr {
457            if aggr_funcs.contains(&function.name.to_string()) {
458                return true;
459            }
460        }
461
462        for c in expr.get_children() {
463            stack.push(c);
464        }
465    }
466
467    false
468}
469
470pub struct CypherParser {
471    config: Arc<dyn CypherConfiguration>,
472}
473
474impl CypherParser {
475    pub fn new(config: Arc<dyn CypherConfiguration>) -> Self {
476        CypherParser { config }
477    }
478}
479
480impl QueryParser for CypherParser {
481    fn parse(&self, input: &str) -> Result<ast::Query, QueryParseError> {
482        match parse(input, &*self.config) {
483            Ok(query) => Ok(query),
484            Err(e) => Err(QueryParseError::ParserError(Box::new(e))),
485        }
486    }
487}