use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum Term {
Var(Arc<str>),
App {
op: Arc<str>,
args: Vec<Self>,
},
Case {
scrutinee: Box<Self>,
branches: Vec<CaseBranch>,
},
Hole {
name: Option<Arc<str>>,
},
Let {
name: Arc<str>,
bound: Box<Self>,
body: Box<Self>,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct CaseBranch {
pub constructor: Arc<str>,
pub binders: Vec<Arc<str>>,
pub body: Term,
}
impl Term {
#[must_use]
pub fn var(name: impl Into<Arc<str>>) -> Self {
Self::Var(name.into())
}
#[must_use]
pub fn app(op: impl Into<Arc<str>>, args: Vec<Self>) -> Self {
Self::App {
op: op.into(),
args,
}
}
#[must_use]
pub fn constant(op: impl Into<Arc<str>>) -> Self {
Self::App {
op: op.into(),
args: Vec::new(),
}
}
#[must_use]
pub fn substitute(&self, subst: &rustc_hash::FxHashMap<Arc<str>, Self>) -> Self {
match self {
Self::Var(name) => subst.get(name).cloned().unwrap_or_else(|| self.clone()),
Self::Hole { .. } => self.clone(),
Self::App { op, args } => Self::App {
op: Arc::clone(op),
args: args.iter().map(|a| a.substitute(subst)).collect(),
},
Self::Case {
scrutinee,
branches,
} => {
let new_scrutinee = Box::new(scrutinee.substitute(subst));
let new_branches = branches
.iter()
.map(|b| {
let mut inner = subst.clone();
for binder in &b.binders {
inner.remove(binder);
}
CaseBranch {
constructor: Arc::clone(&b.constructor),
binders: b.binders.clone(),
body: b.body.substitute(&inner),
}
})
.collect();
Self::Case {
scrutinee: new_scrutinee,
branches: new_branches,
}
}
Self::Let { name, bound, body } => {
let new_bound = Box::new(bound.substitute(subst));
let mut inner = subst.clone();
inner.remove(name);
let body_free = body.free_vars();
let mut captures = false;
let mut taken: rustc_hash::FxHashSet<Arc<str>> = rustc_hash::FxHashSet::default();
for (k, v) in &inner {
if body_free.contains(k) {
let fv = v.free_vars();
if fv.contains(name) {
captures = true;
}
for n in fv {
taken.insert(n);
}
}
}
if captures {
let mut fresh = format!("{name}'");
while taken.contains::<str>(fresh.as_str())
|| body_free.contains::<str>(fresh.as_str())
|| &*fresh == name.as_ref()
{
fresh.push('\'');
}
let fresh_name: Arc<str> = Arc::from(fresh);
let mut rename = rustc_hash::FxHashMap::default();
rename.insert(Arc::clone(name), Self::Var(Arc::clone(&fresh_name)));
let renamed_body = body.substitute(&rename);
Self::Let {
name: fresh_name,
bound: new_bound,
body: Box::new(renamed_body.substitute(&inner)),
}
} else {
Self::Let {
name: Arc::clone(name),
bound: new_bound,
body: Box::new(body.substitute(&inner)),
}
}
}
}
}
#[must_use]
pub fn free_vars(&self) -> rustc_hash::FxHashSet<Arc<str>> {
let mut vars = rustc_hash::FxHashSet::default();
self.collect_vars(&mut vars);
vars
}
fn collect_vars(&self, vars: &mut rustc_hash::FxHashSet<Arc<str>>) {
match self {
Self::Var(name) => {
vars.insert(Arc::clone(name));
}
Self::Hole { .. } => {}
Self::App { args, .. } => {
for arg in args {
arg.collect_vars(vars);
}
}
Self::Case {
scrutinee,
branches,
} => {
scrutinee.collect_vars(vars);
for b in branches {
let mut local = rustc_hash::FxHashSet::default();
b.body.collect_vars(&mut local);
for binder in &b.binders {
local.remove(binder);
}
vars.extend(local);
}
}
Self::Let { name, bound, body } => {
bound.collect_vars(vars);
let mut local = rustc_hash::FxHashSet::default();
body.collect_vars(&mut local);
local.remove(name);
vars.extend(local);
}
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
match self {
Self::Var(_) | Self::Hole { .. } => self.clone(),
Self::App { op, args } => Self::App {
op: op_map.get(op).cloned().unwrap_or_else(|| Arc::clone(op)),
args: args.iter().map(|a| a.rename_ops(op_map)).collect(),
},
Self::Case {
scrutinee,
branches,
} => Self::Case {
scrutinee: Box::new(scrutinee.rename_ops(op_map)),
branches: branches
.iter()
.map(|b| CaseBranch {
constructor: op_map
.get(&b.constructor)
.cloned()
.unwrap_or_else(|| Arc::clone(&b.constructor)),
binders: b.binders.clone(),
body: b.body.rename_ops(op_map),
})
.collect(),
},
Self::Let { name, bound, body } => Self::Let {
name: Arc::clone(name),
bound: Box::new(bound.rename_ops(op_map)),
body: Box::new(body.rename_ops(op_map)),
},
}
}
}
#[must_use]
pub fn compose_subst<S1, S2>(
tau: &std::collections::HashMap<Arc<str>, Term, S1>,
sigma: &std::collections::HashMap<Arc<str>, Term, S2>,
) -> rustc_hash::FxHashMap<Arc<str>, Term>
where
S1: std::hash::BuildHasher,
S2: std::hash::BuildHasher,
{
let tau_fx: rustc_hash::FxHashMap<Arc<str>, Term> = tau
.iter()
.map(|(k, v)| (Arc::clone(k), v.clone()))
.collect();
let mut out: rustc_hash::FxHashMap<Arc<str>, Term> = rustc_hash::FxHashMap::default();
for (x, t) in sigma {
out.insert(Arc::clone(x), t.substitute(&tau_fx));
}
for (x, t) in tau {
out.entry(Arc::clone(x)).or_insert_with(|| t.clone());
}
out
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct Equation {
pub name: Arc<str>,
pub lhs: Term,
pub rhs: Term,
}
impl Equation {
#[must_use]
pub fn new(name: impl Into<Arc<str>>, lhs: Term, rhs: Term) -> Self {
Self {
name: name.into(),
lhs,
rhs,
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
Self {
name: Arc::clone(&self.name),
lhs: self.lhs.rename_ops(op_map),
rhs: self.rhs.rename_ops(op_map),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct DirectedEquation {
pub name: Arc<str>,
pub lhs: Term,
pub rhs: Term,
pub impl_term: panproto_expr::Expr,
pub inverse: Option<panproto_expr::Expr>,
pub source_kind: Option<crate::sort::ValueKind>,
pub target_kind: Option<crate::sort::ValueKind>,
pub coercion_class: crate::sort::CoercionClass,
}
#[must_use]
pub fn alpha_equivalent(t1: &Term, t2: &Term) -> bool {
let mut checker = AlphaChecker {
forward: rustc_hash::FxHashMap::default(),
backward: rustc_hash::FxHashMap::default(),
};
checker.check(t1, t2)
}
#[must_use]
pub fn alpha_equivalent_equation(lhs1: &Term, rhs1: &Term, lhs2: &Term, rhs2: &Term) -> bool {
let mut checker = AlphaChecker {
forward: rustc_hash::FxHashMap::default(),
backward: rustc_hash::FxHashMap::default(),
};
checker.check(lhs1, lhs2) && checker.check(rhs1, rhs2)
}
struct AlphaChecker {
forward: rustc_hash::FxHashMap<Arc<str>, Arc<str>>,
backward: rustc_hash::FxHashMap<Arc<str>, Arc<str>>,
}
impl AlphaChecker {
fn check(&mut self, t1: &Term, t2: &Term) -> bool {
match (t1, t2) {
(Term::Var(a), Term::Var(b)) => self.check_vars(a, b),
(
Term::App {
op: op1,
args: args1,
},
Term::App {
op: op2,
args: args2,
},
) => {
op1 == op2
&& args1.len() == args2.len()
&& args1
.iter()
.zip(args2.iter())
.all(|(a1, a2)| self.check(a1, a2))
}
(
Term::Case {
scrutinee: s1,
branches: b1,
},
Term::Case {
scrutinee: s2,
branches: b2,
},
) => {
if !self.check(s1, s2) {
return false;
}
if b1.len() != b2.len() {
return false;
}
for (br1, br2) in b1.iter().zip(b2.iter()) {
if br1.constructor != br2.constructor || br1.binders.len() != br2.binders.len()
{
return false;
}
let saved_forward = self.forward.clone();
let saved_backward = self.backward.clone();
let mut ok = true;
for (a, b) in br1.binders.iter().zip(br2.binders.iter()) {
self.forward.insert(Arc::clone(a), Arc::clone(b));
self.backward.insert(Arc::clone(b), Arc::clone(a));
}
if !self.check(&br1.body, &br2.body) {
ok = false;
}
self.forward = saved_forward;
self.backward = saved_backward;
if !ok {
return false;
}
}
true
}
(Term::Hole { name: n1 }, Term::Hole { name: n2 }) => n1 == n2,
(
Term::Let {
name: n1,
bound: b1,
body: body1,
},
Term::Let {
name: n2,
bound: b2,
body: body2,
},
) => self.check_let(n1, b1, body1, n2, b2, body2),
(
Term::Var(_)
| Term::App { .. }
| Term::Case { .. }
| Term::Hole { .. }
| Term::Let { .. },
_,
) => false,
}
}
fn check_vars(&mut self, a: &Arc<str>, b: &Arc<str>) -> bool {
if let Some(mapped) = self.forward.get(a) {
if mapped != b {
return false;
}
} else if let Some(mapped_back) = self.backward.get(b) {
if mapped_back != a {
return false;
}
} else {
self.forward.insert(Arc::clone(a), Arc::clone(b));
self.backward.insert(Arc::clone(b), Arc::clone(a));
}
true
}
fn check_let(
&mut self,
n1: &Arc<str>,
b1: &Term,
body1: &Term,
n2: &Arc<str>,
b2: &Term,
body2: &Term,
) -> bool {
if !self.check(b1, b2) {
return false;
}
let saved_forward = self.forward.clone();
let saved_backward = self.backward.clone();
self.forward.insert(Arc::clone(n1), Arc::clone(n2));
self.backward.insert(Arc::clone(n2), Arc::clone(n1));
let ok = self.check(body1, body2);
self.forward = saved_forward;
self.backward = saved_backward;
ok
}
}
#[must_use]
pub fn match_pattern(pattern: &Term, term: &Term) -> Option<rustc_hash::FxHashMap<Arc<str>, Term>> {
let mut subst = rustc_hash::FxHashMap::default();
if match_pattern_inner(pattern, term, &mut subst) {
Some(subst)
} else {
None
}
}
fn match_pattern_inner(
pattern: &Term,
term: &Term,
subst: &mut rustc_hash::FxHashMap<Arc<str>, Term>,
) -> bool {
match pattern {
Term::Var(name) => {
if subst.contains_key(name) {
subst.get(name).is_some_and(|existing| existing == term)
} else {
subst.insert(Arc::clone(name), term.clone());
true
}
}
Term::App {
op: p_op,
args: p_args,
} => match term {
Term::App {
op: t_op,
args: t_args,
} => {
p_op == t_op
&& p_args.len() == t_args.len()
&& p_args
.iter()
.zip(t_args.iter())
.all(|(p, t)| match_pattern_inner(p, t, subst))
}
Term::Var(_) | Term::Case { .. } | Term::Hole { .. } | Term::Let { .. } => false,
},
Term::Case {
scrutinee: p_s,
branches: p_b,
} => match term {
Term::Case {
scrutinee: t_s,
branches: t_b,
} => {
if p_b.len() != t_b.len() || !match_pattern_inner(p_s, t_s, subst) {
return false;
}
for (pb, tb) in p_b.iter().zip(t_b.iter()) {
if pb.constructor != tb.constructor || pb.binders.len() != tb.binders.len() {
return false;
}
let saved = subst.clone();
for (pb_b, tb_b) in pb.binders.iter().zip(tb.binders.iter()) {
subst.insert(Arc::clone(pb_b), Term::Var(Arc::clone(tb_b)));
}
let ok = match_pattern_inner(&pb.body, &tb.body, subst);
*subst = saved;
if !ok {
return false;
}
}
true
}
Term::Var(_) | Term::App { .. } | Term::Hole { .. } | Term::Let { .. } => false,
},
Term::Hole { name } => match term {
Term::Hole { name: n2 } => name == n2,
Term::Var(_) | Term::App { .. } | Term::Case { .. } | Term::Let { .. } => false,
},
Term::Let {
name: p_n,
bound: p_b,
body: p_body,
} => match term {
Term::Let {
name: t_n,
bound: t_b,
body: t_body,
} => {
if !match_pattern_inner(p_b, t_b, subst) {
return false;
}
let saved = subst.clone();
subst.insert(Arc::clone(p_n), Term::Var(Arc::clone(t_n)));
let ok = match_pattern_inner(p_body, t_body, subst);
*subst = saved;
ok
}
Term::Var(_) | Term::App { .. } | Term::Case { .. } | Term::Hole { .. } => false,
},
}
}
#[must_use]
pub fn normalize(term: &Term, directed_eqs: &[DirectedEquation], max_steps: usize) -> Term {
let mut current = term.clone();
let mut steps = 0;
loop {
let next = normalize_once(¤t, directed_eqs, &mut steps, max_steps);
if next == current || steps >= max_steps {
return next;
}
current = next;
}
}
fn normalize_once(
term: &Term,
directed_eqs: &[DirectedEquation],
steps: &mut usize,
max_steps: usize,
) -> Term {
if *steps >= max_steps {
return term.clone();
}
let normalized_subterms = match term {
Term::Var(_) | Term::Hole { .. } => term.clone(),
Term::Let { name, bound, body } => {
let new_bound = normalize_once(bound, directed_eqs, steps, max_steps);
let mut subst = rustc_hash::FxHashMap::default();
subst.insert(Arc::clone(name), new_bound);
let substituted = body.substitute(&subst);
return normalize_once(&substituted, directed_eqs, steps, max_steps);
}
Term::App { op, args } => {
let new_args: Vec<Term> = args
.iter()
.map(|a| normalize_once(a, directed_eqs, steps, max_steps))
.collect();
Term::App {
op: Arc::clone(op),
args: new_args,
}
}
Term::Case {
scrutinee,
branches,
} => {
let new_scrut = Box::new(normalize_once(scrutinee, directed_eqs, steps, max_steps));
if let Term::App { op, args } = new_scrut.as_ref() {
if let Some(branch) = branches.iter().find(|b| &b.constructor == op) {
if branch.binders.len() == args.len() {
let mut subst = rustc_hash::FxHashMap::default();
for (binder, arg) in branch.binders.iter().zip(args.iter()) {
subst.insert(Arc::clone(binder), arg.clone());
}
let body = branch.body.substitute(&subst);
return normalize_once(&body, directed_eqs, steps, max_steps);
}
}
}
let new_branches = branches
.iter()
.map(|b| CaseBranch {
constructor: Arc::clone(&b.constructor),
binders: b.binders.clone(),
body: normalize_once(&b.body, directed_eqs, steps, max_steps),
})
.collect();
Term::Case {
scrutinee: new_scrut,
branches: new_branches,
}
}
};
for de in directed_eqs {
if let Some(subst) = match_pattern(&de.lhs, &normalized_subterms) {
*steps += 1;
let rewritten = de.rhs.substitute(&subst);
return normalize_once(&rewritten, directed_eqs, steps, max_steps);
}
}
normalized_subterms
}
impl DirectedEquation {
#[must_use]
pub fn new(
name: impl Into<Arc<str>>,
lhs: Term,
rhs: Term,
impl_term: panproto_expr::Expr,
) -> Self {
Self {
name: name.into(),
lhs,
rhs,
impl_term,
inverse: None,
source_kind: None,
target_kind: None,
coercion_class: crate::sort::CoercionClass::Opaque,
}
}
#[must_use]
pub fn with_inverse(
name: impl Into<Arc<str>>,
lhs: Term,
rhs: Term,
impl_term: panproto_expr::Expr,
inverse: panproto_expr::Expr,
) -> Self {
Self {
name: name.into(),
lhs,
rhs,
impl_term,
inverse: Some(inverse),
source_kind: None,
target_kind: None,
coercion_class: crate::sort::CoercionClass::Retraction,
}
}
#[must_use]
pub const fn with_kinds(
mut self,
source: crate::sort::ValueKind,
target: crate::sort::ValueKind,
class: crate::sort::CoercionClass,
) -> Self {
self.source_kind = Some(source);
self.target_kind = Some(target);
self.coercion_class = class;
self
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
Self {
name: Arc::clone(&self.name),
lhs: self.lhs.rename_ops(op_map),
rhs: self.rhs.rename_ops(op_map),
impl_term: self.impl_term.clone(),
inverse: self.inverse.clone(),
source_kind: self.source_kind,
target_kind: self.target_kind,
coercion_class: self.coercion_class,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn term_substitution() {
let term = Term::app("add", vec![Term::var("x"), Term::constant("zero")]);
let mut subst = rustc_hash::FxHashMap::default();
subst.insert(Arc::from("x"), Term::var("y"));
let result = term.substitute(&subst);
assert_eq!(
result,
Term::app("add", vec![Term::var("y"), Term::constant("zero")])
);
}
#[test]
fn free_variables() {
let term = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let vars = term.free_vars();
assert!(vars.contains("x"));
assert!(vars.contains("y"));
assert_eq!(vars.len(), 2);
}
#[test]
fn alpha_eq_same_vars() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_renamed_vars() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let t2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_non_injective_rejected() {
let t1 = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let t2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_non_surjective_rejected() {
let t1 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("x")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_different_ops() {
let t1 = Term::app("f", vec![Term::var("x")]);
let t2 = Term::app("g", vec![Term::var("x")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_different_structure() {
let t1 = Term::app(
"f",
vec![Term::var("x"), Term::app("g", vec![Term::var("y")])],
);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_constants() {
let t1 = Term::app("f", vec![Term::constant("c")]);
let t2 = Term::app("f", vec![Term::constant("c")]);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_constants_differ() {
let t1 = Term::app("f", vec![Term::constant("c")]);
let t2 = Term::app("f", vec![Term::constant("d")]);
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_nested_renamed() {
let t1 = Term::app(
"f",
vec![
Term::app("g", vec![Term::var("x"), Term::var("y")]),
Term::app("h", vec![Term::var("y"), Term::var("x")]),
],
);
let t2 = Term::app(
"f",
vec![
Term::app("g", vec![Term::var("a"), Term::var("b")]),
Term::app("h", vec![Term::var("b"), Term::var("a")]),
],
);
assert!(alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_equation_shared_bijection() {
let lhs1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let rhs1 = Term::app("g", vec![Term::var("y"), Term::var("x")]);
let lhs2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let rhs2 = Term::app("g", vec![Term::var("b"), Term::var("a")]);
assert!(alpha_equivalent_equation(&lhs1, &rhs1, &lhs2, &rhs2));
}
#[test]
fn alpha_eq_equation_inconsistent_bijection() {
let lhs1 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let rhs1 = Term::app("g", vec![Term::var("y")]);
let lhs2 = Term::app("f", vec![Term::var("a"), Term::var("b")]);
let rhs2 = Term::app("g", vec![Term::var("a")]);
assert!(!alpha_equivalent_equation(&lhs1, &rhs1, &lhs2, &rhs2));
}
#[test]
fn match_pattern_var_binds() {
let pat = Term::var("x");
let term = Term::app("f", vec![Term::constant("a")]);
let result = match_pattern(&pat, &term);
assert!(result.is_some());
let subst = result.unwrap();
assert_eq!(subst.get(&Arc::from("x")).unwrap(), &term);
}
#[test]
fn match_pattern_op_matches() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("b")]);
let result = match_pattern(&pat, &term);
assert!(result.is_some());
let subst = result.unwrap();
assert_eq!(subst.get(&Arc::from("x")).unwrap(), &Term::constant("a"));
assert_eq!(subst.get(&Arc::from("y")).unwrap(), &Term::constant("b"));
}
#[test]
fn match_pattern_op_mismatch() {
let pat = Term::app("f", vec![Term::var("x")]);
let term = Term::app("g", vec![Term::constant("a")]);
assert!(match_pattern(&pat, &term).is_none());
}
#[test]
fn match_pattern_repeated_var_consistent() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("a")]);
assert!(match_pattern(&pat, &term).is_some());
}
#[test]
fn match_pattern_repeated_var_inconsistent() {
let pat = Term::app("f", vec![Term::var("x"), Term::var("x")]);
let term = Term::app("f", vec![Term::constant("a"), Term::constant("b")]);
assert!(match_pattern(&pat, &term).is_none());
}
fn make_directed_eq(name: &str, lhs: Term, rhs: Term) -> DirectedEquation {
DirectedEquation::new(name, lhs, rhs, panproto_expr::Expr::Var("_".into()))
}
#[test]
fn normalize_no_rules() {
let term = Term::app("f", vec![Term::var("x")]);
let result = normalize(&term, &[], 100);
assert_eq!(result, term);
}
#[test]
fn normalize_simple_rewrite() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app("add", vec![Term::constant("zero"), Term::var("x")]);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::var("x"));
}
#[test]
fn normalize_nested() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app(
"f",
vec![Term::app(
"add",
vec![Term::constant("zero"), Term::var("x")],
)],
);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::app("f", vec![Term::var("x")]));
}
#[test]
fn normalize_multi_step() {
let rule = make_directed_eq(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let term = Term::app(
"add",
vec![
Term::constant("zero"),
Term::app("add", vec![Term::constant("zero"), Term::var("x")]),
],
);
let result = normalize(&term, &[rule], 100);
assert_eq!(result, Term::var("x"));
}
#[test]
fn normalize_max_steps_guard() {
let rule = make_directed_eq(
"expand",
Term::app("f", vec![Term::var("x")]),
Term::app("f", vec![Term::app("f", vec![Term::var("x")])]),
);
let term = Term::app("f", vec![Term::constant("a")]);
let result = normalize(&term, &[rule], 5);
assert!(matches!(result, Term::App { .. }));
}
#[test]
fn alpha_eq_var_vs_app() {
let t1 = Term::var("x");
let t2 = Term::constant("c");
assert!(!alpha_equivalent(&t1, &t2));
}
#[test]
fn alpha_eq_arity_mismatch() {
let t1 = Term::app("f", vec![Term::var("x")]);
let t2 = Term::app("f", vec![Term::var("x"), Term::var("y")]);
assert!(!alpha_equivalent(&t1, &t2));
}
fn mk_subst(pairs: &[(&str, Term)]) -> rustc_hash::FxHashMap<Arc<str>, Term> {
let mut m = rustc_hash::FxHashMap::default();
for (k, v) in pairs {
m.insert(Arc::from(*k), v.clone());
}
m
}
#[test]
fn compose_subst_agrees_with_sequential_application() {
let t = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let sigma = mk_subst(&[("x", Term::app("g", vec![Term::var("a")]))]);
let tau = mk_subst(&[("y", Term::constant("b")), ("a", Term::constant("c"))]);
let sequential = t.substitute(&sigma).substitute(&tau);
let composed = t.substitute(&compose_subst(&tau, &sigma));
assert_eq!(sequential, composed);
}
#[test]
fn substitute_empty_is_identity_unit() {
let t = Term::app("f", vec![Term::var("x"), Term::constant("c")]);
let empty = rustc_hash::FxHashMap::default();
assert_eq!(t.substitute(&empty), t);
}
#[test]
fn compose_subst_empty_left_is_right() {
let sigma = mk_subst(&[("x", Term::var("y"))]);
let empty = rustc_hash::FxHashMap::default();
let composed = compose_subst(&empty, &sigma);
assert_eq!(composed.get(&Arc::from("x")).unwrap(), &Term::var("y"));
assert_eq!(composed.len(), 1);
}
#[test]
fn compose_subst_empty_right_is_left() {
let tau = mk_subst(&[("y", Term::var("z"))]);
let empty = rustc_hash::FxHashMap::default();
let composed = compose_subst(&tau, &empty);
assert_eq!(composed.get(&Arc::from("y")).unwrap(), &Term::var("z"));
assert_eq!(composed.len(), 1);
}
#[test]
fn alpha_equivalence_transitive() {
let a = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let b = Term::app("f", vec![Term::var("u"), Term::var("v")]);
let c = Term::app("f", vec![Term::var("p"), Term::var("q")]);
assert!(alpha_equivalent(&a, &b));
assert!(alpha_equivalent(&b, &c));
assert!(alpha_equivalent(&a, &c));
}
mod property {
use super::*;
use proptest::prelude::*;
const VAR_NAMES: &[&str] = &["x", "y", "z", "a", "b"];
const OP_NAMES: &[&str] = &["f", "g", "h", "add", "mul"];
fn arb_name() -> impl Strategy<Value = Arc<str>> {
prop::sample::select(VAR_NAMES).prop_map(Arc::from)
}
fn arb_term(max_depth: usize) -> BoxedStrategy<Term> {
if max_depth == 0 {
arb_name().prop_map(Term::Var).boxed()
} else {
let leaf = arb_name().prop_map(Term::Var);
let app = (
prop::sample::select(OP_NAMES).prop_map(Arc::from),
prop::collection::vec(arb_term(max_depth - 1), 0..=3),
)
.prop_map(|(op, args)| Term::App { op, args });
prop_oneof![leaf, app].boxed()
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn alpha_equivalence_is_reflexive(t in arb_term(3)) {
prop_assert!(alpha_equivalent(&t, &t));
}
#[test]
fn alpha_equivalence_is_symmetric(a in arb_term(2), b in arb_term(2)) {
prop_assert_eq!(
alpha_equivalent(&a, &b),
alpha_equivalent(&b, &a),
);
}
#[test]
fn substitute_empty_is_identity(t in arb_term(3)) {
let empty = rustc_hash::FxHashMap::default();
prop_assert_eq!(t.substitute(&empty), t);
}
#[test]
fn rename_ops_empty_is_identity(t in arb_term(3)) {
let empty = std::collections::HashMap::new();
prop_assert_eq!(t.rename_ops(&empty), t);
}
#[test]
fn rename_ops_preserves_alpha_structure(
t in arb_term(2),
src in prop::sample::select(OP_NAMES),
tgt in prop::sample::select(OP_NAMES),
) {
let mut map = std::collections::HashMap::new();
map.insert(Arc::from(src), Arc::from(tgt));
let renamed = t.rename_ops(&map);
prop_assert_eq!(t.free_vars().len(), renamed.free_vars().len());
}
#[test]
fn substitute_composition_law(
t in arb_term(3),
v1 in arb_name(),
r1 in arb_term(1),
v2 in arb_name(),
r2 in arb_term(1),
) {
let sigma = {
let mut m = rustc_hash::FxHashMap::default();
m.insert(v1, r1);
m
};
let tau = {
let mut m = rustc_hash::FxHashMap::default();
m.insert(v2, r2);
m
};
let sequential = t.substitute(&sigma).substitute(&tau);
let composed = t.substitute(&compose_subst(&tau, &sigma));
prop_assert_eq!(sequential, composed);
}
#[test]
fn alpha_equivalence_transitive_prop(
t in arb_term(3),
) {
prop_assert!(alpha_equivalent(&t, &t));
}
#[test]
fn let_substitute_does_not_capture(
_dummy in arb_name(),
) {
let body = Term::app("f", vec![Term::Var(Arc::from("x")), Term::Var(Arc::from("y"))]);
let t = Term::Let {
name: Arc::from("x"),
bound: Box::new(Term::constant("a")),
body: Box::new(body),
};
let mut subst = rustc_hash::FxHashMap::default();
subst.insert(
Arc::from("y"),
Term::app("g", vec![Term::Var(Arc::from("x"))]),
);
let result = t.substitute(&subst);
if let Term::Let { name, body, .. } = result {
prop_assert_ne!(&*name, "x");
if let Term::App { args, .. } = *body {
let is_g = matches!(&args[1], Term::App { op, .. } if &**op == "g");
prop_assert!(is_g);
} else {
prop_assert!(false, "expected App body");
}
} else {
prop_assert!(false, "expected Let result");
}
}
#[test]
fn free_vars_subset_after_substitution(
t in arb_term(2),
var in arb_name(),
replacement in arb_term(1),
) {
let mut subst = rustc_hash::FxHashMap::default();
subst.insert(var.clone(), replacement.clone());
let result = t.substitute(&subst);
let result_vars = result.free_vars();
let orig_vars = t.free_vars();
let repl_vars = replacement.free_vars();
for v in &result_vars {
prop_assert!(
(orig_vars.contains(v) && *v != var) || repl_vars.contains(v),
"unexpected var {:?} in result",
v,
);
}
}
}
}
}