math_engine/expression.rs
1use crate::context::Context;
2use crate::error;
3
4#[derive(Debug, Clone, Copy)]
5pub enum BinOp {
6 Addition,
7 Subtraction,
8 Product,
9 Division,
10}
11
12#[derive(Debug, Clone)]
13pub enum Expression {
14 BinOp(BinOp, Box<Expression>, Box<Expression>),
15 Constant(f32),
16 Variable(String),
17}
18
19use std::str::FromStr;
20impl FromStr for Expression {
21 type Err = error::ParserError;
22
23 fn from_str(s: &str) -> Result<Self, Self::Err> {
24 use crate::parser::parse_expression;
25 parse_expression(s)
26 }
27}
28
29impl Expression {
30 fn to_string(&self) -> String {
31 match self {
32 Expression::Constant(val) => val.to_string(),
33 Expression::Variable(var) => var.to_string(),
34 Expression::BinOp(op, e1, e2) => {
35 let s1 = e1.to_string();
36 let s2 = e2.to_string();
37 match op {
38 BinOp::Addition => format!("({} + {})", s1, s2),
39 BinOp::Subtraction => format!("({} - {})", s1, s2),
40 BinOp::Product => format!("({} * {})", s1, s2),
41 BinOp::Division => format!("({} / {})", s1, s2),
42 }
43 }
44 }
45 }
46}
47
48use std::fmt::{Display, Error, Formatter};
49impl Display for Expression {
50 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
51 write!(f, "{}", self.to_string())
52 }
53}
54
55impl Expression {
56 /// Parses an expression from a string.
57 ///
58 /// # Examples
59 /// Basic usage;
60 ///
61 /// ```
62 /// use math_engine::expression::Expression;
63 ///
64 /// let expr = Expression::parse("1.0 + x").unwrap();
65 /// ```
66 ///
67 /// # Errors
68 /// A ParserError is returned by the parser if the string could not be
69 /// parsed properly.
70 pub fn parse(s: &str) -> Result<Self, error::ParserError> {
71 Expression::from_str(s)
72 }
73
74 /// Creates a new constant from a floating point value.
75 ///
76 /// # Examples
77 /// Basic usage:
78 ///
79 /// ```
80 /// use math_engine::expression::Expression;
81 ///
82 /// let expr = Expression::constant(2.0);
83 /// let eval = expr.eval().unwrap();
84 ///
85 /// assert_eq!(eval, 2.0);
86 /// ```
87 pub fn constant(val: f32) -> Self {
88 Expression::Constant(val)
89 }
90
91 /// Creates a variable.
92 ///
93 /// # Examples
94 /// Basic usage:
95 ///
96 /// ```
97 /// use math_engine::context::Context;
98 /// use math_engine::expression::Expression;
99 ///
100 /// let expr = Expression::variable("x");
101 /// let ctx = Context::new().with_variable("x", 32.0);
102 /// let eval = expr.eval_with_context(&ctx).unwrap();
103 ///
104 /// assert_eq!(eval, 32.0);
105 /// ```
106 pub fn variable(name: &str) -> Self {
107 Expression::Variable(name.to_string())
108 }
109
110 /// Creates an expression representing a binary operation.
111 fn binary_op(op: BinOp, e1: Expression, e2: Expression) -> Self {
112 Expression::BinOp(op, Box::new(e1), Box::new(e2))
113 }
114
115 /// Creates a new binary operation which sums two sub-expressions
116 ///
117 /// # Examples
118 /// Basic usage:
119 ///
120 /// ```
121 /// use math_engine::expression::Expression;
122 ///
123 /// let expr = Expression::addition(
124 /// Expression::constant(2.0),
125 /// Expression::Constant(3.0)
126 /// );
127 /// let eval = expr.eval().unwrap();
128 ///
129 /// assert_eq!(eval, 5.0);
130 /// ```
131 pub fn addition(e1: Expression, e2: Expression) -> Self {
132 Expression::BinOp(BinOp::Addition, Box::new(e1), Box::new(e2))
133 }
134
135 /// Creates a new binary operation which subtracts two sub-expressions
136 ///
137 /// # Examples
138 /// Basic usage:
139 ///
140 /// ```
141 /// use math_engine::expression::Expression;
142 ///
143 /// let expr = Expression::subtraction(
144 /// Expression::constant(2.0),
145 /// Expression::Constant(3.0)
146 /// );
147 /// let eval = expr.eval().unwrap();
148 ///
149 /// assert_eq!(eval, -1.0);
150 /// ```
151 pub fn subtraction(e1: Expression, e2: Expression) -> Self {
152 Expression::BinOp(BinOp::Subtraction, Box::new(e1), Box::new(e2))
153 }
154
155 /// Creates a new binary operation which multiplies two sub-expressions
156 ///
157 /// # Examples
158 /// Basic usage:
159 ///
160 /// ```
161 /// use math_engine::expression::Expression;
162 ///
163 /// let expr = Expression::product(
164 /// Expression::constant(2.0),
165 /// Expression::Constant(3.0)
166 /// );
167 /// let eval = expr.eval().unwrap();
168 ///
169 /// assert_eq!(eval, 6.0);
170 /// ```
171 pub fn product(e1: Expression, e2: Expression) -> Self {
172 Expression::BinOp(BinOp::Product, Box::new(e1), Box::new(e2))
173 }
174
175 /// Creates a new binary operation which divides two sub-expressions
176 ///
177 /// # Examples
178 /// Basic usage:
179 ///
180 /// ```
181 /// use math_engine::expression::Expression;
182 ///
183 /// let expr = Expression::division(
184 /// Expression::constant(3.0),
185 /// Expression::Constant(2.0)
186 /// );
187 /// let eval = expr.eval().unwrap();
188 ///
189 /// assert_eq!(eval, 1.5);
190 /// ```
191 pub fn division(e1: Expression, e2: Expression) -> Self {
192 Expression::BinOp(BinOp::Division, Box::new(e1), Box::new(e2))
193 }
194
195 fn eval_core(&self, ctx: Option<&Context>) -> Result<f32, error::EvalError> {
196 match self {
197 Expression::Constant(val) => Ok(*val),
198 Expression::BinOp(op, e1, e2) => {
199 let r1 = e1.eval_core(ctx)?;
200 let r2 = e2.eval_core(ctx)?;
201 let r = match op {
202 BinOp::Addition => r1 + r2,
203 BinOp::Subtraction => r1 - r2,
204 BinOp::Product => r1 * r2,
205 BinOp::Division => r1 / r2,
206 };
207 if r.is_nan() {
208 Err(error::EvalError::NotANumber)
209 } else if r.is_infinite() {
210 Err(error::EvalError::IsInfinite)
211 } else {
212 Ok(r)
213 }
214 }
215 Expression::Variable(name) => match ctx {
216 Some(ctx) => match ctx.get_variable(name) {
217 Ok(r) => Ok(r),
218 Err(_) => Err(error::EvalError::VariableNotFound(name.clone())),
219 },
220 None => Err(error::EvalError::NoContextGiven),
221 },
222 }
223 }
224
225 /// Evaluates the expression into a floating point value without a context.
226 ///
227 /// As of now, floating point value is the only supported evaluation. Please
228 /// note that it is therefore subject to approximations due to some values
229 /// not being representable.
230 ///
231 /// # Examples
232 ///
233 /// ```
234 /// use math_engine::context::Context;
235 /// use math_engine::expression::Expression;
236 ///
237 /// // Expression is (1 - 5) + (2 * (4 + 6))
238 /// let expr = Expression::addition(
239 /// Expression::subtraction(
240 /// Expression::constant(1.0),
241 /// Expression::constant(5.0)
242 /// ),
243 /// Expression::product(
244 /// Expression::constant(2.0),
245 /// Expression::addition(
246 /// Expression::constant(4.0),
247 /// Expression::constant(6.0)
248 /// )
249 /// )
250 /// );
251 /// let eval = expr.eval().unwrap();
252 ///
253 /// assert_eq!(eval, 16.0);
254 /// ```
255 ///
256 /// # Errors
257 ///
258 /// If any intermediary result is not a number of is infinity, an error is
259 /// returned.
260 /// If the expression contains a variable, an error is returned
261 pub fn eval(&self) -> Result<f32, error::EvalError> {
262 self.eval_core(None)
263 }
264
265 /// Evaluates the expression into a floating point value with a given context.
266 ///
267 /// As of now, floating point value is the only supported evaluation. Please
268 /// note that it is therefore subject to approximations due to some values
269 /// not being representable.
270 ///
271 /// # Examples
272 ///
273 /// ```
274 /// use math_engine::context::Context;
275 /// use math_engine::expression::Expression;
276 ///
277 /// // Expression is (1 / (1 + x))
278 /// let expr = Expression::division(
279 /// Expression::constant(1.0),
280 /// Expression::addition(
281 /// Expression::constant(1.0),
282 /// Expression::variable("x"),
283 /// )
284 /// );
285 /// let ctx = Context::new().with_variable("x", 2.0);
286 /// let eval = expr.eval_with_context(&ctx).unwrap();
287 ///
288 /// assert_eq!(eval, 1.0/3.0);
289 /// ```
290 ///
291 /// # Errors
292 ///
293 /// If any intermediary result is not a number of is infinity, an error is
294 /// returned.
295 /// If the expression contains a variable but the context does not define all
296 /// the variables, an error is returned.
297 pub fn eval_with_context(&self, ctx: &Context) -> Result<f32, error::EvalError> {
298 self.eval_core(Some(ctx))
299 }
300
301 /// Calculates the derivative of an expression.
302 ///
303 /// # Examples
304 /// Basic usage:
305 ///
306 /// ```
307 /// use math_engine::expression::Expression;
308 /// use std::str::FromStr;
309 ///
310 /// //Represents y + 2x
311 /// let expr = Expression::from_str("1.0 * y + 2.0 * x");
312 ///
313 /// //Represents y + 2
314 /// let deri = expr.derivative("x");
315 /// ```
316 pub fn derivative(&self, deriv_var: &str) -> Self {
317 match self {
318 Expression::Constant(_) => Expression::constant(0.0),
319 Expression::Variable(var) => {
320 if var.as_str() == deriv_var {
321 Expression::constant(1.0)
322 } else {
323 Expression::variable(var.as_str())
324 }
325 }
326 Expression::BinOp(op, e1, e2) => {
327 let deriv_e1 = e1.derivative(deriv_var);
328 let deriv_e2 = e2.derivative(deriv_var);
329 match op {
330 BinOp::Addition => Expression::addition(deriv_e1, deriv_e2),
331 BinOp::Subtraction => Expression::subtraction(deriv_e1, deriv_e2),
332 BinOp::Product => Expression::addition(
333 Expression::product(*e1.clone(), deriv_e2),
334 Expression::product(deriv_e1, *e2.clone()),
335 ),
336 BinOp::Division => Expression::division(
337 Expression::subtraction(
338 Expression::product(*e2.clone(), deriv_e1),
339 Expression::product(deriv_e2, *e1.clone()),
340 ),
341 Expression::product(*e2.clone(), *e2.clone()),
342 ),
343 }
344 }
345 }
346 }
347
348 /// Simplifies the expression by applying constant propagation.
349 ///
350 /// # Examples
351 /// Basic usage:
352 ///
353 /// ```
354 /// use math_engine::expression::Expression;
355 ///
356 /// let expr = Expression::parse("1.0 * y + 0.0 * x + 2.0 / 3.0").unwrap();
357 ///
358 /// //Represents "y + 0.66666..."
359 /// let simp = expr.constant_propagation().unwrap()
360 /// ```
361 ///
362 /// # Errors
363 /// An EvalError (DivisionByZero) can be returned if the partial evaluation
364 /// of the expression revealed a division by zero.
365 pub fn constant_propagation(&self) -> Result<Self, error::EvalError> {
366 match self {
367 Expression::Constant(_) => Ok(self.clone()),
368 Expression::Variable(_) => Ok(self.clone()),
369 Expression::BinOp(op, e1, e2) => {
370 let e1 = e1.constant_propagation()?;
371 let e2 = e2.constant_propagation()?;
372 match (op, &e1, &e2) {
373 (_, Expression::Constant(v1), Expression::Constant(v2)) => match op {
374 BinOp::Addition => Ok(Expression::constant(v1 + v2)),
375 BinOp::Subtraction => Ok(Expression::constant(v1 - v2)),
376 BinOp::Product => Ok(Expression::constant(v1 * v2)),
377 BinOp::Division => Ok(Expression::constant(v1 / v2)),
378 },
379 (BinOp::Product, Expression::Constant(v), _) if *v == 1.0 => Ok(e2),
380 (BinOp::Product, _, Expression::Constant(v)) if *v == 1.0 => Ok(e1),
381 (BinOp::Division, _, Expression::Constant(v)) if *v == 1.0 => Ok(e1),
382 (_, Expression::Constant(v), _) if *v == 0.0 => match op {
383 BinOp::Addition => Ok(e2),
384 BinOp::Subtraction => unimplemented!(),
385 BinOp::Product => Ok(Expression::constant(0.0)),
386 BinOp::Division => Ok(Expression::constant(0.0)),
387 },
388 (_, _, Expression::Constant(v)) if *v == 0.0 => match op {
389 BinOp::Addition => Ok(e1),
390 BinOp::Subtraction => Ok(e1),
391 BinOp::Product => Ok(Expression::constant(0.0)),
392 BinOp::Division => Err(error::EvalError::DivisionByZero),
393 },
394 _ => Ok(Expression::binary_op(*op, e1, e2)),
395 }
396 }
397 }
398 }
399}
400
401use std::ops::{Add, Sub, Mul, Div};
402macro_rules! expression_impl_trait {
403 ($tr:ident, $tr_fun:ident, $fun:ident) => {
404 impl $tr for Expression {
405 type Output = Self;
406
407 fn $tr_fun(self, other: Self) -> Self::Output {
408 Expression::$fun(self, other)
409 }
410 }
411 //impl {$t}rAssign for Expression {
412 // fn $tr_fun_assign(&mut self, other: Self) {
413 // *self = Expression::$fun(self, other)
414 // }
415 //}
416 }
417}
418expression_impl_trait!(Add, add, addition);
419expression_impl_trait!(Sub, sub, subtraction);
420expression_impl_trait!(Mul, mul, product);
421expression_impl_trait!(Div, div, division);