1use super::rules::{
2 AddZero, CanonicalOrder, ConstFold, DivSelf, ExpandMul, FlattenAdd, FlattenMul, MulOne,
3 MulZero, PowOne, PowZero, RewriteRule, SubSelf,
4};
5use crate::deriv::log::{DerivationLog, DerivedExpr};
6use crate::kernel::{ExprData, ExprId, ExprPool};
7
8#[derive(Debug, Clone)]
14pub struct SimplifyConfig {
15 pub max_iterations: usize,
17 pub expand: bool,
23 pub allow_branch_cut_rewrites: bool,
30}
31
32impl Default for SimplifyConfig {
33 fn default() -> Self {
34 SimplifyConfig {
35 max_iterations: 100,
36 expand: false,
37 allow_branch_cut_rewrites: false,
38 }
39 }
40}
41
42pub fn rules_for_config(config: &SimplifyConfig) -> Vec<Box<dyn RewriteRule>> {
48 let mut rules: Vec<Box<dyn RewriteRule>> = vec![
49 Box::new(FlattenMul),
50 Box::new(FlattenAdd),
51 Box::new(MulZero),
52 Box::new(AddZero),
53 Box::new(MulOne),
54 Box::new(PowZero),
55 Box::new(PowOne),
56 Box::new(ConstFold),
57 Box::new(SubSelf),
58 Box::new(DivSelf),
59 Box::new(CanonicalOrder),
60 ];
61 if config.expand {
62 rules.push(Box::new(ExpandMul));
63 }
64 rules
65}
66
67pub fn default_rules() -> Vec<Box<dyn RewriteRule>> {
68 rules_for_config(&SimplifyConfig::default())
69}
70
71fn simplify_node(
76 expr: ExprId,
77 pool: &ExprPool,
78 rules: &[Box<dyn RewriteRule>],
79) -> DerivedExpr<ExprId> {
80 let data = pool.get(expr);
82 let (rebuilt, child_log) = simplify_children(data, pool, rules);
83
84 let mut current = rebuilt;
86 let mut rule_log = DerivationLog::new();
87 loop {
88 let mut fired = false;
89 for rule in rules {
90 if let Some((new_expr, step_log)) = rule.apply(current, pool) {
91 rule_log = rule_log.merge(step_log);
92 current = new_expr;
93 fired = true;
94 break; }
96 }
97 if !fired {
98 break;
99 }
100 }
101
102 DerivedExpr::with_log(current, child_log.merge(rule_log))
103}
104
105fn simplify_children(
107 data: ExprData,
108 pool: &ExprPool,
109 rules: &[Box<dyn RewriteRule>],
110) -> (ExprId, DerivationLog) {
111 let mut log = DerivationLog::new();
112 match data {
113 ExprData::Add(args) => {
114 let new_args: Vec<ExprId> = args
115 .into_iter()
116 .map(|a| {
117 let r = simplify_node(a, pool, rules);
118 log = std::mem::take(&mut log).merge(r.log);
119 r.value
120 })
121 .collect();
122 (pool.add(new_args), log)
123 }
124 ExprData::Mul(args) => {
125 let new_args: Vec<ExprId> = args
126 .into_iter()
127 .map(|a| {
128 let r = simplify_node(a, pool, rules);
129 log = std::mem::take(&mut log).merge(r.log);
130 r.value
131 })
132 .collect();
133 (pool.mul(new_args), log)
134 }
135 ExprData::Pow { base, exp } => {
136 let rb = simplify_node(base, pool, rules);
137 log = log.merge(rb.log);
138 let re = simplify_node(exp, pool, rules);
139 log = log.merge(re.log);
140 (pool.pow(rb.value, re.value), log)
141 }
142 ExprData::Func { name, args } => {
143 let new_args: Vec<ExprId> = args
144 .into_iter()
145 .map(|a| {
146 let r = simplify_node(a, pool, rules);
147 log = std::mem::take(&mut log).merge(r.log);
148 r.value
149 })
150 .collect();
151 (pool.func(name, new_args), log)
152 }
153 ExprData::Piecewise { branches, default } => {
157 let new_branches: Vec<(ExprId, ExprId)> = branches
158 .into_iter()
159 .map(|(cond, val)| {
160 let rv = simplify_node(val, pool, rules);
161 log = std::mem::take(&mut log).merge(rv.log);
162 (cond, rv.value)
163 })
164 .collect();
165 let rd = simplify_node(default, pool, rules);
166 log = log.merge(rd.log);
167 (pool.piecewise(new_branches, rd.value), log)
168 }
169 ExprData::Predicate { kind, args } => {
171 let new_args: Vec<ExprId> = args
172 .into_iter()
173 .map(|a| {
174 let r = simplify_node(a, pool, rules);
175 log = std::mem::take(&mut log).merge(r.log);
176 r.value
177 })
178 .collect();
179 (pool.predicate(kind, new_args), log)
180 }
181 ExprData::Forall { var, body } => {
182 let rb = simplify_node(body, pool, rules);
183 log = log.merge(rb.log);
184 (pool.forall(var, rb.value), log)
185 }
186 ExprData::Exists { var, body } => {
187 let rb = simplify_node(body, pool, rules);
188 log = log.merge(rb.log);
189 (pool.exists(var, rb.value), log)
190 }
191 ExprData::BigO(arg) => {
192 let r = simplify_node(arg, pool, rules);
193 log = log.merge(r.log);
194 (pool.big_o(r.value), log)
195 }
196 atom => (pool.intern(atom), log),
198 }
199}
200
201pub fn simplify_with(
207 expr: ExprId,
208 pool: &ExprPool,
209 rules: &[Box<dyn RewriteRule>],
210 config: SimplifyConfig,
211) -> DerivedExpr<ExprId> {
212 let mut current = DerivedExpr::new(expr);
213 for _ in 0..config.max_iterations {
214 let result = simplify_node(current.value, pool, rules);
215 let merged_log = current.log.merge(result.log);
216 if result.value == current.value {
217 return DerivedExpr::with_log(current.value, merged_log);
218 }
219 current = DerivedExpr::with_log(result.value, merged_log);
220 }
221 current
222}
223
224pub fn simplify(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
226 let config = SimplifyConfig::default();
227 simplify_with(expr, pool, &rules_for_config(&config), config)
228}
229
230pub fn simplify_expanded(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
232 let config = SimplifyConfig {
233 expand: true,
234 ..SimplifyConfig::default()
235 };
236 simplify_with(expr, pool, &rules_for_config(&config), config)
237}
238
239#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::kernel::{Domain, ExprPool};
247
248 fn p() -> ExprPool {
249 ExprPool::new()
250 }
251
252 #[test]
253 fn simplify_x_plus_zero() {
254 let pool = p();
255 let x = pool.symbol("x", Domain::Real);
256 let expr = pool.add(vec![x, pool.integer(0_i32)]);
257 let r = simplify(expr, &pool);
258 assert_eq!(r.value, x);
259 assert!(!r.log.is_empty(), "should have logged a step");
260 assert!(
261 r.log.steps().iter().any(|s| s.rule_name == "add_zero"),
262 "log should mention add_zero"
263 );
264 }
265
266 #[test]
267 fn simplify_x_times_one() {
268 let pool = p();
269 let x = pool.symbol("x", Domain::Real);
270 let expr = pool.mul(vec![x, pool.integer(1_i32)]);
271 let r = simplify(expr, &pool);
272 assert_eq!(r.value, x);
273 }
274
275 #[test]
276 fn simplify_x_times_zero() {
277 let pool = p();
278 let x = pool.symbol("x", Domain::Real);
279 let expr = pool.mul(vec![x, pool.integer(0_i32)]);
280 let r = simplify(expr, &pool);
281 assert_eq!(r.value, pool.integer(0_i32));
282 }
283
284 #[test]
285 fn simplify_x_pow_one() {
286 let pool = p();
287 let x = pool.symbol("x", Domain::Real);
288 let expr = pool.pow(x, pool.integer(1_i32));
289 let r = simplify(expr, &pool);
290 assert_eq!(r.value, x);
291 }
292
293 #[test]
294 fn simplify_x_pow_zero() {
295 let pool = p();
296 let x = pool.symbol("x", Domain::Real);
297 let expr = pool.pow(x, pool.integer(0_i32));
298 let r = simplify(expr, &pool);
299 assert_eq!(r.value, pool.integer(1_i32));
300 assert!(
301 r.log.steps().iter().any(|s| !s.side_conditions.is_empty()),
302 "pow_zero should record side condition"
303 );
304 }
305
306 #[test]
307 fn simplify_const_fold_add() {
308 let pool = p();
309 let expr = pool.add(vec![pool.integer(2_i32), pool.integer(3_i32)]);
310 let r = simplify(expr, &pool);
311 assert_eq!(r.value, pool.integer(5_i32));
312 }
313
314 #[test]
315 fn simplify_const_fold_mul() {
316 let pool = p();
317 let expr = pool.mul(vec![pool.integer(4_i32), pool.integer(5_i32)]);
318 let r = simplify(expr, &pool);
319 assert_eq!(r.value, pool.integer(20_i32));
320 }
321
322 #[test]
323 fn simplify_const_fold_pow() {
324 let pool = p();
325 let expr = pool.pow(pool.integer(2_i32), pool.integer(10_i32));
326 let r = simplify(expr, &pool);
327 assert_eq!(r.value, pool.integer(1024_i32));
328 }
329
330 #[test]
331 fn simplify_sub_self() {
332 let pool = p();
334 let x = pool.symbol("x", Domain::Real);
335 let neg_x = pool.mul(vec![pool.integer(-1_i32), x]);
336 let expr = pool.add(vec![x, neg_x]);
337 let r = simplify(expr, &pool);
338 assert_eq!(r.value, pool.integer(0_i32));
339 }
340
341 #[test]
342 fn simplify_div_self() {
343 let pool = p();
345 let x = pool.symbol("x", Domain::Real);
346 let x_inv = pool.pow(x, pool.integer(-1_i32));
347 let expr = pool.mul(vec![x, x_inv]);
348 let r = simplify(expr, &pool);
349 assert_eq!(r.value, pool.integer(1_i32));
350 }
351
352 #[test]
353 fn simplify_nested() {
354 let pool = p();
356 let x = pool.symbol("x", Domain::Real);
357 let inner = pool.add(vec![x, pool.integer(0_i32)]);
358 let expr = pool.mul(vec![inner, pool.integer(1_i32)]);
359 let r = simplify(expr, &pool);
360 assert_eq!(r.value, x);
361 }
362
363 #[test]
364 fn simplify_idempotent_on_already_simple() {
365 let pool = p();
366 let x = pool.symbol("x", Domain::Real);
367 let r = simplify(x, &pool);
368 assert_eq!(r.value, x);
369 assert!(r.log.is_empty());
370 }
371
372 #[test]
373 fn simplify_with_custom_config() {
374 let pool = p();
375 let x = pool.symbol("x", Domain::Real);
376 let expr = pool.add(vec![x, pool.integer(0_i32)]);
377 let config = SimplifyConfig {
378 max_iterations: 1,
379 ..SimplifyConfig::default()
380 };
381 let r = simplify_with(expr, &pool, &default_rules(), config);
382 assert_eq!(r.value, x);
383 }
384}