math_engine/
expression.rs

1use crate::context::Context;
2use crate::error;
3
4#[derive(Debug, Clone, Copy)]
5pub enum BinOp {
6    Addition,
7    Subtraction,
8    Product,
9    Division,
10}
11
12#[derive(Debug, Clone)]
13pub enum Expression {
14    BinOp(BinOp, Box<Expression>, Box<Expression>),
15    Constant(f32),
16    Variable(String),
17}
18
19use std::str::FromStr;
20impl FromStr for Expression {
21    type Err = error::ParserError;
22
23    fn from_str(s: &str) -> Result<Self, Self::Err> {
24        use crate::parser::parse_expression;
25        parse_expression(s)
26    }
27}
28
29impl Expression {
30    fn to_string(&self) -> String {
31        match self {
32            Expression::Constant(val) => val.to_string(),
33            Expression::Variable(var) => var.to_string(),
34            Expression::BinOp(op, e1, e2) => {
35                let s1 = e1.to_string();
36                let s2 = e2.to_string();
37                match op {
38                    BinOp::Addition => format!("({} + {})", s1, s2),
39                    BinOp::Subtraction => format!("({} - {})", s1, s2),
40                    BinOp::Product => format!("({} * {})", s1, s2),
41                    BinOp::Division => format!("({} / {})", s1, s2),
42                }
43            }
44        }
45    }
46}
47
48use std::fmt::{Display, Error, Formatter};
49impl Display for Expression {
50    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
51        write!(f, "{}", self.to_string())
52    }
53}
54
55impl Expression {
56    /// Parses an expression from a string.
57    ///
58    /// # Examples
59    /// Basic usage;
60    ///
61    /// ```
62    /// use math_engine::expression::Expression;
63    ///
64    /// let expr = Expression::parse("1.0 + x").unwrap();
65    /// ```
66    ///
67    /// # Errors
68    /// A ParserError is returned by the parser if the string could not be
69    /// parsed properly.
70    pub fn parse(s: &str) -> Result<Self, error::ParserError> {
71        Expression::from_str(s)
72    }
73
74    /// Creates a new constant from a floating point value.
75    ///
76    /// # Examples
77    /// Basic usage:
78    ///
79    /// ```
80    /// use math_engine::expression::Expression;
81    ///
82    /// let expr = Expression::constant(2.0);
83    /// let eval = expr.eval().unwrap();
84    ///
85    /// assert_eq!(eval, 2.0);
86    /// ```
87    pub fn constant(val: f32) -> Self {
88        Expression::Constant(val)
89    }
90
91    /// Creates a variable.
92    ///
93    /// # Examples
94    /// Basic usage:
95    ///
96    /// ```
97    /// use math_engine::context::Context;
98    /// use math_engine::expression::Expression;
99    ///
100    /// let expr = Expression::variable("x");
101    /// let ctx = Context::new().with_variable("x", 32.0);
102    /// let eval = expr.eval_with_context(&ctx).unwrap();
103    ///
104    /// assert_eq!(eval, 32.0);
105    /// ```
106    pub fn variable(name: &str) -> Self {
107        Expression::Variable(name.to_string())
108    }
109
110    /// Creates an expression representing a binary operation.
111    fn binary_op(op: BinOp, e1: Expression, e2: Expression) -> Self {
112        Expression::BinOp(op, Box::new(e1), Box::new(e2))
113    }
114
115    /// Creates a new binary operation which sums two sub-expressions
116    ///
117    /// # Examples
118    /// Basic usage:
119    ///
120    /// ```
121    /// use math_engine::expression::Expression;
122    ///
123    /// let expr = Expression::addition(
124    ///     Expression::constant(2.0),
125    ///     Expression::Constant(3.0)
126    /// );
127    /// let eval = expr.eval().unwrap();
128    ///
129    /// assert_eq!(eval, 5.0);
130    /// ```
131    pub fn addition(e1: Expression, e2: Expression) -> Self {
132        Expression::BinOp(BinOp::Addition, Box::new(e1), Box::new(e2))
133    }
134
135    /// Creates a new binary operation which subtracts two sub-expressions
136    ///
137    /// # Examples
138    /// Basic usage:
139    ///
140    /// ```
141    /// use math_engine::expression::Expression;
142    ///
143    /// let expr = Expression::subtraction(
144    ///     Expression::constant(2.0),
145    ///     Expression::Constant(3.0)
146    /// );
147    /// let eval = expr.eval().unwrap();
148    ///
149    /// assert_eq!(eval, -1.0);
150    /// ```
151    pub fn subtraction(e1: Expression, e2: Expression) -> Self {
152        Expression::BinOp(BinOp::Subtraction, Box::new(e1), Box::new(e2))
153    }
154
155    /// Creates a new binary operation which multiplies two sub-expressions
156    ///
157    /// # Examples
158    /// Basic usage:
159    ///
160    /// ```
161    /// use math_engine::expression::Expression;
162    ///
163    /// let expr = Expression::product(
164    ///     Expression::constant(2.0),
165    ///     Expression::Constant(3.0)
166    /// );
167    /// let eval = expr.eval().unwrap();
168    ///
169    /// assert_eq!(eval, 6.0);
170    /// ```
171    pub fn product(e1: Expression, e2: Expression) -> Self {
172        Expression::BinOp(BinOp::Product, Box::new(e1), Box::new(e2))
173    }
174
175    /// Creates a new binary operation which divides two sub-expressions
176    ///
177    /// # Examples
178    /// Basic usage:
179    ///
180    /// ```
181    /// use math_engine::expression::Expression;
182    ///
183    /// let expr = Expression::division(
184    ///     Expression::constant(3.0),
185    ///     Expression::Constant(2.0)
186    /// );
187    /// let eval = expr.eval().unwrap();
188    ///
189    /// assert_eq!(eval, 1.5);
190    /// ```
191    pub fn division(e1: Expression, e2: Expression) -> Self {
192        Expression::BinOp(BinOp::Division, Box::new(e1), Box::new(e2))
193    }
194
195    fn eval_core(&self, ctx: Option<&Context>) -> Result<f32, error::EvalError> {
196        match self {
197            Expression::Constant(val) => Ok(*val),
198            Expression::BinOp(op, e1, e2) => {
199                let r1 = e1.eval_core(ctx)?;
200                let r2 = e2.eval_core(ctx)?;
201                let r = match op {
202                    BinOp::Addition => r1 + r2,
203                    BinOp::Subtraction => r1 - r2,
204                    BinOp::Product => r1 * r2,
205                    BinOp::Division => r1 / r2,
206                };
207                if r.is_nan() {
208                    Err(error::EvalError::NotANumber)
209                } else if r.is_infinite() {
210                    Err(error::EvalError::IsInfinite)
211                } else {
212                    Ok(r)
213                }
214            }
215            Expression::Variable(name) => match ctx {
216                Some(ctx) => match ctx.get_variable(name) {
217                    Ok(r) => Ok(r),
218                    Err(_) => Err(error::EvalError::VariableNotFound(name.clone())),
219                },
220                None => Err(error::EvalError::NoContextGiven),
221            },
222        }
223    }
224
225    /// Evaluates the expression into a floating point value without a context.
226    ///
227    /// As of now, floating point value is the only supported evaluation. Please
228    /// note that it is therefore subject to approximations due to some values
229    /// not being representable.
230    ///
231    /// # Examples
232    ///
233    /// ```
234    /// use math_engine::context::Context;
235    /// use math_engine::expression::Expression;
236    ///
237    /// // Expression is (1 - 5) + (2 * (4 + 6))
238    /// let expr = Expression::addition(
239    ///     Expression::subtraction(
240    ///         Expression::constant(1.0),
241    ///         Expression::constant(5.0)
242    ///     ),
243    ///     Expression::product(
244    ///         Expression::constant(2.0),
245    ///         Expression::addition(
246    ///             Expression::constant(4.0),
247    ///             Expression::constant(6.0)
248    ///         )
249    ///     )
250    /// );
251    /// let eval = expr.eval().unwrap();
252    ///
253    /// assert_eq!(eval, 16.0);
254    /// ```
255    ///
256    /// # Errors
257    ///
258    /// If any intermediary result is not a number of is infinity, an error is
259    /// returned.
260    /// If the expression contains a variable, an error is returned
261    pub fn eval(&self) -> Result<f32, error::EvalError> {
262        self.eval_core(None)
263    }
264
265    /// Evaluates the expression into a floating point value with a given context.
266    ///
267    /// As of now, floating point value is the only supported evaluation. Please
268    /// note that it is therefore subject to approximations due to some values
269    /// not being representable.
270    ///
271    /// # Examples
272    ///
273    /// ```
274    /// use math_engine::context::Context;
275    /// use math_engine::expression::Expression;
276    ///
277    /// // Expression is (1 / (1 + x))
278    /// let expr = Expression::division(
279    ///     Expression::constant(1.0),
280    ///     Expression::addition(
281    ///         Expression::constant(1.0),
282    ///         Expression::variable("x"),
283    ///     )
284    /// );
285    /// let ctx = Context::new().with_variable("x", 2.0);
286    /// let eval = expr.eval_with_context(&ctx).unwrap();
287    ///
288    /// assert_eq!(eval, 1.0/3.0);
289    /// ```
290    ///
291    /// # Errors
292    ///
293    /// If any intermediary result is not a number of is infinity, an error is
294    /// returned.
295    /// If the expression contains a variable but the context does not define all
296    /// the variables, an error is returned.
297    pub fn eval_with_context(&self, ctx: &Context) -> Result<f32, error::EvalError> {
298        self.eval_core(Some(ctx))
299    }
300
301    /// Calculates the derivative of an expression.
302    ///
303    /// # Examples
304    /// Basic usage:
305    ///
306    /// ```
307    /// use math_engine::expression::Expression;
308    /// use std::str::FromStr;
309    ///
310    /// //Represents y + 2x
311    /// let expr = Expression::from_str("1.0 * y + 2.0 * x");
312    ///
313    /// //Represents y + 2
314    /// let deri = expr.derivative("x");
315    /// ```
316    pub fn derivative(&self, deriv_var: &str) -> Self {
317        match self {
318            Expression::Constant(_) => Expression::constant(0.0),
319            Expression::Variable(var) => {
320                if var.as_str() == deriv_var {
321                    Expression::constant(1.0)
322                } else {
323                    Expression::variable(var.as_str())
324                }
325            }
326            Expression::BinOp(op, e1, e2) => {
327                let deriv_e1 = e1.derivative(deriv_var);
328                let deriv_e2 = e2.derivative(deriv_var);
329                match op {
330                    BinOp::Addition => Expression::addition(deriv_e1, deriv_e2),
331                    BinOp::Subtraction => Expression::subtraction(deriv_e1, deriv_e2),
332                    BinOp::Product => Expression::addition(
333                        Expression::product(*e1.clone(), deriv_e2),
334                        Expression::product(deriv_e1, *e2.clone()),
335                    ),
336                    BinOp::Division => Expression::division(
337                        Expression::subtraction(
338                            Expression::product(*e2.clone(), deriv_e1),
339                            Expression::product(deriv_e2, *e1.clone()),
340                        ),
341                        Expression::product(*e2.clone(), *e2.clone()),
342                    ),
343                }
344            }
345        }
346    }
347
348    /// Simplifies the expression by applying constant propagation.
349    ///
350    /// # Examples
351    /// Basic usage:
352    /// 
353    /// ```
354    /// use math_engine::expression::Expression;
355    ///
356    /// let expr = Expression::parse("1.0 * y + 0.0 * x + 2.0 / 3.0").unwrap();
357    ///
358    /// //Represents "y + 0.66666..."
359    /// let simp = expr.constant_propagation().unwrap()
360    /// ```
361    ///
362    /// # Errors
363    /// An EvalError (DivisionByZero) can be returned if the partial evaluation
364    /// of the expression revealed a division by zero.
365    pub fn constant_propagation(&self) -> Result<Self, error::EvalError> {
366        match self {
367            Expression::Constant(_) => Ok(self.clone()),
368            Expression::Variable(_) => Ok(self.clone()),
369            Expression::BinOp(op, e1, e2) => {
370                let e1 = e1.constant_propagation()?;
371                let e2 = e2.constant_propagation()?;
372                match (op, &e1, &e2) {
373                    (_, Expression::Constant(v1), Expression::Constant(v2)) => match op {
374                        BinOp::Addition => Ok(Expression::constant(v1 + v2)),
375                        BinOp::Subtraction => Ok(Expression::constant(v1 - v2)),
376                        BinOp::Product => Ok(Expression::constant(v1 * v2)),
377                        BinOp::Division => Ok(Expression::constant(v1 / v2)),
378                    },
379                    (BinOp::Product, Expression::Constant(v), _) if *v == 1.0 => Ok(e2),
380                    (BinOp::Product, _, Expression::Constant(v)) if *v == 1.0  => Ok(e1),
381                    (BinOp::Division, _, Expression::Constant(v)) if *v == 1.0 => Ok(e1),
382                    (_, Expression::Constant(v), _) if *v == 0.0 => match op {
383                        BinOp::Addition => Ok(e2),
384                        BinOp::Subtraction => unimplemented!(),
385                        BinOp::Product => Ok(Expression::constant(0.0)),
386                        BinOp::Division => Ok(Expression::constant(0.0)),
387                    },
388                    (_, _, Expression::Constant(v)) if *v == 0.0 => match op {
389                        BinOp::Addition => Ok(e1),
390                        BinOp::Subtraction => Ok(e1),
391                        BinOp::Product => Ok(Expression::constant(0.0)),
392                        BinOp::Division => Err(error::EvalError::DivisionByZero),
393                    },
394                    _ => Ok(Expression::binary_op(*op, e1, e2)),
395                }
396            }
397        }
398    }
399}
400
401use std::ops::{Add, Sub, Mul, Div};
402macro_rules! expression_impl_trait {
403    ($tr:ident, $tr_fun:ident, $fun:ident) => {
404        impl $tr for Expression {
405            type Output = Self;
406
407            fn $tr_fun(self, other: Self) -> Self::Output {
408                Expression::$fun(self, other)
409            }
410        }
411        //impl {$t}rAssign for Expression {
412        //    fn $tr_fun_assign(&mut self, other: Self) {
413        //        *self = Expression::$fun(self, other)
414        //    }
415        //}
416    }
417}
418expression_impl_trait!(Add, add, addition);
419expression_impl_trait!(Sub, sub, subtraction);
420expression_impl_trait!(Mul, mul, product);
421expression_impl_trait!(Div, div, division);