1use crate::parsing::*;
2use std::collections::HashMap;
3
4#[derive(Debug, PartialEq)]
5pub enum Expression {
6 Number(f64),
7 Variable(String),
8 Add(Box<Expression>, Box<Expression>),
9 Subtract(Box<Expression>, Box<Expression>),
10 Multiply(Box<Expression>, Box<Expression>),
11 Divide(Box<Expression>, Box<Expression>),
12 Power(Box<Expression>, Box<Expression>),
13 Sin(Box<Expression>),
14 Cos(Box<Expression>),
15 Tan(Box<Expression>),
16 ArcSin(Box<Expression>),
17 ArcCos(Box<Expression>),
18 ArcTan(Box<Expression>),
19}
20
21impl Expression {
22 pub fn evaluate(&self, variables: &HashMap<String, f64>) -> Result<f64, String> {
23 match self {
24 Expression::Number(n) => Ok(*n),
25 Expression::Variable(name) => variables
26 .get(name)
27 .copied()
28 .ok_or(format!("Variable '{}' not found", name)),
29 Expression::Add(a, b) => Ok(a.evaluate(variables)? + b.evaluate(variables)?),
30 Expression::Subtract(a, b) => Ok(a.evaluate(variables)? - b.evaluate(variables)?),
31 Expression::Multiply(a, b) => Ok(a.evaluate(variables)? * b.evaluate(variables)?),
32 Expression::Divide(a, b) => {
33 let denominator = b.evaluate(variables)?;
34 if denominator == 0.0 {
35 return Err("Division by 0".to_string());
36 }
37 Ok(a.evaluate(variables)? / denominator)
38 }
39 Expression::Power(base, exponent) => Ok(base
40 .evaluate(variables)?
41 .powf(exponent.evaluate(variables)?)),
42 Expression::Sin(expr) => Ok(expr.evaluate(variables)?.sin()),
43 Expression::Cos(expr) => Ok(expr.evaluate(variables)?.cos()),
44 Expression::Tan(expr) => Ok(expr.evaluate(variables)?.tan()),
45 Expression::ArcSin(expr) => {
46 let val = expr.evaluate(variables)?;
47 if val < -1.0 || val > 1.0 {
48 Err("Domain error: arcsin argument must be between -1 and 1".to_string())
49 } else {
50 Ok(val.asin())
51 }
52 }
53 Expression::ArcCos(expr) => {
54 let val = expr.evaluate(variables)?;
55 if val < -1.0 || val > 1.0 {
56 Err("Domain error: arccos argument must be between -1 and 1".to_string())
57 } else {
58 Ok(val.acos())
59 }
60 }
61 Expression::ArcTan(expr) => Ok(expr.evaluate(variables)?.atan()),
62 }
63 }
64
65 pub fn parse(input: &str) -> Result<Expression, String> {
66 Parser::new(tokenize(input)?).parse_expression()
67 }
68}
69
70pub mod expr {
71 use super::Expression;
72
73 pub fn number(n: f64) -> Expression {
74 Expression::Number(n)
75 }
76
77 pub fn variable(name: &str) -> Expression {
78 Expression::Variable(name.to_string())
79 }
80
81 pub fn add(a: Expression, b: Expression) -> Expression {
82 Expression::Add(Box::new(a), Box::new(b))
83 }
84
85 pub fn subtract(a: Expression, b: Expression) -> Expression {
86 Expression::Subtract(Box::new(a), Box::new(b))
87 }
88
89 pub fn multiply(a: Expression, b: Expression) -> Expression {
90 Expression::Multiply(Box::new(a), Box::new(b))
91 }
92
93 pub fn divide(a: Expression, b: Expression) -> Expression {
94 Expression::Divide(Box::new(a), Box::new(b))
95 }
96
97 pub fn power(base: Expression, power: Expression) -> Expression {
98 Expression::Power(Box::new(base), Box::new(power))
99 }
100
101 pub fn sin(expr: Expression) -> Expression {
102 Expression::Sin(Box::new(expr))
103 }
104
105 pub fn cos(expr: Expression) -> Expression {
106 Expression::Cos(Box::new(expr))
107 }
108
109 pub fn tan(expr: Expression) -> Expression {
110 Expression::Tan(Box::new(expr))
111 }
112
113 pub fn arcsin(expr: Expression) -> Expression {
114 Expression::ArcSin(Box::new(expr))
115 }
116
117 pub fn arccos(expr: Expression) -> Expression {
118 Expression::ArcCos(Box::new(expr))
119 }
120
121 pub fn arctan(expr: Expression) -> Expression {
122 Expression::ArcTan(Box::new(expr))
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::collections::HashMap;
130
131 fn parse_str(input: &str) -> Result<Expression, String> {
133 let tokens = tokenize(input)?;
134 Parser::new(tokens).parse_expression()
135 }
136
137 #[test]
138 fn test_tokenize() {
139 assert_eq!(
140 tokenize("1 + 2").unwrap(),
141 vec![Token::Number(1.0), Token::Plus, Token::Number(2.0)]
142 );
143
144 assert_eq!(
145 tokenize("sin(x)").unwrap(),
146 vec![Token::Function("sin".to_string()), Token::LParen, Token::Variable("x".to_string()), Token::RParen]
147 );
148
149 assert_eq!(
150 tokenize("x^2 + y^2").unwrap(),
151 vec![
152 Token::Variable("x".to_string()),
153 Token::Caret,
154 Token::Number(2.0),
155 Token::Plus,
156 Token::Variable("y".to_string()),
157 Token::Caret,
158 Token::Number(2.0)
159 ]
160 );
161
162 assert_eq!(
163 tokenize(" 1 + 2 ").unwrap(),
164 vec![Token::Number(1.0), Token::Plus, Token::Number(2.0)]
165 );
166 }
167
168 #[test]
169 fn test_parse_numbers_and_variables() {
170 assert_eq!(parse_str("42").unwrap(), Expression::Number(42.0));
171 assert_eq!(parse_str("3.141592653589793").unwrap(), Expression::Number(std::f64::consts::PI));
172 assert_eq!(parse_str("x").unwrap(), Expression::Variable("x".to_string()));
173 }
174
175 #[test]
176 fn test_parse_operators() {
177 assert_eq!(
179 parse_str("1 + 2").unwrap(),
180 Expression::Add(Box::new(Expression::Number(1.0)), Box::new(Expression::Number(2.0)))
181 );
182
183 assert_eq!(
185 parse_str("5 - 3").unwrap(),
186 Expression::Subtract(Box::new(Expression::Number(5.0)), Box::new(Expression::Number(3.0)))
187 );
188
189 assert_eq!(
191 parse_str("2 * 3").unwrap(),
192 Expression::Multiply(Box::new(Expression::Number(2.0)), Box::new(Expression::Number(3.0)))
193 );
194
195 assert_eq!(
197 parse_str("6 / 2").unwrap(),
198 Expression::Divide(Box::new(Expression::Number(6.0)), Box::new(Expression::Number(2.0)))
199 );
200
201 assert_eq!(
203 parse_str("2 ^ 3").unwrap(),
204 Expression::Power(Box::new(Expression::Number(2.0)), Box::new(Expression::Number(3.0)))
205 );
206 }
207
208 #[test]
209 fn test_parse_precedence() {
210 let expr = parse_str("1 + 2 * 3").unwrap();
212 let vars = HashMap::new();
213 assert_eq!(expr.evaluate(&vars).unwrap(), 7.0);
214
215 let expr = parse_str("(1 + 2) * 3").unwrap();
217 assert_eq!(expr.evaluate(&vars).unwrap(), 9.0);
218
219 let expr = parse_str("2 * 3 + 4 ^ 2 / 2").unwrap();
221 assert_eq!(expr.evaluate(&vars).unwrap(), 14.0);
222 }
223
224 #[test]
225 fn test_parse_functions() {
226 assert!(matches!(parse_str("sin(0)").unwrap(), Expression::Sin(_)));
228 assert!(matches!(parse_str("cos(0)").unwrap(), Expression::Cos(_)));
229 assert!(matches!(parse_str("tan(0)").unwrap(), Expression::Tan(_)));
230 assert!(matches!(parse_str("arcsin(0)").unwrap(), Expression::ArcSin(_)));
231 assert!(matches!(parse_str("arccos(1)").unwrap(), Expression::ArcCos(_)));
232 assert!(matches!(parse_str("arctan(0)").unwrap(), Expression::ArcTan(_)));
233
234 assert!(matches!(
236 parse_str("sin(x + 1)").unwrap(),
237 Expression::Sin(_)
238 ));
239 }
240
241 #[test]
242 fn test_parse_complex_expressions() {
243 let expr = parse_str("(1 + 2) * (3 - 4) ^ 2").unwrap();
245 let mut vars = HashMap::new();
246 assert_eq!(expr.evaluate(&vars).unwrap(), 3.0);
247
248 let expr = parse_str("x^2 + y^2").unwrap();
250 vars.insert("x".to_string(), 3.0);
251 vars.insert("y".to_string(), 4.0);
252 assert_eq!(expr.evaluate(&vars).unwrap(), 25.0);
253
254 let expr = parse_str("sin(x)^2 + cos(x)^2").unwrap();
256 vars.insert("x".to_string(), 0.5);
257 assert_eq!(expr.evaluate(&vars).unwrap(), 1.0);
258 }
259
260 #[test]
261 fn test_parse_errors() {
262 assert!(parse_str("(1 + 2").is_err());
264
265 assert!(parse_str("unknown(x)").is_err());
267
268 assert!(parse_str("1 + + 2").is_err());
270
271 assert!(parse_str("").is_err());
273
274 assert!(parse_str("sin x").is_err());
276
277 assert!(tokenize("1 @ 2").is_err());
279 }
280
281 #[test]
282 fn test_domain_errors() {
283 let expr = parse_str("arcsin(2)").unwrap();
285 let vars = HashMap::new();
286 assert!(expr.evaluate(&vars).is_err());
287
288 let expr = parse_str("1 / (x - 2)").unwrap();
290 let mut vars = HashMap::new();
291 vars.insert("x".to_string(), 2.0);
292 assert!(expr.evaluate(&vars).is_err());
293 }
294}