darklua_core/nodes/expressions/
binary.rs

1use crate::nodes::{Expression, FunctionReturnType, Token, Type};
2
3/// Represents binary operators used in a binary expression.
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
5pub enum BinaryOperator {
6    /// Logical AND operator (`and`)
7    And,
8    /// Logical OR operator (`or`)
9    Or,
10    /// Equality operator (`==`)
11    Equal,
12    /// Inequality operator (`~=`)
13    NotEqual,
14    /// Less than operator (`<`)
15    LowerThan,
16    /// Less than or equal operator (`<=`)
17    LowerOrEqualThan,
18    /// Greater than operator (`>`)
19    GreaterThan,
20    /// Greater than or equal operator (`>=`)
21    GreaterOrEqualThan,
22    /// Addition operator (`+`)
23    Plus,
24    /// Subtraction operator (`-`)
25    Minus,
26    /// Multiplication operator (`*`)
27    Asterisk,
28    /// Division operator (`/`)
29    Slash,
30    /// Integer division operator (`//`)
31    DoubleSlash,
32    /// Modulo operator (`%`)
33    Percent,
34    /// Exponentiation operator (`^`)
35    Caret,
36    /// String concatenation operator (`..`)
37    Concat,
38}
39
40#[inline]
41fn ends_with_if_expression(expression: &Expression) -> bool {
42    let mut current = expression;
43
44    loop {
45        match current {
46            Expression::If(_) => break true,
47            Expression::Binary(binary) => current = binary.right(),
48            Expression::Unary(unary) => current = unary.get_expression(),
49            Expression::Call(_)
50            | Expression::False(_)
51            | Expression::Field(_)
52            | Expression::Function(_)
53            | Expression::Identifier(_)
54            | Expression::Index(_)
55            | Expression::Nil(_)
56            | Expression::Number(_)
57            | Expression::Parenthese(_)
58            | Expression::String(_)
59            | Expression::InterpolatedString(_)
60            | Expression::Table(_)
61            | Expression::True(_)
62            | Expression::VariableArguments(_)
63            | Expression::TypeCast(_) => break false,
64        }
65    }
66}
67
68#[inline]
69fn ends_with_type_cast_to_type_name_without_type_parameters(expression: &Expression) -> bool {
70    let mut current = expression;
71
72    loop {
73        match current {
74            Expression::If(if_statement) => current = if_statement.get_else_result(),
75            Expression::Binary(binary) => current = binary.right(),
76            Expression::Unary(unary) => current = unary.get_expression(),
77            Expression::TypeCast(type_cast) => {
78                let mut current_type = type_cast.get_type();
79
80                break loop {
81                    match current_type {
82                        Type::Name(name) => break !name.has_type_parameters(),
83                        Type::Field(field) => break !field.get_type_name().has_type_parameters(),
84                        Type::Function(function) => {
85                            current_type = match function.get_return_type() {
86                                FunctionReturnType::Type(r#type) => r#type,
87                                FunctionReturnType::TypePack(_)
88                                | FunctionReturnType::GenericTypePack(_) => break false,
89                                FunctionReturnType::VariadicTypePack(variadic_type) => {
90                                    variadic_type.get_type()
91                                }
92                            }
93                        }
94                        Type::Intersection(intersection) => {
95                            current_type = intersection.last_type();
96                        }
97                        Type::Union(union_type) => {
98                            current_type = union_type.last_type();
99                        }
100                        Type::True(_)
101                        | Type::False(_)
102                        | Type::Nil(_)
103                        | Type::String(_)
104                        | Type::Array(_)
105                        | Type::Table(_)
106                        | Type::TypeOf(_)
107                        | Type::Parenthese(_)
108                        | Type::Optional(_) => break false,
109                    }
110                };
111            }
112            Expression::Call(_)
113            | Expression::False(_)
114            | Expression::Field(_)
115            | Expression::Function(_)
116            | Expression::Identifier(_)
117            | Expression::Index(_)
118            | Expression::Nil(_)
119            | Expression::Number(_)
120            | Expression::Parenthese(_)
121            | Expression::String(_)
122            | Expression::InterpolatedString(_)
123            | Expression::Table(_)
124            | Expression::True(_)
125            | Expression::VariableArguments(_) => break false,
126        }
127    }
128}
129
130impl BinaryOperator {
131    /// Checks if this operator has higher precedence than another operator.
132    #[inline]
133    pub fn precedes(&self, other: Self) -> bool {
134        self.get_precedence() > other.get_precedence()
135    }
136
137    /// Checks if this operator has higher precedence than unary expressions.
138    ///
139    /// Currently only the exponentiation operator (`^`) has this property.
140    #[inline]
141    pub fn precedes_unary_expression(&self) -> bool {
142        matches!(self, Self::Caret)
143    }
144
145    /// Determines if this operator is left associative.
146    ///
147    /// Left associative operators like `+` evaluate expressions from left to right:
148    /// `a + b + c` is evaluated as `(a + b) + c`.
149    #[inline]
150    pub fn is_left_associative(&self) -> bool {
151        !matches!(self, Self::Caret | Self::Concat)
152    }
153
154    /// Determines if this operator is right associative.
155    ///
156    /// Right associative operators like `^` evaluate expressions from right to left:
157    /// `a ^ b ^ c` is evaluated as `a ^ (b ^ c)`.
158    #[inline]
159    pub fn is_right_associative(&self) -> bool {
160        matches!(self, Self::Caret | Self::Concat)
161    }
162
163    /// Determines if the left operand needs parentheses when generating code.
164    pub fn left_needs_parentheses(&self, left: &Expression) -> bool {
165        let needs_parentheses = match left {
166            Expression::Binary(left) => {
167                if self.is_left_associative() {
168                    self.precedes(left.operator())
169                } else {
170                    !left.operator().precedes(*self)
171                }
172            }
173            Expression::Unary(_) => self.precedes_unary_expression(),
174            Expression::If(_) => true,
175            _ => false,
176        };
177        needs_parentheses
178            || ends_with_if_expression(left)
179            || (matches!(self, BinaryOperator::LowerThan)
180                && ends_with_type_cast_to_type_name_without_type_parameters(left))
181    }
182
183    /// Determines if the right operand needs parentheses when generating code.
184    pub fn right_needs_parentheses(&self, right: &Expression) -> bool {
185        match right {
186            Expression::Binary(right) => {
187                if self.is_right_associative() {
188                    self.precedes(right.operator())
189                } else {
190                    !right.operator().precedes(*self)
191                }
192            }
193            Expression::Unary(_) => false,
194            _ => false,
195        }
196    }
197
198    /// Returns the string representation of this operator.
199    pub fn to_str(&self) -> &'static str {
200        match self {
201            Self::And => "and",
202            Self::Or => "or",
203            Self::Equal => "==",
204            Self::NotEqual => "~=",
205            Self::LowerThan => "<",
206            Self::LowerOrEqualThan => "<=",
207            Self::GreaterThan => ">",
208            Self::GreaterOrEqualThan => ">=",
209            Self::Plus => "+",
210            Self::Minus => "-",
211            Self::Asterisk => "*",
212            Self::Slash => "/",
213            Self::DoubleSlash => "//",
214            Self::Percent => "%",
215            Self::Caret => "^",
216            Self::Concat => "..",
217        }
218    }
219
220    /// Returns the precedence level of this operator (higher value = higher precedence).
221    fn get_precedence(&self) -> u8 {
222        match self {
223            Self::Or => 0,
224            Self::And => 1,
225            Self::Equal
226            | Self::NotEqual
227            | Self::LowerThan
228            | Self::LowerOrEqualThan
229            | Self::GreaterThan
230            | Self::GreaterOrEqualThan => 2,
231            Self::Concat => 3,
232            Self::Plus | Self::Minus => 4,
233            Self::Asterisk | Self::Slash | Self::DoubleSlash | Self::Percent => 5,
234            Self::Caret => 7,
235        }
236    }
237}
238
239/// Represents a binary operation in expressions.
240#[derive(Clone, Debug, PartialEq, Eq)]
241pub struct BinaryExpression {
242    operator: BinaryOperator,
243    left: Expression,
244    right: Expression,
245    token: Option<Token>,
246}
247
248impl BinaryExpression {
249    /// Creates a new binary expression with the given operator and operands.
250    pub fn new<T: Into<Expression>, U: Into<Expression>>(
251        operator: BinaryOperator,
252        left: T,
253        right: U,
254    ) -> Self {
255        Self {
256            operator,
257            left: left.into(),
258            right: right.into(),
259            token: None,
260        }
261    }
262
263    /// Associates a token with this expression.
264    pub fn with_token(mut self, token: Token) -> Self {
265        self.token = Some(token);
266        self
267    }
268
269    /// Associates a token with this expression.
270    #[inline]
271    pub fn set_token(&mut self, token: Token) {
272        self.token = Some(token);
273    }
274
275    /// Returns the token associated with this expression, if any.
276    #[inline]
277    pub fn get_token(&self) -> Option<&Token> {
278        self.token.as_ref()
279    }
280
281    /// Returns a mutable reference to the left operand.
282    #[inline]
283    pub fn mutate_left(&mut self) -> &mut Expression {
284        &mut self.left
285    }
286
287    /// Returns a mutable reference to the right operand.
288    #[inline]
289    pub fn mutate_right(&mut self) -> &mut Expression {
290        &mut self.right
291    }
292
293    /// Returns a reference to the left operand.
294    #[inline]
295    pub fn left(&self) -> &Expression {
296        &self.left
297    }
298
299    /// Returns a reference to the right operand.
300    #[inline]
301    pub fn right(&self) -> &Expression {
302        &self.right
303    }
304
305    /// Returns the binary operator.
306    #[inline]
307    pub fn operator(&self) -> BinaryOperator {
308        self.operator
309    }
310
311    /// Changes the operator and updates the associated token's content if present.
312    #[inline]
313    pub fn set_operator(&mut self, operator: BinaryOperator) {
314        if self.operator == operator {
315            return;
316        }
317        self.operator = operator;
318        if let Some(token) = self.token.as_mut() {
319            token.replace_with_content(operator.to_str());
320        }
321    }
322
323    /// Returns a mutable reference to the last token for this binary expression.
324    pub fn mutate_last_token(&mut self) -> &mut Token {
325        self.right.mutate_last_token()
326    }
327
328    super::impl_token_fns!(iter = [token]);
329}
330
331#[cfg(test)]
332mod test {
333    use super::*;
334
335    mod precedence {
336        use super::*;
337
338        use BinaryOperator::*;
339
340        #[test]
341        fn caret() {
342            assert!(Caret.precedes(And));
343            assert!(Caret.precedes(Or));
344            assert!(Caret.precedes(Equal));
345            assert!(Caret.precedes(NotEqual));
346            assert!(Caret.precedes(LowerThan));
347            assert!(Caret.precedes(LowerOrEqualThan));
348            assert!(Caret.precedes(GreaterThan));
349            assert!(Caret.precedes(GreaterOrEqualThan));
350            assert!(Caret.precedes(Plus));
351            assert!(Caret.precedes(Minus));
352            assert!(Caret.precedes(Asterisk));
353            assert!(Caret.precedes(Slash));
354            assert!(Caret.precedes(DoubleSlash));
355            assert!(Caret.precedes(Percent));
356            assert!(Caret.precedes(Concat));
357            assert!(!Caret.precedes(Caret));
358            assert!(Caret.precedes_unary_expression());
359        }
360
361        #[test]
362        fn asterisk() {
363            assert!(Asterisk.precedes(And));
364            assert!(Asterisk.precedes(Or));
365            assert!(Asterisk.precedes(Equal));
366            assert!(Asterisk.precedes(NotEqual));
367            assert!(Asterisk.precedes(LowerThan));
368            assert!(Asterisk.precedes(LowerOrEqualThan));
369            assert!(Asterisk.precedes(GreaterThan));
370            assert!(Asterisk.precedes(GreaterOrEqualThan));
371            assert!(Asterisk.precedes(Plus));
372            assert!(Asterisk.precedes(Minus));
373            assert!(!Asterisk.precedes(Asterisk));
374            assert!(!Asterisk.precedes(Slash));
375            assert!(!Asterisk.precedes(DoubleSlash));
376            assert!(!Asterisk.precedes(Percent));
377            assert!(Asterisk.precedes(Concat));
378            assert!(!Asterisk.precedes(Caret));
379            assert!(!Asterisk.precedes_unary_expression());
380        }
381
382        #[test]
383        fn slash() {
384            assert!(Slash.precedes(And));
385            assert!(Slash.precedes(Or));
386            assert!(Slash.precedes(Equal));
387            assert!(Slash.precedes(NotEqual));
388            assert!(Slash.precedes(LowerThan));
389            assert!(Slash.precedes(LowerOrEqualThan));
390            assert!(Slash.precedes(GreaterThan));
391            assert!(Slash.precedes(GreaterOrEqualThan));
392            assert!(Slash.precedes(Plus));
393            assert!(Slash.precedes(Minus));
394            assert!(!Slash.precedes(Asterisk));
395            assert!(!Slash.precedes(Slash));
396            assert!(!Slash.precedes(DoubleSlash));
397            assert!(!Slash.precedes(Percent));
398            assert!(Slash.precedes(Concat));
399            assert!(!Slash.precedes(Caret));
400            assert!(!Slash.precedes_unary_expression());
401        }
402
403        #[test]
404        fn percent() {
405            assert!(Percent.precedes(And));
406            assert!(Percent.precedes(Or));
407            assert!(Percent.precedes(Equal));
408            assert!(Percent.precedes(NotEqual));
409            assert!(Percent.precedes(LowerThan));
410            assert!(Percent.precedes(LowerOrEqualThan));
411            assert!(Percent.precedes(GreaterThan));
412            assert!(Percent.precedes(GreaterOrEqualThan));
413            assert!(Percent.precedes(Plus));
414            assert!(Percent.precedes(Minus));
415            assert!(!Percent.precedes(Asterisk));
416            assert!(!Percent.precedes(Slash));
417            assert!(!Percent.precedes(DoubleSlash));
418            assert!(!Percent.precedes(Percent));
419            assert!(Percent.precedes(Concat));
420            assert!(!Percent.precedes(Caret));
421            assert!(!Percent.precedes_unary_expression());
422        }
423
424        #[test]
425        fn plus() {
426            assert!(Plus.precedes(And));
427            assert!(Plus.precedes(Or));
428            assert!(Plus.precedes(Equal));
429            assert!(Plus.precedes(NotEqual));
430            assert!(Plus.precedes(LowerThan));
431            assert!(Plus.precedes(LowerOrEqualThan));
432            assert!(Plus.precedes(GreaterThan));
433            assert!(Plus.precedes(GreaterOrEqualThan));
434            assert!(!Plus.precedes(Plus));
435            assert!(!Plus.precedes(Minus));
436            assert!(!Plus.precedes(Asterisk));
437            assert!(!Plus.precedes(Slash));
438            assert!(!Plus.precedes(DoubleSlash));
439            assert!(!Plus.precedes(Percent));
440            assert!(Plus.precedes(Concat));
441            assert!(!Plus.precedes(Caret));
442            assert!(!Plus.precedes_unary_expression());
443        }
444
445        #[test]
446        fn minus() {
447            assert!(Minus.precedes(And));
448            assert!(Minus.precedes(Or));
449            assert!(Minus.precedes(Equal));
450            assert!(Minus.precedes(NotEqual));
451            assert!(Minus.precedes(LowerThan));
452            assert!(Minus.precedes(LowerOrEqualThan));
453            assert!(Minus.precedes(GreaterThan));
454            assert!(Minus.precedes(GreaterOrEqualThan));
455            assert!(!Minus.precedes(Plus));
456            assert!(!Minus.precedes(Minus));
457            assert!(!Minus.precedes(Asterisk));
458            assert!(!Minus.precedes(Slash));
459            assert!(!Minus.precedes(DoubleSlash));
460            assert!(!Minus.precedes(Percent));
461            assert!(Minus.precedes(Concat));
462            assert!(!Minus.precedes(Caret));
463            assert!(!Minus.precedes_unary_expression());
464        }
465
466        #[test]
467        fn concat() {
468            assert!(Concat.precedes(And));
469            assert!(Concat.precedes(Or));
470            assert!(Concat.precedes(Equal));
471            assert!(Concat.precedes(NotEqual));
472            assert!(Concat.precedes(LowerThan));
473            assert!(Concat.precedes(LowerOrEqualThan));
474            assert!(Concat.precedes(GreaterThan));
475            assert!(Concat.precedes(GreaterOrEqualThan));
476            assert!(!Concat.precedes(Plus));
477            assert!(!Concat.precedes(Minus));
478            assert!(!Concat.precedes(Asterisk));
479            assert!(!Concat.precedes(Slash));
480            assert!(!Concat.precedes(DoubleSlash));
481            assert!(!Concat.precedes(Percent));
482            assert!(!Concat.precedes(Concat));
483            assert!(!Concat.precedes(Caret));
484            assert!(!Concat.precedes_unary_expression());
485        }
486
487        #[test]
488        fn and() {
489            assert!(!And.precedes(And));
490            assert!(And.precedes(Or));
491            assert!(!And.precedes(Equal));
492            assert!(!And.precedes(NotEqual));
493            assert!(!And.precedes(LowerThan));
494            assert!(!And.precedes(LowerOrEqualThan));
495            assert!(!And.precedes(GreaterThan));
496            assert!(!And.precedes(GreaterOrEqualThan));
497            assert!(!And.precedes(Plus));
498            assert!(!And.precedes(Minus));
499            assert!(!And.precedes(Asterisk));
500            assert!(!And.precedes(Slash));
501            assert!(!And.precedes(DoubleSlash));
502            assert!(!And.precedes(Percent));
503            assert!(!And.precedes(Concat));
504            assert!(!And.precedes(Caret));
505            assert!(!And.precedes_unary_expression());
506        }
507    }
508}