1use std::collections::HashMap;
26use std::fmt;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct MissingVarError {
30 pub var_name: String,
31}
32
33impl fmt::Display for MissingVarError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "Missing variable '{}'", self.var_name)
36 }
37}
38
39impl std::error::Error for MissingVarError {}
40
41#[derive(Debug, Clone, PartialEq)]
42pub enum Exp {
43 Val(f64),
44 Var(String),
45 Add(Box<Exp>, Box<Exp>),
46 Sub(Box<Exp>, Box<Exp>),
47 Mul(Box<Exp>, Box<Exp>),
48 Div(Box<Exp>, Box<Exp>),
49 Power(Box<Exp>, f64),
50 Neg(Box<Exp>),
51 Sin(Box<Exp>),
52 Cos(Box<Exp>),
53 Ln(Box<Exp>),
54 Exp(Box<Exp>),
55}
56
57#[allow(clippy::self_named_constructors, clippy::should_implement_trait)]
58impl Exp {
59 pub fn var(name: impl Into<String>) -> Self {
60 Exp::Var(name.into())
61 }
62
63 pub fn val(v: f64) -> Self {
64 Exp::Val(v)
65 }
66
67 pub fn add(lhs: Exp, rhs: Exp) -> Self {
68 Exp::Add(Box::new(lhs), Box::new(rhs))
69 }
70
71 pub fn sub(lhs: Exp, rhs: Exp) -> Self {
72 Exp::Sub(Box::new(lhs), Box::new(rhs))
73 }
74
75 pub fn mul(lhs: Exp, rhs: Exp) -> Self {
76 Exp::Mul(Box::new(lhs), Box::new(rhs))
77 }
78
79 pub fn div(lhs: Exp, rhs: Exp) -> Self {
80 Exp::Div(Box::new(lhs), Box::new(rhs))
81 }
82
83 pub fn power(base: Exp, exp: f64) -> Self {
84 Exp::Power(Box::new(base), exp)
85 }
86
87 pub fn neg(exp: Exp) -> Self {
88 Exp::Neg(Box::new(exp))
89 }
90
91 pub fn sin(exp: Exp) -> Self {
92 Exp::Sin(Box::new(exp))
93 }
94
95 pub fn cos(exp: Exp) -> Self {
96 Exp::Cos(Box::new(exp))
97 }
98
99 pub fn ln(exp: Exp) -> Self {
100 Exp::Ln(Box::new(exp))
101 }
102
103 pub fn exp(exp: Exp) -> Self {
104 Exp::Exp(Box::new(exp))
105 }
106
107 pub fn evaluate_checked(
108 &self,
109 vars: &HashMap<String, f64>,
110 ) -> Result<f64, MissingVarError> {
111 match self {
112 Exp::Val(v) => Ok(*v),
113 Exp::Var(name) => vars.get(name).copied().ok_or_else(|| MissingVarError {
114 var_name: name.clone(),
115 }),
116 Exp::Add(l, r) => Ok(l.evaluate_checked(vars)? + r.evaluate_checked(vars)?),
117 Exp::Sub(l, r) => Ok(l.evaluate_checked(vars)? - r.evaluate_checked(vars)?),
118 Exp::Mul(l, r) => Ok(l.evaluate_checked(vars)? * r.evaluate_checked(vars)?),
119 Exp::Div(l, r) => Ok(l.evaluate_checked(vars)? / r.evaluate_checked(vars)?),
120 Exp::Power(base, exp) => Ok(base.evaluate_checked(vars)?.powf(*exp)),
121 Exp::Neg(e) => Ok(-e.evaluate_checked(vars)?),
122 Exp::Sin(e) => Ok(e.evaluate_checked(vars)?.sin()),
123 Exp::Cos(e) => Ok(e.evaluate_checked(vars)?.cos()),
124 Exp::Ln(e) => Ok(e.evaluate_checked(vars)?.ln()),
125 Exp::Exp(e) => Ok(e.evaluate_checked(vars)?.exp()),
126 }
127 }
128
129 pub fn evaluate(&self, vars: &HashMap<String, f64>) -> f64 {
130 match self {
131 Exp::Val(v) => *v,
132 Exp::Var(name) => *vars.get(name).unwrap_or(&0.0),
133 Exp::Add(l, r) => l.evaluate(vars) + r.evaluate(vars),
134 Exp::Sub(l, r) => l.evaluate(vars) - r.evaluate(vars),
135 Exp::Mul(l, r) => l.evaluate(vars) * r.evaluate(vars),
136 Exp::Div(l, r) => l.evaluate(vars) / r.evaluate(vars),
137 Exp::Power(base, exp) => base.evaluate(vars).powf(*exp),
138 Exp::Neg(e) => -e.evaluate(vars),
139 Exp::Sin(e) => e.evaluate(vars).sin(),
140 Exp::Cos(e) => e.evaluate(vars).cos(),
141 Exp::Ln(e) => e.evaluate(vars).ln(),
142 Exp::Exp(e) => e.evaluate(vars).exp(),
143 }
144 }
145
146 pub fn differentiate(&self, var_name: &str) -> Exp {
147 match self {
148 Exp::Val(_) => Exp::Val(0.0),
149 Exp::Var(name) => {
150 if name == var_name {
151 Exp::Val(1.0)
152 } else {
153 Exp::Val(0.0)
154 }
155 }
156 Exp::Add(l, r) => Exp::add(l.differentiate(var_name), r.differentiate(var_name)),
157 Exp::Sub(l, r) => Exp::sub(l.differentiate(var_name), r.differentiate(var_name)),
158 Exp::Mul(l, r) => {
159 let dl = l.differentiate(var_name);
160 let dr = r.differentiate(var_name);
161 Exp::add(Exp::mul(dl, (**r).clone()), Exp::mul((**l).clone(), dr))
162 }
163 Exp::Div(l, r) => {
164 let dl = l.differentiate(var_name);
165 let dr = r.differentiate(var_name);
166 Exp::div(
167 Exp::sub(Exp::mul(dl, (**r).clone()), Exp::mul((**l).clone(), dr)),
168 Exp::power((**r).clone(), 2.0),
169 )
170 }
171 Exp::Power(base, exp) => {
172 let db = base.differentiate(var_name);
173 Exp::mul(
174 Exp::mul(Exp::val(*exp), Exp::power((**base).clone(), exp - 1.0)),
175 db,
176 )
177 }
178 Exp::Neg(e) => Exp::neg(e.differentiate(var_name)),
179 Exp::Sin(e) => {
180 let de = e.differentiate(var_name);
181 Exp::mul(Exp::cos((**e).clone()), de)
182 }
183 Exp::Cos(e) => {
184 let de = e.differentiate(var_name);
185 Exp::neg(Exp::mul(Exp::sin((**e).clone()), de))
186 }
187 Exp::Ln(e) => {
188 let de = e.differentiate(var_name);
189 Exp::div(de, (**e).clone())
190 }
191 Exp::Exp(e) => {
192 let de = e.differentiate(var_name);
193 Exp::mul(Exp::exp((**e).clone()), de)
194 }
195 }
196 }
197
198 pub fn simplify(&self) -> Exp {
199 match self {
200 Exp::Add(l, r) => {
201 let ls = l.simplify();
202 let rs = r.simplify();
203 match (&ls, &rs) {
204 (Exp::Val(lv), Exp::Val(rv)) => Exp::Val(lv + rv),
205 (Exp::Val(0.0), _) => rs,
206 (_, Exp::Val(0.0)) => ls,
207 _ => Exp::add(ls, rs),
208 }
209 }
210 Exp::Sub(l, r) => {
211 let ls = l.simplify();
212 let rs = r.simplify();
213 match (&ls, &rs) {
214 (Exp::Val(lv), Exp::Val(rv)) => Exp::Val(lv - rv),
215 (_, Exp::Val(0.0)) => ls,
216 _ => Exp::sub(ls, rs),
217 }
218 }
219 Exp::Mul(l, r) => {
220 let ls = l.simplify();
221 let rs = r.simplify();
222 match (&ls, &rs) {
223 (Exp::Val(lv), Exp::Val(rv)) => Exp::Val(lv * rv),
224 (Exp::Val(0.0), _) | (_, Exp::Val(0.0)) => Exp::Val(0.0),
225 (Exp::Val(1.0), _) => rs,
226 (_, Exp::Val(1.0)) => ls,
227 _ => Exp::mul(ls, rs),
228 }
229 }
230 Exp::Div(l, r) => {
231 let ls = l.simplify();
232 let rs = r.simplify();
233 match (&ls, &rs) {
234 (Exp::Val(lv), Exp::Val(rv)) if *rv != 0.0 => Exp::Val(lv / rv),
235 (Exp::Val(0.0), _) => Exp::Val(0.0),
236 (_, Exp::Val(1.0)) => ls,
237 _ => Exp::div(ls, rs),
238 }
239 }
240 Exp::Power(base, exp) => {
241 let bs = base.simplify();
242 match &bs {
243 Exp::Val(v) => Exp::Val(v.powf(*exp)),
244 _ if *exp == 0.0 => Exp::Val(1.0),
245 _ if *exp == 1.0 => bs,
246 _ => Exp::power(bs, *exp),
247 }
248 }
249 Exp::Neg(e) => {
250 let es = e.simplify();
251 match &es {
252 Exp::Val(v) => Exp::Val(-v),
253 _ => Exp::neg(es),
254 }
255 }
256 _ => self.clone(),
257 }
258 }
259}
260
261impl fmt::Display for Exp {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 match self {
264 Exp::Val(v) => write!(f, "{v}"),
265 Exp::Var(name) => write!(f, "{name}"),
266 Exp::Add(l, r) => write!(f, "({l} + {r})"),
267 Exp::Sub(l, r) => write!(f, "({l} - {r})"),
268 Exp::Mul(l, r) => write!(f, "({l} * {r})"),
269 Exp::Div(l, r) => write!(f, "({l} / {r})"),
270 Exp::Power(base, exp) => write!(f, "({base}^{exp})"),
271 Exp::Neg(e) => write!(f, "(-{e})"),
272 Exp::Sin(e) => write!(f, "sin({e})"),
273 Exp::Cos(e) => write!(f, "cos({e})"),
274 Exp::Ln(e) => write!(f, "ln({e})"),
275 Exp::Exp(e) => write!(f, "exp({e})"),
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_evaluate() {
286 let mut vars = HashMap::new();
287 vars.insert("x".to_string(), 2.0);
288 vars.insert("y".to_string(), 3.0);
289
290 let x = Exp::var("x");
291 let y = Exp::var("y");
292
293 let expr = Exp::add(Exp::mul(x.clone(), y.clone()), Exp::val(5.0));
294 assert_eq!(expr.evaluate(&vars), 11.0);
295
296 let expr2 = Exp::power(x.clone(), 2.0);
297 assert_eq!(expr2.evaluate(&vars), 4.0);
298 }
299
300 #[test]
301 fn test_evaluate_checked_missing_var() {
302 let vars = HashMap::new();
303 let x = Exp::var("x");
304 let err = x.evaluate_checked(&vars).expect_err("expected missing variable error");
305 assert_eq!(err.var_name, "x");
306 }
307
308 #[test]
309 fn test_differentiate() {
310 let x = Exp::var("x");
311 let y = Exp::var("y");
312
313 let expr = Exp::mul(x.clone(), y.clone());
314 let dx = expr.differentiate("x");
315 let dy = expr.differentiate("y");
316
317 let mut vars = HashMap::new();
318 vars.insert("x".to_string(), 2.0);
319 vars.insert("y".to_string(), 3.0);
320
321 assert_eq!(dx.evaluate(&vars), 3.0);
322 assert_eq!(dy.evaluate(&vars), 2.0);
323
324 let expr2 = Exp::power(x.clone(), 3.0);
325 let dx2 = expr2.differentiate("x");
326 assert_eq!(dx2.evaluate(&vars), 12.0);
327 }
328
329 #[test]
330 fn test_simplify() {
331 let expr = Exp::add(Exp::val(2.0), Exp::val(3.0));
332 assert_eq!(expr.simplify(), Exp::val(5.0));
333
334 let x = Exp::var("x");
335 let expr2 = Exp::mul(x.clone(), Exp::val(0.0));
336 assert_eq!(expr2.simplify(), Exp::val(0.0));
337
338 let expr3 = Exp::add(x.clone(), Exp::val(0.0));
339 assert_eq!(expr3.simplify(), x);
340 }
341}