1use std::{cmp::Ordering, fmt::Display};
8
9use crate::compiler::grammar::expr::parser::ID_EXTERNAL;
10use crate::Event;
11use crate::{compiler::Number, runtime::Variable, Context};
12
13use crate::compiler::grammar::expr::{BinaryOperator, Constant, Expression, UnaryOperator};
14
15impl Context<'_> {
16    pub(crate) fn eval_expression(&mut self, expr: &[Expression]) -> Result<Variable, Event> {
17        let mut exprs = expr.iter().skip(self.expr_pos);
18        while let Some(expr) = exprs.next() {
19            self.expr_pos += 1;
20            match expr {
21                Expression::Variable(v) => {
22                    self.expr_stack.push(self.variable(v).unwrap_or_default());
23                }
24                Expression::Constant(val) => {
25                    self.expr_stack.push(Variable::from(val));
26                }
27                Expression::UnaryOperator(op) => {
28                    let value = self.expr_stack.pop().unwrap_or_default();
29                    self.expr_stack.push(match op {
30                        UnaryOperator::Not => value.op_not(),
31                        UnaryOperator::Minus => value.op_minus(),
32                    });
33                }
34                Expression::BinaryOperator(op) => {
35                    let right = self.expr_stack.pop().unwrap_or_default();
36                    let left = self.expr_stack.pop().unwrap_or_default();
37                    self.expr_stack.push(match op {
38                        BinaryOperator::Add => left.op_add(right),
39                        BinaryOperator::Subtract => left.op_subtract(right),
40                        BinaryOperator::Multiply => left.op_multiply(right),
41                        BinaryOperator::Divide => left.op_divide(right),
42                        BinaryOperator::And => left.op_and(right),
43                        BinaryOperator::Or => left.op_or(right),
44                        BinaryOperator::Xor => left.op_xor(right),
45                        BinaryOperator::Eq => left.op_eq(right),
46                        BinaryOperator::Ne => left.op_ne(right),
47                        BinaryOperator::Lt => left.op_lt(right),
48                        BinaryOperator::Le => left.op_le(right),
49                        BinaryOperator::Gt => left.op_gt(right),
50                        BinaryOperator::Ge => left.op_ge(right),
51                    });
52                }
53                Expression::Function { id, num_args } => {
54                    let num_args = *num_args as usize;
55
56                    if let Some(fnc) = self.runtime.functions.get(*id as usize) {
57                        let mut arguments = vec![Variable::Integer(0); num_args];
58                        for arg_num in 0..num_args {
59                            arguments[num_args - arg_num - 1] =
60                                self.expr_stack.pop().unwrap_or_default();
61                        }
62                        self.expr_stack.push((fnc)(self, arguments));
63                    } else {
64                        let mut arguments = vec![Variable::Integer(0); num_args];
65                        for arg_num in 0..num_args {
66                            arguments[num_args - arg_num - 1] =
67                                self.expr_stack.pop().unwrap_or_default();
68                        }
69                        self.pos -= 1; return Err(Event::Function {
71                            id: ID_EXTERNAL - *id,
72                            arguments,
73                        });
74                    }
75                }
76                Expression::JmpIf { val, pos } => {
77                    if self.expr_stack.last().is_some_and(|v| v.to_bool()) == *val {
78                        self.expr_pos += *pos as usize;
79                        for _ in 0..*pos {
80                            exprs.next();
81                        }
82                    }
83                }
84                Expression::ArrayAccess => {
85                    let index = self.expr_stack.pop().unwrap_or_default().to_usize();
86                    let array = self.expr_stack.pop().unwrap_or_default().into_array();
87                    self.expr_stack
88                        .push(array.get(index).cloned().unwrap_or_default());
89                }
90                Expression::ArrayBuild(num_items) => {
91                    let num_items = *num_items as usize;
92                    let mut items = vec![Variable::Integer(0); num_items];
93                    for arg_num in 0..num_items {
94                        items[num_items - arg_num - 1] = self.expr_stack.pop().unwrap_or_default();
95                    }
96                    self.expr_stack.push(Variable::Array(items.into()));
97                }
98            }
99        }
100
101        let result = self.expr_stack.pop().unwrap_or_default();
102        self.expr_stack.clear();
103        self.expr_pos = 0;
104        Ok(result)
105    }
106}
107
108impl Variable {
109    pub fn op_add(self, other: Variable) -> Variable {
110        match (self, other) {
111            (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_add(b)),
112            (Variable::Float(a), Variable::Float(b)) => Variable::Float(a + b),
113            (Variable::Integer(i), Variable::Float(f))
114            | (Variable::Float(f), Variable::Integer(i)) => Variable::Float(i as f64 + f),
115            (Variable::Array(a), Variable::Array(b)) => {
116                Variable::Array(a.iter().chain(b.iter()).cloned().collect::<Vec<_>>().into())
117            }
118            (Variable::Array(a), b) => a.iter().cloned().chain([b]).collect::<Vec<_>>().into(),
119            (a, Variable::Array(b)) => [a]
120                .into_iter()
121                .chain(b.iter().cloned())
122                .collect::<Vec<_>>()
123                .into(),
124            (Variable::String(a), b) => {
125                if !a.is_empty() {
126                    Variable::String(format!("{}{}", a, b).into())
127                } else {
128                    b
129                }
130            }
131            (a, Variable::String(b)) => {
132                if !b.is_empty() {
133                    Variable::String(format!("{}{}", a, b).into())
134                } else {
135                    a
136                }
137            }
138        }
139    }
140
141    pub fn op_subtract(self, other: Variable) -> Variable {
142        match (self, other) {
143            (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_sub(b)),
144            (Variable::Float(a), Variable::Float(b)) => Variable::Float(a - b),
145            (Variable::Integer(a), Variable::Float(b)) => Variable::Float(a as f64 - b),
146            (Variable::Float(a), Variable::Integer(b)) => Variable::Float(a - b as f64),
147            (Variable::Array(a), b) | (b, Variable::Array(a)) => Variable::Array(
148                a.iter()
149                    .filter(|v| *v != &b)
150                    .cloned()
151                    .collect::<Vec<_>>()
152                    .into(),
153            ),
154            (a, b) => a.parse_number().op_subtract(b.parse_number()),
155        }
156    }
157
158    pub fn op_multiply(self, other: Variable) -> Variable {
159        match (self, other) {
160            (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_mul(b)),
161            (Variable::Float(a), Variable::Float(b)) => Variable::Float(a * b),
162            (Variable::Integer(i), Variable::Float(f))
163            | (Variable::Float(f), Variable::Integer(i)) => Variable::Float(i as f64 * f),
164            (a, b) => a.parse_number().op_multiply(b.parse_number()),
165        }
166    }
167
168    pub fn op_divide(self, other: Variable) -> Variable {
169        match (self, other) {
170            (Variable::Integer(a), Variable::Integer(b)) => {
171                Variable::Float(if b != 0 { a as f64 / b as f64 } else { 0.0 })
172            }
173            (Variable::Float(a), Variable::Float(b)) => {
174                Variable::Float(if b != 0.0 { a / b } else { 0.0 })
175            }
176            (Variable::Integer(a), Variable::Float(b)) => {
177                Variable::Float(if b != 0.0 { a as f64 / b } else { 0.0 })
178            }
179            (Variable::Float(a), Variable::Integer(b)) => {
180                Variable::Float(if b != 0 { a / b as f64 } else { 0.0 })
181            }
182            (a, b) => a.parse_number().op_divide(b.parse_number()),
183        }
184    }
185
186    pub fn op_and(self, other: Variable) -> Variable {
187        Variable::Integer(i64::from(self.to_bool() & other.to_bool()))
188    }
189
190    pub fn op_or(self, other: Variable) -> Variable {
191        Variable::Integer(i64::from(self.to_bool() | other.to_bool()))
192    }
193
194    pub fn op_xor(self, other: Variable) -> Variable {
195        Variable::Integer(i64::from(self.to_bool() ^ other.to_bool()))
196    }
197
198    pub fn op_eq(self, other: Variable) -> Variable {
199        Variable::Integer(i64::from(self == other))
200    }
201
202    pub fn op_ne(self, other: Variable) -> Variable {
203        Variable::Integer(i64::from(self != other))
204    }
205
206    pub fn op_lt(self, other: Variable) -> Variable {
207        Variable::Integer(i64::from(self < other))
208    }
209
210    pub fn op_le(self, other: Variable) -> Variable {
211        Variable::Integer(i64::from(self <= other))
212    }
213
214    pub fn op_gt(self, other: Variable) -> Variable {
215        Variable::Integer(i64::from(self > other))
216    }
217
218    pub fn op_ge(self, other: Variable) -> Variable {
219        Variable::Integer(i64::from(self >= other))
220    }
221
222    pub fn op_not(self) -> Variable {
223        Variable::Integer(i64::from(!self.to_bool()))
224    }
225
226    pub fn op_minus(self) -> Variable {
227        match self {
228            Variable::Integer(n) => Variable::Integer(-n),
229            Variable::Float(n) => Variable::Float(-n),
230            _ => self.parse_number().op_minus(),
231        }
232    }
233
234    pub fn parse_number(&self) -> Variable {
235        match self {
236            Variable::String(s) if !s.is_empty() => {
237                if let Ok(n) = s.parse::<i64>() {
238                    Variable::Integer(n)
239                } else if let Ok(n) = s.parse::<f64>() {
240                    Variable::Float(n)
241                } else {
242                    Variable::Integer(0)
243                }
244            }
245            Variable::Integer(n) => Variable::Integer(*n),
246            Variable::Float(n) => Variable::Float(*n),
247            Variable::Array(l) => Variable::Integer(l.is_empty() as i64),
248            _ => Variable::Integer(0),
249        }
250    }
251
252    pub fn to_bool(&self) -> bool {
253        match self {
254            Variable::Float(f) => *f != 0.0,
255            Variable::Integer(n) => *n != 0,
256            Variable::String(s) => !s.is_empty(),
257            Variable::Array(a) => !a.is_empty(),
258        }
259    }
260}
261
262impl PartialEq for Variable {
263    fn eq(&self, other: &Self) -> bool {
264        match (self, other) {
265            (Self::Integer(a), Self::Integer(b)) => a == b,
266            (Self::Float(a), Self::Float(b)) => a == b,
267            (Self::Integer(a), Self::Float(b)) | (Self::Float(b), Self::Integer(a)) => {
268                *a as f64 == *b
269            }
270            (Self::String(a), Self::String(b)) => a == b,
271            (Self::String(_), Self::Integer(_) | Self::Float(_)) => &self.parse_number() == other,
272            (Self::Integer(_) | Self::Float(_), Self::String(_)) => self == &other.parse_number(),
273            (Self::Array(a), Self::Array(b)) => a == b,
274            _ => false,
275        }
276    }
277}
278
279impl Eq for Variable {}
280
281#[allow(clippy::non_canonical_partial_ord_impl)]
282impl PartialOrd for Variable {
283    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
284        match (self, other) {
285            (Self::Integer(a), Self::Integer(b)) => a.partial_cmp(b),
286            (Self::Float(a), Self::Float(b)) => a.partial_cmp(b),
287            (Self::Integer(a), Self::Float(b)) => (*a as f64).partial_cmp(b),
288            (Self::Float(a), Self::Integer(b)) => a.partial_cmp(&(*b as f64)),
289            (Self::String(a), Self::String(b)) => a.partial_cmp(b),
290            (Self::String(_), Self::Integer(_) | Self::Float(_)) => {
291                self.parse_number().partial_cmp(other)
292            }
293            (Self::Integer(_) | Self::Float(_), Self::String(_)) => {
294                self.partial_cmp(&other.parse_number())
295            }
296            (Self::Array(a), Self::Array(b)) => a.partial_cmp(b),
297            (Self::Array(_) | Self::String(_), _) => Ordering::Greater.into(),
298            (_, Self::Array(_)) => Ordering::Less.into(),
299        }
300    }
301}
302
303impl Ord for Variable {
304    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
305        self.partial_cmp(other).unwrap_or(Ordering::Greater)
306    }
307}
308
309impl Display for Variable {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        match self {
312            Variable::String(v) => v.fmt(f),
313            Variable::Integer(v) => v.fmt(f),
314            Variable::Float(v) => v.fmt(f),
315            Variable::Array(v) => {
316                for (i, v) in v.iter().enumerate() {
317                    if i > 0 {
318                        f.write_str("\n")?;
319                    }
320                    v.fmt(f)?;
321                }
322                Ok(())
323            }
324        }
325    }
326}
327
328impl Number {
329    pub fn is_non_zero(&self) -> bool {
330        match self {
331            Number::Integer(n) => *n != 0,
332            Number::Float(n) => *n != 0.0,
333        }
334    }
335}
336
337impl Default for Number {
338    fn default() -> Self {
339        Number::Integer(0)
340    }
341}
342
343impl From<bool> for Number {
344    #[inline(always)]
345    fn from(b: bool) -> Self {
346        Number::Integer(i64::from(b))
347    }
348}
349
350impl From<i64> for Number {
351    #[inline(always)]
352    fn from(n: i64) -> Self {
353        Number::Integer(n)
354    }
355}
356
357impl From<f64> for Number {
358    #[inline(always)]
359    fn from(n: f64) -> Self {
360        Number::Float(n)
361    }
362}
363
364impl From<i32> for Number {
365    #[inline(always)]
366    fn from(n: i32) -> Self {
367        Number::Integer(n as i64)
368    }
369}
370
371impl<'x> From<&'x Constant> for Variable {
372    fn from(value: &'x Constant) -> Self {
373        match value {
374            Constant::Integer(i) => Variable::Integer(*i),
375            Constant::Float(f) => Variable::Float(*f),
376            Constant::String(s) => Variable::String(s.clone()),
377        }
378    }
379}
380
381#[cfg(test)]
382mod test {
383    use ahash::{HashMap, HashMapExt};
384
385    use crate::{
386        compiler::{
387            grammar::expr::{
388                parser::ExpressionParser, tokenizer::Tokenizer, BinaryOperator, Expression, Token,
389                UnaryOperator,
390            },
391            VariableType,
392        },
393        runtime::Variable,
394    };
395
396    use evalexpr::*;
397
398    pub trait EvalExpression {
399        fn eval(&self, variables: &HashMap<String, Variable>) -> Option<Variable>;
400    }
401
402    impl EvalExpression for Vec<Expression> {
403        fn eval(&self, variables: &HashMap<String, Variable>) -> Option<Variable> {
404            let mut stack = Vec::with_capacity(self.len());
405            let mut exprs = self.iter();
406
407            while let Some(expr) = exprs.next() {
408                match expr {
409                    Expression::Variable(VariableType::Global(v)) => {
410                        stack.push(variables.get(v)?.clone());
411                    }
412                    Expression::Constant(val) => {
413                        stack.push(Variable::from(val));
414                    }
415                    Expression::UnaryOperator(op) => {
416                        let value = stack.pop()?;
417                        stack.push(match op {
418                            UnaryOperator::Not => value.op_not(),
419                            UnaryOperator::Minus => value.op_minus(),
420                        });
421                    }
422                    Expression::BinaryOperator(op) => {
423                        let right = stack.pop()?;
424                        let left = stack.pop()?;
425                        stack.push(match op {
426                            BinaryOperator::Add => left.op_add(right),
427                            BinaryOperator::Subtract => left.op_subtract(right),
428                            BinaryOperator::Multiply => left.op_multiply(right),
429                            BinaryOperator::Divide => left.op_divide(right),
430                            BinaryOperator::And => left.op_and(right),
431                            BinaryOperator::Or => left.op_or(right),
432                            BinaryOperator::Xor => left.op_xor(right),
433                            BinaryOperator::Eq => left.op_eq(right),
434                            BinaryOperator::Ne => left.op_ne(right),
435                            BinaryOperator::Lt => left.op_lt(right),
436                            BinaryOperator::Le => left.op_le(right),
437                            BinaryOperator::Gt => left.op_gt(right),
438                            BinaryOperator::Ge => left.op_ge(right),
439                        });
440                    }
441                    Expression::JmpIf { val, pos } => {
442                        if stack.last()?.to_bool() == *val {
443                            for _ in 0..*pos {
444                                exprs.next();
445                            }
446                        }
447                    }
448                    _ => unreachable!("Invalid expression"),
449                }
450            }
451            stack.pop()
452        }
453    }
454
455    #[test]
456    fn eval_expression() {
457        let mut variables = HashMap::from_iter([
458            ("A".to_string(), Variable::Integer(0)),
459            ("B".to_string(), Variable::Integer(0)),
460            ("C".to_string(), Variable::Integer(0)),
461            ("D".to_string(), Variable::Integer(0)),
462            ("E".to_string(), Variable::Integer(0)),
463            ("F".to_string(), Variable::Integer(0)),
464            ("G".to_string(), Variable::Integer(0)),
465            ("H".to_string(), Variable::Integer(0)),
466            ("I".to_string(), Variable::Integer(0)),
467            ("J".to_string(), Variable::Integer(0)),
468        ]);
469        let num_vars = variables.len();
470
471        for expr in [
472            "A + B",
473            "A * B",
474            "A / B",
475            "A - B",
476            "-A",
477            "A == B",
478            "A != B",
479            "A > B",
480            "A < B",
481            "A >= B",
482            "A <= B",
483            "A + B * C - D / E",
484            "A + B + C - D - E",
485            "(A + B) * (C - D) / E",
486            "A - B + C * D / E * F - G",
487            "A + B * C - D / E",
488            "(A + B) * (C - D) / E",
489            "A - B + C / D * E",
490            "(A + B) / (C - D) + E",
491            "A * (B + C) - D / E",
492            "A / (B - C + D) * E",
493            "(A + B) * C - D / (E + F)",
494            "A * B - C + D / E",
495            "A + B - C * D / E",
496            "(A * B + C) / D - E",
497            "A - B / C + D * E",
498            "A + B * (C - D) / E",
499            "A * B / C + (D - E)",
500            "(A - B) * C / D + E",
501            "A * (B / C) - D + E",
502            "(A + B) / (C + D) * E",
503            "A - B * C / D + E",
504            "A + (B - C) * D / E",
505            "(A + B) * (C / D) - E",
506            "A - B / (C * D) + E",
507            "(A + B) > (C - D) && E <= F",
508            "A * B == C / D || E - F != G + H",
509            "A / B >= C * D && E + F < G - H",
510            "(A * B - C) != (D / E + F) && G > H",
511            "A - B < C && D + E >= F * G",
512            "(A * B) > C && (D / E) < F || G == H",
513            "(A + B) <= (C - D) || E > F && G != H",
514            "A * B != C + D || E - F == G / H",
515            "A >= B * C && D < E - F || G != H + I",
516            "(A / B + C) > D && E * F <= G - H",
517            "A * (B - C) == D && E / F > G + H",
518            "(A - B + C) != D || E * F >= G && H < I",
519            "A < B / C && D + E * F == G - H",
520            "(A + B * C) <= D && E > F / G",
521            "(A * B - C) > D || E <= F + G && H != I",
522            "A != B / C && D == E * F - G",
523            "A <= B + C - D && E / F > G * H",
524            "(A - B * C) < D || E >= F + G && H != I",
525            "(A + B) / C == D && E - F < G * H",
526            "A * B != C && D >= E + F / G || H < I",
527            "!(A * B != C) && !(D >= E + F / G) || !(H < I)",
528            "-A - B - (- C - D) - E - (-F)",
529        ] {
530            println!("Testing {}", expr);
531            for (pos, v) in variables.values_mut().enumerate() {
532                *v = Variable::Integer(pos as i64 + 1);
533            }
534
535            assert_expr(expr, &variables);
536
537            for (pos, v) in variables.values_mut().enumerate() {
538                *v = Variable::Integer((num_vars - pos) as i64);
539            }
540
541            assert_expr(expr, &variables);
542        }
543
544        for expr in [
545            "true && false",
546            "!true || false",
547            "true && !false",
548            "!(true && false)",
549            "true || true && false",
550            "!false && (true || false)",
551            "!(true || !false) && true",
552            "!(!true && !false)",
553            "true || false && !true",
554            "!(true && true) || !false",
555            "!(!true || !false) && (!false) && !(!true)",
556        ] {
557            let pexp = parse_expression(expr.replace("true", "1").replace("false", "0").as_str());
558            let result = pexp.eval(&HashMap::new()).unwrap();
559
560            match (eval(expr).expect(expr), result) {
563                (Value::Float(a), Variable::Float(b)) if a == b => (),
564                (Value::Float(a), Variable::Integer(b)) if a == b as f64 => (),
565                (Value::Boolean(a), Variable::Integer(b)) if a == (b != 0) => (),
566                (a, b) => {
567                    panic!("{} => {:?} != {:?}", expr, a, b)
568                }
569            }
570        }
571    }
572
573    fn assert_expr(expr: &str, variables: &HashMap<String, Variable>) {
574        let e = parse_expression(expr);
575
576        let result = e.eval(variables).unwrap();
577
578        let mut str_expr = expr.to_string();
579        let mut str_expr_float = expr.to_string();
580        for (k, v) in variables {
581            let v = v.to_string();
582
583            if v.contains('.') {
584                str_expr_float = str_expr_float.replace(k, &v);
585            } else {
586                str_expr_float = str_expr_float.replace(k, &format!("{}.0", v));
587            }
588            str_expr = str_expr.replace(k, &v);
589        }
590
591        assert_eq!(
592            parse_expression(&str_expr)
593                .eval(&HashMap::new())
594                .unwrap()
595                .to_number()
596                .to_float(),
597            result.to_number().to_float()
598        );
599
600        assert_eq!(
601            parse_expression(&str_expr_float)
602                .eval(&HashMap::new())
603                .unwrap()
604                .to_number()
605                .to_float(),
606            result.to_number().to_float()
607        );
608
609        match (
612            eval(&str_expr_float)
613                .map(|v| {
614                    if matches!(&v, Value::Float(f) if f.is_infinite()) {
616                        Value::Float(0.0)
617                    } else {
618                        v
619                    }
620                })
621                .expect(&str_expr),
622            result,
623        ) {
624            (Value::Float(a), Variable::Float(b)) if a == b => (),
625            (Value::Float(a), Variable::Integer(b)) if a == b as f64 => (),
626            (Value::Boolean(a), Variable::Integer(b)) if a == (b != 0) => (),
627            (a, b) => {
628                panic!("{} => {:?} != {:?}", str_expr, a, b)
629            }
630        }
631    }
632
633    fn parse_expression(expr: &str) -> Vec<Expression> {
634        ExpressionParser::from_tokenizer(Tokenizer::new(expr, |var_name: &str, _: bool| {
635            Ok::<_, String>(Token::Variable(VariableType::Global(var_name.to_string())))
636        }))
637        .parse()
638        .unwrap()
639        .output
640    }
641}