use std::sync::Arc;
use rustc_hash::FxHashMap;
use crate::eq::{DirectedEquation, Term, normalize};
use crate::error::GatError;
use crate::theory::Theory;
#[derive(Debug, Clone)]
pub struct ConfluenceReport {
pub critical_pairs: Vec<CriticalPair>,
}
#[derive(Debug, Clone)]
pub struct CriticalPair {
pub rule_a: Arc<str>,
pub rule_b: Arc<str>,
pub left: Term,
pub right: Term,
pub joins: bool,
}
pub fn check_local_confluence(
theory: &Theory,
normalize_depth: usize,
) -> Result<ConfluenceReport, GatError> {
let rules = &theory.directed_eqs;
let mut out = Vec::new();
for r1 in rules {
for r2 in rules {
collect_critical_pairs_at_positions(r1, r2, rules, normalize_depth, &mut out);
}
}
Ok(ConfluenceReport {
critical_pairs: out,
})
}
fn collect_critical_pairs_at_positions(
r1: &DirectedEquation,
r2: &DirectedEquation,
rules: &[DirectedEquation],
normalize_depth: usize,
out: &mut Vec<CriticalPair>,
) {
let r2_fresh = freshen_rule(r2, &r1.name);
let positions = non_variable_positions(&r1.lhs);
for pos in positions {
let Some(subterm) = term_at_position(&r1.lhs, &pos) else {
continue;
};
let Some(sigma) = first_order_unify(&subterm, &r2_fresh.lhs) else {
continue;
};
if pos.is_empty() && r1.name == r2.name {
continue;
}
let left = r1.rhs.substitute(&sigma);
let right_inner = r2_fresh.rhs.substitute(&sigma);
let right_whole = r1.lhs.substitute(&sigma);
let Ok(right) = replace_at_position(&right_whole, &pos, right_inner) else {
continue;
};
let left_nf = normalize(&left, rules, normalize_depth);
let right_nf = normalize(&right, rules, normalize_depth);
let joins = left_nf == right_nf;
out.push(CriticalPair {
rule_a: Arc::clone(&r1.name),
rule_b: Arc::clone(&r2.name),
left,
right,
joins,
});
}
}
fn freshen_rule(r: &DirectedEquation, other_name: &Arc<str>) -> DirectedEquation {
let vars = collect_vars(&r.lhs)
.into_iter()
.chain(collect_vars(&r.rhs))
.collect::<std::collections::BTreeSet<_>>();
let mut rename = FxHashMap::default();
for v in vars {
let fresh: Arc<str> = Arc::from(format!("{v}__{other_name}_cp"));
rename.insert(v, Term::Var(fresh));
}
DirectedEquation {
name: Arc::clone(&r.name),
lhs: r.lhs.substitute(&rename),
rhs: r.rhs.substitute(&rename),
impl_term: r.impl_term.clone(),
inverse: r.inverse.clone(),
source_kind: r.source_kind,
target_kind: r.target_kind,
coercion_class: r.coercion_class,
}
}
fn collect_vars_walk(term: &Term, out: &mut Vec<Arc<str>>) {
match term {
Term::Var(name) => {
if !out.contains(name) {
out.push(Arc::clone(name));
}
}
Term::Hole { .. } => {}
Term::Let { bound, body, .. } => {
collect_vars_walk(bound, out);
collect_vars_walk(body, out);
}
Term::App { args, .. } => {
for arg in args {
collect_vars_walk(arg, out);
}
}
Term::Case {
scrutinee,
branches,
} => {
collect_vars_walk(scrutinee, out);
for branch in branches {
collect_vars_walk(&branch.body, out);
}
}
}
}
fn collect_vars(term: &Term) -> Vec<Arc<str>> {
let mut out = Vec::new();
collect_vars_walk(term, &mut out);
out
}
fn non_variable_positions_walk(term: &Term, path: &mut Vec<usize>, out: &mut Vec<Vec<usize>>) {
if matches!(term, Term::App { .. }) {
out.push(path.clone());
}
if let Term::App { args, .. } = term {
for (i, arg) in args.iter().enumerate() {
path.push(i);
non_variable_positions_walk(arg, path, out);
path.pop();
}
}
}
fn non_variable_positions(term: &Term) -> Vec<Vec<usize>> {
let mut out = Vec::new();
non_variable_positions_walk(term, &mut Vec::new(), &mut out);
out
}
fn term_at_position(term: &Term, pos: &[usize]) -> Option<Term> {
let mut cur = term.clone();
for &i in pos {
match cur {
Term::App { args, .. } => {
cur = args.get(i).cloned()?;
}
Term::Case {
scrutinee,
branches,
} => {
if i == 0 {
cur = *scrutinee;
} else {
let branch = branches.get(i - 1)?;
cur = branch.body.clone();
}
}
Term::Let { bound, body, .. } => match i {
0 => cur = *bound,
1 => cur = *body,
_ => return None,
},
Term::Var(_) | Term::Hole { .. } => return None,
}
}
Some(cur)
}
fn replace_at_position(term: &Term, pos: &[usize], replacement: Term) -> Result<Term, GatError> {
if pos.is_empty() {
return Ok(replacement);
}
let head = pos[0];
let rest = &pos[1..];
match term.clone() {
Term::App { op, mut args } => {
let slot = args
.get(head)
.ok_or_else(|| GatError::InvalidRewritePosition {
path: pos.to_vec(),
node_kind: "App (index out of range)",
})?;
let replaced = replace_at_position(slot, rest, replacement)?;
args[head] = replaced;
Ok(Term::App { op, args })
}
Term::Case {
scrutinee,
mut branches,
} => {
if head == 0 {
let new_scrut = replace_at_position(&scrutinee, rest, replacement)?;
Ok(Term::Case {
scrutinee: Box::new(new_scrut),
branches,
})
} else {
let branch_idx = head - 1;
let branch =
branches
.get(branch_idx)
.ok_or_else(|| GatError::InvalidRewritePosition {
path: pos.to_vec(),
node_kind: "Case (branch index out of range)",
})?;
let new_body = replace_at_position(&branch.body, rest, replacement)?;
branches[branch_idx] = crate::eq::CaseBranch {
constructor: Arc::clone(&branch.constructor),
binders: branch.binders.clone(),
body: new_body,
};
Ok(Term::Case {
scrutinee,
branches,
})
}
}
Term::Let { name, bound, body } => match head {
0 => {
let new_bound = replace_at_position(&bound, rest, replacement)?;
Ok(Term::Let {
name,
bound: Box::new(new_bound),
body,
})
}
1 => {
let new_body = replace_at_position(&body, rest, replacement)?;
Ok(Term::Let {
name,
bound,
body: Box::new(new_body),
})
}
_ => Err(GatError::InvalidRewritePosition {
path: pos.to_vec(),
node_kind: "Let (only indices 0 and 1 are valid)",
}),
},
Term::Var(_) => Err(GatError::InvalidRewritePosition {
path: pos.to_vec(),
node_kind: "Var",
}),
Term::Hole { .. } => Err(GatError::InvalidRewritePosition {
path: pos.to_vec(),
node_kind: "Hole",
}),
}
}
fn first_order_unify(left: &Term, right: &Term) -> Option<FxHashMap<Arc<str>, Term>> {
let mut subst: FxHashMap<Arc<str>, Term> = FxHashMap::default();
let mut eqs = vec![(left.clone(), right.clone())];
while let Some((lhs, rhs)) = eqs.pop() {
let lhs = apply(&lhs, &subst);
let rhs = apply(&rhs, &subst);
match (lhs, rhs) {
(Term::Var(n1), Term::Var(n2)) if n1 == n2 => {}
(Term::Var(n), term) | (term, Term::Var(n)) => {
if occurs(&n, &term) {
return None;
}
let mut single = FxHashMap::default();
single.insert(Arc::clone(&n), term.clone());
for v in subst.values_mut() {
*v = v.substitute(&single);
}
subst.insert(n, term);
}
(Term::App { op: op_a, args: aa }, Term::App { op: op_b, args: bb }) => {
if op_a != op_b || aa.len() != bb.len() {
return None;
}
for pair in aa.into_iter().zip(bb) {
eqs.push(pair);
}
}
_ => return None,
}
}
Some(subst)
}
fn apply(t: &Term, subst: &FxHashMap<Arc<str>, Term>) -> Term {
if subst.is_empty() {
t.clone()
} else {
t.substitute(subst)
}
}
fn occurs(var: &Arc<str>, t: &Term) -> bool {
match t {
Term::Var(v) => v == var,
Term::Hole { .. } => false,
Term::Let { name, bound, body } => occurs(var, bound) || (name != var && occurs(var, body)),
Term::App { args, .. } => args.iter().any(|a| occurs(var, a)),
Term::Case {
scrutinee,
branches,
} => occurs(var, scrutinee) || branches.iter().any(|b| occurs(var, &b.body)),
}
}
#[derive(Debug, Clone, Default)]
pub struct OpPrecedence {
pub order: Vec<Arc<str>>,
}
impl OpPrecedence {
#[must_use]
pub fn new<I, S>(order: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<Arc<str>>,
{
Self {
order: order.into_iter().map(Into::into).collect(),
}
}
#[must_use]
pub fn compare(&self, a: &Arc<str>, b: &Arc<str>) -> Option<std::cmp::Ordering> {
let ia = self.order.iter().position(|n| n == a)?;
let ib = self.order.iter().position(|n| n == b)?;
Some(ia.cmp(&ib))
}
}
#[derive(Debug, Clone)]
pub struct TerminationReport {
pub violations: Vec<RuleViolation>,
}
#[derive(Debug, Clone)]
pub struct RuleViolation {
pub rule: Arc<str>,
pub reason: String,
}
pub fn check_termination_via_lpo(
theory: &Theory,
precedence: &OpPrecedence,
) -> Result<TerminationReport, GatError> {
let mut violations = Vec::new();
for rule in &theory.directed_eqs {
if contains_hole(&rule.lhs) || contains_hole(&rule.rhs) {
return Err(GatError::LpoHoleInRule {
rule: rule.name.to_string(),
});
}
if !lpo_greater(&rule.lhs, &rule.rhs, precedence) {
violations.push(RuleViolation {
rule: Arc::clone(&rule.name),
reason: format!("lhs {} is not >_lpo rhs {}", rule.lhs, rule.rhs),
});
}
}
Ok(TerminationReport { violations })
}
fn contains_hole(t: &Term) -> bool {
match t {
Term::Hole { .. } => true,
Term::Var(_) => false,
Term::App { args, .. } => args.iter().any(contains_hole),
Term::Case {
scrutinee,
branches,
} => contains_hole(scrutinee) || branches.iter().any(|b| contains_hole(&b.body)),
Term::Let { bound, body, .. } => contains_hole(bound) || contains_hole(body),
}
}
#[must_use]
pub fn lpo_greater(s: &Term, t: &Term, prec: &OpPrecedence) -> bool {
if matches!(s, Term::Var(_) | Term::Hole { .. }) {
return false;
}
let s_subs = subterms(s);
for si in &s_subs {
if si == t || lpo_greater(si, t, prec) {
return true;
}
}
match t {
Term::Var(x) => s_subs.iter().any(|a| contains_var(a, x)),
Term::Hole { .. } => false,
_ => {
let (Some(f), Some(g)) = (structural_head(s), structural_head(t)) else {
return false;
};
let t_subs = subterms(t);
match prec.compare(&f, &g) {
Some(std::cmp::Ordering::Greater) => {
t_subs.iter().all(|tj| lpo_greater(s, tj, prec))
}
Some(std::cmp::Ordering::Equal) => {
if !t_subs.iter().all(|tj| lpo_greater(s, tj, prec)) {
return false;
}
lex_greater(&s_subs, &t_subs, prec)
}
_ => false,
}
}
}
}
fn subterms(t: &Term) -> Vec<Term> {
match t {
Term::Var(_) | Term::Hole { .. } => Vec::new(),
Term::App { args, .. } => args.clone(),
Term::Case {
scrutinee,
branches,
} => {
let mut v = Vec::with_capacity(branches.len() + 1);
v.push((**scrutinee).clone());
for b in branches {
v.push(b.body.clone());
}
v
}
Term::Let { bound, body, .. } => vec![(**bound).clone(), (**body).clone()],
}
}
fn structural_head(t: &Term) -> Option<Arc<str>> {
match t {
Term::App { op, .. } => Some(Arc::clone(op)),
Term::Case { .. } => Some(Arc::from("__case__")),
Term::Let { .. } => Some(Arc::from("__let__")),
Term::Var(_) | Term::Hole { .. } => None,
}
}
fn contains_var(t: &Term, var: &Arc<str>) -> bool {
match t {
Term::Var(v) => v == var,
Term::Hole { .. } => false,
Term::Let { name, bound, body } => {
contains_var(bound, var) || (name != var && contains_var(body, var))
}
Term::App { args, .. } => args.iter().any(|a| contains_var(a, var)),
Term::Case {
scrutinee,
branches,
} => contains_var(scrutinee, var) || branches.iter().any(|b| contains_var(&b.body, var)),
}
}
fn lex_greater(a: &[Term], b: &[Term], prec: &OpPrecedence) -> bool {
for (ai, bi) in a.iter().zip(b.iter()) {
if ai == bi {
continue;
}
return lpo_greater(ai, bi, prec);
}
a.len() > b.len()
}
impl ConfluenceReport {
#[must_use]
pub fn is_locally_confluent(&self) -> bool {
self.critical_pairs.iter().all(|p| p.joins)
}
#[must_use]
pub fn non_joining(&self) -> Vec<&CriticalPair> {
self.critical_pairs.iter().filter(|p| !p.joins).collect()
}
}
impl TerminationReport {
#[must_use]
pub fn is_lpo_terminating(&self) -> bool {
self.violations.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eq::{DirectedEquation, Term};
use crate::op::Operation;
use crate::sort::Sort;
use crate::theory::Theory;
use panproto_expr::Expr;
fn mk_rule(name: &str, lhs: Term, rhs: Term) -> DirectedEquation {
DirectedEquation::new(name, lhs, rhs, Expr::Var("_".into()))
}
fn trivial_theory(rules: Vec<DirectedEquation>) -> Theory {
Theory::full(
"T",
Vec::new(),
vec![Sort::simple("S")],
vec![
Operation::nullary("zero", "S"),
Operation::unary("succ", "n", "S", "S"),
Operation::unary("f", "x", "S", "S"),
Operation::unary("g", "x", "S", "S"),
Operation::unary("h", "x", "S", "S"),
Operation::new(
"add",
vec![
(Arc::from("a"), crate::sort::SortExpr::from("S")),
(Arc::from("b"), crate::sort::SortExpr::from("S")),
],
"S",
),
],
Vec::new(),
rules,
Vec::new(),
)
}
#[test]
fn confluence_reports_trivially_joinable_pair() -> Result<(), GatError> {
let r1 = mk_rule(
"a",
Term::app("f", vec![Term::var("x")]),
Term::app("g", vec![Term::var("x")]),
);
let r2 = mk_rule(
"b",
Term::app("h", vec![Term::var("y")]),
Term::app("f", vec![Term::var("y")]),
);
let theory = trivial_theory(vec![r1, r2]);
let report = check_local_confluence(&theory, 32)?;
for cp in &report.critical_pairs {
assert!(cp.joins, "unexpected non-joining pair: {cp:?}");
}
Ok(())
}
#[test]
fn confluence_reports_known_critical_pair() -> Result<(), GatError> {
let r1 = mk_rule(
"r1",
Term::app("f", vec![Term::app("g", vec![Term::var("x")])]),
Term::var("x"),
);
let r2 = mk_rule(
"r2",
Term::app("g", vec![Term::var("y")]),
Term::app("h", vec![Term::var("y")]),
);
let theory = trivial_theory(vec![r1, r2]);
let report = check_local_confluence(&theory, 32)?;
assert!(
report
.critical_pairs
.iter()
.any(|cp| !cp.joins && &*cp.rule_a == "r1" && &*cp.rule_b == "r2"),
"expected non-joining critical pair between r1 and r2, got {:?}",
report.critical_pairs
);
Ok(())
}
#[test]
fn confluence_detects_non_joinable() -> Result<(), GatError> {
let r1 = mk_rule(
"r1",
Term::app("f", vec![Term::var("x")]),
Term::app("g", vec![Term::var("x")]),
);
let r2 = mk_rule(
"r2",
Term::app("f", vec![Term::var("x")]),
Term::app("h", vec![Term::var("x")]),
);
let theory = trivial_theory(vec![r1, r2]);
let report = check_local_confluence(&theory, 32)?;
assert!(!report.is_locally_confluent());
assert!(!report.non_joining().is_empty());
Ok(())
}
#[test]
fn lpo_accepts_decreasing_rule() -> Result<(), GatError> {
let r = mk_rule(
"left_id",
Term::app("add", vec![Term::constant("zero"), Term::var("y")]),
Term::var("y"),
);
let theory = trivial_theory(vec![r]);
let prec = OpPrecedence::new(["zero", "add"]);
let report = check_termination_via_lpo(&theory, &prec)?;
assert!(
report.is_lpo_terminating(),
"got violations {:?}",
report.violations,
);
Ok(())
}
#[test]
fn term_at_position_preserves_index_in_three_arg_app() {
let term = Term::app(
"f",
vec![
Term::constant("a"),
Term::constant("b"),
Term::constant("c"),
],
);
assert_eq!(term_at_position(&term, &[0]), Some(Term::constant("a")),);
assert_eq!(term_at_position(&term, &[1]), Some(Term::constant("b")),);
assert_eq!(term_at_position(&term, &[2]), Some(Term::constant("c")),);
}
#[test]
fn replace_at_position_inside_let_body() {
let t = Term::Let {
name: Arc::from("x"),
bound: Box::new(Term::constant("a")),
body: Box::new(Term::app("f", vec![Term::var("x")])),
};
let result = replace_at_position(&t, &[1], Term::app("g", vec![Term::var("x")]));
match result {
Ok(Term::Let { body, .. }) => {
assert_eq!(*body, Term::app("g", vec![Term::var("x")]));
}
other => panic!("expected Ok(Let), got {other:?}"),
}
}
#[test]
fn replace_at_position_invalid_var_errors() {
let t = Term::var("x");
let err = replace_at_position(&t, &[0], Term::constant("a"));
assert!(matches!(err, Err(GatError::InvalidRewritePosition { .. })));
}
#[test]
fn lpo_rejects_hole_in_rule() {
let r = mk_rule(
"bad",
Term::app("f", vec![Term::Hole { name: None }]),
Term::constant("a"),
);
let theory = trivial_theory(vec![r]);
let prec = OpPrecedence::new(["a", "f"]);
let err = check_termination_via_lpo(&theory, &prec);
assert!(matches!(err, Err(GatError::LpoHoleInRule { .. })));
}
#[test]
fn lpo_rejects_increasing_rule() -> Result<(), GatError> {
let r = mk_rule(
"expand",
Term::var("y"),
Term::app("f", vec![Term::var("y")]),
);
let theory = trivial_theory(vec![r]);
let prec = OpPrecedence::new(["f"]);
let report = check_termination_via_lpo(&theory, &prec)?;
assert!(!report.is_lpo_terminating());
Ok(())
}
mod property {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(128))]
#[test]
fn lpo_result_is_stable_under_recomputation(
order in prop::collection::vec(
prop::sample::select(&["f", "g", "h"][..]).prop_map(Arc::from),
0..=3,
),
) {
let mut seen = std::collections::BTreeSet::new();
let mut trimmed: Vec<Arc<str>> = Vec::new();
for n in order {
if seen.insert(Arc::clone(&n)) {
trimmed.push(n);
}
}
let prec = OpPrecedence { order: trimmed };
let s = Term::app("f", vec![Term::app("g", vec![Term::var("x")])]);
let t = Term::app("h", vec![Term::var("x")]);
let a = lpo_greater(&s, &t, &prec);
let b = lpo_greater(&s, &t, &prec);
prop_assert_eq!(a, b);
}
}
}
}