darklua_core/nodes/expressions/
binary.rs

1use crate::nodes::{Expression, FunctionReturnType, Token, Type};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum BinaryOperator {
5    And,
6    Or,
7    Equal,
8    NotEqual,
9    LowerThan,
10    LowerOrEqualThan,
11    GreaterThan,
12    GreaterOrEqualThan,
13    Plus,
14    Minus,
15    Asterisk,
16    Slash,
17    DoubleSlash,
18    Percent,
19    Caret,
20    Concat,
21}
22
23#[inline]
24fn ends_with_if_expression(expression: &Expression) -> bool {
25    let mut current = expression;
26
27    loop {
28        match current {
29            Expression::If(_) => break true,
30            Expression::Binary(binary) => current = binary.right(),
31            Expression::Unary(unary) => current = unary.get_expression(),
32            Expression::Call(_)
33            | Expression::False(_)
34            | Expression::Field(_)
35            | Expression::Function(_)
36            | Expression::Identifier(_)
37            | Expression::Index(_)
38            | Expression::Nil(_)
39            | Expression::Number(_)
40            | Expression::Parenthese(_)
41            | Expression::String(_)
42            | Expression::InterpolatedString(_)
43            | Expression::Table(_)
44            | Expression::True(_)
45            | Expression::VariableArguments(_)
46            | Expression::TypeCast(_) => break false,
47        }
48    }
49}
50
51#[inline]
52fn ends_with_type_cast_to_type_name_without_type_parameters(expression: &Expression) -> bool {
53    let mut current = expression;
54
55    loop {
56        match current {
57            Expression::If(if_statement) => current = if_statement.get_else_result(),
58            Expression::Binary(binary) => current = binary.right(),
59            Expression::Unary(unary) => current = unary.get_expression(),
60            Expression::TypeCast(type_cast) => {
61                let mut current_type = type_cast.get_type();
62
63                break loop {
64                    match current_type {
65                        Type::Name(name) => break !name.has_type_parameters(),
66                        Type::Field(field) => break !field.get_type_name().has_type_parameters(),
67                        Type::Function(function) => {
68                            current_type = match function.get_return_type() {
69                                FunctionReturnType::Type(r#type) => r#type,
70                                FunctionReturnType::TypePack(_)
71                                | FunctionReturnType::GenericTypePack(_) => break false,
72                                FunctionReturnType::VariadicTypePack(variadic_type) => {
73                                    variadic_type.get_type()
74                                }
75                            }
76                        }
77                        Type::Intersection(intersection) => {
78                            current_type = intersection.last_type();
79                        }
80                        Type::Union(union_type) => {
81                            current_type = union_type.last_type();
82                        }
83                        Type::True(_)
84                        | Type::False(_)
85                        | Type::Nil(_)
86                        | Type::String(_)
87                        | Type::Array(_)
88                        | Type::Table(_)
89                        | Type::TypeOf(_)
90                        | Type::Parenthese(_)
91                        | Type::Optional(_) => break false,
92                    }
93                };
94            }
95            Expression::Call(_)
96            | Expression::False(_)
97            | Expression::Field(_)
98            | Expression::Function(_)
99            | Expression::Identifier(_)
100            | Expression::Index(_)
101            | Expression::Nil(_)
102            | Expression::Number(_)
103            | Expression::Parenthese(_)
104            | Expression::String(_)
105            | Expression::InterpolatedString(_)
106            | Expression::Table(_)
107            | Expression::True(_)
108            | Expression::VariableArguments(_) => break false,
109        }
110    }
111}
112
113impl BinaryOperator {
114    #[inline]
115    pub fn precedes(&self, other: Self) -> bool {
116        self.get_precedence() > other.get_precedence()
117    }
118
119    #[inline]
120    pub fn precedes_unary_expression(&self) -> bool {
121        matches!(self, Self::Caret)
122    }
123
124    #[inline]
125    pub fn is_left_associative(&self) -> bool {
126        !matches!(self, Self::Caret | Self::Concat)
127    }
128
129    #[inline]
130    pub fn is_right_associative(&self) -> bool {
131        matches!(self, Self::Caret | Self::Concat)
132    }
133
134    pub fn left_needs_parentheses(&self, left: &Expression) -> bool {
135        let needs_parentheses = match left {
136            Expression::Binary(left) => {
137                if self.is_left_associative() {
138                    self.precedes(left.operator())
139                } else {
140                    !left.operator().precedes(*self)
141                }
142            }
143            Expression::Unary(_) => self.precedes_unary_expression(),
144            Expression::If(_) => true,
145            _ => false,
146        };
147        needs_parentheses
148            || ends_with_if_expression(left)
149            || (matches!(self, BinaryOperator::LowerThan)
150                && ends_with_type_cast_to_type_name_without_type_parameters(left))
151    }
152
153    pub fn right_needs_parentheses(&self, right: &Expression) -> bool {
154        match right {
155            Expression::Binary(right) => {
156                if self.is_right_associative() {
157                    self.precedes(right.operator())
158                } else {
159                    !right.operator().precedes(*self)
160                }
161            }
162            Expression::Unary(_) => false,
163            _ => false,
164        }
165    }
166
167    pub fn to_str(&self) -> &'static str {
168        match self {
169            Self::And => "and",
170            Self::Or => "or",
171            Self::Equal => "==",
172            Self::NotEqual => "~=",
173            Self::LowerThan => "<",
174            Self::LowerOrEqualThan => "<=",
175            Self::GreaterThan => ">",
176            Self::GreaterOrEqualThan => ">=",
177            Self::Plus => "+",
178            Self::Minus => "-",
179            Self::Asterisk => "*",
180            Self::Slash => "/",
181            Self::DoubleSlash => "//",
182            Self::Percent => "%",
183            Self::Caret => "^",
184            Self::Concat => "..",
185        }
186    }
187
188    fn get_precedence(&self) -> u8 {
189        match self {
190            Self::Or => 0,
191            Self::And => 1,
192            Self::Equal
193            | Self::NotEqual
194            | Self::LowerThan
195            | Self::LowerOrEqualThan
196            | Self::GreaterThan
197            | Self::GreaterOrEqualThan => 2,
198            Self::Concat => 3,
199            Self::Plus | Self::Minus => 4,
200            Self::Asterisk | Self::Slash | Self::DoubleSlash | Self::Percent => 5,
201            Self::Caret => 7,
202        }
203    }
204}
205
206#[derive(Clone, Debug, PartialEq, Eq)]
207pub struct BinaryExpression {
208    operator: BinaryOperator,
209    left: Expression,
210    right: Expression,
211    token: Option<Token>,
212}
213
214impl BinaryExpression {
215    pub fn new<T: Into<Expression>, U: Into<Expression>>(
216        operator: BinaryOperator,
217        left: T,
218        right: U,
219    ) -> Self {
220        Self {
221            operator,
222            left: left.into(),
223            right: right.into(),
224            token: None,
225        }
226    }
227
228    pub fn with_token(mut self, token: Token) -> Self {
229        self.token = Some(token);
230        self
231    }
232
233    #[inline]
234    pub fn set_token(&mut self, token: Token) {
235        self.token = Some(token);
236    }
237
238    #[inline]
239    pub fn get_token(&self) -> Option<&Token> {
240        self.token.as_ref()
241    }
242
243    #[inline]
244    pub fn mutate_left(&mut self) -> &mut Expression {
245        &mut self.left
246    }
247
248    #[inline]
249    pub fn mutate_right(&mut self) -> &mut Expression {
250        &mut self.right
251    }
252
253    #[inline]
254    pub fn left(&self) -> &Expression {
255        &self.left
256    }
257
258    #[inline]
259    pub fn right(&self) -> &Expression {
260        &self.right
261    }
262
263    #[inline]
264    pub fn operator(&self) -> BinaryOperator {
265        self.operator
266    }
267
268    #[inline]
269    pub fn set_operator(&mut self, operator: BinaryOperator) {
270        if self.operator == operator {
271            return;
272        }
273        self.operator = operator;
274        if let Some(token) = self.token.as_mut() {
275            token.replace_with_content(operator.to_str());
276        }
277    }
278
279    super::impl_token_fns!(iter = [token]);
280}
281
282#[cfg(test)]
283mod test {
284    use super::*;
285
286    mod precedence {
287        use super::*;
288
289        use BinaryOperator::*;
290
291        #[test]
292        fn caret() {
293            assert!(Caret.precedes(And));
294            assert!(Caret.precedes(Or));
295            assert!(Caret.precedes(Equal));
296            assert!(Caret.precedes(NotEqual));
297            assert!(Caret.precedes(LowerThan));
298            assert!(Caret.precedes(LowerOrEqualThan));
299            assert!(Caret.precedes(GreaterThan));
300            assert!(Caret.precedes(GreaterOrEqualThan));
301            assert!(Caret.precedes(Plus));
302            assert!(Caret.precedes(Minus));
303            assert!(Caret.precedes(Asterisk));
304            assert!(Caret.precedes(Slash));
305            assert!(Caret.precedes(DoubleSlash));
306            assert!(Caret.precedes(Percent));
307            assert!(Caret.precedes(Concat));
308            assert!(!Caret.precedes(Caret));
309            assert!(Caret.precedes_unary_expression());
310        }
311
312        #[test]
313        fn asterisk() {
314            assert!(Asterisk.precedes(And));
315            assert!(Asterisk.precedes(Or));
316            assert!(Asterisk.precedes(Equal));
317            assert!(Asterisk.precedes(NotEqual));
318            assert!(Asterisk.precedes(LowerThan));
319            assert!(Asterisk.precedes(LowerOrEqualThan));
320            assert!(Asterisk.precedes(GreaterThan));
321            assert!(Asterisk.precedes(GreaterOrEqualThan));
322            assert!(Asterisk.precedes(Plus));
323            assert!(Asterisk.precedes(Minus));
324            assert!(!Asterisk.precedes(Asterisk));
325            assert!(!Asterisk.precedes(Slash));
326            assert!(!Asterisk.precedes(DoubleSlash));
327            assert!(!Asterisk.precedes(Percent));
328            assert!(Asterisk.precedes(Concat));
329            assert!(!Asterisk.precedes(Caret));
330            assert!(!Asterisk.precedes_unary_expression());
331        }
332
333        #[test]
334        fn slash() {
335            assert!(Slash.precedes(And));
336            assert!(Slash.precedes(Or));
337            assert!(Slash.precedes(Equal));
338            assert!(Slash.precedes(NotEqual));
339            assert!(Slash.precedes(LowerThan));
340            assert!(Slash.precedes(LowerOrEqualThan));
341            assert!(Slash.precedes(GreaterThan));
342            assert!(Slash.precedes(GreaterOrEqualThan));
343            assert!(Slash.precedes(Plus));
344            assert!(Slash.precedes(Minus));
345            assert!(!Slash.precedes(Asterisk));
346            assert!(!Slash.precedes(Slash));
347            assert!(!Slash.precedes(DoubleSlash));
348            assert!(!Slash.precedes(Percent));
349            assert!(Slash.precedes(Concat));
350            assert!(!Slash.precedes(Caret));
351            assert!(!Slash.precedes_unary_expression());
352        }
353
354        #[test]
355        fn percent() {
356            assert!(Percent.precedes(And));
357            assert!(Percent.precedes(Or));
358            assert!(Percent.precedes(Equal));
359            assert!(Percent.precedes(NotEqual));
360            assert!(Percent.precedes(LowerThan));
361            assert!(Percent.precedes(LowerOrEqualThan));
362            assert!(Percent.precedes(GreaterThan));
363            assert!(Percent.precedes(GreaterOrEqualThan));
364            assert!(Percent.precedes(Plus));
365            assert!(Percent.precedes(Minus));
366            assert!(!Percent.precedes(Asterisk));
367            assert!(!Percent.precedes(Slash));
368            assert!(!Percent.precedes(DoubleSlash));
369            assert!(!Percent.precedes(Percent));
370            assert!(Percent.precedes(Concat));
371            assert!(!Percent.precedes(Caret));
372            assert!(!Percent.precedes_unary_expression());
373        }
374
375        #[test]
376        fn plus() {
377            assert!(Plus.precedes(And));
378            assert!(Plus.precedes(Or));
379            assert!(Plus.precedes(Equal));
380            assert!(Plus.precedes(NotEqual));
381            assert!(Plus.precedes(LowerThan));
382            assert!(Plus.precedes(LowerOrEqualThan));
383            assert!(Plus.precedes(GreaterThan));
384            assert!(Plus.precedes(GreaterOrEqualThan));
385            assert!(!Plus.precedes(Plus));
386            assert!(!Plus.precedes(Minus));
387            assert!(!Plus.precedes(Asterisk));
388            assert!(!Plus.precedes(Slash));
389            assert!(!Plus.precedes(DoubleSlash));
390            assert!(!Plus.precedes(Percent));
391            assert!(Plus.precedes(Concat));
392            assert!(!Plus.precedes(Caret));
393            assert!(!Plus.precedes_unary_expression());
394        }
395
396        #[test]
397        fn minus() {
398            assert!(Minus.precedes(And));
399            assert!(Minus.precedes(Or));
400            assert!(Minus.precedes(Equal));
401            assert!(Minus.precedes(NotEqual));
402            assert!(Minus.precedes(LowerThan));
403            assert!(Minus.precedes(LowerOrEqualThan));
404            assert!(Minus.precedes(GreaterThan));
405            assert!(Minus.precedes(GreaterOrEqualThan));
406            assert!(!Minus.precedes(Plus));
407            assert!(!Minus.precedes(Minus));
408            assert!(!Minus.precedes(Asterisk));
409            assert!(!Minus.precedes(Slash));
410            assert!(!Minus.precedes(DoubleSlash));
411            assert!(!Minus.precedes(Percent));
412            assert!(Minus.precedes(Concat));
413            assert!(!Minus.precedes(Caret));
414            assert!(!Minus.precedes_unary_expression());
415        }
416
417        #[test]
418        fn concat() {
419            assert!(Concat.precedes(And));
420            assert!(Concat.precedes(Or));
421            assert!(Concat.precedes(Equal));
422            assert!(Concat.precedes(NotEqual));
423            assert!(Concat.precedes(LowerThan));
424            assert!(Concat.precedes(LowerOrEqualThan));
425            assert!(Concat.precedes(GreaterThan));
426            assert!(Concat.precedes(GreaterOrEqualThan));
427            assert!(!Concat.precedes(Plus));
428            assert!(!Concat.precedes(Minus));
429            assert!(!Concat.precedes(Asterisk));
430            assert!(!Concat.precedes(Slash));
431            assert!(!Concat.precedes(DoubleSlash));
432            assert!(!Concat.precedes(Percent));
433            assert!(!Concat.precedes(Concat));
434            assert!(!Concat.precedes(Caret));
435            assert!(!Concat.precedes_unary_expression());
436        }
437
438        #[test]
439        fn and() {
440            assert!(!And.precedes(And));
441            assert!(And.precedes(Or));
442            assert!(!And.precedes(Equal));
443            assert!(!And.precedes(NotEqual));
444            assert!(!And.precedes(LowerThan));
445            assert!(!And.precedes(LowerOrEqualThan));
446            assert!(!And.precedes(GreaterThan));
447            assert!(!And.precedes(GreaterOrEqualThan));
448            assert!(!And.precedes(Plus));
449            assert!(!And.precedes(Minus));
450            assert!(!And.precedes(Asterisk));
451            assert!(!And.precedes(Slash));
452            assert!(!And.precedes(DoubleSlash));
453            assert!(!And.precedes(Percent));
454            assert!(!And.precedes(Concat));
455            assert!(!And.precedes(Caret));
456            assert!(!And.precedes_unary_expression());
457        }
458    }
459}