1use num::FromPrimitive;
2use num::ToPrimitive;
3use num::bigint::BigInt;
4use num::rational::Ratio;
5use qudit_core::RealScalar;
6use std::collections::HashMap;
7use std::collections::HashSet;
8
9use crate::analysis::simplify;
10
11pub type Rational = Ratio<BigInt>;
12pub type Constant = Rational;
13
14#[derive(Clone)]
15pub enum Expression {
16 Pi,
17 Variable(String),
18 Constant(Constant),
19 Neg(Box<Expression>),
20 Add(Box<Expression>, Box<Expression>),
21 Sub(Box<Expression>, Box<Expression>),
22 Mul(Box<Expression>, Box<Expression>),
23 Div(Box<Expression>, Box<Expression>),
24 Pow(Box<Expression>, Box<Expression>),
25 Sqrt(Box<Expression>),
26 Sin(Box<Expression>),
27 Cos(Box<Expression>),
28}
29
30impl Expression {
31 pub fn zero() -> Self {
32 Expression::Constant(Constant::new(BigInt::from(0), BigInt::from(1)))
33 }
34
35 pub fn one() -> Self {
36 Expression::Constant(Constant::new(BigInt::from(1), BigInt::from(1)))
37 }
38
39 pub fn from_int(n: i64) -> Self {
40 Expression::Constant(Constant::new(BigInt::from(n), BigInt::from(1)))
41 }
42
43 pub fn from_float(f: f64) -> Self {
44 Self::from_float_64(f)
45 }
46
47 pub fn from_float_32(f: f32) -> Self {
48 Expression::Constant(Constant::from_f32(f).unwrap())
49 }
50
51 pub fn from_float_64(f: f64) -> Self {
52 Expression::Constant(Constant::from_f64(f).unwrap())
53 }
54
55 pub fn to_float(&self) -> f64 {
56 match self {
57 Expression::Constant(c) => c.to_f64().unwrap(),
58 Expression::Variable(_) => panic!("Cannot convert variable to float"),
59 Expression::Pi => std::f64::consts::PI,
60 Expression::Neg(expr) => -expr.to_float(),
61 Expression::Add(lhs, rhs) => lhs.to_float() + rhs.to_float(),
62 Expression::Sub(lhs, rhs) => lhs.to_float() - rhs.to_float(),
63 Expression::Mul(lhs, rhs) => lhs.to_float() * rhs.to_float(),
64 Expression::Div(lhs, rhs) => lhs.to_float() / rhs.to_float(),
65 Expression::Pow(lhs, rhs) => lhs.to_float().powf(rhs.to_float()),
66 Expression::Sqrt(expr) => expr.to_float().sqrt(),
67 Expression::Sin(expr) => expr.to_float().sin(),
68 Expression::Cos(expr) => expr.to_float().cos(),
69 }
70 }
71
72 pub fn to_constant(&self) -> Constant {
73 Constant::from_float(self.to_float()).unwrap()
75 }
76
77 pub fn gather_context(&self) -> HashSet<String> {
78 let mut context = HashSet::new();
79 context.insert(self.to_string());
80 match self {
81 Expression::Pi => {
82 context.insert("pi".to_string());
83 }
84 Expression::Variable(var) => {
85 context.insert(var.clone());
86 }
87 Expression::Constant(_) => {
88 context.insert(self.to_string());
89 context.insert(self.to_float().to_string());
90 }
91 Expression::Neg(expr) => {
92 context.extend(expr.gather_context());
93 }
94 Expression::Add(lhs, rhs) => {
95 context.extend(lhs.gather_context());
96 context.extend(rhs.gather_context());
97 }
98 Expression::Sub(lhs, rhs) => {
99 context.extend(lhs.gather_context());
100 context.extend(rhs.gather_context());
101 }
102 Expression::Mul(lhs, rhs) => {
103 context.extend(lhs.gather_context());
104 context.extend(rhs.gather_context());
105 }
106 Expression::Div(lhs, rhs) => {
107 context.extend(lhs.gather_context());
108 context.extend(rhs.gather_context());
109 }
110 Expression::Pow(lhs, rhs) => {
111 context.extend(lhs.gather_context());
112 context.extend(rhs.gather_context());
113 }
114 Expression::Sqrt(expr) => {
115 context.extend(expr.gather_context());
116 }
117 Expression::Sin(expr) => {
118 context.extend(expr.gather_context());
119 }
120 Expression::Cos(expr) => {
121 context.extend(expr.gather_context());
122 }
123 }
124 context
125 }
126
127 pub fn is_zero(&self) -> bool {
128 match self {
129 Expression::Constant(c) => *c.numer() == BigInt::from(0),
130 Expression::Neg(expr) => expr.is_zero(),
131 Expression::Add(lhs, rhs) => lhs.is_zero() && rhs.is_zero(),
132 Expression::Sub(lhs, rhs) => (lhs.is_zero() && rhs.is_zero()) || lhs == rhs,
133 Expression::Mul(lhs, rhs) => lhs.is_zero() || rhs.is_zero(),
134 Expression::Div(lhs, _) => lhs.is_zero(),
135 Expression::Pow(lhs, rhs) => lhs.is_zero() && !rhs.is_zero(),
136 Expression::Sqrt(expr) => expr.is_zero(),
137 Expression::Sin(expr) => expr.is_zero(),
138 Expression::Cos(expr) => {
139 !expr.is_parameterized()
140 && (expr.eval::<f64>(&HashMap::new()) - std::f64::consts::PI / 2.0) < 1e-6
141 }
142 Expression::Pi => false,
143 Expression::Variable(_) => false,
144 }
145 }
146
147 pub fn is_zero_fast(&self) -> bool {
149 match self {
150 Expression::Constant(c) => *c.numer() == BigInt::from(0),
151 Expression::Neg(expr) => expr.is_zero_fast(),
152 Expression::Add(lhs, rhs) => lhs.is_zero_fast() && rhs.is_zero_fast(),
153 Expression::Sub(lhs, rhs) => lhs.is_zero_fast() && rhs.is_zero_fast(),
154 Expression::Mul(lhs, rhs) => lhs.is_zero_fast() || rhs.is_zero_fast(),
155 Expression::Div(lhs, _) => lhs.is_zero_fast(),
156 Expression::Pow(lhs, rhs) => lhs.is_zero_fast() && !rhs.is_zero_fast(),
157 Expression::Sqrt(expr) => expr.is_zero_fast(),
158 Expression::Sin(expr) => expr.is_zero_fast(),
159 Expression::Cos(_expr) => false,
160 Expression::Pi => false,
161 Expression::Variable(_) => false,
162 }
163 }
164
165 pub fn is_one(&self) -> bool {
166 match self {
167 Expression::Constant(c) => *c.numer() == *c.denom(),
168 Expression::Neg(expr) => {
169 !expr.is_parameterized() && expr.eval::<f64>(&HashMap::new()) == -1.0
170 }
171 Expression::Add(lhs, rhs) => {
172 lhs.is_one() && rhs.is_zero() || lhs.is_zero() && rhs.is_one()
173 }
174 Expression::Sub(lhs, rhs) => lhs.is_one() && rhs.is_zero(),
175 Expression::Mul(lhs, rhs) => lhs.is_one() && rhs.is_one(),
176 Expression::Div(lhs, rhs) => lhs == rhs && !rhs.is_zero(),
177 Expression::Pow(lhs, _rhs) => lhs.is_one(),
178 Expression::Sqrt(expr) => expr.is_one(),
179 Expression::Sin(expr) => {
180 !expr.is_parameterized()
181 && (expr.eval::<f64>(&HashMap::new()) - std::f64::consts::PI / 2.0) < 1e-6
182 }
183 Expression::Cos(expr) => expr.is_zero(),
184 Expression::Pi => false,
185 Expression::Variable(_) => false,
186 }
187 }
188
189 pub fn is_one_fast(&self) -> bool {
190 match self {
191 Expression::Constant(c) => *c.numer() == *c.denom(),
192 Expression::Neg(_expr) => false,
193 Expression::Add(lhs, rhs) => {
194 lhs.is_one_fast() && rhs.is_zero_fast() || lhs.is_zero_fast() && rhs.is_one_fast()
195 }
196 Expression::Sub(lhs, rhs) => lhs.is_one_fast() && rhs.is_zero_fast(),
197 Expression::Mul(lhs, rhs) => lhs.is_one_fast() && rhs.is_one_fast(),
198 Expression::Div(lhs, rhs) => lhs.is_one_fast() && rhs.is_one_fast(),
199 Expression::Pow(lhs, _rhs) => lhs.is_one_fast(),
200 Expression::Sqrt(expr) => expr.is_one_fast(),
201 Expression::Sin(_expr) => false,
202 Expression::Cos(expr) => expr.is_zero_fast(),
203 Expression::Pi => false,
204 Expression::Variable(_) => false,
205 }
206 }
207
208 pub fn contains_variable<T: AsRef<str>>(&self, var: T) -> bool {
209 let var = var.as_ref();
210 match self {
211 Expression::Pi => false,
212 Expression::Variable(v) => v == var,
213 Expression::Constant(_) => false,
214 Expression::Neg(expr) => expr.contains_variable(var),
215 Expression::Add(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
216 Expression::Sub(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
217 Expression::Mul(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
218 Expression::Div(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
219 Expression::Pow(lhs, rhs) => lhs.contains_variable(var) || rhs.contains_variable(var),
220 Expression::Sqrt(expr) => expr.contains_variable(var),
221 Expression::Sin(expr) => expr.contains_variable(var),
222 Expression::Cos(expr) => expr.contains_variable(var),
223 }
224 }
225
226 pub fn is_parameterized(&self) -> bool {
227 match self {
228 Expression::Pi => false,
229 Expression::Variable(_) => true,
230 Expression::Constant(_) => false,
231 Expression::Neg(expr) => expr.is_parameterized(),
232 Expression::Add(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
233 Expression::Sub(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
234 Expression::Mul(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
235 Expression::Div(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
236 Expression::Pow(lhs, rhs) => lhs.is_parameterized() || rhs.is_parameterized(),
237 Expression::Sqrt(expr) => expr.is_parameterized(),
238 Expression::Sin(expr) => expr.is_parameterized(),
239 Expression::Cos(expr) => expr.is_parameterized(),
240 }
241 }
242
243 pub fn eval<R: RealScalar>(&self, args: &HashMap<&str, R>) -> R {
244 match self {
245 Expression::Pi => R::PI(),
246 Expression::Variable(var) => {
247 if let Some(val) = args.get(var.as_str()) {
248 *val
249 } else {
250 panic!("Variable {} not found in arguments", var)
251 }
252 }
253 Expression::Constant(c) => R::from_ratio(c.clone()).unwrap(),
254 Expression::Neg(expr) => -expr.eval(args),
255 Expression::Add(lhs, rhs) => lhs.eval(args) + rhs.eval(args),
256 Expression::Sub(lhs, rhs) => lhs.eval(args) - rhs.eval(args),
257 Expression::Mul(lhs, rhs) => lhs.eval(args) * rhs.eval(args),
258 Expression::Div(lhs, rhs) => lhs.eval(args) / rhs.eval(args),
259 Expression::Pow(lhs, rhs) => lhs.eval(args).powf(rhs.eval(args)),
260 Expression::Sqrt(expr) => expr.eval(args).sqrt(),
261 Expression::Sin(expr) => expr.eval(args).sin(),
262 Expression::Cos(expr) => expr.eval(args).cos(),
263 }
264 }
265
266 pub fn hash_eval(&self) -> f64 {
268 let val = match self {
269 Expression::Pi => self.to_float(),
270 Expression::Variable(_) => 1.7,
271 Expression::Constant(_) => self.to_float(),
272 Expression::Neg(expr) => -expr.hash_eval(),
273 Expression::Add(lhs, rhs) => lhs.hash_eval() + rhs.hash_eval(),
274 Expression::Sub(lhs, rhs) => lhs.hash_eval() - rhs.hash_eval(),
275 Expression::Mul(lhs, rhs) => lhs.hash_eval() * rhs.hash_eval(),
276 Expression::Div(lhs, rhs) => lhs.hash_eval() / rhs.hash_eval(),
277 Expression::Pow(lhs, rhs) => lhs.hash_eval().powf(rhs.hash_eval()),
278 Expression::Sqrt(expr) => expr.hash_eval().sqrt(),
279 Expression::Sin(expr) => expr.hash_eval().sin(),
280 Expression::Cos(expr) => expr.hash_eval().cos(),
281 };
282
283 if val.is_nan() || val.is_subnormal() {
284 0.0
285 } else {
286 val
287 }
288 }
289
290 pub fn map_var_names(&self, var_map: &HashMap<String, String>) -> Self {
291 match self {
292 Expression::Pi => Expression::Pi,
293 Expression::Variable(var) => {
294 if let Some(new_var) = var_map.get(var.as_str()) {
295 Expression::Variable(new_var.to_string())
296 } else {
297 Expression::Variable(var.clone())
298 }
299 }
300 Expression::Constant(c) => Expression::Constant(c.clone()),
301 Expression::Neg(expr) => Expression::Neg(Box::new(expr.map_var_names(var_map))),
302 Expression::Add(lhs, rhs) => Expression::Add(
303 Box::new(lhs.map_var_names(var_map)),
304 Box::new(rhs.map_var_names(var_map)),
305 ),
306 Expression::Sub(lhs, rhs) => Expression::Sub(
307 Box::new(lhs.map_var_names(var_map)),
308 Box::new(rhs.map_var_names(var_map)),
309 ),
310 Expression::Mul(lhs, rhs) => Expression::Mul(
311 Box::new(lhs.map_var_names(var_map)),
312 Box::new(rhs.map_var_names(var_map)),
313 ),
314 Expression::Div(lhs, rhs) => Expression::Div(
315 Box::new(lhs.map_var_names(var_map)),
316 Box::new(rhs.map_var_names(var_map)),
317 ),
318 Expression::Pow(lhs, rhs) => Expression::Pow(
319 Box::new(lhs.map_var_names(var_map)),
320 Box::new(rhs.map_var_names(var_map)),
321 ),
322 Expression::Sqrt(expr) => Expression::Sqrt(Box::new(expr.map_var_names(var_map))),
323 Expression::Sin(expr) => Expression::Sin(Box::new(expr.map_var_names(var_map))),
324 Expression::Cos(expr) => Expression::Cos(Box::new(expr.map_var_names(var_map))),
325 }
326 }
327
328 pub fn rename_variable<S: AsRef<str>, T: AsRef<str>>(&self, original: S, new: T) -> Self {
329 let original = original.as_ref();
330 let new = new.as_ref();
331 match self {
332 Expression::Pi => Expression::Pi,
333 Expression::Variable(var) => {
334 if var == original {
335 Expression::Variable(new.to_string())
336 } else {
337 Expression::Variable(var.clone())
338 }
339 }
340 Expression::Constant(c) => Expression::Constant(c.clone()),
341 Expression::Neg(expr) => Expression::Neg(Box::new(expr.rename_variable(original, new))),
342 Expression::Add(lhs, rhs) => Expression::Add(
343 Box::new(lhs.rename_variable(original, new)),
344 Box::new(rhs.rename_variable(original, new)),
345 ),
346 Expression::Sub(lhs, rhs) => Expression::Sub(
347 Box::new(lhs.rename_variable(original, new)),
348 Box::new(rhs.rename_variable(original, new)),
349 ),
350 Expression::Mul(lhs, rhs) => Expression::Mul(
351 Box::new(lhs.rename_variable(original, new)),
352 Box::new(rhs.rename_variable(original, new)),
353 ),
354 Expression::Div(lhs, rhs) => Expression::Div(
355 Box::new(lhs.rename_variable(original, new)),
356 Box::new(rhs.rename_variable(original, new)),
357 ),
358 Expression::Pow(lhs, rhs) => Expression::Pow(
359 Box::new(lhs.rename_variable(original, new)),
360 Box::new(rhs.rename_variable(original, new)),
361 ),
362 Expression::Sqrt(expr) => {
363 Expression::Sqrt(Box::new(expr.rename_variable(original, new)))
364 }
365 Expression::Sin(expr) => Expression::Sin(Box::new(expr.rename_variable(original, new))),
366 Expression::Cos(expr) => Expression::Cos(Box::new(expr.rename_variable(original, new))),
367 }
368 }
369
370 pub fn differentiate<S: AsRef<str>>(&self, wrt: S) -> Self {
371 let wrt = wrt.as_ref();
372 match self {
373 Expression::Pi => Expression::zero(),
374 Expression::Variable(var) => {
375 if var == wrt {
376 Expression::one()
377 } else {
378 Expression::zero()
379 }
380 }
381 Expression::Constant(_) => Expression::zero(),
382 Expression::Neg(expr) => Expression::Neg(Box::new(expr.differentiate(wrt))),
383 Expression::Add(lhs, rhs) => Expression::Add(
384 Box::new(lhs.differentiate(wrt)),
385 Box::new(rhs.differentiate(wrt)),
386 ),
387 Expression::Sub(lhs, rhs) => Expression::Sub(
388 Box::new(lhs.differentiate(wrt)),
389 Box::new(rhs.differentiate(wrt)),
390 ),
391 Expression::Mul(lhs, rhs) => {
392 lhs.differentiate(wrt) * *rhs.clone() + *lhs.clone() * rhs.differentiate(wrt)
393 }
394 Expression::Div(lhs, rhs) => {
395 (lhs.differentiate(wrt) * *rhs.clone() - *lhs.clone() * rhs.differentiate(wrt))
396 / (*rhs.clone() * *rhs.clone())
397 }
398 Expression::Pow(lhs, rhs) => {
399 let base_fn_x = lhs.contains_variable(wrt);
400 let exponent_fn_x = rhs.contains_variable(wrt);
401
402 if !base_fn_x && !exponent_fn_x {
403 Expression::zero()
404 } else if !base_fn_x && exponent_fn_x {
405 if lhs.is_parameterized() {
406 todo!(
407 "Cannot differentiate with respect to a parameterized power base until ln is implemented"
408 )
409 } else {
410 self.clone()
411 * rhs.differentiate(wrt)
412 * Expression::from_float(lhs.eval::<f64>(&HashMap::new()).ln())
413 }
414 } else if base_fn_x && !exponent_fn_x {
415 *rhs.clone()
416 * Expression::Pow(
417 Box::new(*lhs.clone()),
418 Box::new(*rhs.clone() - Expression::one()),
419 )
420 * lhs.differentiate(wrt)
421 } else {
422 todo!(
423 "Cannot differentiate with respect to a parameterized base and exponent until ln is implemented"
424 )
425 }
426 }
427 Expression::Sqrt(expr) => {
428 let two = Expression::from_int(2);
429 (Expression::one() / (two * self.clone())) * expr.differentiate(wrt)
430 }
431 Expression::Sin(expr) => {
432 Expression::Cos(Box::new(*expr.clone())) * expr.differentiate(wrt)
433 }
434 Expression::Cos(expr) => {
435 Expression::Neg(Box::new(Expression::Sin(Box::new(*expr.clone()))))
436 * expr.differentiate(wrt)
437 }
438 }
439 }
440
441 pub fn get_ancestors<S: AsRef<str>>(&self, variable: S) -> Vec<Expression> {
442 let variable = variable.as_ref();
443 let mut ancestors = Vec::new();
444 match self {
445 Expression::Pi => {}
446 Expression::Variable(var) => {
447 if var == variable {
448 ancestors.push(self.clone());
449 }
450 }
451 Expression::Constant(_) => {}
452 Expression::Neg(expr) => {
453 let node_ancsestors = expr.get_ancestors(variable);
454 let is_empty = node_ancsestors.is_empty();
455 ancestors.extend(node_ancsestors);
456 if !is_empty {
457 ancestors.push(self.clone());
458 }
459 }
460 Expression::Add(lhs, rhs) => {
461 let lhs_ancestors = lhs.get_ancestors(variable);
462 let rhs_ancestors = rhs.get_ancestors(variable);
463 let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
464 ancestors.extend(lhs_ancestors);
465 ancestors.extend(rhs_ancestors);
466 if !is_empty {
467 ancestors.push(self.clone());
468 }
469 }
470 Expression::Sub(lhs, rhs) => {
471 let lhs_ancestors = lhs.get_ancestors(variable);
472 let rhs_ancestors = rhs.get_ancestors(variable);
473 let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
474 ancestors.extend(lhs_ancestors);
475 ancestors.extend(rhs_ancestors);
476 if !is_empty {
477 ancestors.push(self.clone());
478 }
479 }
480 Expression::Mul(lhs, rhs) => {
481 let lhs_ancestors = lhs.get_ancestors(variable);
482 let rhs_ancestors = rhs.get_ancestors(variable);
483 let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
484 ancestors.extend(lhs_ancestors);
485 ancestors.extend(rhs_ancestors);
486 if !is_empty {
487 ancestors.push(self.clone());
488 }
489 }
490 Expression::Div(lhs, rhs) => {
491 let lhs_ancestors = lhs.get_ancestors(variable);
492 let rhs_ancestors = rhs.get_ancestors(variable);
493 let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
494 ancestors.extend(lhs_ancestors);
495 ancestors.extend(rhs_ancestors);
496 if !is_empty {
497 ancestors.push(self.clone());
498 }
499 }
500 Expression::Pow(lhs, rhs) => {
501 let lhs_ancestors = lhs.get_ancestors(variable);
502 let rhs_ancestors = rhs.get_ancestors(variable);
503 let is_empty = lhs_ancestors.is_empty() && rhs_ancestors.is_empty();
504 ancestors.extend(lhs_ancestors);
505 ancestors.extend(rhs_ancestors);
506 if !is_empty {
507 ancestors.push(self.clone());
508 }
509 }
510 Expression::Sqrt(expr) => {
511 let node_ancsestors = expr.get_ancestors(variable);
512 let is_empty = node_ancsestors.is_empty();
513 ancestors.extend(node_ancsestors);
514 if !is_empty {
515 ancestors.push(self.clone());
516 }
517 }
518 Expression::Sin(expr) => {
519 let node_ancsestors = expr.get_ancestors(variable);
520 let is_empty = node_ancsestors.is_empty();
521 ancestors.extend(node_ancsestors);
522 if !is_empty {
523 ancestors.push(self.clone());
524 }
525 }
526 Expression::Cos(expr) => {
527 let node_ancsestors = expr.get_ancestors(variable);
528 let is_empty = node_ancsestors.is_empty();
529 ancestors.extend(node_ancsestors);
530 if !is_empty {
531 ancestors.push(self.clone());
532 }
533 }
534 }
535 ancestors
536 }
537
538 pub fn fast_eq(&self, other: &Expression) -> bool {
539 match (self, other) {
540 (Expression::Pi, Expression::Pi) => true,
541 (Expression::Variable(var1), Expression::Variable(var2)) => var1 == var2,
542 (Expression::Constant(c1), Expression::Constant(c2)) => c1 == c2,
543 (Expression::Neg(expr1), Expression::Neg(expr2)) => expr1.fast_eq(expr2),
544 (Expression::Add(lhs1, rhs1), Expression::Add(lhs2, rhs2)) => {
545 (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
546 || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
547 }
548 (Expression::Sub(lhs1, rhs1), Expression::Sub(lhs2, rhs2)) => {
549 (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
550 || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
551 }
552 (Expression::Mul(lhs1, rhs1), Expression::Mul(lhs2, rhs2)) => {
553 (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
554 || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
555 }
556 (Expression::Div(lhs1, rhs1), Expression::Div(lhs2, rhs2)) => {
557 (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
558 || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
559 }
560 (Expression::Pow(lhs1, rhs1), Expression::Pow(lhs2, rhs2)) => {
561 (lhs1.fast_eq(lhs2) && rhs1.fast_eq(rhs2))
562 || (lhs1.fast_eq(rhs2) && rhs1.fast_eq(lhs2))
563 }
564 (Expression::Sqrt(expr1), Expression::Sqrt(expr2)) => expr1.fast_eq(expr2),
565 (Expression::Sin(expr1), Expression::Sin(expr2)) => expr1.fast_eq(expr2),
566 (Expression::Cos(expr1), Expression::Cos(expr2)) => expr1.fast_eq(expr2),
567 _ => false,
568 }
569 }
570
571 pub fn substitute<S: AsRef<Expression>, T: AsRef<Expression>>(
572 &self,
573 original: S,
574 substitution: T,
575 ) -> Self {
576 let original = original.as_ref();
577 let substitution = substitution.as_ref();
578 if self.fast_eq(original) {
579 return substitution.clone();
580 }
581 match self {
582 Expression::Pi => self.clone(),
583 Expression::Variable(_) => self.clone(),
584 Expression::Constant(_) => self.clone(),
585 Expression::Neg(expr) => {
586 Expression::Neg(Box::new(expr.substitute(original, substitution)))
587 }
588 Expression::Add(lhs, rhs) => Expression::Add(
589 Box::new(lhs.substitute(original, substitution)),
590 Box::new(rhs.substitute(original, substitution)),
591 ),
592 Expression::Sub(lhs, rhs) => Expression::Sub(
593 Box::new(lhs.substitute(original, substitution)),
594 Box::new(rhs.substitute(original, substitution)),
595 ),
596 Expression::Mul(lhs, rhs) => Expression::Mul(
597 Box::new(lhs.substitute(original, substitution)),
598 Box::new(rhs.substitute(original, substitution)),
599 ),
600 Expression::Div(lhs, rhs) => Expression::Div(
601 Box::new(lhs.substitute(original, substitution)),
602 Box::new(rhs.substitute(original, substitution)),
603 ),
604 Expression::Pow(lhs, rhs) => Expression::Pow(
605 Box::new(lhs.substitute(original, substitution)),
606 Box::new(rhs.substitute(original, substitution)),
607 ),
608 Expression::Sqrt(expr) => {
609 Expression::Sqrt(Box::new(expr.substitute(original, substitution)))
610 }
611 Expression::Sin(expr) => {
612 Expression::Sin(Box::new(expr.substitute(original, substitution)))
613 }
614 Expression::Cos(expr) => {
615 Expression::Cos(Box::new(expr.substitute(original, substitution)))
616 }
617 }
618 }
619
620 pub fn simplify(&self) -> Self {
621 simplify(self)
622 }
623
624 pub fn get_unique_variables(&self) -> Vec<String> {
625 match self {
626 Expression::Pi => {
627 vec![]
628 }
629 Expression::Variable(s) => {
630 vec![s.clone()]
631 }
632 Expression::Constant(_) => {
633 vec![]
634 }
635 Expression::Neg(expr) => expr.get_unique_variables(),
636 Expression::Add(lhs, rhs) => {
637 let mut l = lhs.get_unique_variables();
638 for r in rhs.get_unique_variables().into_iter() {
639 if !l.contains(&r) {
640 l.push(r)
641 }
642 }
643 l
644 }
645 Expression::Sub(lhs, rhs) => {
646 let mut l = lhs.get_unique_variables();
647 for r in rhs.get_unique_variables().into_iter() {
648 if !l.contains(&r) {
649 l.push(r)
650 }
651 }
652 l
653 }
654 Expression::Mul(lhs, rhs) => {
655 let mut l = lhs.get_unique_variables();
656 for r in rhs.get_unique_variables().into_iter() {
657 if !l.contains(&r) {
658 l.push(r)
659 }
660 }
661 l
662 }
663 Expression::Div(lhs, rhs) => {
664 let mut l = lhs.get_unique_variables();
665 for r in rhs.get_unique_variables().into_iter() {
666 if !l.contains(&r) {
667 l.push(r)
668 }
669 }
670 l
671 }
672 Expression::Pow(lhs, rhs) => {
673 let mut l = lhs.get_unique_variables();
674 for r in rhs.get_unique_variables().into_iter() {
675 if !l.contains(&r) {
676 l.push(r)
677 }
678 }
679 l
680 }
681 Expression::Sqrt(expr) => expr.get_unique_variables(),
682 Expression::Sin(expr) => expr.get_unique_variables(),
683 Expression::Cos(expr) => expr.get_unique_variables(),
684 }
685 }
686}
687
688impl std::ops::Add<Expression> for Expression {
689 type Output = Self;
690
691 fn add(self, other: Self) -> Self {
692 &self + &other
693 }
694}
695
696impl std::ops::Add<&Expression> for Expression {
697 type Output = Expression;
698
699 fn add(self, other: &Expression) -> Expression {
700 &self + other
701 }
702}
703
704impl std::ops::Add<Expression> for &Expression {
705 type Output = Expression;
706
707 fn add(self, other: Expression) -> Expression {
708 self + &other
709 }
710}
711
712impl std::ops::Add<&Expression> for &Expression {
713 type Output = Expression;
714
715 fn add(self, other: &Expression) -> Expression {
716 if let Expression::Constant(c1) = self
717 && let Expression::Constant(c2) = other
718 {
719 return Expression::Constant(c1 + c2);
720 }
721 if other.is_zero_fast() {
722 self.clone()
723 } else if self.is_zero_fast() {
724 other.clone()
725 } else {
726 Expression::Add(Box::new(self.clone()), Box::new(other.clone()))
727 }
728 }
729}
730
731impl std::ops::Sub<Expression> for Expression {
732 type Output = Self;
733
734 fn sub(self, other: Self) -> Self {
735 &self - &other
736 }
737}
738
739impl std::ops::Sub<&Expression> for Expression {
740 type Output = Expression;
741
742 fn sub(self, other: &Expression) -> Expression {
743 &self - other
744 }
745}
746
747impl std::ops::Sub<Expression> for &Expression {
748 type Output = Expression;
749
750 fn sub(self, other: Expression) -> Expression {
751 self - &other
752 }
753}
754
755impl std::ops::Sub<&Expression> for &Expression {
756 type Output = Expression;
757
758 fn sub(self, other: &Expression) -> Expression {
759 if let Expression::Constant(c1) = self
760 && let Expression::Constant(c2) = other
761 {
762 return Expression::Constant(c1 - c2);
763 }
764 if other.is_zero_fast() {
765 self.clone()
766 } else if self.is_zero_fast() {
767 -other.clone()
768 } else {
769 Expression::Sub(Box::new(self.clone()), Box::new(other.clone()))
770 }
771 }
772}
773
774impl std::ops::Mul<Expression> for Expression {
775 type Output = Self;
776
777 fn mul(self, other: Self) -> Self {
778 &self * &other
779 }
780}
781
782impl std::ops::Mul<&Expression> for Expression {
783 type Output = Expression;
784
785 fn mul(self, other: &Expression) -> Expression {
786 &self * other
787 }
788}
789
790impl std::ops::Mul<Expression> for &Expression {
791 type Output = Expression;
792
793 fn mul(self, other: Expression) -> Expression {
794 self * &other
795 }
796}
797
798impl std::ops::Mul<&Expression> for &Expression {
799 type Output = Expression;
800
801 fn mul(self, other: &Expression) -> Expression {
802 if let Expression::Constant(c1) = self
803 && let Expression::Constant(c2) = other
804 {
805 return Expression::Constant(c1 * c2);
806 }
807 if other.is_zero_fast() || self.is_zero_fast() {
808 Expression::zero()
809 } else if other.is_one_fast() {
810 self.clone()
811 } else if self.is_one_fast() {
812 other.clone()
813 } else {
814 Expression::Mul(Box::new(self.clone()), Box::new(other.clone()))
815 }
816 }
817}
818
819impl std::ops::Div<Expression> for Expression {
820 type Output = Self;
821
822 fn div(self, other: Self) -> Self {
823 &self / &other
824 }
825}
826
827impl std::ops::Div<&Expression> for Expression {
828 type Output = Expression;
829
830 fn div(self, other: &Expression) -> Expression {
831 &self / other
832 }
833}
834
835impl std::ops::Div<Expression> for &Expression {
836 type Output = Expression;
837
838 fn div(self, other: Expression) -> Expression {
839 self / &other
840 }
841}
842
843impl std::ops::Div<&Expression> for &Expression {
844 type Output = Expression;
845
846 fn div(self, other: &Expression) -> Expression {
847 if other.is_zero_fast() {
848 panic!("Cannot divide by zero")
849 } else if let (Expression::Constant(c1), Expression::Constant(c2)) = (self, other) {
850 Expression::Constant(c1 / c2)
851 } else if self.is_zero_fast() {
852 Expression::zero()
853 } else {
854 Expression::Div(Box::new(self.clone()), Box::new(other.clone()))
855 }
856 }
857}
858
859impl std::ops::Neg for Expression {
860 type Output = Self;
861
862 fn neg(self) -> Self {
863 -&self
864 }
865}
866
867impl std::ops::Neg for &Expression {
868 type Output = Expression;
869
870 fn neg(self) -> Expression {
871 if self.is_zero_fast() {
872 self.clone()
873 } else {
874 Expression::Neg(Box::new(self.clone()))
875 }
876 }
877}
878
879impl std::fmt::Debug for Expression {
880 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
881 write!(f, "{}", self)
882 }
883}
884
885impl PartialEq for Expression {
886 fn eq(&self, other: &Self) -> bool {
887 self.fast_eq(other)
888 }
893}
894
895impl Eq for Expression {}
896
897impl std::hash::Hash for Expression {
898 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
899 let val = self.hash_eval();
900 (val * 1e5_f64).round().to_bits().hash(state);
901 }
902}
903
904impl AsRef<Expression> for Expression {
905 fn as_ref(&self) -> &Expression {
906 self
907 }
908}
909
910impl std::fmt::Display for Expression {
911 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
912 let inner = match self {
913 Expression::Pi => "pi".to_string(),
914 Expression::Variable(var) => var.clone(),
915 Expression::Constant(_c) => self.to_float().to_string(),
916 Expression::Neg(expr) => format!("~ {}", expr),
917 Expression::Add(lhs, rhs) => format!("+ {} {}", lhs, rhs),
918 Expression::Sub(lhs, rhs) => format!("- {} {}", lhs, rhs),
919 Expression::Mul(lhs, rhs) => format!("* {} {}", lhs, rhs),
920 Expression::Div(lhs, rhs) => format!("/ {} {}", lhs, rhs),
921 Expression::Pow(lhs, rhs) => format!("pow {} {}", lhs, rhs),
922 Expression::Sqrt(expr) => format!("sqrt {}", expr),
923 Expression::Sin(expr) => format!("sin {}", expr),
924 Expression::Cos(expr) => format!("cos {}", expr),
925 };
926 write!(f, "({})", inner)
927 }
928}
929
930impl<R: RealScalar> From<R> for Expression {
931 fn from(value: R) -> Self {
932 Expression::from_float(value.to64())
933 }
934}