Skip to main content

alkahest_cas/simplify/
engine.rs

1use super::rules::{
2    AddZero, CanonicalOrder, ConstFold, DivSelf, ExpandMul, FlattenAdd, FlattenMul, MulOne,
3    MulZero, PowOne, PowZero, RewriteRule, SubSelf,
4};
5use crate::deriv::log::{DerivationLog, DerivedExpr};
6use crate::kernel::{ExprData, ExprId, ExprPool};
7
8// ---------------------------------------------------------------------------
9// Configuration
10// ---------------------------------------------------------------------------
11
12/// Controls how many full bottom-up passes the simplifier may perform.
13#[derive(Debug, Clone)]
14pub struct SimplifyConfig {
15    /// Maximum number of full bottom-up passes (default 100).
16    pub max_iterations: usize,
17    /// Whether to distribute multiplication over addition (default false).
18    ///
19    /// When `true`, the `ExpandMul` rule is included: `(a + b) * c → a*c + b*c`.
20    /// Keep disabled unless explicitly expanding, because expansion can loop
21    /// against a future `factor` rule.
22    pub expand: bool,
23    /// Allow branch-cut-sensitive rewrites such as `log(a*b) → log(a) + log(b)`.
24    ///
25    /// This identity only holds when `a` and `b` are positive reals.  Set this
26    /// flag to `true` when you know all variables are positive and want the
27    /// full log/exp rule set; leave it `false` (the default) for safe behaviour
28    /// over complex numbers or when sign information is unavailable.
29    pub allow_branch_cut_rewrites: bool,
30}
31
32impl Default for SimplifyConfig {
33    fn default() -> Self {
34        SimplifyConfig {
35            max_iterations: 100,
36            expand: false,
37            allow_branch_cut_rewrites: false,
38        }
39    }
40}
41
42// ---------------------------------------------------------------------------
43// Default rule set
44// ---------------------------------------------------------------------------
45
46/// Build the rule set for a given config.
47pub fn rules_for_config(config: &SimplifyConfig) -> Vec<Box<dyn RewriteRule>> {
48    let mut rules: Vec<Box<dyn RewriteRule>> = vec![
49        Box::new(FlattenMul),
50        Box::new(FlattenAdd),
51        Box::new(MulZero),
52        Box::new(AddZero),
53        Box::new(MulOne),
54        Box::new(PowZero),
55        Box::new(PowOne),
56        Box::new(ConstFold),
57        Box::new(SubSelf),
58        Box::new(DivSelf),
59        Box::new(CanonicalOrder),
60    ];
61    if config.expand {
62        rules.push(Box::new(ExpandMul));
63    }
64    rules
65}
66
67pub fn default_rules() -> Vec<Box<dyn RewriteRule>> {
68    rules_for_config(&SimplifyConfig::default())
69}
70
71// ---------------------------------------------------------------------------
72// Internal: bottom-up traversal — simplify children, then current node
73// ---------------------------------------------------------------------------
74
75fn simplify_node(
76    expr: ExprId,
77    pool: &ExprPool,
78    rules: &[Box<dyn RewriteRule>],
79) -> DerivedExpr<ExprId> {
80    // 1. Rebuild with simplified children
81    let data = pool.get(expr);
82    let (rebuilt, child_log) = simplify_children(data, pool, rules);
83
84    // 2. Apply rules to rebuilt node until no rule fires
85    let mut current = rebuilt;
86    let mut rule_log = DerivationLog::new();
87    loop {
88        let mut fired = false;
89        for rule in rules {
90            if let Some((new_expr, step_log)) = rule.apply(current, pool) {
91                rule_log = rule_log.merge(step_log);
92                current = new_expr;
93                fired = true;
94                break; // restart from first rule after any change
95            }
96        }
97        if !fired {
98            break;
99        }
100    }
101
102    DerivedExpr::with_log(current, child_log.merge(rule_log))
103}
104
105/// Simplify children of a node and return (rebuilt_expr, child_log).
106fn simplify_children(
107    data: ExprData,
108    pool: &ExprPool,
109    rules: &[Box<dyn RewriteRule>],
110) -> (ExprId, DerivationLog) {
111    let mut log = DerivationLog::new();
112    match data {
113        ExprData::Add(args) => {
114            let new_args: Vec<ExprId> = args
115                .into_iter()
116                .map(|a| {
117                    let r = simplify_node(a, pool, rules);
118                    log = std::mem::take(&mut log).merge(r.log);
119                    r.value
120                })
121                .collect();
122            (pool.add(new_args), log)
123        }
124        ExprData::Mul(args) => {
125            let new_args: Vec<ExprId> = args
126                .into_iter()
127                .map(|a| {
128                    let r = simplify_node(a, pool, rules);
129                    log = std::mem::take(&mut log).merge(r.log);
130                    r.value
131                })
132                .collect();
133            (pool.mul(new_args), log)
134        }
135        ExprData::Pow { base, exp } => {
136            let rb = simplify_node(base, pool, rules);
137            log = log.merge(rb.log);
138            let re = simplify_node(exp, pool, rules);
139            log = log.merge(re.log);
140            (pool.pow(rb.value, re.value), log)
141        }
142        ExprData::Func { name, args } => {
143            let new_args: Vec<ExprId> = args
144                .into_iter()
145                .map(|a| {
146                    let r = simplify_node(a, pool, rules);
147                    log = std::mem::take(&mut log).merge(r.log);
148                    r.value
149                })
150                .collect();
151            (pool.func(name, new_args), log)
152        }
153        // PA-9: Simplify values in each branch and the default.
154        // The condition expressions (predicates) are passed through unchanged
155        // since there are no simplification rules for predicates yet.
156        ExprData::Piecewise { branches, default } => {
157            let new_branches: Vec<(ExprId, ExprId)> = branches
158                .into_iter()
159                .map(|(cond, val)| {
160                    let rv = simplify_node(val, pool, rules);
161                    log = std::mem::take(&mut log).merge(rv.log);
162                    (cond, rv.value)
163                })
164                .collect();
165            let rd = simplify_node(default, pool, rules);
166            log = log.merge(rd.log);
167            (pool.piecewise(new_branches, rd.value), log)
168        }
169        // Predicate args may be simplified as expressions.
170        ExprData::Predicate { kind, args } => {
171            let new_args: Vec<ExprId> = args
172                .into_iter()
173                .map(|a| {
174                    let r = simplify_node(a, pool, rules);
175                    log = std::mem::take(&mut log).merge(r.log);
176                    r.value
177                })
178                .collect();
179            (pool.predicate(kind, new_args), log)
180        }
181        ExprData::Forall { var, body } => {
182            let rb = simplify_node(body, pool, rules);
183            log = log.merge(rb.log);
184            (pool.forall(var, rb.value), log)
185        }
186        ExprData::Exists { var, body } => {
187            let rb = simplify_node(body, pool, rules);
188            log = log.merge(rb.log);
189            (pool.exists(var, rb.value), log)
190        }
191        ExprData::BigO(arg) => {
192            let r = simplify_node(arg, pool, rules);
193            log = log.merge(r.log);
194            (pool.big_o(r.value), log)
195        }
196        // Atoms have no children
197        atom => (pool.intern(atom), log),
198    }
199}
200
201// ---------------------------------------------------------------------------
202// Public API
203// ---------------------------------------------------------------------------
204
205/// Simplify `expr` with a custom rule set and config.
206pub fn simplify_with(
207    expr: ExprId,
208    pool: &ExprPool,
209    rules: &[Box<dyn RewriteRule>],
210    config: SimplifyConfig,
211) -> DerivedExpr<ExprId> {
212    let mut current = DerivedExpr::new(expr);
213    for _ in 0..config.max_iterations {
214        let result = simplify_node(current.value, pool, rules);
215        let merged_log = current.log.merge(result.log);
216        if result.value == current.value {
217            return DerivedExpr::with_log(current.value, merged_log);
218        }
219        current = DerivedExpr::with_log(result.value, merged_log);
220    }
221    current
222}
223
224/// Simplify `expr` with the default rule set.
225pub fn simplify(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
226    let config = SimplifyConfig::default();
227    simplify_with(expr, pool, &rules_for_config(&config), config)
228}
229
230/// Simplify `expr` with expansion enabled (`(a+b)*c → a*c + b*c`).
231pub fn simplify_expanded(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
232    let config = SimplifyConfig {
233        expand: true,
234        ..SimplifyConfig::default()
235    };
236    simplify_with(expr, pool, &rules_for_config(&config), config)
237}
238
239// ---------------------------------------------------------------------------
240// Unit tests
241// ---------------------------------------------------------------------------
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::kernel::{Domain, ExprPool};
247
248    fn p() -> ExprPool {
249        ExprPool::new()
250    }
251
252    #[test]
253    fn simplify_x_plus_zero() {
254        let pool = p();
255        let x = pool.symbol("x", Domain::Real);
256        let expr = pool.add(vec![x, pool.integer(0_i32)]);
257        let r = simplify(expr, &pool);
258        assert_eq!(r.value, x);
259        assert!(!r.log.is_empty(), "should have logged a step");
260        assert!(
261            r.log.steps().iter().any(|s| s.rule_name == "add_zero"),
262            "log should mention add_zero"
263        );
264    }
265
266    #[test]
267    fn simplify_x_times_one() {
268        let pool = p();
269        let x = pool.symbol("x", Domain::Real);
270        let expr = pool.mul(vec![x, pool.integer(1_i32)]);
271        let r = simplify(expr, &pool);
272        assert_eq!(r.value, x);
273    }
274
275    #[test]
276    fn simplify_x_times_zero() {
277        let pool = p();
278        let x = pool.symbol("x", Domain::Real);
279        let expr = pool.mul(vec![x, pool.integer(0_i32)]);
280        let r = simplify(expr, &pool);
281        assert_eq!(r.value, pool.integer(0_i32));
282    }
283
284    #[test]
285    fn simplify_x_pow_one() {
286        let pool = p();
287        let x = pool.symbol("x", Domain::Real);
288        let expr = pool.pow(x, pool.integer(1_i32));
289        let r = simplify(expr, &pool);
290        assert_eq!(r.value, x);
291    }
292
293    #[test]
294    fn simplify_x_pow_zero() {
295        let pool = p();
296        let x = pool.symbol("x", Domain::Real);
297        let expr = pool.pow(x, pool.integer(0_i32));
298        let r = simplify(expr, &pool);
299        assert_eq!(r.value, pool.integer(1_i32));
300        assert!(
301            r.log.steps().iter().any(|s| !s.side_conditions.is_empty()),
302            "pow_zero should record side condition"
303        );
304    }
305
306    #[test]
307    fn simplify_const_fold_add() {
308        let pool = p();
309        let expr = pool.add(vec![pool.integer(2_i32), pool.integer(3_i32)]);
310        let r = simplify(expr, &pool);
311        assert_eq!(r.value, pool.integer(5_i32));
312    }
313
314    #[test]
315    fn simplify_const_fold_mul() {
316        let pool = p();
317        let expr = pool.mul(vec![pool.integer(4_i32), pool.integer(5_i32)]);
318        let r = simplify(expr, &pool);
319        assert_eq!(r.value, pool.integer(20_i32));
320    }
321
322    #[test]
323    fn simplify_const_fold_pow() {
324        let pool = p();
325        let expr = pool.pow(pool.integer(2_i32), pool.integer(10_i32));
326        let r = simplify(expr, &pool);
327        assert_eq!(r.value, pool.integer(1024_i32));
328    }
329
330    #[test]
331    fn simplify_sub_self() {
332        // x + (-1)*x → 0
333        let pool = p();
334        let x = pool.symbol("x", Domain::Real);
335        let neg_x = pool.mul(vec![pool.integer(-1_i32), x]);
336        let expr = pool.add(vec![x, neg_x]);
337        let r = simplify(expr, &pool);
338        assert_eq!(r.value, pool.integer(0_i32));
339    }
340
341    #[test]
342    fn simplify_div_self() {
343        // x * x^(-1) → 1
344        let pool = p();
345        let x = pool.symbol("x", Domain::Real);
346        let x_inv = pool.pow(x, pool.integer(-1_i32));
347        let expr = pool.mul(vec![x, x_inv]);
348        let r = simplify(expr, &pool);
349        assert_eq!(r.value, pool.integer(1_i32));
350    }
351
352    #[test]
353    fn simplify_nested() {
354        // (x + 0) * 1 → x
355        let pool = p();
356        let x = pool.symbol("x", Domain::Real);
357        let inner = pool.add(vec![x, pool.integer(0_i32)]);
358        let expr = pool.mul(vec![inner, pool.integer(1_i32)]);
359        let r = simplify(expr, &pool);
360        assert_eq!(r.value, x);
361    }
362
363    #[test]
364    fn simplify_idempotent_on_already_simple() {
365        let pool = p();
366        let x = pool.symbol("x", Domain::Real);
367        let r = simplify(x, &pool);
368        assert_eq!(r.value, x);
369        assert!(r.log.is_empty());
370    }
371
372    #[test]
373    fn simplify_with_custom_config() {
374        let pool = p();
375        let x = pool.symbol("x", Domain::Real);
376        let expr = pool.add(vec![x, pool.integer(0_i32)]);
377        let config = SimplifyConfig {
378            max_iterations: 1,
379            ..SimplifyConfig::default()
380        };
381        let r = simplify_with(expr, &pool, &default_rules(), config);
382        assert_eq!(r.value, x);
383    }
384}