use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum Term {
Var(Arc<str>),
App {
op: Arc<str>,
args: Vec<Self>,
},
}
impl Term {
#[must_use]
pub fn var(name: impl Into<Arc<str>>) -> Self {
Self::Var(name.into())
}
#[must_use]
pub fn app(op: impl Into<Arc<str>>, args: Vec<Self>) -> Self {
Self::App {
op: op.into(),
args,
}
}
#[must_use]
pub fn constant(op: impl Into<Arc<str>>) -> Self {
Self::App {
op: op.into(),
args: Vec::new(),
}
}
#[must_use]
pub fn substitute(&self, subst: &rustc_hash::FxHashMap<Arc<str>, Self>) -> Self {
match self {
Self::Var(name) => subst.get(name).cloned().unwrap_or_else(|| self.clone()),
Self::App { op, args } => Self::App {
op: Arc::clone(op),
args: args.iter().map(|a| a.substitute(subst)).collect(),
},
}
}
#[must_use]
pub fn free_vars(&self) -> rustc_hash::FxHashSet<Arc<str>> {
let mut vars = rustc_hash::FxHashSet::default();
self.collect_vars(&mut vars);
vars
}
fn collect_vars(&self, vars: &mut rustc_hash::FxHashSet<Arc<str>>) {
match self {
Self::Var(name) => {
vars.insert(Arc::clone(name));
}
Self::App { args, .. } => {
for arg in args {
arg.collect_vars(vars);
}
}
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
match self {
Self::Var(_) => self.clone(),
Self::App { op, args } => Self::App {
op: op_map.get(op).cloned().unwrap_or_else(|| Arc::clone(op)),
args: args.iter().map(|a| a.rename_ops(op_map)).collect(),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct Equation {
pub name: Arc<str>,
pub lhs: Term,
pub rhs: Term,
}
impl Equation {
#[must_use]
pub fn new(name: impl Into<Arc<str>>, lhs: Term, rhs: Term) -> Self {
Self {
name: name.into(),
lhs,
rhs,
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
Self {
name: Arc::clone(&self.name),
lhs: self.lhs.rename_ops(op_map),
rhs: self.rhs.rename_ops(op_map),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct DirectedEquation {
pub name: Arc<str>,
pub lhs: Term,
pub rhs: Term,
pub impl_term: panproto_expr::Expr,
pub inverse: Option<panproto_expr::Expr>,
pub source_kind: Option<crate::sort::ValueKind>,
pub target_kind: Option<crate::sort::ValueKind>,
pub coercion_class: crate::sort::CoercionClass,
}
#[must_use]
pub fn alpha_equivalent(t1: &Term, t2: &Term) -> bool {
let mut checker = AlphaChecker {
forward: rustc_hash::FxHashMap::default(),
backward: rustc_hash::FxHashMap::default(),
};
checker.check(t1, t2)
}
#[must_use]
pub fn alpha_equivalent_equation(lhs1: &Term, rhs1: &Term, lhs2: &Term, rhs2: &Term) -> bool {
let mut checker = AlphaChecker {
forward: rustc_hash::FxHashMap::default(),
backward: rustc_hash::FxHashMap::default(),
};
checker.check(lhs1, lhs2) && checker.check(rhs1, rhs2)
}
struct AlphaChecker {
forward: rustc_hash::FxHashMap<Arc<str>, Arc<str>>,
backward: rustc_hash::FxHashMap<Arc<str>, Arc<str>>,
}
impl AlphaChecker {
fn check(&mut self, t1: &Term, t2: &Term) -> bool {
match (t1, t2) {
(Term::Var(a), Term::Var(b)) => {
if let Some(mapped) = self.forward.get(a) {
if mapped != b {
return false;
}
} else if let Some(mapped_back) = self.backward.get(b) {
if mapped_back != a {
return false;
}
} else {
self.forward.insert(Arc::clone(a), Arc::clone(b));
self.backward.insert(Arc::clone(b), Arc::clone(a));
}
true
}
(
Term::App {
op: op1,
args: args1,
},
Term::App {
op: op2,
args: args2,
},
) => {
op1 == op2
&& args1.len() == args2.len()
&& args1
.iter()
.zip(args2.iter())
.all(|(a1, a2)| self.check(a1, a2))
}
_ => false,
}
}
}
#[must_use]
pub fn match_pattern(pattern: &Term, term: &Term) -> Option<rustc_hash::FxHashMap<Arc<str>, Term>> {
let mut subst = rustc_hash::FxHashMap::default();
if match_pattern_inner(pattern, term, &mut subst) {
Some(subst)
} else {
None
}
}
fn match_pattern_inner(
pattern: &Term,
term: &Term,
subst: &mut rustc_hash::FxHashMap<Arc<str>, Term>,
) -> bool {
match pattern {
Term::Var(name) => {
if subst.contains_key(name) {
subst.get(name).is_some_and(|existing| existing == term)
} else {
subst.insert(Arc::clone(name), term.clone());
true
}
}
Term::App {
op: p_op,
args: p_args,
} => match term {
Term::App {
op: t_op,
args: t_args,
} => {
p_op == t_op
&& p_args.len() == t_args.len()
&& p_args
.iter()
.zip(t_args.iter())
.all(|(p, t)| match_pattern_inner(p, t, subst))
}
Term::Var(_) => false,
},
}
}
#[must_use]
pub fn normalize(term: &Term, directed_eqs: &[DirectedEquation], max_steps: usize) -> Term {
let mut current = term.clone();
let mut steps = 0;
loop {
let next = normalize_once(¤t, directed_eqs, &mut steps, max_steps);
if next == current || steps >= max_steps {
return next;
}
current = next;
}
}
fn normalize_once(
term: &Term,
directed_eqs: &[DirectedEquation],
steps: &mut usize,
max_steps: usize,
) -> Term {
if *steps >= max_steps {
return term.clone();
}
let normalized_subterms = match term {
Term::Var(_) => term.clone(),
Term::App { op, args } => {
let new_args: Vec<Term> = args
.iter()
.map(|a| normalize_once(a, directed_eqs, steps, max_steps))
.collect();
Term::App {
op: Arc::clone(op),
args: new_args,
}
}
};
for de in directed_eqs {
if let Some(subst) = match_pattern(&de.lhs, &normalized_subterms) {
*steps += 1;
let rewritten = de.rhs.substitute(&subst);
return normalize_once(&rewritten, directed_eqs, steps, max_steps);
}
}
normalized_subterms
}
impl DirectedEquation {
#[must_use]
pub fn new(
name: impl Into<Arc<str>>,
lhs: Term,
rhs: Term,
impl_term: panproto_expr::Expr,
) -> Self {
Self {
name: name.into(),
lhs,
rhs,
impl_term,
inverse: None,
source_kind: None,
target_kind: None,
coercion_class: crate::sort::CoercionClass::Opaque,
}
}
#[must_use]
pub fn with_inverse(
name: impl Into<Arc<str>>,
lhs: Term,
rhs: Term,
impl_term: panproto_expr::Expr,
inverse: panproto_expr::Expr,
) -> Self {
Self {
name: name.into(),
lhs,
rhs,
impl_term,
inverse: Some(inverse),
source_kind: None,
target_kind: None,
coercion_class: crate::sort::CoercionClass::Retraction,
}
}
#[must_use]
pub const fn with_kinds(
mut self,
source: crate::sort::ValueKind,
target: crate::sort::ValueKind,
class: crate::sort::CoercionClass,
) -> Self {
self.source_kind = Some(source);
self.target_kind = Some(target);
self.coercion_class = class;
self
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
Self {
name: Arc::clone(&self.name),
lhs: self.lhs.rename_ops(op_map),
rhs: self.rhs.rename_ops(op_map),
impl_term: self.impl_term.clone(),
inverse: self.inverse.clone(),
source_kind: self.source_kind,
target_kind: self.target_kind,
coercion_class: self.coercion_class,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn term_substitution() {
let term = Term::app("add", vec![Term::var("x"), Term::constant("zero")]);
let mut subst = rustc_hash::FxHashMap::default();
subst.insert(Arc::from("x"), Term::var("y"));
let result = term.substitute(&subst);
assert_eq!(
result,
Term::app("add", vec![Term::var("y"), Term::constant("zero")])
);
}
#[test]
fn free_variables() {
let term = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let vars = term.free_vars();
assert!(vars.contains("x"));
assert!(vars.contains("y"));
assert_eq!(vars.len(), 2);
}
#[test]
fn alpha_eq_same_vars() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_renamed_vars() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let t2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_non_injective_rejected() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let t2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_non_surjective_rejected() {
let t1 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("x")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_different_ops() {
let t1 = Term::app("f", vec![Term::var("x")]);
let t2 = Term::app("g", vec![Term::var("x")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_different_structure() {
let t1 = Term::app(
"f",
vec![Term::var("x"), Term::app("g", vec![Term::var("y")])],
);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_constants() {
let t1 = Term::app("f", vec![Term::constant("c")]);
let t2 = Term::app("f", vec![Term::constant("c")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_constants_differ() {
let t1 = Term::app("f", vec![Term::constant("c")]);
let t2 = Term::app("f", vec![Term::constant("d")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_nested_renamed() {
let t1 = Term::app(
"f",
vec![
Term::app("g", vec![Term::var("x"), Term::var("y")]),
Term::app("h", vec![Term::var("y"), Term::var("x")]),
],
);
let t2 = Term::app(
"f",
vec![
Term::app("g", vec![Term::var("a"), Term::var("b")]),
Term::app("h", vec![Term::var("b"), Term::var("a")]),
],
);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_equation_shared_bijection() {
let lhs1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let rhs1 = Term::app("g", vec![Term::var("y"), Term::var("x")]);
let lhs2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let rhs2 = Term::app("g", vec![Term::var("b"), Term::var("a")]);
assert!(alpha_equivalent_equation(&lhs1, &rhs1, &lhs2, &rhs2));
}
#[test]
fn alpha_eq_equation_inconsistent_bijection() {
let lhs1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let rhs1 = Term::app("g", vec![Term::var("y")]);
let lhs2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let rhs2 = Term::app("g", vec![Term::var("a")]);
assert!(!alpha_equivalent_equation(&lhs1, &rhs1, &lhs2, &rhs2));
}
#[test]
fn match_pattern_var_binds() {
let pat = Term::var("x");
let term = Term::app("f", vec![Term::constant("a")]);
let result = match_pattern(&pat, &term);
assert!(result.is_some());
let subst = result.unwrap();
assert_eq!(subst.get(&Arc::from("x")).unwrap(), &term);
}
#[test]
fn match_pattern_op_matches() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("b")]);
let result = match_pattern(&pat, &term);
assert!(result.is_some());
let subst = result.unwrap();
assert_eq!(subst.get(&Arc::from("x")).unwrap(), &Term::constant("a"));
assert_eq!(subst.get(&Arc::from("y")).unwrap(), &Term::constant("b"));
}
#[test]
fn match_pattern_op_mismatch() {
let pat = Term::app("f", vec![Term::var("x")]);
let term = Term::app("g", vec![Term::constant("a")]);
assert!(match_pattern(&pat, &term).is_none());
}
#[test]
fn match_pattern_repeated_var_consistent() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("a")]);
assert!(match_pattern(&pat, &term).is_some());
}
#[test]
fn match_pattern_repeated_var_inconsistent() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("b")]);
assert!(match_pattern(&pat, &term).is_none());
}
fn make_directed_eq(name: &str, lhs: Term, rhs: Term) -> DirectedEquation {
DirectedEquation::new(name, lhs, rhs, panproto_expr::Expr::Var("_".into()))
}
#[test]
fn normalize_no_rules() {
let term = Term::app("f", vec![Term::var("x")]);
let result = normalize(&term, &[], 100);
assert_eq!(result, term);
}
#[test]
fn normalize_simple_rewrite() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app("add", vec![Term::constant("zero"), Term::var("x")]);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::var("x"));
}
#[test]
fn normalize_nested() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app(
"f",
vec![Term::app(
"add",
vec![Term::constant("zero"), Term::var("x")],
)],
);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::app("f", vec![Term::var("x")]));
}
#[test]
fn normalize_multi_step() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app(
"add",
vec![
Term::constant("zero"),
Term::app("add", vec![Term::constant("zero"), Term::var("x")]),
],
);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::var("x"));
}
#[test]
fn normalize_max_steps_guard() {
let rule = make_directed_eq(
"expand",
Term::app("f", vec![Term::var("x")]),
Term::app("f", vec![Term::app("f", vec![Term::var("x")])]),
);
let term = Term::app("f", vec![Term::constant("a")]);
let result = normalize(&term, &[rule], 5);
assert!(matches!(result, Term::App { .. }));
}
#[test]
fn alpha_eq_var_vs_app() {
let t1 = Term::var("x");
let t2 = Term::constant("c");
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_arity_mismatch() {
let t1 = Term::app("f", vec![Term::var("x")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(!alpha_equivalent(&t1, &t2));
}
mod property {
use super::*;
use proptest::prelude::*;
const VAR_NAMES: &[&str] = &["x", "y", "z", "a", "b"];
const OP_NAMES: &[&str] = &["f", "g", "h", "add", "mul"];
fn arb_name() -> impl Strategy<Value = Arc<str>> {
prop::sample::select(VAR_NAMES).prop_map(Arc::from)
}
fn arb_term(max_depth: usize) -> BoxedStrategy<Term> {
if max_depth == 0 {
arb_name().prop_map(Term::Var).boxed()
} else {
let leaf = arb_name().prop_map(Term::Var);
let app = (
prop::sample::select(OP_NAMES).prop_map(Arc::from),
prop::collection::vec(arb_term(max_depth - 1), 0..=3),
)
.prop_map(|(op, args)| Term::App { op, args });
prop_oneof![leaf, app].boxed()
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn alpha_equivalence_is_reflexive(t in arb_term(3)) {
prop_assert!(alpha_equivalent(&t, &t));
}
#[test]
fn alpha_equivalence_is_symmetric(a in arb_term(2), b in arb_term(2)) {
prop_assert_eq!(
alpha_equivalent(&a, &b),
alpha_equivalent(&b, &a),
);
}
#[test]
fn substitute_empty_is_identity(t in arb_term(3)) {
let empty = rustc_hash::FxHashMap::default();
prop_assert_eq!(t.substitute(&empty), t);
}
#[test]
fn rename_ops_empty_is_identity(t in arb_term(3)) {
let empty = std::collections::HashMap::new();
prop_assert_eq!(t.rename_ops(&empty), t);
}
#[test]
fn rename_ops_preserves_alpha_structure(
t in arb_term(2),
src in prop::sample::select(OP_NAMES),
tgt in prop::sample::select(OP_NAMES),
) {
let mut map = std::collections::HashMap::new();
map.insert(Arc::from(src), Arc::from(tgt));
let renamed = t.rename_ops(&map);
prop_assert_eq!(t.free_vars().len(), renamed.free_vars().len());
}
#[test]
fn free_vars_subset_after_substitution(
t in arb_term(2),
var in arb_name(),
replacement in arb_term(1),
) {
let mut subst = rustc_hash::FxHashMap::default();
subst.insert(var.clone(), replacement.clone());
let result = t.substitute(&subst);
let result_vars = result.free_vars();
let orig_vars = t.free_vars();
let repl_vars = replacement.free_vars();
for v in &result_vars {
prop_assert!(
(orig_vars.contains(v) && *v != var) || repl_vars.contains(v),
"unexpected var {:?} in result",
v,
);
}
}
}
}
}