qudit_expr/analysis/
mod.rs

1use egg::*;
2use num::Signed;
3use num::Zero;
4use ordered_float::NotNan;
5
6use crate::ComplexExpression;
7use crate::Expression;
8
9// #[cfg(test)]
10// mod extract;
11
12pub type EGraph = egg::EGraph<TrigLanguage, ConstantFold>;
13pub type Constant = NotNan<f64>;
14
15define_language! {
16    pub enum TrigLanguage {
17        "pi" = Pi,
18
19        "~" = Neg([Id; 1]),
20        "+" = Add([Id; 2]),
21        "-" = Sub([Id; 2]),
22        "*" = Mul([Id; 2]),
23        "/" = Div([Id; 2]),
24
25        "pow" = Pow([Id; 2]),
26        "sqrt" = Sqrt(Id),
27        "sin" = Sin(Id),
28        "cos" = Cos(Id),
29
30        Constant(Constant),
31        Symbol(Symbol),
32    }
33}
34
35fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
36    let var = var.parse().unwrap();
37    move |egraph, _, subst| {
38        if let Some(n) = &egraph[subst[var]].data {
39            *(n.0) != 0.0
40        } else {
41            false
42        }
43    }
44}
45
46fn is_non_negative_conservative(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
47    let var = var.parse().unwrap();
48    move |egraph, _, subst| {
49        if let Some(n) = &egraph[subst[var]].data {
50            *(n.0) >= 0.0
51        } else {
52            false
53        }
54    }
55}
56
57#[allow(dead_code)]
58fn all_not_zero(vars: &[&str]) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
59    let vars: Vec<_> = vars.iter().map(|v| v.parse().unwrap()).collect();
60    move |egraph, _, subst| {
61        vars.iter().all(|&v| {
62            if let Some(n) = &egraph[subst[v]].data {
63                *(n.0) != 0.0
64            } else {
65                true
66            }
67        })
68    }
69}
70
71fn cmp<T: PartialOrd>(a: &Option<T>, b: &Option<T>) -> Ordering {
72    match (a, b) {
73        (None, None) => Ordering::Equal,
74        (None, Some(_)) => Ordering::Greater,
75        (Some(_), None) => Ordering::Less,
76        (Some(a), Some(b)) => a.partial_cmp(b).unwrap(),
77    }
78}
79
80use core::f64;
81use std::cmp::Ordering;
82use std::collections::HashMap;
83
84#[allow(dead_code)]
85struct SineExtractor<'a> {
86    costs: HashMap<Id, (f64, TrigLanguage)>,
87    egraph: &'a EGraph,
88}
89
90impl<'a> SineExtractor<'a> {
91    #![allow(clippy::type_complexity)]
92    pub fn new(egraph: &'a EGraph) -> Self {
93        let costs = HashMap::default();
94        let mut extractor = SineExtractor { costs, egraph };
95        extractor.calculate_costs();
96        extractor
97    }
98
99    pub fn extract_sine(&self, eclass: Id) -> Option<RecExpr<TrigLanguage>> {
100        let id = self.egraph.find(eclass);
101        let eclass = &self.egraph[id];
102
103        // find best cosine node in eclass
104        let mut best_cost = None;
105        let mut best_node = None;
106
107        for node in &eclass.nodes {
108            if let TrigLanguage::Sin(id) = node {
109                let cost = self.costs[id].0;
110                if best_cost.is_none() || cost < best_cost.unwrap() {
111                    best_cost = Some(cost);
112                    best_node = Some(node.clone());
113                }
114            }
115        }
116
117        if let Some(node) = best_node {
118            let expr = node.build_recexpr(|id| self.find_best_node(id).clone());
119            Some(expr)
120        } else {
121            None
122        }
123    }
124
125    pub fn find_best_node(&self, eclass: Id) -> &TrigLanguage {
126        &self.costs[&self.egraph.find(eclass)].1
127    }
128
129    fn calculate_costs(&mut self) {
130        let mut did_something = true;
131        while did_something {
132            did_something = false;
133
134            for class in self.egraph.classes() {
135                let pass = self.make_pass(class);
136                match (self.costs.get(&class.id), pass) {
137                    (None, Some(new)) => {
138                        self.costs.insert(class.id, new);
139                        did_something = true;
140                    }
141                    (Some(old), Some(new)) if new.0 < old.0 => {
142                        self.costs.insert(class.id, new);
143                        did_something = true;
144                    }
145                    _ => {}
146                }
147            }
148        }
149
150        for class in self.egraph.classes() {
151            if !self.costs.contains_key(&class.id) {
152                println!("failed to calculate cost for {:?}", class);
153            }
154        }
155    }
156
157    fn make_pass(
158        &mut self,
159        eclass: &EClass<TrigLanguage, Option<(NotNan<f64>, RecExpr<ENodeOrVar<TrigLanguage>>)>>,
160    ) -> Option<(f64, TrigLanguage)> {
161        let (cost, node) = eclass
162            .iter()
163            .map(|n| (self.node_total_cost(n), n))
164            .min_by(|a, b| cmp(&a.0, &b.0))
165            .unwrap_or_else(|| panic!("eclass is empty"));
166        cost.map(|c| (c, node.clone()))
167    }
168
169    fn node_total_cost(&mut self, enode: &TrigLanguage) -> Option<f64> {
170        let eg = &self.egraph;
171        let has_cost = |id| self.costs.contains_key(&eg.find(id));
172        // TODO: these hashes can be quite expensive, see if we can cache
173        // the enode.all(has_cost) cost in maybe a vector?
174        if enode.all(has_cost) {
175            Some(self.node_cost(enode))
176        } else {
177            None
178        }
179    }
180
181    fn node_cost(&self, enode: &TrigLanguage) -> f64 {
182        let op_cost = match enode {
183            TrigLanguage::Constant(_) => 0.5,
184            TrigLanguage::Neg(_) => 1.0,
185            TrigLanguage::Add(_) | TrigLanguage::Sub(_) => 1.0,
186            TrigLanguage::Mul(_) | TrigLanguage::Div(_) => 5.0,
187            TrigLanguage::Pow(_)
188            | TrigLanguage::Sqrt(_)
189            | TrigLanguage::Sin(_)
190            | TrigLanguage::Cos(_) => 50.0,
191            _ => 0.0,
192        };
193        enode.fold(op_cost, |acc, id| acc + self.costs[&id].0)
194    }
195}
196
197// struct TestExtractor<'a> {
198//     memory: FxHashMap<Id, TrigLanguage>,
199//     currently_processing: FxHashSet<Id>,
200//     egraph: &'a EGraph,
201// }
202
203// impl<'a> TestExtractor<'a> {
204//     pub fn new(e: &'a EGraph) -> Self {
205//         let memory = FxHashMap::default();
206//         let currently_processing = FxHashSet::default();
207//         TestExtractor { memory, currently_processing, egraph: e }
208//     }
209
210//     pub fn extract_best(&mut self, eclass: Id) -> RecExpr<TrigLanguage> {
211//         let root = self.egraph.find(eclass);
212//         let node = self.calculate_best_node(root);
213//         node.1.build_recexpr(|id| self.memory.get(&id).unwrap().clone())
214//     }
215
216//     pub fn calculate_best_node(&mut self, eclass: Id) -> (f64, TrigLanguage) {
217//         let id = &self.egraph.find(eclass);
218//         if self.currently_processing.contains(&id) {
219//             return (f64::INFINITY, TrigLanguage::Constant(NotNan::new(f64::INFINITY).unwrap()));
220//         }
221//         if let Some(node) = self.memory.get(id) {
222//             return (0.0, node.clone());
223//         }
224
225//         for node in &self.egraph[*id].nodes {
226//             if node.is_leaf() {
227//                 self.memory.insert(*id, node.clone());
228//                 return (0.0, node.clone());
229//             }
230//         }
231
232//         self.currently_processing.insert(*id);
233//         let (cost, node) = self.egraph[*id].iter()
234//             .map(|n| (self.evaluate_node(n), n)).min_by(|a, b| if a.0 < b.0 { Ordering::Less } else if a.0 == b.0 { Ordering::Equal } else { Ordering::Greater }).unwrap();
235//         self.currently_processing.remove(id);
236//         self.memory.insert(*id, node.clone());
237//         (cost, node.clone())
238//     }
239
240//     pub fn evaluate_node(&mut self, enode: &'a TrigLanguage) -> f64 {
241//         match enode {
242//             TrigLanguage::Neg([a]) => self.calculate_best_node(*a).0 + 1.0,
243//             TrigLanguage::Add([a, b]) => self.calculate_best_node(*a).0 + self.calculate_best_node(*b).0 + 1.0,
244//             TrigLanguage::Sub([a, b]) => self.calculate_best_node(*a).0 + self.calculate_best_node(*b).0 + 1.0,
245//             TrigLanguage::Mul([a, b]) => self.calculate_best_node(*a).0 + self.calculate_best_node(*b).0 + 5.0,
246//             TrigLanguage::Div([a, b]) => self.calculate_best_node(*a).0 + self.calculate_best_node(*b).0 + 10.0,
247//             TrigLanguage::Pow([a, b]) => self.calculate_best_node(*a).0 + self.calculate_best_node(*b).0 + 100.0,
248//             TrigLanguage::Sqrt(a) => self.calculate_best_node(*a).0 + 50.0,
249//             TrigLanguage::Sin(a) => self.calculate_best_node(*a).0 + 50.0,
250//             TrigLanguage::Cos(a) => self.calculate_best_node(*a).0 + 50.0,
251//             _ => panic!("unexpected node"),
252//         }
253//     }
254// }
255
256struct TrigExprExtractor<'a> {
257    costs: Vec<(f64, TrigLanguage)>,
258    egraph: &'a EGraph,
259    has_changed: bool,
260}
261
262impl<'a> TrigExprExtractor<'a> {
263    #![allow(clippy::type_complexity)]
264    pub fn new(egraph: &'a EGraph) -> Self {
265        // let costs = FxHashMap::default();
266        let mut max_id = 0usize;
267        for class in egraph.classes() {
268            let id = unsafe { std::mem::transmute::<Id, u32>(class.id) } as usize;
269            if id > max_id {
270                max_id = id
271            }
272        }
273        let costs = vec![(-1.0, TrigLanguage::Pi); max_id + 1];
274        let mut extractor = TrigExprExtractor {
275            costs,
276            egraph,
277            has_changed: false,
278        };
279        extractor.calculate_costs();
280        extractor
281    }
282
283    pub fn extract_best(&mut self, eclass: Id) -> RecExpr<TrigLanguage> {
284        let root = self.get_cost(self.egraph.find(eclass)).1.clone();
285        let expr = root.build_recexpr(|id| self.extract_best_node(id));
286        // TODO: This is creates a greedy search when simplifying many expressions
287        // in a row: The expression's that are simplified later are effected greatly
288        // by the extractions done earlier. There may be better extractions over
289        // many expressions if a simultaneous extraction is done, but this is
290        // not simple to implement. // Potential for signficant speed up here
291        if self.has_changed {
292            self.recalculate_costs();
293            self.has_changed = false;
294        }
295        expr
296    }
297
298    pub fn extract_best_node(&mut self, eclass: Id) -> TrigLanguage {
299        let id = &self.egraph.find(eclass);
300        let (cost, enode) = self.get_cost(*id).clone();
301        if cost != 0.0 {
302            self.put_cost(*id, (cost, enode.clone()));
303            self.has_changed = true;
304        }
305        enode
306    }
307
308    #[allow(dead_code)]
309    pub fn find_best_node(&self, eclass: Id) -> &TrigLanguage {
310        &self.get_cost(self.egraph.find(eclass)).1
311    }
312
313    fn recalculate_costs(&mut self) {
314        let mut did_something = true;
315        while did_something {
316            did_something = false;
317
318            for class in self.egraph.classes() {
319                let pass = self.make_repass(class);
320                let old = self.get_cost(class.id);
321                match (old, pass) {
322                    (old, new) if old.0 < 0.0 || new < old.0 => {
323                        self.put_cost(class.id, (new, old.1.clone()));
324                        did_something = true;
325                    }
326                    _ => {}
327                }
328            }
329        }
330    }
331
332    fn calculate_costs(&mut self) {
333        let mut did_something = true;
334        while did_something {
335            did_something = false;
336
337            for class in self.egraph.classes() {
338                let pass = self.make_pass(class);
339                if let (old, Some(new)) = (self.get_cost(class.id), pass) {
340                    if old.0 < 0.0 || (new.0 > 0.0 && new.0 < old.0) {
341                        self.put_cost(class.id, new);
342                        did_something = true;
343                    }
344                }
345            }
346        }
347
348        for class in self.egraph.classes() {
349            if self.get_cost(class.id).0 < 0.0 {
350                println!("failed to calculate cost for {:?}", class);
351            }
352        }
353    }
354
355    fn make_pass(
356        &mut self,
357        eclass: &EClass<TrigLanguage, Option<(NotNan<f64>, RecExpr<ENodeOrVar<TrigLanguage>>)>>,
358    ) -> Option<(f64, TrigLanguage)> {
359        let (cost, node) = eclass
360            .iter()
361            .map(|n| (self.node_total_cost(n), n))
362            .min_by(|a, b| cmp(&a.0, &b.0))
363            .unwrap_or_else(|| panic!("eclass is empty"));
364        cost.map(|c| (c, node.clone()))
365    }
366
367    fn make_repass(
368        &mut self,
369        eclass: &EClass<TrigLanguage, Option<(NotNan<f64>, RecExpr<ENodeOrVar<TrigLanguage>>)>>,
370    ) -> f64 {
371        eclass
372            .iter()
373            .map(|n| self.node_cost(n))
374            .min_by(|a, b| match a < b {
375                true => Ordering::Less,
376                false => Ordering::Greater,
377            })
378            .unwrap()
379    }
380
381    fn node_total_cost(&mut self, enode: &TrigLanguage) -> Option<f64> {
382        let eg = &self.egraph;
383        let has_cost = |id| self.get_cost(eg.find(id)).0 >= 0.0;
384        if enode.all(has_cost) {
385            Some(self.node_cost(enode))
386        } else {
387            None
388        }
389    }
390
391    fn node_cost(&self, enode: &TrigLanguage) -> f64 {
392        let op_cost = match enode {
393            TrigLanguage::Constant(_) => 0.5,
394            TrigLanguage::Neg(_) => 1.0,
395            TrigLanguage::Add(_) | TrigLanguage::Sub(_) => 1.0,
396            TrigLanguage::Mul(_) | TrigLanguage::Div(_) => 5.0,
397            TrigLanguage::Sqrt(_) | TrigLanguage::Sin(_) | TrigLanguage::Cos(_) => 50.0,
398            TrigLanguage::Pow(_) => 100.0,
399            _ => 0.0,
400        };
401        enode.fold(op_cost, |acc, id| acc + self.get_cost(id).0)
402    }
403
404    #[inline(always)]
405    fn get_cost(&self, id: Id) -> &(f64, TrigLanguage) {
406        // TODO: this is actually unsafe because Id may not be transparent.
407        &self.costs[unsafe { std::mem::transmute::<Id, u32>(id) } as usize]
408    }
409
410    #[inline(always)]
411    fn put_cost(&mut self, id: Id, cost: (f64, TrigLanguage)) {
412        self.costs[unsafe { std::mem::transmute::<Id, u32>(id) } as usize] = cost;
413    }
414}
415
416struct TrigCostFn;
417impl CostFunction<TrigLanguage> for TrigCostFn {
418    type Cost = f64;
419    fn cost<C>(&mut self, enode: &TrigLanguage, mut costs: C) -> Self::Cost
420    where
421        C: FnMut(Id) -> Self::Cost,
422    {
423        let op_cost = match enode {
424            TrigLanguage::Constant(_) => 0.5,
425            TrigLanguage::Neg(_) => 1.0,
426            TrigLanguage::Add(_) | TrigLanguage::Sub(_) => 1.0,
427            TrigLanguage::Mul(_) | TrigLanguage::Div(_) => 5.0,
428            TrigLanguage::Pow(_)
429            | TrigLanguage::Sqrt(_)
430            | TrigLanguage::Sin(_)
431            | TrigLanguage::Cos(_) => 50.0,
432            _ => 0.0,
433        };
434
435        enode.fold(op_cost, |acc, id| acc + costs(id))
436    }
437}
438
439pub fn can_multiply(a: NotNan<f64>, b: NotNan<f64>) -> bool {
440    if !a.is_zero() && a.is_positive() && a < NotNan::new(1e-15).unwrap() {
441        return false;
442    }
443    if !b.is_zero() && b.is_positive() && b < NotNan::new(1e-15).unwrap() {
444        return false;
445    }
446    if !a.is_zero() && a.is_negative() && a > NotNan::new(-1e-15).unwrap() {
447        return false;
448    }
449    if !b.is_zero() && b.is_negative() && b > NotNan::new(-1e-15).unwrap() {
450        return false;
451    }
452    if a > NotNan::new(1e15).unwrap() || b > NotNan::new(1e15).unwrap() {
453        return false;
454    }
455    if a < NotNan::new(-1e15).unwrap() || b < NotNan::new(-1e15).unwrap() {
456        return false;
457    }
458    a.is_finite() && b.is_finite() && !a.is_subnormal() && !b.is_subnormal()
459}
460
461pub fn can_divide(a: NotNan<f64>, b: NotNan<f64>) -> bool {
462    if !a.is_zero() && a.is_positive() && a < NotNan::new(1e-15).unwrap() {
463        return false;
464    }
465    if !b.is_zero() && b.is_positive() && b < NotNan::new(1e-15).unwrap() {
466        return false;
467    }
468    if !a.is_zero() && a.is_negative() && a > NotNan::new(-1e-15).unwrap() {
469        return false;
470    }
471    if !b.is_zero() && b.is_negative() && b > NotNan::new(-1e-15).unwrap() {
472        return false;
473    }
474    if a > NotNan::new(1e15).unwrap() || b > NotNan::new(1e15).unwrap() {
475        return false;
476    }
477    if a < NotNan::new(-1e15).unwrap() || b < NotNan::new(-1e15).unwrap() {
478        return false;
479    }
480    a.is_finite() && b.is_finite() && !b.is_zero() && !b.is_subnormal() && !a.is_subnormal()
481}
482
483#[derive(Default)]
484pub struct ConstantFold;
485impl Analysis<TrigLanguage> for ConstantFold {
486    type Data = Option<(Constant, PatternAst<TrigLanguage>)>;
487
488    fn make(egraph: &mut EGraph, enode: &TrigLanguage) -> Self::Data {
489        let x = |i: &Id| egraph[*i].data.as_ref().map(|d| d.0);
490        Some(match enode {
491            // TrigLanguage::Pi => (NotNan::new(std::f64::consts::PI).unwrap(), "pi".parse().unwrap()),
492            TrigLanguage::Constant(c) => (*c, format!("{}", c).parse().unwrap()),
493            TrigLanguage::Add([a, b]) if can_multiply(x(a)?, x(b)?) => (
494                x(a)? + x(b)?,
495                format!("(+ {} {})", x(a)?, x(b)?).parse().unwrap(),
496            ),
497            TrigLanguage::Sub([a, b]) if can_multiply(x(a)?, x(b)?) => (
498                x(a)? - x(b)?,
499                format!("(- {} {})", x(a)?, x(b)?).parse().unwrap(),
500            ),
501            TrigLanguage::Mul([a, b]) if can_multiply(x(a)?, x(b)?) => (
502                x(a)? * x(b)?,
503                format!("(* {} {})", x(a)?, x(b)?).parse().unwrap(),
504            ),
505            TrigLanguage::Div([a, b]) if can_divide(x(a)?, x(b)?) => (
506                x(a)? / x(b)?,
507                format!("(/ {} {})", x(a)?, x(b)?).parse().unwrap(),
508            ),
509            TrigLanguage::Neg([a]) => (-x(a)?, format!("(~ {})", x(a)?).parse().unwrap()),
510            // TrigLanguage::Pow([a, b]) => (
511            //     NotNan::new(x(a)?.powf(x(b)?.into_inner())).unwrap(), // TODO: handle invalids
512            //     format!("(pow {} {})", x(a)?, x(b)?).parse().unwrap(),
513            // ),
514            // TrigLanguage::Sqrt(a) => (
515            //     NotNan::new(x(a)?.sqrt()).unwrap(), // TODO: handle invalids
516            //     format!("(sqrt {})", x(a)?).parse().unwrap(),
517            // ),
518            // TrigLanguage::Cbrt(a) => (
519            //     NotNan::new(x(a)?.cbrt()).unwrap(), // TODO: handle invalids
520            //     format!("(cbrt {})", x(a)?).parse().unwrap(),
521            // ),
522            // TrigLanguage::Sin(a) => (
523            //     NotNan::new(x(a)?.sin()).unwrap(), // TODO: handle invalids
524            //     format!("(sin {})", x(a)?).parse().unwrap(),
525            // ),
526            // TrigLanguage::Cos(a) => (
527            //     NotNan::new(x(a)?.cos()).unwrap(), // TODO: handle invalids
528            //     format!("(cos {})", x(a)?).parse().unwrap(),
529            // ),
530            _ => return None,
531        })
532    }
533
534    // fn pre_union(
535    //         egraph: &egg::EGraph<TrigLanguage, Self>,
536    //         id1: Id,
537    //         id2: Id,
538    //         justification: &Option<Justification>,
539    //     ) {
540    //     println!("pre_union: {:?} {:?}", id1, id2);
541    //     println!("justification: {:?}", justification);
542    // }
543
544    fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
545        merge_option(to, from, |a, b| {
546            assert_eq!(a.0, b.0, "Merged non-equal constants");
547            DidMerge(false, false)
548        })
549    }
550
551    fn modify(egraph: &mut EGraph, id: Id) {
552        let data = egraph[id].data.clone();
553        if let Some((c, pat)) = data {
554            if egraph.are_explanations_enabled() {
555                egraph.union_instantiations(
556                    &pat,
557                    &format!("{}", c).parse().unwrap(),
558                    &Default::default(),
559                    "constant_fold".to_string(),
560                );
561            } else {
562                let added = egraph.add(TrigLanguage::Constant(c));
563                egraph.union(id, added);
564            }
565            // to not prune, comment this out
566            egraph[id].nodes.retain(|n| n.is_leaf());
567
568            #[cfg(debug_assertions)]
569            egraph[id].assert_unique_leaves();
570        }
571    }
572}
573
574fn make_rules() -> Vec<Rewrite<TrigLanguage, ConstantFold>> {
575    vec![
576        // Commutativity
577        rewrite!("+-commutative"; "(+ ?a ?b)" => "(+ ?b ?a)"),
578        rewrite!("*-commutative"; "(* ?a ?b)" => "(* ?b ?a)"),
579        // Associativity
580        rewrite!("associate-+r+"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
581        rewrite!("associate-+l+"; "(+ (+ ?a ?b) ?c)" => "(+ ?a (+ ?b ?c))"),
582        rewrite!("associate-+r-"; "(+ ?a (- ?b ?c))" => "(- (+ ?a ?b) ?c)"),
583        rewrite!("associate-+l-"; "(+ (- ?a ?b) ?c)" => "(- ?a (- ?b ?c))"),
584        rewrite!("associate--r+"; "(- ?a (+ ?b ?c))" => "(- (- ?a ?b) ?c)"),
585        rewrite!("associate--l+"; "(- (+ ?a ?b) ?c)" => "(+ ?a (- ?b ?c))"),
586        rewrite!("associate--l-"; "(- (- ?a ?b) ?c)" => "(- ?a (+ ?b ?c))"),
587        rewrite!("associate--r-"; "(- ?a (- ?b ?c))" => "(+ (- ?a ?b) ?c)"),
588        rewrite!("associate-*r*"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
589        rewrite!("associate-*l*"; "(* (* ?a ?b) ?c)" => "(* ?a (* ?b ?c))"),
590        rewrite!("associate-*r/"; "(* ?a (/ ?b ?c))" => "(/ (* ?a ?b) ?c)"),
591        rewrite!("associate-*l/"; "(* (/ ?a ?b) ?c)" => "(/ (* ?a ?c) ?b)"),
592        rewrite!("associate-/r*"; "(/ ?a (* ?b ?c))" => "(/ (/ ?a ?b) ?c)"),
593        rewrite!("associate-/r/"; "(/ ?a (/ ?b ?c))" => "(* (/ ?a ?b) ?c)"),
594        rewrite!("associate-/l/"; "(/ (/ ?b ?c) ?a)" => "(/ ?b (* ?a ?c))"),
595        rewrite!("associate-/l*"; "(/ (* ?b ?c) ?a)" => "(* ?b (/ ?c ?a))"),
596        // Counting
597        rewrite!("count-2"; "(+ ?x ?x)" => "(* 2 ?x)"),
598        // Distributivity
599        rewrite!("distribute-lft-in"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
600        rewrite!("distribute-rgt-in"; "(* ?a (+ ?b ?c))" => "(+ (* ?b ?a) (* ?c ?a))"),
601        rewrite!("distribute-lft-out"; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
602        rewrite!("distribute-lft-out--"; "(- (* ?a ?b) (* ?a ?c))" => "(* ?a (- ?b ?c))"),
603        rewrite!("distribute-rgt-out"; "(+ (* ?b ?a) (* ?c ?a))" => "(* ?a (+ ?b ?c))"),
604        rewrite!("distribute-rgt-out--"; "(- (* ?b ?a) (* ?c ?a))" => "(* ?a (- ?b ?c))"),
605        rewrite!("distribute-lft1-in"; "(+ (* ?b ?a) ?a)" => "(* (+ ?b 1) ?a)"),
606        rewrite!("distribute-rgt1-in"; "(+ ?a (* ?c ?a))" => "(* (+ ?c 1) ?a)"),
607        // Distributivity Fp Safe
608        rewrite!("distribute-lft-neg-in"; "(~ (* ?a ?b))" => "(* (~ ?a) ?b)"),
609        rewrite!("distribute-rgt-neg-in"; "(~ (* ?a ?b))" => "(* ?a (~ ?b))"),
610        rewrite!("distribute-lft-neg-out"; "(* (~ ?a) ?b)" => "(~ (* ?a ?b))"),
611        rewrite!("distribute-rgt-neg-out"; "(* ?a (~ ?b))" => "(~ (* ?a ?b))"),
612        rewrite!("distribute-neg-in"; "(~ (+ ?a ?b))" => "(+ (~ ?a) (~ ?b))"),
613        rewrite!("distribute-neg-out"; "(+ (~ ?a) (~ ?b))" => "(~ (+ ?a ?b))"),
614        rewrite!("distribute-frac-neg"; "(/ (~ ?a) ?b)" => "(~ (/ ?a ?b))"),
615        rewrite!("distribute-frac-neg2"; "(/ ?a (~ ?b))" => "(~ (/ ?a ?b))"),
616        rewrite!("distribute-neg-frac"; "(~ (/ ?a ?b))" => "(/ (~ ?a) ?b)"),
617        rewrite!("distribute-neg-frac2"; "(~ (/ ?a ?b))" => "(/ ?a (~ ?b))"),
618        // Cancel Sign Fp Safe
619        rewrite!("cancel-sign-sub"; "(- ?a (* (~ ?b) ?c))" => "(+ ?a (* ?b ?c))"),
620        rewrite!("cancel-sign-sub-inv"; "(- ?a (* ?b ?c))" => "(+ ?a (* (~ ?b) ?c))"),
621        // Difference Of Squares Canonicalize
622        rewrite!("swap-sqr"; "(* (* ?a ?b) (* ?a ?b))" => "(* (* ?a ?a) (* ?b ?b))"),
623        rewrite!("unswap-sqr"; "(* (* ?a ?a) (* ?b ?b))" => "(* (* ?a ?b) (* ?a ?b))"),
624        rewrite!("difference-of-squares"; "(- (* ?a ?a) (* ?b ?b))" => "(* (+ ?a ?b) (- ?a ?b))"),
625        rewrite!("difference-of-sqr-1"; "(- (* ?a ?a) 1)" => "(* (+ ?a 1) (- ?a 1))"),
626        rewrite!("difference-of-sqr--1"; "(+ (* ?a ?a) -1)" => "(* (+ ?a 1) (- ?a 1))"),
627        rewrite!("pow-sqr"; "(* (pow ?a ?b) (pow ?a ?b))" => "(pow ?a (* 2 ?b))"),
628        // // Sqr Pow Expand
629        // rewrite!("sqr-pow"; "(pow ?a ?b)" => "(* (pow ?a (/ ?b 2)) (pow ?a (/ ?b 2)))"),
630        // TODO: add conditional that a > 0
631
632        // // Difference Of Squares Flip
633        // rewrite!("flip-+"; "(+ ?a ?b)" => "(/ (- (* ?a ?a) (* ?b ?b)) (- ?a ?b))"),
634        // rewrite!("flip--"; "(- ?a ?b)" => "(/ (- (* ?a ?a) (* ?b ?b)) (+ ?a ?b))"),
635        // // TODO: causes issues since dem must not be zero, not so easy to check
636
637        // Id Reduce
638        rewrite!("remove-double-div"; "(/ 1 (/ 1 ?a))" => "?a"),
639        rewrite!("rgt-mult-inverse"; "(* ?a (/ 1 ?a))" => "1" if is_not_zero("?a")),
640        rewrite!("lft-mult-inverse"; "(* (/ 1 ?a) ?a)" => "1" if is_not_zero("?a")),
641        // TODO: are the checks necessary?
642
643        // Id Reduce Fp Safe Nan
644        rewrite!("+-inverses"; "(- ?a ?a)" => "0"),
645        rewrite!("div0"; "(/ 0 ?a)" => "0" if is_not_zero("?a")),
646        rewrite!("mul0-lft"; "(* 0 ?a)" => "0"),
647        rewrite!("mul0-rgt"; "(* ?a 0)" => "0"),
648        rewrite!("*-inverses"; "(/ ?a ?a)" => "1" if is_not_zero("?a")),
649        // Id Reduce Fp Safe
650        rewrite!("+-lft-identity"; "(+ 0 ?a)" => "?a"),
651        rewrite!("+-rgt-identity"; "(+ ?a 0)" => "?a"),
652        rewrite!("--rgt-identity"; "(- ?a 0)" => "?a"),
653        rewrite!("sub0-neg"; "(- 0 ?a)" => "(~ ?a)"),
654        rewrite!("remove-double-neg"; "(~ (~ ?a))" => "?a"),
655        rewrite!("*-lft-identity"; "(* 1 ?a)" => "?a"),
656        rewrite!("*-rgt-identity"; "(* ?a 1)" => "?a"),
657        rewrite!("/-rgt-identity"; "(/ ?a 1)" => "?a"),
658        rewrite!("mul-1-neg"; "(* -1 ?a)" => "(~ ?a)"),
659        // Nan Transform Fp Safe
660        rewrite!("sub-neg"; "(- ?a ?b)" => "(+ ?a (~ ?b))"),
661        rewrite!("unsub-neg"; "(+ ?a (~ ?b))" => "(- ?a ?b)"),
662        rewrite!("neg-sub0"; "(~ ?b)" => "(- 0 ?b)"),
663        rewrite!("neg-mul-1"; "(~ ?a)" => "(* -1 ?a)"),
664        // Id Transform Safe
665        rewrite!("div-inv"; "(/ ?a ?b)" => "(* ?a (/ 1 ?b))"),
666        rewrite!("un-div-inv"; "(* ?a (/ 1 ?b))" => "(/ ?a ?b)"),
667        // // Id Transform Clear Num
668        // rewrite!("clear-num"; "(/ ?a ?b)" => "(/ 1 (/ ?b ?a))" if is_not_zero("?a")),
669        // // TODO: Causes issues with trig functions; probably a confounding issue, need further
670        // // investigation
671
672        // Id Transform Fp Safe
673        rewrite!("*-un-lft-identity"; "?a" => "(* 1 ?a)"),
674        // Difference Of Cubes
675        rewrite!("sum-cubes"; "(+ (pow ?a 3) (pow ?b 3))" => "(* (+ (* ?a ?a) (- (* ?b ?b) (* ?a ?b))) (+ ?a ?b))"),
676        rewrite!("difference-cubes"; "(- (pow ?a 3) (pow ?b 3))" => "(* (+ (* ?a ?a) (+ (* ?b ?b) (* ?a ?b))) (- ?a ?b))"),
677        // rewrite!("flip3-+"; "(+ ?a ?b)" => "(/ (+ (pow ?a 3) (pow ?b 3)) (+ (* ?a ?a) (- (* ?b ?b) (* ?a ?b))))"),
678        // rewrite!("flip3--"; "(- ?a ?b)" => "(/ (- (pow ?a 3) (pow ?b 3)) (+ (* ?a ?a) (+ (* ?b ?b) (* ?a ?b))))"),
679
680        // Fractions Distribute
681        rewrite!("div-sub"; "(/ (- ?a ?b) ?c)" => "(- (/ ?a ?c) (/ ?b ?c))"),
682        rewrite!("times-frac"; "(/ (* ?a ?b) (* ?c ?d))" => "(* (/ ?a ?c) (/ ?b ?d))"),
683        // Fractions Transform
684        rewrite!("sub-div"; "(- (/ ?a ?c) (/ ?b ?c))" => "(/ (- ?a ?b) ?c)"),
685        rewrite!("frac-add"; "(+ (/ ?a ?b) (/ ?c ?d))" => "(/ (+ (* ?a ?d) (* ?b ?c)) (* ?b ?d))"),
686        rewrite!("frac-sub"; "(- (/ ?a ?b) (/ ?c ?d))" => "(/ (- (* ?a ?d) (* ?b ?c)) (* ?b ?d))"),
687        rewrite!("frac-times"; "(* (/ ?a ?b) (/ ?c ?d))" => "(/ (* ?a ?c) (* ?b ?d))"),
688        rewrite!("frac-2neg"; "(/ ?a ?b)" => "(/ (~ ?a) (~ ?b))"),
689        // Squares Reduce
690        rewrite!("rem-square-sqrt"; "(* (sqrt ?x) (sqrt ?x))" => "?x"),
691        // Squares Reduce Fp Sound
692        rewrite!("sqr-neg"; "(* (~ ?x) (~ ?x))" => "(* ?x ?x)"),
693        // Squares Transform Sound
694        rewrite!("sqrt-pow2"; "(pow (sqrt ?x) ?y)" => "(pow ?x (/ ?y 2))"),
695        rewrite!("sqrt-unprod"; "(* (sqrt ?x) (sqrt ?y))" => "(sqrt (* ?x ?y))"),
696        rewrite!("sqrt-undiv"; "(/ (sqrt ?x) (sqrt ?y))" => "(sqrt (/ ?x ?y))"),
697        // Sqrt Canonicalize
698        rewrite!("sqrt-1"; "(sqrt 1)" => "1"),
699        rewrite!("sqrt-0"; "(sqrt 0)" => "0"),
700        rewrite!("sqrt-can"; "(/ (sqrt ?x) ?x)" => "(/ 1 (sqrt ?x))"),
701        rewrite!("sqrt-can-inv"; "(/ ?x (sqrt ?x))" => "(sqrt ?x)"),
702        rewrite!("sqrt-can-rev"; "(/ 1 (sqrt ?x))" => "(/ (sqrt ?x) ?x)"),
703        // // Squares Transform
704        // rewrite!("sqrt-pow1"; "(sqrt (pow ?x ?y))" => "(pow ?x (/ ?y 2))"),
705        // rewrite!("sqrt-prod"; "(sqrt (* ?x ?y))" => "(* (sqrt ?x) (sqrt ?y))"),
706        // rewrite!("sqrt-div"; "(sqrt (/ ?x ?y))" => "(/ (sqrt ?x) (sqrt ?y))"),
707        rewrite!("add-sqr-sqrt"; "?x" => "(* (sqrt ?x) (sqrt ?x))" if is_non_negative_conservative("?x")),
708        // TODO: determine if necessary, if so, then determine conditionals
709        // conditions x and y are negative, which causes issues splitting them over real
710
711        // Cubes Distribute
712        rewrite!("cube-prod"; "(pow (* ?x ?y) 3)" => "(* (pow ?x 3) (pow ?y 3))"),
713        rewrite!("cube-div"; "(pow (/ ?x ?y) 3)" => "(/ (pow ?x 3) (pow ?y 3))"),
714        rewrite!("cube-mult"; "(pow ?x 3)" => "(* ?x (* ?x ?x))"),
715        // Cubes Canonicalize
716        rewrite!("cube-unmult"; "(* ?x (* ?x ?x))" => "(pow ?x 3)"),
717        // Pow Reduce
718        rewrite!("unpow-1"; "(pow ?a -1)" => "(/ 1 ?a)"),
719        // Pow Reduce Fp Safe
720        rewrite!("unpow1"; "(pow ?a 1)" => "?a"),
721        // Pow Reduce Fp Safe Nan
722        rewrite!("unpow0"; "(pow ?a 0)" => "1" if is_not_zero("?a")),
723        rewrite!("pow-base-1"; "(pow 1 ?a)" => "1"),
724        // Pow Expand Fp Safe
725        rewrite!("pow1"; "?a" => "(pow ?a 1)"),
726        // Pow Canonicalize
727        rewrite!("unpow1/2"; "(pow ?a 0.5)" => "(sqrt ?a)"),
728        rewrite!("unpow2"; "(pow ?a 2)" => "(* ?a ?a)"),
729        rewrite!("unpow3"; "(pow ?a 3)" => "(* (* ?a ?a) ?a)"),
730        rewrite!("pow-plus"; "(* (pow ?a ?b) ?a)" => "(pow ?a (+ ?b 1))"),
731        // Pow Transform Sound
732        rewrite!("pow-prod-down"; "(* (pow ?b ?a) (pow ?c ?a))" => "(pow (* ?b ?c) ?a)"),
733        rewrite!("pow-prod-up"; "(* (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (+ ?b ?c))"),
734        rewrite!("pow-flip"; "(/ 1 (pow ?a ?b))" => "(pow ?a (~ ?b))"),
735        rewrite!("pow-neg"; "(pow ?a (~ ?b))" => "(/ 1 (pow ?a ?b))" if is_not_zero("?a")),
736        rewrite!("pow-div"; "(/ (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (- ?b ?c))"),
737        // Pow Specialize Sound
738        rewrite!("pow1/2"; "(sqrt ?a)" => "(pow ?a 0.5)"),
739        rewrite!("pow2"; "(* ?a ?a)" => "(pow ?a 2)"),
740        rewrite!("pow3"; "(* (* ?a ?a) ?a)" => "(pow ?a 3)"),
741        // // Pow Transform
742        // rewrite!("pow-sub"; "(pow ?a (- ?b ?c))" => "(/ (pow ?a ?b) (pow ?a ?c))"),
743        // rewrite!("pow-pow"; "(pow (pow ?a ?b) ?c)" => "(pow ?a (* ?b ?c))"),
744        // rewrite!("pow-unpow"; "(pow ?a (* ?b ?c))" => "(pow (pow ?a ?b) ?c)"),
745        // rewrite!("unpow-prod-up"; "(pow ?a (+ ?b ?c))" => "(* (pow ?a ?b) (pow ?a ?c))"),
746        // rewrite!("unpow-prod-down"; "(pow (* ?b ?c) ?a)" => "(* (pow ?b ?a) (pow ?c ?a))"),
747        // TODO: determine if necessary, if so, then determine conditionals
748
749        // Pow Transform Fp Safe Nan
750        rewrite!("pow-base-0"; "(pow 0 ?a)" => "0" if is_not_zero("?a")),
751        // Pow Transform Fp Safe
752        rewrite!("inv-pow"; "(/ 1 ?a)" => "(pow ?a -1)"),
753        // Trig Reduce Fp Sound
754        rewrite!("sin-0"; "(sin 0)" => "0"),
755        rewrite!("cos-0"; "(cos 0)" => "1"),
756        // Trig Reduce Fp Sound Nan
757        rewrite!("sin-neg"; "(sin (~ ?x))" => "(~ (sin ?x))"),
758        rewrite!("neg-sig"; "(~ (sin ?x))" => "(sin (~ ?x))"),
759        rewrite!("cos-neg"; "(cos (~ ?x))" => "(cos ?x)"),
760        rewrite!("neg-cos"; "(cos ?x)" => "(cos (~ ?x))"),
761        // Trig Expand Fp Safe
762        rewrite!("sqr-sin-b"; "(* (sin ?x) (sin ?x))" => "(- 1 (* (cos ?x) (cos ?x)))"),
763        rewrite!("sqr-cos-b"; "(* (cos ?x) (cos ?x))" => "(- 1 (* (sin ?x) (sin ?x)))"),
764        // Trig Reduce Sound
765        rewrite!("cos-sin-sum"; "(+ (* (cos ?a) (cos ?a)) (* (sin ?a) (sin ?a)))" => "1"),
766        rewrite!("1-sub-cos"; "(- 1 (* (cos ?a) (cos ?a)))" => "(* (sin ?a) (sin ?a))"),
767        rewrite!("1-sub-sin"; "(- 1 (* (sin ?a) (sin ?a)))" => "(* (cos ?a) (cos ?a))"),
768        rewrite!("-1-add-cos"; "(+ (* (cos ?a) (cos ?a)) -1)" => "(~ (* (sin ?a) (sin ?a)))"),
769        rewrite!("-1-add-sin"; "(+ (* (sin ?a) (sin ?a)) -1)" => "(~ (* (cos ?a) (cos ?a)))"),
770        rewrite!("sub-1-cos"; "(- (* (cos ?a) (cos ?a)) 1)" => "(~ (* (sin ?a) (sin ?a)))"),
771        rewrite!("sub-1-sin"; "(- (* (sin ?a) (sin ?a)) 1)" => "(~ (* (cos ?a) (cos ?a)))"),
772        rewrite!("sin-PI/6"; "(sin (/ (pi) 6))" => "0.5"),
773        rewrite!("sin-PI/4"; "(sin (/ (pi) 4))" => "(/ (sqrt 2) 2)"),
774        rewrite!("sin-PI*0.25"; "(sin (* (pi) 0.25))" => "(/ (sqrt 2) 2)"),
775        rewrite!("sin-PI*-0.25"; "(sin (* (pi) -0.25))" => "(~ (/ (sqrt 2) 2))"),
776        rewrite!("sin-PI/3"; "(sin (/ (pi) 3))" => "(/ (sqrt 3) 2)"),
777        rewrite!("sin-PI/2"; "(sin (/ (pi) 2))" => "1"),
778        rewrite!("sin-PI*0.5"; "(sin (* (pi) 0.5))" => "1"),
779        rewrite!("sin-PI"; "(sin (pi))" => "0"),
780        rewrite!("sin-+PI"; "(sin (+ ?x (pi)))" => "(~ (sin ?x))"),
781        rewrite!("sin-+PI/2"; "(sin (+ ?x (/ (pi) 2)))" => "(cos ?x)"),
782        rewrite!("cos-PI/6"; "(cos (/ (pi) 6))" => "(/ (sqrt 3) 2)"),
783        rewrite!("cos-PI/4"; "(cos (/ (pi) 4))" => "(/ (sqrt 2) 2)"),
784        rewrite!("cos-PI*0.25"; "(cos (* (pi) 0.25))" => "(/ (sqrt 2) 2)"),
785        rewrite!("cos-PI/3"; "(cos (/ (pi) 3))" => "0.5"),
786        rewrite!("cos-PI/2"; "(cos (/ (pi) 2))" => "0"),
787        rewrite!("cos-PI*0.5"; "(cos (* (pi) 0.5))" => "0"),
788        rewrite!("cos-PI"; "(cos (pi))" => "-1"),
789        rewrite!("cos-+PI"; "(cos (+ ?x (pi)))" => "(~ (cos ?x))"),
790        rewrite!("cos-+PI/2"; "(cos (+ ?x (* (pi) 0.5)))" => "(~ (sin ?x))"),
791        rewrite!("hang-0p-tan"; "(/ (sin ?a) (+ 1 (cos ?a)))" => "(/ (sin (/ ?a 2)) (cos (/ ?a 2)))"),
792        rewrite!("hang-0m-tan"; "(/ (~ (sin ?a)) (+ 1 (cos ?a)))" => "(/ (sin (/ (~ ?a) 2)) (cos (/ (~ ?a) 2)))"),
793        rewrite!("hang-p0-tan"; "(/ (- 1 (cos ?a)) (sin ?a))" => "(/ (sin (/ ?a 2)) (cos (/ ?a 2)))"),
794        rewrite!("hang-m0-tan"; "(/ (- 1 (cos ?a)) (~ (sin ?a)))" => "(/ (sin (/ (~ ?a) 2)) (cos (/ (~ ?a) 2)))"),
795        rewrite!("tan-hang-0p"; "(/ (sin (* ?a 0.5)) (cos (* ?a 0.5)))" => "(/ (sin ?a) (+ 1 (cos ?a)))"),
796        rewrite!("tan-hang-0m"; "(/ (sin (* (~ ?a) 0.5)) (cos (* (~ ?a) 0.5)))" => "(/ (~ (sin ?a)) (+ 1 (cos ?a)))"),
797        rewrite!("tan-hang-p0"; "(/ (sin (* ?a 0.5)) (cos (* ?a 0.5)))" => "(/ (- 1 (cos ?a)) (sin ?a))"),
798        rewrite!("tan-hang-m0"; "(/ (sin (* (~ ?a) 0.5)) (cos (* (~ ?a) 0.5)))" => "(/ (- 1 (cos ?a)) (~ (sin ?a)))" if is_not_zero("?a")),
799        // Trig Expand Sound
800        rewrite!("csc-cot"; "(/ 1 (* (sin ?a) (sin ?a)))" => "(+ 1 (/ (* (cos ?a) (cos ?a)) (* (sin ?a) (sin ?a))))"),
801        rewrite!("sec-tan"; "(/ 1 (* (cos ?a) (cos ?a)))" => "(+ 1 (/ (* (sin ?a) (sin ?a)) (* (cos ?a) (cos ?a))))"),
802        rewrite!("csc-sec"; "(* (/ 1 (* (cos ?a) (cos ?a))) (/ 1 (* (sin ?a) (sin ?a))))" => "(+ (/ 1 (* (cos ?a) (cos ?a))) (/ 1 (* (sin ?a) (sin ?a))))"),
803        rewrite!("sin-sum"; "(sin (+ ?x ?y))" => "(+ (* (sin ?x) (cos ?y)) (* (cos ?x) (sin ?y)))"),
804        rewrite!("cos-sum"; "(cos (+ ?x ?y))" => "(- (* (cos ?x) (cos ?y)) (* (sin ?x) (sin ?y)))"),
805        // rewrite!("tan-sum"; "(/ (sin (+ ?a ?b)) (cos (+ ?a ?b)))" => "(/ (+ (/ (sin ?a) (cos ?a)) (/ (sin ?b) (cos ?b))) (- 1 (* (/ (sin ?a) (cos ?a)) (/ (sin ?b) (cos ?b)))))"),
806        // rewrite!("cot-sum"; "(/ (cos (+ ?a ?b)) (sin (+ ?a ?b)))" => "(/ (- (* (/ (cos ?a) (sin ?a)) (/ (cos ?b) (sin ?b))) 1) (+ (/ (cos ?b) (sin ?b)) (/ (cos ?a) (sin ?a))))"),
807        rewrite!("sin-diff"; "(sin (- ?x ?y))" => "(- (* (sin ?x) (cos ?y)) (* (cos ?x) (sin ?y)))"),
808        rewrite!("cos-diff"; "(cos (- ?x ?y))" => "(+ (* (cos ?x) (cos ?y)) (* (sin ?x) (sin ?y)))"),
809        rewrite!("sin-2"; "(sin (* 2 ?x))" => "(* 2 (* (sin ?x) (cos ?x)))"),
810        rewrite!("sin-3"; "(sin (* 3 ?x))" => "(- (* 3 (sin ?x)) (* 4 (pow (sin ?x) 3)))"),
811        rewrite!("2-sin"; "(* 2 (* (sin ?x) (cos ?x)))" => "(sin (* 2 ?x))"),
812        rewrite!("3-sin"; "(- (* 3 (sin ?x)) (* 4 (pow (sin ?x) 3)))" => "(sin (* 3 ?x))"),
813        rewrite!("cos-2"; "(cos (* 2 ?x))" => "(- (* (cos ?x) (cos ?x)) (* (sin ?x) (sin ?x)))"),
814        rewrite!("cos-3"; "(cos (* 3 ?x))" => "(- (* 4 (pow (cos ?x) 3)) (* 3 (cos ?x)))"),
815        rewrite!("2-cos"; "(- (* (cos ?x) (cos ?x)) (* (sin ?x) (sin ?x)))" => "(cos (* 2 ?x))"),
816        rewrite!("3-cos"; "(- (* 4 (pow (cos ?x) 3)) (* 3 (cos ?x)))" => "(cos (* 3 ?x))"),
817        // Trig Expand Sound2
818        rewrite!("sqr-sin-a"; "(* (sin ?x) (sin ?x))" => "(- 0.5 (* 0.5 (cos (* 2 ?x))))"),
819        rewrite!("sqr-cos-a"; "(* (cos ?x) (cos ?x))" => "(+ 0.5 (* 0.5 (cos (* 2 ?x))))"),
820        rewrite!("diff-sin"; "(- (sin ?x) (sin ?y))" => "(* 2 (* (sin (/ (- ?x ?y) 2)) (cos (/ (+ ?x ?y) 2))))"),
821        rewrite!("diff-cos"; "(- (cos ?x) (cos ?y))" => "(* -2 (* (sin (/ (- ?x ?y) 2)) (sin (/ (+ ?x ?y) 2))))"),
822        rewrite!("sum-sin"; "(+ (sin ?x) (sin ?y))" => "(* 2 (* (sin (/ (+ ?x ?y) 2)) (cos (/ (- ?x ?y) 2))))"),
823        rewrite!("sum-cos"; "(+ (cos ?x) (cos ?y))" => "(* 2 (* (cos (/ (+ ?x ?y) 2)) (cos (/ (- ?x ?y) 2))))"),
824        rewrite!("cos-mult"; "(* (cos ?x) (cos ?y))" => "(/ (+ (cos (+ ?x ?y)) (cos (- ?x ?y))) 2)"),
825        rewrite!("sin-mult"; "(* (sin ?x) (sin ?y))" => "(/ (- (cos (- ?x ?y)) (cos (+ ?x ?y))) 2)"),
826        rewrite!("sin-cos-mult"; "(* (sin ?x) (cos ?y))" => "(/ (+ (sin (- ?x ?y)) (sin (+ ?x ?y))) 2)"),
827        rewrite!("tan-2"; "(/ (sin (* 2 ?x)) (cos (* 2 ?x)))" => "(/ (* 2 (/ (sin ?x) (cos ?x))) (- 1 (* (/ (sin ?x) (cos ?x)) (/ (sin ?x) (cos ?x)))))"),
828        rewrite!("2-tan"; "(/ (* 2 (/ (sin ?x) (cos ?x))) (- 1 (* (/ (sin ?x) (cos ?x)) (/ (sin ?x) (cos ?x)))))" => "(/ (sin (* 2 ?x)) (cos (* 2 ?x)))"),
829    ]
830}
831
832fn to_egg_expr(expr: &Expression) -> RecExpr<TrigLanguage> {
833    expr.to_string().parse().unwrap()
834}
835
836use crate::qgl::lexer::Lexer;
837use crate::qgl::lexer::Token;
838
839// pi -> Ident("pi")
840// 5 -> Number(5)
841// (pi) -> LParen, Ident("pi"), RParen
842// (+ 1 2) -> LParen, Ident("+"), Number(1), Number(2), RParen
843// (sin (pi)) -> LParen, Ident("sin"), LParen, Ident("pi"), RParen, RParen
844// (+ (- 2 (pi)) 3) -> LParen, Ident("+"), LParen, Ident("-"), Number(2), LParen, Ident("pi"), RParen, RParen, Number(3), RParen
845// [LParen, Ident("cos"), LParen, Op('*'), Number("0.5"), Ident("f1"), RParen, RParen]
846fn _from_egg_expr(tokens: Vec<Token>) -> Expression {
847    let (start, op_token) = if tokens[0] == Token::LParen {
848        assert!(tokens.last() == Some(&Token::RParen));
849        (1, tokens[1].clone())
850    } else {
851        (0, tokens[0].clone())
852    };
853
854    if let Token::Number(num) = op_token {
855        return Expression::from_float(num.parse::<f64>().unwrap());
856    }
857
858    if let Token::Ident(ref id) = op_token {
859        if id == "pi" {
860            return Expression::Pi;
861        }
862        if id != "sin" && id != "cos" && id != "sqrt" && id != "pow" {
863            return Expression::Variable(id.to_string());
864        }
865    }
866
867    let mut operands = vec![];
868    let mut i = start + 1;
869    while i < tokens.len() {
870        let token = &tokens[i];
871        if *token == Token::LParen {
872            let mut num_open_parenthesis = 1;
873            let start = i + 1;
874            for (j, token) in tokens.iter().enumerate().skip(i + 1) {
875                if *token == Token::LParen {
876                    num_open_parenthesis += 1;
877                } else if *token == Token::RParen {
878                    num_open_parenthesis -= 1;
879                }
880
881                if num_open_parenthesis == 0 {
882                    operands.push(_from_egg_expr(tokens[start..j + 1].to_vec()));
883                    i = j;
884                    break;
885                }
886            }
887        } else if *token == Token::RParen {
888            assert_eq!(i, tokens.len() - 1);
889        } else {
890            operands.push(_from_egg_expr(tokens[i..i + 1].to_vec()));
891        }
892        i += 1;
893    }
894
895    match op_token {
896        Token::Ident(id) => match id.clone().as_str() {
897            "sin" => Expression::Sin(Box::new(operands[0].clone())),
898            "cos" => Expression::Cos(Box::new(operands[0].clone())),
899            "sqrt" => Expression::Sqrt(Box::new(operands[0].clone())),
900            "pow" => Expression::Pow(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
901            _ => panic!("Invalid operator during parsing of egg expression"),
902        },
903        Token::Negation => Expression::Neg(Box::new(operands[0].clone())),
904        Token::Op(op) => match op {
905            '+' => Expression::Add(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
906            '-' => Expression::Sub(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
907            '*' => Expression::Mul(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
908            '/' => Expression::Div(Box::new(operands[0].clone()), Box::new(operands[1].clone())),
909            _ => panic!("Invalid operator during parsing of egg expression"),
910        },
911        _ => panic!("Invalid token during parsing of egg expression"),
912    }
913}
914
915fn from_egg_expr(expr: RecExpr<TrigLanguage>) -> Expression {
916    let expr_str = expr.to_string();
917    let expr_tokens = Lexer::new(&expr_str).collect::<Vec<_>>();
918    if expr_tokens.is_empty() {
919        panic!("Failure to lex expression: {}", expr_str);
920    }
921
922    // any Op('-') that is next to a Number is a negation and should be grouped
923    let mut grouped_tokens = vec![];
924    let mut i = 0;
925    while i < expr_tokens.len() {
926        if expr_tokens[i] == Token::Op('-') && i < expr_tokens.len() - 1 {
927            if let Token::Number(n) = &expr_tokens[i + 1] {
928                let n = n.parse::<f64>().unwrap();
929                grouped_tokens.push(Token::Number((-n).to_string()));
930                i += 1
931            } else {
932                grouped_tokens.push(expr_tokens[i].clone());
933            }
934        } else {
935            grouped_tokens.push(expr_tokens[i].clone());
936        }
937        i += 1;
938    }
939
940    _from_egg_expr(grouped_tokens)
941}
942
943/// parse an expression, simplify it using egg, and pretty print it back out
944pub fn simplify(expr: &Expression) -> Expression {
945    // parse the expression, the type annotation tells it which Language to use
946    let expr: RecExpr<TrigLanguage> = to_egg_expr(expr);
947
948    // simplify the expression using a Runner, which creates an e-graph with
949    // the given expression and runs the given rules over it
950    let runner = Runner::default().with_expr(&expr).run(&make_rules());
951    let mut extractor = TrigExprExtractor::new(&runner.egraph);
952
953    // the Runner knows which e-class the expression given with `with_expr` is in
954    let root = runner.roots[0];
955
956    // use an Extractor to pick the best element of the root eclass
957    let best = extractor.extract_best(root);
958    from_egg_expr(best)
959}
960
961#[allow(dead_code)]
962pub fn extract_best_sine(expr: Expression) -> Option<Expression> {
963    let expr: RecExpr<TrigLanguage> = to_egg_expr(&expr);
964    let runner = Runner::default()
965        .with_expr(&expr)
966        .with_iter_limit(1000)
967        .with_node_limit(10000000)
968        .run(&make_rules());
969    let egraph = &runner.egraph;
970    let extractor = SineExtractor::new(egraph);
971    let root = runner.roots[0];
972    let best = extractor.extract_sine(root);
973    best.map(from_egg_expr)
974}
975
976#[allow(dead_code)]
977pub fn simplify_complex(expr: ComplexExpression) -> ComplexExpression {
978    let ComplexExpression { real, imag } = expr;
979
980    let real_expr: RecExpr<TrigLanguage> = to_egg_expr(&real);
981    let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(&imag);
982
983    let runner = Runner::default()
984        .with_expr(&real_expr)
985        .with_expr(&imag_expr)
986        .run(&make_rules());
987    let mut extractor = TrigExprExtractor::new(&runner.egraph);
988
989    // the Runner knows which e-class the expression given with `with_expr` is in
990    let real_root = runner.roots[0];
991    let imag_root = runner.roots[1];
992
993    // use an Extractor to pick the best element of the root eclass
994    let best_real = extractor.extract_best(real_root);
995    let real_simple = from_egg_expr(best_real);
996
997    let best_imag = extractor.extract_best(imag_root);
998    let imag_simple = from_egg_expr(best_imag);
999
1000    ComplexExpression {
1001        real: real_simple,
1002        imag: imag_simple,
1003    }
1004}
1005
1006#[allow(dead_code)]
1007pub fn simplify_matrix_no_context(
1008    matrix_expression: &Vec<Vec<ComplexExpression>>,
1009) -> Vec<Vec<ComplexExpression>> {
1010    let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1011
1012    for row in matrix_expression {
1013        for expr in row {
1014            let ComplexExpression { real, imag } = expr;
1015            let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1016            let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1017            runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1018        }
1019    }
1020
1021    runner = runner.run(&make_rules());
1022    let extractor = Extractor::new(&runner.egraph, TrigCostFn);
1023
1024    let mut simplified_matrix = vec![vec![]; matrix_expression.len()];
1025    let nrows = matrix_expression.len();
1026    let ncols = matrix_expression[0].len();
1027
1028    for i in 0..nrows {
1029        for j in 0..ncols {
1030            let real_root = runner.roots[2 * (i * ncols + j)];
1031            let imag_root = runner.roots[2 * (i * ncols + j) + 1];
1032
1033            let (_, best_real) = extractor.find_best(real_root);
1034            let real_simple = from_egg_expr(best_real);
1035            let (_, best_imag) = extractor.find_best(imag_root);
1036            let imag_simple = from_egg_expr(best_imag);
1037
1038            simplified_matrix[i].push(ComplexExpression {
1039                real: real_simple,
1040                imag: imag_simple,
1041            });
1042        }
1043    }
1044
1045    simplified_matrix
1046}
1047
1048#[allow(dead_code)]
1049pub fn simplify_matrix(
1050    matrix_expression: &Vec<Vec<ComplexExpression>>,
1051) -> Vec<Vec<ComplexExpression>> {
1052    let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1053
1054    for row in matrix_expression {
1055        for expr in row {
1056            let ComplexExpression { real, imag } = expr;
1057            let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1058            let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1059            runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1060        }
1061    }
1062
1063    runner = runner.run(&make_rules());
1064    let mut extractor = TrigExprExtractor::new(&runner.egraph);
1065
1066    let mut simplified_matrix = vec![vec![]; matrix_expression.len()];
1067    let nrows = matrix_expression.len();
1068    let ncols = matrix_expression[0].len();
1069
1070    for i in 0..nrows {
1071        for j in 0..ncols {
1072            let real_root = runner.roots[2 * (i * ncols + j)];
1073            let imag_root = runner.roots[2 * (i * ncols + j) + 1];
1074
1075            let best_real = extractor.extract_best(real_root);
1076            let real_simple = from_egg_expr(best_real);
1077            let best_imag = extractor.extract_best(imag_root);
1078            let imag_simple = from_egg_expr(best_imag);
1079
1080            simplified_matrix[i].push(ComplexExpression {
1081                real: real_simple,
1082                imag: imag_simple,
1083            });
1084        }
1085    }
1086
1087    simplified_matrix
1088}
1089
1090pub fn simplify_expressions_iter<'a>(
1091    expression: impl Iterator<Item = &'a Expression>,
1092) -> Vec<Expression> {
1093    let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1094
1095    let mut num_expressions = 0;
1096    for expr in expression {
1097        let expr: RecExpr<TrigLanguage> = to_egg_expr(expr);
1098        runner = runner.with_expr(&expr);
1099        num_expressions += 1;
1100    }
1101
1102    runner = runner.run(&make_rules());
1103    let mut extractor = TrigExprExtractor::new(&runner.egraph);
1104
1105    let mut simplified_expressions = vec![];
1106
1107    for i in 0..num_expressions {
1108        let root = runner.roots[i];
1109        let best = extractor.extract_best(root);
1110        simplified_expressions.push(from_egg_expr(best));
1111    }
1112
1113    simplified_expressions
1114}
1115
1116#[allow(dead_code)]
1117pub fn simplify_expressions(expression: Vec<Expression>) -> Vec<Expression> {
1118    let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1119
1120    let mut num_expressions = 0;
1121    for expr in expression {
1122        let expr: RecExpr<TrigLanguage> = to_egg_expr(&expr);
1123        runner = runner.with_expr(&expr);
1124        num_expressions += 1;
1125    }
1126
1127    runner = runner.run(&make_rules());
1128    let mut extractor = TrigExprExtractor::new(&runner.egraph);
1129
1130    let mut simplified_expressions = vec![];
1131
1132    for i in 0..num_expressions {
1133        let root = runner.roots[i];
1134        let best = extractor.extract_best(root);
1135        simplified_expressions.push(from_egg_expr(best));
1136    }
1137
1138    simplified_expressions
1139}
1140
1141#[allow(dead_code)]
1142pub fn simplify_matrix_and_matvec(
1143    matrix_expression: &Vec<Vec<ComplexExpression>>,
1144    matvec_expression: &Vec<Vec<Vec<ComplexExpression>>>,
1145) -> (
1146    Vec<Vec<ComplexExpression>>,
1147    Vec<Vec<Vec<ComplexExpression>>>,
1148) {
1149    let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1150
1151    for row in matrix_expression {
1152        for expr in row {
1153            let ComplexExpression { real, imag } = expr;
1154            let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1155            let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1156            runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1157        }
1158    }
1159
1160    for mat in matvec_expression {
1161        for row in mat {
1162            for expr in row {
1163                let ComplexExpression { real, imag } = expr;
1164                let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1165                let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1166                runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1167            }
1168        }
1169    }
1170
1171    runner = runner.run(&make_rules());
1172    let mut extractor = TrigExprExtractor::new(&runner.egraph);
1173
1174    let mut simplified_matrix = vec![vec![]; matrix_expression.len()];
1175    let nrows = matrix_expression.len();
1176    let ncols = matrix_expression[0].len();
1177
1178    for i in 0..nrows {
1179        for j in 0..ncols {
1180            let real_root = runner.roots[2 * (i * ncols + j)];
1181            let imag_root = runner.roots[2 * (i * ncols + j) + 1];
1182
1183            let best_real = extractor.extract_best(real_root);
1184            let real_simple = from_egg_expr(best_real);
1185            let best_imag = extractor.extract_best(imag_root);
1186            let imag_simple = from_egg_expr(best_imag);
1187
1188            simplified_matrix[i].push(ComplexExpression {
1189                real: real_simple,
1190                imag: imag_simple,
1191            });
1192        }
1193    }
1194
1195    let matrix_expr_offset = 2 * nrows * ncols;
1196    let nmats = matvec_expression.len();
1197    if nmats == 0 {
1198        return (simplified_matrix, vec![]);
1199    }
1200    let nrows = matvec_expression[0].len();
1201    let ncols = matvec_expression[0][0].len();
1202    let mut simplified_matvec = vec![vec![vec![]; nrows]; nmats];
1203
1204    for m in 0..nmats {
1205        for i in 0..nrows {
1206            for j in 0..ncols {
1207                let real_root =
1208                    runner.roots[matrix_expr_offset + 2 * (m * nrows * ncols + i * ncols + j)];
1209                let imag_root =
1210                    runner.roots[matrix_expr_offset + 2 * (m * nrows * ncols + i * ncols + j) + 1];
1211
1212                let best_real = extractor.extract_best(real_root);
1213                let real_simple = from_egg_expr(best_real);
1214                let best_imag = extractor.extract_best(imag_root);
1215                let imag_simple = from_egg_expr(best_imag);
1216
1217                simplified_matvec[m][i].push(ComplexExpression {
1218                    real: real_simple,
1219                    imag: imag_simple,
1220                });
1221            }
1222        }
1223    }
1224
1225    // println!("{:?}", context);
1226
1227    (simplified_matrix, simplified_matvec)
1228}
1229
1230// #[cfg(test)]
1231// use crate::UnitaryExpression;
1232// #[cfg(test)]
1233// use extract::BottomUpExtractor;
1234
1235// #[test]
1236// fn test_simplify_matrix_and_matvec() {
1237//     let u3 = UnitaryExpression::new(
1238//         String::from("
1239//             utry U3(f1, f2, f3) {
1240//                 [
1241//                     [ cos(f1/2), ~e^(i*f3)*sin(f1/2) ],
1242//                     [ e^(i*f2)*sin(f1/2), e^(i*(f2+f3))*cos(f1/2) ]
1243//                 ]
1244//             }
1245//         "),
1246//     );
1247//     let grad = u3.differentiate();
1248//     let start = std::time::Instant::now();
1249//     let (simplified_matrix, simplified_matvec) = simplify_matrix_and_matvec(&u3.body, &grad.body);
1250//     let elapsed = start.elapsed();
1251//     println!("Time taken: {:?}", elapsed);
1252//     println!("{:?}", simplified_matrix);
1253//     println!("{:?}", simplified_matvec);
1254
1255//     // bottom-up test
1256//     let matrix_expression = &u3.body;
1257//     let matvec_expression = &grad.body;
1258
1259//     let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default();
1260
1261//     for row in matrix_expression {
1262//         for expr in row {
1263//             let ComplexExpression { real, imag } = expr;
1264//             let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1265//             let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1266//             runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1267//         }
1268//     }
1269
1270//     for mat in matvec_expression {
1271//         for row in mat {
1272//             for expr in row {
1273//                 let ComplexExpression { real, imag } = expr;
1274//                 let real_expr: RecExpr<TrigLanguage> = to_egg_expr(real);
1275//                 let imag_expr: RecExpr<TrigLanguage> = to_egg_expr(imag);
1276//                 runner = runner.with_expr(&real_expr).with_expr(&imag_expr);
1277//             }
1278//         }
1279//     }
1280
1281//     runner = runner.run(&make_rules());
1282//     let extractor = BottomUpExtractor;
1283//     let start = std::time::Instant::now();
1284//     let _result = extractor.extract(&runner.egraph, &runner.roots);
1285//     let elapsed = start.elapsed();
1286//     println!("Time taken: {:?}", elapsed);
1287// }
1288
1289#[allow(dead_code)]
1290pub fn check_many_equality(expr1s: &[&Expression], expr2s: &[&Expression]) -> bool {
1291    let expr1s: Vec<RecExpr<TrigLanguage>> = expr1s.iter().map(|expr| to_egg_expr(expr)).collect();
1292    let expr2s: Vec<RecExpr<TrigLanguage>> = expr2s.iter().map(|expr| to_egg_expr(expr)).collect();
1293
1294    let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1295        .with_iter_limit(120)
1296        .with_node_limit(100_000_000);
1297    for expr1 in &expr1s {
1298        runner = runner.with_expr(expr1);
1299    }
1300    for expr2 in &expr2s {
1301        runner = runner.with_expr(expr2);
1302    }
1303    runner = runner.run(&make_rules());
1304
1305    for (expr1, expr2) in expr1s.iter().zip(expr2s.iter()) {
1306        if runner.egraph.equivs(expr1, expr2).is_empty() {
1307            return false;
1308        }
1309    }
1310
1311    true
1312}
1313
1314#[allow(dead_code)]
1315pub fn check_equality(expr: &Expression, expr2: &Expression) -> bool {
1316    let expr1: RecExpr<TrigLanguage> = to_egg_expr(expr);
1317    let expr2: RecExpr<TrigLanguage> = to_egg_expr(expr2);
1318
1319    let runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1320        .with_expr(&expr1)
1321        .with_expr(&expr2)
1322        .with_iter_limit(100)
1323        .with_node_limit(1_000_000)
1324        .run(&make_rules());
1325
1326    !runner.egraph.equivs(&expr1, &expr2).is_empty()
1327    // if runner.egraph.equivs(&expr1, &expr2).len() > 0 {
1328    //     return true;
1329    // }
1330
1331    // let runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1332    //     .with_expr(&expr2)
1333    //     .with_iter_limit(40)
1334    //     .with_node_limit(25_000)
1335    //     .run(&make_rules());
1336
1337    // runner.egraph.equivs(&expr2, &expr1).len() > 0
1338}
1339
1340#[allow(dead_code)]
1341fn print_equality(s1: &str, s2: &str) {
1342    let expr1: RecExpr<TrigLanguage> = s1.parse().unwrap();
1343    let expr2: RecExpr<TrigLanguage> = s2.parse().unwrap();
1344
1345    let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1346        .with_explanations_enabled()
1347        .with_expr(&expr1)
1348        .run(&make_rules());
1349    println!(
1350        "{}",
1351        runner.explain_equivalence(&expr1, &expr2).get_flat_string()
1352    );
1353    // let egraph = runner.egraph;
1354    // let equivs = egraph.equivs(&expr1, &expr2);
1355    // if equivs.is_empty() {
1356    //     println!("{} and {} are not equivalent", s1, s2);
1357    // } else {
1358    //     println!("{} and {} are equivalent", s1, s2);
1359    // }
1360}
1361
1362#[allow(dead_code)]
1363fn check_equality_lhs_only(s1: &str, s2: &str) -> bool {
1364    let expr1: RecExpr<TrigLanguage> = s1.parse().unwrap();
1365    let expr2: RecExpr<TrigLanguage> = s2.parse().unwrap();
1366
1367    let runner: Runner<TrigLanguage, ConstantFold> =
1368        Runner::default().with_expr(&expr1).run(&make_rules());
1369
1370    !runner.egraph.equivs(&expr1, &expr2).is_empty()
1371}
1372
1373#[allow(dead_code)]
1374fn check_equality_both(s1: &str, s2: &str) -> bool {
1375    let expr1: RecExpr<TrigLanguage> = s1.parse().unwrap();
1376    let expr2: RecExpr<TrigLanguage> = s2.parse().unwrap();
1377
1378    let runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1379        .with_expr(&expr1)
1380        .with_node_limit(25_000)
1381        .run(&make_rules());
1382
1383    let lhs = !runner.egraph.equivs(&expr1, &expr2).is_empty();
1384
1385    let runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1386        .with_expr(&expr2)
1387        .with_node_limit(25_000)
1388        .run(&make_rules());
1389
1390    let rhs = !runner.egraph.equivs(&expr2, &expr1).is_empty();
1391
1392    lhs && rhs
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397    // use super::*;
1398
1399    #[test]
1400    fn check_equality_test() {
1401        // let s1 = "(* (sin (* (pi) (-0.25))) (cos (* (0.5) (- (?x2) (?x0)))))";
1402        // let s2 = "(* (pow (2) (-0.5)) (~ (cos (* (0.5) (- (?x0) (?x2))))))";
1403        // let s1 = "(* (sin (* pi (~ 0.25))) (cos (* (0.5) (- (?x2) (?x0)))))";
1404        // let s2 = "(* (pow 2 -0.5) (~ (cos (* (0.5) (- (?x0) (?x2))))))";
1405        // let s1 = "(/ (cos (* (0.5) (- (?x0) (?x2)))) (~ (sqrt (2))))";
1406        // let s2 = "(* (sin (* (pi) (-0.25))) (cos (* (0.5) (- (?x0) (?x2)))))";
1407        // let s1 = "(~ (/ (sin (x1)) (sqrt (2))))";
1408        // let s2 = "(/ (sin (x1)) (~ (sqrt (2))))";
1409        // let expr1: RecExpr<TrigLanguage> = s1.parse().unwrap();
1410        // let expr2: RecExpr<TrigLanguage> = s2.parse().unwrap();
1411
1412        // let runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1413        //     .with_expr(&expr1)
1414        //     // .with_expr(&expr2)
1415        //     .with_iter_limit(50)
1416        //     .with_node_limit(50_000)
1417        //     .run(&make_rules());
1418
1419        // println!("{:?}", runner.egraph.equivs(&expr1, &expr2).len());
1420        //
1421        // let s = "(* (sin (+ (x0) (x1))) (/ (- (cos (x0)) (sin (x0))) (sqrt (2))))";
1422        // let expr1: RecExpr<TrigLanguage> = s.parse().unwrap();
1423        // let mut runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1424        //     // .with_explanations_enabled()
1425        //     .with_expr(&expr1)
1426        //     .with_iter_limit(100)
1427        //     .with_node_limit(1_000_000)
1428        //     .run(&make_rules());
1429
1430        // let expr1: RecExpr<TrigLanguage> = "0.001953125".parse().unwrap();
1431        // let expr2: RecExpr<TrigLanguage> = "7.450580596923828e-9".parse().unwrap();
1432        // println!("{}", runner.explain_equivalence(&expr1, &expr2).get_flat_string());
1433        // let runner = Runner::default().with_expr(&"(cos 0.5)".parse().unwrap()).run(&make_rules());
1434        // assert!(check_equality_lhs_only("(/ 1 4)", "(0.25)"));
1435        // assert!(check_equality_lhs_only("(/ (sqrt 2) 2)", "(/ 1 (sqrt 2))"));
1436        // assert!(check_equality_both("(sin (/ pi 4))", "(sin (* pi 0.25))"));
1437        // assert!(check_equality_lhs_only("(sin (/ pi 4))", "(/ 1 (sqrt 2))"));
1438        // assert!(check_equality_lhs_only("(sin (/ pi 4))", "(pow (2) (-0.5))"));
1439        // assert!(check_equality_both("(cos (* (0.5) (- (?x0) (?x2)))))", "(cos (* (0.5) (- (?x2) (?x0))))"));
1440        // assert!(check_equality_lhs_only("(* (sin (* (pi) (-0.25))) (cos (* (0.5) (- (?x2) (?x0)))))", "(* (pow (2) (-0.5)) (~ (cos (* (0.5) (- (?x0) (?x2))))))"));
1441        // (* (cos (x0)) (- (/ (cos (+ (x0) (x0))) (sqrt (2))) (/ (* (+ (cos (+ (x0) (x0))) (-1)) (sqrt (2))) (2))))
1442        // assert!(check_equality_both("(cos (* 2 ?x))", "(- (* (2) (* (cos ?x) (cos ?x))) 1)"));
1443        // assert!(check_equality_both("(- (cos (* 2 ?x)) 1)", "(* (-2) (* (sin ?x) (sin ?x)))"));
1444        // assert!(check_equality_both("(* (sin (~ ?x)) (sin (* 2 ?x)))", "(* (cos ?x) (- (cos (* 2 ?x)) 1))"));
1445        // assert!(check_equality_lhs_only("(* (cos ?x) (- (/ (cos (+ ?x ?x)) (sqrt 2)) (/ (* (+ (cos (+ ?x ?x)) (-1)) (sqrt (2))) (2))))", "(/ (cos ?x) (sqrt 2))"));
1446
1447        // let s1 = "(* (sin (x0)) (- (/ (* (+ (1) (cos (+ (x0) (x0)))) (sqrt (2))) (2)) (/ (+ (* (-0.5) (sin (+ (x0) (x0)))) (- (+ (* (0.5) (sin (+ (x0) (x0)))) (1)) (-1))) (sqrt (2)))))";
1448        // let s2 = "(/ (sin (x0)) (sqrt (2)))";
1449        // let s1 = "(* (sin (x0)) (- (+ (1) (cos (+ (x0) (x0)))) (+ (* (-0.5) (sin (+ (x0) (x0)))) (- (+ (* (0.5) (sin (+ (x0) (x0)))) (1)) (-1))) ))";
1450        // let s2 = "(* (sin (x0)) (+ (cos (* 2 x0)) 1))";
1451        // let s1 = "(* (sin (x0)) (- (+ (1) (cos (+ (x0) (x0)))) (+ (* (-0.5) (sin (+ (x0) (x0)))) (- (+ (* (0.5) (sin (+ (x0) (x0)))) (1)) (-1)))))";
1452        // let s2 = "(* (sin (x0)) (- (cos (* 2 x0)) 1))";
1453        // assert!(check_equality_lhs_only(s1, s2));
1454        // let s1 = "(/ (+ (* (* (sin (* (pi) (0.25))) (cos (* (0.5) (- (θ) (λ))))) (/ (cos (θ)) (sqrt (2)))) (* (~ (* (sin (* (pi) (0.25))) (sin (* (0.5) (- (θ) (λ)))))) (/ (sin (θ)) (sqrt (2))))) (+ (* (/ (cos (θ)) (sqrt (2))) (/ (cos (θ)) (sqrt (2)))) (* (/ (sin (θ)) (sqrt (2))) (/ (sin (θ)) (sqrt (2))))))";
1455        // let s2 = "(/ (- (* (~ (* (sin (* (pi) (0.25))) (sin (* (0.5) (- (θ) (λ)))))) (/ (cos (θ)) (sqrt (2)))) (* (* (sin (* (pi) (0.25))) (cos (* (0.5) (- (θ) (λ))))) (/ (sin (θ)) (sqrt (2))))) (+ (* (/ (cos (θ)) (sqrt (2))) (/ (cos (θ)) (sqrt (2)))) (* (/ (sin (θ)) (sqrt (2))) (/ (sin (θ)) (sqrt (2))))))";
1456        // // let expr1: RecExpr<TrigLanguage> = s1.parse().unwrap();
1457        // let expr2: RecExpr<TrigLanguage> = s2.parse().unwrap();
1458        // let mut _runner: Runner<TrigLanguage, ConstantFold> = Runner::default()
1459        //     // .with_explanations_enabled()
1460        //     // .with_expr(&expr1)
1461        //     .with_expr(&expr2)
1462        //     .with_iter_limit(30)
1463        //     .with_node_limit(30_000)
1464        //     .run(&make_rules());
1465        // println!("{}", runner.explain_equivalence(&"1".parse().unwrap(), &"-1".parse().unwrap()).get_flat_string());
1466    }
1467}