mf_expression/parser/
unary.rs

1use crate::functions::{
2    ClosureFunction, DateMethod, DeprecatedFunction, FunctionKind,
3    InternalFunction, MethodKind,
4};
5use crate::lexer::{
6    Bracket, ComparisonOperator, Identifier, LogicalOperator, Operator,
7    TokenKind,
8};
9use crate::parser::ast::{AstNodeError, Node};
10use crate::parser::constants::{Associativity, BINARY_OPERATORS, UNARY_OPERATORS};
11use crate::parser::parser::{Parser, ParserContext};
12use crate::parser::unary::UnaryNodeBehaviour::CompareWithReference;
13use crate::parser::{NodeMetadata, ParserResult};
14
15#[derive(Debug)]
16pub struct Unary;
17
18const ROOT_NODE: Node<'static> = Node::Identifier("$");
19
20impl<'arena, 'token_ref> Parser<'arena, 'token_ref, Unary> {
21    pub fn parse(&self) -> ParserResult<'arena> {
22        let root = self.root_expression();
23
24        ParserResult {
25            root,
26            is_complete: self.is_done(),
27            metadata: self.node_metadata.clone().map(|t| t.into_inner()),
28        }
29    }
30
31    fn root_expression(&self) -> &'arena Node<'arena> {
32        let mut left_node = self.expression_pair();
33
34        while !self.is_done() {
35            let Some(current_token) = self.current() else {
36                break;
37            };
38
39            let join_operator = match &current_token.kind {
40                TokenKind::Operator(Operator::Logical(
41                    LogicalOperator::And,
42                )) => Operator::Logical(LogicalOperator::And),
43                TokenKind::Operator(Operator::Logical(LogicalOperator::Or))
44                | TokenKind::Operator(Operator::Comma) => {
45                    Operator::Logical(LogicalOperator::Or)
46                },
47                _ => {
48                    return self.error(AstNodeError::Custom {
49                        message: self.bump.alloc_str(
50                            format!(
51                                "Invalid join operator `{}`",
52                                current_token.kind
53                            )
54                            .as_str(),
55                        ),
56                        span: current_token.span,
57                    });
58                },
59            };
60
61            self.next();
62            let right_node = self.expression_pair();
63            left_node = self.node(
64                Node::Binary {
65                    left: left_node,
66                    operator: join_operator,
67                    right: right_node,
68                },
69                |h| NodeMetadata {
70                    span: h.span(left_node, right_node).unwrap_or_default(),
71                },
72            );
73        }
74
75        left_node
76    }
77
78    fn expression_pair(&self) -> &'arena Node<'arena> {
79        let mut left_node = &ROOT_NODE;
80        let current_token = self.current();
81
82        if let Some(TokenKind::Operator(Operator::Comparison(_))) =
83            self.current_kind()
84        {
85            // Skips
86        } else {
87            left_node = self.binary_expression(0, ParserContext::Global);
88        }
89
90        match self.current_kind() {
91            Some(TokenKind::Operator(Operator::Comparison(comparison))) => {
92                self.next();
93                let right_node =
94                    self.binary_expression(0, ParserContext::Global);
95                left_node = self.node(
96                    Node::Binary {
97                        left: left_node,
98                        operator: Operator::Comparison(*comparison),
99                        right: right_node,
100                    },
101                    |h| NodeMetadata {
102                        span: (
103                            current_token.map(|t| t.span.0).unwrap_or_default(),
104                            h.metadata(right_node)
105                                .map(|n| n.span.1)
106                                .unwrap_or_default(),
107                        ),
108                    },
109                );
110            },
111            _ => {
112                let behaviour = UnaryNodeBehaviour::from(left_node);
113                match behaviour {
114                    CompareWithReference(comparator) => {
115                        left_node = self.node(
116                            Node::Binary {
117                                left: &ROOT_NODE,
118                                operator: Operator::Comparison(comparator),
119                                right: left_node,
120                            },
121                            |h| NodeMetadata {
122                                span: (
123                                    current_token
124                                        .map(|t| t.span.0)
125                                        .unwrap_or_default(),
126                                    h.metadata(left_node)
127                                        .map(|n| n.span.1)
128                                        .unwrap_or_default(),
129                                ),
130                            },
131                        )
132                    },
133                    UnaryNodeBehaviour::AsBoolean => {
134                        left_node = self.node(
135                            Node::FunctionCall {
136                                kind: FunctionKind::Internal(
137                                    InternalFunction::Bool,
138                                ),
139                                arguments: self
140                                    .bump
141                                    .alloc_slice_clone(&[left_node]),
142                            },
143                            |h| NodeMetadata {
144                                span: (
145                                    current_token
146                                        .map(|t| t.span.0)
147                                        .unwrap_or_default(),
148                                    h.metadata(left_node)
149                                        .map(|n| n.span.1)
150                                        .unwrap_or_default(),
151                                ),
152                            },
153                        )
154                    },
155                }
156            },
157        }
158
159        left_node
160    }
161
162    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
163    fn binary_expression(
164        &self,
165        precedence: u8,
166        ctx: ParserContext,
167    ) -> &'arena Node<'arena> {
168        let mut node_left = self.unary_expression();
169        let Some(mut token) = self.current() else {
170            return node_left;
171        };
172
173        while let TokenKind::Operator(operator) = &token.kind {
174            if self.is_done() {
175                break;
176            }
177
178            if ctx == ParserContext::Global
179                && matches!(
180                    operator,
181                    Operator::Comma
182                        | Operator::Logical(LogicalOperator::And)
183                        | Operator::Logical(LogicalOperator::Or)
184                )
185            {
186                break;
187            }
188
189            let Some(op) = BINARY_OPERATORS.get(operator) else {
190                break;
191            };
192
193            if op.precedence < precedence {
194                break;
195            }
196
197            self.next();
198            let node_right = match op.associativity {
199                Associativity::Left => self.binary_expression(
200                    op.precedence + 1,
201                    ParserContext::Global,
202                ),
203                _ => {
204                    self.binary_expression(op.precedence, ParserContext::Global)
205                },
206            };
207
208            node_left = self.node(
209                Node::Binary {
210                    operator: *operator,
211                    left: node_left,
212                    right: node_right,
213                },
214                |h| NodeMetadata {
215                    span: h.span(node_left, node_right).unwrap_or_default(),
216                },
217            );
218
219            let Some(t) = self.current() else {
220                break;
221            };
222            token = t;
223        }
224
225        if precedence == 0 {
226            if let Some(conditional_node) =
227                self.conditional(node_left, &|c| self.binary_expression(0, c))
228            {
229                node_left = conditional_node;
230            }
231        }
232
233        node_left
234    }
235
236    fn unary_expression(&self) -> &'arena Node<'arena> {
237        let Some(token) = self.current() else {
238            return self.literal(&|c| self.binary_expression(0, c));
239        };
240
241        if self.depth() > 0
242            && token.kind
243                == TokenKind::Identifier(Identifier::CallbackReference)
244        {
245            self.next();
246
247            let node =
248                self.node(Node::Pointer, |_| NodeMetadata { span: token.span });
249            let (n, _) =
250                self.with_postfix(node, &|c| self.binary_expression(0, c));
251            return n;
252        }
253
254        if let TokenKind::Operator(operator) = &token.kind {
255            let Some(unary_operator) = UNARY_OPERATORS.get(operator) else {
256                return self.error(AstNodeError::UnexpectedToken {
257                    expected: self.bump.alloc_str("UnaryOperator"),
258                    received: self
259                        .bump
260                        .alloc_str(token.kind.to_string().as_str()),
261                    span: token.span,
262                });
263            };
264
265            self.next();
266            let expr = self.binary_expression(
267                unary_operator.precedence,
268                ParserContext::Global,
269            );
270            let node = self.node(
271                Node::Unary { operator: *operator, node: expr },
272                |h| NodeMetadata {
273                    span: (
274                        token.span.0,
275                        h.metadata(expr).map(|n| n.span.1).unwrap_or_default(),
276                    ),
277                },
278            );
279
280            return node;
281        }
282
283        if let Some(interval_node) =
284            self.interval(&|c| self.binary_expression(0, c))
285        {
286            return interval_node;
287        }
288
289        if token.kind == TokenKind::Bracket(Bracket::LeftParenthesis) {
290            let p_start = self.current().map(|s| s.span.0);
291
292            self.next();
293            let binary_node = self.binary_expression(0, ParserContext::Global);
294            if let Some(error_node) =
295                self.expect(TokenKind::Bracket(Bracket::RightParenthesis))
296            {
297                return error_node;
298            };
299
300            let expr =
301                self.node(Node::Parenthesized(binary_node), |_| NodeMetadata {
302                    span: (p_start.unwrap_or_default(), self.prev_token_end()),
303                });
304
305            let (n, _) =
306                self.with_postfix(expr, &|c| self.binary_expression(0, c));
307            return n;
308        }
309
310        self.literal(&|c| self.binary_expression(0, c))
311    }
312}
313
314/// Dictates the behaviour of nodes in unary mode.
315/// If `CompareWithReference` is set, node will attempt to make the comparison with the reference,
316/// essentially making it (in case of Equal operator) `$ == nodeValue`, or (in case of In operator)
317/// `$ in nodeValue`.
318///
319/// Using `AsBoolean` will cast the nodeValue to boolean and skip comparison with reference ($).
320/// You may still use references in such case directly, e.g. `contains($, 'hello')`.
321///
322/// Rationale behind this is to avoid scenarios where e.g. $ = false and expression is
323/// `contains($, 'needle')`. If we didn't ignore the reference, unary expression will be
324/// reduced to `$ == contains($, 'needle')` which will be truthy when $ does not
325/// contain needle.
326#[derive(Debug, PartialEq)]
327enum UnaryNodeBehaviour {
328    CompareWithReference(ComparisonOperator),
329    AsBoolean,
330}
331
332impl From<&Node<'_>> for UnaryNodeBehaviour {
333    fn from(value: &Node) -> Self {
334        use ComparisonOperator::*;
335        use UnaryNodeBehaviour::*;
336
337        match value {
338            Node::Null => CompareWithReference(Equal),
339            Node::Root => CompareWithReference(Equal),
340            Node::Bool(_) => CompareWithReference(Equal),
341            Node::Number(_) => CompareWithReference(Equal),
342            Node::String(_) => CompareWithReference(Equal),
343            Node::TemplateString(_) => CompareWithReference(Equal),
344            Node::Object(_) => CompareWithReference(Equal),
345            Node::Assignments { .. } => CompareWithReference(Equal),
346            Node::Pointer => AsBoolean,
347            Node::Array(_) => CompareWithReference(In),
348            Node::Identifier(_) => CompareWithReference(Equal),
349            Node::Closure(_) => AsBoolean,
350            Node::Member { .. } => CompareWithReference(Equal),
351            Node::Slice { .. } => CompareWithReference(In),
352            Node::Interval { .. } => CompareWithReference(In),
353            Node::Conditional { on_true, on_false, .. } => {
354                let a = UnaryNodeBehaviour::from(*on_true);
355                let b = UnaryNodeBehaviour::from(*on_false);
356
357                if a == b { a } else { CompareWithReference(Equal) }
358            },
359            Node::Unary { node, .. } => UnaryNodeBehaviour::from(*node),
360            Node::Parenthesized(n) => UnaryNodeBehaviour::from(*n),
361            Node::Binary { left, operator, right } => match operator {
362                Operator::Arithmetic(_) => {
363                    let a = UnaryNodeBehaviour::from(*left);
364                    let b = UnaryNodeBehaviour::from(*right);
365
366                    if a == b { a } else { CompareWithReference(Equal) }
367                },
368                Operator::Logical(_) => AsBoolean,
369                Operator::Comparison(_) => AsBoolean,
370                Operator::Range => CompareWithReference(In),
371                Operator::Slice => CompareWithReference(In),
372                Operator::Comma => AsBoolean,
373                Operator::Dot => AsBoolean,
374                Operator::QuestionMark => AsBoolean,
375                Operator::Assign => AsBoolean,
376                Operator::Semi => AsBoolean,
377            },
378            Node::FunctionCall { kind, .. } => match kind {
379                FunctionKind::Internal(i) => match i {
380                    InternalFunction::Len => CompareWithReference(Equal),
381                    InternalFunction::Upper => CompareWithReference(Equal),
382                    InternalFunction::Lower => CompareWithReference(Equal),
383                    InternalFunction::Trim => CompareWithReference(Equal),
384                    InternalFunction::Abs => CompareWithReference(Equal),
385                    InternalFunction::Sum => CompareWithReference(Equal),
386                    InternalFunction::Avg => CompareWithReference(Equal),
387                    InternalFunction::Min => CompareWithReference(Equal),
388                    InternalFunction::Max => CompareWithReference(Equal),
389                    InternalFunction::Rand => CompareWithReference(Equal),
390                    InternalFunction::Median => CompareWithReference(Equal),
391                    InternalFunction::Mode => CompareWithReference(Equal),
392                    InternalFunction::Floor => CompareWithReference(Equal),
393                    InternalFunction::Ceil => CompareWithReference(Equal),
394                    InternalFunction::Round => CompareWithReference(Equal),
395                    InternalFunction::Trunc => CompareWithReference(Equal),
396                    InternalFunction::String => CompareWithReference(Equal),
397                    InternalFunction::Number => CompareWithReference(Equal),
398                    InternalFunction::Bool => CompareWithReference(Equal),
399                    InternalFunction::Flatten => CompareWithReference(In),
400                    InternalFunction::Extract => CompareWithReference(In),
401                    InternalFunction::Contains => AsBoolean,
402                    InternalFunction::StartsWith => AsBoolean,
403                    InternalFunction::EndsWith => AsBoolean,
404                    InternalFunction::Matches => AsBoolean,
405                    InternalFunction::FuzzyMatch => CompareWithReference(Equal),
406                    InternalFunction::Split => CompareWithReference(In),
407                    InternalFunction::IsNumeric => AsBoolean,
408                    InternalFunction::Keys => CompareWithReference(In),
409                    InternalFunction::Values => CompareWithReference(In),
410                    InternalFunction::Type => CompareWithReference(Equal),
411                    InternalFunction::Date => CompareWithReference(Equal),
412                },
413                FunctionKind::Deprecated(d) => match d {
414                    DeprecatedFunction::Date => CompareWithReference(Equal),
415                    DeprecatedFunction::Time => CompareWithReference(Equal),
416                    DeprecatedFunction::Duration => CompareWithReference(Equal),
417                    DeprecatedFunction::Year => CompareWithReference(Equal),
418                    DeprecatedFunction::DayOfWeek => {
419                        CompareWithReference(Equal)
420                    },
421                    DeprecatedFunction::DayOfMonth => {
422                        CompareWithReference(Equal)
423                    },
424                    DeprecatedFunction::DayOfYear => {
425                        CompareWithReference(Equal)
426                    },
427                    DeprecatedFunction::WeekOfYear => {
428                        CompareWithReference(Equal)
429                    },
430                    DeprecatedFunction::MonthOfYear => {
431                        CompareWithReference(Equal)
432                    },
433                    DeprecatedFunction::MonthString => {
434                        CompareWithReference(Equal)
435                    },
436                    DeprecatedFunction::DateString => {
437                        CompareWithReference(Equal)
438                    },
439                    DeprecatedFunction::WeekdayString => {
440                        CompareWithReference(Equal)
441                    },
442                    DeprecatedFunction::StartOf => CompareWithReference(Equal),
443                    DeprecatedFunction::EndOf => CompareWithReference(Equal),
444                },
445                FunctionKind::Closure(c) => match c {
446                    ClosureFunction::All => AsBoolean,
447                    ClosureFunction::Some => AsBoolean,
448                    ClosureFunction::None => AsBoolean,
449                    ClosureFunction::One => AsBoolean,
450                    ClosureFunction::Filter => CompareWithReference(In),
451                    ClosureFunction::Map => CompareWithReference(In),
452                    ClosureFunction::FlatMap => CompareWithReference(In),
453                    ClosureFunction::Count => CompareWithReference(Equal),
454                },
455                FunctionKind::Mf(_) => AsBoolean,
456            },
457            Node::MethodCall { kind, .. } => match kind {
458                MethodKind::DateMethod(dm) => match dm {
459                    DateMethod::Add => CompareWithReference(Equal),
460                    DateMethod::Sub => CompareWithReference(Equal),
461                    DateMethod::Format => CompareWithReference(Equal),
462                    DateMethod::Month => CompareWithReference(Equal),
463                    DateMethod::Year => CompareWithReference(Equal),
464                    DateMethod::Set => CompareWithReference(Equal),
465                    DateMethod::StartOf => CompareWithReference(Equal),
466                    DateMethod::EndOf => CompareWithReference(Equal),
467                    DateMethod::Diff => CompareWithReference(Equal),
468                    DateMethod::Tz => CompareWithReference(Equal),
469                    DateMethod::Second => CompareWithReference(Equal),
470                    DateMethod::Minute => CompareWithReference(Equal),
471                    DateMethod::Hour => CompareWithReference(Equal),
472                    DateMethod::Day => CompareWithReference(Equal),
473                    DateMethod::DayOfYear => CompareWithReference(Equal),
474                    DateMethod::Week => CompareWithReference(Equal),
475                    DateMethod::Weekday => CompareWithReference(Equal),
476                    DateMethod::Quarter => CompareWithReference(Equal),
477                    DateMethod::Timestamp => CompareWithReference(Equal),
478                    DateMethod::OffsetName => CompareWithReference(Equal),
479                    DateMethod::IsSame => AsBoolean,
480                    DateMethod::IsBefore => AsBoolean,
481                    DateMethod::IsAfter => AsBoolean,
482                    DateMethod::IsSameOrBefore => AsBoolean,
483                    DateMethod::IsSameOrAfter => AsBoolean,
484                    DateMethod::IsValid => AsBoolean,
485                    DateMethod::IsYesterday => AsBoolean,
486                    DateMethod::IsToday => AsBoolean,
487                    DateMethod::IsTomorrow => AsBoolean,
488                    DateMethod::IsLeapYear => AsBoolean,
489                },
490            },
491            Node::Error { .. } => AsBoolean,
492        }
493    }
494}