Skip to main content

constraint_solver/
exp.rs

1/*
2MIT License
3
4Copyright (c) 2026 Raja Lehtihet & Wael El Oraiby
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23*/
24
25use std::collections::HashMap;
26use std::fmt;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct MissingVarError {
30    pub var_name: String,
31}
32
33impl fmt::Display for MissingVarError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        write!(f, "Missing variable '{}'", self.var_name)
36    }
37}
38
39impl std::error::Error for MissingVarError {}
40
41#[derive(Debug, Clone, PartialEq)]
42pub enum Exp {
43    Val(f64),
44    Var(String),
45    Add(Box<Exp>, Box<Exp>),
46    Sub(Box<Exp>, Box<Exp>),
47    Mul(Box<Exp>, Box<Exp>),
48    Div(Box<Exp>, Box<Exp>),
49    Power(Box<Exp>, f64),
50    Neg(Box<Exp>),
51    Sin(Box<Exp>),
52    Cos(Box<Exp>),
53    Ln(Box<Exp>),
54    Exp(Box<Exp>),
55}
56
57#[allow(clippy::self_named_constructors, clippy::should_implement_trait)]
58impl Exp {
59    pub fn var(name: impl Into<String>) -> Self {
60        Exp::Var(name.into())
61    }
62
63    pub fn val(v: f64) -> Self {
64        Exp::Val(v)
65    }
66
67    pub fn add(lhs: Exp, rhs: Exp) -> Self {
68        Exp::Add(Box::new(lhs), Box::new(rhs))
69    }
70
71    pub fn sub(lhs: Exp, rhs: Exp) -> Self {
72        Exp::Sub(Box::new(lhs), Box::new(rhs))
73    }
74
75    pub fn mul(lhs: Exp, rhs: Exp) -> Self {
76        Exp::Mul(Box::new(lhs), Box::new(rhs))
77    }
78
79    pub fn div(lhs: Exp, rhs: Exp) -> Self {
80        Exp::Div(Box::new(lhs), Box::new(rhs))
81    }
82
83    pub fn power(base: Exp, exp: f64) -> Self {
84        Exp::Power(Box::new(base), exp)
85    }
86
87    pub fn neg(exp: Exp) -> Self {
88        Exp::Neg(Box::new(exp))
89    }
90
91    pub fn sin(exp: Exp) -> Self {
92        Exp::Sin(Box::new(exp))
93    }
94
95    pub fn cos(exp: Exp) -> Self {
96        Exp::Cos(Box::new(exp))
97    }
98
99    pub fn ln(exp: Exp) -> Self {
100        Exp::Ln(Box::new(exp))
101    }
102
103    pub fn exp(exp: Exp) -> Self {
104        Exp::Exp(Box::new(exp))
105    }
106
107    pub fn evaluate_checked(
108        &self,
109        vars: &HashMap<String, f64>,
110    ) -> Result<f64, MissingVarError> {
111        match self {
112            Exp::Val(v) => Ok(*v),
113            Exp::Var(name) => vars.get(name).copied().ok_or_else(|| MissingVarError {
114                var_name: name.clone(),
115            }),
116            Exp::Add(l, r) => Ok(l.evaluate_checked(vars)? + r.evaluate_checked(vars)?),
117            Exp::Sub(l, r) => Ok(l.evaluate_checked(vars)? - r.evaluate_checked(vars)?),
118            Exp::Mul(l, r) => Ok(l.evaluate_checked(vars)? * r.evaluate_checked(vars)?),
119            Exp::Div(l, r) => Ok(l.evaluate_checked(vars)? / r.evaluate_checked(vars)?),
120            Exp::Power(base, exp) => Ok(base.evaluate_checked(vars)?.powf(*exp)),
121            Exp::Neg(e) => Ok(-e.evaluate_checked(vars)?),
122            Exp::Sin(e) => Ok(e.evaluate_checked(vars)?.sin()),
123            Exp::Cos(e) => Ok(e.evaluate_checked(vars)?.cos()),
124            Exp::Ln(e) => Ok(e.evaluate_checked(vars)?.ln()),
125            Exp::Exp(e) => Ok(e.evaluate_checked(vars)?.exp()),
126        }
127    }
128
129    pub fn evaluate(&self, vars: &HashMap<String, f64>) -> f64 {
130        match self {
131            Exp::Val(v) => *v,
132            Exp::Var(name) => *vars.get(name).unwrap_or(&0.0),
133            Exp::Add(l, r) => l.evaluate(vars) + r.evaluate(vars),
134            Exp::Sub(l, r) => l.evaluate(vars) - r.evaluate(vars),
135            Exp::Mul(l, r) => l.evaluate(vars) * r.evaluate(vars),
136            Exp::Div(l, r) => l.evaluate(vars) / r.evaluate(vars),
137            Exp::Power(base, exp) => base.evaluate(vars).powf(*exp),
138            Exp::Neg(e) => -e.evaluate(vars),
139            Exp::Sin(e) => e.evaluate(vars).sin(),
140            Exp::Cos(e) => e.evaluate(vars).cos(),
141            Exp::Ln(e) => e.evaluate(vars).ln(),
142            Exp::Exp(e) => e.evaluate(vars).exp(),
143        }
144    }
145
146    pub fn differentiate(&self, var_name: &str) -> Exp {
147        match self {
148            Exp::Val(_) => Exp::Val(0.0),
149            Exp::Var(name) => {
150                if name == var_name {
151                    Exp::Val(1.0)
152                } else {
153                    Exp::Val(0.0)
154                }
155            }
156            Exp::Add(l, r) => Exp::add(l.differentiate(var_name), r.differentiate(var_name)),
157            Exp::Sub(l, r) => Exp::sub(l.differentiate(var_name), r.differentiate(var_name)),
158            Exp::Mul(l, r) => {
159                let dl = l.differentiate(var_name);
160                let dr = r.differentiate(var_name);
161                Exp::add(Exp::mul(dl, (**r).clone()), Exp::mul((**l).clone(), dr))
162            }
163            Exp::Div(l, r) => {
164                let dl = l.differentiate(var_name);
165                let dr = r.differentiate(var_name);
166                Exp::div(
167                    Exp::sub(Exp::mul(dl, (**r).clone()), Exp::mul((**l).clone(), dr)),
168                    Exp::power((**r).clone(), 2.0),
169                )
170            }
171            Exp::Power(base, exp) => {
172                let db = base.differentiate(var_name);
173                Exp::mul(
174                    Exp::mul(Exp::val(*exp), Exp::power((**base).clone(), exp - 1.0)),
175                    db,
176                )
177            }
178            Exp::Neg(e) => Exp::neg(e.differentiate(var_name)),
179            Exp::Sin(e) => {
180                let de = e.differentiate(var_name);
181                Exp::mul(Exp::cos((**e).clone()), de)
182            }
183            Exp::Cos(e) => {
184                let de = e.differentiate(var_name);
185                Exp::neg(Exp::mul(Exp::sin((**e).clone()), de))
186            }
187            Exp::Ln(e) => {
188                let de = e.differentiate(var_name);
189                Exp::div(de, (**e).clone())
190            }
191            Exp::Exp(e) => {
192                let de = e.differentiate(var_name);
193                Exp::mul(Exp::exp((**e).clone()), de)
194            }
195        }
196    }
197
198    pub fn simplify(&self) -> Exp {
199        match self {
200            Exp::Add(l, r) => {
201                let ls = l.simplify();
202                let rs = r.simplify();
203                match (&ls, &rs) {
204                    (Exp::Val(lv), Exp::Val(rv)) => Exp::Val(lv + rv),
205                    (Exp::Val(0.0), _) => rs,
206                    (_, Exp::Val(0.0)) => ls,
207                    _ => Exp::add(ls, rs),
208                }
209            }
210            Exp::Sub(l, r) => {
211                let ls = l.simplify();
212                let rs = r.simplify();
213                match (&ls, &rs) {
214                    (Exp::Val(lv), Exp::Val(rv)) => Exp::Val(lv - rv),
215                    (_, Exp::Val(0.0)) => ls,
216                    _ => Exp::sub(ls, rs),
217                }
218            }
219            Exp::Mul(l, r) => {
220                let ls = l.simplify();
221                let rs = r.simplify();
222                match (&ls, &rs) {
223                    (Exp::Val(lv), Exp::Val(rv)) => Exp::Val(lv * rv),
224                    (Exp::Val(0.0), _) | (_, Exp::Val(0.0)) => Exp::Val(0.0),
225                    (Exp::Val(1.0), _) => rs,
226                    (_, Exp::Val(1.0)) => ls,
227                    _ => Exp::mul(ls, rs),
228                }
229            }
230            Exp::Div(l, r) => {
231                let ls = l.simplify();
232                let rs = r.simplify();
233                match (&ls, &rs) {
234                    (Exp::Val(lv), Exp::Val(rv)) if *rv != 0.0 => Exp::Val(lv / rv),
235                    (Exp::Val(0.0), _) => Exp::Val(0.0),
236                    (_, Exp::Val(1.0)) => ls,
237                    _ => Exp::div(ls, rs),
238                }
239            }
240            Exp::Power(base, exp) => {
241                let bs = base.simplify();
242                match &bs {
243                    Exp::Val(v) => Exp::Val(v.powf(*exp)),
244                    _ if *exp == 0.0 => Exp::Val(1.0),
245                    _ if *exp == 1.0 => bs,
246                    _ => Exp::power(bs, *exp),
247                }
248            }
249            Exp::Neg(e) => {
250                let es = e.simplify();
251                match &es {
252                    Exp::Val(v) => Exp::Val(-v),
253                    _ => Exp::neg(es),
254                }
255            }
256            _ => self.clone(),
257        }
258    }
259}
260
261impl fmt::Display for Exp {
262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263        match self {
264            Exp::Val(v) => write!(f, "{v}"),
265            Exp::Var(name) => write!(f, "{name}"),
266            Exp::Add(l, r) => write!(f, "({l} + {r})"),
267            Exp::Sub(l, r) => write!(f, "({l} - {r})"),
268            Exp::Mul(l, r) => write!(f, "({l} * {r})"),
269            Exp::Div(l, r) => write!(f, "({l} / {r})"),
270            Exp::Power(base, exp) => write!(f, "({base}^{exp})"),
271            Exp::Neg(e) => write!(f, "(-{e})"),
272            Exp::Sin(e) => write!(f, "sin({e})"),
273            Exp::Cos(e) => write!(f, "cos({e})"),
274            Exp::Ln(e) => write!(f, "ln({e})"),
275            Exp::Exp(e) => write!(f, "exp({e})"),
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_evaluate() {
286        let mut vars = HashMap::new();
287        vars.insert("x".to_string(), 2.0);
288        vars.insert("y".to_string(), 3.0);
289
290        let x = Exp::var("x");
291        let y = Exp::var("y");
292
293        let expr = Exp::add(Exp::mul(x.clone(), y.clone()), Exp::val(5.0));
294        assert_eq!(expr.evaluate(&vars), 11.0);
295
296        let expr2 = Exp::power(x.clone(), 2.0);
297        assert_eq!(expr2.evaluate(&vars), 4.0);
298    }
299
300    #[test]
301    fn test_evaluate_checked_missing_var() {
302        let vars = HashMap::new();
303        let x = Exp::var("x");
304        let err = x.evaluate_checked(&vars).expect_err("expected missing variable error");
305        assert_eq!(err.var_name, "x");
306    }
307
308    #[test]
309    fn test_differentiate() {
310        let x = Exp::var("x");
311        let y = Exp::var("y");
312
313        let expr = Exp::mul(x.clone(), y.clone());
314        let dx = expr.differentiate("x");
315        let dy = expr.differentiate("y");
316
317        let mut vars = HashMap::new();
318        vars.insert("x".to_string(), 2.0);
319        vars.insert("y".to_string(), 3.0);
320
321        assert_eq!(dx.evaluate(&vars), 3.0);
322        assert_eq!(dy.evaluate(&vars), 2.0);
323
324        let expr2 = Exp::power(x.clone(), 3.0);
325        let dx2 = expr2.differentiate("x");
326        assert_eq!(dx2.evaluate(&vars), 12.0);
327    }
328
329    #[test]
330    fn test_simplify() {
331        let expr = Exp::add(Exp::val(2.0), Exp::val(3.0));
332        assert_eq!(expr.simplify(), Exp::val(5.0));
333
334        let x = Exp::var("x");
335        let expr2 = Exp::mul(x.clone(), Exp::val(0.0));
336        assert_eq!(expr2.simplify(), Exp::val(0.0));
337
338        let expr3 = Exp::add(x.clone(), Exp::val(0.0));
339        assert_eq!(expr3.simplify(), x);
340    }
341}