rssn 0.2.9

A comprehensive scientific computing library for Rust, aiming for feature parity with NumPy and SymPy.
Documentation
#![allow(clippy::match_same_arms)]

//! # Term Rewriting Systems
//!
//! This module provides tools for working with term rewriting systems (TRS).
//! It includes structures for `RewriteRule`s, functions to apply these rules
//! to expressions, and an implementation of the Knuth-Bendix completion algorithm
//! to attempt to convert a set of equations into a confluent and Noetherian
//! term rewriting system.

use std::collections::HashMap;

use serde::Deserialize;
use serde::Serialize;

use crate::symbolic::calculus::substitute;
use crate::symbolic::core::Expr;
use crate::symbolic::polynomial::contains_var;
use crate::symbolic::simplify_dag::pattern_match;
use crate::symbolic::simplify_dag::substitute_patterns;

/// Represents a rewrite rule, e.g., `lhs -> rhs`.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RewriteRule {
    /// The left-hand side pattern to match.
    pub lhs: Expr,
    /// The right-hand side expression to replace it with.
    pub rhs: Expr,
}

/// A simple term ordering based on the complexity (number of nodes) of an expression.
/// Returns true if e1 > e2.
pub(crate) fn is_greater(
    e1: &Expr,
    e2: &Expr,
) -> bool {
    complexity(e1) > complexity(e2)
}

/// Applies a set of rewrite rules repeatedly to an expression until a fixed point is reached.
///
/// This function computes the "normal form" of the expression with respect to the given
/// set of rewrite rules. It iteratively applies rules until no further changes can be made.
///
/// # Arguments
/// * `expr` - The expression to transform.
/// * `rules` - A slice of `RewriteRule`s to apply.
///
/// # Returns
/// A new `Expr` representing the normal form of the expression.
#[must_use]
pub fn apply_rules_to_normal_form(
    expr: &Expr,
    rules: &[RewriteRule],
) -> Expr {
    let mut current_expr = expr.clone();

    let mut changed = true;

    while changed {
        changed = false;

        let (next_expr, applied) = apply_rules_once(&current_expr, rules);

        if applied {
            current_expr = next_expr;

            changed = true;
        }
    }

    current_expr
}

/// Applies the first applicable rule to the expression tree in a pre-order traversal.
pub(crate) fn apply_rules_once(
    expr: &Expr,
    rules: &[RewriteRule],
) -> (Expr, bool) {
    for rule in rules {
        if let Some(assignments) = pattern_match(expr, &rule.lhs) {
            return (substitute_patterns(&rule.rhs, &assignments), true);
        }
    }

    match expr {
        | Expr::Dag(node) => {
            return apply_rules_once(
                &node.to_expr().expect(
                    "Apply rules \
                         once",
                ),
                rules,
            );
        },
        | Expr::Add(a, b) => {
            let (na, ca) = apply_rules_once(a, rules);

            if ca {
                return (Expr::new_add(na, b.clone()), true);
            }

            let (nb, cb) = apply_rules_once(b, rules);

            if cb {
                return (Expr::new_add(a.clone(), nb), true);
            }
        },
        | Expr::Mul(a, b) => {
            let (na, ca) = apply_rules_once(a, rules);

            if ca {
                return (Expr::new_mul(na, b.clone()), true);
            }

            let (nb, cb) = apply_rules_once(b, rules);

            if cb {
                return (Expr::new_mul(a.clone(), nb), true);
            }
        },
        | _ => {},
    }

    (expr.clone(), false)
}

/// Attempts to produce a complete term-rewriting system from a set of equations
/// using the Knuth-Bendix completion algorithm.
///
/// The Knuth-Bendix algorithm takes a set of equations and tries to convert them
/// into a confluent and Noetherian (terminating) set of rewrite rules. This is done
/// by generating and resolving "critical pairs" (overlaps between rules).
///
/// # Arguments
/// * `equations` - A slice of `Expr::Eq` representing the initial equations.
///
/// # Returns
/// A `Result` containing a `Vec<RewriteRule>` if the completion is successful.
///
/// # Errors
///
/// This function will return an error if any element in the `equations` slice is
/// not a valid `Expr::Eq`.
pub fn knuth_bendix(equations: &[Expr]) -> Result<Vec<RewriteRule>, String> {
    let mut rules: Vec<RewriteRule> = Vec::new();

    for eq in equations {
        if let Expr::Eq(lhs, rhs) = eq {
            if is_greater(lhs, rhs) {
                rules.push(RewriteRule {
                    lhs: lhs.as_ref().clone(),
                    rhs: rhs.as_ref().clone(),
                });
            } else if is_greater(rhs, lhs) {
                rules.push(RewriteRule {
                    lhs: rhs.as_ref().clone(),
                    rhs: lhs.as_ref().clone(),
                });
            }
        } else {
            return Err("Input must be a list \
                 of equations \
                 (Expr::Eq)."
                .to_string());
        }
    }

    let mut i = 0;

    while i < rules.len() {
        let mut j = 0;

        while j <= i {
            let (rule1, rule2) = (&rules[i].clone(), &rules[j].clone());

            let critical_pairs = find_critical_pairs(rule1, rule2);

            for (t1, t2) in critical_pairs {
                let n1 = apply_rules_to_normal_form(&t1, &rules);

                let n2 = apply_rules_to_normal_form(&t2, &rules);

                if n1 != n2 {
                    let new_rule = if is_greater(&n1, &n2) {
                        RewriteRule { lhs: n1, rhs: n2 }
                    } else {
                        RewriteRule { lhs: n2, rhs: n1 }
                    };

                    if new_rule.lhs != new_rule.rhs && !rules.iter().any(|r| r.lhs == new_rule.lhs)
                    {
                        rules.push(new_rule);

                        i = 0;

                        j = 0;
                    }
                }
            }

            j += 1;
        }

        i += 1;
    }

    Ok(rules)
}

/// Finds critical pairs between two rewrite rules.
pub(crate) fn find_critical_pairs(
    r1: &RewriteRule,
    r2: &RewriteRule,
) -> Vec<(Expr, Expr)> {
    let mut pairs = Vec::new();

    let mut sub_expressions = Vec::new();

    r1.lhs.pre_order_walk(&mut |sub_expr| {
        sub_expressions.push(sub_expr.clone());
    });

    for sub_expr in &sub_expressions {
        if let Some(subst) = unify(sub_expr, &r2.lhs) {
            let t1 = substitute(&r1.lhs, &sub_expr.to_string(), &r2.rhs);

            let t1_subst = substitute_patterns(&t1, &subst);

            let t2 = substitute_patterns(&r1.rhs, &subst);

            if t1_subst != t2 {
                pairs.push((t1_subst, t2));
            }
        }
    }

    pairs
}

/// Unifies two expressions, finding a substitution that makes them equal.
/// Returns a map of substitutions if successful.
pub(crate) fn unify(
    e1: &Expr,
    e2: &Expr,
) -> Option<HashMap<String, Expr>> {
    let mut subst = HashMap::new();

    if unify_recursive(e1, e2, &mut subst) {
        Some(subst)
    } else {
        None
    }
}

pub(crate) fn unify_recursive(
    e1: &Expr,
    e2: &Expr,
    subst: &mut HashMap<String, Expr>,
) -> bool {
    match (e1, e2) {
        | (Expr::Pattern(p), _) => {
            if let Some(val) = subst.get(p) {
                return val == e2;
            }

            if contains_var(e2, p) {
                return false;
            }

            subst.insert(p.clone(), e2.clone());

            true
        },
        | (_, Expr::Pattern(p)) => {
            if let Some(val) = subst.get(p) {
                return val == e1;
            }

            if contains_var(e1, p) {
                return false;
            }

            subst.insert(p.clone(), e1.clone());

            true
        },
        | (Expr::Add(a1, b1), Expr::Add(a2, b2)) | (Expr::Mul(a1, b1), Expr::Mul(a2, b2)) => {
            let original_subst = subst.clone();

            if unify_recursive(a1, a2, subst) && unify_recursive(b1, b2, subst) {
                true
            } else {
                *subst = original_subst;

                unify_recursive(a1, b2, subst) && unify_recursive(b1, a2, subst)
            }
        },
        | (Expr::Sub(a1, b1), Expr::Sub(a2, b2))
        | (Expr::Div(a1, b1), Expr::Div(a2, b2))
        | (Expr::Power(a1, b1), Expr::Power(a2, b2)) => {
            unify_recursive(a1, a2, subst) && unify_recursive(b1, b2, subst)
        },
        | (Expr::Sin(a1), Expr::Sin(a2))
        | (Expr::Cos(a1), Expr::Cos(a2))
        | (Expr::Tan(a1), Expr::Tan(a2))
        | (Expr::Log(a1), Expr::Log(a2))
        | (Expr::Exp(a1), Expr::Exp(a2))
        | (Expr::Neg(a1), Expr::Neg(a2)) => unify_recursive(a1, a2, subst),
        | _ => e1 == e2,
    }
}

/// Calculates a simple complexity measure for an expression.
pub(crate) fn complexity(expr: &Expr) -> usize {
    match expr {
        | Expr::Dag(node) => complexity(&node.to_expr().expect("Complexity")),
        | Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Sub(a, b) | Expr::Div(a, b) => {
            complexity(a) + complexity(b) + 1
        },
        | Expr::Power(b, e) => complexity(b) + complexity(e) + 2,
        | Expr::Sin(a)
        | Expr::Cos(a)
        | Expr::Tan(a)
        | Expr::Log(a)
        | Expr::Exp(a)
        | Expr::Neg(a) => complexity(a) + 1,
        | Expr::UnaryList(_, a) => complexity(a) + 1,
        | Expr::BinaryList(_, a, b) => complexity(a) + complexity(b) + 1,
        | Expr::NaryList(_, v) => v.iter().map(complexity).sum::<usize>() + 1,
        | _ => 1,
    }
}