Skip to main content

openscenario_rs/
expression.rs

1//! OpenSCENARIO expression parsing and evaluation engine
2//!
3//! This module provides:
4//! - Expression parsing for OpenSCENARIO's mathematical expressions
5//! - Expression evaluation with parameter substitution
6//! - Support for ${expression} syntax from the XSD schema
7//! - Comprehensive error handling for invalid expressions
8//!
9//! Supported operators: +, -, *, /, %, (, ), >, <, >=, <=, ==, !=
10//! Supported types: numeric literals, parameters, function calls, constants
11//! Supported functions: sin, cos, tan, sqrt, abs, floor, ceil, min, max
12//! Supported constants: PI, E
13//!
14//! XSD Pattern: `[$][{][ A-Za-z0-9_\+\-\*/%$\(\)\.,]*[\}]`
15
16use crate::error::{Error, Result};
17use std::collections::HashMap;
18use std::str::FromStr;
19
20/// Expression token types for parsing
21#[derive(Debug, Clone, PartialEq)]
22pub enum Token {
23    Number(f64),
24    Parameter(String),
25    Operator(Operator),
26    Function(String),
27    Constant(String),
28    LeftParen,
29    RightParen,
30    Comma,
31}
32
33/// Supported mathematical operators
34#[derive(Debug, Clone, PartialEq)]
35pub enum Operator {
36    Add,
37    Subtract,
38    Multiply,
39    Divide,
40    Modulo,
41    // Comparison operators
42    Greater,
43    Less,
44    GreaterEqual,
45    LessEqual,
46    Equal,
47    NotEqual,
48}
49
50/// Abstract syntax tree node for expressions
51#[derive(Debug, Clone, PartialEq)]
52pub enum Expr {
53    Number(f64),
54    Parameter(String),
55    Constant(String),
56    BinaryOp {
57        left: Box<Expr>,
58        operator: Operator,
59        right: Box<Expr>,
60    },
61    UnaryMinus(Box<Expr>),
62    FunctionCall {
63        name: String,
64        args: Vec<Expr>,
65    },
66}
67
68/// Expression parser for OpenSCENARIO mathematical expressions
69#[derive(Debug)]
70pub struct ExpressionParser {
71    tokens: Vec<Token>,
72    current: usize,
73}
74
75impl ExpressionParser {
76    /// Create a new parser with the given expression string
77    pub fn new(expr: &str) -> Result<Self> {
78        let tokens = Self::tokenize(expr)?;
79        Ok(Self { tokens, current: 0 })
80    }
81
82    /// Parse the expression into an AST
83    pub fn parse(&mut self) -> Result<Expr> {
84        let expr = self.parse_expression()?;
85        if self.current < self.tokens.len() {
86            return Err(Error::parse_error(
87                "expression",
88                "unexpected token after expression",
89            ));
90        }
91        Ok(expr)
92    }
93
94    /// Tokenize the input expression string
95    fn tokenize(input: &str) -> Result<Vec<Token>> {
96        let mut tokens = Vec::new();
97        let mut chars = input.chars().peekable();
98
99        while let Some(&ch) = chars.peek() {
100            match ch {
101                ' ' | '\t' | '\n' | '\r' => {
102                    chars.next();
103                }
104                '+' => {
105                    tokens.push(Token::Operator(Operator::Add));
106                    chars.next();
107                }
108                '-' => {
109                    tokens.push(Token::Operator(Operator::Subtract));
110                    chars.next();
111                }
112                '*' => {
113                    tokens.push(Token::Operator(Operator::Multiply));
114                    chars.next();
115                }
116                '/' => {
117                    tokens.push(Token::Operator(Operator::Divide));
118                    chars.next();
119                }
120                '%' => {
121                    tokens.push(Token::Operator(Operator::Modulo));
122                    chars.next();
123                }
124                '(' => {
125                    tokens.push(Token::LeftParen);
126                    chars.next();
127                }
128                ')' => {
129                    tokens.push(Token::RightParen);
130                    chars.next();
131                }
132                ',' => {
133                    tokens.push(Token::Comma);
134                    chars.next();
135                }
136                '>' => {
137                    chars.next();
138                    if chars.peek() == Some(&'=') {
139                        chars.next();
140                        tokens.push(Token::Operator(Operator::GreaterEqual));
141                    } else {
142                        tokens.push(Token::Operator(Operator::Greater));
143                    }
144                }
145                '<' => {
146                    chars.next();
147                    if chars.peek() == Some(&'=') {
148                        chars.next();
149                        tokens.push(Token::Operator(Operator::LessEqual));
150                    } else {
151                        tokens.push(Token::Operator(Operator::Less));
152                    }
153                }
154                '=' => {
155                    chars.next();
156                    if chars.peek() == Some(&'=') {
157                        chars.next();
158                        tokens.push(Token::Operator(Operator::Equal));
159                    } else {
160                        return Err(Error::parse_error(
161                            input,
162                            "single '=' not supported, use '==' for equality",
163                        ));
164                    }
165                }
166                '!' => {
167                    chars.next();
168                    if chars.peek() == Some(&'=') {
169                        chars.next();
170                        tokens.push(Token::Operator(Operator::NotEqual));
171                    } else {
172                        return Err(Error::parse_error(
173                            input,
174                            "single '!' not supported, use '!=' for inequality",
175                        ));
176                    }
177                }
178                '$' => {
179                    // Parameter reference: ${paramName} or $paramName
180                    chars.next();
181                    if chars.peek() == Some(&'{') {
182                        chars.next();
183                        let param_name = Self::read_until_char(&mut chars, '}')?;
184                        if chars.next() != Some('}') {
185                            return Err(Error::parse_error(
186                                input,
187                                "missing closing brace in parameter reference",
188                            ));
189                        }
190                        tokens.push(Token::Parameter(param_name));
191                    } else {
192                        // Simple $paramName format (deprecated but supported)
193                        let param_name = Self::read_identifier(&mut chars)?;
194                        if param_name.is_empty() {
195                            return Err(Error::parse_error(input, "empty parameter name"));
196                        }
197                        tokens.push(Token::Parameter(param_name));
198                    }
199                }
200                '0'..='9' | '.' => {
201                    let number = Self::read_number(&mut chars)?;
202                    tokens.push(Token::Number(number));
203                }
204                'a'..='z' | 'A'..='Z' | '_' => {
205                    let identifier = Self::read_identifier(&mut chars)?;
206                    // Check if it's followed by '(' for function call
207                    if chars.peek() == Some(&'(') {
208                        tokens.push(Token::Function(identifier));
209                    } else if Self::is_constant(&identifier) {
210                        tokens.push(Token::Constant(identifier));
211                    } else {
212                        tokens.push(Token::Parameter(identifier));
213                    }
214                }
215                _ => {
216                    return Err(Error::parse_error(
217                        input,
218                        &format!("unexpected character: '{}'", ch),
219                    ));
220                }
221            }
222        }
223
224        Ok(tokens)
225    }
226
227    /// Read a number from the character stream
228    fn read_number(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<f64> {
229        let mut number_str = String::new();
230        let mut has_dot = false;
231
232        while let Some(&ch) = chars.peek() {
233            match ch {
234                '0'..='9' => {
235                    number_str.push(ch);
236                    chars.next();
237                }
238                '.' if !has_dot => {
239                    has_dot = true;
240                    number_str.push(ch);
241                    chars.next();
242                }
243                'e' | 'E' => {
244                    // Scientific notation
245                    number_str.push(ch);
246                    chars.next();
247                    if chars.peek() == Some(&'+') || chars.peek() == Some(&'-') {
248                        number_str.push(chars.next().unwrap());
249                    }
250                }
251                _ => break,
252            }
253        }
254
255        number_str
256            .parse::<f64>()
257            .map_err(|_| Error::parse_error(&number_str, "invalid number format"))
258    }
259
260    /// Read an identifier from the character stream
261    fn read_identifier(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<String> {
262        let mut identifier = String::new();
263
264        while let Some(&ch) = chars.peek() {
265            match ch {
266                'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {
267                    identifier.push(ch);
268                    chars.next();
269                }
270                _ => break,
271            }
272        }
273
274        if identifier.is_empty() {
275            return Err(Error::parse_error(
276                "",
277                "Expected identifier but found empty string",
278            ));
279        }
280
281        Ok(identifier)
282    }
283
284    /// Read characters until the specified terminator
285    fn read_until_char(
286        chars: &mut std::iter::Peekable<std::str::Chars>,
287        terminator: char,
288    ) -> Result<String> {
289        let mut content = String::new();
290
291        while let Some(&ch) = chars.peek() {
292            if ch == terminator {
293                break;
294            }
295            content.push(ch);
296            chars.next();
297        }
298
299        Ok(content)
300    }
301
302    /// Check if identifier is a mathematical constant
303    fn is_constant(identifier: &str) -> bool {
304        matches!(identifier, "PI" | "E")
305    }
306
307    /// Parse an expression with precedence handling
308    fn parse_expression(&mut self) -> Result<Expr> {
309        self.parse_comparison()
310    }
311
312    /// Parse comparison expressions (>, <, >=, <=, ==, !=)
313    fn parse_comparison(&mut self) -> Result<Expr> {
314        let mut left = self.parse_additive()?;
315
316        while self.current < self.tokens.len() {
317            match &self.tokens[self.current] {
318                Token::Operator(Operator::Greater) => {
319                    self.current += 1;
320                    let right = self.parse_additive()?;
321                    left = Expr::BinaryOp {
322                        left: Box::new(left),
323                        operator: Operator::Greater,
324                        right: Box::new(right),
325                    };
326                }
327                Token::Operator(Operator::Less) => {
328                    self.current += 1;
329                    let right = self.parse_additive()?;
330                    left = Expr::BinaryOp {
331                        left: Box::new(left),
332                        operator: Operator::Less,
333                        right: Box::new(right),
334                    };
335                }
336                Token::Operator(Operator::GreaterEqual) => {
337                    self.current += 1;
338                    let right = self.parse_additive()?;
339                    left = Expr::BinaryOp {
340                        left: Box::new(left),
341                        operator: Operator::GreaterEqual,
342                        right: Box::new(right),
343                    };
344                }
345                Token::Operator(Operator::LessEqual) => {
346                    self.current += 1;
347                    let right = self.parse_additive()?;
348                    left = Expr::BinaryOp {
349                        left: Box::new(left),
350                        operator: Operator::LessEqual,
351                        right: Box::new(right),
352                    };
353                }
354                Token::Operator(Operator::Equal) => {
355                    self.current += 1;
356                    let right = self.parse_additive()?;
357                    left = Expr::BinaryOp {
358                        left: Box::new(left),
359                        operator: Operator::Equal,
360                        right: Box::new(right),
361                    };
362                }
363                Token::Operator(Operator::NotEqual) => {
364                    self.current += 1;
365                    let right = self.parse_additive()?;
366                    left = Expr::BinaryOp {
367                        left: Box::new(left),
368                        operator: Operator::NotEqual,
369                        right: Box::new(right),
370                    };
371                }
372                _ => break,
373            }
374        }
375
376        Ok(left)
377    }
378
379    /// Parse additive expressions (+ and -)
380    fn parse_additive(&mut self) -> Result<Expr> {
381        let mut left = self.parse_multiplicative()?;
382
383        while self.current < self.tokens.len() {
384            match &self.tokens[self.current] {
385                Token::Operator(Operator::Add) => {
386                    self.current += 1;
387                    let right = self.parse_multiplicative()?;
388                    left = Expr::BinaryOp {
389                        left: Box::new(left),
390                        operator: Operator::Add,
391                        right: Box::new(right),
392                    };
393                }
394                Token::Operator(Operator::Subtract) => {
395                    self.current += 1;
396                    let right = self.parse_multiplicative()?;
397                    left = Expr::BinaryOp {
398                        left: Box::new(left),
399                        operator: Operator::Subtract,
400                        right: Box::new(right),
401                    };
402                }
403                _ => break,
404            }
405        }
406
407        Ok(left)
408    }
409
410    /// Parse multiplicative expressions (*, /, %)
411    fn parse_multiplicative(&mut self) -> Result<Expr> {
412        let mut left = self.parse_unary()?;
413
414        while self.current < self.tokens.len() {
415            match &self.tokens[self.current] {
416                Token::Operator(Operator::Multiply) => {
417                    self.current += 1;
418                    let right = self.parse_unary()?;
419                    left = Expr::BinaryOp {
420                        left: Box::new(left),
421                        operator: Operator::Multiply,
422                        right: Box::new(right),
423                    };
424                }
425                Token::Operator(Operator::Divide) => {
426                    self.current += 1;
427                    let right = self.parse_unary()?;
428                    left = Expr::BinaryOp {
429                        left: Box::new(left),
430                        operator: Operator::Divide,
431                        right: Box::new(right),
432                    };
433                }
434                Token::Operator(Operator::Modulo) => {
435                    self.current += 1;
436                    let right = self.parse_unary()?;
437                    left = Expr::BinaryOp {
438                        left: Box::new(left),
439                        operator: Operator::Modulo,
440                        right: Box::new(right),
441                    };
442                }
443                _ => break,
444            }
445        }
446
447        Ok(left)
448    }
449
450    /// Parse unary expressions (unary minus)
451    fn parse_unary(&mut self) -> Result<Expr> {
452        if self.current < self.tokens.len() {
453            if let Token::Operator(Operator::Subtract) = &self.tokens[self.current] {
454                self.current += 1;
455                let expr = self.parse_primary()?;
456                return Ok(Expr::UnaryMinus(Box::new(expr)));
457            }
458        }
459        self.parse_primary()
460    }
461
462    /// Parse primary expressions (numbers, parameters, parentheses)
463    fn parse_primary(&mut self) -> Result<Expr> {
464        if self.current >= self.tokens.len() {
465            return Err(Error::validation_error(
466                "expression",
467                "unexpected end of expression",
468            ));
469        }
470
471        let token = &self.tokens[self.current].clone();
472        self.current += 1;
473
474        match token {
475            Token::Number(n) => Ok(Expr::Number(*n)),
476            Token::Parameter(name) => Ok(Expr::Parameter(name.clone())),
477            Token::Constant(name) => Ok(Expr::Constant(name.clone())),
478            Token::Function(name) => {
479                // Function call: function_name(arg1, arg2, ...)
480                if self.current >= self.tokens.len()
481                    || self.tokens[self.current] != Token::LeftParen
482                {
483                    return Err(Error::validation_error(
484                        "expression",
485                        "expected '(' after function name",
486                    ));
487                }
488                self.current += 1; // consume '('
489
490                let mut args = Vec::new();
491
492                // Handle empty function calls like sin()
493                if self.current < self.tokens.len()
494                    && self.tokens[self.current] != Token::RightParen
495                {
496                    loop {
497                        args.push(self.parse_expression()?);
498
499                        if self.current >= self.tokens.len() {
500                            return Err(Error::validation_error(
501                                "expression",
502                                "missing closing parenthesis in function call",
503                            ));
504                        }
505
506                        match &self.tokens[self.current] {
507                            Token::Comma => {
508                                self.current += 1; // consume comma
509                                continue;
510                            }
511                            Token::RightParen => break,
512                            _ => {
513                                return Err(Error::validation_error(
514                                    "expression",
515                                    "expected ',' or ')' in function call",
516                                ))
517                            }
518                        }
519                    }
520                }
521
522                if self.current >= self.tokens.len()
523                    || self.tokens[self.current] != Token::RightParen
524                {
525                    return Err(Error::validation_error(
526                        "expression",
527                        "missing closing parenthesis in function call",
528                    ));
529                }
530                self.current += 1; // consume ')'
531
532                Ok(Expr::FunctionCall {
533                    name: name.clone(),
534                    args,
535                })
536            }
537            Token::LeftParen => {
538                let expr = self.parse_expression()?;
539                if self.current >= self.tokens.len()
540                    || self.tokens[self.current] != Token::RightParen
541                {
542                    return Err(Error::validation_error(
543                        "expression",
544                        "missing closing parenthesis",
545                    ));
546                }
547                self.current += 1;
548                Ok(expr)
549            }
550            _ => Err(Error::validation_error(
551                "expression",
552                &format!("unexpected token: {:?}", token),
553            )),
554        }
555    }
556}
557
558/// Expression evaluator for OpenSCENARIO expressions
559pub struct ExpressionEvaluator {
560    parameters: HashMap<String, String>,
561}
562
563impl ExpressionEvaluator {
564    /// Create a new evaluator with the given parameter context
565    pub fn new(parameters: HashMap<String, String>) -> Self {
566        Self { parameters }
567    }
568
569    /// Evaluate an expression AST to a numeric result
570    pub fn evaluate(&self, expr: &Expr) -> Result<f64> {
571        match expr {
572            Expr::Number(n) => Ok(*n),
573            Expr::Parameter(name) => {
574                let param_value = self
575                    .parameters
576                    .get(name)
577                    .ok_or_else(|| Error::parameter_error(name, "parameter not found"))?;
578
579                param_value.parse::<f64>().map_err(|e| {
580                    Error::parameter_error(
581                        name,
582                        &format!("failed to parse '{}': {}", param_value, e),
583                    )
584                })
585            }
586            Expr::Constant(name) => match name.as_str() {
587                "PI" => Ok(std::f64::consts::PI),
588                "E" => Ok(std::f64::consts::E),
589                _ => Err(Error::parameter_error(name, "unknown constant")),
590            },
591            Expr::BinaryOp {
592                left,
593                operator,
594                right,
595            } => {
596                let left_val = self.evaluate(left)?;
597                let right_val = self.evaluate(right)?;
598
599                match operator {
600                    Operator::Add => Ok(left_val + right_val),
601                    Operator::Subtract => Ok(left_val - right_val),
602                    Operator::Multiply => Ok(left_val * right_val),
603                    Operator::Divide => {
604                        if right_val == 0.0 {
605                            Err(Error::parameter_error("division", "division by zero"))
606                        } else {
607                            Ok(left_val / right_val)
608                        }
609                    }
610                    Operator::Modulo => {
611                        if right_val == 0.0 {
612                            Err(Error::parameter_error("modulo", "modulo by zero"))
613                        } else {
614                            Ok(left_val % right_val)
615                        }
616                    }
617                    Operator::Greater => Ok(if left_val > right_val { 1.0 } else { 0.0 }),
618                    Operator::Less => Ok(if left_val < right_val { 1.0 } else { 0.0 }),
619                    Operator::GreaterEqual => Ok(if left_val >= right_val { 1.0 } else { 0.0 }),
620                    Operator::LessEqual => Ok(if left_val <= right_val { 1.0 } else { 0.0 }),
621                    Operator::Equal => Ok(if (left_val - right_val).abs() < f64::EPSILON {
622                        1.0
623                    } else {
624                        0.0
625                    }),
626                    Operator::NotEqual => Ok(if (left_val - right_val).abs() >= f64::EPSILON {
627                        1.0
628                    } else {
629                        0.0
630                    }),
631                }
632            }
633            Expr::UnaryMinus(expr) => {
634                let val = self.evaluate(expr)?;
635                Ok(-val)
636            }
637            Expr::FunctionCall { name, args } => self.evaluate_function(name, args),
638        }
639    }
640
641    /// Evaluate a function call
642    fn evaluate_function(&self, name: &str, args: &[Expr]) -> Result<f64> {
643        match name {
644            "sin" => {
645                if args.len() != 1 {
646                    return Err(Error::parameter_error(
647                        name,
648                        "sin() requires exactly 1 argument",
649                    ));
650                }
651                let arg = self.evaluate(&args[0])?;
652                Ok(arg.sin())
653            }
654            "cos" => {
655                if args.len() != 1 {
656                    return Err(Error::parameter_error(
657                        name,
658                        "cos() requires exactly 1 argument",
659                    ));
660                }
661                let arg = self.evaluate(&args[0])?;
662                Ok(arg.cos())
663            }
664            "tan" => {
665                if args.len() != 1 {
666                    return Err(Error::parameter_error(
667                        name,
668                        "tan() requires exactly 1 argument",
669                    ));
670                }
671                let arg = self.evaluate(&args[0])?;
672                Ok(arg.tan())
673            }
674            "sqrt" => {
675                if args.len() != 1 {
676                    return Err(Error::parameter_error(
677                        name,
678                        "sqrt() requires exactly 1 argument",
679                    ));
680                }
681                let arg = self.evaluate(&args[0])?;
682                if arg < 0.0 {
683                    return Err(Error::parameter_error(name, "sqrt() of negative number"));
684                }
685                Ok(arg.sqrt())
686            }
687            "abs" => {
688                if args.len() != 1 {
689                    return Err(Error::parameter_error(
690                        name,
691                        "abs() requires exactly 1 argument",
692                    ));
693                }
694                let arg = self.evaluate(&args[0])?;
695                Ok(arg.abs())
696            }
697            "floor" => {
698                if args.len() != 1 {
699                    return Err(Error::parameter_error(
700                        name,
701                        "floor() requires exactly 1 argument",
702                    ));
703                }
704                let arg = self.evaluate(&args[0])?;
705                Ok(arg.floor())
706            }
707            "ceil" => {
708                if args.len() != 1 {
709                    return Err(Error::parameter_error(
710                        name,
711                        "ceil() requires exactly 1 argument",
712                    ));
713                }
714                let arg = self.evaluate(&args[0])?;
715                Ok(arg.ceil())
716            }
717            "min" => {
718                if args.len() != 2 {
719                    return Err(Error::parameter_error(
720                        name,
721                        "min() requires exactly 2 arguments",
722                    ));
723                }
724                let arg1 = self.evaluate(&args[0])?;
725                let arg2 = self.evaluate(&args[1])?;
726                Ok(arg1.min(arg2))
727            }
728            "max" => {
729                if args.len() != 2 {
730                    return Err(Error::parameter_error(
731                        name,
732                        "max() requires exactly 2 arguments",
733                    ));
734                }
735                let arg1 = self.evaluate(&args[0])?;
736                let arg2 = self.evaluate(&args[1])?;
737                Ok(arg1.max(arg2))
738            }
739            _ => Err(Error::parameter_error(name, "unknown function")),
740        }
741    }
742}
743
744/// Parse and evaluate an OpenSCENARIO expression
745pub fn evaluate_expression<T>(expr: &str, params: &HashMap<String, String>) -> Result<T>
746where
747    T: FromStr,
748    T::Err: std::fmt::Display,
749{
750    let mut parser = ExpressionParser::new(expr)?;
751    let ast = parser.parse()?;
752    let evaluator = ExpressionEvaluator::new(params.clone());
753    let result = evaluator.evaluate(&ast)?;
754
755    // Convert the numeric result to the target type
756    let result_str = result.to_string();
757    result_str.parse::<T>().map_err(|e| {
758        Error::parameter_error(
759            expr,
760            &format!("failed to parse result '{}': {}", result_str, e),
761        )
762    })
763}
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768
769    #[test]
770    fn test_tokenize_numbers() {
771        let tokens = ExpressionParser::tokenize("123.45").unwrap();
772        assert_eq!(tokens, vec![Token::Number(123.45)]);
773
774        let tokens = ExpressionParser::tokenize("1.5e-3").unwrap();
775        assert_eq!(tokens, vec![Token::Number(0.0015)]);
776    }
777
778    #[test]
779    fn test_tokenize_operators() {
780        let tokens = ExpressionParser::tokenize("+ - * / %").unwrap();
781        assert_eq!(
782            tokens,
783            vec![
784                Token::Operator(Operator::Add),
785                Token::Operator(Operator::Subtract),
786                Token::Operator(Operator::Multiply),
787                Token::Operator(Operator::Divide),
788                Token::Operator(Operator::Modulo),
789            ]
790        );
791    }
792
793    #[test]
794    fn test_tokenize_parameters() {
795        let tokens = ExpressionParser::tokenize("${speed} + $velocity").unwrap();
796        assert_eq!(
797            tokens,
798            vec![
799                Token::Parameter("speed".to_string()),
800                Token::Operator(Operator::Add),
801                Token::Parameter("velocity".to_string()),
802            ]
803        );
804    }
805
806    #[test]
807    fn test_parse_simple_expression() {
808        let mut parser = ExpressionParser::new("2 + 3").unwrap();
809        let ast = parser.parse().unwrap();
810
811        match ast {
812            Expr::BinaryOp {
813                left,
814                operator,
815                right,
816            } => {
817                assert_eq!(*left, Expr::Number(2.0));
818                assert_eq!(operator, Operator::Add);
819                assert_eq!(*right, Expr::Number(3.0));
820            }
821            _ => panic!("Expected binary operation"),
822        }
823    }
824
825    #[test]
826    fn test_parse_precedence() {
827        let mut parser = ExpressionParser::new("2 + 3 * 4").unwrap();
828        let ast = parser.parse().unwrap();
829
830        // Should parse as 2 + (3 * 4), not (2 + 3) * 4
831        match ast {
832            Expr::BinaryOp {
833                left,
834                operator,
835                right,
836            } => {
837                assert_eq!(*left, Expr::Number(2.0));
838                assert_eq!(operator, Operator::Add);
839                match *right {
840                    Expr::BinaryOp {
841                        left,
842                        operator,
843                        right,
844                    } => {
845                        assert_eq!(*left, Expr::Number(3.0));
846                        assert_eq!(operator, Operator::Multiply);
847                        assert_eq!(*right, Expr::Number(4.0));
848                    }
849                    _ => panic!("Expected multiplication on right side"),
850                }
851            }
852            _ => panic!("Expected addition at root"),
853        }
854    }
855
856    #[test]
857    fn test_parse_parentheses() {
858        let mut parser = ExpressionParser::new("(2 + 3) * 4").unwrap();
859        let ast = parser.parse().unwrap();
860
861        // Should parse as (2 + 3) * 4
862        match ast {
863            Expr::BinaryOp {
864                left,
865                operator,
866                right,
867            } => {
868                assert_eq!(operator, Operator::Multiply);
869                assert_eq!(*right, Expr::Number(4.0));
870                match *left {
871                    Expr::BinaryOp {
872                        left,
873                        operator,
874                        right,
875                    } => {
876                        assert_eq!(*left, Expr::Number(2.0));
877                        assert_eq!(operator, Operator::Add);
878                        assert_eq!(*right, Expr::Number(3.0));
879                    }
880                    _ => panic!("Expected addition in parentheses"),
881                }
882            }
883            _ => panic!("Expected multiplication at root"),
884        }
885    }
886
887    #[test]
888    fn test_evaluate_simple() {
889        let params = HashMap::new();
890        let evaluator = ExpressionEvaluator::new(params);
891
892        let expr = Expr::BinaryOp {
893            left: Box::new(Expr::Number(2.0)),
894            operator: Operator::Add,
895            right: Box::new(Expr::Number(3.0)),
896        };
897
898        let result = evaluator.evaluate(&expr).unwrap();
899        assert_eq!(result, 5.0);
900    }
901
902    #[test]
903    fn test_evaluate_with_parameters() {
904        let mut params = HashMap::new();
905        params.insert("speed".to_string(), "30.0".to_string());
906        params.insert("acceleration".to_string(), "2.5".to_string());
907
908        let evaluator = ExpressionEvaluator::new(params);
909
910        let expr = Expr::BinaryOp {
911            left: Box::new(Expr::Parameter("speed".to_string())),
912            operator: Operator::Add,
913            right: Box::new(Expr::Parameter("acceleration".to_string())),
914        };
915
916        let result = evaluator.evaluate(&expr).unwrap();
917        assert_eq!(result, 32.5);
918    }
919
920    #[test]
921    fn test_evaluate_division_by_zero() {
922        let params = HashMap::new();
923        let evaluator = ExpressionEvaluator::new(params);
924
925        let expr = Expr::BinaryOp {
926            left: Box::new(Expr::Number(5.0)),
927            operator: Operator::Divide,
928            right: Box::new(Expr::Number(0.0)),
929        };
930
931        let result = evaluator.evaluate(&expr);
932        assert!(result.is_err());
933    }
934
935    #[test]
936    fn test_end_to_end_evaluation() {
937        let mut params = HashMap::new();
938        params.insert("speed".to_string(), "30.0".to_string());
939        params.insert("time".to_string(), "2.0".to_string());
940
941        // Test: speed * time + 10
942        let result: f64 = evaluate_expression("${speed} * ${time} + 10", &params).unwrap();
943        assert_eq!(result, 70.0);
944
945        // Test: (speed + 10) / time
946        let result: f64 = evaluate_expression("(${speed} + 10) / ${time}", &params).unwrap();
947        assert_eq!(result, 20.0);
948    }
949
950    #[test]
951    fn test_complex_expression() {
952        let mut params = HashMap::new();
953        params.insert("a".to_string(), "5.0".to_string());
954        params.insert("b".to_string(), "3.0".to_string());
955        params.insert("c".to_string(), "2.0".to_string());
956
957        // Test: a * (b + c) - b / c
958        let result: f64 =
959            evaluate_expression("${a} * (${b} + ${c}) - ${b} / ${c}", &params).unwrap();
960        assert_eq!(result, 23.5); // 5 * (3 + 2) - 3 / 2 = 25 - 1.5 = 23.5
961    }
962
963    #[test]
964    fn test_unary_minus() {
965        let mut params = HashMap::new();
966        params.insert("value".to_string(), "10.0".to_string());
967
968        let result: f64 = evaluate_expression("-${value} + 5", &params).unwrap();
969        assert_eq!(result, -5.0);
970
971        let result: f64 = evaluate_expression("-(${value} + 5)", &params).unwrap();
972        assert_eq!(result, -15.0);
973    }
974
975    #[test]
976    fn test_mathematical_constants() {
977        let params = HashMap::new();
978
979        // Test PI constant
980        let result: f64 = evaluate_expression("PI", &params).unwrap();
981        assert!((result - std::f64::consts::PI).abs() < f64::EPSILON);
982
983        // Test E constant
984        let result: f64 = evaluate_expression("E", &params).unwrap();
985        assert!((result - std::f64::consts::E).abs() < f64::EPSILON);
986
987        // Test constants in expressions
988        let result: f64 = evaluate_expression("2 * PI", &params).unwrap();
989        assert!((result - 2.0 * std::f64::consts::PI).abs() < f64::EPSILON);
990    }
991
992    #[test]
993    fn test_trigonometric_functions() {
994        let params = HashMap::new();
995
996        // Test sin function
997        let result: f64 = evaluate_expression("sin(0)", &params).unwrap();
998        assert!((result - 0.0).abs() < f64::EPSILON);
999
1000        let result: f64 = evaluate_expression("sin(PI / 2)", &params).unwrap();
1001        assert!((result - 1.0).abs() < 1e-10);
1002
1003        // Test cos function
1004        let result: f64 = evaluate_expression("cos(0)", &params).unwrap();
1005        assert!((result - 1.0).abs() < f64::EPSILON);
1006
1007        let result: f64 = evaluate_expression("cos(PI)", &params).unwrap();
1008        assert!((result - (-1.0)).abs() < 1e-10);
1009
1010        // Test tan function
1011        let result: f64 = evaluate_expression("tan(0)", &params).unwrap();
1012        assert!((result - 0.0).abs() < f64::EPSILON);
1013    }
1014
1015    #[test]
1016    fn test_mathematical_functions() {
1017        let params = HashMap::new();
1018
1019        // Test sqrt function
1020        let result: f64 = evaluate_expression("sqrt(4)", &params).unwrap();
1021        assert_eq!(result, 2.0);
1022
1023        let result: f64 = evaluate_expression("sqrt(9)", &params).unwrap();
1024        assert_eq!(result, 3.0);
1025
1026        // Test abs function
1027        let result: f64 = evaluate_expression("abs(-5)", &params).unwrap();
1028        assert_eq!(result, 5.0);
1029
1030        let result: f64 = evaluate_expression("abs(3.14)", &params).unwrap();
1031        assert_eq!(result, 3.14);
1032
1033        // Test floor function
1034        let result: f64 = evaluate_expression("floor(3.7)", &params).unwrap();
1035        assert_eq!(result, 3.0);
1036
1037        let result: f64 = evaluate_expression("floor(-2.3)", &params).unwrap();
1038        assert_eq!(result, -3.0);
1039
1040        // Test ceil function
1041        let result: f64 = evaluate_expression("ceil(3.2)", &params).unwrap();
1042        assert_eq!(result, 4.0);
1043
1044        let result: f64 = evaluate_expression("ceil(-2.7)", &params).unwrap();
1045        assert_eq!(result, -2.0);
1046    }
1047
1048    #[test]
1049    fn test_min_max_functions() {
1050        let params = HashMap::new();
1051
1052        // Test min function
1053        let result: f64 = evaluate_expression("min(5, 3)", &params).unwrap();
1054        assert_eq!(result, 3.0);
1055
1056        let result: f64 = evaluate_expression("min(-2, 1)", &params).unwrap();
1057        assert_eq!(result, -2.0);
1058
1059        // Test max function
1060        let result: f64 = evaluate_expression("max(5, 3)", &params).unwrap();
1061        assert_eq!(result, 5.0);
1062
1063        let result: f64 = evaluate_expression("max(-2, 1)", &params).unwrap();
1064        assert_eq!(result, 1.0);
1065    }
1066
1067    #[test]
1068    fn test_comparison_operators() {
1069        let params = HashMap::new();
1070
1071        // Test greater than
1072        let result: f64 = evaluate_expression("5 > 3", &params).unwrap();
1073        assert_eq!(result, 1.0);
1074
1075        let result: f64 = evaluate_expression("2 > 5", &params).unwrap();
1076        assert_eq!(result, 0.0);
1077
1078        // Test less than
1079        let result: f64 = evaluate_expression("3 < 5", &params).unwrap();
1080        assert_eq!(result, 1.0);
1081
1082        let result: f64 = evaluate_expression("5 < 3", &params).unwrap();
1083        assert_eq!(result, 0.0);
1084
1085        // Test greater than or equal
1086        let result: f64 = evaluate_expression("5 >= 5", &params).unwrap();
1087        assert_eq!(result, 1.0);
1088
1089        let result: f64 = evaluate_expression("5 >= 3", &params).unwrap();
1090        assert_eq!(result, 1.0);
1091
1092        let result: f64 = evaluate_expression("3 >= 5", &params).unwrap();
1093        assert_eq!(result, 0.0);
1094
1095        // Test less than or equal
1096        let result: f64 = evaluate_expression("3 <= 3", &params).unwrap();
1097        assert_eq!(result, 1.0);
1098
1099        let result: f64 = evaluate_expression("3 <= 5", &params).unwrap();
1100        assert_eq!(result, 1.0);
1101
1102        let result: f64 = evaluate_expression("5 <= 3", &params).unwrap();
1103        assert_eq!(result, 0.0);
1104
1105        // Test equality
1106        let result: f64 = evaluate_expression("5 == 5", &params).unwrap();
1107        assert_eq!(result, 1.0);
1108
1109        let result: f64 = evaluate_expression("5 == 3", &params).unwrap();
1110        assert_eq!(result, 0.0);
1111
1112        // Test inequality
1113        let result: f64 = evaluate_expression("5 != 3", &params).unwrap();
1114        assert_eq!(result, 1.0);
1115
1116        let result: f64 = evaluate_expression("5 != 5", &params).unwrap();
1117        assert_eq!(result, 0.0);
1118    }
1119
1120    #[test]
1121    fn test_complex_expressions_with_functions() {
1122        let mut params = HashMap::new();
1123        params.insert("angle".to_string(), "0.5".to_string());
1124        params.insert("radius".to_string(), "10.0".to_string());
1125
1126        // Test complex trigonometric expression: radius * sin(angle)
1127        let result: f64 = evaluate_expression("${radius} * sin(${angle})", &params).unwrap();
1128        assert!((result - 10.0 * 0.5_f64.sin()).abs() < 1e-10);
1129
1130        // Test with constants: 2 * PI * radius
1131        let result: f64 = evaluate_expression("2 * PI * ${radius}", &params).unwrap();
1132        assert!((result - 2.0 * std::f64::consts::PI * 10.0).abs() < 1e-10);
1133
1134        // Test nested functions: sqrt(abs(-16))
1135        let result: f64 = evaluate_expression("sqrt(abs(-16))", &params).unwrap();
1136        assert_eq!(result, 4.0);
1137    }
1138
1139    #[test]
1140    fn test_function_error_handling() {
1141        let params = HashMap::new();
1142
1143        // Test wrong number of arguments
1144        assert!(evaluate_expression::<f64>("sin(1, 2)", &params).is_err());
1145        assert!(evaluate_expression::<f64>("sqrt()", &params).is_err());
1146        assert!(evaluate_expression::<f64>("min(5)", &params).is_err());
1147        assert!(evaluate_expression::<f64>("max(1, 2, 3)", &params).is_err());
1148
1149        // Test unknown function
1150        assert!(evaluate_expression::<f64>("unknown_func(5)", &params).is_err());
1151
1152        // Test sqrt of negative number
1153        assert!(evaluate_expression::<f64>("sqrt(-1)", &params).is_err());
1154
1155        // Test unknown constant
1156        assert!(evaluate_expression::<f64>("UNKNOWN_CONSTANT", &params).is_err());
1157    }
1158
1159    #[test]
1160    fn test_complex_automotive_scenarios() {
1161        let mut params = HashMap::new();
1162        params.insert("current_speed".to_string(), "50.0".to_string());
1163        params.insert("target_speed".to_string(), "30.0".to_string());
1164        params.insert("deceleration".to_string(), "3.0".to_string());
1165        params.insert("reaction_time".to_string(), "1.5".to_string());
1166
1167        // Test braking distance calculation: speed^2 / (2 * deceleration)
1168        // Note: Using multiplication instead of exponentiation for now
1169        let result: f64 = evaluate_expression(
1170            "(${current_speed} * ${current_speed}) / (2 * ${deceleration})",
1171            &params,
1172        )
1173        .unwrap();
1174        assert!((result - (50.0 * 50.0) / (2.0 * 3.0)).abs() < 1e-10);
1175
1176        // Test speed comparison: current_speed > target_speed
1177        let result: f64 =
1178            evaluate_expression("${current_speed} > ${target_speed}", &params).unwrap();
1179        assert_eq!(result, 1.0);
1180
1181        // Test time-based calculation with functions
1182        let result: f64 = evaluate_expression(
1183            "max(${reaction_time}, min(5.0, abs(${current_speed} - ${target_speed}) / 10.0))",
1184            &params,
1185        )
1186        .unwrap();
1187        assert_eq!(result, 2.0); // max(1.5, min(5.0, 20.0 / 10.0)) = max(1.5, 2.0) = 2.0
1188    }
1189}