use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum Term {
Var(Arc<str>),
App {
op: Arc<str>,
args: Vec<Self>,
},
}
impl Term {
#[must_use]
pub fn var(name: impl Into<Arc<str>>) -> Self {
Self::Var(name.into())
}
#[must_use]
pub fn app(op: impl Into<Arc<str>>, args: Vec<Self>) -> Self {
Self::App {
op: op.into(),
args,
}
}
#[must_use]
pub fn constant(op: impl Into<Arc<str>>) -> Self {
Self::App {
op: op.into(),
args: Vec::new(),
}
}
#[must_use]
pub fn substitute(&self, subst: &rustc_hash::FxHashMap<Arc<str>, Self>) -> Self {
match self {
Self::Var(name) => subst.get(name).cloned().unwrap_or_else(|| self.clone()),
Self::App { op, args } => Self::App {
op: Arc::clone(op),
args: args.iter().map(|a| a.substitute(subst)).collect(),
},
}
}
#[must_use]
pub fn free_vars(&self) -> rustc_hash::FxHashSet<Arc<str>> {
let mut vars = rustc_hash::FxHashSet::default();
self.collect_vars(&mut vars);
vars
}
fn collect_vars(&self, vars: &mut rustc_hash::FxHashSet<Arc<str>>) {
match self {
Self::Var(name) => {
vars.insert(Arc::clone(name));
}
Self::App { args, .. } => {
for arg in args {
arg.collect_vars(vars);
}
}
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
match self {
Self::Var(_) => self.clone(),
Self::App { op, args } => Self::App {
op: op_map.get(op).cloned().unwrap_or_else(|| Arc::clone(op)),
args: args.iter().map(|a| a.rename_ops(op_map)).collect(),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct Equation {
pub name: Arc<str>,
pub lhs: Term,
pub rhs: Term,
}
impl Equation {
#[must_use]
pub fn new(name: impl Into<Arc<str>>, lhs: Term, rhs: Term) -> Self {
Self {
name: name.into(),
lhs,
rhs,
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
Self {
name: Arc::clone(&self.name),
lhs: self.lhs.rename_ops(op_map),
rhs: self.rhs.rename_ops(op_map),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct DirectedEquation {
pub name: Arc<str>,
pub lhs: Term,
pub rhs: Term,
pub impl_term: panproto_expr::Expr,
pub inverse: Option<panproto_expr::Expr>,
}
#[must_use]
pub fn alpha_equivalent(t1: &Term, t2: &Term) -> bool {
let mut checker = AlphaChecker {
forward: rustc_hash::FxHashMap::default(),
backward: rustc_hash::FxHashMap::default(),
};
checker.check(t1, t2)
}
#[must_use]
pub fn alpha_equivalent_equation(lhs1: &Term, rhs1: &Term, lhs2: &Term, rhs2: &Term) -> bool {
let mut checker = AlphaChecker {
forward: rustc_hash::FxHashMap::default(),
backward: rustc_hash::FxHashMap::default(),
};
checker.check(lhs1, lhs2) && checker.check(rhs1, rhs2)
}
struct AlphaChecker {
forward: rustc_hash::FxHashMap<Arc<str>, Arc<str>>,
backward: rustc_hash::FxHashMap<Arc<str>, Arc<str>>,
}
impl AlphaChecker {
fn check(&mut self, t1: &Term, t2: &Term) -> bool {
match (t1, t2) {
(Term::Var(a), Term::Var(b)) => {
if let Some(mapped) = self.forward.get(a) {
if mapped != b {
return false;
}
} else if let Some(mapped_back) = self.backward.get(b) {
if mapped_back != a {
return false;
}
} else {
self.forward.insert(Arc::clone(a), Arc::clone(b));
self.backward.insert(Arc::clone(b), Arc::clone(a));
}
true
}
(
Term::App {
op: op1,
args: args1,
},
Term::App {
op: op2,
args: args2,
},
) => {
op1 == op2
&& args1.len() == args2.len()
&& args1
.iter()
.zip(args2.iter())
.all(|(a1, a2)| self.check(a1, a2))
}
_ => false,
}
}
}
#[must_use]
pub fn match_pattern(pattern: &Term, term: &Term) -> Option<rustc_hash::FxHashMap<Arc<str>, Term>> {
let mut subst = rustc_hash::FxHashMap::default();
if match_pattern_inner(pattern, term, &mut subst) {
Some(subst)
} else {
None
}
}
fn match_pattern_inner(
pattern: &Term,
term: &Term,
subst: &mut rustc_hash::FxHashMap<Arc<str>, Term>,
) -> bool {
match pattern {
Term::Var(name) => {
if subst.contains_key(name) {
subst.get(name).is_some_and(|existing| existing == term)
} else {
subst.insert(Arc::clone(name), term.clone());
true
}
}
Term::App {
op: p_op,
args: p_args,
} => match term {
Term::App {
op: t_op,
args: t_args,
} => {
p_op == t_op
&& p_args.len() == t_args.len()
&& p_args
.iter()
.zip(t_args.iter())
.all(|(p, t)| match_pattern_inner(p, t, subst))
}
Term::Var(_) => false,
},
}
}
#[must_use]
pub fn normalize(term: &Term, directed_eqs: &[DirectedEquation], max_steps: usize) -> Term {
let mut current = term.clone();
let mut steps = 0;
loop {
let next = normalize_once(¤t, directed_eqs, &mut steps, max_steps);
if next == current || steps >= max_steps {
return next;
}
current = next;
}
}
fn normalize_once(
term: &Term,
directed_eqs: &[DirectedEquation],
steps: &mut usize,
max_steps: usize,
) -> Term {
if *steps >= max_steps {
return term.clone();
}
let normalized_subterms = match term {
Term::Var(_) => term.clone(),
Term::App { op, args } => {
let new_args: Vec<Term> = args
.iter()
.map(|a| normalize_once(a, directed_eqs, steps, max_steps))
.collect();
Term::App {
op: Arc::clone(op),
args: new_args,
}
}
};
for de in directed_eqs {
if let Some(subst) = match_pattern(&de.lhs, &normalized_subterms) {
*steps += 1;
let rewritten = de.rhs.substitute(&subst);
return normalize_once(&rewritten, directed_eqs, steps, max_steps);
}
}
normalized_subterms
}
impl DirectedEquation {
#[must_use]
pub fn new(
name: impl Into<Arc<str>>,
lhs: Term,
rhs: Term,
impl_term: panproto_expr::Expr,
) -> Self {
Self {
name: name.into(),
lhs,
rhs,
impl_term,
inverse: None,
}
}
#[must_use]
pub fn with_inverse(
name: impl Into<Arc<str>>,
lhs: Term,
rhs: Term,
impl_term: panproto_expr::Expr,
inverse: panproto_expr::Expr,
) -> Self {
Self {
name: name.into(),
lhs,
rhs,
impl_term,
inverse: Some(inverse),
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
Self {
name: Arc::clone(&self.name),
lhs: self.lhs.rename_ops(op_map),
rhs: self.rhs.rename_ops(op_map),
impl_term: self.impl_term.clone(),
inverse: self.inverse.clone(),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn term_substitution() {
let term = Term::app("add", vec![Term::var("x"), Term::constant("zero")]);
let mut subst = rustc_hash::FxHashMap::default();
subst.insert(Arc::from("x"), Term::var("y"));
let result = term.substitute(&subst);
assert_eq!(
result,
Term::app("add", vec![Term::var("y"), Term::constant("zero")])
);
}
#[test]
fn free_variables() {
let term = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let vars = term.free_vars();
assert!(vars.contains("x"));
assert!(vars.contains("y"));
assert_eq!(vars.len(), 2);
}
#[test]
fn alpha_eq_same_vars() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_renamed_vars() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let t2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_non_injective_rejected() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let t2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_non_surjective_rejected() {
let t1 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("x")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_different_ops() {
let t1 = Term::app("f", vec![Term::var("x")]);
let t2 = Term::app("g", vec![Term::var("x")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_different_structure() {
let t1 = Term::app(
"f",
vec![Term::var("x"), Term::app("g", vec![Term::var("y")])],
);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_constants() {
let t1 = Term::app("f", vec![Term::constant("c")]);
let t2 = Term::app("f", vec![Term::constant("c")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_constants_differ() {
let t1 = Term::app("f", vec![Term::constant("c")]);
let t2 = Term::app("f", vec![Term::constant("d")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_nested_renamed() {
let t1 = Term::app(
"f",
vec![
Term::app("g", vec![Term::var("x"), Term::var("y")]),
Term::app("h", vec![Term::var("y"), Term::var("x")]),
],
);
let t2 = Term::app(
"f",
vec![
Term::app("g", vec![Term::var("a"), Term::var("b")]),
Term::app("h", vec![Term::var("b"), Term::var("a")]),
],
);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_equation_shared_bijection() {
let lhs1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let rhs1 = Term::app("g", vec![Term::var("y"), Term::var("x")]);
let lhs2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let rhs2 = Term::app("g", vec![Term::var("b"), Term::var("a")]);
assert!(alpha_equivalent_equation(&lhs1, &rhs1, &lhs2, &rhs2));
}
#[test]
fn alpha_eq_equation_inconsistent_bijection() {
let lhs1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let rhs1 = Term::app("g", vec![Term::var("y")]);
let lhs2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let rhs2 = Term::app("g", vec![Term::var("a")]);
assert!(!alpha_equivalent_equation(&lhs1, &rhs1, &lhs2, &rhs2));
}
#[test]
fn match_pattern_var_binds() {
let pat = Term::var("x");
let term = Term::app("f", vec![Term::constant("a")]);
let result = match_pattern(&pat, &term);
assert!(result.is_some());
let subst = result.unwrap();
assert_eq!(subst.get(&Arc::from("x")).unwrap(), &term);
}
#[test]
fn match_pattern_op_matches() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("b")]);
let result = match_pattern(&pat, &term);
assert!(result.is_some());
let subst = result.unwrap();
assert_eq!(subst.get(&Arc::from("x")).unwrap(), &Term::constant("a"));
assert_eq!(subst.get(&Arc::from("y")).unwrap(), &Term::constant("b"));
}
#[test]
fn match_pattern_op_mismatch() {
let pat = Term::app("f", vec![Term::var("x")]);
let term = Term::app("g", vec![Term::constant("a")]);
assert!(match_pattern(&pat, &term).is_none());
}
#[test]
fn match_pattern_repeated_var_consistent() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("a")]);
assert!(match_pattern(&pat, &term).is_some());
}
#[test]
fn match_pattern_repeated_var_inconsistent() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("b")]);
assert!(match_pattern(&pat, &term).is_none());
}
fn make_directed_eq(name: &str, lhs: Term, rhs: Term) -> DirectedEquation {
DirectedEquation::new(name, lhs, rhs, panproto_expr::Expr::Var("_".into()))
}
#[test]
fn normalize_no_rules() {
let term = Term::app("f", vec![Term::var("x")]);
let result = normalize(&term, &[], 100);
assert_eq!(result, term);
}
#[test]
fn normalize_simple_rewrite() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app("add", vec![Term::constant("zero"), Term::var("x")]);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::var("x"));
}
#[test]
fn normalize_nested() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app(
"f",
vec![Term::app(
"add",
vec![Term::constant("zero"), Term::var("x")],
)],
);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::app("f", vec![Term::var("x")]));
}
#[test]
fn normalize_multi_step() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app(
"add",
vec![
Term::constant("zero"),
Term::app("add", vec![Term::constant("zero"), Term::var("x")]),
],
);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::var("x"));
}
#[test]
fn normalize_max_steps_guard() {
let rule = make_directed_eq(
"expand",
Term::app("f", vec![Term::var("x")]),
Term::app("f", vec![Term::app("f", vec![Term::var("x")])]),
);
let term = Term::app("f", vec![Term::constant("a")]);
let result = normalize(&term, &[rule], 5);
assert!(matches!(result, Term::App { .. }));
}
#[test]
fn alpha_eq_var_vs_app() {
let t1 = Term::var("x");
let t2 = Term::constant("c");
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_arity_mismatch() {
let t1 = Term::app("f", vec![Term::var("x")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(!alpha_equivalent(&t1, &t2));
}
}