1use crate::deriv::log::{DerivationLog, RewriteStep, SideCondition};
19use crate::kernel::{ExprData, ExprId, ExprPool};
20use crate::pattern::{Pattern, Substitution};
21use crate::simplify::rules::RewriteRule;
22
23fn one_step(name: &'static str, before: ExprId, after: ExprId) -> DerivationLog {
24 let mut log = DerivationLog::new();
25 log.push(RewriteStep::simple(name, before, after));
26 log
27}
28
29pub struct SinNeg;
35
36impl RewriteRule for SinNeg {
37 fn name(&self) -> &'static str {
38 "sin_neg"
39 }
40
41 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
42 let arg = func_arg("sin", expr, pool)?;
43 let inner = neg_inner(arg, pool)?;
44 let after_inner = pool.func("sin", vec![inner]);
45 let neg_one = pool.integer(-1_i32);
46 let after = pool.mul(vec![neg_one, after_inner]);
47 Some((after, one_step(self.name(), expr, after)))
48 }
49}
50
51pub struct CosNeg;
53
54impl RewriteRule for CosNeg {
55 fn name(&self) -> &'static str {
56 "cos_neg"
57 }
58
59 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
60 let arg = func_arg("cos", expr, pool)?;
61 let inner = neg_inner(arg, pool)?;
62 let after = pool.func("cos", vec![inner]);
63 Some((after, one_step(self.name(), expr, after)))
64 }
65}
66
67pub struct TanExpand;
69
70impl RewriteRule for TanExpand {
71 fn name(&self) -> &'static str {
72 "tan_expand"
73 }
74
75 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
76 let arg = func_arg("tan", expr, pool)?;
77 let sin_x = pool.func("sin", vec![arg]);
78 let cos_x = pool.func("cos", vec![arg]);
79 let cos_inv = pool.pow(cos_x, pool.integer(-1_i32));
80 let after = pool.mul(vec![sin_x, cos_inv]);
81 Some((after, one_step(self.name(), expr, after)))
82 }
83}
84
85pub struct SinCosIdentity;
90
91impl RewriteRule for SinCosIdentity {
92 fn name(&self) -> &'static str {
93 "sin_sq_plus_cos_sq"
94 }
95
96 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
97 let args = match pool.get(expr) {
98 ExprData::Add(v) => v,
99 _ => return None,
100 };
101
102 let sin_sq_pos = args.iter().position(|&a| is_sin_sq(a, pool))?;
104 let sin_arg = sin_inner(args[sin_sq_pos], pool).unwrap();
105 let cos_sq_pos = args.iter().position(|&a| is_cos_sq_of(a, sin_arg, pool))?;
106
107 if sin_sq_pos == cos_sq_pos {
108 return None;
109 }
110
111 let one = pool.integer(1_i32);
113 let mut new_args: Vec<ExprId> = args
114 .into_iter()
115 .enumerate()
116 .filter(|&(i, _)| i != sin_sq_pos && i != cos_sq_pos)
117 .map(|(_, a)| a)
118 .collect();
119 new_args.push(one);
120
121 let after = match new_args.len() {
122 1 => new_args[0],
123 _ => pool.add(new_args),
124 };
125
126 Some((after, one_step(self.name(), expr, after)))
127 }
128}
129
130pub fn trig_rules() -> Vec<Box<dyn RewriteRule>> {
132 vec![
133 Box::new(SinNeg),
134 Box::new(CosNeg),
135 Box::new(TanExpand),
136 Box::new(SinCosIdentity),
137 ]
138}
139
140pub struct LogOfExp;
146
147impl RewriteRule for LogOfExp {
148 fn name(&self) -> &'static str {
149 "log_of_exp"
150 }
151
152 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
153 let arg = func_arg("log", expr, pool)?;
154 let inner = func_arg("exp", arg, pool)?;
155 Some((inner, one_step(self.name(), expr, inner)))
156 }
157}
158
159pub struct ExpOfLog;
161
162impl RewriteRule for ExpOfLog {
163 fn name(&self) -> &'static str {
164 "exp_of_log"
165 }
166
167 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
168 let arg = func_arg("exp", expr, pool)?;
169 let inner = func_arg("log", arg, pool)?;
170 Some((inner, one_step(self.name(), expr, inner)))
171 }
172}
173
174pub struct LogOfProduct;
182
183impl RewriteRule for LogOfProduct {
184 fn name(&self) -> &'static str {
185 "log_of_product"
186 }
187
188 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
189 let arg = func_arg("log", expr, pool)?;
190 let factors = match pool.get(arg) {
191 ExprData::Mul(v) if v.len() >= 2 => v,
192 _ => return None,
193 };
194 let logs: Vec<ExprId> = factors.iter().map(|&f| pool.func("log", vec![f])).collect();
195 let after = pool.add(logs);
196 let conds: Vec<SideCondition> = factors
197 .iter()
198 .map(|&f| SideCondition::Positive(f))
199 .collect();
200 let mut log = DerivationLog::new();
201 log.push(RewriteStep::with_conditions(
202 "log_of_product",
203 expr,
204 after,
205 conds,
206 ));
207 Some((after, log))
208 }
209}
210
211pub struct LogOfPow;
213
214impl RewriteRule for LogOfPow {
215 fn name(&self) -> &'static str {
216 "log_of_pow"
217 }
218
219 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
220 let arg = func_arg("log", expr, pool)?;
221 let (base, exp) = match pool.get(arg) {
222 ExprData::Pow { base, exp } => (base, exp),
223 _ => return None,
224 };
225 let log_base = pool.func("log", vec![base]);
226 let after = pool.mul(vec![exp, log_base]);
227 Some((after, one_step(self.name(), expr, after)))
228 }
229}
230
231pub fn log_exp_rules() -> Vec<Box<dyn RewriteRule>> {
237 vec![
238 Box::new(LogOfExp),
239 Box::new(ExpOfLog),
240 Box::new(LogOfProduct),
241 Box::new(LogOfPow),
242 ]
243}
244
245pub fn log_exp_rules_safe() -> Vec<Box<dyn RewriteRule>> {
250 vec![Box::new(LogOfExp), Box::new(ExpOfLog), Box::new(LogOfPow)]
251}
252
253pub struct PatternRule {
289 pub lhs: Pattern,
290 pub rhs: ExprId,
291 name: &'static str,
292}
293
294impl PatternRule {
295 pub fn new(lhs: Pattern, rhs: ExprId) -> Self {
296 PatternRule {
297 lhs,
298 rhs,
299 name: "pattern_rule",
300 }
301 }
302
303 pub fn named(lhs: Pattern, rhs: ExprId, name: &'static str) -> Self {
304 PatternRule { lhs, rhs, name }
305 }
306}
307
308impl RewriteRule for PatternRule {
309 fn name(&self) -> &'static str {
310 self.name
311 }
312
313 fn apply(&self, expr: ExprId, pool: &ExprPool) -> Option<(ExprId, DerivationLog)> {
314 let subst = match_at_root(&self.lhs, expr, pool)?;
316 let after = subst.apply(self.rhs, pool);
317 if after == expr {
318 return None;
319 }
320 Some((after, one_step(self.name, expr, after)))
321 }
322}
323
324fn match_at_root(pattern: &Pattern, expr: ExprId, pool: &ExprPool) -> Option<Substitution> {
326 let empty = Substitution {
327 bindings: std::collections::HashMap::new(),
328 };
329 match_root_node(pattern.root, expr, empty, pool)
330}
331
332fn match_root_node(
333 pat: ExprId,
334 expr: ExprId,
335 subst: Substitution,
336 pool: &ExprPool,
337) -> Option<Substitution> {
338 use crate::kernel::expr::ExprData as ED;
339
340 enum PN {
341 Wildcard(String),
342 Integer(i64),
343 Symbol(String),
344 Add(Vec<ExprId>),
345 Mul(Vec<ExprId>),
346 Pow(ExprId, ExprId),
347 Func(String, Vec<ExprId>),
348 Literal,
349 }
350 enum EN {
351 Integer(i64),
352 Symbol(String),
353 Add(Vec<ExprId>),
354 Mul(Vec<ExprId>),
355 Pow(ExprId, ExprId),
356 Func(String, Vec<ExprId>),
357 Other,
358 }
359
360 let pn = pool.with(pat, |d| match d {
361 ED::Symbol { name, .. } if name.starts_with(|c: char| c.is_lowercase()) => {
362 PN::Wildcard(name.clone())
363 }
364 ED::Symbol { name, .. } => PN::Symbol(name.clone()),
365 ED::Integer(n) => PN::Integer(n.0.to_i64().unwrap_or(i64::MIN)),
366 ED::Add(v) => PN::Add(v.clone()),
367 ED::Mul(v) => PN::Mul(v.clone()),
368 ED::Pow { base, exp } => PN::Pow(*base, *exp),
369 ED::Func { name, args } => PN::Func(name.clone(), args.clone()),
370 _ => PN::Literal,
371 });
372
373 let en = pool.with(expr, |d| match d {
374 ED::Symbol { name, .. } => EN::Symbol(name.clone()),
375 ED::Integer(n) => EN::Integer(n.0.to_i64().unwrap_or(i64::MIN)),
376 ED::Add(v) => EN::Add(v.clone()),
377 ED::Mul(v) => EN::Mul(v.clone()),
378 ED::Pow { base, exp } => EN::Pow(*base, *exp),
379 ED::Func { name, args } => EN::Func(name.clone(), args.clone()),
380 _ => EN::Other,
381 });
382
383 match pn {
384 PN::Wildcard(name) => {
385 let mut s = subst;
386 match s.bindings.get(&name) {
387 Some(&existing) if existing != expr => return None,
388 _ => {
389 s.bindings.insert(name, expr);
390 }
391 }
392 Some(s)
393 }
394 PN::Integer(pv) => {
395 if matches!(en, EN::Integer(ev) if ev == pv) {
396 Some(subst)
397 } else {
398 None
399 }
400 }
401 PN::Symbol(pname) => {
402 if matches!(en, EN::Symbol(ref ename) if *ename == pname) {
403 Some(subst)
404 } else {
405 None
406 }
407 }
408 PN::Add(pargs) => {
409 let EN::Add(eargs) = en else { return None };
410 match_args_exact(&pargs, &eargs, subst, pool)
411 }
412 PN::Mul(pargs) => {
413 let EN::Mul(eargs) = en else { return None };
414 match_args_exact(&pargs, &eargs, subst, pool)
415 }
416 PN::Pow(pb, pe) => {
417 let EN::Pow(eb, ee) = en else { return None };
418 let s = match_root_node(pb, eb, subst, pool)?;
419 match_root_node(pe, ee, s, pool)
420 }
421 PN::Func(pname, pargs) => {
422 let EN::Func(ename, eargs) = en else {
423 return None;
424 };
425 if pname != ename {
426 return None;
427 }
428 match_args_exact(&pargs, &eargs, subst, pool)
429 }
430 PN::Literal => {
431 if pat == expr {
432 Some(subst)
433 } else {
434 None
435 }
436 }
437 }
438}
439
440fn match_args_exact(
441 pat_args: &[ExprId],
442 expr_args: &[ExprId],
443 subst: Substitution,
444 pool: &ExprPool,
445) -> Option<Substitution> {
446 if pat_args.len() != expr_args.len() {
447 return None;
448 }
449 let mut s = subst;
450 for (&p, &e) in pat_args.iter().zip(expr_args.iter()) {
451 s = match_root_node(p, e, s, pool)?;
452 }
453 Some(s)
454}
455
456fn func_arg(name: &str, expr: ExprId, pool: &ExprPool) -> Option<ExprId> {
461 pool.with(expr, |data| match data {
462 ExprData::Func { name: n, args } if n == name && args.len() == 1 => Some(args[0]),
463 _ => None,
464 })
465}
466
467fn neg_inner(expr: ExprId, pool: &ExprPool) -> Option<ExprId> {
469 let args = match pool.get(expr) {
470 ExprData::Mul(v) => v,
471 _ => return None,
472 };
473 let neg1_pos = args
474 .iter()
475 .position(|&a| pool.with(a, |d| matches!(d, ExprData::Integer(n) if n.0 == -1)))?;
476 let others: Vec<ExprId> = args
477 .into_iter()
478 .enumerate()
479 .filter(|&(i, _)| i != neg1_pos)
480 .map(|(_, a)| a)
481 .collect();
482 Some(match others.len() {
483 0 => pool.integer(1_i32),
484 1 => others[0],
485 _ => pool.mul(others),
486 })
487}
488
489fn is_sin_sq(expr: ExprId, pool: &ExprPool) -> bool {
490 match pool.get(expr) {
491 ExprData::Pow { base, exp } => {
492 let is_two = pool.with(exp, |d| matches!(d, ExprData::Integer(n) if n.0 == 2));
493 let is_sin = pool.with(
494 base,
495 |d| matches!(d, ExprData::Func { name, .. } if name == "sin"),
496 );
497 is_two && is_sin
498 }
499 _ => false,
500 }
501}
502
503fn sin_inner(expr: ExprId, pool: &ExprPool) -> Option<ExprId> {
504 match pool.get(expr) {
505 ExprData::Pow { base, .. } => func_arg("sin", base, pool),
506 _ => None,
507 }
508}
509
510fn is_cos_sq_of(expr: ExprId, arg: ExprId, pool: &ExprPool) -> bool {
511 match pool.get(expr) {
512 ExprData::Pow { base, exp } => {
513 let is_two = pool.with(exp, |d| matches!(d, ExprData::Integer(n) if n.0 == 2));
514 let is_cos_of_arg = func_arg("cos", base, pool).is_some_and(|a| a == arg);
515 is_two && is_cos_of_arg
516 }
517 _ => false,
518 }
519}
520
521#[cfg(test)]
526mod tests {
527 use super::*;
528 use crate::kernel::{Domain, ExprPool};
529 use crate::pattern::Pattern;
530 use crate::simplify::engine::{simplify_with, SimplifyConfig};
531
532 fn p() -> ExprPool {
533 ExprPool::new()
534 }
535
536 #[test]
537 fn sin_neg_fires() {
538 let pool = p();
539 let x = pool.symbol("x", Domain::Real);
540 let neg_x = pool.mul(vec![pool.integer(-1_i32), x]);
541 let expr = pool.func("sin", vec![neg_x]);
542 let rules = trig_rules();
543 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
544 let expected = pool.mul(vec![pool.integer(-1_i32), pool.func("sin", vec![x])]);
546 assert_eq!(r.value, expected);
547 }
548
549 #[test]
550 fn cos_neg_fires() {
551 let pool = p();
552 let x = pool.symbol("x", Domain::Real);
553 let neg_x = pool.mul(vec![pool.integer(-1_i32), x]);
554 let expr = pool.func("cos", vec![neg_x]);
555 let rules = trig_rules();
556 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
557 assert_eq!(r.value, pool.func("cos", vec![x]));
558 }
559
560 #[test]
561 fn tan_expand_fires() {
562 let pool = p();
563 let x = pool.symbol("x", Domain::Real);
564 let expr = pool.func("tan", vec![x]);
565 let rules = trig_rules();
566 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
567 let sin_x = pool.func("sin", vec![x]);
568 let cos_x = pool.func("cos", vec![x]);
569 let cos_inv = pool.pow(cos_x, pool.integer(-1_i32));
570 let expected = pool.mul(vec![sin_x, cos_inv]);
571 assert_eq!(r.value, expected);
572 }
573
574 #[test]
575 fn sin_cos_identity_fires() {
576 let pool = p();
577 let x = pool.symbol("x", Domain::Real);
578 let sin_x = pool.func("sin", vec![x]);
579 let cos_x = pool.func("cos", vec![x]);
580 let two = pool.integer(2_i32);
581 let sin_sq = pool.pow(sin_x, two);
582 let cos_sq = pool.pow(cos_x, two);
583 let expr = pool.add(vec![sin_sq, cos_sq]);
584 let rules = trig_rules();
585 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
586 assert_eq!(r.value, pool.integer(1_i32));
587 }
588
589 #[test]
590 fn log_of_exp_fires() {
591 let pool = p();
592 let x = pool.symbol("x", Domain::Real);
593 let expr = pool.func("log", vec![pool.func("exp", vec![x])]);
594 let rules = log_exp_rules();
595 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
596 assert_eq!(r.value, x);
597 }
598
599 #[test]
600 fn exp_of_log_fires() {
601 let pool = p();
602 let x = pool.symbol("x", Domain::Real);
603 let expr = pool.func("exp", vec![pool.func("log", vec![x])]);
604 let rules = log_exp_rules();
605 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
606 assert_eq!(r.value, x);
607 }
608
609 #[test]
610 fn log_of_product_fires() {
611 let pool = p();
612 let x = pool.symbol("x", Domain::Real);
613 let y = pool.symbol("y", Domain::Real);
614 let expr = pool.func("log", vec![pool.mul(vec![x, y])]);
615 let rules = log_exp_rules();
616 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
617 let log_x = pool.func("log", vec![x]);
618 let log_y = pool.func("log", vec![y]);
619 let expected = pool.add(vec![log_x, log_y]);
620 assert_eq!(r.value, expected);
621 }
622
623 #[test]
624 fn log_of_product_records_positive_side_conditions() {
625 let pool = p();
627 let x = pool.symbol("x", Domain::Real);
628 let y = pool.symbol("y", Domain::Real);
629 let expr = pool.func("log", vec![pool.mul(vec![x, y])]);
630 let rules = log_exp_rules();
631 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
632 let has_positive_conds = r.log.steps().iter().any(|s| {
633 s.rule_name == "log_of_product"
634 && s.side_conditions
635 .iter()
636 .any(|c| matches!(c, SideCondition::Positive(_)))
637 });
638 assert!(
639 has_positive_conds,
640 "log_of_product should record Positive side conditions"
641 );
642 }
643
644 #[test]
645 fn log_of_product_safe_does_not_fire() {
646 let pool = p();
648 let x = pool.symbol("x", Domain::Real);
649 let y = pool.symbol("y", Domain::Real);
650 let expr = pool.func("log", vec![pool.mul(vec![x, y])]);
651 let rules = log_exp_rules_safe();
652 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
653 assert_eq!(
654 r.value, expr,
655 "log(x*y) should NOT be split with log_exp_rules_safe"
656 );
657 }
658
659 #[test]
660 fn log_of_pow_fires() {
661 let pool = p();
662 let x = pool.symbol("x", Domain::Real);
663 let n = pool.integer(3_i32);
664 let expr = pool.func("log", vec![pool.pow(x, n)]);
665 let rules = log_exp_rules();
666 let r = simplify_with(expr, &pool, &rules, SimplifyConfig::default());
667 let log_x = pool.func("log", vec![x]);
668 let expected = pool.mul(vec![n, log_x]);
669 assert_eq!(r.value, expected);
670 }
671
672 #[test]
673 fn pattern_rule_simple() {
674 let pool = p();
675 let a = pool.symbol("a", Domain::Real);
676 let lhs = pool.add(vec![a, a]);
677 let rhs = pool.mul(vec![pool.integer(2_i32), a]);
678 let rule = PatternRule::new(Pattern::from_expr(lhs), rhs);
679 let x = pool.symbol("x", Domain::Real);
680 let expr = pool.add(vec![x, x]);
681 let r = simplify_with(expr, &pool, &[Box::new(rule)], SimplifyConfig::default());
682 let expected = pool.mul(vec![pool.integer(2_i32), x]);
683 assert_eq!(r.value, expected);
684 }
685}