prexel/
evaluator.rs

1use std::fmt::Debug;
2use std::marker::PhantomData;
3use std::str::FromStr;
4
5use crate::context::{Context, DefaultContext};
6use crate::error::{Error, ErrorKind};
7use crate::num::checked::CheckedNum;
8use crate::token::Token;
9use crate::token::Token::*;
10use crate::tokenizer::Tokenizer;
11use crate::Result;
12
13/// Represents the default `Evaluator`.
14pub struct Evaluator<'a, N, C: Context<'a, N> = DefaultContext<'a, N>> {
15    /// The context used for evaluation.
16    context: C,
17    tokenizer: Tokenizer<'a, N, C>,
18    _marker: &'a PhantomData<N>,
19}
20
21impl<'a, N: CheckedNum> Default for Evaluator<'a, N, DefaultContext<'a, N>> {
22    fn default() -> Self {
23        Evaluator::new()
24    }
25}
26
27impl<'a, N: CheckedNum> Evaluator<'a, N, DefaultContext<'a, N>> {
28    /// Constructs a new `Evaluator` using the checked `DefaultContext`.
29    #[inline]
30    pub fn new() -> Self {
31        Evaluator {
32            context: DefaultContext::new_checked(),
33            tokenizer: Tokenizer::new(),
34            _marker: &PhantomData,
35        }
36    }
37}
38
39impl<'a, N, C> Evaluator<'a, N, C>
40where
41    C: Context<'a, N>,
42{
43    /// Constructs a new `Evaluator` using the specified `Context`.
44    #[inline]
45    pub fn with_context(context: C) -> Self {
46        Evaluator {
47            context,
48            tokenizer: Tokenizer::new(),
49            _marker: &PhantomData,
50        }
51    }
52
53    /// Gets a reference to the `Context` used by this evaluator.
54    #[inline]
55    pub fn context(&self) -> &C {
56        &self.context
57    }
58
59    /// Gets a mutable reference to the `Context` used by this evaluator.
60    #[inline]
61    pub fn mut_context(&mut self) -> &mut C {
62        &mut self.context
63    }
64}
65
66impl<'a, N, C> Evaluator<'a, N, C>
67where
68    C: Context<'a, N>,
69{
70    pub fn with_context_and_tokenizer(context: C, tokenizer: Tokenizer<'a, N, C>) -> Self {
71        Evaluator {
72            context,
73            tokenizer,
74            _marker: &PhantomData,
75        }
76    }
77}
78
79impl<'a, N, C> Evaluator<'a, N, C>
80where
81    C: Context<'a, N>,
82    N: FromStr + Debug + Clone,
83{
84    /// Evaluates the given `str` expression.
85    ///
86    /// # Example
87    /// ```
88    /// use prexel::evaluator::Evaluator;
89    ///
90    /// let evaluator : Evaluator<f64> = Evaluator::new();
91    /// match evaluator.eval("3 + 2 * 5"){
92    ///     Ok(n) => {
93    ///         assert_eq!(n, 13_f64);
94    ///         println!("Result: {}", n);
95    ///      },
96    ///     Err(e) => println!("{}", e)
97    /// }
98    /// ```
99    #[inline]
100    pub fn eval(&'a self, expression: &str) -> Result<N> {
101        let context = self.context();
102        let tokens = self.tokenizer.tokenize(context, expression)?;
103        self.eval_tokens(&tokens)
104    }
105}
106
107impl<'a, C, N> Evaluator<'a, N, C>
108where
109    C: Context<'a, N>,
110    N: Debug + Clone,
111{
112    #[inline]
113    pub fn eval_tokens(&self, tokens: &[Token<N>]) -> Result<N> {
114        rpn_eval(tokens, self.context())
115    }
116}
117
118/// Evaluates an array of tokens in `Reverse Polish Notation`.
119///
120/// # Arguments
121/// - token: The tokens of the expression to convert.
122/// - context: the context which contains the variables, constants and functions.
123///
124/// See: `https://en.wikipedia.org/wiki/Reverse_Polish_notation`
125pub fn rpn_eval<'a, N, C>(tokens: &[Token<N>], context: &C) -> Result<N>
126where
127    N: Debug + Clone,
128    C: Context<'a, N>,
129{
130    // Converts the array of tokens to RPN.
131    let rpn = shunting_yard::infix_to_rpn(tokens, context)?;
132    // Stores the resulting values
133    let mut values: Vec<N> = Vec::new();
134    // Stores the argument count of the current function, if any.
135    let mut arg_count: Option<usize> = None;
136
137    for token in &rpn {
138        match token {
139            Number(n) => values.push(n.clone()),
140            Variable(name) => {
141                let n = context
142                    .get_variable(name)
143                    .ok_or_else(|| {
144                        Error::new(
145                            ErrorKind::InvalidInput,
146                            format!("Variable `{}` not found", name),
147                        )
148                    })?
149                    .clone();
150
151                values.push(n);
152            }
153            Constant(name) => {
154                let n = context
155                    .get_constant(name)
156                    .ok_or_else(|| {
157                        Error::new(
158                            ErrorKind::InvalidInput,
159                            format!("Constant `{}` not found", name),
160                        )
161                    })?
162                    .clone();
163
164                values.push(n);
165            }
166            ArgCount(n) => {
167                debug_assert_eq!(arg_count, None);
168                arg_count = Some(*n);
169            }
170            UnaryOperator(name) => {
171                let func = context.get_unary_function(name).ok_or_else(|| {
172                    Error::new(
173                        ErrorKind::InvalidInput,
174                        format!("Unary operator `{}` not found", name),
175                    )
176                })?;
177
178                match values.pop() {
179                    Some(n) => {
180                        let result = func.call(n)?;
181                        values.push(result);
182                    }
183                    _ => {
184                        return Err(Error::new(
185                            ErrorKind::InvalidExpression,
186                            format!("{:?}", &tokens),
187                        ));
188                    }
189                }
190            }
191            BinaryOperator(name) => {
192                let func = context.get_binary_function(name).ok_or_else(|| {
193                    Error::new(
194                        ErrorKind::InvalidInput,
195                        format!("Binary operator `{}` not found", name),
196                    )
197                })?;
198
199                match (values.pop(), values.pop()) {
200                    (Some(x), Some(y)) => {
201                        let result = func.call(y, x)?;
202                        values.push(result);
203                    }
204                    _ => {
205                        return Err(Error::new(
206                            ErrorKind::InvalidExpression,
207                            format!("{:?}", &tokens),
208                        ));
209                    }
210                }
211            }
212            Function(name) => {
213                // A reference to the function
214                let func = context.get_function(name).ok_or_else(|| {
215                    Error::new(
216                        ErrorKind::InvalidInput,
217                        format!("Function `{}` not found", name),
218                    )
219                })?;
220
221                // The number of arguments the function takes
222                let n = arg_count.ok_or_else(|| {
223                    Error::new(
224                        ErrorKind::InvalidInput,
225                        format!(
226                            "Cannot evaluate function `{}`, unknown number of arguments",
227                            name
228                        ),
229                    )
230                })?;
231
232                // Stores the arguments to pass to the function.
233                let mut args = Vec::new();
234
235                for _ in 0..n {
236                    match values.pop() {
237                        Some(n) => args.push(n.clone()),
238                        None => {
239                            Error::new(
240                                ErrorKind::InvalidArgumentCount,
241                                format!("expected {} arguments but {} was get", n, args.len()),
242                            );
243                        }
244                    }
245                }
246
247                // Reverse the order of the arguments.
248                // For a function as `TakeFirst(1, 2, 3)`, values are taken from last,
249                // so `args` will contain [3, 2, 1], so reverse is needed.
250                args.reverse();
251                let result = func.call(&args)?;
252                values.push(result);
253                arg_count = None;
254            }
255            _ => {
256                return Err(Error::new(
257                    ErrorKind::InvalidInput,
258                    format!("Unknown token: `{:?}`", token),
259                ));
260            }
261        }
262    }
263
264    // If there is a single value left, that is the result
265    if values.len() == 1 {
266        Ok(values[0].clone())
267    } else {
268        Err(Error::from(ErrorKind::InvalidExpression))
269    }
270}
271
272/// Converts the given array of tokens to reverse polish notation.
273///
274/// # Arguments
275/// - token: The tokens of the expression to convert.
276/// - context: the context which contains the variables, constants and functions.
277///
278/// # Example
279/// ```
280/// use prexel::token::Token::*;
281/// use prexel::evaluator;
282/// use prexel::context::DefaultContext;
283///
284/// let tokens = [Number(5), BinaryOperator("+".to_string()), Number(2)];
285/// let context = DefaultContext::new_checked();
286/// let rpn = evaluator::infix_to_rpn(&tokens, &context).unwrap();
287///
288/// assert_eq!(&rpn, &[Number(5), Number(2), BinaryOperator("+".to_string())]);
289/// ```
290#[inline(always)]
291pub fn infix_to_rpn<'a, N, C>(tokens: &[Token<N>], context: &C) -> Result<Vec<Token<N>>>
292where
293    N: Clone + Debug,
294    C: Context<'a, N>,
295{
296    shunting_yard::infix_to_rpn(tokens, context)
297}
298
299mod shunting_yard {
300    use std::fmt::Debug;
301
302    use crate::context::Context;
303    use crate::error::{Error, ErrorKind};
304    use crate::function::{Associativity, Notation};
305    use crate::token::Token;
306    use crate::token::Token::*;
307    use crate::Result;
308
309    /// Converts an `infix` notation expression to `rpn` (Reverse Polish Notation) using
310    /// the shunting yard algorithm.
311    ///
312    /// # Arguments
313    /// - token: The tokens of the expression to convert.
314    /// - context: the context which contains the variables, constants and functions.
315    ///
316    /// See: https://en.wikipedia.org/wiki/Shunting-yard_algorithm
317    pub fn infix_to_rpn<'a, N, C>(tokens: &[Token<N>], context: &C) -> Result<Vec<Token<N>>>
318    where
319        N: Clone + Debug,
320        C: Context<'a, N>,
321    {
322        let mut output = Vec::new();
323        let mut operators = Vec::new();
324        let mut arg_count: Vec<usize> = Vec::new();
325        let mut grouping_count: Vec<usize> = Vec::new();
326
327        let mut token_iterator = tokens.iter().enumerate().peekable();
328        while let Some((pos, token)) = token_iterator.next() {
329            match token {
330                Token::Number(_) | Token::Variable(_) | Token::Constant(_) => {
331                    push_number(context, &mut output, &mut operators, token)
332                }
333                Token::BinaryOperator(name) => {
334                    push_binary_function(context, &mut output, &mut operators, token, name)?;
335                }
336                Token::UnaryOperator(name) => {
337                    push_unary_function(context, &mut output, &mut operators, token, name)?
338                }
339                Token::Function(name) => {
340                    if !context.config().custom_function_call {
341                        // Checks the function call starts with a parentheses open
342                        // We only allow function arguments in a parentheses, so function calls
343                        // with custom grouping symbols are invalid eg: Max[1,2,3], Sum<2,4,6>
344                        if !token_iterator
345                            .peek()
346                            .map_or(false, |t| t.1.contains_symbol('('))
347                        {
348                            return Err(Error::new(
349                                ErrorKind::InvalidInput,
350                                format!(
351                                    "Function arguments of `{}` are not within a parentheses",
352                                    name
353                                ),
354                            ));
355                        }
356                    }
357
358                    arg_count.push(0);
359                    operators.push(token.clone());
360                }
361                Token::GroupingOpen(_) => {
362                    operators.push(token.clone());
363                    if !arg_count.is_empty() {
364                        grouping_count.push(pos);
365                    }
366                }
367                Token::GroupingClose(c) => {
368                    push_grouping_close(context, *c, &mut output, &mut operators, &mut arg_count)?;
369
370                    // Checking for empty grouping symbols: eg: `Random(())`, `()+2`
371                    if pos > 1 {
372                        if let Token::GroupingOpen(s) = tokens[pos - 1] {
373                            if context
374                                .config()
375                                .get_group_open_for(*c)
376                                .map_or(false, |v| v == s)
377                                && !tokens[pos - 2].is_function()
378                            {
379                                return Err(Error::new(
380                                    ErrorKind::InvalidInput,
381                                    format!(
382                                        // Empty grouping symbols: ()
383                                        "Empty grouping symbols: {}{}",
384                                        context.config().get_group_open_for(*c).unwrap(),
385                                        c
386                                    ),
387                                ));
388                            }
389                        }
390                    }
391
392                    if !arg_count.is_empty() {
393                        grouping_count.pop();
394                    }
395                }
396                Token::Comma => {
397                    check_comma_position(tokens, &grouping_count, pos)?;
398                    push_comma(&mut output, &mut operators, &mut arg_count)?
399                }
400                _ => {
401                    return Err(Error::new(
402                        ErrorKind::InvalidInput,
403                        format!("Invalid token: {:?}", token),
404                    ))
405                }
406            }
407
408            // If implicit multiplication
409            if context.config().implicit_mul {
410                if token.is_number() {
411                    // 2Max, 2PI, 2x, 2(4)
412                    if let Some(next_token) = token_iterator.peek() {
413                        match next_token.1 {
414                            Token::Function(_)
415                            | Token::Constant(_)
416                            | Token::Variable(_)
417                            | Token::GroupingOpen(_) => {
418                                operators.push(BinaryOperator('*'.to_string()));
419                            }
420                            _ => {}
421                        }
422                    }
423                } else if token.is_grouping_close() {
424                    //(2)2, (2)PI, (2)x, (4)(2), Sin(30)Cos(30), Tan(45)2
425                    if let Some(next_token) = token_iterator.peek() {
426                        match next_token.1 {
427                            Number(_) | Variable(_) | Constant(_) | Function(_)
428                            | GroupingOpen(_) => operators.push(BinaryOperator('*'.to_string())),
429                            _ => {}
430                        }
431                    }
432                }
433            }
434        }
435
436        while let Some(t) = operators.pop() {
437            if t.is_grouping_close() || t.is_grouping_close() {
438                return Err(Error::new(
439                    ErrorKind::InvalidExpression,
440                    "Misplace parentheses",
441                ));
442            }
443
444            output.push(t)
445        }
446
447        Ok(output)
448    }
449
450    fn check_comma_position<N>(
451        tokens: &[Token<N>],
452        grouping_count: &[usize],
453        pos: usize,
454    ) -> Result<()> {
455        // TODO: Moves this comma checks to its own function
456        if pos == 0 {
457            return Err(Error::new(ErrorKind::InvalidInput, "Misplaced comma"));
458        }
459
460        if tokens.get(pos - 1).map_or(false, |t| t.is_grouping_open()) {
461            // Invalid expression: `(,`
462            return Err(Error::new(ErrorKind::InvalidInput, "Misplaced comma: `(,`"));
463        }
464
465        if tokens.get(pos + 1).map_or(false, |t| t.is_grouping_close()) {
466            // Invalid expression: `,)`
467            return Err(Error::new(ErrorKind::InvalidInput, "Misplaced comma: `,)`"));
468        }
469
470        // We avoid all function arguments wrapped by grouping symbols,
471        // eg: Max((1,2,3))
472        if !grouping_count.is_empty()
473            && !tokens
474                .get(*grouping_count.last().unwrap() - 1)
475                .map_or(false, |t| t.is_function())
476        {
477            return Err(Error::new(ErrorKind::InvalidInput, "Misplaced comma"));
478        }
479
480        Ok(())
481    }
482
483    fn push_number<'a, N: Clone + Debug>(
484        context: &impl Context<'a, N>,
485        output: &mut Vec<Token<N>>,
486        operators: &mut Vec<Token<N>>,
487        token: &Token<N>,
488    ) {
489        output.push(token.clone());
490        if let Some(Token::UnaryOperator(op)) = operators.last() {
491            if context.get_unary_function(op).is_some() {
492                output.push(operators.pop().unwrap());
493            }
494        }
495    }
496
497    fn push_unary_function<'a, N: Clone + Debug>(
498        context: &impl Context<'a, N>,
499        output: &mut Vec<Token<N>>,
500        operators: &mut Vec<Token<N>>,
501        token: &Token<N>,
502        name: &str,
503    ) -> Result<()> {
504        if let Some(unary) = context.get_unary_function(name) {
505            match unary.notation() {
506                Notation::Prefix => {
507                    //+6
508                    operators.push(token.clone());
509                }
510                Notation::Postfix => {
511                    // 5!
512                    if !output.is_empty() {
513                        output.push(token.clone())
514                    } else {
515                        return Err(Error::new(
516                            ErrorKind::InvalidExpression,
517                            "Misplace unary operator",
518                        ));
519                    }
520                }
521            }
522
523            Ok(())
524        } else {
525            Err(Error::new(
526                ErrorKind::InvalidInput,
527                format!("Unary operator `{}` not found", name),
528            ))
529        }
530    }
531
532    fn push_binary_function<'a, N: Clone + Debug>(
533        context: &impl Context<'a, N>,
534        output: &mut Vec<Token<N>>,
535        operators: &mut Vec<Token<N>>,
536        token: &Token<N>,
537        name: &str,
538    ) -> Result<()> {
539        let operator = context.get_binary_function(name).ok_or_else(|| {
540            Error::new(
541                ErrorKind::InvalidInput,
542                format!("Binary function `{}` not found", name),
543            )
544        })?;
545
546        while let Some(t) = operators.last() {
547            if let Token::GroupingOpen(_) = t {
548                break;
549            }
550
551            if t.is_function() {
552                output.push(operators.pop().unwrap());
553            } else {
554                let top_operator = match t {
555                    Token::BinaryOperator(op) => context.get_binary_function(op),
556                    _ => {
557                        break;
558                    }
559                };
560
561                match top_operator {
562                    Some(top) => {
563                        if (top.precedence() > operator.precedence())
564                            || (top.precedence() == operator.precedence()
565                                && top.associativity() == Associativity::Left)
566                        {
567                            output.push(operators.pop().unwrap());
568                        } else {
569                            break;
570                        }
571                    }
572                    _ => {
573                        break;
574                    }
575                }
576            }
577        }
578
579        operators.push(token.clone());
580        Ok(())
581    }
582
583    fn push_grouping_close<'a, N: Clone + Debug>(
584        context: &impl Context<'a, N>,
585        group_close: char,
586        output: &mut Vec<Token<N>>,
587        operators: &mut Vec<Token<N>>,
588        arg_count: &mut Vec<usize>,
589    ) -> Result<()> {
590        // Flag used for detect misplaced grouping symbol.
591        let mut is_group_open = false;
592
593        // Pop tokens from the operator stack and push then into the output stack
594        // until a group close token is found.
595        while let Some(t) = operators.pop() {
596            match t {
597                Token::GroupingOpen(c) => {
598                    if let Some((_, close)) = context.config().get_group_symbol(c) {
599                        if close == group_close {
600                            is_group_open = true;
601                            // If `arg_count` is not empty we are inside a function.
602                            // So we pop the argument count and function token into the output stack.
603                            if !arg_count.is_empty() {
604                                if let Some(Token::Function(_)) = operators.last() {
605                                    let count = arg_count.pop().unwrap() + 1;
606                                    output.push(Token::ArgCount(count));
607                                    output.push(operators.pop().unwrap());
608                                }
609                            }
610                        }
611                    }
612
613                    break;
614                }
615                _ => output.push(t.clone()),
616            }
617        }
618
619        if !is_group_open {
620            Err(Error::new(
621                ErrorKind::InvalidExpression,
622                "Misplace grouping symbol",
623            ))
624        } else {
625            Ok(())
626        }
627    }
628
629    fn push_comma<N: Clone + Debug>(
630        output: &mut Vec<Token<N>>,
631        operators: &mut Vec<Token<N>>,
632        arg_count: &mut Vec<usize>,
633    ) -> Result<()> {
634        match arg_count.last_mut() {
635            None => {
636                return Err(Error::new(
637                    ErrorKind::InvalidExpression,
638                    "Comma found but not function",
639                ))
640            }
641            Some(n) => *n += 1,
642        }
643
644        let mut is_group_open = false;
645        while let Some(t) = operators.last() {
646            match t {
647                Token::GroupingOpen(_) => {
648                    is_group_open = true;
649                    break;
650                }
651                _ => {
652                    output.push(operators.pop().unwrap());
653                }
654            }
655        }
656
657        if !is_group_open {
658            Err(Error::new(ErrorKind::InvalidExpression, "Misplace comma"))
659        } else {
660            Ok(())
661        }
662    }
663
664    #[cfg(test)]
665    mod tests {
666        use super::*;
667        use crate::context::{Config, DefaultContext};
668
669        #[test]
670        fn unary_ops_test1() {
671            let context = &DefaultContext::new_checked();
672
673            assert_eq!(
674                infix_to_rpn(
675                    // -(+10) -> 10 + -
676                    &[
677                        UnaryOperator('-'.to_string()),
678                        GroupingOpen('('),
679                        UnaryOperator('+'.to_string()),
680                        Number(10),
681                        GroupingClose(')')
682                    ],
683                    context
684                )
685                .unwrap(),
686                [
687                    Number(10),
688                    UnaryOperator('+'.to_string()),
689                    UnaryOperator('-'.to_string())
690                ]
691            );
692        }
693
694        #[test]
695        fn binary_ops_test1() {
696            let context = &DefaultContext::new_checked();
697
698            assert_eq!(
699                infix_to_rpn(
700                    // 3 + 2 -> 3 2 +
701                    &[Number(3), BinaryOperator('+'.to_string()), Number(2)],
702                    context
703                )
704                .unwrap(),
705                [Number(3), Number(2), BinaryOperator('+'.to_string())]
706            );
707        }
708
709        #[test]
710        fn binary_ops_test2() {
711            let context = &DefaultContext::new_checked();
712
713            assert_eq!(
714                infix_to_rpn(
715                    // 2 + 3 * 5 -> 2 3 5 + *
716                    &[
717                        Number(2),
718                        BinaryOperator('+'.to_string()),
719                        Number(3),
720                        BinaryOperator('*'.to_string()),
721                        Number(5)
722                    ],
723                    context
724                )
725                .unwrap(),
726                [
727                    Number(2),
728                    Number(3),
729                    Number(5),
730                    BinaryOperator('*'.to_string()),
731                    BinaryOperator('+'.to_string())
732                ]
733            );
734        }
735
736        #[test]
737        fn binary_ops_test3() {
738            let context = &DefaultContext::new_checked();
739
740            assert_eq!(
741                infix_to_rpn(
742                    // 2 ^ 3 ^ 4 - 1
743                    &[
744                        Number(2),
745                        BinaryOperator('^'.to_string()),
746                        Number(3),
747                        BinaryOperator('^'.to_string()),
748                        Number(4),
749                        BinaryOperator('-'.to_string()),
750                        Number(1)
751                    ],
752                    context
753                )
754                .unwrap(),
755                [
756                    Number(2),
757                    Number(3),
758                    Number(4),
759                    BinaryOperator('^'.to_string()),
760                    BinaryOperator('^'.to_string()),
761                    Number(1),
762                    BinaryOperator('-'.to_string())
763                ]
764            );
765        }
766
767        #[test]
768        fn binary_ops_test4() {
769            let context = &DefaultContext::new_checked();
770
771            assert_eq!(
772                infix_to_rpn(
773                    // (5 + (-3)) ^ Max(1, 2*5, (30/2))
774                    &[
775                        GroupingOpen('('),
776                        Number(5),
777                        BinaryOperator('+'.to_string()),
778                        GroupingOpen('('),
779                        UnaryOperator('-'.to_string()),
780                        Number(3),
781                        GroupingClose(')'),
782                        GroupingClose(')'),
783                        BinaryOperator('^'.to_string()),
784                        Function("Max".to_string()),
785                        GroupingOpen('('),
786                        Number(1),
787                        Comma,
788                        Number(2),
789                        BinaryOperator('*'.to_string()),
790                        Number(5),
791                        Comma,
792                        GroupingOpen('('),
793                        Number(30),
794                        BinaryOperator('/'.to_string()),
795                        Number(2),
796                        GroupingClose(')'),
797                        GroupingClose(')'),
798                    ],
799                    context
800                )
801                .unwrap(),
802                [
803                    Number(5),
804                    Number(3),
805                    UnaryOperator('-'.to_string()),
806                    BinaryOperator('+'.to_string()),
807                    Number(1),
808                    Number(2),
809                    Number(5),
810                    BinaryOperator('*'.to_string()),
811                    Number(30),
812                    Number(2),
813                    BinaryOperator('/'.to_string()),
814                    ArgCount(3),
815                    Function("Max".to_string()),
816                    BinaryOperator('^'.to_string())
817                ]
818            );
819        }
820
821        #[test]
822        fn infix_ops_test() {
823            let context = &DefaultContext::new_checked();
824
825            assert_eq!(
826                infix_to_rpn(
827                    // 10 mod 2 -> 10 2 mod
828                    &[Number(10), BinaryOperator(String::from("mod")), Number(2)],
829                    context
830                )
831                .unwrap(),
832                [Number(10), Number(2), BinaryOperator(String::from("mod"))]
833            );
834        }
835
836        #[test]
837        fn function_test() {
838            let context = &DefaultContext::new_checked();
839
840            assert_eq!(
841                infix_to_rpn(
842                    // 5 * Sum(2, 3) -> 2 3 2arg Sum 5 *
843                    &[
844                        Number(5),
845                        BinaryOperator('*'.to_string()),
846                        Function(String::from("Sum")),
847                        GroupingOpen('('),
848                        Number(2),
849                        Comma,
850                        Number(3),
851                        GroupingClose(')'),
852                    ],
853                    context
854                )
855                .unwrap(),
856                [
857                    Number(5),
858                    Number(2),
859                    Number(3),
860                    ArgCount(2),
861                    Function(String::from("Sum")),
862                    BinaryOperator('*'.to_string()),
863                ]
864            );
865        }
866
867        #[test]
868        fn implicit_mul_test1() {
869            let config = Config::new().with_implicit_mul(true);
870            let context = DefaultContext::with_config_checked(config);
871
872            let infix = &[Token::Number(10), Token::Constant("PI".to_string())];
873            let rpn = infix_to_rpn(infix, &context).unwrap();
874            assert_eq!(
875                rpn,
876                &[
877                    Token::Number(10),
878                    Token::Constant("PI".to_string()),
879                    Token::BinaryOperator('*'.to_string())
880                ]
881            );
882        }
883
884        #[test]
885        fn implicit_mul_test2() {
886            let config = Config::new().with_implicit_mul(true);
887            let context = DefaultContext::with_config_checked(config);
888
889            let infix = &[
890                Token::GroupingOpen('('),
891                Token::Number(2),
892                Token::GroupingClose(')'),
893                Token::GroupingOpen('('),
894                Token::Number(3),
895                Token::GroupingClose(')'),
896            ];
897
898            let rpn = infix_to_rpn(infix, &context).unwrap();
899            assert_eq!(
900                rpn,
901                &[
902                    Token::Number(2),
903                    Token::Number(3),
904                    Token::BinaryOperator('*'.to_string())
905                ]
906            );
907        }
908    }
909}
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914    use crate::context::{Config, Grouping};
915    use crate::ops::math::Function;
916
917    #[test]
918    fn eval_test() {
919        let config = Config::new().with_grouping(Grouping::Parenthesis);
920        let evaluator: Evaluator<i64> =
921            Evaluator::with_context(DefaultContext::with_config_checked(config));
922
923        assert_eq!(evaluator.eval("(2 ^ 3) ^ 4").unwrap(), 4096);
924        assert_eq!(evaluator.eval("Min(10, 2) + Max(10, 2)").unwrap(), 12);
925        assert_eq!(
926            evaluator
927                .eval("Sum(1, 2, 3) * 2 - Max(2, 10/2, 2^3)")
928                .unwrap(),
929            4
930        );
931
932        assert!(evaluator.eval("5").is_ok());
933        assert!(evaluator.eval("-2").is_ok());
934        assert!(evaluator.eval("(10)").is_ok());
935        assert!(evaluator.eval("([(25)])").is_ok());
936        assert!(evaluator.eval("-(+(6))").is_ok());
937        assert!(evaluator.eval("+10").is_ok());
938        assert!(evaluator.eval("((10)+(((2)))*(3))").is_ok());
939        assert!(evaluator.eval("-(2)^(((4)))").is_ok());
940        assert!(evaluator.eval("-(+(-(+(5))))").is_ok());
941        assert!(evaluator.eval("10 - - +2").is_ok());
942        assert!(evaluator.eval("+2!").is_ok());
943        assert!(evaluator.eval("5 * Sin(40)").is_ok());
944        assert!(evaluator.eval("Sin(30) * 5").is_ok());
945        assert!(evaluator.eval("Cos(30) * Sin(30)").is_ok());
946
947        assert!(evaluator.eval("((20) + 2").is_err());
948        assert!(evaluator.eval("(1,23) + 1").is_err());
949        assert!(evaluator.eval("2^").is_err());
950        assert!(evaluator.eval("10 2").is_err());
951        assert!(evaluator.eval("2 3 +").is_err());
952        assert!(evaluator.eval("^10!").is_err());
953        assert!(evaluator.eval("8+").is_err());
954        assert!(evaluator.eval("([10)]").is_err());
955        assert!(evaluator.eval("()+5").is_err());
956        assert!(evaluator.eval("()+5").is_err());
957        assert!(evaluator.eval("Sum 2 3 4").is_err());
958        assert!(evaluator.eval("Max(,)").is_err());
959        assert!(evaluator.eval("Max(2, )").is_err());
960        assert!(evaluator.eval("Max( ,3)").is_err());
961        assert!(evaluator.eval("Max(2, 3,)").is_err());
962        assert!(evaluator.eval("Sum((10, 2, 3))").is_err());
963        assert!(evaluator.eval("(())").is_err());
964        assert!(evaluator.eval("Random(())").is_err());
965    }
966
967    #[test]
968    fn eval_implicit_mul_test() {
969        let config = Config::new().with_implicit_mul(true);
970        let context = DefaultContext::with_config_checked(config);
971        let mut evaluator: Evaluator<i64> = Evaluator::with_context(context);
972
973        evaluator.mut_context().set_variable("x", 10).unwrap();
974        assert_eq!(evaluator.eval("2x").unwrap(), 20);
975
976        evaluator.mut_context().set_variable("x", 5).unwrap();
977        assert_eq!(evaluator.eval("3x").unwrap(), 15);
978
979        assert!(evaluator.eval("2Sin(50)").is_ok());
980        assert!(evaluator.eval("(2)(4)").is_ok());
981        assert!(evaluator.eval("2(4)").is_ok());
982        assert!(evaluator.eval("(2)4").is_ok());
983        assert!(evaluator.eval("Cos(30) * 2").is_ok());
984        assert!(evaluator.eval("Cos(30)(2)").is_ok());
985
986        // Not allowed currently due looks like function call
987        assert!(evaluator.eval("5x(2)").is_err());
988
989        // Confusing expression
990        assert!(evaluator.eval("3 2Sin(50)").is_err());
991    }
992
993    #[test]
994    fn eval_tokens_test() {
995        let evaluator = Evaluator::new();
996
997        // 2 + 3
998        assert_eq!(
999            evaluator
1000                .eval_tokens(&[
1001                    Token::Number(3),
1002                    Token::BinaryOperator('+'.to_string()),
1003                    Token::Number(2)
1004                ])
1005                .unwrap(),
1006            5
1007        );
1008
1009        // 2 ^ 3 ^ 2
1010        assert_eq!(
1011            evaluator
1012                .eval_tokens(&[
1013                    Token::Number(2),
1014                    Token::BinaryOperator('^'.to_string()),
1015                    Token::Number(3),
1016                    Token::BinaryOperator('^'.to_string()),
1017                    Token::Number(2)
1018                ])
1019                .unwrap(),
1020            512
1021        );
1022
1023        // (2 ^ 3) ^ 4
1024        assert_eq!(
1025            evaluator
1026                .eval_tokens(&[
1027                    Token::GroupingOpen('('),
1028                    Token::Number(2),
1029                    Token::BinaryOperator('^'.to_string()),
1030                    Token::Number(3),
1031                    Token::GroupingClose(')'),
1032                    Token::BinaryOperator('^'.to_string()),
1033                    Token::Number(4)
1034                ])
1035                .unwrap(),
1036            4096
1037        );
1038    }
1039
1040    #[test]
1041    fn eval_using_variable_test() {
1042        let mut evaluator = Evaluator::new();
1043        evaluator.mut_context().set_variable("x", 10).unwrap();
1044
1045        assert_eq!(evaluator.eval("x + 2").unwrap(), 12);
1046    }
1047
1048    #[test]
1049    fn eval_with_alias_test() {
1050        struct SumFunction;
1051        impl Function<f64> for SumFunction {
1052            fn name(&self) -> &str {
1053                "sum"
1054            }
1055
1056            fn call(&self, args: &[f64]) -> Result<f64> {
1057                Ok(args.iter().sum())
1058            }
1059
1060            fn aliases(&self) -> Option<&[&str]> {
1061                Some(&["add", "∑"])
1062            }
1063        }
1064
1065        let mut context: DefaultContext<f64> = DefaultContext::new();
1066        context.add_function(SumFunction).unwrap();
1067
1068        let evaluator = Evaluator::with_context(context);
1069        assert_eq!(evaluator.eval("sum(1, 2, 3, 4)").unwrap(), 10.0);
1070        assert_eq!(evaluator.eval("add(1, 2, 3, 4)").unwrap(), 10.0);
1071        assert_eq!(evaluator.eval("∑(1, 2, 3, 4)").unwrap(), 10.0);
1072    }
1073}