use crate::ast::{Expression, Function, SymbolicConstant};
use crate::pattern::{Pattern, Rule};
pub fn arithmetic_rules() -> Vec<Rule> {
vec![sub_self_rule(), div_self_rule()]
}
fn sub_self_rule() -> Rule {
use crate::ast::BinaryOp;
Rule::new(
Pattern::binary(
BinaryOp::Sub,
Pattern::wildcard("x"),
Pattern::wildcard("x"),
),
Pattern::exact(Expression::Integer(0)),
)
.named("sub_self")
}
fn div_self_rule() -> Rule {
use crate::ast::BinaryOp;
Rule::new(
Pattern::binary(
BinaryOp::Div,
Pattern::wildcard("x"),
Pattern::wildcard("x"),
),
Pattern::exact(Expression::Integer(1)),
)
.named("div_self")
}
pub fn power_rules() -> Vec<Rule> {
use crate::pattern::common_rules;
vec![common_rules::power_zero(), common_rules::power_one()]
}
pub fn log_exp_rules() -> Vec<Rule> {
vec![ln_of_exp_rule(), exp_of_ln_rule()]
}
fn ln_of_exp_rule() -> Rule {
Rule::new(
Pattern::function(
Function::Ln,
vec![Pattern::power(
Pattern::exact(Expression::Constant(SymbolicConstant::E)),
Pattern::wildcard("x"),
)],
),
Pattern::wildcard("x"),
)
.named("ln_of_exp")
}
fn exp_of_ln_rule() -> Rule {
Rule::new(
Pattern::power(
Pattern::exact(Expression::Constant(SymbolicConstant::E)),
Pattern::function(Function::Ln, vec![Pattern::wildcard("x")]),
),
Pattern::wildcard("x"),
)
.named("exp_of_ln")
}
pub fn trig_pythagorean_rule() -> Vec<Rule> {
vec![sin_sq_plus_cos_sq_rule()]
}
fn sin_sq_plus_cos_sq_rule() -> Rule {
use crate::ast::BinaryOp;
let sin_sq = Pattern::power(
Pattern::function(Function::Sin, vec![Pattern::wildcard("x")]),
Pattern::exact(Expression::Integer(2)),
);
let cos_sq = Pattern::power(
Pattern::function(Function::Cos, vec![Pattern::wildcard("x")]),
Pattern::exact(Expression::Integer(2)),
);
Rule::new(
Pattern::binary(BinaryOp::Add, sin_sq, cos_sq),
Pattern::exact(Expression::Integer(1)),
)
.named("sin_sq_plus_cos_sq")
}
pub fn all_simplification_rules() -> Vec<Rule> {
use crate::pattern::common_rules;
let mut rules = Vec::new();
rules.extend(trig_pythagorean_rule());
rules.extend(log_exp_rules());
rules.extend(arithmetic_rules());
rules.extend(common_rules::all());
rules
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{BinaryOp, Variable};
use crate::pattern::{apply_rule, apply_rules_to_fixpoint};
fn var(name: &str) -> Expression {
Expression::Variable(Variable::new(name))
}
fn int(n: i64) -> Expression {
Expression::Integer(n)
}
#[test]
fn test_sub_self_variable() {
let rule = sub_self_rule();
let expr = Expression::Binary(BinaryOp::Sub, Box::new(var("x")), Box::new(var("x")));
assert_eq!(apply_rule(&expr, &rule), Some(int(0)));
}
#[test]
fn test_sub_self_does_not_match_different() {
let rule = sub_self_rule();
let expr = Expression::Binary(BinaryOp::Sub, Box::new(var("x")), Box::new(var("y")));
assert_eq!(apply_rule(&expr, &rule), None);
}
#[test]
fn test_div_self_variable() {
let rule = div_self_rule();
let expr = Expression::Binary(BinaryOp::Div, Box::new(var("x")), Box::new(var("x")));
assert_eq!(apply_rule(&expr, &rule), Some(int(1)));
}
#[test]
fn test_div_self_complex_expr() {
let rule = div_self_rule();
let subexpr = Expression::Binary(BinaryOp::Add, Box::new(var("x")), Box::new(int(1)));
let expr = Expression::Binary(BinaryOp::Div, Box::new(subexpr.clone()), Box::new(subexpr));
assert_eq!(apply_rule(&expr, &rule), Some(int(1)));
}
#[test]
fn test_ln_of_exp() {
let rule = ln_of_exp_rule();
let e = Expression::Constant(SymbolicConstant::E);
let expr = Expression::Function(
Function::Ln,
vec![Expression::Power(Box::new(e), Box::new(var("x")))],
);
assert_eq!(apply_rule(&expr, &rule), Some(var("x")));
}
#[test]
fn test_ln_of_exp_does_not_match_wrong_base() {
let rule = ln_of_exp_rule();
let expr = Expression::Function(
Function::Ln,
vec![Expression::Power(Box::new(int(2)), Box::new(var("x")))],
);
assert_eq!(apply_rule(&expr, &rule), None);
}
#[test]
fn test_exp_of_ln() {
let rule = exp_of_ln_rule();
let e = Expression::Constant(SymbolicConstant::E);
let expr = Expression::Power(
Box::new(e),
Box::new(Expression::Function(Function::Ln, vec![var("x")])),
);
assert_eq!(apply_rule(&expr, &rule), Some(var("x")));
}
#[test]
fn test_exp_of_ln_does_not_match_wrong_base() {
let rule = exp_of_ln_rule();
let expr = Expression::Power(
Box::new(int(2)),
Box::new(Expression::Function(Function::Ln, vec![var("x")])),
);
assert_eq!(apply_rule(&expr, &rule), None);
}
#[test]
fn test_sin_sq_plus_cos_sq() {
let rule = sin_sq_plus_cos_sq_rule();
let sin_sq = Expression::Power(
Box::new(Expression::Function(Function::Sin, vec![var("x")])),
Box::new(int(2)),
);
let cos_sq = Expression::Power(
Box::new(Expression::Function(Function::Cos, vec![var("x")])),
Box::new(int(2)),
);
let expr = Expression::Binary(BinaryOp::Add, Box::new(sin_sq), Box::new(cos_sq));
assert_eq!(apply_rule(&expr, &rule), Some(int(1)));
}
#[test]
fn test_sin_sq_plus_cos_sq_commutative() {
let rule = sin_sq_plus_cos_sq_rule();
let sin_sq = Expression::Power(
Box::new(Expression::Function(Function::Sin, vec![var("x")])),
Box::new(int(2)),
);
let cos_sq = Expression::Power(
Box::new(Expression::Function(Function::Cos, vec![var("x")])),
Box::new(int(2)),
);
let expr = Expression::Binary(BinaryOp::Add, Box::new(cos_sq), Box::new(sin_sq));
assert_eq!(apply_rule(&expr, &rule), Some(int(1)));
}
#[test]
fn test_sin_sq_plus_cos_sq_different_args_no_match() {
let rule = sin_sq_plus_cos_sq_rule();
let sin_sq = Expression::Power(
Box::new(Expression::Function(Function::Sin, vec![var("x")])),
Box::new(int(2)),
);
let cos_sq = Expression::Power(
Box::new(Expression::Function(Function::Cos, vec![var("y")])),
Box::new(int(2)),
);
let expr = Expression::Binary(BinaryOp::Add, Box::new(sin_sq), Box::new(cos_sq));
assert_eq!(apply_rule(&expr, &rule), None);
}
#[test]
fn test_fixpoint_x_plus_0() {
let rules = all_simplification_rules();
let expr = Expression::Binary(BinaryOp::Add, Box::new(var("x")), Box::new(int(0)));
assert_eq!(apply_rules_to_fixpoint(&expr, &rules, 20), var("x"));
}
#[test]
fn test_fixpoint_x_times_1() {
let rules = all_simplification_rules();
let expr = Expression::Binary(BinaryOp::Mul, Box::new(var("x")), Box::new(int(1)));
assert_eq!(apply_rules_to_fixpoint(&expr, &rules, 20), var("x"));
}
#[test]
fn test_fixpoint_x_times_0() {
let rules = all_simplification_rules();
let expr = Expression::Binary(BinaryOp::Mul, Box::new(var("x")), Box::new(int(0)));
assert_eq!(apply_rules_to_fixpoint(&expr, &rules, 20), int(0));
}
#[test]
fn test_fixpoint_double_neg() {
use crate::ast::UnaryOp;
let rules = all_simplification_rules();
let expr = Expression::Unary(
UnaryOp::Neg,
Box::new(Expression::Unary(UnaryOp::Neg, Box::new(var("x")))),
);
assert_eq!(apply_rules_to_fixpoint(&expr, &rules, 20), var("x"));
}
#[test]
fn test_fixpoint_x_pow_1() {
let rules = all_simplification_rules();
let expr = Expression::Power(Box::new(var("x")), Box::new(int(1)));
assert_eq!(apply_rules_to_fixpoint(&expr, &rules, 20), var("x"));
}
#[test]
fn test_fixpoint_x_pow_0() {
let rules = all_simplification_rules();
let expr = Expression::Power(Box::new(var("x")), Box::new(int(0)));
assert_eq!(apply_rules_to_fixpoint(&expr, &rules, 20), int(1));
}
}