qudit_expr/expressions/
base.rs

1use num::FromPrimitive;
2use num::ToPrimitive;
3use num::bigint::BigInt;
4use num::rational::Ratio;
5use qudit_core::RealScalar;
6use std::collections::HashMap;
7use std::collections::HashSet;
8
9use crate::analysis::simplify;
10
11pub type Rational = Ratio<BigInt>;
12pub type Constant = Rational;
13
14#[derive(Clone)]
15pub enum Expression {
16    Pi,
17    Variable(String),
18    Constant(Constant),
19    Neg(Box<Expression>),
20    Add(Box<Expression>, Box<Expression>),
21    Sub(Box<Expression>, Box<Expression>),
22    Mul(Box<Expression>, Box<Expression>),
23    Div(Box<Expression>, Box<Expression>),
24    Pow(Box<Expression>, Box<Expression>),
25    Sqrt(Box<Expression>),
26    Sin(Box<Expression>),
27    Cos(Box<Expression>),
28}
29
30impl Expression {
31    pub fn zero() -> Self {
32        Expression::Constant(Constant::new(BigInt::from(0), BigInt::from(1)))
33    }
34
35    pub fn one() -> Self {
36        Expression::Constant(Constant::new(BigInt::from(1), BigInt::from(1)))
37    }
38
39    pub fn from_int(n: i64) -> Self {
40        Expression::Constant(Constant::new(BigInt::from(n), BigInt::from(1)))
41    }
42
43    pub fn from_float(f: f64) -> Self {
44        Self::from_float_64(f)
45    }
46
47    pub fn from_float_32(f: f32) -> Self {
48        Expression::Constant(Constant::from_f32(f).unwrap())
49    }
50
51    pub fn from_float_64(f: f64) -> Self {
52        Expression::Constant(Constant::from_f64(f).unwrap())
53    }
54
55    pub fn to_float(&self) -> f64 {
56        match self {
57            Expression::Constant(c) => c.to_f64().unwrap(),
58            Expression::Variable(_) => panic!("Cannot convert variable to float"),
59            Expression::Pi => std::f64::consts::PI,
60            Expression::Neg(expr) => -expr.to_float(),
61            Expression::Add(lhs, rhs) => lhs.to_float() + rhs.to_float(),
62            Expression::Sub(lhs, rhs) => lhs.to_float() - rhs.to_float(),
63            Expression::Mul(lhs, rhs) => lhs.to_float() * rhs.to_float(),
64            Expression::Div(lhs, rhs) => lhs.to_float() / rhs.to_float(),
65            Expression::Pow(lhs, rhs) => lhs.to_float().powf(rhs.to_float()),
66            Expression::Sqrt(expr) => expr.to_float().sqrt(),
67            Expression::Sin(expr) => expr.to_float().sin(),
68            Expression::Cos(expr) => expr.to_float().cos(),
69        }
70    }
71
72    pub fn to_constant(&self) -> Constant {
73        // TODO: Figure out how to maintain precision by doing math over Constants
74        Constant::from_float(self.to_float()).unwrap()
75    }
76
77    pub fn gather_context(&self) -> HashSet<String> {
78        let mut context = HashSet::new();
79        context.insert(self.to_string());
80        match self {
81            Expression::Pi => {
82                context.insert("pi".to_string());
83            }
84            Expression::Variable(var) => {
85                context.insert(var.clone());
86            }
87            Expression::Constant(_) => {
88                context.insert(self.to_string());
89                context.insert(self.to_float().to_string());
90            }
91            Expression::Neg(expr) => {
92                context.extend(expr.gather_context());
93            }
94            Expression::Add(lhs, rhs) => {
95                context.extend(lhs.gather_context());
96                context.extend(rhs.gather_context());
97            }
98            Expression::Sub(lhs, rhs) => {
99                context.extend(lhs.gather_context());
100                context.extend(rhs.gather_context());
101            }
102            Expression::Mul(lhs, rhs) => {
103                context.extend(lhs.gather_context());
104                context.extend(rhs.gather_context());
105            }
106            Expression::Div(lhs, rhs) => {
107                context.extend(lhs.gather_context());
108                context.extend(rhs.gather_context());
109            }
110            Expression::Pow(lhs, rhs) => {
111                context.extend(lhs.gather_context());
112                context.extend(rhs.gather_context());
113            }
114            Expression::Sqrt(expr) => {
115                context.extend(expr.gather_context());
116            }
117            Expression::Sin(expr) => {
118                context.extend(expr.gather_context());
119            }
120            Expression::Cos(expr) => {
121                context.extend(expr.gather_context());
122            }
123        }
124        context
125    }
126
127    pub fn is_zero(&self) -> bool {
128        match self {
129            Expression::Constant(c) => *c.numer() == BigInt::from(0),
130            Expression::Neg(expr) => expr.is_zero(),
131            Expression::Add(lhs, rhs) => lhs.is_zero() && rhs.is_zero(),
132            Expression::Sub(lhs, rhs) => (lhs.is_zero() && rhs.is_zero()) || lhs == rhs,
133            Expression::Mul(lhs, rhs) => lhs.is_zero() || rhs.is_zero(),
134            Expression::Div(lhs, _) => lhs.is_zero(),
135            Expression::Pow(lhs, rhs) => lhs.is_zero() && !rhs.is_zero(),
136            Expression::Sqrt(expr) => expr.is_zero(),
137            Expression::Sin(expr) => expr.is_zero(),
138            Expression::Cos(expr) => {
139                !expr.is_parameterized()
140                    && (expr.eval::<f64>(&HashMap::new()) - std::f64::consts::PI / 2.0) < 1e-6
141            }
142            Expression::Pi => false,
143            Expression::Variable(_) => false,
144        }
145    }
146
147    /// Conservative check for zero. This is faster than the exact check.
148    pub fn is_zero_fast(&self) -> bool {
149        match self {
150            Expression::Constant(c) => *c.numer() == BigInt::from(0),
151            Expression::Neg(expr) => expr.is_zero_fast(),
152            Expression::Add(lhs, rhs) => lhs.is_zero_fast() && rhs.is_zero_fast(),
153            Expression::Sub(lhs, rhs) => lhs.is_zero_fast() && rhs.is_zero_fast(),
154            Expression::Mul(lhs, rhs) => lhs.is_zero_fast() || rhs.is_zero_fast(),
155            Expression::Div(lhs, _) => lhs.is_zero_fast(),
156            Expression::Pow(lhs, rhs) => lhs.is_zero_fast() && !rhs.is_zero_fast(),
157            Expression::Sqrt(expr) => expr.is_zero_fast(),
158            Expression::Sin(expr) => expr.is_zero_fast(),
159            Expression::Cos(_expr) => false,
160            Expression::Pi => false,
161            Expression::Variable(_) => false,
162        }
163    }
164
165    pub fn is_one(&self) -> bool {
166        match self {
167            Expression::Constant(c) => *c.numer() == *c.denom(),
168            Expression::Neg(expr) => {
169                !expr.is_parameterized() && expr.eval::<f64>(&HashMap::new()) == -1.0
170            }
171            Expression::Add(lhs, rhs) => {
172                lhs.is_one() && rhs.is_zero() || lhs.is_zero() && rhs.is_one()
173            }
174            Expression::Sub(lhs, rhs) => lhs.is_one() && rhs.is_zero(),
175            Expression::Mul(lhs, rhs) => lhs.is_one() && rhs.is_one(),
176            Expression::Div(lhs, rhs) => lhs == rhs && !rhs.is_zero(),
177            Expression::Pow(lhs, _rhs) => lhs.is_one(),
178            Expression::Sqrt(expr) => expr.is_one(),
179            Expression::Sin(expr) => {
180                !expr.is_parameterized()
181                    && (expr.eval::<f64>(&HashMap::new()) - std::f64::consts::PI / 2.0) < 1e-6
182            }
183            Expression::Cos(expr) => expr.is_zero(),
184            Expression::Pi => false,
185            Expression::Variable(_) => false,
186        }
187    }
188
189    pub fn is_one_fast(&self) -> bool {
190        match self {
191            Expression::Constant(c) => *c.numer() == *c.denom(),
192            Expression::Neg(_expr) => false,
193            Expression::Add(lhs, rhs) => {
194                lhs.is_one_fast() && rhs.is_zero_fast() || lhs.is_zero_fast() && rhs.is_one_fast()
195            }
196            Expression::Sub(lhs, rhs) => lhs.is_one_fast() && rhs.is_zero_fast(),
197            Expression::Mul(lhs, rhs) => lhs.is_one_fast() && rhs.is_one_fast(),
198            Expression::Div(lhs, rhs) => lhs.is_one_fast() && rhs.is_one_fast(),
199            Expression::Pow(lhs, _rhs) => lhs.is_one_fast(),
200            Expression::Sqrt(expr) => expr.is_one_fast(),
201            Expression::Sin(_expr) => false,
202            Expression::Cos(expr) => expr.is_zero_fast(),
203            Expression::Pi => false,
204            Expression::Variable(_) => false,
205        }
206    }
207
208    pub fn contains_variable<T: AsRef<str>>(&self, var: T) -> bool {
209        let var = var.as_ref();
210        match self {
211            Expression::Pi => false,
212            Expression::Variable(v) => v == var,
213            Expression::Constant(_) => false,
214            Expression::Neg(expr) => expr.contains_variable(var),
215            Expression::Add(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
216            Expression::Sub(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
217            Expression::Mul(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
218            Expression::Div(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
219            Expression::Pow(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
220            Expression::Sqrt(expr) => expr.contains_variable(var),
221            Expression::Sin(expr) => expr.contains_variable(var),
222            Expression::Cos(expr) => expr.contains_variable(var),
223        }
224    }
225
226    pub fn is_parameterized(&self) -> bool {
227        match self {
228            Expression::Pi => false,
229            Expression::Variable(_) => true,
230            Expression::Constant(_) => false,
231            Expression::Neg(expr) => expr.is_parameterized(),
232            Expression::Add(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
233            Expression::Sub(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
234            Expression::Mul(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
235            Expression::Div(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
236            Expression::Pow(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
237            Expression::Sqrt(expr) => expr.is_parameterized(),
238            Expression::Sin(expr) => expr.is_parameterized(),
239            Expression::Cos(expr) => expr.is_parameterized(),
240        }
241    }
242
243    pub fn eval<R: RealScalar>(&self, args: &HashMap<&str, R>) -> R {
244        match self {
245            Expression::Pi => R::PI(),
246            Expression::Variable(var) => {
247                if let Some(val) = args.get(var.as_str()) {
248                    *val
249                } else {
250                    panic!("Variable {} not found in arguments", var)
251                }
252            }
253            Expression::Constant(c) => R::from_ratio(c.clone()).unwrap(),
254            Expression::Neg(expr) => -expr.eval(args),
255            Expression::Add(lhs, rhs) => lhs.eval(args) + rhs.eval(args),
256            Expression::Sub(lhs, rhs) => lhs.eval(args) - rhs.eval(args),
257            Expression::Mul(lhs, rhs) => lhs.eval(args) * rhs.eval(args),
258            Expression::Div(lhs, rhs) => lhs.eval(args) / rhs.eval(args),
259            Expression::Pow(lhs, rhs) => lhs.eval(args).powf(rhs.eval(args)),
260            Expression::Sqrt(expr) => expr.eval(args).sqrt(),
261            Expression::Sin(expr) => expr.eval(args).sin(),
262            Expression::Cos(expr) => expr.eval(args).cos(),
263        }
264    }
265
266    /// Uses a magic value to evaluate the expression. This is useful for hashing expressions.
267    pub fn hash_eval(&self) -> f64 {
268        let val = match self {
269            Expression::Pi => self.to_float(),
270            Expression::Variable(_) => 1.7,
271            Expression::Constant(_) => self.to_float(),
272            Expression::Neg(expr) => -expr.hash_eval(),
273            Expression::Add(lhs, rhs) => lhs.hash_eval() + rhs.hash_eval(),
274            Expression::Sub(lhs, rhs) => lhs.hash_eval() - rhs.hash_eval(),
275            Expression::Mul(lhs, rhs) => lhs.hash_eval() * rhs.hash_eval(),
276            Expression::Div(lhs, rhs) => lhs.hash_eval() / rhs.hash_eval(),
277            Expression::Pow(lhs, rhs) => lhs.hash_eval().powf(rhs.hash_eval()),
278            Expression::Sqrt(expr) => expr.hash_eval().sqrt(),
279            Expression::Sin(expr) => expr.hash_eval().sin(),
280            Expression::Cos(expr) => expr.hash_eval().cos(),
281        };
282
283        if val.is_nan() || val.is_subnormal() {
284            0.0
285        } else {
286            val
287        }
288    }
289
290    pub fn map_var_names(&self, var_map: &HashMap<String, String>) -> Self {
291        match self {
292            Expression::Pi => Expression::Pi,
293            Expression::Variable(var) => {
294                if let Some(new_var) = var_map.get(var.as_str()) {
295                    Expression::Variable(new_var.to_string())
296                } else {
297                    Expression::Variable(var.clone())
298                }
299            }
300            Expression::Constant(c) => Expression::Constant(c.clone()),
301            Expression::Neg(expr) => Expression::Neg(Box::new(expr.map_var_names(var_map))),
302            Expression::Add(lhs, rhs) => Expression::Add(
303                Box::new(lhs.map_var_names(var_map)),
304                Box::new(rhs.map_var_names(var_map)),
305            ),
306            Expression::Sub(lhs, rhs) => Expression::Sub(
307                Box::new(lhs.map_var_names(var_map)),
308                Box::new(rhs.map_var_names(var_map)),
309            ),
310            Expression::Mul(lhs, rhs) => Expression::Mul(
311                Box::new(lhs.map_var_names(var_map)),
312                Box::new(rhs.map_var_names(var_map)),
313            ),
314            Expression::Div(lhs, rhs) => Expression::Div(
315                Box::new(lhs.map_var_names(var_map)),
316                Box::new(rhs.map_var_names(var_map)),
317            ),
318            Expression::Pow(lhs, rhs) => Expression::Pow(
319                Box::new(lhs.map_var_names(var_map)),
320                Box::new(rhs.map_var_names(var_map)),
321            ),
322            Expression::Sqrt(expr) => Expression::Sqrt(Box::new(expr.map_var_names(var_map))),
323            Expression::Sin(expr) => Expression::Sin(Box::new(expr.map_var_names(var_map))),
324            Expression::Cos(expr) => Expression::Cos(Box::new(expr.map_var_names(var_map))),
325        }
326    }
327
328    pub fn rename_variable<S: AsRef<str>, T: AsRef<str>>(&self, original: S, new: T) -> Self {
329        let original = original.as_ref();
330        let new = new.as_ref();
331        match self {
332            Expression::Pi => Expression::Pi,
333            Expression::Variable(var) => {
334                if var == original {
335                    Expression::Variable(new.to_string())
336                } else {
337                    Expression::Variable(var.clone())
338                }
339            }
340            Expression::Constant(c) => Expression::Constant(c.clone()),
341            Expression::Neg(expr) => Expression::Neg(Box::new(expr.rename_variable(original, new))),
342            Expression::Add(lhs, rhs) => Expression::Add(
343                Box::new(lhs.rename_variable(original, new)),
344                Box::new(rhs.rename_variable(original, new)),
345            ),
346            Expression::Sub(lhs, rhs) => Expression::Sub(
347                Box::new(lhs.rename_variable(original, new)),
348                Box::new(rhs.rename_variable(original, new)),
349            ),
350            Expression::Mul(lhs, rhs) => Expression::Mul(
351                Box::new(lhs.rename_variable(original, new)),
352                Box::new(rhs.rename_variable(original, new)),
353            ),
354            Expression::Div(lhs, rhs) => Expression::Div(
355                Box::new(lhs.rename_variable(original, new)),
356                Box::new(rhs.rename_variable(original, new)),
357            ),
358            Expression::Pow(lhs, rhs) => Expression::Pow(
359                Box::new(lhs.rename_variable(original, new)),
360                Box::new(rhs.rename_variable(original, new)),
361            ),
362            Expression::Sqrt(expr) => {
363                Expression::Sqrt(Box::new(expr.rename_variable(original, new)))
364            }
365            Expression::Sin(expr) => Expression::Sin(Box::new(expr.rename_variable(original, new))),
366            Expression::Cos(expr) => Expression::Cos(Box::new(expr.rename_variable(original, new))),
367        }
368    }
369
370    pub fn differentiate<S: AsRef<str>>(&self, wrt: S) -> Self {
371        let wrt = wrt.as_ref();
372        match self {
373            Expression::Pi => Expression::zero(),
374            Expression::Variable(var) => {
375                if var == wrt {
376                    Expression::one()
377                } else {
378                    Expression::zero()
379                }
380            }
381            Expression::Constant(_) => Expression::zero(),
382            Expression::Neg(expr) => Expression::Neg(Box::new(expr.differentiate(wrt))),
383            Expression::Add(lhs, rhs) => Expression::Add(
384                Box::new(lhs.differentiate(wrt)),
385                Box::new(rhs.differentiate(wrt)),
386            ),
387            Expression::Sub(lhs, rhs) => Expression::Sub(
388                Box::new(lhs.differentiate(wrt)),
389                Box::new(rhs.differentiate(wrt)),
390            ),
391            Expression::Mul(lhs, rhs) => {
392                lhs.differentiate(wrt) * *rhs.clone() + *lhs.clone() * rhs.differentiate(wrt)
393            }
394            Expression::Div(lhs, rhs) => {
395                (lhs.differentiate(wrt) * *rhs.clone() - *lhs.clone() * rhs.differentiate(wrt))
396                    / (*rhs.clone() * *rhs.clone())
397            }
398            Expression::Pow(lhs, rhs) => {
399                let base_fn_x = lhs.contains_variable(wrt);
400                let exponent_fn_x = rhs.contains_variable(wrt);
401
402                if !base_fn_x && !exponent_fn_x {
403                    Expression::zero()
404                } else if !base_fn_x && exponent_fn_x {
405                    if lhs.is_parameterized() {
406                        todo!(
407                            "Cannot differentiate with respect to a parameterized power base until ln is implemented"
408                        )
409                    } else {
410                        self.clone()
411                            * rhs.differentiate(wrt)
412                            * Expression::from_float(lhs.eval::<f64>(&HashMap::new()).ln())
413                    }
414                } else if base_fn_x && !exponent_fn_x {
415                    *rhs.clone()
416                        * Expression::Pow(
417                            Box::new(*lhs.clone()),
418                            Box::new(*rhs.clone() - Expression::one()),
419                        )
420                        * lhs.differentiate(wrt)
421                } else {
422                    todo!(
423                        "Cannot differentiate with respect to a parameterized base and exponent until ln is implemented"
424                    )
425                }
426            }
427            Expression::Sqrt(expr) => {
428                let two = Expression::from_int(2);
429                (Expression::one() / (two * self.clone())) * expr.differentiate(wrt)
430            }
431            Expression::Sin(expr) => {
432                Expression::Cos(Box::new(*expr.clone())) * expr.differentiate(wrt)
433            }
434            Expression::Cos(expr) => {
435                Expression::Neg(Box::new(Expression::Sin(Box::new(*expr.clone()))))
436                    * expr.differentiate(wrt)
437            }
438        }
439    }
440
441    pub fn get_ancestors<S: AsRef<str>>(&self, variable: S) -> Vec<Expression> {
442        let variable = variable.as_ref();
443        let mut ancestors = Vec::new();
444        match self {
445            Expression::Pi => {}
446            Expression::Variable(var) => {
447                if var == variable {
448                    ancestors.push(self.clone());
449                }
450            }
451            Expression::Constant(_) => {}
452            Expression::Neg(expr) => {
453                let node_ancsestors = expr.get_ancestors(variable);
454                let is_empty = node_ancsestors.is_empty();
455                ancestors.extend(node_ancsestors);
456                if !is_empty {
457                    ancestors.push(self.clone());
458                }
459            }
460            Expression::Add(lhs, rhs) => {
461                let lhs_ancestors = lhs.get_ancestors(variable);
462                let rhs_ancestors = rhs.get_ancestors(variable);
463                let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
464                ancestors.extend(lhs_ancestors);
465                ancestors.extend(rhs_ancestors);
466                if !is_empty {
467                    ancestors.push(self.clone());
468                }
469            }
470            Expression::Sub(lhs, rhs) => {
471                let lhs_ancestors = lhs.get_ancestors(variable);
472                let rhs_ancestors = rhs.get_ancestors(variable);
473                let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
474                ancestors.extend(lhs_ancestors);
475                ancestors.extend(rhs_ancestors);
476                if !is_empty {
477                    ancestors.push(self.clone());
478                }
479            }
480            Expression::Mul(lhs, rhs) => {
481                let lhs_ancestors = lhs.get_ancestors(variable);
482                let rhs_ancestors = rhs.get_ancestors(variable);
483                let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
484                ancestors.extend(lhs_ancestors);
485                ancestors.extend(rhs_ancestors);
486                if !is_empty {
487                    ancestors.push(self.clone());
488                }
489            }
490            Expression::Div(lhs, rhs) => {
491                let lhs_ancestors = lhs.get_ancestors(variable);
492                let rhs_ancestors = rhs.get_ancestors(variable);
493                let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
494                ancestors.extend(lhs_ancestors);
495                ancestors.extend(rhs_ancestors);
496                if !is_empty {
497                    ancestors.push(self.clone());
498                }
499            }
500            Expression::Pow(lhs, rhs) => {
501                let lhs_ancestors = lhs.get_ancestors(variable);
502                let rhs_ancestors = rhs.get_ancestors(variable);
503                let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
504                ancestors.extend(lhs_ancestors);
505                ancestors.extend(rhs_ancestors);
506                if !is_empty {
507                    ancestors.push(self.clone());
508                }
509            }
510            Expression::Sqrt(expr) => {
511                let node_ancsestors = expr.get_ancestors(variable);
512                let is_empty = node_ancsestors.is_empty();
513                ancestors.extend(node_ancsestors);
514                if !is_empty {
515                    ancestors.push(self.clone());
516                }
517            }
518            Expression::Sin(expr) => {
519                let node_ancsestors = expr.get_ancestors(variable);
520                let is_empty = node_ancsestors.is_empty();
521                ancestors.extend(node_ancsestors);
522                if !is_empty {
523                    ancestors.push(self.clone());
524                }
525            }
526            Expression::Cos(expr) => {
527                let node_ancsestors = expr.get_ancestors(variable);
528                let is_empty = node_ancsestors.is_empty();
529                ancestors.extend(node_ancsestors);
530                if !is_empty {
531                    ancestors.push(self.clone());
532                }
533            }
534        }
535        ancestors
536    }
537
538    pub fn fast_eq(&self, other: &Expression) -> bool {
539        match (self, other) {
540            (Expression::Pi, Expression::Pi) => true,
541            (Expression::Variable(var1), Expression::Variable(var2)) => var1 == var2,
542            (Expression::Constant(c1), Expression::Constant(c2)) => c1 == c2,
543            (Expression::Neg(expr1), Expression::Neg(expr2)) => expr1.fast_eq(expr2),
544            (Expression::Add(lhs1, rhs1), Expression::Add(lhs2, rhs2)) => {
545                (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
546                    || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
547            }
548            (Expression::Sub(lhs1, rhs1), Expression::Sub(lhs2, rhs2)) => {
549                (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
550                    || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
551            }
552            (Expression::Mul(lhs1, rhs1), Expression::Mul(lhs2, rhs2)) => {
553                (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
554                    || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
555            }
556            (Expression::Div(lhs1, rhs1), Expression::Div(lhs2, rhs2)) => {
557                (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
558                    || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
559            }
560            (Expression::Pow(lhs1, rhs1), Expression::Pow(lhs2, rhs2)) => {
561                (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
562                    || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
563            }
564            (Expression::Sqrt(expr1), Expression::Sqrt(expr2)) => expr1.fast_eq(expr2),
565            (Expression::Sin(expr1), Expression::Sin(expr2)) => expr1.fast_eq(expr2),
566            (Expression::Cos(expr1), Expression::Cos(expr2)) => expr1.fast_eq(expr2),
567            _ => false,
568        }
569    }
570
571    pub fn substitute<S: AsRef<Expression>, T: AsRef<Expression>>(
572        &self,
573        original: S,
574        substitution: T,
575    ) -> Self {
576        let original = original.as_ref();
577        let substitution = substitution.as_ref();
578        if self.fast_eq(original) {
579            return substitution.clone();
580        }
581        match self {
582            Expression::Pi => self.clone(),
583            Expression::Variable(_) => self.clone(),
584            Expression::Constant(_) => self.clone(),
585            Expression::Neg(expr) => {
586                Expression::Neg(Box::new(expr.substitute(original, substitution)))
587            }
588            Expression::Add(lhs, rhs) => Expression::Add(
589                Box::new(lhs.substitute(original, substitution)),
590                Box::new(rhs.substitute(original, substitution)),
591            ),
592            Expression::Sub(lhs, rhs) => Expression::Sub(
593                Box::new(lhs.substitute(original, substitution)),
594                Box::new(rhs.substitute(original, substitution)),
595            ),
596            Expression::Mul(lhs, rhs) => Expression::Mul(
597                Box::new(lhs.substitute(original, substitution)),
598                Box::new(rhs.substitute(original, substitution)),
599            ),
600            Expression::Div(lhs, rhs) => Expression::Div(
601                Box::new(lhs.substitute(original, substitution)),
602                Box::new(rhs.substitute(original, substitution)),
603            ),
604            Expression::Pow(lhs, rhs) => Expression::Pow(
605                Box::new(lhs.substitute(original, substitution)),
606                Box::new(rhs.substitute(original, substitution)),
607            ),
608            Expression::Sqrt(expr) => {
609                Expression::Sqrt(Box::new(expr.substitute(original, substitution)))
610            }
611            Expression::Sin(expr) => {
612                Expression::Sin(Box::new(expr.substitute(original, substitution)))
613            }
614            Expression::Cos(expr) => {
615                Expression::Cos(Box::new(expr.substitute(original, substitution)))
616            }
617        }
618    }
619
620    pub fn simplify(&self) -> Self {
621        simplify(self)
622    }
623
624    pub fn get_unique_variables(&self) -> Vec<String> {
625        match self {
626            Expression::Pi => {
627                vec![]
628            }
629            Expression::Variable(s) => {
630                vec![s.clone()]
631            }
632            Expression::Constant(_) => {
633                vec![]
634            }
635            Expression::Neg(expr) => expr.get_unique_variables(),
636            Expression::Add(lhs, rhs) => {
637                let mut l = lhs.get_unique_variables();
638                for r in rhs.get_unique_variables().into_iter() {
639                    if !l.contains(&r) {
640                        l.push(r)
641                    }
642                }
643                l
644            }
645            Expression::Sub(lhs, rhs) => {
646                let mut l = lhs.get_unique_variables();
647                for r in rhs.get_unique_variables().into_iter() {
648                    if !l.contains(&r) {
649                        l.push(r)
650                    }
651                }
652                l
653            }
654            Expression::Mul(lhs, rhs) => {
655                let mut l = lhs.get_unique_variables();
656                for r in rhs.get_unique_variables().into_iter() {
657                    if !l.contains(&r) {
658                        l.push(r)
659                    }
660                }
661                l
662            }
663            Expression::Div(lhs, rhs) => {
664                let mut l = lhs.get_unique_variables();
665                for r in rhs.get_unique_variables().into_iter() {
666                    if !l.contains(&r) {
667                        l.push(r)
668                    }
669                }
670                l
671            }
672            Expression::Pow(lhs, rhs) => {
673                let mut l = lhs.get_unique_variables();
674                for r in rhs.get_unique_variables().into_iter() {
675                    if !l.contains(&r) {
676                        l.push(r)
677                    }
678                }
679                l
680            }
681            Expression::Sqrt(expr) => expr.get_unique_variables(),
682            Expression::Sin(expr) => expr.get_unique_variables(),
683            Expression::Cos(expr) => expr.get_unique_variables(),
684        }
685    }
686}
687
688impl std::ops::Add<Expression> for Expression {
689    type Output = Self;
690
691    fn add(self, other: Self) -> Self {
692        &self + &other
693    }
694}
695
696impl std::ops::Add<&Expression> for Expression {
697    type Output = Expression;
698
699    fn add(self, other: &Expression) -> Expression {
700        &self + other
701    }
702}
703
704impl std::ops::Add<Expression> for &Expression {
705    type Output = Expression;
706
707    fn add(self, other: Expression) -> Expression {
708        self + &other
709    }
710}
711
712impl std::ops::Add<&Expression> for &Expression {
713    type Output = Expression;
714
715    fn add(self, other: &Expression) -> Expression {
716        if let Expression::Constant(c1) = self
717            && let Expression::Constant(c2) = other
718        {
719            return Expression::Constant(c1 + c2);
720        }
721        if other.is_zero_fast() {
722            self.clone()
723        } else if self.is_zero_fast() {
724            other.clone()
725        } else {
726            Expression::Add(Box::new(self.clone()), Box::new(other.clone()))
727        }
728    }
729}
730
731impl std::ops::Sub<Expression> for Expression {
732    type Output = Self;
733
734    fn sub(self, other: Self) -> Self {
735        &self - &other
736    }
737}
738
739impl std::ops::Sub<&Expression> for Expression {
740    type Output = Expression;
741
742    fn sub(self, other: &Expression) -> Expression {
743        &self - other
744    }
745}
746
747impl std::ops::Sub<Expression> for &Expression {
748    type Output = Expression;
749
750    fn sub(self, other: Expression) -> Expression {
751        self - &other
752    }
753}
754
755impl std::ops::Sub<&Expression> for &Expression {
756    type Output = Expression;
757
758    fn sub(self, other: &Expression) -> Expression {
759        if let Expression::Constant(c1) = self
760            && let Expression::Constant(c2) = other
761        {
762            return Expression::Constant(c1 - c2);
763        }
764        if other.is_zero_fast() {
765            self.clone()
766        } else if self.is_zero_fast() {
767            -other.clone()
768        } else {
769            Expression::Sub(Box::new(self.clone()), Box::new(other.clone()))
770        }
771    }
772}
773
774impl std::ops::Mul<Expression> for Expression {
775    type Output = Self;
776
777    fn mul(self, other: Self) -> Self {
778        &self * &other
779    }
780}
781
782impl std::ops::Mul<&Expression> for Expression {
783    type Output = Expression;
784
785    fn mul(self, other: &Expression) -> Expression {
786        &self * other
787    }
788}
789
790impl std::ops::Mul<Expression> for &Expression {
791    type Output = Expression;
792
793    fn mul(self, other: Expression) -> Expression {
794        self * &other
795    }
796}
797
798impl std::ops::Mul<&Expression> for &Expression {
799    type Output = Expression;
800
801    fn mul(self, other: &Expression) -> Expression {
802        if let Expression::Constant(c1) = self
803            && let Expression::Constant(c2) = other
804        {
805            return Expression::Constant(c1 * c2);
806        }
807        if other.is_zero_fast() || self.is_zero_fast() {
808            Expression::zero()
809        } else if other.is_one_fast() {
810            self.clone()
811        } else if self.is_one_fast() {
812            other.clone()
813        } else {
814            Expression::Mul(Box::new(self.clone()), Box::new(other.clone()))
815        }
816    }
817}
818
819impl std::ops::Div<Expression> for Expression {
820    type Output = Self;
821
822    fn div(self, other: Self) -> Self {
823        &self / &other
824    }
825}
826
827impl std::ops::Div<&Expression> for Expression {
828    type Output = Expression;
829
830    fn div(self, other: &Expression) -> Expression {
831        &self / other
832    }
833}
834
835impl std::ops::Div<Expression> for &Expression {
836    type Output = Expression;
837
838    fn div(self, other: Expression) -> Expression {
839        self / &other
840    }
841}
842
843impl std::ops::Div<&Expression> for &Expression {
844    type Output = Expression;
845
846    fn div(self, other: &Expression) -> Expression {
847        if other.is_zero_fast() {
848            panic!("Cannot divide by zero")
849        } else if let (Expression::Constant(c1), Expression::Constant(c2)) = (self, other) {
850            Expression::Constant(c1 / c2)
851        } else if self.is_zero_fast() {
852            Expression::zero()
853        } else {
854            Expression::Div(Box::new(self.clone()), Box::new(other.clone()))
855        }
856    }
857}
858
859impl std::ops::Neg for Expression {
860    type Output = Self;
861
862    fn neg(self) -> Self {
863        -&self
864    }
865}
866
867impl std::ops::Neg for &Expression {
868    type Output = Expression;
869
870    fn neg(self) -> Expression {
871        if self.is_zero_fast() {
872            self.clone()
873        } else {
874            Expression::Neg(Box::new(self.clone()))
875        }
876    }
877}
878
879impl std::fmt::Debug for Expression {
880    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
881        write!(f, "{}", self)
882    }
883}
884
885impl PartialEq for Expression {
886    fn eq(&self, other: &Self) -> bool {
887        self.fast_eq(other)
888        // if self.fast_eq(other) {
889        //     return true;
890        // }
891        // check_equality(self, other)
892    }
893}
894
895impl Eq for Expression {}
896
897impl std::hash::Hash for Expression {
898    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
899        let val = self.hash_eval();
900        (val * 1e5_f64).round().to_bits().hash(state);
901    }
902}
903
904impl AsRef<Expression> for Expression {
905    fn as_ref(&self) -> &Expression {
906        self
907    }
908}
909
910impl std::fmt::Display for Expression {
911    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
912        let inner = match self {
913            Expression::Pi => "pi".to_string(),
914            Expression::Variable(var) => var.clone(),
915            Expression::Constant(_c) => self.to_float().to_string(),
916            Expression::Neg(expr) => format!("~ {}", expr),
917            Expression::Add(lhs, rhs) => format!("+ {} {}", lhs, rhs),
918            Expression::Sub(lhs, rhs) => format!("- {} {}", lhs, rhs),
919            Expression::Mul(lhs, rhs) => format!("* {} {}", lhs, rhs),
920            Expression::Div(lhs, rhs) => format!("/ {} {}", lhs, rhs),
921            Expression::Pow(lhs, rhs) => format!("pow {} {}", lhs, rhs),
922            Expression::Sqrt(expr) => format!("sqrt {}", expr),
923            Expression::Sin(expr) => format!("sin {}", expr),
924            Expression::Cos(expr) => format!("cos {}", expr),
925        };
926        write!(f, "({})", inner)
927    }
928}
929
930impl<R: RealScalar> From<R> for Expression {
931    fn from(value: R) -> Self {
932        Expression::from_float(value.to64())
933    }
934}