zen_expression/parser/
unary.rs

1use crate::functions::{
2    ClosureFunction, DateMethod, DeprecatedFunction, FunctionKind, InternalFunction, MethodKind,
3};
4use crate::lexer::{Bracket, ComparisonOperator, Identifier, LogicalOperator, Operator, TokenKind};
5use crate::parser::ast::{AstNodeError, Node};
6use crate::parser::constants::{Associativity, BINARY_OPERATORS, UNARY_OPERATORS};
7use crate::parser::parser::{Parser, ParserContext};
8use crate::parser::unary::UnaryNodeBehaviour::CompareWithReference;
9use crate::parser::{NodeMetadata, ParserResult};
10
11#[derive(Debug)]
12pub struct Unary;
13
14const ROOT_NODE: Node<'static> = Node::Identifier("$");
15
16impl<'arena, 'token_ref> Parser<'arena, 'token_ref, Unary> {
17    pub fn parse(&self) -> ParserResult<'arena> {
18        let root = self.root_expression();
19
20        ParserResult {
21            root,
22            is_complete: self.is_done(),
23            metadata: self.node_metadata.clone().map(|t| t.into_inner()),
24        }
25    }
26
27    fn root_expression(&self) -> &'arena Node<'arena> {
28        let mut left_node = self.expression_pair();
29
30        while !self.is_done() {
31            let Some(current_token) = self.current() else {
32                break;
33            };
34
35            let join_operator = match &current_token.kind {
36                TokenKind::Operator(Operator::Logical(LogicalOperator::And)) => {
37                    Operator::Logical(LogicalOperator::And)
38                }
39                TokenKind::Operator(Operator::Logical(LogicalOperator::Or))
40                | TokenKind::Operator(Operator::Comma) => Operator::Logical(LogicalOperator::Or),
41                _ => {
42                    return self.error(AstNodeError::Custom {
43                        message: self.bump.alloc_str(
44                            format!("Invalid join operator `{}`", current_token.kind).as_str(),
45                        ),
46                        span: current_token.span,
47                    })
48                }
49            };
50
51            self.next();
52            let right_node = self.expression_pair();
53            left_node = self.node(
54                Node::Binary {
55                    left: left_node,
56                    operator: join_operator,
57                    right: right_node,
58                },
59                |h| NodeMetadata {
60                    span: h.span(left_node, right_node).unwrap_or_default(),
61                },
62            );
63        }
64
65        left_node
66    }
67
68    fn expression_pair(&self) -> &'arena Node<'arena> {
69        let mut left_node = &ROOT_NODE;
70        let current_token = self.current();
71
72        if let Some(TokenKind::Operator(Operator::Comparison(_))) = self.current_kind() {
73            // Skips
74        } else {
75            left_node = self.binary_expression(0, ParserContext::Global);
76        }
77
78        match self.current_kind() {
79            Some(TokenKind::Operator(Operator::Comparison(comparison))) => {
80                self.next();
81                let right_node = self.binary_expression(0, ParserContext::Global);
82                left_node = self.node(
83                    Node::Binary {
84                        left: left_node,
85                        operator: Operator::Comparison(*comparison),
86                        right: right_node,
87                    },
88                    |h| NodeMetadata {
89                        span: (
90                            current_token.map(|t| t.span.0).unwrap_or_default(),
91                            h.metadata(right_node).map(|n| n.span.1).unwrap_or_default(),
92                        ),
93                    },
94                );
95            }
96            _ => {
97                let behaviour = UnaryNodeBehaviour::from(left_node);
98                match behaviour {
99                    CompareWithReference(comparator) => {
100                        left_node = self.node(
101                            Node::Binary {
102                                left: &ROOT_NODE,
103                                operator: Operator::Comparison(comparator),
104                                right: left_node,
105                            },
106                            |h| NodeMetadata {
107                                span: (
108                                    current_token.map(|t| t.span.0).unwrap_or_default(),
109                                    h.metadata(left_node).map(|n| n.span.1).unwrap_or_default(),
110                                ),
111                            },
112                        )
113                    }
114                    UnaryNodeBehaviour::AsBoolean => {
115                        left_node = self.node(
116                            Node::FunctionCall {
117                                kind: FunctionKind::Internal(InternalFunction::Bool),
118                                arguments: self.bump.alloc_slice_clone(&[left_node]),
119                            },
120                            |h| NodeMetadata {
121                                span: (
122                                    current_token.map(|t| t.span.0).unwrap_or_default(),
123                                    h.metadata(left_node).map(|n| n.span.1).unwrap_or_default(),
124                                ),
125                            },
126                        )
127                    }
128                }
129            }
130        }
131
132        left_node
133    }
134
135    fn binary_expression(&self, precedence: u8, ctx: ParserContext) -> &'arena Node<'arena> {
136        let mut node_left = self.unary_expression();
137        let Some(mut token) = self.current() else {
138            return node_left;
139        };
140
141        while let TokenKind::Operator(operator) = &token.kind {
142            if self.is_done() {
143                break;
144            }
145
146            if ctx == ParserContext::Global
147                && matches!(
148                    operator,
149                    Operator::Comma
150                        | Operator::Logical(LogicalOperator::And)
151                        | Operator::Logical(LogicalOperator::Or)
152                )
153            {
154                break;
155            }
156
157            let Some(op) = BINARY_OPERATORS.get(operator) else {
158                break;
159            };
160
161            if op.precedence < precedence {
162                break;
163            }
164
165            self.next();
166            let node_right = match op.associativity {
167                Associativity::Left => {
168                    self.binary_expression(op.precedence + 1, ParserContext::Global)
169                }
170                _ => self.binary_expression(op.precedence, ParserContext::Global),
171            };
172
173            node_left = self.node(
174                Node::Binary {
175                    operator: *operator,
176                    left: node_left,
177                    right: node_right,
178                },
179                |h| NodeMetadata {
180                    span: h.span(node_left, node_right).unwrap_or_default(),
181                },
182            );
183
184            let Some(t) = self.current() else {
185                break;
186            };
187            token = t;
188        }
189
190        if precedence == 0 {
191            if let Some(conditional_node) =
192                self.conditional(node_left, |c| self.binary_expression(0, c))
193            {
194                node_left = conditional_node;
195            }
196        }
197
198        node_left
199    }
200
201    fn unary_expression(&self) -> &'arena Node<'arena> {
202        let Some(token) = self.current() else {
203            return self.literal(|c| self.binary_expression(0, c));
204        };
205
206        if self.depth() > 0 && token.kind == TokenKind::Identifier(Identifier::CallbackReference) {
207            self.next();
208
209            let node = self.node(Node::Pointer, |_| NodeMetadata { span: token.span });
210            return self.with_postfix(node, |c| self.binary_expression(0, c));
211        }
212
213        if let TokenKind::Operator(operator) = &token.kind {
214            let Some(unary_operator) = UNARY_OPERATORS.get(operator) else {
215                return self.error(AstNodeError::UnexpectedToken {
216                    expected: self.bump.alloc_str("UnaryOperator"),
217                    received: self.bump.alloc_str(token.kind.to_string().as_str()),
218                    span: token.span,
219                });
220            };
221
222            self.next();
223            let expr = self.binary_expression(unary_operator.precedence, ParserContext::Global);
224            let node = self.node(
225                Node::Unary {
226                    operator: *operator,
227                    node: expr,
228                },
229                |h| NodeMetadata {
230                    span: (
231                        token.span.0,
232                        h.metadata(expr).map(|n| n.span.1).unwrap_or_default(),
233                    ),
234                },
235            );
236
237            return node;
238        }
239
240        if let Some(interval_node) = self.interval(|c| self.binary_expression(0, c)) {
241            return interval_node;
242        }
243
244        if token.kind == TokenKind::Bracket(Bracket::LeftParenthesis) {
245            let p_start = self.current().map(|s| s.span.0);
246
247            self.next();
248            let binary_node = self.binary_expression(0, ParserContext::Global);
249            if let Some(error_node) = self.expect(TokenKind::Bracket(Bracket::RightParenthesis)) {
250                return error_node;
251            };
252
253            let expr = self.node(Node::Parenthesized(binary_node), |_| NodeMetadata {
254                span: (p_start.unwrap_or_default(), self.prev_token_end()),
255            });
256
257            return self.with_postfix(expr, |c| self.binary_expression(0, c));
258        }
259
260        self.literal(|c| self.binary_expression(0, c))
261    }
262}
263
264/// Dictates the behaviour of nodes in unary mode.
265/// If `CompareWithReference` is set, node will attempt to make the comparison with the reference,
266/// essentially making it (in case of Equal operator) `$ == nodeValue`, or (in case of In operator)
267/// `$ in nodeValue`.
268///
269/// Using `AsBoolean` will cast the nodeValue to boolean and skip comparison with reference ($).
270/// You may still use references in such case directly, e.g. `contains($, 'hello')`.
271///
272/// Rationale behind this is to avoid scenarios where e.g. $ = false and expression is
273/// `contains($, 'needle')`. If we didn't ignore the reference, unary expression will be
274/// reduced to `$ == contains($, 'needle')` which will be truthy when $ does not
275/// contain needle.
276#[derive(Debug, PartialEq)]
277enum UnaryNodeBehaviour {
278    CompareWithReference(ComparisonOperator),
279    AsBoolean,
280}
281
282impl From<&Node<'_>> for UnaryNodeBehaviour {
283    fn from(value: &Node) -> Self {
284        use ComparisonOperator::*;
285        use UnaryNodeBehaviour::*;
286
287        match value {
288            Node::Null => CompareWithReference(Equal),
289            Node::Root => CompareWithReference(Equal),
290            Node::Bool(_) => CompareWithReference(Equal),
291            Node::Number(_) => CompareWithReference(Equal),
292            Node::String(_) => CompareWithReference(Equal),
293            Node::TemplateString(_) => CompareWithReference(Equal),
294            Node::Object(_) => CompareWithReference(Equal),
295            Node::Pointer => AsBoolean,
296            Node::Array(_) => CompareWithReference(In),
297            Node::Identifier(_) => CompareWithReference(Equal),
298            Node::Closure(_) => AsBoolean,
299            Node::Member { .. } => CompareWithReference(Equal),
300            Node::Slice { .. } => CompareWithReference(In),
301            Node::Interval { .. } => CompareWithReference(In),
302            Node::Conditional {
303                on_true, on_false, ..
304            } => {
305                let a = UnaryNodeBehaviour::from(*on_true);
306                let b = UnaryNodeBehaviour::from(*on_false);
307
308                if a == b {
309                    a
310                } else {
311                    CompareWithReference(Equal)
312                }
313            }
314            Node::Unary { node, .. } => UnaryNodeBehaviour::from(*node),
315            Node::Parenthesized(n) => UnaryNodeBehaviour::from(*n),
316            Node::Binary {
317                left,
318                operator,
319                right,
320            } => match operator {
321                Operator::Arithmetic(_) => {
322                    let a = UnaryNodeBehaviour::from(*left);
323                    let b = UnaryNodeBehaviour::from(*right);
324
325                    if a == b {
326                        a
327                    } else {
328                        CompareWithReference(Equal)
329                    }
330                }
331                Operator::Logical(_) => AsBoolean,
332                Operator::Comparison(_) => AsBoolean,
333                Operator::Range => CompareWithReference(In),
334                Operator::Slice => CompareWithReference(In),
335                Operator::Comma => AsBoolean,
336                Operator::Dot => AsBoolean,
337                Operator::QuestionMark => AsBoolean,
338            },
339            Node::FunctionCall { kind, .. } => match kind {
340                FunctionKind::Internal(i) => match i {
341                    InternalFunction::Len => CompareWithReference(Equal),
342                    InternalFunction::Upper => CompareWithReference(Equal),
343                    InternalFunction::Lower => CompareWithReference(Equal),
344                    InternalFunction::Trim => CompareWithReference(Equal),
345                    InternalFunction::Abs => CompareWithReference(Equal),
346                    InternalFunction::Sum => CompareWithReference(Equal),
347                    InternalFunction::Avg => CompareWithReference(Equal),
348                    InternalFunction::Min => CompareWithReference(Equal),
349                    InternalFunction::Max => CompareWithReference(Equal),
350                    InternalFunction::Rand => CompareWithReference(Equal),
351                    InternalFunction::Median => CompareWithReference(Equal),
352                    InternalFunction::Mode => CompareWithReference(Equal),
353                    InternalFunction::Floor => CompareWithReference(Equal),
354                    InternalFunction::Ceil => CompareWithReference(Equal),
355                    InternalFunction::Round => CompareWithReference(Equal),
356                    InternalFunction::Trunc => CompareWithReference(Equal),
357                    InternalFunction::String => CompareWithReference(Equal),
358                    InternalFunction::Number => CompareWithReference(Equal),
359                    InternalFunction::Bool => CompareWithReference(Equal),
360                    InternalFunction::Flatten => CompareWithReference(In),
361                    InternalFunction::Extract => CompareWithReference(In),
362                    InternalFunction::Contains => AsBoolean,
363                    InternalFunction::StartsWith => AsBoolean,
364                    InternalFunction::EndsWith => AsBoolean,
365                    InternalFunction::Matches => AsBoolean,
366                    InternalFunction::FuzzyMatch => CompareWithReference(Equal),
367                    InternalFunction::Split => CompareWithReference(In),
368                    InternalFunction::IsNumeric => AsBoolean,
369                    InternalFunction::Keys => CompareWithReference(In),
370                    InternalFunction::Values => CompareWithReference(In),
371                    InternalFunction::Type => CompareWithReference(Equal),
372                    InternalFunction::Date => CompareWithReference(Equal),
373                },
374                FunctionKind::Deprecated(d) => match d {
375                    DeprecatedFunction::Date => CompareWithReference(Equal),
376                    DeprecatedFunction::Time => CompareWithReference(Equal),
377                    DeprecatedFunction::Duration => CompareWithReference(Equal),
378                    DeprecatedFunction::Year => CompareWithReference(Equal),
379                    DeprecatedFunction::DayOfWeek => CompareWithReference(Equal),
380                    DeprecatedFunction::DayOfMonth => CompareWithReference(Equal),
381                    DeprecatedFunction::DayOfYear => CompareWithReference(Equal),
382                    DeprecatedFunction::WeekOfYear => CompareWithReference(Equal),
383                    DeprecatedFunction::MonthOfYear => CompareWithReference(Equal),
384                    DeprecatedFunction::MonthString => CompareWithReference(Equal),
385                    DeprecatedFunction::DateString => CompareWithReference(Equal),
386                    DeprecatedFunction::WeekdayString => CompareWithReference(Equal),
387                    DeprecatedFunction::StartOf => CompareWithReference(Equal),
388                    DeprecatedFunction::EndOf => CompareWithReference(Equal),
389                },
390                FunctionKind::Closure(c) => match c {
391                    ClosureFunction::All => AsBoolean,
392                    ClosureFunction::Some => AsBoolean,
393                    ClosureFunction::None => AsBoolean,
394                    ClosureFunction::One => AsBoolean,
395                    ClosureFunction::Filter => CompareWithReference(In),
396                    ClosureFunction::Map => CompareWithReference(In),
397                    ClosureFunction::FlatMap => CompareWithReference(In),
398                    ClosureFunction::Count => CompareWithReference(Equal),
399                },
400            },
401            Node::MethodCall { kind, .. } => match kind {
402                MethodKind::DateMethod(dm) => match dm {
403                    DateMethod::Add => CompareWithReference(Equal),
404                    DateMethod::Sub => CompareWithReference(Equal),
405                    DateMethod::Format => CompareWithReference(Equal),
406                    DateMethod::Month => CompareWithReference(Equal),
407                    DateMethod::Year => CompareWithReference(Equal),
408                    DateMethod::Set => CompareWithReference(Equal),
409                    DateMethod::StartOf => CompareWithReference(Equal),
410                    DateMethod::EndOf => CompareWithReference(Equal),
411                    DateMethod::Diff => CompareWithReference(Equal),
412                    DateMethod::Tz => CompareWithReference(Equal),
413                    DateMethod::Second => CompareWithReference(Equal),
414                    DateMethod::Minute => CompareWithReference(Equal),
415                    DateMethod::Hour => CompareWithReference(Equal),
416                    DateMethod::Day => CompareWithReference(Equal),
417                    DateMethod::DayOfYear => CompareWithReference(Equal),
418                    DateMethod::Week => CompareWithReference(Equal),
419                    DateMethod::Weekday => CompareWithReference(Equal),
420                    DateMethod::Quarter => CompareWithReference(Equal),
421                    DateMethod::Timestamp => CompareWithReference(Equal),
422                    DateMethod::OffsetName => CompareWithReference(Equal),
423                    DateMethod::IsSame => AsBoolean,
424                    DateMethod::IsBefore => AsBoolean,
425                    DateMethod::IsAfter => AsBoolean,
426                    DateMethod::IsSameOrBefore => AsBoolean,
427                    DateMethod::IsSameOrAfter => AsBoolean,
428                    DateMethod::IsValid => AsBoolean,
429                    DateMethod::IsYesterday => AsBoolean,
430                    DateMethod::IsToday => AsBoolean,
431                    DateMethod::IsTomorrow => AsBoolean,
432                    DateMethod::IsLeapYear => AsBoolean,
433                },
434            },
435            Node::Error { .. } => AsBoolean,
436        }
437    }
438}