minicas_crs/
lib.rs

1//! Mechanical simplification / factorization rules for algebraic expressions.
2//!
3//! Companion crate to [minicas_core].
4//!
5//! ```
6//! # use minicas_crs::simplify;
7//! # use minicas_core::ast::*;
8//! let mut n = Node::try_from("5x * 2x").unwrap();
9//!
10//! // true means apply the full set of rules (i.e. factorization rules)
11//! simplify(&mut n, true).unwrap();
12//!
13//! assert_eq!(n, Node::try_from("10 * pow(x, 2)").unwrap());
14//! ```
15
16use lazy_static::lazy_static;
17use minicas_core::ast::{AstNode, Node};
18use minicas_core::rules::{Rule, RuleSpec};
19use std::collections::HashMap;
20use toml::de;
21
22include!(concat!(env!("OUT_DIR"), "/rules.rs"));
23
24/// Returns all CRS rules.
25fn rules() -> HashMap<String, RuleSpec> {
26    de::from_str::<HashMap<String, RuleSpec>>(RULES_SRC).unwrap()
27}
28
29lazy_static! {
30    static ref RULES: HashMap<String, Rule> = rules()
31        .into_iter()
32        .map(|(name, spec)| { (name, spec.try_into().unwrap()) })
33        .collect();
34}
35
36/// Iteratively applies the simplification rules to the given node until none match.
37///
38/// When `all` is true, factorization and more aggressive simplification rules are additionally applied.
39///
40/// ```
41/// # use minicas_crs::simplify;
42/// # use minicas_core::ast::{AstNode, Node};
43/// let mut n = Node::try_from("(a - a) / b").unwrap();
44/// simplify(&mut n, false).unwrap();
45/// assert_eq!(n, Node::try_from("0").unwrap());
46/// ```
47pub fn simplify(n: &mut Node, all: bool) -> Result<(), ()> {
48    let (mut rule_matched, mut i) = (true, 0usize);
49    while rule_matched && i < 50 {
50        rule_matched = false;
51        i += 1;
52
53        n.walk_mut(true, &mut |n| {
54            for (_name, rule) in RULES.iter() {
55                if (all && !rule.meta.alt_form) || rule.meta.is_simplify {
56                    rule_matched |= rule.eval(n).unwrap(); // TODO: dont just unwrap
57                }
58            }
59
60            true
61        });
62    }
63
64    if i >= 50 {
65        Err(())
66    } else {
67        Ok(())
68    }
69}
70
71#[cfg(test)]
72fn do_test_rule(name: &str) {
73    let rule = RULES.get(name).expect("rule doesnt exist");
74    if let Err((idx, e)) = rule.self_test() {
75        panic!("failed test at index {}: {}", idx, e);
76    }
77}
78
79#[cfg(test)]
80#[test]
81fn test_simplify_rules() {
82    // Shouldnt try and do constant folding
83    let mut n = Node::try_from("1 + 1").unwrap();
84    simplify(&mut n, false).unwrap();
85    assert_eq!(n, Node::try_from("1 + 1").unwrap());
86
87    // Two steps: a-a to 0, then 0/b to 0.
88    let mut n = Node::try_from("(a - a) / b").unwrap();
89    simplify(&mut n, false).unwrap();
90    assert_eq!(n, Node::try_from("0").unwrap());
91
92    // Combining coefficients
93    let mut n = Node::try_from("3 * 5x").unwrap();
94    simplify(&mut n, true).unwrap();
95    assert_eq!(n, Node::try_from("15x").unwrap());
96
97    let mut n = Node::try_from("5x -- 2x").unwrap();
98    simplify(&mut n, true).unwrap();
99    assert_eq!(n, Node::try_from("7x").unwrap());
100    let mut n = Node::try_from("5x * 2x").unwrap();
101    simplify(&mut n, true).unwrap();
102    assert_eq!(n, Node::try_from("10 * pow(x, 2)").unwrap());
103}