Skip to main content

panproto_gat/
typecheck.rs

1//! Term type-checking for GAT expressions.
2//!
3//! Verifies that terms are well-typed with respect to a theory's operation
4//! signatures. Each operation is allowed to produce a dependent output
5//! sort whose parameters reference the input argument terms: when an
6//! operation `f : (x1 : S1, ..., xn : Sn) -> T` is applied to concrete
7//! terms `a1, ..., an`, the result sort is `T[x1 := a1, ..., xn := an]`
8//! (Cartmell-style substitution into the sort expression).
9
10use std::sync::Arc;
11
12use rustc_hash::FxHashMap;
13
14use crate::eq::{CaseBranch, Equation, Term};
15use crate::error::GatError;
16use crate::op::{Implicit, Operation};
17use crate::sort::{SortClosure, SortExpr};
18use crate::theory::Theory;
19
20/// A sort scheme: a sort expression universally quantified over a list
21/// of metavariable names.
22///
23/// Retained for future extension. GAT signatures are first-order and
24/// have no free sort-metavariables, so [`typecheck_term`] never
25/// constructs a non-empty `metavars` list at the present time; every
26/// scheme encountered in practice is monomorphic.
27#[derive(Debug, Clone)]
28pub struct SortScheme {
29    /// Universally-quantified metavariable names.
30    pub metavars: Vec<Arc<str>>,
31    /// The scheme's body.
32    pub body: SortExpr,
33}
34
35impl SortScheme {
36    /// Build a monomorphic scheme from a sort expression.
37    #[must_use]
38    pub const fn mono(body: SortExpr) -> Self {
39        Self {
40            metavars: Vec::new(),
41            body,
42        }
43    }
44
45    /// Instantiate the scheme by freshening each metavariable and
46    /// substituting the fresh name into the scheme's body.
47    ///
48    /// Fresh names are derived by suffixing `counter` to each
49    /// metavariable, so two calls with different counters produce
50    /// distinct sort expressions.
51    #[must_use]
52    pub fn instantiate(&self, counter: usize) -> SortExpr {
53        if self.metavars.is_empty() {
54            return self.body.clone();
55        }
56        let mut subst: FxHashMap<Arc<str>, crate::eq::Term> = FxHashMap::default();
57        for mv in &self.metavars {
58            let fresh: Arc<str> = Arc::from(format!("{mv}_inst_{counter}"));
59            subst.insert(Arc::clone(mv), crate::eq::Term::Var(fresh));
60        }
61        self.body.subst(&subst)
62    }
63}
64
65/// A report generated for each typed hole encountered by
66/// [`typecheck_term_with_holes`].
67#[derive(Debug, Clone)]
68pub struct HoleReport {
69    /// The hole's optional name (from `?name` syntax).
70    pub name: Option<Arc<str>>,
71    /// The sort expected at this hole site. For holes whose sort is not
72    /// constrained by the surrounding context, this is a fresh
73    /// metavariable sort.
74    pub expected: SortExpr,
75    /// The variable context in scope at the hole.
76    pub context: VarContext,
77    /// Optional source span (populated by the DSL when available).
78    pub position: Option<miette::SourceSpan>,
79}
80
81/// A variable typing context.
82///
83/// Maps variable names to their sort expressions. Sort expressions may
84/// themselves reference variables (as `Term::Var` nodes under a
85/// `SortExpr::App`), which is the load-bearing feature that distinguishes
86/// GATs from many-sorted equational theories.
87pub type VarContext = FxHashMap<Arc<str>, SortExpr>;
88
89/// Infer the output sort expression of a term given a variable context
90/// and theory.
91///
92/// For `Var(x)`: returns `ctx[x]` or [`GatError::UnboundVariable`].
93/// For `App { op, args }`: looks up `op` in the theory, recursively
94/// typechecks each argument, and for each argument position `i` compares
95/// the argument's inferred sort against `op.inputs[i].1.subst(θ)` where
96/// θ is the running substitution from parameter names to argument terms.
97/// The returned sort is `op.output.subst(θ)`.
98///
99/// # Errors
100///
101/// Returns an error if:
102/// - A variable is not in the context ([`GatError::UnboundVariable`]).
103/// - An operation is not in the theory ([`GatError::OpNotFound`]).
104/// - Argument count does not match ([`GatError::TermArityMismatch`]).
105/// - An argument's sort does not alpha-equal the expected input sort
106///   under the running substitution ([`GatError::ArgTypeMismatch`]).
107pub fn typecheck_term(
108    term: &Term,
109    ctx: &VarContext,
110    theory: &Theory,
111) -> Result<SortExpr, GatError> {
112    match term {
113        Term::Var(name) => ctx
114            .get(name)
115            .cloned()
116            .ok_or_else(|| GatError::UnboundVariable(name.to_string())),
117
118        Term::Hole { name } => {
119            // In strict mode (no hole collector), a hole typechecks to a
120            // fresh metavariable sort whose name is `?<name>` (or
121            // `?hole` for anonymous holes). This string does not
122            // round-trip through the DSL parser: callers that serialize
123            // a term for re-parsing must use `typecheck_term_with_holes`
124            // to collect hole reports and fill in their expected sorts
125            // before emitting. Use `typecheck_term_with_holes` to
126            // collect hole reports.
127            let mv: Arc<str> = Arc::from(format!("?{}", name.as_deref().unwrap_or("hole")));
128            Ok(SortExpr::Name(mv))
129        }
130
131        Term::App { op, args } => {
132            let operation = theory
133                .find_op(op)
134                .ok_or_else(|| GatError::OpNotFound(op.to_string()))?;
135
136            let has_implicits = operation
137                .inputs
138                .iter()
139                .any(|(_, _, imp)| matches!(imp, Implicit::Yes));
140            if has_implicits {
141                typecheck_app_with_implicits(op, args, operation, ctx, theory)
142            } else {
143                typecheck_app_explicit(op, args, operation, ctx, theory)
144            }
145        }
146
147        Term::Case {
148            scrutinee,
149            branches,
150        } => typecheck_case(scrutinee, branches, ctx, theory),
151
152        Term::Let { name, bound, body } => {
153            // Typecheck the bound term, extend the context with the
154            // resulting sort, then typecheck the body. GAT sorts are
155            // first-order with no free sort-metavariables, so there is
156            // nothing to generalize over: the binding is always
157            // monomorphic and the [`SortScheme`] produced at a
158            // hypothetical generalization step always has an empty
159            // `metavars` list. We therefore bind the inferred sort
160            // directly rather than constructing a trivial scheme.
161            let bound_sort = typecheck_term(bound, ctx, theory)?;
162            let mut extended = ctx.clone();
163            extended.insert(Arc::clone(name), bound_sort);
164            typecheck_term(body, &extended, theory)
165        }
166    }
167}
168
169/// Typecheck a [`Term::Case`] expression.
170fn typecheck_case(
171    scrutinee: &Term,
172    branches: &[CaseBranch],
173    ctx: &VarContext,
174    theory: &Theory,
175) -> Result<SortExpr, GatError> {
176    let scrutinee_sort = typecheck_term(scrutinee, ctx, theory)?;
177    let sort_name = scrutinee_sort.head();
178    let sort_decl = theory
179        .find_sort(sort_name)
180        .ok_or_else(|| GatError::SortNotFound(sort_name.to_string()))?;
181    let constructors = match &sort_decl.closure {
182        SortClosure::Open => {
183            return Err(GatError::CaseOnOpenSort {
184                sort: sort_name.to_string(),
185            });
186        }
187        SortClosure::Closed(cs) => cs.clone(),
188    };
189
190    // Exhaustiveness: every constructor appears exactly once, and no
191    // branch names an unknown constructor. Build both checks in one
192    // pass.
193    let mut seen: rustc_hash::FxHashSet<Arc<str>> = rustc_hash::FxHashSet::default();
194    for b in branches {
195        if !constructors.contains(&b.constructor) {
196            return Err(GatError::UnknownCaseConstructor {
197                sort: sort_name.to_string(),
198                constructor: b.constructor.to_string(),
199            });
200        }
201        if !seen.insert(Arc::clone(&b.constructor)) {
202            return Err(GatError::RedundantCaseBranch {
203                sort: sort_name.to_string(),
204                constructor: b.constructor.to_string(),
205            });
206        }
207    }
208    if seen.len() < constructors.len() {
209        let missing: Vec<String> = constructors
210            .iter()
211            .filter(|c| !seen.contains(*c))
212            .map(ToString::to_string)
213            .collect();
214        return Err(GatError::NonExhaustiveCase {
215            sort: sort_name.to_string(),
216            missing,
217        });
218    }
219
220    // Typecheck each branch body; all must produce alpha-equivalent
221    // output sorts. The first branch's output sort is the case
222    // expression's sort.
223    let mut branch_sort: Option<SortExpr> = None;
224    for b in branches {
225        let constructor_op = theory
226            .find_op(&b.constructor)
227            .ok_or_else(|| GatError::OpNotFound(b.constructor.to_string()))?;
228        if constructor_op.inputs.len() != b.binders.len() {
229            return Err(GatError::TermArityMismatch {
230                op: b.constructor.to_string(),
231                expected: constructor_op.inputs.len(),
232                got: b.binders.len(),
233            });
234        }
235        // Unify constructor's declared output sort with the actual
236        // scrutinee sort to instantiate the constructor's parameter
237        // vars.
238        let unify_eqs: Vec<(Term, Term)> = constructor_op
239            .output
240            .args()
241            .iter()
242            .zip(scrutinee_sort.args().iter())
243            .map(|(a, b)| (a.clone(), b.clone()))
244            .collect();
245        if constructor_op.output.head() != scrutinee_sort.head()
246            || constructor_op.output.args().len() != scrutinee_sort.args().len()
247        {
248            return Err(GatError::OpTypeMismatch {
249                op: b.constructor.to_string(),
250                detail: format!(
251                    "constructor output sort {} does not match scrutinee sort {scrutinee_sort}",
252                    constructor_op.output
253                ),
254            });
255        }
256        let subst = unify_all(unify_eqs)?;
257        let mut extended = ctx.clone();
258        for ((_, declared_sort, _), binder) in constructor_op.inputs.iter().zip(b.binders.iter()) {
259            let binder_sort = declared_sort.subst(&subst);
260            extended.insert(Arc::clone(binder), binder_sort);
261        }
262        let body_sort = typecheck_term(&b.body, &extended, theory)?;
263        match &branch_sort {
264            None => branch_sort = Some(body_sort),
265            Some(existing) => {
266                if !existing.alpha_eq(&body_sort) {
267                    return Err(GatError::EquationSortMismatch {
268                        equation: "case".to_string(),
269                        lhs_sort: existing.to_string(),
270                        rhs_sort: body_sort.to_string(),
271                    });
272                }
273            }
274        }
275    }
276
277    branch_sort.ok_or_else(|| GatError::NonExhaustiveCase {
278        sort: sort_name.to_string(),
279        missing: constructors.iter().map(ToString::to_string).collect(),
280    })
281}
282
283/// Typecheck an `App` against an operation with every input marked
284/// explicit. Uses the existing sequential-theta-propagation path: each
285/// argument's inferred sort must alpha-equal the expected input sort
286/// under the running substitution theta.
287fn typecheck_app_explicit(
288    op: &Arc<str>,
289    args: &[Term],
290    operation: &Operation,
291    ctx: &VarContext,
292    theory: &Theory,
293) -> Result<SortExpr, GatError> {
294    if args.len() != operation.inputs.len() {
295        return Err(GatError::TermArityMismatch {
296            op: op.to_string(),
297            expected: operation.inputs.len(),
298            got: args.len(),
299        });
300    }
301
302    let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
303    for (i, (arg, (param_name, declared_sort, _))) in
304        args.iter().zip(operation.inputs.iter()).enumerate()
305    {
306        let arg_sort = typecheck_term(arg, ctx, theory)?;
307        let expected = declared_sort.subst(&theta);
308        if !arg_sort.alpha_eq(&expected) {
309            return Err(GatError::ArgTypeMismatch {
310                op: op.to_string(),
311                arg_index: i,
312                expected: expected.to_string(),
313                got: arg_sort.to_string(),
314            });
315        }
316        theta.insert(Arc::clone(param_name), arg.clone());
317    }
318
319    Ok(operation.output.subst(&theta))
320}
321
322/// Typecheck an `App` against an operation that declares one or more
323/// implicit inputs.
324///
325/// Implicit parameters are fresh-renamed to unique metavariables to
326/// avoid clashing with context variables. Explicit argument sorts are
327/// then unified against the declared input sorts via first-order
328/// Robinson unification; the MGU recovers the implicit params and the
329/// output sort follows by substitution.
330fn typecheck_app_with_implicits(
331    op: &Arc<str>,
332    args: &[Term],
333    operation: &Operation,
334    ctx: &VarContext,
335    theory: &Theory,
336) -> Result<SortExpr, GatError> {
337    let explicit_count = operation.explicit_arity();
338    if args.len() != explicit_count {
339        return Err(GatError::TermArityMismatch {
340            op: op.to_string(),
341            expected: explicit_count,
342            got: args.len(),
343        });
344    }
345
346    // Fresh-rename every implicit param name to a unique metavariable.
347    let mut fresh_rename: FxHashMap<Arc<str>, Term> = FxHashMap::default();
348    for (idx, (pname, _, imp)) in operation.inputs.iter().enumerate() {
349        if matches!(imp, Implicit::Yes) {
350            let mv: Arc<str> = Arc::from(format!("?{pname}_{idx}"));
351            fresh_rename.insert(Arc::clone(pname), Term::Var(mv));
352        }
353    }
354
355    // Walk explicit args, typecheck each, and push unification constraints
356    // matching declared input sort against inferred sort. Running theta
357    // also records explicit-arg values for later positions (dependent
358    // signatures).
359    let mut theta: FxHashMap<Arc<str>, Term> = fresh_rename.clone();
360    let mut term_eqs: Vec<(Term, Term)> = Vec::new();
361    let mut explicit_iter = args.iter();
362    for (pname, declared_sort, imp) in &operation.inputs {
363        match imp {
364            Implicit::Yes => {
365                // Nothing to do at this slot: the value is recovered by
366                // unification.
367            }
368            Implicit::No => {
369                let Some(arg) = explicit_iter.next() else {
370                    return Err(GatError::TermArityMismatch {
371                        op: op.to_string(),
372                        expected: explicit_count,
373                        got: args.len(),
374                    });
375                };
376                let arg_sort = typecheck_term(arg, ctx, theory)?;
377                let expected = declared_sort.subst(&theta);
378                push_sort_expr_eqs_into(&expected, &arg_sort, op, &mut term_eqs)?;
379                theta.insert(Arc::clone(pname), arg.clone());
380            }
381        }
382    }
383
384    let mgu = unify_all(term_eqs).map_err(|e| match e {
385        GatError::SortUnificationFailure { reason } => GatError::SortUnificationFailure {
386            reason: format!("implicit inference for {op}: {reason}"),
387        },
388        other => other,
389    })?;
390
391    // Compose mgu with theta to get the final substitution for the
392    // output sort.
393    let mut final_subst = theta.clone();
394    for (k, v) in &mgu {
395        final_subst.insert(Arc::clone(k), v.clone());
396    }
397    // Apply mgu to the running bindings so metavariables resolve.
398    let final_subst: FxHashMap<Arc<str>, Term> = final_subst
399        .into_iter()
400        .map(|(k, v)| (k, v.substitute(&mgu)))
401        .collect();
402
403    Ok(operation.output.subst(&final_subst))
404}
405
406/// Push head-agreement + pairwise-arg-term constraints for two sort
407/// expressions into a unification queue.
408///
409/// Heads must agree; on head mismatch this returns
410/// [`GatError::ArgTypeMismatch`] naming the two sorts. Argument terms
411/// are pushed pairwise onto `term_eqs` for later unification.
412fn push_sort_expr_eqs_into(
413    expected: &SortExpr,
414    actual: &SortExpr,
415    op: &Arc<str>,
416    term_eqs: &mut Vec<(Term, Term)>,
417) -> Result<(), GatError> {
418    if expected.head() != actual.head() || expected.args().len() != actual.args().len() {
419        return Err(GatError::ArgTypeMismatch {
420            op: op.to_string(),
421            arg_index: 0,
422            expected: expected.to_string(),
423            got: actual.to_string(),
424        });
425    }
426    for (x, y) in expected.args().iter().zip(actual.args().iter()) {
427        term_eqs.push((x.clone(), y.clone()));
428    }
429    Ok(())
430}
431
432/// Infer variable sorts from an equation's term structure.
433///
434/// Walks both sides of the equation and, for every operation-application
435/// site, imposes a sort-expression constraint on each variable argument.
436/// When two uses of the same variable impose different constraints, the
437/// constraints are unified via first-order unification over `Term`,
438/// producing a term-level substitution that is then applied back to the
439/// inferred sort expressions.
440///
441/// # Errors
442///
443/// Returns [`GatError::ConflictingVarSort`] when two sort-expression
444/// constraints on a variable have different heads,
445/// [`GatError::SortUnificationFailure`] when unification fails
446/// (including the occurs check), or [`GatError::OpNotFound`] when a
447/// referenced operation is absent from the theory.
448pub fn infer_var_sorts(eq: &Equation, theory: &Theory) -> Result<VarContext, GatError> {
449    let mut ctx = VarContext::default();
450    let mut term_eqs: Vec<(Term, Term)> = Vec::new();
451    collect_constraints(&eq.lhs, theory, &mut ctx, &mut term_eqs)?;
452    collect_constraints(&eq.rhs, theory, &mut ctx, &mut term_eqs)?;
453
454    let substitution = unify_all(term_eqs)?;
455    if !substitution.is_empty() {
456        for sort in ctx.values_mut() {
457            *sort = sort.subst(&substitution);
458        }
459    }
460    Ok(ctx)
461}
462
463/// Recursive helper: walk a term and constrain each variable argument
464/// to the expected input sort of its enclosing operation (with the
465/// running substitution θ applied so earlier arguments flow into later
466/// expected sorts).
467fn collect_constraints(
468    term: &Term,
469    theory: &Theory,
470    ctx: &mut VarContext,
471    term_eqs: &mut Vec<(Term, Term)>,
472) -> Result<(), GatError> {
473    let (op, args) = match term {
474        Term::App { op, args } => (op, args),
475        Term::Case {
476            scrutinee,
477            branches,
478        } => {
479            collect_constraints(scrutinee, theory, ctx, term_eqs)?;
480            for b in branches {
481                collect_constraints(&b.body, theory, ctx, term_eqs)?;
482            }
483            return Ok(());
484        }
485        Term::Let { bound, body, .. } => {
486            collect_constraints(bound, theory, ctx, term_eqs)?;
487            collect_constraints(body, theory, ctx, term_eqs)?;
488            return Ok(());
489        }
490        Term::Var(_) | Term::Hole { .. } => return Ok(()),
491    };
492    let operation = theory
493        .find_op(op)
494        .ok_or_else(|| GatError::OpNotFound(op.to_string()))?;
495
496    if args.len() != operation.inputs.len() {
497        return Err(GatError::TermArityMismatch {
498            op: op.to_string(),
499            expected: operation.inputs.len(),
500            got: args.len(),
501        });
502    }
503
504    let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
505    for (arg, (param_name, declared_sort, _)) in args.iter().zip(operation.inputs.iter()) {
506        let expected = declared_sort.subst(&theta);
507        match arg {
508            Term::Var(var_name) => {
509                if let Some(existing) = ctx.get(var_name).cloned() {
510                    unify_sort_exprs(&existing, &expected, var_name, term_eqs)?;
511                } else {
512                    ctx.insert(Arc::clone(var_name), expected);
513                }
514            }
515            Term::App { .. } | Term::Case { .. } | Term::Hole { .. } | Term::Let { .. } => {
516                collect_constraints(arg, theory, ctx, term_eqs)?;
517            }
518        }
519        theta.insert(Arc::clone(param_name), arg.clone());
520    }
521    Ok(())
522}
523
524/// Push pairwise equality constraints between two sort expressions.
525///
526/// Returns [`GatError::ConflictingVarSort`] when the heads differ or the
527/// argument arities do not line up. On success, accumulates pairwise
528/// `(Term, Term)` constraints into `term_eqs` for a later unification
529/// pass.
530fn unify_sort_exprs(
531    a: &SortExpr,
532    b: &SortExpr,
533    var: &Arc<str>,
534    term_eqs: &mut Vec<(Term, Term)>,
535) -> Result<(), GatError> {
536    if a.head() != b.head() {
537        return Err(GatError::ConflictingVarSort {
538            var: var.to_string(),
539            sort1: a.to_string(),
540            sort2: b.to_string(),
541        });
542    }
543    let a_args = a.args();
544    let b_args = b.args();
545    if a_args.len() != b_args.len() {
546        return Err(GatError::ConflictingVarSort {
547            var: var.to_string(),
548            sort1: a.to_string(),
549            sort2: b.to_string(),
550        });
551    }
552    for (x, y) in a_args.iter().zip(b_args.iter()) {
553        term_eqs.push((x.clone(), y.clone()));
554    }
555    Ok(())
556}
557
558/// First-order unification over a list of term equality constraints.
559///
560/// Implements Robinson-style unification with an explicit occurs check.
561/// Returns a substitution mapping variable names to terms, or a
562/// [`GatError::SortUnificationFailure`] when the constraints are
563/// unsatisfiable.
564fn unify_all(mut eqs: Vec<(Term, Term)>) -> Result<FxHashMap<Arc<str>, Term>, GatError> {
565    let mut subst: FxHashMap<Arc<str>, Term> = FxHashMap::default();
566
567    while let Some((a, b)) = eqs.pop() {
568        let a = apply_subst(&a, &subst);
569        let b = apply_subst(&b, &subst);
570        match (a, b) {
571            (Term::Var(x), Term::Var(y)) if x == y => {}
572            (Term::Var(x), t) | (t, Term::Var(x)) => {
573                if occurs_in(&x, &t) {
574                    return Err(GatError::SortUnificationFailure {
575                        reason: format!("occurs check failed: {x} in {t}"),
576                    });
577                }
578                // Extend substitution with x := t, applying it to existing bindings.
579                let updated: FxHashMap<Arc<str>, Term> = subst
580                    .iter()
581                    .map(|(k, v)| {
582                        (
583                            Arc::clone(k),
584                            v.substitute(&std::iter::once((Arc::clone(&x), t.clone())).collect()),
585                        )
586                    })
587                    .collect();
588                subst = updated;
589                subst.insert(x, t);
590            }
591            (
592                Term::App {
593                    op: op_a,
594                    args: args_a,
595                },
596                Term::App {
597                    op: op_b,
598                    args: args_b,
599                },
600            ) => {
601                if op_a != op_b {
602                    return Err(GatError::SortUnificationFailure {
603                        reason: format!("cannot unify {op_a}(...) with {op_b}(...)"),
604                    });
605                }
606                if args_a.len() != args_b.len() {
607                    return Err(GatError::SortUnificationFailure {
608                        reason: format!(
609                            "arity mismatch unifying {op_a}: {} vs {}",
610                            args_a.len(),
611                            args_b.len()
612                        ),
613                    });
614                }
615                for pair in args_a.into_iter().zip(args_b) {
616                    eqs.push(pair);
617                }
618            }
619            (lhs, rhs) => {
620                return Err(GatError::SortUnificationFailure {
621                    reason: format!("cannot unify {lhs} with {rhs}"),
622                });
623            }
624        }
625    }
626
627    Ok(subst)
628}
629
630fn apply_subst(term: &Term, subst: &FxHashMap<Arc<str>, Term>) -> Term {
631    if subst.is_empty() {
632        return term.clone();
633    }
634    term.substitute(subst)
635}
636
637fn occurs_in(var: &Arc<str>, term: &Term) -> bool {
638    match term {
639        Term::Var(v) => v == var,
640        Term::Hole { .. } => false,
641        Term::Let { name, bound, body } => {
642            occurs_in(var, bound) || (name != var && occurs_in(var, body))
643        }
644        Term::App { args, .. } => args.iter().any(|a| occurs_in(var, a)),
645        Term::Case {
646            scrutinee,
647            branches,
648        } => {
649            occurs_in(var, scrutinee)
650                || branches
651                    .iter()
652                    .any(|b| !b.binders.contains(var) && occurs_in(var, &b.body))
653        }
654    }
655}
656
657/// Typecheck a term, collecting a [`HoleReport`] at every [`Term::Hole`]
658/// site.
659///
660/// Unlike [`typecheck_term`], which treats holes as fresh metavariable
661/// sorts and lets the caller discard the information, this walker
662/// propagates the surrounding context's expected sort into each hole.
663/// When the enclosing operation's input sort constrains the hole, the
664/// report carries that sort exactly; otherwise the report carries a
665/// fresh metavariable sort.
666///
667/// # Errors
668///
669/// Returns any error from [`typecheck_term`] except that hole
670/// encounters never fail.
671pub fn typecheck_term_with_holes(
672    term: &Term,
673    ctx: &VarContext,
674    theory: &Theory,
675) -> Result<(SortExpr, Vec<HoleReport>), GatError> {
676    let mut reports: Vec<HoleReport> = Vec::new();
677    let sort = typecheck_with_expected(term, None, ctx, theory, &mut reports)?;
678    Ok((sort, reports))
679}
680
681fn typecheck_with_expected(
682    term: &Term,
683    expected: Option<&SortExpr>,
684    ctx: &VarContext,
685    theory: &Theory,
686    reports: &mut Vec<HoleReport>,
687) -> Result<SortExpr, GatError> {
688    match term {
689        Term::Hole { name } => {
690            let sort = expected.cloned().unwrap_or_else(|| {
691                SortExpr::Name(Arc::from(format!("?{}", name.as_deref().unwrap_or("hole"))))
692            });
693            reports.push(HoleReport {
694                name: name.clone(),
695                expected: sort.clone(),
696                context: ctx.clone(),
697                position: None,
698            });
699            Ok(sort)
700        }
701        Term::Var(n) => ctx
702            .get(n)
703            .cloned()
704            .ok_or_else(|| GatError::UnboundVariable(n.to_string())),
705        Term::App { op, args } => {
706            let operation = theory
707                .find_op(op)
708                .ok_or_else(|| GatError::OpNotFound(op.to_string()))?;
709            let has_implicits = operation
710                .inputs
711                .iter()
712                .any(|(_, _, imp)| matches!(imp, Implicit::Yes));
713            if has_implicits {
714                // Implicit-inference path: thread the expected sort
715                // through each explicit arg so that holes attach to the
716                // right expected sort, then recover the implicit
717                // parameter values via the same unification machinery
718                // as typecheck_term.
719                typecheck_app_with_implicits_collecting_holes(
720                    op, args, operation, ctx, theory, reports,
721                )
722            } else {
723                if args.len() != operation.inputs.len() {
724                    return Err(GatError::TermArityMismatch {
725                        op: op.to_string(),
726                        expected: operation.inputs.len(),
727                        got: args.len(),
728                    });
729                }
730                let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
731                for (i, (arg, (param_name, declared_sort, _))) in
732                    args.iter().zip(operation.inputs.iter()).enumerate()
733                {
734                    let expected_sort = declared_sort.subst(&theta);
735                    let arg_sort =
736                        typecheck_with_expected(arg, Some(&expected_sort), ctx, theory, reports)?;
737                    // Holes produce the expected sort by construction,
738                    // so alpha_eq holds trivially. For non-hole terms,
739                    // enforce the usual alpha_eq check.
740                    if !term_contains_hole(arg) && !arg_sort.alpha_eq(&expected_sort) {
741                        return Err(GatError::ArgTypeMismatch {
742                            op: op.to_string(),
743                            arg_index: i,
744                            expected: expected_sort.to_string(),
745                            got: arg_sort.to_string(),
746                        });
747                    }
748                    theta.insert(Arc::clone(param_name), arg.clone());
749                }
750                Ok(operation.output.subst(&theta))
751            }
752        }
753        Term::Case {
754            scrutinee,
755            branches,
756        } => typecheck_case_with_holes(scrutinee, branches, ctx, theory, reports),
757        Term::Let { name, bound, body } => {
758            let bound_sort = typecheck_with_expected(bound, None, ctx, theory, reports)?;
759            let mut extended = ctx.clone();
760            extended.insert(Arc::clone(name), bound_sort);
761            typecheck_with_expected(body, None, &extended, theory, reports)
762        }
763    }
764}
765
766fn typecheck_case_with_holes(
767    scrutinee: &Term,
768    branches: &[CaseBranch],
769    ctx: &VarContext,
770    theory: &Theory,
771    reports: &mut Vec<HoleReport>,
772) -> Result<SortExpr, GatError> {
773    let scrutinee_sort = typecheck_with_expected(scrutinee, None, ctx, theory, reports)?;
774    check_case_exhaustiveness_soft(&scrutinee_sort, branches, theory)?;
775    let mut branch_sort: Option<SortExpr> = None;
776    for b in branches {
777        let constructor_op = theory
778            .find_op(&b.constructor)
779            .ok_or_else(|| GatError::OpNotFound(b.constructor.to_string()))?;
780        if constructor_op.inputs.len() != b.binders.len() {
781            return Err(GatError::TermArityMismatch {
782                op: b.constructor.to_string(),
783                expected: constructor_op.inputs.len(),
784                got: b.binders.len(),
785            });
786        }
787        // Mirror the strict path: unify the constructor's declared
788        // output sort with the scrutinee's actual sort, apply the
789        // resulting subst to each declared binder input sort, and
790        // extend the context with the substituted sorts.
791        if constructor_op.output.head() != scrutinee_sort.head()
792            || constructor_op.output.args().len() != scrutinee_sort.args().len()
793        {
794            return Err(GatError::OpTypeMismatch {
795                op: b.constructor.to_string(),
796                detail: format!(
797                    "constructor output sort {} does not match scrutinee sort {scrutinee_sort}",
798                    constructor_op.output
799                ),
800            });
801        }
802        let unify_eqs: Vec<(Term, Term)> = constructor_op
803            .output
804            .args()
805            .iter()
806            .zip(scrutinee_sort.args().iter())
807            .map(|(a, b)| (a.clone(), b.clone()))
808            .collect();
809        let subst = unify_all(unify_eqs)?;
810        let mut extended = ctx.clone();
811        for ((_, declared_sort, _), binder) in constructor_op.inputs.iter().zip(b.binders.iter()) {
812            let binder_sort = declared_sort.subst(&subst);
813            extended.insert(Arc::clone(binder), binder_sort);
814        }
815        let body_sort = typecheck_with_expected(&b.body, None, &extended, theory, reports)?;
816        match &branch_sort {
817            None => branch_sort = Some(body_sort),
818            Some(existing) => {
819                // Require every branch body to produce an
820                // alpha-equivalent output sort. Mirror the strict
821                // typecheck_case behaviour.
822                if !existing.alpha_eq(&body_sort) {
823                    return Err(GatError::EquationSortMismatch {
824                        equation: "case".to_string(),
825                        lhs_sort: existing.to_string(),
826                        rhs_sort: body_sort.to_string(),
827                    });
828                }
829            }
830        }
831    }
832    branch_sort.ok_or_else(|| GatError::NonExhaustiveCase {
833        sort: scrutinee_sort.head().to_string(),
834        missing: Vec::new(),
835    })
836}
837
838fn check_case_exhaustiveness_soft(
839    scrutinee_sort: &SortExpr,
840    branches: &[CaseBranch],
841    theory: &Theory,
842) -> Result<(), GatError> {
843    let Some(sort_decl) = theory.find_sort(scrutinee_sort.head()) else {
844        return Ok(());
845    };
846    let SortClosure::Closed(ctors) = &sort_decl.closure else {
847        return Ok(());
848    };
849    let mut seen: rustc_hash::FxHashSet<Arc<str>> = rustc_hash::FxHashSet::default();
850    for b in branches {
851        if !ctors.contains(&b.constructor) {
852            return Err(GatError::UnknownCaseConstructor {
853                sort: scrutinee_sort.head().to_string(),
854                constructor: b.constructor.to_string(),
855            });
856        }
857        if !seen.insert(Arc::clone(&b.constructor)) {
858            return Err(GatError::RedundantCaseBranch {
859                sort: scrutinee_sort.head().to_string(),
860                constructor: b.constructor.to_string(),
861            });
862        }
863    }
864    if seen.len() < ctors.len() {
865        let missing: Vec<String> = ctors
866            .iter()
867            .filter(|c| !seen.contains(*c))
868            .map(ToString::to_string)
869            .collect();
870        return Err(GatError::NonExhaustiveCase {
871            sort: scrutinee_sort.head().to_string(),
872            missing,
873        });
874    }
875    Ok(())
876}
877
878/// Typecheck an equation: infer variable sorts, typecheck both sides,
879/// verify they produce the same output sort.
880///
881/// # Errors
882///
883/// Returns [`GatError::EquationSortMismatch`] if the two sides have
884/// different sorts, or any error from [`typecheck_term`] or
885/// [`infer_var_sorts`].
886pub fn typecheck_equation(eq: &Equation, theory: &Theory) -> Result<(), GatError> {
887    let hole_count = count_holes(&eq.lhs) + count_holes(&eq.rhs);
888    if hole_count > 0 {
889        return Err(GatError::HolesInEquation {
890            equation: eq.name.to_string(),
891            count: hole_count,
892        });
893    }
894    let ctx = infer_var_sorts(eq, theory)?;
895    let lhs_sort = typecheck_term(&eq.lhs, &ctx, theory)?;
896    let rhs_sort = typecheck_term(&eq.rhs, &ctx, theory)?;
897    if !lhs_sort.alpha_eq(&rhs_sort) {
898        return Err(GatError::EquationSortMismatch {
899            equation: eq.name.to_string(),
900            lhs_sort: lhs_sort.to_string(),
901            rhs_sort: rhs_sort.to_string(),
902        });
903    }
904    Ok(())
905}
906
907/// Typecheck an equation with sort equality relaxed modulo a set of
908/// directed rewrite rules.
909///
910/// Differs from [`typecheck_equation`] only in the final comparison
911/// step: the inferred LHS and RHS sorts are considered equal when they
912/// match under [`SortExpr::alpha_eq_modulo_rewrites`] with the given
913/// rules and step budget. Useful in dependent-type settings where
914/// judgmental equality holds modulo the theory's own rewrite system.
915///
916/// # Errors
917///
918/// Same as [`typecheck_equation`] except that the mismatch check uses
919/// the relaxed equality.
920pub fn typecheck_equation_modulo_rewrites(
921    eq: &Equation,
922    theory: &Theory,
923    rules: &[crate::eq::DirectedEquation],
924    step_limit: usize,
925) -> Result<(), GatError> {
926    let hole_count = count_holes(&eq.lhs) + count_holes(&eq.rhs);
927    if hole_count > 0 {
928        return Err(GatError::HolesInEquation {
929            equation: eq.name.to_string(),
930            count: hole_count,
931        });
932    }
933    let ctx = infer_var_sorts(eq, theory)?;
934    let lhs_sort = typecheck_term(&eq.lhs, &ctx, theory)?;
935    let rhs_sort = typecheck_term(&eq.rhs, &ctx, theory)?;
936    if !lhs_sort.alpha_eq_modulo_rewrites(&rhs_sort, rules, step_limit) {
937        return Err(GatError::EquationSortMismatch {
938            equation: eq.name.to_string(),
939            lhs_sort: lhs_sort.to_string(),
940            rhs_sort: rhs_sort.to_string(),
941        });
942    }
943    Ok(())
944}
945
946/// Typecheck all equations in a theory.
947///
948/// Also verifies, for every operation, that every implicit parameter's
949/// name occurs in at least one explicit input sort or the output sort.
950/// An implicit parameter that never appears in a position where
951/// first-order unification can pin it down is rejected with
952/// [`GatError::NonInferrableImplicit`].
953///
954/// # Errors
955///
956/// Returns the first type error encountered.
957pub fn typecheck_theory(theory: &Theory) -> Result<(), GatError> {
958    for op in &theory.ops {
959        check_implicits_inferrable(op)?;
960    }
961    check_closed_sorts(theory)?;
962    for eq in &theory.eqs {
963        typecheck_equation(eq, theory)?;
964    }
965    Ok(())
966}
967
968/// Verify every [`SortClosure::Closed`] sort's constructor list against
969/// the theory's op table.
970///
971/// For a sort `S` closed against `[c1, ..., cn]`:
972/// - each `ci` must be an op in the theory,
973/// - each `ci`'s output head must equal `S`,
974/// - no op outside `[c1, ..., cn]` produces `S`.
975fn check_closed_sorts(theory: &Theory) -> Result<(), GatError> {
976    for sort in &theory.sorts {
977        let SortClosure::Closed(ctors) = &sort.closure else {
978            continue;
979        };
980        let ctor_set: rustc_hash::FxHashSet<Arc<str>> = ctors.iter().map(Arc::clone).collect();
981        for ctor in ctors {
982            let op =
983                theory
984                    .find_op(ctor)
985                    .ok_or_else(|| GatError::InvalidClosedSortConstructor {
986                        sort: sort.name.to_string(),
987                        constructor: ctor.to_string(),
988                        detail: "op does not exist in the theory".to_string(),
989                    })?;
990            if op.output.head() != &sort.name {
991                return Err(GatError::InvalidClosedSortConstructor {
992                    sort: sort.name.to_string(),
993                    constructor: ctor.to_string(),
994                    detail: format!(
995                        "op output head is {}, expected {}",
996                        op.output.head(),
997                        sort.name
998                    ),
999                });
1000            }
1001        }
1002        for op in &theory.ops {
1003            if op.output.head() == &sort.name && !ctor_set.contains(&op.name) {
1004                return Err(GatError::InvalidClosedSortConstructor {
1005                    sort: sort.name.to_string(),
1006                    constructor: op.name.to_string(),
1007                    detail: "op produces the closed sort but is not listed in its closure"
1008                        .to_string(),
1009                });
1010            }
1011        }
1012    }
1013    Ok(())
1014}
1015
1016/// Verify that every implicit parameter of `op` occurs as a `Term::Var`
1017/// in at least one explicit input sort or in the output sort.
1018fn check_implicits_inferrable(op: &Operation) -> Result<(), GatError> {
1019    for (pname, _, imp) in &op.inputs {
1020        if !matches!(imp, Implicit::Yes) {
1021            continue;
1022        }
1023        let mut found = false;
1024        for (_, sort_expr, other_imp) in &op.inputs {
1025            if matches!(other_imp, Implicit::No) && sort_expr_mentions_var(sort_expr, pname) {
1026                found = true;
1027                break;
1028            }
1029        }
1030        if !found && sort_expr_mentions_var(&op.output, pname) {
1031            found = true;
1032        }
1033        if !found {
1034            return Err(GatError::NonInferrableImplicit {
1035                op: op.name.to_string(),
1036                param: pname.to_string(),
1037            });
1038        }
1039    }
1040    Ok(())
1041}
1042
1043/// Returns `true` if `name` appears as a [`Term::Var`] anywhere in the
1044/// argument terms of `sort`.
1045fn sort_expr_mentions_var(sort: &SortExpr, name: &Arc<str>) -> bool {
1046    sort.args().iter().any(|t| term_mentions_var(t, name))
1047}
1048
1049/// Hole-collecting variant of [`typecheck_app_with_implicits`]. Walks
1050/// explicit args through [`typecheck_with_expected`] so that holes
1051/// attach to the correct expected sort (before implicit-inference
1052/// unification pins anything down), then recovers the output sort via
1053/// the same unification pipeline as the non-collecting path.
1054fn typecheck_app_with_implicits_collecting_holes(
1055    op: &Arc<str>,
1056    args: &[Term],
1057    operation: &Operation,
1058    ctx: &VarContext,
1059    theory: &Theory,
1060    reports: &mut Vec<HoleReport>,
1061) -> Result<SortExpr, GatError> {
1062    let explicit_count = operation.explicit_arity();
1063    if args.len() != explicit_count {
1064        return Err(GatError::TermArityMismatch {
1065            op: op.to_string(),
1066            expected: explicit_count,
1067            got: args.len(),
1068        });
1069    }
1070
1071    let mut fresh_rename: FxHashMap<Arc<str>, Term> = FxHashMap::default();
1072    for (idx, (pname, _, imp)) in operation.inputs.iter().enumerate() {
1073        if matches!(imp, Implicit::Yes) {
1074            let mv: Arc<str> = Arc::from(format!("?{pname}_{idx}"));
1075            fresh_rename.insert(Arc::clone(pname), Term::Var(mv));
1076        }
1077    }
1078
1079    let mut theta: FxHashMap<Arc<str>, Term> = fresh_rename.clone();
1080    let mut term_eqs: Vec<(Term, Term)> = Vec::new();
1081    let mut explicit_iter = args.iter();
1082    for (pname, declared_sort, imp) in &operation.inputs {
1083        match imp {
1084            Implicit::Yes => {}
1085            Implicit::No => {
1086                let Some(arg) = explicit_iter.next() else {
1087                    return Err(GatError::TermArityMismatch {
1088                        op: op.to_string(),
1089                        expected: explicit_count,
1090                        got: args.len(),
1091                    });
1092                };
1093                let expected = declared_sort.subst(&theta);
1094                let arg_sort = typecheck_with_expected(arg, Some(&expected), ctx, theory, reports)?;
1095                push_sort_expr_eqs_into(&expected, &arg_sort, op, &mut term_eqs)?;
1096                theta.insert(Arc::clone(pname), arg.clone());
1097            }
1098        }
1099    }
1100
1101    let mgu = unify_all(term_eqs).map_err(|e| match e {
1102        GatError::SortUnificationFailure { reason } => GatError::SortUnificationFailure {
1103            reason: format!("implicit inference for {op}: {reason}"),
1104        },
1105        other => other,
1106    })?;
1107
1108    let mut final_subst = theta.clone();
1109    for (k, v) in &mgu {
1110        final_subst.insert(Arc::clone(k), v.clone());
1111    }
1112    let final_subst: FxHashMap<Arc<str>, Term> = final_subst
1113        .into_iter()
1114        .map(|(k, v)| (k, v.substitute(&mgu)))
1115        .collect();
1116
1117    Ok(operation.output.subst(&final_subst))
1118}
1119
1120fn count_holes(t: &Term) -> usize {
1121    match t {
1122        Term::Hole { .. } => 1,
1123        Term::Var(_) => 0,
1124        Term::App { args, .. } => args.iter().map(count_holes).sum(),
1125        Term::Case {
1126            scrutinee,
1127            branches,
1128        } => count_holes(scrutinee) + branches.iter().map(|b| count_holes(&b.body)).sum::<usize>(),
1129        Term::Let { bound, body, .. } => count_holes(bound) + count_holes(body),
1130    }
1131}
1132
1133fn term_contains_hole(t: &Term) -> bool {
1134    match t {
1135        Term::Hole { .. } => true,
1136        Term::Var(_) => false,
1137        Term::Let { bound, body, .. } => term_contains_hole(bound) || term_contains_hole(body),
1138        Term::App { args, .. } => args.iter().any(term_contains_hole),
1139        Term::Case {
1140            scrutinee,
1141            branches,
1142        } => term_contains_hole(scrutinee) || branches.iter().any(|b| term_contains_hole(&b.body)),
1143    }
1144}
1145
1146fn term_mentions_var(t: &Term, name: &Arc<str>) -> bool {
1147    match t {
1148        Term::Var(v) => v == name,
1149        Term::Hole { .. } => false,
1150        Term::Let {
1151            name: binder,
1152            bound,
1153            body,
1154        } => term_mentions_var(bound, name) || (binder != name && term_mentions_var(body, name)),
1155        Term::App { args, .. } => args.iter().any(|a| term_mentions_var(a, name)),
1156        Term::Case {
1157            scrutinee,
1158            branches,
1159        } => {
1160            term_mentions_var(scrutinee, name)
1161                || branches
1162                    .iter()
1163                    .any(|b| !b.binders.contains(name) && term_mentions_var(&b.body, name))
1164        }
1165    }
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170    use super::*;
1171    use crate::eq::Term;
1172    use crate::op::Operation;
1173    use crate::sort::{Sort, SortParam};
1174    use crate::theory::Theory;
1175
1176    fn monoid_theory() -> Theory {
1177        let carrier = Sort::simple("Carrier");
1178        let mul = Operation::new(
1179            "mul",
1180            vec![
1181                (Arc::from("a"), SortExpr::from("Carrier")),
1182                (Arc::from("b"), SortExpr::from("Carrier")),
1183            ],
1184            "Carrier",
1185        );
1186        let unit = Operation::nullary("unit", "Carrier");
1187
1188        let assoc = Equation::new(
1189            "assoc",
1190            Term::app(
1191                "mul",
1192                vec![
1193                    Term::var("a"),
1194                    Term::app("mul", vec![Term::var("b"), Term::var("c")]),
1195                ],
1196            ),
1197            Term::app(
1198                "mul",
1199                vec![
1200                    Term::app("mul", vec![Term::var("a"), Term::var("b")]),
1201                    Term::var("c"),
1202                ],
1203            ),
1204        );
1205        let left_id = Equation::new(
1206            "left_id",
1207            Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
1208            Term::var("a"),
1209        );
1210        let right_id = Equation::new(
1211            "right_id",
1212            Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
1213            Term::var("a"),
1214        );
1215
1216        Theory::new(
1217            "Monoid",
1218            vec![carrier],
1219            vec![mul, unit],
1220            vec![assoc, left_id, right_id],
1221        )
1222    }
1223
1224    fn two_sort_theory() -> Theory {
1225        Theory::new(
1226            "TwoSort",
1227            vec![Sort::simple("A"), Sort::simple("B")],
1228            vec![
1229                Operation::unary("f", "x", "A", "B"),
1230                Operation::unary("g", "x", "B", "A"),
1231                Operation::nullary("a0", "A"),
1232            ],
1233            vec![],
1234        )
1235    }
1236
1237    /// A minimal category-like theory used to exercise the dependent
1238    /// sort machinery. `Hom(a, b)` is the hom-sort; `id(x)` inhabits
1239    /// `Hom(x, x)`; `compose(f, g)` is the composition with the middle
1240    /// object shared between the two hom-sorts.
1241    fn category_theory() -> Theory {
1242        let ob = Sort::simple("Ob");
1243        let hom = Sort::dependent(
1244            "Hom",
1245            vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
1246        );
1247        let hom_xx = SortExpr::App {
1248            name: Arc::from("Hom"),
1249            args: vec![Term::var("x"), Term::var("x")],
1250        };
1251        let id = Operation::unary("id", "x", "Ob", hom_xx);
1252        let hom_src_mid = SortExpr::App {
1253            name: Arc::from("Hom"),
1254            args: vec![Term::var("x"), Term::var("y")],
1255        };
1256        let hom_mid_tgt = SortExpr::App {
1257            name: Arc::from("Hom"),
1258            args: vec![Term::var("y"), Term::var("z")],
1259        };
1260        let hom_src_tgt = SortExpr::App {
1261            name: Arc::from("Hom"),
1262            args: vec![Term::var("x"), Term::var("z")],
1263        };
1264        let compose = Operation::new(
1265            "compose",
1266            vec![
1267                (Arc::from("x"), SortExpr::from("Ob")),
1268                (Arc::from("y"), SortExpr::from("Ob")),
1269                (Arc::from("z"), SortExpr::from("Ob")),
1270                (Arc::from("f"), hom_src_mid),
1271                (Arc::from("g"), hom_mid_tgt),
1272            ],
1273            hom_src_tgt,
1274        );
1275        Theory::new("Category", vec![ob, hom], vec![id, compose], Vec::new())
1276    }
1277
1278    #[test]
1279    fn typecheck_variable() -> Result<(), Box<dyn std::error::Error>> {
1280        let theory = monoid_theory();
1281        let mut ctx = VarContext::default();
1282        ctx.insert(Arc::from("x"), SortExpr::from("Carrier"));
1283        let sort = typecheck_term(&Term::var("x"), &ctx, &theory)?;
1284        assert_eq!(&**sort.head(), "Carrier");
1285        Ok(())
1286    }
1287
1288    #[test]
1289    fn typecheck_unbound_variable() {
1290        let theory = monoid_theory();
1291        let ctx = VarContext::default();
1292        let result = typecheck_term(&Term::var("z"), &ctx, &theory);
1293        assert!(matches!(result, Err(GatError::UnboundVariable(_))));
1294    }
1295
1296    #[test]
1297    fn typecheck_constant() -> Result<(), Box<dyn std::error::Error>> {
1298        let theory = monoid_theory();
1299        let ctx = VarContext::default();
1300        let sort = typecheck_term(&Term::constant("unit"), &ctx, &theory)?;
1301        assert_eq!(&**sort.head(), "Carrier");
1302        Ok(())
1303    }
1304
1305    #[test]
1306    fn typecheck_binary_op() -> Result<(), Box<dyn std::error::Error>> {
1307        let theory = monoid_theory();
1308        let mut ctx = VarContext::default();
1309        ctx.insert(Arc::from("a"), SortExpr::from("Carrier"));
1310        ctx.insert(Arc::from("b"), SortExpr::from("Carrier"));
1311        let sort = typecheck_term(
1312            &Term::app("mul", vec![Term::var("a"), Term::var("b")]),
1313            &ctx,
1314            &theory,
1315        )?;
1316        assert_eq!(&**sort.head(), "Carrier");
1317        Ok(())
1318    }
1319
1320    #[test]
1321    fn typecheck_arity_mismatch() {
1322        let theory = monoid_theory();
1323        let mut ctx = VarContext::default();
1324        ctx.insert(Arc::from("a"), SortExpr::from("Carrier"));
1325        let result = typecheck_term(&Term::app("mul", vec![Term::var("a")]), &ctx, &theory);
1326        assert!(matches!(result, Err(GatError::TermArityMismatch { .. })));
1327    }
1328
1329    #[test]
1330    fn typecheck_sort_mismatch() {
1331        let theory = two_sort_theory();
1332        let mut ctx = VarContext::default();
1333        ctx.insert(Arc::from("x"), SortExpr::from("B"));
1334        // f expects A but we give it B
1335        let result = typecheck_term(&Term::app("f", vec![Term::var("x")]), &ctx, &theory);
1336        assert!(matches!(result, Err(GatError::ArgTypeMismatch { .. })));
1337    }
1338
1339    #[test]
1340    fn typecheck_nested_term() -> Result<(), Box<dyn std::error::Error>> {
1341        let theory = two_sort_theory();
1342        let ctx = VarContext::default();
1343        // g(f(a0())) : A -- should typecheck
1344        let term = Term::app("g", vec![Term::app("f", vec![Term::constant("a0")])]);
1345        let sort = typecheck_term(&term, &ctx, &theory)?;
1346        assert_eq!(&**sort.head(), "A");
1347        Ok(())
1348    }
1349
1350    #[test]
1351    fn typecheck_nested_sort_mismatch() {
1352        let theory = two_sort_theory();
1353        let ctx = VarContext::default();
1354        // f(f(a0())) -- inner f returns B, outer f expects A
1355        let term = Term::app("f", vec![Term::app("f", vec![Term::constant("a0")])]);
1356        let result = typecheck_term(&term, &ctx, &theory);
1357        assert!(matches!(result, Err(GatError::ArgTypeMismatch { .. })));
1358    }
1359
1360    #[test]
1361    fn typecheck_unknown_op() {
1362        let theory = monoid_theory();
1363        let ctx = VarContext::default();
1364        let result = typecheck_term(&Term::constant("nonexistent"), &ctx, &theory);
1365        assert!(matches!(result, Err(GatError::OpNotFound(_))));
1366    }
1367
1368    #[test]
1369    fn infer_var_sorts_monoid() -> Result<(), Box<dyn std::error::Error>> {
1370        let theory = monoid_theory();
1371        let eq = &theory.eqs[0]; // assoc
1372        let ctx = infer_var_sorts(eq, &theory)?;
1373        assert_eq!(ctx.len(), 3);
1374        assert_eq!(&**ctx[&Arc::from("a")].head(), "Carrier");
1375        assert_eq!(&**ctx[&Arc::from("b")].head(), "Carrier");
1376        assert_eq!(&**ctx[&Arc::from("c")].head(), "Carrier");
1377        Ok(())
1378    }
1379
1380    #[test]
1381    fn infer_var_sorts_identity_law() -> Result<(), Box<dyn std::error::Error>> {
1382        let theory = monoid_theory();
1383        let eq = &theory.eqs[1]; // left_id
1384        let ctx = infer_var_sorts(eq, &theory)?;
1385        assert_eq!(ctx.len(), 1);
1386        assert_eq!(&**ctx[&Arc::from("a")].head(), "Carrier");
1387        Ok(())
1388    }
1389
1390    #[test]
1391    fn conflicting_var_sort() {
1392        let theory = two_sort_theory();
1393        let eq = Equation::new(
1394            "bogus",
1395            Term::app("f", vec![Term::var("x")]),
1396            Term::app("g", vec![Term::var("x")]),
1397        );
1398        let result = infer_var_sorts(&eq, &theory);
1399        assert!(matches!(result, Err(GatError::ConflictingVarSort { .. })));
1400    }
1401
1402    #[test]
1403    fn typecheck_monoid_equations() -> Result<(), Box<dyn std::error::Error>> {
1404        let theory = monoid_theory();
1405        typecheck_theory(&theory)?;
1406        Ok(())
1407    }
1408
1409    #[test]
1410    fn typecheck_equation_sort_mismatch() {
1411        let theory = two_sort_theory();
1412        let eq = Equation::new(
1413            "bad",
1414            Term::app("f", vec![Term::constant("a0")]),
1415            Term::constant("a0"),
1416        );
1417        let result = typecheck_equation(&eq, &theory);
1418        assert!(matches!(result, Err(GatError::EquationSortMismatch { .. })));
1419    }
1420
1421    #[test]
1422    fn typecheck_graph_theory() -> Result<(), Box<dyn std::error::Error>> {
1423        let theory = Theory::new(
1424            "Graph",
1425            vec![Sort::simple("Vertex"), Sort::simple("Edge")],
1426            vec![
1427                Operation::unary("src", "e", "Edge", "Vertex"),
1428                Operation::unary("tgt", "e", "Edge", "Vertex"),
1429            ],
1430            vec![],
1431        );
1432        typecheck_theory(&theory)?;
1433        Ok(())
1434    }
1435
1436    #[test]
1437    fn typecheck_reflexive_graph_equations() -> Result<(), Box<dyn std::error::Error>> {
1438        let theory = Theory::new(
1439            "ReflexiveGraph",
1440            vec![Sort::simple("Vertex"), Sort::simple("Edge")],
1441            vec![
1442                Operation::unary("src", "e", "Edge", "Vertex"),
1443                Operation::unary("tgt", "e", "Edge", "Vertex"),
1444                Operation::unary("id", "v", "Vertex", "Edge"),
1445            ],
1446            vec![
1447                Equation::new(
1448                    "src_id",
1449                    Term::app("src", vec![Term::app("id", vec![Term::var("v")])]),
1450                    Term::var("v"),
1451                ),
1452                Equation::new(
1453                    "tgt_id",
1454                    Term::app("tgt", vec![Term::app("id", vec![Term::var("v")])]),
1455                    Term::var("v"),
1456                ),
1457            ],
1458        );
1459        typecheck_theory(&theory)?;
1460        Ok(())
1461    }
1462
1463    #[test]
1464    fn typecheck_symmetric_graph_equations() -> Result<(), Box<dyn std::error::Error>> {
1465        let theory = Theory::new(
1466            "SymmetricGraph",
1467            vec![Sort::simple("Vertex"), Sort::simple("Edge")],
1468            vec![
1469                Operation::unary("src", "e", "Edge", "Vertex"),
1470                Operation::unary("tgt", "e", "Edge", "Vertex"),
1471                Operation::unary("inv", "e", "Edge", "Edge"),
1472            ],
1473            vec![
1474                Equation::new(
1475                    "src_inv",
1476                    Term::app("src", vec![Term::app("inv", vec![Term::var("e")])]),
1477                    Term::app("tgt", vec![Term::var("e")]),
1478                ),
1479                Equation::new(
1480                    "tgt_inv",
1481                    Term::app("tgt", vec![Term::app("inv", vec![Term::var("e")])]),
1482                    Term::app("src", vec![Term::var("e")]),
1483                ),
1484                Equation::new(
1485                    "inv_inv",
1486                    Term::app("inv", vec![Term::app("inv", vec![Term::var("e")])]),
1487                    Term::var("e"),
1488                ),
1489            ],
1490        );
1491        typecheck_theory(&theory)?;
1492        Ok(())
1493    }
1494
1495    // --- Dependent-sort tests ---
1496
1497    #[test]
1498    fn typecheck_dependent_id_ok() -> Result<(), Box<dyn std::error::Error>> {
1499        let theory = category_theory();
1500        let mut ctx = VarContext::default();
1501        ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1502        let result = typecheck_term(&Term::app("id", vec![Term::var("x")]), &ctx, &theory)?;
1503        assert_eq!(&**result.head(), "Hom");
1504        assert_eq!(result.args().len(), 2);
1505        // Both args should be `x`.
1506        assert_eq!(result.args()[0], Term::var("x"));
1507        assert_eq!(result.args()[1], Term::var("x"));
1508        Ok(())
1509    }
1510
1511    #[test]
1512    fn typecheck_dependent_compose_ok() -> Result<(), Box<dyn std::error::Error>> {
1513        let theory = category_theory();
1514        let mut ctx = VarContext::default();
1515        ctx.insert(Arc::from("a"), SortExpr::from("Ob"));
1516        ctx.insert(Arc::from("b"), SortExpr::from("Ob"));
1517        ctx.insert(Arc::from("c"), SortExpr::from("Ob"));
1518        ctx.insert(
1519            Arc::from("f"),
1520            SortExpr::App {
1521                name: Arc::from("Hom"),
1522                args: vec![Term::var("a"), Term::var("b")],
1523            },
1524        );
1525        ctx.insert(
1526            Arc::from("g"),
1527            SortExpr::App {
1528                name: Arc::from("Hom"),
1529                args: vec![Term::var("b"), Term::var("c")],
1530            },
1531        );
1532        let term = Term::app(
1533            "compose",
1534            vec![
1535                Term::var("a"),
1536                Term::var("b"),
1537                Term::var("c"),
1538                Term::var("f"),
1539                Term::var("g"),
1540            ],
1541        );
1542        let result = typecheck_term(&term, &ctx, &theory)?;
1543        let expected = SortExpr::App {
1544            name: Arc::from("Hom"),
1545            args: vec![Term::var("a"), Term::var("c")],
1546        };
1547        assert!(result.alpha_eq(&expected), "got {result}");
1548        Ok(())
1549    }
1550
1551    #[test]
1552    fn typecheck_dependent_compose_arg_mismatch() {
1553        let theory = category_theory();
1554        let mut ctx = VarContext::default();
1555        ctx.insert(Arc::from("a"), SortExpr::from("Ob"));
1556        ctx.insert(Arc::from("b"), SortExpr::from("Ob"));
1557        ctx.insert(Arc::from("c"), SortExpr::from("Ob"));
1558        // f : Hom(a, b) and g : Hom(c, c). Middle object disagrees.
1559        ctx.insert(
1560            Arc::from("f"),
1561            SortExpr::App {
1562                name: Arc::from("Hom"),
1563                args: vec![Term::var("a"), Term::var("b")],
1564            },
1565        );
1566        ctx.insert(
1567            Arc::from("g"),
1568            SortExpr::App {
1569                name: Arc::from("Hom"),
1570                args: vec![Term::var("c"), Term::var("c")],
1571            },
1572        );
1573        let term = Term::app(
1574            "compose",
1575            vec![
1576                Term::var("a"),
1577                Term::var("b"),
1578                Term::var("c"),
1579                Term::var("f"),
1580                Term::var("g"),
1581            ],
1582        );
1583        let result = typecheck_term(&term, &ctx, &theory);
1584        assert!(
1585            matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1586            "expected ArgTypeMismatch, got {result:?}",
1587        );
1588    }
1589
1590    #[test]
1591    fn typecheck_dependent_equation_ok() -> Result<(), Box<dyn std::error::Error>> {
1592        // Build a category with the associativity equation
1593        // compose(a,b,d, f, compose(b,c,d, g, h))
1594        //   = compose(a,c,d, compose(a,b,c, f, g), h)
1595        let mut theory = category_theory();
1596        let assoc = Equation::new(
1597            "assoc",
1598            Term::app(
1599                "compose",
1600                vec![
1601                    Term::var("a"),
1602                    Term::var("b"),
1603                    Term::var("d"),
1604                    Term::var("f"),
1605                    Term::app(
1606                        "compose",
1607                        vec![
1608                            Term::var("b"),
1609                            Term::var("c"),
1610                            Term::var("d"),
1611                            Term::var("g"),
1612                            Term::var("h"),
1613                        ],
1614                    ),
1615                ],
1616            ),
1617            Term::app(
1618                "compose",
1619                vec![
1620                    Term::var("a"),
1621                    Term::var("c"),
1622                    Term::var("d"),
1623                    Term::app(
1624                        "compose",
1625                        vec![
1626                            Term::var("a"),
1627                            Term::var("b"),
1628                            Term::var("c"),
1629                            Term::var("f"),
1630                            Term::var("g"),
1631                        ],
1632                    ),
1633                    Term::var("h"),
1634                ],
1635            ),
1636        );
1637        theory.eqs.push(assoc);
1638        typecheck_theory(&theory)?;
1639        Ok(())
1640    }
1641
1642    // --- A3: unification soundness, occurs check, idempotence ---
1643
1644    #[test]
1645    fn unify_same_var_yields_empty_subst() -> Result<(), Box<dyn std::error::Error>> {
1646        let subst = unify_all(vec![(Term::var("x"), Term::var("x"))])?;
1647        assert!(subst.is_empty());
1648        Ok(())
1649    }
1650
1651    #[test]
1652    fn unify_var_to_constant_binds() -> Result<(), Box<dyn std::error::Error>> {
1653        let subst = unify_all(vec![(Term::var("x"), Term::constant("c"))])?;
1654        assert_eq!(subst.get(&Arc::from("x")), Some(&Term::constant("c")));
1655        Ok(())
1656    }
1657
1658    #[test]
1659    fn unify_occurs_check_fails() {
1660        // x = f(x) must fail the occurs check.
1661        let r = unify_all(vec![(Term::var("x"), Term::app("f", vec![Term::var("x")]))]);
1662        assert!(matches!(r, Err(GatError::SortUnificationFailure { .. })));
1663    }
1664
1665    #[test]
1666    fn unify_head_mismatch_fails() {
1667        let r = unify_all(vec![(
1668            Term::app("f", vec![Term::var("x")]),
1669            Term::app("g", vec![Term::var("x")]),
1670        )]);
1671        assert!(matches!(r, Err(GatError::SortUnificationFailure { .. })));
1672    }
1673
1674    #[test]
1675    fn unify_is_idempotent() -> Result<(), Box<dyn std::error::Error>> {
1676        // Unify f(x, y) = f(a, g(b)). Then applying the substitution twice
1677        // is the same as once.
1678        let eqs = vec![(
1679            Term::app("f", vec![Term::var("x"), Term::var("y")]),
1680            Term::app(
1681                "f",
1682                vec![Term::var("a"), Term::app("g", vec![Term::var("b")])],
1683            ),
1684        )];
1685        let subst = unify_all(eqs)?;
1686        // Apply to x and compare to apply-twice.
1687        for k in subst.keys() {
1688            let once = Term::var(Arc::clone(k)).substitute(&subst);
1689            let twice = once.substitute(&subst);
1690            assert_eq!(once, twice, "substitution not idempotent on {k}");
1691        }
1692        Ok(())
1693    }
1694
1695    #[test]
1696    fn unify_soundness_mgu_instantiates_both_sides() -> Result<(), Box<dyn std::error::Error>> {
1697        // f(x, g(y)) = f(h(a), g(b))
1698        let lhs = Term::app(
1699            "f",
1700            vec![Term::var("x"), Term::app("g", vec![Term::var("y")])],
1701        );
1702        let rhs = Term::app(
1703            "f",
1704            vec![
1705                Term::app("h", vec![Term::var("a")]),
1706                Term::app("g", vec![Term::var("b")]),
1707            ],
1708        );
1709        let subst = unify_all(vec![(lhs.clone(), rhs.clone())])?;
1710        let l2 = lhs.substitute(&subst);
1711        let r2 = rhs.substitute(&subst);
1712        assert_eq!(l2, r2);
1713        Ok(())
1714    }
1715
1716    // --- A4: typecheck idempotence and substitution commuting ---
1717
1718    #[test]
1719    fn typecheck_term_idempotent() -> Result<(), Box<dyn std::error::Error>> {
1720        let theory = category_theory();
1721        let mut ctx = VarContext::default();
1722        ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1723        let t = Term::app("id", vec![Term::var("x")]);
1724        let s1 = typecheck_term(&t, &ctx, &theory)?;
1725        let s2 = typecheck_term(&t, &ctx, &theory)?;
1726        assert_eq!(s1, s2);
1727        Ok(())
1728    }
1729
1730    #[test]
1731    fn typecheck_context_strengthening() -> Result<(), Box<dyn std::error::Error>> {
1732        let theory = category_theory();
1733        let mut ctx = VarContext::default();
1734        ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1735        let t = Term::app("id", vec![Term::var("x")]);
1736        let s1 = typecheck_term(&t, &ctx, &theory)?;
1737        // Extend ctx with unrelated var.
1738        ctx.insert(Arc::from("unused"), SortExpr::from("Ob"));
1739        let s2 = typecheck_term(&t, &ctx, &theory)?;
1740        assert_eq!(s1, s2);
1741        Ok(())
1742    }
1743
1744    #[test]
1745    fn typecheck_substitution_commutes() -> Result<(), Box<dyn std::error::Error>> {
1746        // typecheck(t, ctx) = s implies typecheck(t.subst(sigma), ctx.subst(sigma)) = s.subst(sigma)
1747        let theory = category_theory();
1748        let mut ctx = VarContext::default();
1749        ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1750        let t = Term::app("id", vec![Term::var("x")]);
1751        let s = typecheck_term(&t, &ctx, &theory)?;
1752
1753        // sigma maps x to a new variable y : Ob.
1754        let mut sigma: FxHashMap<Arc<str>, Term> = FxHashMap::default();
1755        sigma.insert(Arc::from("x"), Term::var("y"));
1756
1757        let t_prime = t.substitute(&sigma);
1758        let mut ctx_prime = VarContext::default();
1759        ctx_prime.insert(Arc::from("y"), SortExpr::from("Ob"));
1760
1761        let s_prime = typecheck_term(&t_prime, &ctx_prime, &theory)?;
1762        let s_expected = s.subst(&sigma);
1763        assert!(
1764            s_prime.alpha_eq(&s_expected),
1765            "got {s_prime}, expected {s_expected}"
1766        );
1767        Ok(())
1768    }
1769
1770    // --- dependent-sort middle-object agreement ---
1771    //
1772    // A shared parameter that appears in more than one argument's sort
1773    // (the middle object in `compose(f : Hom(a, b), g : Hom(b, c))`)
1774    // must take the same value in every occurrence. `typecheck_term`
1775    // propagates the substitution theta left-to-right and compares
1776    // each argument's inferred sort against the expected input sort
1777    // under theta via strict `alpha_eq`, so mismatched derivations of
1778    // a shared parameter are rejected.
1779
1780    #[test]
1781    fn compose_with_disagreeing_middle_object_is_rejected() {
1782        // compose : (x, y, z : Ob, f : Hom(x, y), g : Hom(y, z)) ->
1783        // Hom(x, z). Supply f : Hom(p, q) and g : Hom(r, s) with q != r
1784        // and call compose with an explicit middle-object choice that
1785        // cannot satisfy both input-sort constraints at once.
1786        let theory = category_theory();
1787        let mut ctx = VarContext::default();
1788        ctx.insert(Arc::from("p"), SortExpr::from("Ob"));
1789        ctx.insert(Arc::from("q"), SortExpr::from("Ob"));
1790        ctx.insert(Arc::from("r"), SortExpr::from("Ob"));
1791        ctx.insert(Arc::from("s"), SortExpr::from("Ob"));
1792        ctx.insert(
1793            Arc::from("f"),
1794            SortExpr::App {
1795                name: Arc::from("Hom"),
1796                args: vec![Term::var("p"), Term::var("q")],
1797            },
1798        );
1799        ctx.insert(
1800            Arc::from("g"),
1801            SortExpr::App {
1802                name: Arc::from("Hom"),
1803                args: vec![Term::var("r"), Term::var("s")],
1804            },
1805        );
1806        // Whatever Ob we pick for the middle argument, one of f or g
1807        // cannot match it: f wants middle = q, g wants middle = r, and
1808        // q and r are distinct Obs.
1809        let term = Term::app(
1810            "compose",
1811            vec![
1812                Term::var("p"),
1813                Term::var("q"),
1814                Term::var("s"),
1815                Term::var("f"),
1816                Term::var("g"),
1817            ],
1818        );
1819        let result = typecheck_term(&term, &ctx, &theory);
1820        assert!(
1821            matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1822            "compose with mismatched middle object must be rejected, got {result:?}",
1823        );
1824    }
1825
1826    #[test]
1827    fn compose_of_identity_with_unrelated_arrow_is_rejected() {
1828        // compose(id(p), f) where f : Hom(q, r) with p != q. id(p)
1829        // has sort Hom(p, p); compose forces the middle object to
1830        // equal p in the first slot and q in the second, contradicting.
1831        let theory = category_theory();
1832        let mut ctx = VarContext::default();
1833        ctx.insert(Arc::from("p"), SortExpr::from("Ob"));
1834        ctx.insert(Arc::from("q"), SortExpr::from("Ob"));
1835        ctx.insert(Arc::from("r"), SortExpr::from("Ob"));
1836        ctx.insert(
1837            Arc::from("f"),
1838            SortExpr::App {
1839                name: Arc::from("Hom"),
1840                args: vec![Term::var("q"), Term::var("r")],
1841            },
1842        );
1843        // Explicit Obs: (src = p, mid = p, tgt = r) forces g's sort
1844        // check with expected Hom(p, r) but actual Hom(q, r), so q != p
1845        // fails.
1846        let term = Term::app(
1847            "compose",
1848            vec![
1849                Term::var("p"),
1850                Term::var("p"),
1851                Term::var("r"),
1852                Term::app("id", vec![Term::var("p")]),
1853                Term::var("f"),
1854            ],
1855        );
1856        let result = typecheck_term(&term, &ctx, &theory);
1857        assert!(
1858            matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1859            "compose(id(p), f) with src(f) != p must be rejected, got {result:?}",
1860        );
1861    }
1862
1863    #[test]
1864    fn compose_of_two_identities_at_distinct_objects_is_rejected() {
1865        // compose(id(p), id(q)) with p != q. id(p) : Hom(p, p), id(q)
1866        // : Hom(q, q); the two identities cannot share a middle object
1867        // and no choice of the explicit middle-object arguments makes
1868        // both input-sort checks pass.
1869        let theory = category_theory();
1870        let mut ctx = VarContext::default();
1871        ctx.insert(Arc::from("p"), SortExpr::from("Ob"));
1872        ctx.insert(Arc::from("q"), SortExpr::from("Ob"));
1873        // Choose the middle as p; then id(p) : Hom(p, p) is ok for the
1874        // first hom-slot, but id(q) : Hom(q, q) cannot match the
1875        // expected Hom(p, q).
1876        let term = Term::app(
1877            "compose",
1878            vec![
1879                Term::var("p"),
1880                Term::var("p"),
1881                Term::var("q"),
1882                Term::app("id", vec![Term::var("p")]),
1883                Term::app("id", vec![Term::var("q")]),
1884            ],
1885        );
1886        let result = typecheck_term(&term, &ctx, &theory);
1887        assert!(
1888            matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1889            "compose(id(p), id(q)) with p != q must be rejected, got {result:?}",
1890        );
1891    }
1892
1893    // --- negative typecheck cases returning specific error variants ---
1894
1895    #[test]
1896    fn equation_with_dependent_sort_arg_mismatch_errors() {
1897        // Equation whose argument sort does not unify with the
1898        // declared input sort. f : Hom(a, b); the equation uses f on
1899        // a term of simple sort Ob, which cannot typecheck.
1900        let theory = category_theory();
1901        let eq = Equation::new(
1902            "bad",
1903            Term::app("id", vec![Term::app("id", vec![Term::var("x")])]),
1904            Term::var("x"),
1905        );
1906        // id(id(x)): inner id(x) has sort Hom(x, x), but outer id
1907        // expects Ob. Typechecking this equation should error.
1908        let result = typecheck_equation(&eq, &theory);
1909        assert!(
1910            result.is_err(),
1911            "equation with argument-sort mismatch must error, got {result:?}",
1912        );
1913    }
1914
1915    #[test]
1916    fn equation_with_unknown_op_errors() {
1917        let theory = monoid_theory();
1918        let eq = Equation::new(
1919            "bad",
1920            Term::app("mystery", vec![Term::var("a")]),
1921            Term::var("a"),
1922        );
1923        let result = typecheck_equation(&eq, &theory);
1924        assert!(
1925            matches!(result, Err(GatError::OpNotFound(_))),
1926            "equation referencing unknown op must error, got {result:?}",
1927        );
1928    }
1929
1930    #[test]
1931    fn equation_with_arity_mismatch_errors() {
1932        let theory = monoid_theory();
1933        let eq = Equation::new(
1934            "bad",
1935            Term::app("mul", vec![Term::var("a")]),
1936            Term::var("a"),
1937        );
1938        let result = typecheck_equation(&eq, &theory);
1939        assert!(
1940            matches!(result, Err(GatError::TermArityMismatch { .. })),
1941            "equation with arity mismatch must error, got {result:?}",
1942        );
1943    }
1944
1945    #[test]
1946    fn dependent_sort_with_ill_typed_arg_errors() {
1947        // Passing a Hom-sorted term in an Ob-sorted argument slot
1948        // must error rather than silently proceed; this exercises the
1949        // case where the inferred sort has the wrong head altogether.
1950        let theory = category_theory();
1951        let mut ctx = VarContext::default();
1952        ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1953        ctx.insert(
1954            Arc::from("f"),
1955            SortExpr::App {
1956                name: Arc::from("Hom"),
1957                args: vec![Term::var("x"), Term::var("x")],
1958            },
1959        );
1960        // Pass f (which is a Hom, not an Ob) in the src-Ob position.
1961        let term = Term::app(
1962            "compose",
1963            vec![
1964                Term::var("f"),
1965                Term::var("x"),
1966                Term::var("x"),
1967                Term::var("f"),
1968                Term::var("f"),
1969            ],
1970        );
1971        let result = typecheck_term(&term, &ctx, &theory);
1972        assert!(
1973            matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1974            "ill-typed dependent-sort argument must error, got {result:?}",
1975        );
1976    }
1977
1978    // --- closed sorts and case terms (1.2) ---
1979
1980    /// Theory with closed sort `Nat` and constructors `zero`, `succ`.
1981    fn nat_theory() -> Theory {
1982        let nat = Sort::closed(
1983            "Nat",
1984            Vec::new(),
1985            [Arc::from("zero") as Arc<str>, Arc::from("succ")],
1986        );
1987        let zero = Operation::nullary("zero", "Nat");
1988        let succ = Operation::unary("succ", "n", "Nat", "Nat");
1989        Theory::new("NatTh", vec![nat], vec![zero, succ], Vec::new())
1990    }
1991
1992    #[test]
1993    fn closed_sort_exhaustive_case_typechecks() -> Result<(), Box<dyn std::error::Error>> {
1994        let theory = nat_theory();
1995        let mut ctx = VarContext::default();
1996        ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
1997        let term = Term::Case {
1998            scrutinee: Box::new(Term::var("n")),
1999            branches: vec![
2000                CaseBranch {
2001                    constructor: Arc::from("zero"),
2002                    binders: Vec::new(),
2003                    body: Term::constant("zero"),
2004                },
2005                CaseBranch {
2006                    constructor: Arc::from("succ"),
2007                    binders: vec![Arc::from("m")],
2008                    body: Term::var("m"),
2009                },
2010            ],
2011        };
2012        let sort = typecheck_term(&term, &ctx, &theory)?;
2013        assert_eq!(&**sort.head(), "Nat");
2014        Ok(())
2015    }
2016
2017    #[test]
2018    fn closed_sort_missing_branch_rejected() {
2019        let theory = nat_theory();
2020        let mut ctx = VarContext::default();
2021        ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
2022        let term = Term::Case {
2023            scrutinee: Box::new(Term::var("n")),
2024            branches: vec![CaseBranch {
2025                constructor: Arc::from("zero"),
2026                binders: Vec::new(),
2027                body: Term::constant("zero"),
2028            }],
2029        };
2030        let result = typecheck_term(&term, &ctx, &theory);
2031        assert!(
2032            matches!(result, Err(GatError::NonExhaustiveCase { .. })),
2033            "got {result:?}"
2034        );
2035    }
2036
2037    #[test]
2038    fn closed_sort_redundant_branch_rejected() {
2039        let theory = nat_theory();
2040        let mut ctx = VarContext::default();
2041        ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
2042        let term = Term::Case {
2043            scrutinee: Box::new(Term::var("n")),
2044            branches: vec![
2045                CaseBranch {
2046                    constructor: Arc::from("zero"),
2047                    binders: Vec::new(),
2048                    body: Term::constant("zero"),
2049                },
2050                CaseBranch {
2051                    constructor: Arc::from("zero"),
2052                    binders: Vec::new(),
2053                    body: Term::constant("zero"),
2054                },
2055            ],
2056        };
2057        let result = typecheck_term(&term, &ctx, &theory);
2058        assert!(
2059            matches!(result, Err(GatError::RedundantCaseBranch { .. })),
2060            "got {result:?}"
2061        );
2062    }
2063
2064    #[test]
2065    fn case_on_open_sort_rejected() {
2066        // Use a plain open sort `Vertex` with a `v0` nullary.
2067        let v = Sort::simple("Vertex");
2068        let v0 = Operation::nullary("v0", "Vertex");
2069        let theory = Theory::new("Open", vec![v], vec![v0], Vec::new());
2070        let mut ctx = VarContext::default();
2071        ctx.insert(Arc::from("x"), SortExpr::from("Vertex"));
2072        let term = Term::Case {
2073            scrutinee: Box::new(Term::var("x")),
2074            branches: vec![CaseBranch {
2075                constructor: Arc::from("v0"),
2076                binders: Vec::new(),
2077                body: Term::constant("v0"),
2078            }],
2079        };
2080        let result = typecheck_term(&term, &ctx, &theory);
2081        assert!(
2082            matches!(result, Err(GatError::CaseOnOpenSort { .. })),
2083            "got {result:?}"
2084        );
2085    }
2086
2087    #[test]
2088    fn case_unknown_constructor_rejected() {
2089        let theory = nat_theory();
2090        let mut ctx = VarContext::default();
2091        ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
2092        let term = Term::Case {
2093            scrutinee: Box::new(Term::var("n")),
2094            branches: vec![
2095                CaseBranch {
2096                    constructor: Arc::from("nope"),
2097                    binders: Vec::new(),
2098                    body: Term::constant("zero"),
2099                },
2100                CaseBranch {
2101                    constructor: Arc::from("succ"),
2102                    binders: vec![Arc::from("m")],
2103                    body: Term::var("m"),
2104                },
2105            ],
2106        };
2107        let result = typecheck_term(&term, &ctx, &theory);
2108        assert!(
2109            matches!(result, Err(GatError::UnknownCaseConstructor { .. })),
2110            "got {result:?}"
2111        );
2112    }
2113
2114    #[test]
2115    fn closed_sort_rejects_external_constructor() {
2116        // A sort closed against [zero] but another op also produces Nat:
2117        // typecheck_theory must reject this.
2118        let nat = Sort::closed("Nat", Vec::new(), [Arc::from("zero") as Arc<str>]);
2119        let zero = Operation::nullary("zero", "Nat");
2120        let sneaky = Operation::nullary("sneaky", "Nat");
2121        let theory = Theory::new("BadClosure", vec![nat], vec![zero, sneaky], Vec::new());
2122        let result = typecheck_theory(&theory);
2123        assert!(
2124            matches!(result, Err(GatError::InvalidClosedSortConstructor { .. })),
2125            "got {result:?}"
2126        );
2127    }
2128
2129    #[test]
2130    fn morphism_preserves_closure_constructors() -> Result<(), Box<dyn std::error::Error>> {
2131        use crate::morphism::{TheoryMorphism, check_morphism};
2132        use std::collections::HashMap;
2133
2134        let nat1 = nat_theory();
2135        // Codomain theory: Nat' closed against [zero', succ']
2136        let nat_prime = Sort::closed(
2137            "Nat",
2138            Vec::new(),
2139            [Arc::from("zero2") as Arc<str>, Arc::from("succ2")],
2140        );
2141        let zero2 = Operation::nullary("zero2", "Nat");
2142        let succ2 = Operation::unary("succ2", "n", "Nat", "Nat");
2143        let nat2 = Theory::new("NatTh2", vec![nat_prime], vec![zero2, succ2], Vec::new());
2144
2145        let mut sort_map = HashMap::new();
2146        sort_map.insert(Arc::from("Nat"), Arc::from("Nat"));
2147        let mut op_map = HashMap::new();
2148        op_map.insert(Arc::from("zero"), Arc::from("zero2"));
2149        op_map.insert(Arc::from("succ"), Arc::from("succ2"));
2150        let m = TheoryMorphism::new("m", "NatTh", "NatTh2", sort_map, op_map);
2151        check_morphism(&m, &nat1, &nat2)?;
2152
2153        // Now swap succ2 to an unrelated op in the codomain closure and
2154        // verify the morphism check fails.
2155        let nat_prime_bad = Sort::closed(
2156            "Nat",
2157            Vec::new(),
2158            [Arc::from("zero2") as Arc<str>, Arc::from("other")],
2159        );
2160        let other = Operation::unary("other", "n", "Nat", "Nat");
2161        let nat2_bad = Theory::new(
2162            "NatTh2",
2163            vec![nat_prime_bad],
2164            vec![Operation::nullary("zero2", "Nat"), other],
2165            Vec::new(),
2166        );
2167        let result = check_morphism(&m, &nat1, &nat2_bad);
2168        assert!(
2169            matches!(result, Err(GatError::MorphismClosureMismatch { .. })),
2170            "got {result:?}"
2171        );
2172        Ok(())
2173    }
2174
2175    #[test]
2176    fn case_term_substitution_respects_binder_shadow() {
2177        // case n of zero => m | succ(m) => m end
2178        // Under substitution [m := succ(zero())], the zero branch's
2179        // body becomes succ(zero()), but the succ branch's body is
2180        // shadowed by its binder `m` and remains `m`.
2181        let term = Term::Case {
2182            scrutinee: Box::new(Term::var("n")),
2183            branches: vec![
2184                CaseBranch {
2185                    constructor: Arc::from("zero"),
2186                    binders: Vec::new(),
2187                    body: Term::var("m"),
2188                },
2189                CaseBranch {
2190                    constructor: Arc::from("succ"),
2191                    binders: vec![Arc::from("m")],
2192                    body: Term::var("m"),
2193                },
2194            ],
2195        };
2196        let mut subst = FxHashMap::default();
2197        subst.insert(
2198            Arc::from("m"),
2199            Term::app("succ", vec![Term::constant("zero")]),
2200        );
2201        let result = term.substitute(&subst);
2202        let Term::Case { branches, .. } = &result else {
2203            panic!("expected Case, got {result:?}");
2204        };
2205        assert_eq!(
2206            branches[0].body,
2207            Term::app("succ", vec![Term::constant("zero")]),
2208            "zero branch body should be substituted"
2209        );
2210        assert_eq!(
2211            branches[1].body,
2212            Term::var("m"),
2213            "succ branch body must be shadowed, body stays `m`"
2214        );
2215    }
2216
2217    // --- implicit arg inference (1.1) ---
2218
2219    /// Build a minimal lambda-calculus-style theory with an `app`
2220    /// operation whose first two arguments are implicit type witnesses.
2221    ///
2222    /// Sorts: `Ty` (simple), `Tm(t : Ty)` (dependent).
2223    /// Ops:
2224    /// - `arrow : (a : Ty, b : Ty) -> Ty`
2225    /// - `app : {a : Ty}{b : Ty}(f : Tm(arrow(a, b)), x : Tm(a)) -> Tm(b)`
2226    fn lambda_theory() -> Theory {
2227        use crate::op::Implicit;
2228        let ty = Sort::simple("Ty");
2229        let tm = Sort::dependent("Tm", vec![SortParam::new("t", "Ty")]);
2230        let arrow = Operation::new(
2231            "arrow",
2232            vec![
2233                (Arc::from("a"), SortExpr::from("Ty")),
2234                (Arc::from("b"), SortExpr::from("Ty")),
2235            ],
2236            "Ty",
2237        );
2238        let tm_a = SortExpr::App {
2239            name: Arc::from("Tm"),
2240            args: vec![Term::var("a")],
2241        };
2242        let tm_b = SortExpr::App {
2243            name: Arc::from("Tm"),
2244            args: vec![Term::var("b")],
2245        };
2246        let tm_arrow = SortExpr::App {
2247            name: Arc::from("Tm"),
2248            args: vec![Term::app("arrow", vec![Term::var("a"), Term::var("b")])],
2249        };
2250        let app = Operation::with_implicit(
2251            "app",
2252            vec![
2253                (Arc::from("a"), SortExpr::from("Ty"), Implicit::Yes),
2254                (Arc::from("b"), SortExpr::from("Ty"), Implicit::Yes),
2255                (Arc::from("f"), tm_arrow, Implicit::No),
2256                (Arc::from("x"), tm_a, Implicit::No),
2257            ],
2258            tm_b,
2259        );
2260        Theory::new("Lambda", vec![ty, tm], vec![arrow, app], Vec::new())
2261    }
2262
2263    #[test]
2264    fn app_with_inferred_implicit_types() -> Result<(), Box<dyn std::error::Error>> {
2265        let theory = lambda_theory();
2266        let mut ctx = VarContext::default();
2267        ctx.insert(Arc::from("A"), SortExpr::from("Ty"));
2268        ctx.insert(Arc::from("B"), SortExpr::from("Ty"));
2269        ctx.insert(
2270            Arc::from("f"),
2271            SortExpr::App {
2272                name: Arc::from("Tm"),
2273                args: vec![Term::app("arrow", vec![Term::var("A"), Term::var("B")])],
2274            },
2275        );
2276        ctx.insert(
2277            Arc::from("x"),
2278            SortExpr::App {
2279                name: Arc::from("Tm"),
2280                args: vec![Term::var("A")],
2281            },
2282        );
2283        // Call `app(f, x)` without the implicit type witnesses.
2284        let result = typecheck_term(
2285            &Term::app("app", vec![Term::var("f"), Term::var("x")]),
2286            &ctx,
2287            &theory,
2288        )?;
2289        let expected = SortExpr::App {
2290            name: Arc::from("Tm"),
2291            args: vec![Term::var("B")],
2292        };
2293        assert!(result.alpha_eq(&expected), "got {result}");
2294        Ok(())
2295    }
2296
2297    #[test]
2298    fn implicit_inference_rejects_overconstrained_call() {
2299        use crate::op::Implicit;
2300        // Build a theory extended with two ground `Ty` constants
2301        // `first_ty` and `second_ty`. A call to `app(f, x)` where f's
2302        // domain is the first but x's type is the second must fail
2303        // unification of the implicit `a`.
2304        let type_decl = Sort::simple("Ty");
2305        let term_decl = Sort::dependent("Tm", vec![SortParam::new("t", "Ty")]);
2306        let first_ty = Operation::nullary("tyA", "Ty");
2307        let second_ty = Operation::nullary("tyB", "Ty");
2308        let arrow = Operation::new(
2309            "arrow",
2310            vec![
2311                (Arc::from("a"), SortExpr::from("Ty")),
2312                (Arc::from("b"), SortExpr::from("Ty")),
2313            ],
2314            "Ty",
2315        );
2316        let tm_of_a = SortExpr::App {
2317            name: Arc::from("Tm"),
2318            args: vec![Term::var("a")],
2319        };
2320        let tm_of_b = SortExpr::App {
2321            name: Arc::from("Tm"),
2322            args: vec![Term::var("b")],
2323        };
2324        let tm_of_arrow = SortExpr::App {
2325            name: Arc::from("Tm"),
2326            args: vec![Term::app("arrow", vec![Term::var("a"), Term::var("b")])],
2327        };
2328        let app = Operation::with_implicit(
2329            "app",
2330            vec![
2331                (Arc::from("a"), SortExpr::from("Ty"), Implicit::Yes),
2332                (Arc::from("b"), SortExpr::from("Ty"), Implicit::Yes),
2333                (Arc::from("f"), tm_of_arrow, Implicit::No),
2334                (Arc::from("x"), tm_of_a, Implicit::No),
2335            ],
2336            tm_of_b,
2337        );
2338        let theory = Theory::new(
2339            "LambdaGround",
2340            vec![type_decl, term_decl],
2341            vec![first_ty, second_ty, arrow, app],
2342            Vec::new(),
2343        );
2344
2345        let mut ctx = VarContext::default();
2346        ctx.insert(
2347            Arc::from("f"),
2348            SortExpr::App {
2349                name: Arc::from("Tm"),
2350                args: vec![Term::app(
2351                    "arrow",
2352                    vec![Term::constant("tyA"), Term::constant("tyB")],
2353                )],
2354            },
2355        );
2356        ctx.insert(
2357            Arc::from("x"),
2358            SortExpr::App {
2359                name: Arc::from("Tm"),
2360                args: vec![Term::constant("tyB")],
2361            },
2362        );
2363        let result = typecheck_term(
2364            &Term::app("app", vec![Term::var("f"), Term::var("x")]),
2365            &ctx,
2366            &theory,
2367        );
2368        assert!(
2369            matches!(result, Err(GatError::SortUnificationFailure { .. })),
2370            "overconstrained implicit inference must fail: got {result:?}",
2371        );
2372    }
2373
2374    #[test]
2375    fn implicit_declaration_rejected_when_not_inferrable() {
2376        use crate::op::Implicit;
2377        // Op with implicit param `c` that appears in neither any
2378        // explicit input sort nor the output sort: not inferrable.
2379        let foo = Operation::with_implicit(
2380            "foo",
2381            vec![
2382                (Arc::from("a"), SortExpr::from("Ty"), Implicit::No),
2383                (Arc::from("c"), SortExpr::from("Ty"), Implicit::Yes),
2384            ],
2385            SortExpr::from("Ty"),
2386        );
2387        let theory = Theory::new(
2388            "BadImplicit",
2389            vec![Sort::simple("Ty")],
2390            vec![foo],
2391            Vec::new(),
2392        );
2393        let result = typecheck_theory(&theory);
2394        assert!(
2395            matches!(result, Err(GatError::NonInferrableImplicit { .. })),
2396            "non-inferrable implicit must be rejected: got {result:?}",
2397        );
2398    }
2399
2400    #[test]
2401    fn app_without_implicits_still_typechecks() -> Result<(), Box<dyn std::error::Error>> {
2402        // Sanity check that operations with no implicit inputs still
2403        // traverse the old explicit-theta path.
2404        let theory = category_theory();
2405        let mut ctx = VarContext::default();
2406        ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
2407        let result = typecheck_term(&Term::app("id", vec![Term::var("x")]), &ctx, &theory)?;
2408        assert_eq!(&**result.head(), "Hom");
2409        Ok(())
2410    }
2411
2412    #[test]
2413    fn monomorphic_let_typechecks() -> Result<(), Box<dyn std::error::Error>> {
2414        // let x = unit in mul(x, x) : Carrier.
2415        let theory = monoid_theory();
2416        let ctx = VarContext::default();
2417        let t = Term::Let {
2418            name: Arc::from("x"),
2419            bound: Box::new(Term::constant("unit")),
2420            body: Box::new(Term::app("mul", vec![Term::var("x"), Term::var("x")])),
2421        };
2422        let sort = typecheck_term(&t, &ctx, &theory)?;
2423        assert_eq!(&**sort.head(), "Carrier");
2424        Ok(())
2425    }
2426
2427    #[test]
2428    fn equation_with_hole_is_rejected() {
2429        let theory = monoid_theory();
2430        let eq = Equation::new(
2431            "bad",
2432            Term::app("mul", vec![Term::var("a"), Term::Hole { name: None }]),
2433            Term::var("a"),
2434        );
2435        let result = typecheck_equation(&eq, &theory);
2436        assert!(matches!(result, Err(GatError::HolesInEquation { .. })));
2437    }
2438
2439    // --- proptest property tests ---
2440
2441    mod property {
2442        use super::*;
2443        use proptest::prelude::*;
2444
2445        const SORT_POOL: &[&str] = &["S0", "S1", "S2", "S3"];
2446
2447        /// Generate a well-typed theory: only simple sorts and operations
2448        /// with correct sort references (no equations).
2449        fn arb_well_typed_theory() -> impl Strategy<Value = Theory> {
2450            prop::sample::subsequence(SORT_POOL, 1..=4).prop_flat_map(|sort_names| {
2451                let sorts: Vec<Sort> = sort_names.iter().map(|s| Sort::simple(*s)).collect();
2452                let sn: Vec<String> = sort_names.iter().map(|s| (*s).to_owned()).collect();
2453                let sn2 = sn.clone();
2454                (
2455                    Just(sorts),
2456                    prop::collection::vec(
2457                        (
2458                            0..4usize,
2459                            prop::sample::select(sn),
2460                            prop::sample::select(sn2),
2461                        ),
2462                        0..=3,
2463                    ),
2464                )
2465                    .prop_map(|(sorts, op_specs)| {
2466                        let mut ops = Vec::new();
2467                        let mut seen = std::collections::HashSet::new();
2468                        for (i, (_, input_sort, output_sort)) in op_specs.iter().enumerate() {
2469                            let name = format!("op{i}");
2470                            if !seen.insert(name.clone()) {
2471                                continue;
2472                            }
2473                            ops.push(Operation::unary(
2474                                &*name,
2475                                "x",
2476                                input_sort.as_str(),
2477                                output_sort.as_str(),
2478                            ));
2479                        }
2480                        Theory::new("TypecheckTest", sorts, ops, Vec::new())
2481                    })
2482            })
2483        }
2484
2485        /// Generate a closed `Nat` theory plus a random list of
2486        /// constructor names chosen from `{zero, succ}` as branch
2487        /// constructors, possibly with missing/duplicate/unknown ones.
2488        fn arb_case_on_nat() -> impl Strategy<Value = (Theory, Vec<Arc<str>>)> {
2489            let nat = Sort::closed(
2490                "Nat",
2491                Vec::new(),
2492                [Arc::from("zero") as Arc<str>, Arc::from("succ")],
2493            );
2494            let zero = Operation::nullary("zero", "Nat");
2495            let succ = Operation::unary("succ", "n", "Nat", "Nat");
2496            let theory = Theory::new("NatTh", vec![nat], vec![zero, succ], Vec::new());
2497            (
2498                Just(theory),
2499                prop::collection::vec(
2500                    prop::sample::select(vec![
2501                        Arc::from("zero"),
2502                        Arc::from("succ"),
2503                        Arc::from("bogus"),
2504                    ] as Vec<Arc<str>>),
2505                    0..=3,
2506                ),
2507            )
2508        }
2509
2510        proptest! {
2511            #![proptest_config(ProptestConfig::with_cases(256))]
2512
2513            #[test]
2514            fn case_on_closed_sort_never_panics(
2515                (theory, ctors) in arb_case_on_nat()
2516            ) {
2517                let mut ctx = VarContext::default();
2518                ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
2519                let branches: Vec<CaseBranch> = ctors
2520                    .into_iter()
2521                    .map(|c| CaseBranch {
2522                        constructor: c,
2523                        binders: Vec::new(),
2524                        body: Term::constant("zero"),
2525                    })
2526                    .collect();
2527                let term = Term::Case {
2528                    scrutinee: Box::new(Term::var("n")),
2529                    branches,
2530                };
2531                // Must return a well-typed sort or a specific GatError
2532                // variant; it must not panic.
2533                let r = typecheck_term(&term, &ctx, &theory);
2534                match r {
2535                    Ok(_)
2536                    | Err(
2537                        GatError::NonExhaustiveCase { .. }
2538                        | GatError::RedundantCaseBranch { .. }
2539                        | GatError::UnknownCaseConstructor { .. }
2540                        | GatError::OpTypeMismatch { .. }
2541                        | GatError::TermArityMismatch { .. },
2542                    ) => {}
2543                    other => prop_assert!(false, "unexpected result: {other:?}"),
2544                }
2545            }
2546
2547            #[test]
2548            fn typecheck_is_idempotent(t in arb_well_typed_theory()) {
2549                let result1 = typecheck_theory(&t);
2550                let result2 = typecheck_theory(&t);
2551                prop_assert_eq!(result1.is_ok(), result2.is_ok());
2552            }
2553
2554            #[test]
2555            fn well_typed_theory_passes(t in arb_well_typed_theory()) {
2556                prop_assert!(
2557                    typecheck_theory(&t).is_ok(),
2558                    "well-typed theory should pass typecheck",
2559                );
2560            }
2561
2562            #[test]
2563            fn implicit_inference_stable_across_names(
2564                a_name in prop::sample::select(&["A", "B", "C", "P", "Q"][..]).prop_map(Arc::from),
2565                b_name in prop::sample::select(&["A", "B", "C", "P", "Q"][..]).prop_map(Arc::from),
2566            ) {
2567                use crate::op::Implicit;
2568                // Theory with `arrow` and implicit-argument `app`. For
2569                // any choice of two ground `Ty` constants in ctx, the
2570                // inferred output sort is `Tm(b)` with `b` bound to the
2571                // codomain of the function's declared arrow. Running
2572                // typecheck twice yields the same sort.
2573                let ty = Sort::simple("Ty");
2574                let tm = Sort::dependent("Tm", vec![SortParam::new("t", "Ty")]);
2575                let arrow = Operation::new(
2576                    "arrow",
2577                    vec![
2578                        (Arc::from("a"), SortExpr::from("Ty")),
2579                        (Arc::from("b"), SortExpr::from("Ty")),
2580                    ],
2581                    "Ty",
2582                );
2583                let tm_a = SortExpr::App {
2584                    name: Arc::from("Tm"),
2585                    args: vec![Term::var("a")],
2586                };
2587                let tm_b = SortExpr::App {
2588                    name: Arc::from("Tm"),
2589                    args: vec![Term::var("b")],
2590                };
2591                let tm_arrow = SortExpr::App {
2592                    name: Arc::from("Tm"),
2593                    args: vec![Term::app(
2594                        "arrow",
2595                        vec![Term::var("a"), Term::var("b")],
2596                    )],
2597                };
2598                let app = Operation::with_implicit(
2599                    "app",
2600                    vec![
2601                        (Arc::from("a"), SortExpr::from("Ty"), Implicit::Yes),
2602                        (Arc::from("b"), SortExpr::from("Ty"), Implicit::Yes),
2603                        (Arc::from("f"), tm_arrow, Implicit::No),
2604                        (Arc::from("x"), tm_a, Implicit::No),
2605                    ],
2606                    tm_b,
2607                );
2608                let theory = Theory::new("Lambda", vec![ty, tm], vec![arrow, app], Vec::new());
2609
2610                let mut ctx = VarContext::default();
2611                ctx.insert(Arc::clone(&a_name), SortExpr::from("Ty"));
2612                if a_name != b_name {
2613                    ctx.insert(Arc::clone(&b_name), SortExpr::from("Ty"));
2614                }
2615                ctx.insert(
2616                    Arc::from("f"),
2617                    SortExpr::App {
2618                        name: Arc::from("Tm"),
2619                        args: vec![Term::app(
2620                            "arrow",
2621                            vec![Term::Var(Arc::clone(&a_name)), Term::Var(Arc::clone(&b_name))],
2622                        )],
2623                    },
2624                );
2625                ctx.insert(
2626                    Arc::from("x"),
2627                    SortExpr::App {
2628                        name: Arc::from("Tm"),
2629                        args: vec![Term::Var(Arc::clone(&a_name))],
2630                    },
2631                );
2632                let call = Term::app("app", vec![Term::var("f"), Term::var("x")]);
2633                let s1 = typecheck_term(&call, &ctx, &theory);
2634                let s2 = typecheck_term(&call, &ctx, &theory);
2635                prop_assert_eq!(s1.is_ok(), s2.is_ok());
2636                if let (Ok(a), Ok(b)) = (&s1, &s2) {
2637                    prop_assert!(a.alpha_eq(b));
2638                }
2639            }
2640
2641            #[test]
2642            fn unification_soundness_on_congruent_pairs(
2643                c1 in prop::sample::select(&["a", "b", "c"][..]),
2644                c2 in prop::sample::select(&["a", "b", "c"][..]),
2645            ) {
2646                // f(x, y) = f(c1, c2) under unification: the substitution
2647                // must make both sides equal.
2648                let lhs = Term::app(
2649                    "f",
2650                    vec![Term::var("x"), Term::var("y")],
2651                );
2652                let rhs = Term::app(
2653                    "f",
2654                    vec![Term::constant(c1), Term::constant(c2)],
2655                );
2656                let subst = match unify_all(vec![(lhs.clone(), rhs.clone())]) {
2657                    Ok(s) => s,
2658                    Err(e) => {
2659                        prop_assert!(false, "unify failed: {e}");
2660                        return Ok(());
2661                    }
2662                };
2663                let l2 = lhs.substitute(&subst);
2664                let r2 = rhs.substitute(&subst);
2665                prop_assert_eq!(l2, r2);
2666            }
2667        }
2668    }
2669}