use crate::ast::{BinaryOp, Expression, Function, UnaryOp};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum Pattern {
Wildcard(String),
Any,
Exact(Expression),
Binary(BinaryOp, Box<Pattern>, Box<Pattern>),
Unary(UnaryOp, Box<Pattern>),
Function(Function, Vec<Pattern>),
Power(Box<Pattern>, Box<Pattern>),
Integer(i64),
AnyInteger(Option<String>),
AnyVariable(Option<String>),
}
impl Pattern {
pub fn wildcard(name: &str) -> Self {
Pattern::Wildcard(name.to_string())
}
pub fn exact(expr: Expression) -> Self {
Pattern::Exact(expr)
}
pub fn binary(op: BinaryOp, left: Pattern, right: Pattern) -> Self {
Pattern::Binary(op, Box::new(left), Box::new(right))
}
pub fn unary(op: UnaryOp, operand: Pattern) -> Self {
Pattern::Unary(op, Box::new(operand))
}
pub fn function(func: Function, args: Vec<Pattern>) -> Self {
Pattern::Function(func, args)
}
pub fn power(base: Pattern, exp: Pattern) -> Self {
Pattern::Power(Box::new(base), Box::new(exp))
}
pub fn add(left: Pattern, right: Pattern) -> Self {
Pattern::binary(BinaryOp::Add, left, right)
}
pub fn sub(left: Pattern, right: Pattern) -> Self {
Pattern::binary(BinaryOp::Sub, left, right)
}
pub fn mul(left: Pattern, right: Pattern) -> Self {
Pattern::binary(BinaryOp::Mul, left, right)
}
pub fn div(left: Pattern, right: Pattern) -> Self {
Pattern::binary(BinaryOp::Div, left, right)
}
}
pub fn match_pattern(expr: &Expression, pattern: &Pattern) -> Option<HashMap<String, Expression>> {
let mut bindings = HashMap::new();
if match_pattern_internal(expr, pattern, &mut bindings) {
Some(bindings)
} else {
None
}
}
fn match_pattern_internal(
expr: &Expression,
pattern: &Pattern,
bindings: &mut HashMap<String, Expression>,
) -> bool {
match pattern {
Pattern::Wildcard(name) => {
if let Some(existing) = bindings.get(name) {
expr == existing
} else {
bindings.insert(name.clone(), expr.clone());
true
}
}
Pattern::Any => true,
Pattern::Exact(target) => expr == target,
Pattern::Integer(n) => matches!(expr, Expression::Integer(m) if m == n),
Pattern::AnyInteger(opt_name) => {
if let Expression::Integer(n) = expr {
if let Some(name) = opt_name {
bindings.insert(name.clone(), Expression::Integer(*n));
}
true
} else {
false
}
}
Pattern::AnyVariable(opt_name) => {
if let Expression::Variable(_v) = expr {
if let Some(name) = opt_name {
bindings.insert(name.clone(), expr.clone());
}
true
} else {
false
}
}
Pattern::Binary(op, left_pat, right_pat) => {
if let Expression::Binary(expr_op, left_expr, right_expr) = expr {
if op == expr_op {
if match_pattern_internal(left_expr, left_pat, bindings)
&& match_pattern_internal(right_expr, right_pat, bindings)
{
return true;
}
if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
let mut comm_bindings = bindings.clone();
if match_pattern_internal(left_expr, right_pat, &mut comm_bindings)
&& match_pattern_internal(right_expr, left_pat, &mut comm_bindings)
{
*bindings = comm_bindings;
return true;
}
}
}
}
false
}
Pattern::Unary(op, operand_pat) => {
if let Expression::Unary(expr_op, operand) = expr {
op == expr_op && match_pattern_internal(operand, operand_pat, bindings)
} else {
false
}
}
Pattern::Function(func, arg_pats) => {
if let Expression::Function(expr_func, args) = expr {
if func == expr_func && arg_pats.len() == args.len() {
for (arg_pat, arg) in arg_pats.iter().zip(args.iter()) {
if !match_pattern_internal(arg, arg_pat, bindings) {
return false;
}
}
true
} else {
false
}
} else {
false
}
}
Pattern::Power(base_pat, exp_pat) => {
if let Expression::Power(base, exp) = expr {
match_pattern_internal(base, base_pat, bindings)
&& match_pattern_internal(exp, exp_pat, bindings)
} else {
false
}
}
}
}
pub fn apply_pattern(bindings: &HashMap<String, Expression>, pattern: &Pattern) -> Expression {
match pattern {
Pattern::Wildcard(name) => bindings
.get(name)
.cloned()
.unwrap_or_else(|| panic!("Unbound wildcard: {}", name)),
Pattern::Any => panic!("Cannot apply pattern with Any - use Wildcard instead"),
Pattern::Exact(expr) => expr.clone(),
Pattern::Integer(n) => Expression::Integer(*n),
Pattern::AnyInteger(opt_name) => {
if let Some(name) = opt_name {
bindings
.get(name)
.cloned()
.unwrap_or_else(|| panic!("Unbound integer wildcard: {}", name))
} else {
panic!("Cannot apply pattern with unnamed AnyInteger")
}
}
Pattern::AnyVariable(opt_name) => {
if let Some(name) = opt_name {
bindings
.get(name)
.cloned()
.unwrap_or_else(|| panic!("Unbound variable wildcard: {}", name))
} else {
panic!("Cannot apply pattern with unnamed AnyVariable")
}
}
Pattern::Binary(op, left, right) => Expression::Binary(
*op,
Box::new(apply_pattern(bindings, left)),
Box::new(apply_pattern(bindings, right)),
),
Pattern::Unary(op, operand) => {
Expression::Unary(*op, Box::new(apply_pattern(bindings, operand)))
}
Pattern::Function(func, args) => Expression::Function(
func.clone(),
args.iter().map(|a| apply_pattern(bindings, a)).collect(),
),
Pattern::Power(base, exp) => Expression::Power(
Box::new(apply_pattern(bindings, base)),
Box::new(apply_pattern(bindings, exp)),
),
}
}
#[derive(Clone)]
pub struct Rule {
pub pattern: Pattern,
pub replacement: Pattern,
pub condition: Option<fn(&HashMap<String, Expression>) -> bool>,
pub name: Option<String>,
}
impl Rule {
pub fn new(pattern: Pattern, replacement: Pattern) -> Self {
Rule {
pattern,
replacement,
condition: None,
name: None,
}
}
pub fn with_condition(
pattern: Pattern,
replacement: Pattern,
condition: fn(&HashMap<String, Expression>) -> bool,
) -> Self {
Rule {
pattern,
replacement,
condition: Some(condition),
name: None,
}
}
pub fn named(mut self, name: &str) -> Self {
self.name = Some(name.to_string());
self
}
}
pub fn apply_rule(expr: &Expression, rule: &Rule) -> Option<Expression> {
let bindings = match_pattern(expr, &rule.pattern)?;
if let Some(condition) = rule.condition {
if !condition(&bindings) {
return None;
}
}
Some(apply_pattern(&bindings, &rule.replacement))
}
pub fn apply_rule_recursive(expr: &Expression, rule: &Rule) -> Expression {
let transformed = match expr {
Expression::Binary(op, left, right) => Expression::Binary(
*op,
Box::new(apply_rule_recursive(left, rule)),
Box::new(apply_rule_recursive(right, rule)),
),
Expression::Unary(op, operand) => {
Expression::Unary(*op, Box::new(apply_rule_recursive(operand, rule)))
}
Expression::Function(func, args) => Expression::Function(
func.clone(),
args.iter().map(|a| apply_rule_recursive(a, rule)).collect(),
),
Expression::Power(base, exp) => Expression::Power(
Box::new(apply_rule_recursive(base, rule)),
Box::new(apply_rule_recursive(exp, rule)),
),
_ => expr.clone(),
};
apply_rule(&transformed, rule).unwrap_or(transformed)
}
pub fn apply_rules_to_fixpoint(
expr: &Expression,
rules: &[Rule],
max_iterations: usize,
) -> Expression {
let mut current = expr.clone();
for _ in 0..max_iterations {
let mut changed = false;
for rule in rules {
let new_expr = apply_rule_recursive(¤t, rule);
if new_expr != current {
current = new_expr;
changed = true;
break; }
}
if !changed {
break;
}
}
current
}
pub mod common_rules {
use super::*;
pub fn additive_identity() -> Rule {
Rule::new(
Pattern::add(
Pattern::wildcard("x"),
Pattern::exact(Expression::Integer(0)),
),
Pattern::wildcard("x"),
)
.named("additive_identity")
}
pub fn multiplicative_identity() -> Rule {
Rule::new(
Pattern::mul(
Pattern::wildcard("x"),
Pattern::exact(Expression::Integer(1)),
),
Pattern::wildcard("x"),
)
.named("multiplicative_identity")
}
pub fn multiplicative_zero() -> Rule {
Rule::new(
Pattern::mul(
Pattern::wildcard("x"),
Pattern::exact(Expression::Integer(0)),
),
Pattern::exact(Expression::Integer(0)),
)
.named("multiplicative_zero")
}
pub fn double_negation() -> Rule {
Rule::new(
Pattern::unary(
UnaryOp::Neg,
Pattern::unary(UnaryOp::Neg, Pattern::wildcard("x")),
),
Pattern::wildcard("x"),
)
.named("double_negation")
}
pub fn power_zero() -> Rule {
Rule::new(
Pattern::power(
Pattern::wildcard("x"),
Pattern::exact(Expression::Integer(0)),
),
Pattern::exact(Expression::Integer(1)),
)
.named("power_zero")
}
pub fn power_one() -> Rule {
Rule::new(
Pattern::power(
Pattern::wildcard("x"),
Pattern::exact(Expression::Integer(1)),
),
Pattern::wildcard("x"),
)
.named("power_one")
}
pub fn all() -> Vec<Rule> {
vec![
additive_identity(),
multiplicative_identity(),
multiplicative_zero(),
double_negation(),
power_zero(),
power_one(),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Variable;
fn var(name: &str) -> Expression {
Expression::Variable(Variable::new(name))
}
fn int(n: i64) -> Expression {
Expression::Integer(n)
}
fn add(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Add, Box::new(left), Box::new(right))
}
fn mul(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Mul, Box::new(left), Box::new(right))
}
fn power(base: Expression, exp: Expression) -> Expression {
Expression::Power(Box::new(base), Box::new(exp))
}
#[test]
fn test_wildcard_matching() {
let pattern = Pattern::add(Pattern::wildcard("a"), Pattern::wildcard("b"));
let expr = add(var("x"), var("y"));
let bindings = match_pattern(&expr, &pattern).unwrap();
assert_eq!(bindings.get("a"), Some(&var("x")));
assert_eq!(bindings.get("b"), Some(&var("y")));
}
#[test]
fn test_exact_matching() {
let pattern = Pattern::add(Pattern::wildcard("x"), Pattern::exact(int(0)));
let expr1 = add(var("y"), int(0));
assert!(match_pattern(&expr1, &pattern).is_some());
let expr2 = add(var("y"), int(1));
assert!(match_pattern(&expr2, &pattern).is_none());
}
#[test]
fn test_commutativity() {
let pattern = Pattern::add(Pattern::wildcard("a"), Pattern::wildcard("b"));
let expr = add(var("y"), var("x"));
let bindings = match_pattern(&expr, &pattern).unwrap();
assert!(bindings.contains_key("a"));
assert!(bindings.contains_key("b"));
}
#[test]
fn test_same_wildcard_must_match_same_expr() {
let pattern = Pattern::add(Pattern::wildcard("a"), Pattern::wildcard("a"));
let expr1 = add(var("x"), var("x"));
assert!(match_pattern(&expr1, &pattern).is_some());
let expr2 = add(var("x"), var("y"));
assert!(match_pattern(&expr2, &pattern).is_none());
}
#[test]
fn test_apply_pattern() {
let mut bindings = HashMap::new();
bindings.insert("x".to_string(), var("y"));
let pattern = Pattern::add(Pattern::wildcard("x"), Pattern::exact(int(1)));
let result = apply_pattern(&bindings, &pattern);
assert_eq!(result, add(var("y"), int(1)));
}
#[test]
fn test_additive_identity_rule() {
let rule = common_rules::additive_identity();
let expr = add(var("x"), int(0));
let result = apply_rule(&expr, &rule);
assert_eq!(result, Some(var("x")));
}
#[test]
fn test_multiplicative_zero_rule() {
let rule = common_rules::multiplicative_zero();
let expr = mul(add(var("x"), var("y")), int(0));
let result = apply_rule(&expr, &rule);
assert_eq!(result, Some(int(0)));
}
#[test]
fn test_power_rules() {
let zero_rule = common_rules::power_zero();
let one_rule = common_rules::power_one();
let expr1 = power(var("x"), int(0));
assert_eq!(apply_rule(&expr1, &zero_rule), Some(int(1)));
let expr2 = power(var("x"), int(1));
assert_eq!(apply_rule(&expr2, &one_rule), Some(var("x")));
}
#[test]
fn test_nested_matching() {
let pattern = Pattern::function(
Function::Sin,
vec![Pattern::function(
Function::Cos,
vec![Pattern::wildcard("a")],
)],
);
let expr = Expression::Function(
Function::Sin,
vec![Expression::Function(Function::Cos, vec![var("x")])],
);
let bindings = match_pattern(&expr, &pattern).unwrap();
assert_eq!(bindings.get("a"), Some(&var("x")));
}
#[test]
fn test_recursive_rule_application() {
let rule = common_rules::additive_identity();
let expr = add(add(var("x"), int(0)), add(var("y"), int(0)));
let result = apply_rule_recursive(&expr, &rule);
assert_eq!(result, add(var("x"), var("y")));
}
#[test]
fn test_fixpoint_simplification() {
let rules = common_rules::all();
let expr = add(mul(var("x"), int(1)), int(0));
let result = apply_rules_to_fixpoint(&expr, &rules, 10);
assert_eq!(result, var("x"));
}
}