Skip to main content

darklua_core/nodes/expressions/
binary.rs

1use crate::nodes::{Expression, FunctionReturnType, Token, Type};
2
3/// Represents binary operators used in a binary expression.
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
5pub enum BinaryOperator {
6    /// Logical AND operator (`and`)
7    And,
8    /// Logical OR operator (`or`)
9    Or,
10    /// Equality operator (`==`)
11    Equal,
12    /// Inequality operator (`~=`)
13    NotEqual,
14    /// Less than operator (`<`)
15    LowerThan,
16    /// Less than or equal operator (`<=`)
17    LowerOrEqualThan,
18    /// Greater than operator (`>`)
19    GreaterThan,
20    /// Greater than or equal operator (`>=`)
21    GreaterOrEqualThan,
22    /// Addition operator (`+`)
23    Plus,
24    /// Subtraction operator (`-`)
25    Minus,
26    /// Multiplication operator (`*`)
27    Asterisk,
28    /// Division operator (`/`)
29    Slash,
30    /// Integer division operator (`//`)
31    DoubleSlash,
32    /// Modulo operator (`%`)
33    Percent,
34    /// Exponentiation operator (`^`)
35    Caret,
36    /// String concatenation operator (`..`)
37    Concat,
38}
39
40#[inline]
41fn ends_with_if_expression(expression: &Expression) -> bool {
42    let mut current = expression;
43
44    loop {
45        match current {
46            Expression::If(_) => break true,
47            Expression::Binary(binary) => current = binary.right(),
48            Expression::Unary(unary) => current = unary.get_expression(),
49            Expression::Call(_)
50            | Expression::False(_)
51            | Expression::Field(_)
52            | Expression::Function(_)
53            | Expression::Identifier(_)
54            | Expression::Index(_)
55            | Expression::Nil(_)
56            | Expression::Number(_)
57            | Expression::Parenthese(_)
58            | Expression::String(_)
59            | Expression::InterpolatedString(_)
60            | Expression::Table(_)
61            | Expression::True(_)
62            | Expression::VariableArguments(_)
63            | Expression::TypeCast(_)
64            | Expression::TypeInstantiation(_) => break false,
65        }
66    }
67}
68
69#[inline]
70fn ends_with_type_cast_to_type_name_without_type_parameters(expression: &Expression) -> bool {
71    let mut current = expression;
72
73    loop {
74        match current {
75            Expression::If(if_statement) => current = if_statement.get_else_result(),
76            Expression::Binary(binary) => current = binary.right(),
77            Expression::Unary(unary) => current = unary.get_expression(),
78            Expression::TypeCast(type_cast) => {
79                let mut current_type = type_cast.get_type();
80
81                break loop {
82                    match current_type {
83                        Type::Name(name) => break !name.has_type_parameters(),
84                        Type::Field(field) => break !field.get_type_name().has_type_parameters(),
85                        Type::Function(function) => {
86                            current_type = match function.get_return_type() {
87                                FunctionReturnType::Type(r#type) => r#type,
88                                FunctionReturnType::TypePack(_)
89                                | FunctionReturnType::GenericTypePack(_) => break false,
90                                FunctionReturnType::VariadicTypePack(variadic_type) => {
91                                    variadic_type.get_type()
92                                }
93                            }
94                        }
95                        Type::Intersection(intersection) => {
96                            current_type = intersection.last_type();
97                        }
98                        Type::Union(union_type) => {
99                            current_type = union_type.last_type();
100                        }
101                        Type::True(_)
102                        | Type::False(_)
103                        | Type::Nil(_)
104                        | Type::String(_)
105                        | Type::Array(_)
106                        | Type::Table(_)
107                        | Type::TypeOf(_)
108                        | Type::Parenthese(_)
109                        | Type::Optional(_) => break false,
110                    }
111                };
112            }
113            Expression::Call(_)
114            | Expression::False(_)
115            | Expression::Field(_)
116            | Expression::Function(_)
117            | Expression::Identifier(_)
118            | Expression::Index(_)
119            | Expression::Nil(_)
120            | Expression::Number(_)
121            | Expression::Parenthese(_)
122            | Expression::String(_)
123            | Expression::InterpolatedString(_)
124            | Expression::Table(_)
125            | Expression::True(_)
126            | Expression::TypeInstantiation(_)
127            | Expression::VariableArguments(_) => break false,
128        }
129    }
130}
131
132impl BinaryOperator {
133    /// Checks if this operator has higher precedence than another operator.
134    #[inline]
135    pub fn precedes(&self, other: Self) -> bool {
136        self.get_precedence() > other.get_precedence()
137    }
138
139    /// Checks if this operator has higher precedence than unary expressions.
140    ///
141    /// Currently only the exponentiation operator (`^`) has this property.
142    #[inline]
143    pub fn precedes_unary_expression(&self) -> bool {
144        matches!(self, Self::Caret)
145    }
146
147    /// Determines if this operator is left associative.
148    ///
149    /// Left associative operators like `+` evaluate expressions from left to right:
150    /// `a + b + c` is evaluated as `(a + b) + c`.
151    #[inline]
152    pub fn is_left_associative(&self) -> bool {
153        !matches!(self, Self::Caret | Self::Concat)
154    }
155
156    /// Determines if this operator is right associative.
157    ///
158    /// Right associative operators like `^` evaluate expressions from right to left:
159    /// `a ^ b ^ c` is evaluated as `a ^ (b ^ c)`.
160    #[inline]
161    pub fn is_right_associative(&self) -> bool {
162        matches!(self, Self::Caret | Self::Concat)
163    }
164
165    /// Determines if the left operand needs parentheses when generating code.
166    pub fn left_needs_parentheses(&self, left: &Expression) -> bool {
167        let needs_parentheses = match left {
168            Expression::Binary(left) => {
169                if self.is_left_associative() {
170                    self.precedes(left.operator())
171                } else {
172                    !left.operator().precedes(*self)
173                }
174            }
175            Expression::Unary(_) => self.precedes_unary_expression(),
176            Expression::If(_) => true,
177            _ => false,
178        };
179        needs_parentheses
180            || ends_with_if_expression(left)
181            || (matches!(self, BinaryOperator::LowerThan)
182                && ends_with_type_cast_to_type_name_without_type_parameters(left))
183    }
184
185    /// Determines if the right operand needs parentheses when generating code.
186    pub fn right_needs_parentheses(&self, right: &Expression) -> bool {
187        match right {
188            Expression::Binary(right) => {
189                if self.is_right_associative() {
190                    self.precedes(right.operator())
191                } else {
192                    !right.operator().precedes(*self)
193                }
194            }
195            Expression::Unary(_) => false,
196            _ => false,
197        }
198    }
199
200    /// Returns the string representation of this operator.
201    pub fn to_str(&self) -> &'static str {
202        match self {
203            Self::And => "and",
204            Self::Or => "or",
205            Self::Equal => "==",
206            Self::NotEqual => "~=",
207            Self::LowerThan => "<",
208            Self::LowerOrEqualThan => "<=",
209            Self::GreaterThan => ">",
210            Self::GreaterOrEqualThan => ">=",
211            Self::Plus => "+",
212            Self::Minus => "-",
213            Self::Asterisk => "*",
214            Self::Slash => "/",
215            Self::DoubleSlash => "//",
216            Self::Percent => "%",
217            Self::Caret => "^",
218            Self::Concat => "..",
219        }
220    }
221
222    /// Returns the precedence level of this operator (higher value = higher precedence).
223    fn get_precedence(&self) -> u8 {
224        match self {
225            Self::Or => 0,
226            Self::And => 1,
227            Self::Equal
228            | Self::NotEqual
229            | Self::LowerThan
230            | Self::LowerOrEqualThan
231            | Self::GreaterThan
232            | Self::GreaterOrEqualThan => 2,
233            Self::Concat => 3,
234            Self::Plus | Self::Minus => 4,
235            Self::Asterisk | Self::Slash | Self::DoubleSlash | Self::Percent => 5,
236            Self::Caret => 7,
237        }
238    }
239}
240
241/// Represents a binary operation in expressions.
242#[derive(Clone, Debug, PartialEq, Eq)]
243pub struct BinaryExpression {
244    operator: BinaryOperator,
245    left: Expression,
246    right: Expression,
247    token: Option<Token>,
248}
249
250impl BinaryExpression {
251    /// Creates a new binary expression with the given operator and operands.
252    pub fn new<T: Into<Expression>, U: Into<Expression>>(
253        operator: BinaryOperator,
254        left: T,
255        right: U,
256    ) -> Self {
257        Self {
258            operator,
259            left: left.into(),
260            right: right.into(),
261            token: None,
262        }
263    }
264
265    /// Associates a token with this expression.
266    pub fn with_token(mut self, token: Token) -> Self {
267        self.token = Some(token);
268        self
269    }
270
271    /// Associates a token with this expression.
272    #[inline]
273    pub fn set_token(&mut self, token: Token) {
274        self.token = Some(token);
275    }
276
277    /// Returns the token associated with this expression, if any.
278    #[inline]
279    pub fn get_token(&self) -> Option<&Token> {
280        self.token.as_ref()
281    }
282
283    /// Returns a mutable reference to the left operand.
284    #[inline]
285    pub fn mutate_left(&mut self) -> &mut Expression {
286        &mut self.left
287    }
288
289    /// Returns a mutable reference to the right operand.
290    #[inline]
291    pub fn mutate_right(&mut self) -> &mut Expression {
292        &mut self.right
293    }
294
295    /// Returns a reference to the left operand.
296    #[inline]
297    pub fn left(&self) -> &Expression {
298        &self.left
299    }
300
301    /// Returns a reference to the right operand.
302    #[inline]
303    pub fn right(&self) -> &Expression {
304        &self.right
305    }
306
307    /// Returns the binary operator.
308    #[inline]
309    pub fn operator(&self) -> BinaryOperator {
310        self.operator
311    }
312
313    /// Changes the operator and updates the associated token's content if present.
314    #[inline]
315    pub fn set_operator(&mut self, operator: BinaryOperator) {
316        if self.operator == operator {
317            return;
318        }
319        self.operator = operator;
320        if let Some(token) = self.token.as_mut() {
321            token.replace_with_content(operator.to_str());
322        }
323    }
324
325    /// Returns a mutable reference to the last token for this binary expression.
326    pub fn mutate_last_token(&mut self) -> &mut Token {
327        self.right.mutate_last_token()
328    }
329
330    super::impl_token_fns!(iter = [token]);
331}
332
333#[cfg(test)]
334mod test {
335    use super::*;
336
337    mod precedence {
338        use super::*;
339
340        use BinaryOperator::*;
341
342        #[test]
343        fn caret() {
344            assert!(Caret.precedes(And));
345            assert!(Caret.precedes(Or));
346            assert!(Caret.precedes(Equal));
347            assert!(Caret.precedes(NotEqual));
348            assert!(Caret.precedes(LowerThan));
349            assert!(Caret.precedes(LowerOrEqualThan));
350            assert!(Caret.precedes(GreaterThan));
351            assert!(Caret.precedes(GreaterOrEqualThan));
352            assert!(Caret.precedes(Plus));
353            assert!(Caret.precedes(Minus));
354            assert!(Caret.precedes(Asterisk));
355            assert!(Caret.precedes(Slash));
356            assert!(Caret.precedes(DoubleSlash));
357            assert!(Caret.precedes(Percent));
358            assert!(Caret.precedes(Concat));
359            assert!(!Caret.precedes(Caret));
360            assert!(Caret.precedes_unary_expression());
361        }
362
363        #[test]
364        fn asterisk() {
365            assert!(Asterisk.precedes(And));
366            assert!(Asterisk.precedes(Or));
367            assert!(Asterisk.precedes(Equal));
368            assert!(Asterisk.precedes(NotEqual));
369            assert!(Asterisk.precedes(LowerThan));
370            assert!(Asterisk.precedes(LowerOrEqualThan));
371            assert!(Asterisk.precedes(GreaterThan));
372            assert!(Asterisk.precedes(GreaterOrEqualThan));
373            assert!(Asterisk.precedes(Plus));
374            assert!(Asterisk.precedes(Minus));
375            assert!(!Asterisk.precedes(Asterisk));
376            assert!(!Asterisk.precedes(Slash));
377            assert!(!Asterisk.precedes(DoubleSlash));
378            assert!(!Asterisk.precedes(Percent));
379            assert!(Asterisk.precedes(Concat));
380            assert!(!Asterisk.precedes(Caret));
381            assert!(!Asterisk.precedes_unary_expression());
382        }
383
384        #[test]
385        fn slash() {
386            assert!(Slash.precedes(And));
387            assert!(Slash.precedes(Or));
388            assert!(Slash.precedes(Equal));
389            assert!(Slash.precedes(NotEqual));
390            assert!(Slash.precedes(LowerThan));
391            assert!(Slash.precedes(LowerOrEqualThan));
392            assert!(Slash.precedes(GreaterThan));
393            assert!(Slash.precedes(GreaterOrEqualThan));
394            assert!(Slash.precedes(Plus));
395            assert!(Slash.precedes(Minus));
396            assert!(!Slash.precedes(Asterisk));
397            assert!(!Slash.precedes(Slash));
398            assert!(!Slash.precedes(DoubleSlash));
399            assert!(!Slash.precedes(Percent));
400            assert!(Slash.precedes(Concat));
401            assert!(!Slash.precedes(Caret));
402            assert!(!Slash.precedes_unary_expression());
403        }
404
405        #[test]
406        fn percent() {
407            assert!(Percent.precedes(And));
408            assert!(Percent.precedes(Or));
409            assert!(Percent.precedes(Equal));
410            assert!(Percent.precedes(NotEqual));
411            assert!(Percent.precedes(LowerThan));
412            assert!(Percent.precedes(LowerOrEqualThan));
413            assert!(Percent.precedes(GreaterThan));
414            assert!(Percent.precedes(GreaterOrEqualThan));
415            assert!(Percent.precedes(Plus));
416            assert!(Percent.precedes(Minus));
417            assert!(!Percent.precedes(Asterisk));
418            assert!(!Percent.precedes(Slash));
419            assert!(!Percent.precedes(DoubleSlash));
420            assert!(!Percent.precedes(Percent));
421            assert!(Percent.precedes(Concat));
422            assert!(!Percent.precedes(Caret));
423            assert!(!Percent.precedes_unary_expression());
424        }
425
426        #[test]
427        fn plus() {
428            assert!(Plus.precedes(And));
429            assert!(Plus.precedes(Or));
430            assert!(Plus.precedes(Equal));
431            assert!(Plus.precedes(NotEqual));
432            assert!(Plus.precedes(LowerThan));
433            assert!(Plus.precedes(LowerOrEqualThan));
434            assert!(Plus.precedes(GreaterThan));
435            assert!(Plus.precedes(GreaterOrEqualThan));
436            assert!(!Plus.precedes(Plus));
437            assert!(!Plus.precedes(Minus));
438            assert!(!Plus.precedes(Asterisk));
439            assert!(!Plus.precedes(Slash));
440            assert!(!Plus.precedes(DoubleSlash));
441            assert!(!Plus.precedes(Percent));
442            assert!(Plus.precedes(Concat));
443            assert!(!Plus.precedes(Caret));
444            assert!(!Plus.precedes_unary_expression());
445        }
446
447        #[test]
448        fn minus() {
449            assert!(Minus.precedes(And));
450            assert!(Minus.precedes(Or));
451            assert!(Minus.precedes(Equal));
452            assert!(Minus.precedes(NotEqual));
453            assert!(Minus.precedes(LowerThan));
454            assert!(Minus.precedes(LowerOrEqualThan));
455            assert!(Minus.precedes(GreaterThan));
456            assert!(Minus.precedes(GreaterOrEqualThan));
457            assert!(!Minus.precedes(Plus));
458            assert!(!Minus.precedes(Minus));
459            assert!(!Minus.precedes(Asterisk));
460            assert!(!Minus.precedes(Slash));
461            assert!(!Minus.precedes(DoubleSlash));
462            assert!(!Minus.precedes(Percent));
463            assert!(Minus.precedes(Concat));
464            assert!(!Minus.precedes(Caret));
465            assert!(!Minus.precedes_unary_expression());
466        }
467
468        #[test]
469        fn concat() {
470            assert!(Concat.precedes(And));
471            assert!(Concat.precedes(Or));
472            assert!(Concat.precedes(Equal));
473            assert!(Concat.precedes(NotEqual));
474            assert!(Concat.precedes(LowerThan));
475            assert!(Concat.precedes(LowerOrEqualThan));
476            assert!(Concat.precedes(GreaterThan));
477            assert!(Concat.precedes(GreaterOrEqualThan));
478            assert!(!Concat.precedes(Plus));
479            assert!(!Concat.precedes(Minus));
480            assert!(!Concat.precedes(Asterisk));
481            assert!(!Concat.precedes(Slash));
482            assert!(!Concat.precedes(DoubleSlash));
483            assert!(!Concat.precedes(Percent));
484            assert!(!Concat.precedes(Concat));
485            assert!(!Concat.precedes(Caret));
486            assert!(!Concat.precedes_unary_expression());
487        }
488
489        #[test]
490        fn and() {
491            assert!(!And.precedes(And));
492            assert!(And.precedes(Or));
493            assert!(!And.precedes(Equal));
494            assert!(!And.precedes(NotEqual));
495            assert!(!And.precedes(LowerThan));
496            assert!(!And.precedes(LowerOrEqualThan));
497            assert!(!And.precedes(GreaterThan));
498            assert!(!And.precedes(GreaterOrEqualThan));
499            assert!(!And.precedes(Plus));
500            assert!(!And.precedes(Minus));
501            assert!(!And.precedes(Asterisk));
502            assert!(!And.precedes(Slash));
503            assert!(!And.precedes(DoubleSlash));
504            assert!(!And.precedes(Percent));
505            assert!(!And.precedes(Concat));
506            assert!(!And.precedes(Caret));
507            assert!(!And.precedes_unary_expression());
508        }
509    }
510}