moduforge_rules_expression/parser/
unary.rs

1use crate::functions::{
2    ClosureFunction, DateMethod, DeprecatedFunction, FunctionKind,
3    InternalFunction, MethodKind,
4};
5use crate::lexer::{
6    Bracket, ComparisonOperator, Identifier, LogicalOperator, Operator,
7    TokenKind,
8};
9use crate::parser::ast::{AstNodeError, Node};
10use crate::parser::constants::{Associativity, BINARY_OPERATORS, UNARY_OPERATORS};
11use crate::parser::parser::{Parser, ParserContext};
12use crate::parser::unary::UnaryNodeBehaviour::CompareWithReference;
13use crate::parser::{NodeMetadata, ParserResult};
14
15#[derive(Debug)]
16pub struct Unary;
17
18const ROOT_NODE: Node<'static> = Node::Identifier("$");
19
20impl<'arena, 'token_ref> Parser<'arena, 'token_ref, Unary> {
21    pub fn parse(&self) -> ParserResult<'arena> {
22        let root = self.root_expression();
23
24        ParserResult {
25            root,
26            is_complete: self.is_done(),
27            metadata: self.node_metadata.clone().map(|t| t.into_inner()),
28        }
29    }
30
31    fn root_expression(&self) -> &'arena Node<'arena> {
32        let mut left_node = self.expression_pair();
33
34        while !self.is_done() {
35            let Some(current_token) = self.current() else {
36                break;
37            };
38
39            let join_operator = match &current_token.kind {
40                TokenKind::Operator(Operator::Logical(
41                    LogicalOperator::And,
42                )) => Operator::Logical(LogicalOperator::And),
43                TokenKind::Operator(Operator::Logical(LogicalOperator::Or))
44                | TokenKind::Operator(Operator::Comma) => {
45                    Operator::Logical(LogicalOperator::Or)
46                },
47                _ => {
48                    return self.error(AstNodeError::Custom {
49                        message: self.bump.alloc_str(
50                            format!(
51                                "Invalid join operator `{}`",
52                                current_token.kind
53                            )
54                            .as_str(),
55                        ),
56                        span: current_token.span,
57                    });
58                },
59            };
60
61            self.next();
62            let right_node = self.expression_pair();
63            left_node = self.node(
64                Node::Binary {
65                    left: left_node,
66                    operator: join_operator,
67                    right: right_node,
68                },
69                |h| NodeMetadata {
70                    span: h.span(left_node, right_node).unwrap_or_default(),
71                },
72            );
73        }
74
75        left_node
76    }
77
78    fn expression_pair(&self) -> &'arena Node<'arena> {
79        let mut left_node = &ROOT_NODE;
80        let current_token = self.current();
81
82        if let Some(TokenKind::Operator(Operator::Comparison(_))) =
83            self.current_kind()
84        {
85            // Skips
86        } else {
87            left_node = self.binary_expression(0, ParserContext::Global);
88        }
89
90        match self.current_kind() {
91            Some(TokenKind::Operator(Operator::Comparison(comparison))) => {
92                self.next();
93                let right_node =
94                    self.binary_expression(0, ParserContext::Global);
95                left_node = self.node(
96                    Node::Binary {
97                        left: left_node,
98                        operator: Operator::Comparison(*comparison),
99                        right: right_node,
100                    },
101                    |h| NodeMetadata {
102                        span: (
103                            current_token.map(|t| t.span.0).unwrap_or_default(),
104                            h.metadata(right_node)
105                                .map(|n| n.span.1)
106                                .unwrap_or_default(),
107                        ),
108                    },
109                );
110            },
111            _ => {
112                let behaviour = UnaryNodeBehaviour::from(left_node);
113                match behaviour {
114                    CompareWithReference(comparator) => {
115                        left_node = self.node(
116                            Node::Binary {
117                                left: &ROOT_NODE,
118                                operator: Operator::Comparison(comparator),
119                                right: left_node,
120                            },
121                            |h| NodeMetadata {
122                                span: (
123                                    current_token
124                                        .map(|t| t.span.0)
125                                        .unwrap_or_default(),
126                                    h.metadata(left_node)
127                                        .map(|n| n.span.1)
128                                        .unwrap_or_default(),
129                                ),
130                            },
131                        )
132                    },
133                    UnaryNodeBehaviour::AsBoolean => {
134                        left_node = self.node(
135                            Node::FunctionCall {
136                                kind: FunctionKind::Internal(
137                                    InternalFunction::Bool,
138                                ),
139                                arguments: self
140                                    .bump
141                                    .alloc_slice_clone(&[left_node]),
142                            },
143                            |h| NodeMetadata {
144                                span: (
145                                    current_token
146                                        .map(|t| t.span.0)
147                                        .unwrap_or_default(),
148                                    h.metadata(left_node)
149                                        .map(|n| n.span.1)
150                                        .unwrap_or_default(),
151                                ),
152                            },
153                        )
154                    },
155                }
156            },
157        }
158
159        left_node
160    }
161
162    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
163    fn binary_expression(
164        &self,
165        precedence: u8,
166        ctx: ParserContext,
167    ) -> &'arena Node<'arena> {
168        let mut node_left = self.unary_expression();
169        let Some(mut token) = self.current() else {
170            return node_left;
171        };
172
173        while let TokenKind::Operator(operator) = &token.kind {
174            if self.is_done() {
175                break;
176            }
177
178            if ctx == ParserContext::Global
179                && matches!(
180                    operator,
181                    Operator::Comma
182                        | Operator::Logical(LogicalOperator::And)
183                        | Operator::Logical(LogicalOperator::Or)
184                )
185            {
186                break;
187            }
188
189            let Some(op) = BINARY_OPERATORS.get(operator) else {
190                break;
191            };
192
193            if op.precedence < precedence {
194                break;
195            }
196
197            self.next();
198            let node_right = match op.associativity {
199                Associativity::Left => self.binary_expression(
200                    op.precedence + 1,
201                    ParserContext::Global,
202                ),
203                _ => {
204                    self.binary_expression(op.precedence, ParserContext::Global)
205                },
206            };
207
208            node_left = self.node(
209                Node::Binary {
210                    operator: *operator,
211                    left: node_left,
212                    right: node_right,
213                },
214                |h| NodeMetadata {
215                    span: h.span(node_left, node_right).unwrap_or_default(),
216                },
217            );
218
219            let Some(t) = self.current() else {
220                break;
221            };
222            token = t;
223        }
224
225        if precedence == 0 {
226            if let Some(conditional_node) =
227                self.conditional(node_left, |c| self.binary_expression(0, c))
228            {
229                node_left = conditional_node;
230            }
231        }
232
233        node_left
234    }
235
236    fn unary_expression(&self) -> &'arena Node<'arena> {
237        let Some(token) = self.current() else {
238            return self.literal(|c| self.binary_expression(0, c));
239        };
240
241        if self.depth() > 0
242            && token.kind
243                == TokenKind::Identifier(Identifier::CallbackReference)
244        {
245            self.next();
246
247            let node =
248                self.node(Node::Pointer, |_| NodeMetadata { span: token.span });
249            return self.with_postfix(node, |c| self.binary_expression(0, c));
250        }
251
252        if let TokenKind::Operator(operator) = &token.kind {
253            let Some(unary_operator) = UNARY_OPERATORS.get(operator) else {
254                return self.error(AstNodeError::UnexpectedToken {
255                    expected: self.bump.alloc_str("UnaryOperator"),
256                    received: self
257                        .bump
258                        .alloc_str(token.kind.to_string().as_str()),
259                    span: token.span,
260                });
261            };
262
263            self.next();
264            let expr = self.binary_expression(
265                unary_operator.precedence,
266                ParserContext::Global,
267            );
268            let node = self.node(
269                Node::Unary { operator: *operator, node: expr },
270                |h| NodeMetadata {
271                    span: (
272                        token.span.0,
273                        h.metadata(expr).map(|n| n.span.1).unwrap_or_default(),
274                    ),
275                },
276            );
277
278            return node;
279        }
280
281        if let Some(interval_node) =
282            self.interval(|c| self.binary_expression(0, c))
283        {
284            return interval_node;
285        }
286
287        if token.kind == TokenKind::Bracket(Bracket::LeftParenthesis) {
288            let p_start = self.current().map(|s| s.span.0);
289
290            self.next();
291            let binary_node = self.binary_expression(0, ParserContext::Global);
292            if let Some(error_node) =
293                self.expect(TokenKind::Bracket(Bracket::RightParenthesis))
294            {
295                return error_node;
296            };
297
298            let expr =
299                self.node(Node::Parenthesized(binary_node), |_| NodeMetadata {
300                    span: (p_start.unwrap_or_default(), self.prev_token_end()),
301                });
302
303            return self.with_postfix(expr, |c| self.binary_expression(0, c));
304        }
305
306        self.literal(|c| self.binary_expression(0, c))
307    }
308}
309
310/// Dictates the behaviour of nodes in unary mode.
311/// If `CompareWithReference` is set, node will attempt to make the comparison with the reference,
312/// essentially making it (in case of Equal operator) `$ == nodeValue`, or (in case of In operator)
313/// `$ in nodeValue`.
314///
315/// Using `AsBoolean` will cast the nodeValue to boolean and skip comparison with reference ($).
316/// You may still use references in such case directly, e.g. `contains($, 'hello')`.
317///
318/// Rationale behind this is to avoid scenarios where e.g. $ = false and expression is
319/// `contains($, 'needle')`. If we didn't ignore the reference, unary expression will be
320/// reduced to `$ == contains($, 'needle')` which will be truthy when $ does not
321/// contain needle.
322#[derive(Debug, PartialEq)]
323enum UnaryNodeBehaviour {
324    CompareWithReference(ComparisonOperator),
325    AsBoolean,
326}
327
328impl From<&Node<'_>> for UnaryNodeBehaviour {
329    fn from(value: &Node) -> Self {
330        use ComparisonOperator::*;
331        use UnaryNodeBehaviour::*;
332
333        match value {
334            Node::Null => CompareWithReference(Equal),
335            Node::Root => CompareWithReference(Equal),
336            Node::Bool(_) => CompareWithReference(Equal),
337            Node::Number(_) => CompareWithReference(Equal),
338            Node::String(_) => CompareWithReference(Equal),
339            Node::TemplateString(_) => CompareWithReference(Equal),
340            Node::Object(_) => CompareWithReference(Equal),
341            Node::Pointer => AsBoolean,
342            Node::Array(_) => CompareWithReference(In),
343            Node::Identifier(_) => CompareWithReference(Equal),
344            Node::Closure(_) => AsBoolean,
345            Node::Member { .. } => CompareWithReference(Equal),
346            Node::Slice { .. } => CompareWithReference(In),
347            Node::Interval { .. } => CompareWithReference(In),
348            Node::Conditional { on_true, on_false, .. } => {
349                let a = UnaryNodeBehaviour::from(*on_true);
350                let b = UnaryNodeBehaviour::from(*on_false);
351
352                if a == b { a } else { CompareWithReference(Equal) }
353            },
354            Node::Unary { node, .. } => UnaryNodeBehaviour::from(*node),
355            Node::Parenthesized(n) => UnaryNodeBehaviour::from(*n),
356            Node::Binary { left, operator, right } => match operator {
357                Operator::Arithmetic(_) => {
358                    let a = UnaryNodeBehaviour::from(*left);
359                    let b = UnaryNodeBehaviour::from(*right);
360
361                    if a == b { a } else { CompareWithReference(Equal) }
362                },
363                Operator::Logical(_) => AsBoolean,
364                Operator::Comparison(_) => AsBoolean,
365                Operator::Range => CompareWithReference(In),
366                Operator::Slice => CompareWithReference(In),
367                Operator::Comma => AsBoolean,
368                Operator::Dot => AsBoolean,
369                Operator::QuestionMark => AsBoolean,
370            },
371            Node::FunctionCall { kind, .. } => match kind {
372                FunctionKind::Internal(i) => match i {
373                    InternalFunction::Len => CompareWithReference(Equal),
374                    InternalFunction::Upper => CompareWithReference(Equal),
375                    InternalFunction::Lower => CompareWithReference(Equal),
376                    InternalFunction::Trim => CompareWithReference(Equal),
377                    InternalFunction::Abs => CompareWithReference(Equal),
378                    InternalFunction::Sum => CompareWithReference(Equal),
379                    InternalFunction::Avg => CompareWithReference(Equal),
380                    InternalFunction::Min => CompareWithReference(Equal),
381                    InternalFunction::Max => CompareWithReference(Equal),
382                    InternalFunction::Rand => CompareWithReference(Equal),
383                    InternalFunction::Median => CompareWithReference(Equal),
384                    InternalFunction::Mode => CompareWithReference(Equal),
385                    InternalFunction::Floor => CompareWithReference(Equal),
386                    InternalFunction::Ceil => CompareWithReference(Equal),
387                    InternalFunction::Round => CompareWithReference(Equal),
388                    InternalFunction::Trunc => CompareWithReference(Equal),
389                    InternalFunction::String => CompareWithReference(Equal),
390                    InternalFunction::Number => CompareWithReference(Equal),
391                    InternalFunction::Bool => CompareWithReference(Equal),
392                    InternalFunction::Flatten => CompareWithReference(In),
393                    InternalFunction::Extract => CompareWithReference(In),
394                    InternalFunction::Contains => AsBoolean,
395                    InternalFunction::StartsWith => AsBoolean,
396                    InternalFunction::EndsWith => AsBoolean,
397                    InternalFunction::Matches => AsBoolean,
398                    InternalFunction::FuzzyMatch => CompareWithReference(Equal),
399                    InternalFunction::Split => CompareWithReference(In),
400                    InternalFunction::IsNumeric => AsBoolean,
401                    InternalFunction::Keys => CompareWithReference(In),
402                    InternalFunction::Values => CompareWithReference(In),
403                    InternalFunction::Type => CompareWithReference(Equal),
404                    InternalFunction::Date => CompareWithReference(Equal),
405                },
406                FunctionKind::Deprecated(d) => match d {
407                    DeprecatedFunction::Date => CompareWithReference(Equal),
408                    DeprecatedFunction::Time => CompareWithReference(Equal),
409                    DeprecatedFunction::Duration => CompareWithReference(Equal),
410                    DeprecatedFunction::Year => CompareWithReference(Equal),
411                    DeprecatedFunction::DayOfWeek => {
412                        CompareWithReference(Equal)
413                    },
414                    DeprecatedFunction::DayOfMonth => {
415                        CompareWithReference(Equal)
416                    },
417                    DeprecatedFunction::DayOfYear => {
418                        CompareWithReference(Equal)
419                    },
420                    DeprecatedFunction::WeekOfYear => {
421                        CompareWithReference(Equal)
422                    },
423                    DeprecatedFunction::MonthOfYear => {
424                        CompareWithReference(Equal)
425                    },
426                    DeprecatedFunction::MonthString => {
427                        CompareWithReference(Equal)
428                    },
429                    DeprecatedFunction::DateString => {
430                        CompareWithReference(Equal)
431                    },
432                    DeprecatedFunction::WeekdayString => {
433                        CompareWithReference(Equal)
434                    },
435                    DeprecatedFunction::StartOf => CompareWithReference(Equal),
436                    DeprecatedFunction::EndOf => CompareWithReference(Equal),
437                },
438                FunctionKind::Closure(c) => match c {
439                    ClosureFunction::All => AsBoolean,
440                    ClosureFunction::Some => AsBoolean,
441                    ClosureFunction::None => AsBoolean,
442                    ClosureFunction::One => AsBoolean,
443                    ClosureFunction::Filter => CompareWithReference(In),
444                    ClosureFunction::Map => CompareWithReference(In),
445                    ClosureFunction::FlatMap => CompareWithReference(In),
446                    ClosureFunction::Count => CompareWithReference(Equal),
447                },
448                FunctionKind::Custom(_) => CompareWithReference(Equal), // 自定义函数默认用等号比较
449            },
450            Node::MethodCall { kind, .. } => match kind {
451                MethodKind::DateMethod(dm) => match dm {
452                    DateMethod::Add => CompareWithReference(Equal),
453                    DateMethod::Sub => CompareWithReference(Equal),
454                    DateMethod::Format => CompareWithReference(Equal),
455                    DateMethod::Month => CompareWithReference(Equal),
456                    DateMethod::Year => CompareWithReference(Equal),
457                    DateMethod::Set => CompareWithReference(Equal),
458                    DateMethod::StartOf => CompareWithReference(Equal),
459                    DateMethod::EndOf => CompareWithReference(Equal),
460                    DateMethod::Diff => CompareWithReference(Equal),
461                    DateMethod::Tz => CompareWithReference(Equal),
462                    DateMethod::Second => CompareWithReference(Equal),
463                    DateMethod::Minute => CompareWithReference(Equal),
464                    DateMethod::Hour => CompareWithReference(Equal),
465                    DateMethod::Day => CompareWithReference(Equal),
466                    DateMethod::DayOfYear => CompareWithReference(Equal),
467                    DateMethod::Week => CompareWithReference(Equal),
468                    DateMethod::Weekday => CompareWithReference(Equal),
469                    DateMethod::Quarter => CompareWithReference(Equal),
470                    DateMethod::Timestamp => CompareWithReference(Equal),
471                    DateMethod::OffsetName => CompareWithReference(Equal),
472                    DateMethod::IsSame => AsBoolean,
473                    DateMethod::IsBefore => AsBoolean,
474                    DateMethod::IsAfter => AsBoolean,
475                    DateMethod::IsSameOrBefore => AsBoolean,
476                    DateMethod::IsSameOrAfter => AsBoolean,
477                    DateMethod::IsValid => AsBoolean,
478                    DateMethod::IsYesterday => AsBoolean,
479                    DateMethod::IsToday => AsBoolean,
480                    DateMethod::IsTomorrow => AsBoolean,
481                    DateMethod::IsLeapYear => AsBoolean,
482                },
483            },
484            Node::Error { .. } => AsBoolean,
485        }
486    }
487}