1use egg::*;
2use num::Signed;
3use num::Zero;
4use ordered_float::NotNan;
5
6use crate::ComplexExpression;
7use crate::Expression;
8
9pub type EGraph = egg::EGraph<TrigLanguage, ConstantFold>;
13pub type Constant = NotNan<f64>;
14
15define_language! {
16 pub enum TrigLanguage {
17 "pi" = Pi,
18
19 "~" = Neg([Id; 1]),
20 "+" = Add([Id; 2]),
21 "-" = Sub([Id; 2]),
22 "*" = Mul([Id; 2]),
23 "/" = Div([Id; 2]),
24
25 "pow" = Pow([Id; 2]),
26 "sqrt" = Sqrt(Id),
27 "sin" = Sin(Id),
28 "cos" = Cos(Id),
29
30 Constant(Constant),
31 Symbol(Symbol),
32 }
33}
34
35fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
36 let var = var.parse().unwrap();
37 move |egraph, _, subst| {
38 if let Some(n) = &egraph[subst[var]].data {
39 *(n.0) != 0.0
40 } else {
41 false
42 }
43 }
44}
45
46fn is_non_negative_conservative(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
47 let var = var.parse().unwrap();
48 move |egraph, _, subst| {
49 if let Some(n) = &egraph[subst[var]].data {
50 *(n.0) >= 0.0
51 } else {
52 false
53 }
54 }
55}
56
57#[allow(dead_code)]
58fn all_not_zero(vars: &[&str]) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
59 let vars: Vec<_> = vars.iter().map(|v| v.parse().unwrap()).collect();
60 move |egraph, _, subst| {
61 vars.iter().all(|&v| {
62 if let Some(n) = &egraph[subst[v]].data {
63 *(n.0) != 0.0
64 } else {
65 true
66 }
67 })
68 }
69}
70
71fn cmp<T: PartialOrd>(a: &Option<T>, b: &Option<T>) -> Ordering {
72 match (a, b) {
73 (None, None) => Ordering::Equal,
74 (None, Some(_)) => Ordering::Greater,
75 (Some(_), None) => Ordering::Less,
76 (Some(a), Some(b)) => a.partial_cmp(b).unwrap(),
77 }
78}
79
80use core::f64;
81use std::cmp::Ordering;
82use std::collections::HashMap;
83
84#[allow(dead_code)]
85struct SineExtractor<'a> {
86 costs: HashMap<Id, (f64, TrigLanguage)>,
87 egraph: &'a EGraph,
88}
89
90impl<'a> SineExtractor<'a> {
91 #![allow(clippy::type_complexity)]
92 pub fn new(egraph: &'a EGraph) -> Self {
93 let costs = HashMap::default();
94 let mut extractor = SineExtractor { costs, egraph };
95 extractor.calculate_costs();
96 extractor
97 }
98
99 pub fn extract_sine(&self, eclass: Id) -> Option<RecExpr<TrigLanguage>> {
100 let id = self.egraph.find(eclass);
101 let eclass = &self.egraph[id];
102
103 let mut best_cost = None;
105 let mut best_node = None;
106
107 for node in &eclass.nodes {
108 if let TrigLanguage::Sin(id) = node {
109 let cost = self.costs[id].0;
110 if best_cost.is_none() || cost < best_cost.unwrap() {
111 best_cost = Some(cost);
112 best_node = Some(node.clone());
113 }
114 }
115 }
116
117 if let Some(node) = best_node {
118 let expr = node.build_recexpr(|id| self.find_best_node(id).clone());
119 Some(expr)
120 } else {
121 None
122 }
123 }
124
125 pub fn find_best_node(&self, eclass: Id) -> &TrigLanguage {
126 &self.costs[&self.egraph.find(eclass)].1
127 }
128
129 fn calculate_costs(&mut self) {
130 let mut did_something = true;
131 while did_something {
132 did_something = false;
133
134 for class in self.egraph.classes() {
135 let pass = self.make_pass(class);
136 match (self.costs.get(&class.id), pass) {
137 (None, Some(new)) => {
138 self.costs.insert(class.id, new);
139 did_something = true;
140 }
141 (Some(old), Some(new)) if new.0 < old.0 => {
142 self.costs.insert(class.id, new);
143 did_something = true;
144 }
145 _ => {}
146 }
147 }
148 }
149
150 for class in self.egraph.classes() {
151 if !self.costs.contains_key(&class.id) {
152 println!("failed to calculate cost for {:?}", class);
153 }
154 }
155 }
156
157 fn make_pass(
158 &mut self,
159 eclass: &EClass<TrigLanguage, Option<(NotNan<f64>, RecExpr<ENodeOrVar<TrigLanguage>>)>>,
160 ) -> Option<(f64, TrigLanguage)> {
161 let (cost, node) = eclass
162 .iter()
163 .map(|n| (self.node_total_cost(n), n))
164 .min_by(|a, b| cmp(&a.0, &b.0))
165 .unwrap_or_else(|| panic!("eclass is empty"));
166 cost.map(|c| (c, node.clone()))
167 }
168
169 fn node_total_cost(&mut self, enode: &TrigLanguage) -> Option<f64> {
170 let eg = &self.egraph;
171 let has_cost = |id| self.costs.contains_key(&eg.find(id));
172 if enode.all(has_cost) {
175 Some(self.node_cost(enode))
176 } else {
177 None
178 }
179 }
180
181 fn node_cost(&self, enode: &TrigLanguage) -> f64 {
182 let op_cost = match enode {
183 TrigLanguage::Constant(_) => 0.5,
184 TrigLanguage::Neg(_) => 1.0,
185 TrigLanguage::Add(_) | TrigLanguage::Sub(_) => 1.0,
186 TrigLanguage::Mul(_) | TrigLanguage::Div(_) => 5.0,
187 TrigLanguage::Pow(_)
188 | TrigLanguage::Sqrt(_)
189 | TrigLanguage::Sin(_)
190 | TrigLanguage::Cos(_) => 50.0,
191 _ => 0.0,
192 };
193 enode.fold(op_cost, |acc, id| acc + self.costs[&id].0)
194 }
195}
196
197struct TrigExprExtractor<'a> {
257 costs: Vec<(f64, TrigLanguage)>,
258 egraph: &'a EGraph,
259 has_changed: bool,
260}
261
262impl<'a> TrigExprExtractor<'a> {
263 #![allow(clippy::type_complexity)]
264 pub fn new(egraph: &'a EGraph) -> Self {
265 let mut max_id = 0usize;
267 for class in egraph.classes() {
268 let id = unsafe { std::mem::transmute::<Id, u32>(class.id) } as usize;
269 if id > max_id {
270 max_id = id
271 }
272 }
273 let costs = vec![(-1.0, TrigLanguage::Pi); max_id + 1];
274 let mut extractor = TrigExprExtractor {
275 costs,
276 egraph,
277 has_changed: false,
278 };
279 extractor.calculate_costs();
280 extractor
281 }
282
283 pub fn extract_best(&mut self, eclass: Id) -> RecExpr<TrigLanguage> {
284 let root = self.get_cost(self.egraph.find(eclass)).1.clone();
285 let expr = root.build_recexpr(|id| self.extract_best_node(id));
286 if self.has_changed {
292 self.recalculate_costs();
293 self.has_changed = false;
294 }
295 expr
296 }
297
298 pub fn extract_best_node(&mut self, eclass: Id) -> TrigLanguage {
299 let id = &self.egraph.find(eclass);
300 let (cost, enode) = self.get_cost(*id).clone();
301 if cost != 0.0 {
302 self.put_cost(*id, (cost, enode.clone()));
303 self.has_changed = true;
304 }
305 enode
306 }
307
308 #[allow(dead_code)]
309 pub fn find_best_node(&self, eclass: Id) -> &TrigLanguage {
310 &self.get_cost(self.egraph.find(eclass)).1
311 }
312
313 fn recalculate_costs(&mut self) {
314 let mut did_something = true;
315 while did_something {
316 did_something = false;
317
318 for class in self.egraph.classes() {
319 let pass = self.make_repass(class);
320 let old = self.get_cost(class.id);
321 match (old, pass) {
322 (old, new) if old.0 < 0.0 || new < old.0 => {
323 self.put_cost(class.id, (new, old.1.clone()));
324 did_something = true;
325 }
326 _ => {}
327 }
328 }
329 }
330 }
331
332 fn calculate_costs(&mut self) {
333 let mut did_something = true;
334 while did_something {
335 did_something = false;
336
337 for class in self.egraph.classes() {
338 let pass = self.make_pass(class);
339 if let (old, Some(new)) = (self.get_cost(class.id), pass) {
340 if old.0 < 0.0 || (new.0 > 0.0 && new.0 < old.0) {
341 self.put_cost(class.id, new);
342 did_something = true;
343 }
344 }
345 }
346 }
347
348 for class in self.egraph.classes() {
349 if self.get_cost(class.id).0 < 0.0 {
350 println!("failed to calculate cost for {:?}", class);
351 }
352 }
353 }
354
355 fn make_pass(
356 &mut self,
357 eclass: &EClass<TrigLanguage, Option<(NotNan<f64>, RecExpr<ENodeOrVar<TrigLanguage>>)>>,
358 ) -> Option<(f64, TrigLanguage)> {
359 let (cost, node) = eclass
360 .iter()
361 .map(|n| (self.node_total_cost(n), n))
362 .min_by(|a, b| cmp(&a.0, &b.0))
363 .unwrap_or_else(|| panic!("eclass is empty"));
364 cost.map(|c| (c, node.clone()))
365 }
366
367 fn make_repass(
368 &mut self,
369 eclass: &EClass<TrigLanguage, Option<(NotNan<f64>, RecExpr<ENodeOrVar<TrigLanguage>>)>>,
370 ) -> f64 {
371 eclass
372 .iter()
373 .map(|n| self.node_cost(n))
374 .min_by(|a, b| match a < b {
375 true => Ordering::Less,
376 false => Ordering::Greater,
377 })
378 .unwrap()
379 }
380
381 fn node_total_cost(&mut self, enode: &TrigLanguage) -> Option<f64> {
382 let eg = &self.egraph;
383 let has_cost = |id| self.get_cost(eg.find(id)).0 >= 0.0;
384 if enode.all(has_cost) {
385 Some(self.node_cost(enode))
386 } else {
387 None
388 }
389 }
390
391 fn node_cost(&self, enode: &TrigLanguage) -> f64 {
392 let op_cost = match enode {
393 TrigLanguage::Constant(_) => 0.5,
394 TrigLanguage::Neg(_) => 1.0,
395 TrigLanguage::Add(_) | TrigLanguage::Sub(_) => 1.0,
396 TrigLanguage::Mul(_) | TrigLanguage::Div(_) => 5.0,
397 TrigLanguage::Sqrt(_) | TrigLanguage::Sin(_) | TrigLanguage::Cos(_) => 50.0,
398 TrigLanguage::Pow(_) => 100.0,
399 _ => 0.0,
400 };
401 enode.fold(op_cost, |acc, id| acc + self.get_cost(id).0)
402 }
403
404 #[inline(always)]
405 fn get_cost(&self, id: Id) -> &(f64, TrigLanguage) {
406 &self.costs[unsafe { std::mem::transmute::<Id, u32>(id) } as usize]
408 }
409
410 #[inline(always)]
411 fn put_cost(&mut self, id: Id, cost: (f64, TrigLanguage)) {
412 self.costs[unsafe { std::mem::transmute::<Id, u32>(id) } as usize] = cost;
413 }
414}
415
416struct TrigCostFn;
417impl CostFunction<TrigLanguage> for TrigCostFn {
418 type Cost = f64;
419 fn cost<C>(&mut self, enode: &TrigLanguage, mut costs: C) -> Self::Cost
420 where
421 C: FnMut(Id) -> Self::Cost,
422 {
423 let op_cost = match enode {
424 TrigLanguage::Constant(_) => 0.5,
425 TrigLanguage::Neg(_) => 1.0,
426 TrigLanguage::Add(_) | TrigLanguage::Sub(_) => 1.0,
427 TrigLanguage::Mul(_) | TrigLanguage::Div(_) => 5.0,
428 TrigLanguage::Pow(_)
429 | TrigLanguage::Sqrt(_)
430 | TrigLanguage::Sin(_)
431 | TrigLanguage::Cos(_) => 50.0,
432 _ => 0.0,
433 };
434
435 enode.fold(op_cost, |acc, id| acc + costs(id))
436 }
437}
438
439pub fn can_multiply(a: NotNan<f64>, b: NotNan<f64>) -> bool {
440 if !a.is_zero() && a.is_positive() && a < NotNan::new(1e-15).unwrap() {
441 return false;
442 }
443 if !b.is_zero() && b.is_positive() && b < NotNan::new(1e-15).unwrap() {
444 return false;
445 }
446 if !a.is_zero() && a.is_negative() && a > NotNan::new(-1e-15).unwrap() {
447 return false;
448 }
449 if !b.is_zero() && b.is_negative() && b > NotNan::new(-1e-15).unwrap() {
450 return false;
451 }
452 if a > NotNan::new(1e15).unwrap() || b > NotNan::new(1e15).unwrap() {
453 return false;
454 }
455 if a < NotNan::new(-1e15).unwrap() || b < NotNan::new(-1e15).unwrap() {
456 return false;
457 }
458 a.is_finite() && b.is_finite() && !a.is_subnormal() && !b.is_subnormal()
459}
460
461pub fn can_divide(a: NotNan<f64>, b: NotNan<f64>) -> bool {
462 if !a.is_zero() && a.is_positive() && a < NotNan::new(1e-15).unwrap() {
463 return false;
464 }
465 if !b.is_zero() && b.is_positive() && b < NotNan::new(1e-15).unwrap() {
466 return false;
467 }
468 if !a.is_zero() && a.is_negative() && a > NotNan::new(-1e-15).unwrap() {
469 return false;
470 }
471 if !b.is_zero() && b.is_negative() && b > NotNan::new(-1e-15).unwrap() {
472 return false;
473 }
474 if a > NotNan::new(1e15).unwrap() || b > NotNan::new(1e15).unwrap() {
475 return false;
476 }
477 if a < NotNan::new(-1e15).unwrap() || b < NotNan::new(-1e15).unwrap() {
478 return false;
479 }
480 a.is_finite() && b.is_finite() && !b.is_zero() && !b.is_subnormal() && !a.is_subnormal()
481}
482
483#[derive(Default)]
484pub struct ConstantFold;
485impl Analysis<TrigLanguage> for ConstantFold {
486 type Data = Option<(Constant, PatternAst<TrigLanguage>)>;
487
488 fn make(egraph: &mut EGraph, enode: &TrigLanguage) -> Self::Data {
489 let x = |i: &Id| egraph[*i].data.as_ref().map(|d| d.0);
490 Some(match enode {
491 TrigLanguage::Constant(c) => (*c, format!("{}", c).parse().unwrap()),
493 TrigLanguage::Add([a, b]) if can_multiply(x(a)?, x(b)?) => (
494 x(a)? + x(b)?,
495 format!("(+ {} {})", x(a)?, x(b)?).parse().unwrap(),
496 ),
497 TrigLanguage::Sub([a, b]) if can_multiply(x(a)?, x(b)?) => (
498 x(a)? - x(b)?,
499 format!("(- {} {})", x(a)?, x(b)?).parse().unwrap(),
500 ),
501 TrigLanguage::Mul([a, b]) if can_multiply(x(a)?, x(b)?) => (
502 x(a)? * x(b)?,
503 format!("(* {} {})", x(a)?, x(b)?).parse().unwrap(),
504 ),
505 TrigLanguage::Div([a, b]) if can_divide(x(a)?, x(b)?) => (
506 x(a)? / x(b)?,
507 format!("(/ {} {})", x(a)?, x(b)?).parse().unwrap(),
508 ),
509 TrigLanguage::Neg([a]) => (-x(a)?, format!("(~ {})", x(a)?).parse().unwrap()),
510 _ => return None,
531 })
532 }
533
534 fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
545 merge_option(to, from, |a, b| {
546 assert_eq!(a.0, b.0, "Merged non-equal constants");
547 DidMerge(false, false)
548 })
549 }
550
551 fn modify(egraph: &mut EGraph, id: Id) {
552 let data = egraph[id].data.clone();
553 if let Some((c, pat)) = data {
554 if egraph.are_explanations_enabled() {
555 egraph.union_instantiations(
556 &pat,
557 &format!("{}", c).parse().unwrap(),
558 &Default::default(),
559 "constant_fold".to_string(),
560 );
561 } else {
562 let added = egraph.add(TrigLanguage::Constant(c));
563 egraph.union(id, added);
564 }
565 egraph[id].nodes.retain(|n| n.is_leaf());
567
568 #[cfg(debug_assertions)]
569 egraph[id].assert_unique_leaves();
570 }
571 }
572}
573
574fn make_rules() -> Vec<Rewrite<TrigLanguage, ConstantFold>> {
575 vec![
576 rewrite!("+-commutative"; "(+ ?a ?b)" => "(+ ?b ?a)"),
578 rewrite!("*-commutative"; "(* ?a ?b)" => "(* ?b ?a)"),
579 rewrite!("associate-+r+"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
581 rewrite!("associate-+l+"; "(+ (+ ?a ?b) ?c)" => "(+ ?a (+ ?b ?c))"),
582 rewrite!("associate-+r-"; "(+ ?a (- ?b ?c))" => "(- (+ ?a ?b) ?c)"),
583 rewrite!("associate-+l-"; "(+ (- ?a ?b) ?c)" => "(- ?a (- ?b ?c))"),
584 rewrite!("associate--r+"; "(- ?a (+ ?b ?c))" => "(- (- ?a ?b) ?c)"),
585 rewrite!("associate--l+"; "(- (+ ?a ?b) ?c)" => "(+ ?a (- ?b ?c))"),
586 rewrite!("associate--l-"; "(- (- ?a ?b) ?c)" => "(- ?a (+ ?b ?c))"),
587 rewrite!("associate--r-"; "(- ?a (- ?b ?c))" => "(+ (- ?a ?b) ?c)"),
588 rewrite!("associate-*r*"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
589 rewrite!("associate-*l*"; "(* (* ?a ?b) ?c)" => "(* ?a (* ?b ?c))"),
590 rewrite!("associate-*r/"; "(* ?a (/ ?b ?c))" => "(/ (* ?a ?b) ?c)"),
591 rewrite!("associate-*l/"; "(* (/ ?a ?b) ?c)" => "(/ (* ?a ?c) ?b)"),
592 rewrite!("associate-/r*"; "(/ ?a (* ?b ?c))" => "(/ (/ ?a ?b) ?c)"),
593 rewrite!("associate-/r/"; "(/ ?a (/ ?b ?c))" => "(* (/ ?a ?b) ?c)"),
594 rewrite!("associate-/l/"; "(/ (/ ?b ?c) ?a)" => "(/ ?b (* ?a ?c))"),
595 rewrite!("associate-/l*"; "(/ (* ?b ?c) ?a)" => "(* ?b (/ ?c ?a))"),
596 rewrite!("count-2"; "(+ ?x ?x)" => "(* 2 ?x)"),
598 rewrite!("distribute-lft-in"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
600 rewrite!("distribute-rgt-in"; "(* ?a (+ ?b ?c))" => "(+ (* ?b ?a) (* ?c ?a))"),
601 rewrite!("distribute-lft-out"; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
602 rewrite!("distribute-lft-out--"; "(- (* ?a ?b) (* ?a ?c))" => "(* ?a (- ?b ?c))"),
603 rewrite!("distribute-rgt-out"; "(+ (* ?b ?a) (* ?c ?a))" => "(* ?a (+ ?b ?c))"),
604 rewrite!("distribute-rgt-out--"; "(- (* ?b ?a) (* ?c ?a))" => "(* ?a (- ?b ?c))"),
605 rewrite!("distribute-lft1-in"; "(+ (* ?b ?a) ?a)" => "(* (+ ?b 1) ?a)"),
606 rewrite!("distribute-rgt1-in"; "(+ ?a (* ?c ?a))" => "(* (+ ?c 1) ?a)"),
607 rewrite!("distribute-lft-neg-in"; "(~ (* ?a ?b))" => "(* (~ ?a) ?b)"),
609 rewrite!("distribute-rgt-neg-in"; "(~ (* ?a ?b))" => "(* ?a (~ ?b))"),
610 rewrite!("distribute-lft-neg-out"; "(* (~ ?a) ?b)" => "(~ (* ?a ?b))"),
611 rewrite!("distribute-rgt-neg-out"; "(* ?a (~ ?b))" => "(~ (* ?a ?b))"),
612 rewrite!("distribute-neg-in"; "(~ (+ ?a ?b))" => "(+ (~ ?a) (~ ?b))"),
613 rewrite!("distribute-neg-out"; "(+ (~ ?a) (~ ?b))" => "(~ (+ ?a ?b))"),
614 rewrite!("distribute-frac-neg"; "(/ (~ ?a) ?b)" => "(~ (/ ?a ?b))"),
615 rewrite!("distribute-frac-neg2"; "(/ ?a (~ ?b))" => "(~ (/ ?a ?b))"),
616 rewrite!("distribute-neg-frac"; "(~ (/ ?a ?b))" => "(/ (~ ?a) ?b)"),
617 rewrite!("distribute-neg-frac2"; "(~ (/ ?a ?b))" => "(/ ?a (~ ?b))"),
618 rewrite!("cancel-sign-sub"; "(- ?a (* (~ ?b) ?c))" => "(+ ?a (* ?b ?c))"),
620 rewrite!("cancel-sign-sub-inv"; "(- ?a (* ?b ?c))" => "(+ ?a (* (~ ?b) ?c))"),
621 rewrite!("swap-sqr"; "(* (* ?a ?b) (* ?a ?b))" => "(* (* ?a ?a) (* ?b ?b))"),
623 rewrite!("unswap-sqr"; "(* (* ?a ?a) (* ?b ?b))" => "(* (* ?a ?b) (* ?a ?b))"),
624 rewrite!("difference-of-squares"; "(- (* ?a ?a) (* ?b ?b))" => "(* (+ ?a ?b) (- ?a ?b))"),
625 rewrite!("difference-of-sqr-1"; "(- (* ?a ?a) 1)" => "(* (+ ?a 1) (- ?a 1))"),
626 rewrite!("difference-of-sqr--1"; "(+ (* ?a ?a) -1)" => "(* (+ ?a 1) (- ?a 1))"),
627 rewrite!("pow-sqr"; "(* (pow ?a ?b) (pow ?a ?b))" => "(pow ?a (* 2 ?b))"),
628 rewrite!("remove-double-div"; "(/ 1 (/ 1 ?a))" => "?a"),
639 rewrite!("rgt-mult-inverse"; "(* ?a (/ 1 ?a))" => "1" if is_not_zero("?a")),
640 rewrite!("lft-mult-inverse"; "(* (/ 1 ?a) ?a)" => "1" if is_not_zero("?a")),
641 rewrite!("+-inverses"; "(- ?a ?a)" => "0"),
645 rewrite!("div0"; "(/ 0 ?a)" => "0" if is_not_zero("?a")),
646 rewrite!("mul0-lft"; "(* 0 ?a)" => "0"),
647 rewrite!("mul0-rgt"; "(* ?a 0)" => "0"),
648 rewrite!("*-inverses"; "(/ ?a ?a)" => "1" if is_not_zero("?a")),
649 rewrite!("+-lft-identity"; "(+ 0 ?a)" => "?a"),
651 rewrite!("+-rgt-identity"; "(+ ?a 0)" => "?a"),
652 rewrite!("--rgt-identity"; "(- ?a 0)" => "?a"),
653 rewrite!("sub0-neg"; "(- 0 ?a)" => "(~ ?a)"),
654 rewrite!("remove-double-neg"; "(~ (~ ?a))" => "?a"),
655 rewrite!("*-lft-identity"; "(* 1 ?a)" => "?a"),
656 rewrite!("*-rgt-identity"; "(* ?a 1)" => "?a"),
657 rewrite!("/-rgt-identity"; "(/ ?a 1)" => "?a"),
658 rewrite!("mul-1-neg"; "(* -1 ?a)" => "(~ ?a)"),
659 rewrite!("sub-neg"; "(- ?a ?b)" => "(+ ?a (~ ?b))"),
661 rewrite!("unsub-neg"; "(+ ?a (~ ?b))" => "(- ?a ?b)"),
662 rewrite!("neg-sub0"; "(~ ?b)" => "(- 0 ?b)"),
663 rewrite!("neg-mul-1"; "(~ ?a)" => "(* -1 ?a)"),
664 rewrite!("div-inv"; "(/ ?a ?b)" => "(* ?a (/ 1 ?b))"),
666 rewrite!("un-div-inv"; "(* ?a (/ 1 ?b))" => "(/ ?a ?b)"),
667 rewrite!("*-un-lft-identity"; "?a" => "(* 1 ?a)"),
674 rewrite!("sum-cubes"; "(+ (pow ?a 3) (pow ?b 3))" => "(* (+ (* ?a ?a) (- (* ?b ?b) (* ?a ?b))) (+ ?a ?b))"),
676 rewrite!("difference-cubes"; "(- (pow ?a 3) (pow ?b 3))" => "(* (+ (* ?a ?a) (+ (* ?b ?b) (* ?a ?b))) (- ?a ?b))"),
677 rewrite!("div-sub"; "(/ (- ?a ?b) ?c)" => "(- (/ ?a ?c) (/ ?b ?c))"),
682 rewrite!("times-frac"; "(/ (* ?a ?b) (* ?c ?d))" => "(* (/ ?a ?c) (/ ?b ?d))"),
683 rewrite!("sub-div"; "(- (/ ?a ?c) (/ ?b ?c))" => "(/ (- ?a ?b) ?c)"),
685 rewrite!("frac-add"; "(+ (/ ?a ?b) (/ ?c ?d))" => "(/ (+ (* ?a ?d) (* ?b ?c)) (* ?b ?d))"),
686 rewrite!("frac-sub"; "(- (/ ?a ?b) (/ ?c ?d))" => "(/ (- (* ?a ?d) (* ?b ?c)) (* ?b ?d))"),
687 rewrite!("frac-times"; "(* (/ ?a ?b) (/ ?c ?d))" => "(/ (* ?a ?c) (* ?b ?d))"),
688 rewrite!("frac-2neg"; "(/ ?a ?b)" => "(/ (~ ?a) (~ ?b))"),
689 rewrite!("rem-square-sqrt"; "(* (sqrt ?x) (sqrt ?x))" => "?x"),
691 rewrite!("sqr-neg"; "(* (~ ?x) (~ ?x))" => "(* ?x ?x)"),
693 rewrite!("sqrt-pow2"; "(pow (sqrt ?x) ?y)" => "(pow ?x (/ ?y 2))"),
695 rewrite!("sqrt-unprod"; "(* (sqrt ?x) (sqrt ?y))" => "(sqrt (* ?x ?y))"),
696 rewrite!("sqrt-undiv"; "(/ (sqrt ?x) (sqrt ?y))" => "(sqrt (/ ?x ?y))"),
697 rewrite!("sqrt-1"; "(sqrt 1)" => "1"),
699 rewrite!("sqrt-0"; "(sqrt 0)" => "0"),
700 rewrite!("sqrt-can"; "(/ (sqrt ?x) ?x)" => "(/ 1 (sqrt ?x))"),
701 rewrite!("sqrt-can-inv"; "(/ ?x (sqrt ?x))" => "(sqrt ?x)"),
702 rewrite!("sqrt-can-rev"; "(/ 1 (sqrt ?x))" => "(/ (sqrt ?x) ?x)"),
703 rewrite!("add-sqr-sqrt"; "?x" => "(* (sqrt ?x) (sqrt ?x))" if is_non_negative_conservative("?x")),
708 rewrite!("cube-prod"; "(pow (* ?x ?y) 3)" => "(* (pow ?x 3) (pow ?y 3))"),
713 rewrite!("cube-div"; "(pow (/ ?x ?y) 3)" => "(/ (pow ?x 3) (pow ?y 3))"),
714 rewrite!("cube-mult"; "(pow ?x 3)" => "(* ?x (* ?x ?x))"),
715 rewrite!("cube-unmult"; "(* ?x (* ?x ?x))" => "(pow ?x 3)"),
717 rewrite!("unpow-1"; "(pow ?a -1)" => "(/ 1 ?a)"),
719 rewrite!("unpow1"; "(pow ?a 1)" => "?a"),
721 rewrite!("unpow0"; "(pow ?a 0)" => "1" if is_not_zero("?a")),
723 rewrite!("pow-base-1"; "(pow 1 ?a)" => "1"),
724 rewrite!("pow1"; "?a" => "(pow ?a 1)"),
726 rewrite!("unpow1/2"; "(pow ?a 0.5)" => "(sqrt ?a)"),
728 rewrite!("unpow2"; "(pow ?a 2)" => "(* ?a ?a)"),
729 rewrite!("unpow3"; "(pow ?a 3)" => "(* (* ?a ?a) ?a)"),
730 rewrite!("pow-plus"; "(* (pow ?a ?b) ?a)" => "(pow ?a (+ ?b 1))"),
731 rewrite!("pow-prod-down"; "(* (pow ?b ?a) (pow ?c ?a))" => "(pow (* ?b ?c) ?a)"),
733 rewrite!("pow-prod-up"; "(* (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (+ ?b ?c))"),
734 rewrite!("pow-flip"; "(/ 1 (pow ?a ?b))" => "(pow ?a (~ ?b))"),
735 rewrite!("pow-neg"; "(pow ?a (~ ?b))" => "(/ 1 (pow ?a ?b))" if is_not_zero("?a")),
736 rewrite!("pow-div"; "(/ (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (- ?b ?c))"),
737 rewrite!("pow1/2"; "(sqrt ?a)" => "(pow ?a 0.5)"),
739 rewrite!("pow2"; "(* ?a ?a)" => "(pow ?a 2)"),
740 rewrite!("pow3"; "(* (* ?a ?a) ?a)" => "(pow ?a 3)"),
741 rewrite!("pow-base-0"; "(pow 0 ?a)" => "0" if is_not_zero("?a")),
751 rewrite!("inv-pow"; "(/ 1 ?a)" => "(pow ?a -1)"),
753 rewrite!("sin-0"; "(sin 0)" => "0"),
755 rewrite!("cos-0"; "(cos 0)" => "1"),
756 rewrite!("sin-neg"; "(sin (~ ?x))" => "(~ (sin ?x))"),
758 rewrite!("neg-sig"; "(~ (sin ?x))" => "(sin (~ ?x))"),
759 rewrite!("cos-neg"; "(cos (~ ?x))" => "(cos ?x)"),
760 rewrite!("neg-cos"; "(cos ?x)" => "(cos (~ ?x))"),
761 rewrite!("sqr-sin-b"; "(* (sin ?x) (sin ?x))" => "(- 1 (* (cos ?x) (cos ?x)))"),
763 rewrite!("sqr-cos-b"; "(* (cos ?x) (cos ?x))" => "(- 1 (* (sin ?x) (sin ?x)))"),
764 rewrite!("cos-sin-sum"; "(+ (* (cos ?a) (cos ?a)) (* (sin ?a) (sin ?a)))" => "1"),
766 rewrite!("1-sub-cos"; "(- 1 (* (cos ?a) (cos ?a)))" => "(* (sin ?a) (sin ?a))"),
767 rewrite!("1-sub-sin"; "(- 1 (* (sin ?a) (sin ?a)))" => "(* (cos ?a) (cos ?a))"),
768 rewrite!("-1-add-cos"; "(+ (* (cos ?a) (cos ?a)) -1)" => "(~ (* (sin ?a) (sin ?a)))"),
769 rewrite!("-1-add-sin"; "(+ (* (sin ?a) (sin ?a)) -1)" => "(~ (* (cos ?a) (cos ?a)))"),
770 rewrite!("sub-1-cos"; "(- (* (cos ?a) (cos ?a)) 1)" => "(~ (* (sin ?a) (sin ?a)))"),
771 rewrite!("sub-1-sin"; "(- (* (sin ?a) (sin ?a)) 1)" => "(~ (* (cos ?a) (cos ?a)))"),
772 rewrite!("sin-PI/6"; "(sin (/ (pi) 6))" => "0.5"),
773 rewrite!("sin-PI/4"; "(sin (/ (pi) 4))" => "(/ (sqrt 2) 2)"),
774 rewrite!("sin-PI*0.25"; "(sin (* (pi) 0.25))" => "(/ (sqrt 2) 2)"),
775 rewrite!("sin-PI*-0.25"; "(sin (* (pi) -0.25))" => "(~ (/ (sqrt 2) 2))"),
776 rewrite!("sin-PI/3"; "(sin (/ (pi) 3))" => "(/ (sqrt 3) 2)"),
777 rewrite!("sin-PI/2"; "(sin (/ (pi) 2))" => "1"),
778 rewrite!("sin-PI*0.5"; "(sin (* (pi) 0.5))" => "1"),
779 rewrite!("sin-PI"; "(sin (pi))" => "0"),
780 rewrite!("sin-+PI"; "(sin (+ ?x (pi)))" => "(~ (sin ?x))"),
781 rewrite!("sin-+PI/2"; "(sin (+ ?x (/ (pi) 2)))" => "(cos ?x)"),
782 rewrite!("cos-PI/6"; "(cos (/ (pi) 6))" => "(/ (sqrt 3) 2)"),
783 rewrite!("cos-PI/4"; "(cos (/ (pi) 4))" => "(/ (sqrt 2) 2)"),
784 rewrite!("cos-PI*0.25"; "(cos (* (pi) 0.25))" => "(/ (sqrt 2) 2)"),
785 rewrite!("cos-PI/3"; "(cos (/ (pi) 3))" => "0.5"),
786 rewrite!("cos-PI/2"; "(cos (/ (pi) 2))" => "0"),
787 rewrite!("cos-PI*0.5"; "(cos (* (pi) 0.5))" => "0"),
788 rewrite!("cos-PI"; "(cos (pi))" => "-1"),
789 rewrite!("cos-+PI"; "(cos (+ ?x (pi)))" => "(~ (cos ?x))"),
790 rewrite!("cos-+PI/2"; "(cos (+ ?x (* (pi) 0.5)))" => "(~ (sin ?x))"),
791 rewrite!("hang-0p-tan"; "(/ (sin ?a) (+ 1 (cos ?a)))" => "(/ (sin (/ ?a 2)) (cos (/ ?a 2)))"),
792 rewrite!("hang-0m-tan"; "(/ (~ (sin ?a)) (+ 1 (cos ?a)))" => "(/ (sin (/ (~ ?a) 2)) (cos (/ (~ ?a) 2)))"),
793 rewrite!("hang-p0-tan"; "(/ (- 1 (cos ?a)) (sin ?a))" => "(/ (sin (/ ?a 2)) (cos (/ ?a 2)))"),
794 rewrite!("hang-m0-tan"; "(/ (- 1 (cos ?a)) (~ (sin ?a)))" => "(/ (sin (/ (~ ?a) 2)) (cos (/ (~ ?a) 2)))"),
795 rewrite!("tan-hang-0p"; "(/ (sin (* ?a 0.5)) (cos (* ?a 0.5)))" => "(/ (sin ?a) (+ 1 (cos ?a)))"),
796 rewrite!("tan-hang-0m"; "(/ (sin (* (~ ?a) 0.5)) (cos (* (~ ?a) 0.5)))" => "(/ (~ (sin ?a)) (+ 1 (cos ?a)))"),
797 rewrite!("tan-hang-p0"; "(/ (sin (* ?a 0.5)) (cos (* ?a 0.5)))" => "(/ (- 1 (cos ?a)) (sin ?a))"),
798 rewrite!("tan-hang-m0"; "(/ (sin (* (~ ?a) 0.5)) (cos (* (~ ?a) 0.5)))" => "(/ (- 1 (cos ?a)) (~ (sin ?a)))" if is_not_zero("?a")),
799 rewrite!("csc-cot"; "(/ 1 (* (sin ?a) (sin ?a)))" => "(+ 1 (/ (* (cos ?a) (cos ?a)) (* (sin ?a) (sin ?a))))"),
801 rewrite!("sec-tan"; "(/ 1 (* (cos ?a) (cos ?a)))" => "(+ 1 (/ (* (sin ?a) (sin ?a)) (* (cos ?a) (cos ?a))))"),
802 rewrite!("csc-sec"; "(* (/ 1 (* (cos ?a) (cos ?a))) (/ 1 (* (sin ?a) (sin ?a))))" => "(+ (/ 1 (* (cos ?a) (cos ?a))) (/ 1 (* (sin ?a) (sin ?a))))"),
803 rewrite!("sin-sum"; "(sin (+ ?x ?y))" => "(+ (* (sin ?x) (cos ?y)) (* (cos ?x) (sin ?y)))"),
804 rewrite!("cos-sum"; "(cos (+ ?x ?y))" => "(- (* (cos ?x) (cos ?y)) (* (sin ?x) (sin ?y)))"),
805 rewrite!("sin-diff"; "(sin (- ?x ?y))" => "(- (* (sin ?x) (cos ?y)) (* (cos ?x) (sin ?y)))"),
808 rewrite!("cos-diff"; "(cos (- ?x ?y))" => "(+ (* (cos ?x) (cos ?y)) (* (sin ?x) (sin ?y)))"),
809 rewrite!("sin-2"; "(sin (* 2 ?x))" => "(* 2 (* (sin ?x) (cos ?x)))"),
810 rewrite!("sin-3"; "(sin (* 3 ?x))" => "(- (* 3 (sin ?x)) (* 4 (pow (sin ?x) 3)))"),
811 rewrite!("2-sin"; "(* 2 (* (sin ?x) (cos ?x)))" => "(sin (* 2 ?x))"),
812 rewrite!("3-sin"; "(- (* 3 (sin ?x)) (* 4 (pow (sin ?x) 3)))" => "(sin (* 3 ?x))"),
813 rewrite!("cos-2"; "(cos (* 2 ?x))" => "(- (* (cos ?x) (cos ?x)) (* (sin ?x) (sin ?x)))"),
814 rewrite!("cos-3"; "(cos (* 3 ?x))" => "(- (* 4 (pow (cos ?x) 3)) (* 3 (cos ?x)))"),
815 rewrite!("2-cos"; "(- (* (cos ?x) (cos ?x)) (* (sin ?x) (sin ?x)))" => "(cos (* 2 ?x))"),
816 rewrite!("3-cos"; "(- (* 4 (pow (cos ?x) 3)) (* 3 (cos ?x)))" => "(cos (* 3 ?x))"),
817 rewrite!("sqr-sin-a"; "(* (sin ?x) (sin ?x))" => "(- 0.5 (* 0.5 (cos (* 2 ?x))))"),
819 rewrite!("sqr-cos-a"; "(* (cos ?x) (cos ?x))" => "(+ 0.5 (* 0.5 (cos (* 2 ?x))))"),
820 rewrite!("diff-sin"; "(- (sin ?x) (sin ?y))" => "(* 2 (* (sin (/ (- ?x ?y) 2)) (cos (/ (+ ?x ?y) 2))))"),
821 rewrite!("diff-cos"; "(- (cos ?x) (cos ?y))" => "(* -2 (* (sin (/ (- ?x ?y) 2)) (sin (/ (+ ?x ?y) 2))))"),
822 rewrite!("sum-sin"; "(+ (sin ?x) (sin ?y))" => "(* 2 (* (sin (/ (+ ?x ?y) 2)) (cos (/ (- ?x ?y) 2))))"),
823 rewrite!("sum-cos"; "(+ (cos ?x) (cos ?y))" => "(* 2 (* (cos (/ (+ ?x ?y) 2)) (cos (/ (- ?x ?y) 2))))"),
824 rewrite!("cos-mult"; "(* (cos ?x) (cos ?y))" => "(/ (+ (cos (+ ?x ?y)) (cos (- ?x ?y))) 2)"),
825 rewrite!("sin-mult"; "(* (sin ?x) (sin ?y))" => "(/ (- (cos (- ?x ?y)) (cos (+ ?x ?y))) 2)"),
826 rewrite!("sin-cos-mult"; "(* (sin ?x) (cos ?y))" => "(/ (+ (sin (- ?x ?y)) (sin (+ ?x ?y))) 2)"),
827 rewrite!("tan-2"; "(/ (sin (* 2 ?x)) (cos (* 2 ?x)))" => "(/ (* 2 (/ (sin ?x) (cos ?x))) (- 1 (* (/ (sin ?x) (cos ?x)) (/ (sin ?x) (cos ?x)))))"),
828 rewrite!("2-tan"; "(/ (* 2 (/ (sin ?x) (cos ?x))) (- 1 (* (/ (sin ?x) (cos ?x)) (/ (sin ?x) (cos ?x)))))" => "(/ (sin (* 2 ?x)) (cos (* 2 ?x)))"),
829 ]
830}
831
832fn to_egg_expr(expr: &Expression) -> RecExpr<TrigLanguage> {
833 expr.to_string().parse().unwrap()
834}
835
836use crate::qgl::lexer::Lexer;
837use crate::qgl::lexer::Token;
838
839fn _from_egg_expr(tokens: Vec<Token>) -> Expression {
847 let (start, op_token) = if tokens[0] == Token::LParen {
848 assert!(tokens.last() == Some(&Token::RParen));
849 (1, tokens[1].clone())
850 } else {
851 (0, tokens[0].clone())
852 };
853
854 if let Token::Number(num) = op_token {
855 return Expression::from_float(num.parse::<f64>().unwrap());
856 }
857
858 if let Token::Ident(ref id) = op_token {
859 if id == "pi" {
860 return Expression::Pi;
861 }
862 if id != "sin" && id != "cos" && id != "sqrt" && id != "pow" {
863 return Expression::Variable(id.to_string());
864 }
865 }
866
867 let mut operands = vec![];
868 let mut i = start + 1;
869 while i < tokens.len() {
870 let token = &tokens[i];
871 if *token == Token::LParen {
872 let mut num_open_parenthesis = 1;
873 let start = i + 1;
874 for (j, token) in tokens.iter().enumerate().skip(i + 1) {
875 if *token == Token::LParen {
876 num_open_parenthesis += 1;
877 } else if *token == Token::RParen {
878 num_open_parenthesis -= 1;
879 }
880
881 if num_open_parenthesis == 0 {
882 operands.push(_from_egg_expr(tokens[start..j + 1].to_vec()));
883 i = j;
884 break;
885 }
886 }
887 } else if *token == Token::RParen {
888 assert_eq!(i, tokens.len() - 1);
889 } else {
890 operands.push(_from_egg_expr(tokens[i..i + 1].to_vec()));
891 }
892 i += 1;
893 }
894
895 match op_token {
896 Token::Ident(id) => match id.clone().as_str() {
897 "sin" => Expression::Sin(Box::new(operands[0].clone())),
898 "cos" => Expression::Cos(Box::new(operands[0].clone())),
899 "sqrt" => Expression::Sqrt(Box::new(operands[0].clone())),
900 "pow" => Expression::Pow(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
901 _ => panic!("Invalid operator during parsing of egg expression"),
902 },
903 Token::Negation => Expression::Neg(Box::new(operands[0].clone())),
904 Token::Op(op) => match op {
905 '+' => Expression::Add(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
906 '-' => Expression::Sub(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
907 '*' => Expression::Mul(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
908 '/' => Expression::Div(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
909 _ => panic!("Invalid operator during parsing of egg expression"),
910 },
911 _ => panic!("Invalid token during parsing of egg expression"),
912 }
913}
914
915fn from_egg_expr(expr: RecExpr<TrigLanguage>) -> Expression {
916 let expr_str = expr.to_string();
917 let expr_tokens = Lexer::new(&expr_str).collect::<Vec<_>>();
918 if expr_tokens.is_empty() {
919 panic!("Failure to lex expression: {}", expr_str);
920 }
921
922 let mut grouped_tokens = vec![];
924 let mut i = 0;
925 while i < expr_tokens.len() {
926 if expr_tokens[i] == Token::Op('-') && i < expr_tokens.len() - 1 {
927 if let Token::Number(n) = &expr_tokens[i + 1] {
928 let n = n.parse::<f64>().unwrap();
929 grouped_tokens.push(Token::Number((-n).to_string()));
930 i += 1
931 } else {
932 grouped_tokens.push(expr_tokens[i].clone());
933 }
934 } else {
935 grouped_tokens.push(expr_tokens[i].clone());
936 }
937 i += 1;
938 }
939
940 _from_egg_expr(grouped_tokens)
941}
942
943pub fn simplify(expr: &Expression) -> Expression {
945 let expr: RecExpr<TrigLanguage> = to_egg_expr(expr);
947
948 let runner = Runner::default().with_expr(&expr).run(&make_rules());
951 let mut extractor = TrigExprExtractor::new(&runner.egraph);
952
953 let root = runner.roots[0];
955
956 let best = extractor.extract_best(root);
958 from_egg_expr(best)
959}
960
961#[allow(dead_code)]
962pub fn extract_best_sine(expr: Expression) -> Option<Expression> {
963 let expr: RecExpr<TrigLanguage> = to_egg_expr(&expr);
964 let runner = Runner::default()
965 .with_expr(&expr)
966 .with_iter_limit(1000)
967 .with_node_limit(10000000)
968 .run(&make_rules());
969 let egraph = &runner.egraph;
970 let extractor = SineExtractor::new(egraph);
971 let root = runner.roots[0];
972 let best = extractor.extract_sine(root);
973 best.map(from_egg_expr)
974}
975
976#[allow(dead_code)]
977pub fn simplify_complex(expr: ComplexExpression) -> ComplexExpression {
978 let ComplexExpression { real, imag } = expr;
979
980 let real_expr: RecExpr<TrigLanguage> = to_egg_expr(&real);
981 let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(&imag);
982
983 let runner = Runner::default()
984 .with_expr(&real_expr)
985 .with_expr(&imag_expr)
986 .run(&make_rules());
987 let mut extractor = TrigExprExtractor::new(&runner.egraph);
988
989 let real_root = runner.roots[0];
991 let imag_root = runner.roots[1];
992
993 let best_real = extractor.extract_best(real_root);
995 let real_simple = from_egg_expr(best_real);
996
997 let best_imag = extractor.extract_best(imag_root);
998 let imag_simple = from_egg_expr(best_imag);
999
1000 ComplexExpression {
1001 real: real_simple,
1002 imag: imag_simple,
1003 }
1004}
1005
1006#[allow(dead_code)]
1007pub fn simplify_matrix_no_context(
1008 matrix_expression: &Vec<Vec<ComplexExpression>>,
1009) -> Vec<Vec<ComplexExpression>> {
1010 let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1011
1012 for row in matrix_expression {
1013 for expr in row {
1014 let ComplexExpression { real, imag } = expr;
1015 let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1016 let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1017 runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1018 }
1019 }
1020
1021 runner = runner.run(&make_rules());
1022 let extractor = Extractor::new(&runner.egraph, TrigCostFn);
1023
1024 let mut simplified_matrix = vec![vec![]; matrix_expression.len()];
1025 let nrows = matrix_expression.len();
1026 let ncols = matrix_expression[0].len();
1027
1028 for i in 0..nrows {
1029 for j in 0..ncols {
1030 let real_root = runner.roots[2 * (i * ncols + j)];
1031 let imag_root = runner.roots[2 * (i * ncols + j) + 1];
1032
1033 let (_, best_real) = extractor.find_best(real_root);
1034 let real_simple = from_egg_expr(best_real);
1035 let (_, best_imag) = extractor.find_best(imag_root);
1036 let imag_simple = from_egg_expr(best_imag);
1037
1038 simplified_matrix[i].push(ComplexExpression {
1039 real: real_simple,
1040 imag: imag_simple,
1041 });
1042 }
1043 }
1044
1045 simplified_matrix
1046}
1047
1048#[allow(dead_code)]
1049pub fn simplify_matrix(
1050 matrix_expression: &Vec<Vec<ComplexExpression>>,
1051) -> Vec<Vec<ComplexExpression>> {
1052 let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1053
1054 for row in matrix_expression {
1055 for expr in row {
1056 let ComplexExpression { real, imag } = expr;
1057 let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1058 let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1059 runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1060 }
1061 }
1062
1063 runner = runner.run(&make_rules());
1064 let mut extractor = TrigExprExtractor::new(&runner.egraph);
1065
1066 let mut simplified_matrix = vec![vec![]; matrix_expression.len()];
1067 let nrows = matrix_expression.len();
1068 let ncols = matrix_expression[0].len();
1069
1070 for i in 0..nrows {
1071 for j in 0..ncols {
1072 let real_root = runner.roots[2 * (i * ncols + j)];
1073 let imag_root = runner.roots[2 * (i * ncols + j) + 1];
1074
1075 let best_real = extractor.extract_best(real_root);
1076 let real_simple = from_egg_expr(best_real);
1077 let best_imag = extractor.extract_best(imag_root);
1078 let imag_simple = from_egg_expr(best_imag);
1079
1080 simplified_matrix[i].push(ComplexExpression {
1081 real: real_simple,
1082 imag: imag_simple,
1083 });
1084 }
1085 }
1086
1087 simplified_matrix
1088}
1089
1090pub fn simplify_expressions_iter<'a>(
1091 expression: impl Iterator<Item = &'a Expression>,
1092) -> Vec<Expression> {
1093 let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1094
1095 let mut num_expressions = 0;
1096 for expr in expression {
1097 let expr: RecExpr<TrigLanguage> = to_egg_expr(expr);
1098 runner = runner.with_expr(&expr);
1099 num_expressions += 1;
1100 }
1101
1102 runner = runner.run(&make_rules());
1103 let mut extractor = TrigExprExtractor::new(&runner.egraph);
1104
1105 let mut simplified_expressions = vec![];
1106
1107 for i in 0..num_expressions {
1108 let root = runner.roots[i];
1109 let best = extractor.extract_best(root);
1110 simplified_expressions.push(from_egg_expr(best));
1111 }
1112
1113 simplified_expressions
1114}
1115
1116#[allow(dead_code)]
1117pub fn simplify_expressions(expression: Vec<Expression>) -> Vec<Expression> {
1118 let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1119
1120 let mut num_expressions = 0;
1121 for expr in expression {
1122 let expr: RecExpr<TrigLanguage> = to_egg_expr(&expr);
1123 runner = runner.with_expr(&expr);
1124 num_expressions += 1;
1125 }
1126
1127 runner = runner.run(&make_rules());
1128 let mut extractor = TrigExprExtractor::new(&runner.egraph);
1129
1130 let mut simplified_expressions = vec![];
1131
1132 for i in 0..num_expressions {
1133 let root = runner.roots[i];
1134 let best = extractor.extract_best(root);
1135 simplified_expressions.push(from_egg_expr(best));
1136 }
1137
1138 simplified_expressions
1139}
1140
1141#[allow(dead_code)]
1142pub fn simplify_matrix_and_matvec(
1143 matrix_expression: &Vec<Vec<ComplexExpression>>,
1144 matvec_expression: &Vec<Vec<Vec<ComplexExpression>>>,
1145) -> (
1146 Vec<Vec<ComplexExpression>>,
1147 Vec<Vec<Vec<ComplexExpression>>>,
1148) {
1149 let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1150
1151 for row in matrix_expression {
1152 for expr in row {
1153 let ComplexExpression { real, imag } = expr;
1154 let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1155 let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1156 runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1157 }
1158 }
1159
1160 for mat in matvec_expression {
1161 for row in mat {
1162 for expr in row {
1163 let ComplexExpression { real, imag } = expr;
1164 let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1165 let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1166 runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1167 }
1168 }
1169 }
1170
1171 runner = runner.run(&make_rules());
1172 let mut extractor = TrigExprExtractor::new(&runner.egraph);
1173
1174 let mut simplified_matrix = vec![vec![]; matrix_expression.len()];
1175 let nrows = matrix_expression.len();
1176 let ncols = matrix_expression[0].len();
1177
1178 for i in 0..nrows {
1179 for j in 0..ncols {
1180 let real_root = runner.roots[2 * (i * ncols + j)];
1181 let imag_root = runner.roots[2 * (i * ncols + j) + 1];
1182
1183 let best_real = extractor.extract_best(real_root);
1184 let real_simple = from_egg_expr(best_real);
1185 let best_imag = extractor.extract_best(imag_root);
1186 let imag_simple = from_egg_expr(best_imag);
1187
1188 simplified_matrix[i].push(ComplexExpression {
1189 real: real_simple,
1190 imag: imag_simple,
1191 });
1192 }
1193 }
1194
1195 let matrix_expr_offset = 2 * nrows * ncols;
1196 let nmats = matvec_expression.len();
1197 if nmats == 0 {
1198 return (simplified_matrix, vec![]);
1199 }
1200 let nrows = matvec_expression[0].len();
1201 let ncols = matvec_expression[0][0].len();
1202 let mut simplified_matvec = vec![vec![vec![]; nrows]; nmats];
1203
1204 for m in 0..nmats {
1205 for i in 0..nrows {
1206 for j in 0..ncols {
1207 let real_root =
1208 runner.roots[matrix_expr_offset + 2 * (m * nrows * ncols + i * ncols + j)];
1209 let imag_root =
1210 runner.roots[matrix_expr_offset + 2 * (m * nrows * ncols + i * ncols + j) + 1];
1211
1212 let best_real = extractor.extract_best(real_root);
1213 let real_simple = from_egg_expr(best_real);
1214 let best_imag = extractor.extract_best(imag_root);
1215 let imag_simple = from_egg_expr(best_imag);
1216
1217 simplified_matvec[m][i].push(ComplexExpression {
1218 real: real_simple,
1219 imag: imag_simple,
1220 });
1221 }
1222 }
1223 }
1224
1225 (simplified_matrix, simplified_matvec)
1228}
1229
1230#[allow(dead_code)]
1290pub fn check_many_equality(expr1s: &[&Expression], expr2s: &[&Expression]) -> bool {
1291 let expr1s: Vec<RecExpr<TrigLanguage>> = expr1s.iter().map(|expr| to_egg_expr(expr)).collect();
1292 let expr2s: Vec<RecExpr<TrigLanguage>> = expr2s.iter().map(|expr| to_egg_expr(expr)).collect();
1293
1294 let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1295 .with_iter_limit(120)
1296 .with_node_limit(100_000_000);
1297 for expr1 in &expr1s {
1298 runner = runner.with_expr(expr1);
1299 }
1300 for expr2 in &expr2s {
1301 runner = runner.with_expr(expr2);
1302 }
1303 runner = runner.run(&make_rules());
1304
1305 for (expr1, expr2) in expr1s.iter().zip(expr2s.iter()) {
1306 if runner.egraph.equivs(expr1, expr2).is_empty() {
1307 return false;
1308 }
1309 }
1310
1311 true
1312}
1313
1314#[allow(dead_code)]
1315pub fn check_equality(expr: &Expression, expr2: &Expression) -> bool {
1316 let expr1: RecExpr<TrigLanguage> = to_egg_expr(expr);
1317 let expr2: RecExpr<TrigLanguage> = to_egg_expr(expr2);
1318
1319 let runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1320 .with_expr(&expr1)
1321 .with_expr(&expr2)
1322 .with_iter_limit(100)
1323 .with_node_limit(1_000_000)
1324 .run(&make_rules());
1325
1326 !runner.egraph.equivs(&expr1, &expr2).is_empty()
1327 }
1339
1340#[allow(dead_code)]
1341fn print_equality(s1: &str, s2: &str) {
1342 let expr1: RecExpr<TrigLanguage> = s1.parse().unwrap();
1343 let expr2: RecExpr<TrigLanguage> = s2.parse().unwrap();
1344
1345 let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1346 .with_explanations_enabled()
1347 .with_expr(&expr1)
1348 .run(&make_rules());
1349 println!(
1350 "{}",
1351 runner.explain_equivalence(&expr1, &expr2).get_flat_string()
1352 );
1353 }
1361
1362#[allow(dead_code)]
1363fn check_equality_lhs_only(s1: &str, s2: &str) -> bool {
1364 let expr1: RecExpr<TrigLanguage> = s1.parse().unwrap();
1365 let expr2: RecExpr<TrigLanguage> = s2.parse().unwrap();
1366
1367 let runner: Runner<TrigLanguage, ConstantFold> =
1368 Runner::default().with_expr(&expr1).run(&make_rules());
1369
1370 !runner.egraph.equivs(&expr1, &expr2).is_empty()
1371}
1372
1373#[allow(dead_code)]
1374fn check_equality_both(s1: &str, s2: &str) -> bool {
1375 let expr1: RecExpr<TrigLanguage> = s1.parse().unwrap();
1376 let expr2: RecExpr<TrigLanguage> = s2.parse().unwrap();
1377
1378 let runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1379 .with_expr(&expr1)
1380 .with_node_limit(25_000)
1381 .run(&make_rules());
1382
1383 let lhs = !runner.egraph.equivs(&expr1, &expr2).is_empty();
1384
1385 let runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1386 .with_expr(&expr2)
1387 .with_node_limit(25_000)
1388 .run(&make_rules());
1389
1390 let rhs = !runner.egraph.equivs(&expr2, &expr1).is_empty();
1391
1392 lhs && rhs
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397 #[test]
1400 fn check_equality_test() {
1401 }
1467}