expression_core/
expression.rs

1use crate::parsing::*;
2use std::collections::HashMap;
3
4#[derive(Debug, PartialEq)]
5pub enum Expression {
6    Number(f64),
7    Variable(String),
8    Add(Box<Expression>, Box<Expression>),
9    Subtract(Box<Expression>, Box<Expression>),
10    Multiply(Box<Expression>, Box<Expression>),
11    Divide(Box<Expression>, Box<Expression>),
12    Power(Box<Expression>, Box<Expression>),
13    Sin(Box<Expression>),
14    Cos(Box<Expression>),
15    Tan(Box<Expression>),
16    ArcSin(Box<Expression>),
17    ArcCos(Box<Expression>),
18    ArcTan(Box<Expression>),
19}
20
21impl Expression {
22    pub fn evaluate(&self, variables: &HashMap<String, f64>) -> Result<f64, String> {
23        match self {
24            Expression::Number(n) => Ok(*n),
25            Expression::Variable(name) => variables
26                .get(name)
27                .copied()
28                .ok_or(format!("Variable '{}' not found", name)),
29            Expression::Add(a, b) => Ok(a.evaluate(variables)? + b.evaluate(variables)?),
30            Expression::Subtract(a, b) => Ok(a.evaluate(variables)? - b.evaluate(variables)?),
31            Expression::Multiply(a, b) => Ok(a.evaluate(variables)? * b.evaluate(variables)?),
32            Expression::Divide(a, b) => {
33                let denominator = b.evaluate(variables)?;
34                if denominator == 0.0 {
35                    return Err("Division by 0".to_string());
36                }
37                Ok(a.evaluate(variables)? / denominator)
38            }
39            Expression::Power(base, exponent) => Ok(base
40                .evaluate(variables)?
41                .powf(exponent.evaluate(variables)?)),
42            Expression::Sin(expr) => Ok(expr.evaluate(variables)?.sin()),
43            Expression::Cos(expr) => Ok(expr.evaluate(variables)?.cos()),
44            Expression::Tan(expr) => Ok(expr.evaluate(variables)?.tan()),
45            Expression::ArcSin(expr) => {
46                let val = expr.evaluate(variables)?;
47                if val < -1.0 || val > 1.0 {
48                    Err("Domain error: arcsin argument must be between -1 and 1".to_string())
49                } else {
50                    Ok(val.asin())
51                }
52            }
53            Expression::ArcCos(expr) => {
54                let val = expr.evaluate(variables)?;
55                if val < -1.0 || val > 1.0 {
56                    Err("Domain error: arccos argument must be between -1 and 1".to_string())
57                } else {
58                    Ok(val.acos())
59                }
60            }
61            Expression::ArcTan(expr) => Ok(expr.evaluate(variables)?.atan()),
62        }
63    }
64
65    pub fn parse(input: &str) -> Result<Expression, String> {
66        Parser::new(tokenize(input)?).parse_expression()
67    }
68}
69
70pub mod expr {
71    use super::Expression;
72
73    pub fn number(n: f64) -> Expression {
74        Expression::Number(n)
75    }
76
77    pub fn variable(name: &str) -> Expression {
78        Expression::Variable(name.to_string())
79    }
80
81    pub fn add(a: Expression, b: Expression) -> Expression {
82        Expression::Add(Box::new(a), Box::new(b))
83    }
84
85    pub fn subtract(a: Expression, b: Expression) -> Expression {
86        Expression::Subtract(Box::new(a), Box::new(b))
87    }
88
89    pub fn multiply(a: Expression, b: Expression) -> Expression {
90        Expression::Multiply(Box::new(a), Box::new(b))
91    }
92
93    pub fn divide(a: Expression, b: Expression) -> Expression {
94        Expression::Divide(Box::new(a), Box::new(b))
95    }
96
97    pub fn power(base: Expression, power: Expression) -> Expression {
98        Expression::Power(Box::new(base), Box::new(power))
99    }
100
101    pub fn sin(expr: Expression) -> Expression {
102        Expression::Sin(Box::new(expr))
103    }
104
105    pub fn cos(expr: Expression) -> Expression {
106        Expression::Cos(Box::new(expr))
107    }
108
109    pub fn tan(expr: Expression) -> Expression {
110        Expression::Tan(Box::new(expr))
111    }
112
113    pub fn arcsin(expr: Expression) -> Expression {
114        Expression::ArcSin(Box::new(expr))
115    }
116
117    pub fn arccos(expr: Expression) -> Expression {
118        Expression::ArcCos(Box::new(expr))
119    }
120
121    pub fn arctan(expr: Expression) -> Expression {
122        Expression::ArcTan(Box::new(expr))
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use std::collections::HashMap;
130
131    // Helper function to parse a string directly into an Expression
132    fn parse_str(input: &str) -> Result<Expression, String> {
133        let tokens = tokenize(input)?;
134        Parser::new(tokens).parse_expression()
135    }
136
137    #[test]
138    fn test_tokenize() {
139        assert_eq!(
140            tokenize("1 + 2").unwrap(),
141            vec![Token::Number(1.0), Token::Plus, Token::Number(2.0)]
142        );
143
144        assert_eq!(
145            tokenize("sin(x)").unwrap(),
146            vec![Token::Function("sin".to_string()), Token::LParen, Token::Variable("x".to_string()), Token::RParen]
147        );
148
149        assert_eq!(
150            tokenize("x^2 + y^2").unwrap(),
151            vec![
152                Token::Variable("x".to_string()),
153                Token::Caret,
154                Token::Number(2.0),
155                Token::Plus,
156                Token::Variable("y".to_string()),
157                Token::Caret,
158                Token::Number(2.0)
159            ]
160        );
161
162        assert_eq!(
163            tokenize("  1  +  2  ").unwrap(),
164            vec![Token::Number(1.0), Token::Plus, Token::Number(2.0)]
165        );
166    }
167
168    #[test]
169    fn test_parse_numbers_and_variables() {
170        assert_eq!(parse_str("42").unwrap(), Expression::Number(42.0));
171        assert_eq!(parse_str("3.141592653589793").unwrap(), Expression::Number(std::f64::consts::PI));
172        assert_eq!(parse_str("x").unwrap(), Expression::Variable("x".to_string()));
173    }
174
175    #[test]
176    fn test_parse_operators() {
177        // Addition
178        assert_eq!(
179            parse_str("1 + 2").unwrap(),
180            Expression::Add(Box::new(Expression::Number(1.0)), Box::new(Expression::Number(2.0)))
181        );
182
183        // Subtraction
184        assert_eq!(
185            parse_str("5 - 3").unwrap(),
186            Expression::Subtract(Box::new(Expression::Number(5.0)), Box::new(Expression::Number(3.0)))
187        );
188
189        // Multiplication
190        assert_eq!(
191            parse_str("2 * 3").unwrap(),
192            Expression::Multiply(Box::new(Expression::Number(2.0)), Box::new(Expression::Number(3.0)))
193        );
194
195        // Division
196        assert_eq!(
197            parse_str("6 / 2").unwrap(),
198            Expression::Divide(Box::new(Expression::Number(6.0)), Box::new(Expression::Number(2.0)))
199        );
200
201        // Power
202        assert_eq!(
203            parse_str("2 ^ 3").unwrap(),
204            Expression::Power(Box::new(Expression::Number(2.0)), Box::new(Expression::Number(3.0)))
205        );
206    }
207
208    #[test]
209    fn test_parse_precedence() {
210        // Test operator precedence: 1 + 2 * 3 = 1 + (2 * 3) = 7
211        let expr = parse_str("1 + 2 * 3").unwrap();
212        let vars = HashMap::new();
213        assert_eq!(expr.evaluate(&vars).unwrap(), 7.0);
214
215        // Test operator precedence with parentheses: (1 + 2) * 3 = 9
216        let expr = parse_str("(1 + 2) * 3").unwrap();
217        assert_eq!(expr.evaluate(&vars).unwrap(), 9.0);
218
219        // Test complex expression: 2 * 3 + 4 ^ 2 / 2 = 6 + 8 = 14
220        let expr = parse_str("2 * 3 + 4 ^ 2 / 2").unwrap();
221        assert_eq!(expr.evaluate(&vars).unwrap(), 14.0);
222    }
223
224    #[test]
225    fn test_parse_functions() {
226        // Test trigonometric functions
227        assert!(matches!(parse_str("sin(0)").unwrap(), Expression::Sin(_)));
228        assert!(matches!(parse_str("cos(0)").unwrap(), Expression::Cos(_)));
229        assert!(matches!(parse_str("tan(0)").unwrap(), Expression::Tan(_)));
230        assert!(matches!(parse_str("arcsin(0)").unwrap(), Expression::ArcSin(_)));
231        assert!(matches!(parse_str("arccos(1)").unwrap(), Expression::ArcCos(_)));
232        assert!(matches!(parse_str("arctan(0)").unwrap(), Expression::ArcTan(_)));
233
234        // Test function with expression argument
235        assert!(matches!(
236            parse_str("sin(x + 1)").unwrap(),
237            Expression::Sin(_)
238        ));
239    }
240
241    #[test]
242    fn test_parse_complex_expressions() {
243        // Test nested expressions
244        let expr = parse_str("(1 + 2) * (3 - 4) ^ 2").unwrap();
245        let mut vars = HashMap::new();
246        assert_eq!(expr.evaluate(&vars).unwrap(), 3.0);
247
248        // Test with variables
249        let expr = parse_str("x^2 + y^2").unwrap();
250        vars.insert("x".to_string(), 3.0);
251        vars.insert("y".to_string(), 4.0);
252        assert_eq!(expr.evaluate(&vars).unwrap(), 25.0);
253
254        // Test complex function usage
255        let expr = parse_str("sin(x)^2 + cos(x)^2").unwrap();
256        vars.insert("x".to_string(), 0.5);
257        assert_eq!(expr.evaluate(&vars).unwrap(), 1.0);
258    }
259
260    #[test]
261    fn test_parse_errors() {
262        // Missing closing parenthesis
263        assert!(parse_str("(1 + 2").is_err());
264
265        // Unknown function
266        assert!(parse_str("unknown(x)").is_err());
267
268        // Invalid syntax
269        assert!(parse_str("1 + + 2").is_err());
270
271        // Empty expression
272        assert!(parse_str("").is_err());
273
274        // Expected parenthesis after function
275        assert!(parse_str("sin x").is_err());
276
277        // Invalid characters
278        assert!(tokenize("1 @ 2").is_err());
279    }
280
281    #[test]
282    fn test_domain_errors() {
283        // Test arcsin domain error
284        let expr = parse_str("arcsin(2)").unwrap();
285        let vars = HashMap::new();
286        assert!(expr.evaluate(&vars).is_err());
287
288        // Test division by zero
289        let expr = parse_str("1 / (x - 2)").unwrap();
290        let mut vars = HashMap::new();
291        vars.insert("x".to_string(), 2.0);
292        assert!(expr.evaluate(&vars).is_err());
293    }
294}