1use std::sync::Arc;
11
12use rustc_hash::FxHashMap;
13
14use crate::eq::{CaseBranch, Equation, Term};
15use crate::error::GatError;
16use crate::op::{Implicit, Operation};
17use crate::sort::{SortClosure, SortExpr};
18use crate::theory::Theory;
19
20#[derive(Debug, Clone)]
28pub struct SortScheme {
29 pub metavars: Vec<Arc<str>>,
31 pub body: SortExpr,
33}
34
35impl SortScheme {
36 #[must_use]
38 pub const fn mono(body: SortExpr) -> Self {
39 Self {
40 metavars: Vec::new(),
41 body,
42 }
43 }
44
45 #[must_use]
52 pub fn instantiate(&self, counter: usize) -> SortExpr {
53 if self.metavars.is_empty() {
54 return self.body.clone();
55 }
56 let mut subst: FxHashMap<Arc<str>, crate::eq::Term> = FxHashMap::default();
57 for mv in &self.metavars {
58 let fresh: Arc<str> = Arc::from(format!("{mv}_inst_{counter}"));
59 subst.insert(Arc::clone(mv), crate::eq::Term::Var(fresh));
60 }
61 self.body.subst(&subst)
62 }
63}
64
65#[derive(Debug, Clone)]
68pub struct HoleReport {
69 pub name: Option<Arc<str>>,
71 pub expected: SortExpr,
75 pub context: VarContext,
77 pub position: Option<miette::SourceSpan>,
79}
80
81pub type VarContext = FxHashMap<Arc<str>, SortExpr>;
88
89pub fn typecheck_term(
108 term: &Term,
109 ctx: &VarContext,
110 theory: &Theory,
111) -> Result<SortExpr, GatError> {
112 match term {
113 Term::Var(name) => ctx
114 .get(name)
115 .cloned()
116 .ok_or_else(|| GatError::UnboundVariable(name.to_string())),
117
118 Term::Hole { name } => {
119 let mv: Arc<str> = Arc::from(format!("?{}", name.as_deref().unwrap_or("hole")));
128 Ok(SortExpr::Name(mv))
129 }
130
131 Term::App { op, args } => {
132 let operation = theory
133 .find_op(op)
134 .ok_or_else(|| GatError::OpNotFound(op.to_string()))?;
135
136 let has_implicits = operation
137 .inputs
138 .iter()
139 .any(|(_, _, imp)| matches!(imp, Implicit::Yes));
140 if has_implicits {
141 typecheck_app_with_implicits(op, args, operation, ctx, theory)
142 } else {
143 typecheck_app_explicit(op, args, operation, ctx, theory)
144 }
145 }
146
147 Term::Case {
148 scrutinee,
149 branches,
150 } => typecheck_case(scrutinee, branches, ctx, theory),
151
152 Term::Let { name, bound, body } => {
153 let bound_sort = typecheck_term(bound, ctx, theory)?;
162 let mut extended = ctx.clone();
163 extended.insert(Arc::clone(name), bound_sort);
164 typecheck_term(body, &extended, theory)
165 }
166 }
167}
168
169fn typecheck_case(
171 scrutinee: &Term,
172 branches: &[CaseBranch],
173 ctx: &VarContext,
174 theory: &Theory,
175) -> Result<SortExpr, GatError> {
176 let scrutinee_sort = typecheck_term(scrutinee, ctx, theory)?;
177 let sort_name = scrutinee_sort.head();
178 let sort_decl = theory
179 .find_sort(sort_name)
180 .ok_or_else(|| GatError::SortNotFound(sort_name.to_string()))?;
181 let constructors = match &sort_decl.closure {
182 SortClosure::Open => {
183 return Err(GatError::CaseOnOpenSort {
184 sort: sort_name.to_string(),
185 });
186 }
187 SortClosure::Closed(cs) => cs.clone(),
188 };
189
190 let mut seen: rustc_hash::FxHashSet<Arc<str>> = rustc_hash::FxHashSet::default();
194 for b in branches {
195 if !constructors.contains(&b.constructor) {
196 return Err(GatError::UnknownCaseConstructor {
197 sort: sort_name.to_string(),
198 constructor: b.constructor.to_string(),
199 });
200 }
201 if !seen.insert(Arc::clone(&b.constructor)) {
202 return Err(GatError::RedundantCaseBranch {
203 sort: sort_name.to_string(),
204 constructor: b.constructor.to_string(),
205 });
206 }
207 }
208 if seen.len() < constructors.len() {
209 let missing: Vec<String> = constructors
210 .iter()
211 .filter(|c| !seen.contains(*c))
212 .map(ToString::to_string)
213 .collect();
214 return Err(GatError::NonExhaustiveCase {
215 sort: sort_name.to_string(),
216 missing,
217 });
218 }
219
220 let mut branch_sort: Option<SortExpr> = None;
224 for b in branches {
225 let constructor_op = theory
226 .find_op(&b.constructor)
227 .ok_or_else(|| GatError::OpNotFound(b.constructor.to_string()))?;
228 if constructor_op.inputs.len() != b.binders.len() {
229 return Err(GatError::TermArityMismatch {
230 op: b.constructor.to_string(),
231 expected: constructor_op.inputs.len(),
232 got: b.binders.len(),
233 });
234 }
235 let unify_eqs: Vec<(Term, Term)> = constructor_op
239 .output
240 .args()
241 .iter()
242 .zip(scrutinee_sort.args().iter())
243 .map(|(a, b)| (a.clone(), b.clone()))
244 .collect();
245 if constructor_op.output.head() != scrutinee_sort.head()
246 || constructor_op.output.args().len() != scrutinee_sort.args().len()
247 {
248 return Err(GatError::OpTypeMismatch {
249 op: b.constructor.to_string(),
250 detail: format!(
251 "constructor output sort {} does not match scrutinee sort {scrutinee_sort}",
252 constructor_op.output
253 ),
254 });
255 }
256 let subst = unify_all(unify_eqs)?;
257 let mut extended = ctx.clone();
258 for ((_, declared_sort, _), binder) in constructor_op.inputs.iter().zip(b.binders.iter()) {
259 let binder_sort = declared_sort.subst(&subst);
260 extended.insert(Arc::clone(binder), binder_sort);
261 }
262 let body_sort = typecheck_term(&b.body, &extended, theory)?;
263 match &branch_sort {
264 None => branch_sort = Some(body_sort),
265 Some(existing) => {
266 if !existing.alpha_eq(&body_sort) {
267 return Err(GatError::EquationSortMismatch {
268 equation: "case".to_string(),
269 lhs_sort: existing.to_string(),
270 rhs_sort: body_sort.to_string(),
271 });
272 }
273 }
274 }
275 }
276
277 branch_sort.ok_or_else(|| GatError::NonExhaustiveCase {
278 sort: sort_name.to_string(),
279 missing: constructors.iter().map(ToString::to_string).collect(),
280 })
281}
282
283fn typecheck_app_explicit(
288 op: &Arc<str>,
289 args: &[Term],
290 operation: &Operation,
291 ctx: &VarContext,
292 theory: &Theory,
293) -> Result<SortExpr, GatError> {
294 if args.len() != operation.inputs.len() {
295 return Err(GatError::TermArityMismatch {
296 op: op.to_string(),
297 expected: operation.inputs.len(),
298 got: args.len(),
299 });
300 }
301
302 let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
303 for (i, (arg, (param_name, declared_sort, _))) in
304 args.iter().zip(operation.inputs.iter()).enumerate()
305 {
306 let arg_sort = typecheck_term(arg, ctx, theory)?;
307 let expected = declared_sort.subst(&theta);
308 if !arg_sort.alpha_eq(&expected) {
309 return Err(GatError::ArgTypeMismatch {
310 op: op.to_string(),
311 arg_index: i,
312 expected: expected.to_string(),
313 got: arg_sort.to_string(),
314 });
315 }
316 theta.insert(Arc::clone(param_name), arg.clone());
317 }
318
319 Ok(operation.output.subst(&theta))
320}
321
322fn typecheck_app_with_implicits(
331 op: &Arc<str>,
332 args: &[Term],
333 operation: &Operation,
334 ctx: &VarContext,
335 theory: &Theory,
336) -> Result<SortExpr, GatError> {
337 let explicit_count = operation.explicit_arity();
338 if args.len() != explicit_count {
339 return Err(GatError::TermArityMismatch {
340 op: op.to_string(),
341 expected: explicit_count,
342 got: args.len(),
343 });
344 }
345
346 let mut fresh_rename: FxHashMap<Arc<str>, Term> = FxHashMap::default();
348 for (idx, (pname, _, imp)) in operation.inputs.iter().enumerate() {
349 if matches!(imp, Implicit::Yes) {
350 let mv: Arc<str> = Arc::from(format!("?{pname}_{idx}"));
351 fresh_rename.insert(Arc::clone(pname), Term::Var(mv));
352 }
353 }
354
355 let mut theta: FxHashMap<Arc<str>, Term> = fresh_rename.clone();
360 let mut term_eqs: Vec<(Term, Term)> = Vec::new();
361 let mut explicit_iter = args.iter();
362 for (pname, declared_sort, imp) in &operation.inputs {
363 match imp {
364 Implicit::Yes => {
365 }
368 Implicit::No => {
369 let Some(arg) = explicit_iter.next() else {
370 return Err(GatError::TermArityMismatch {
371 op: op.to_string(),
372 expected: explicit_count,
373 got: args.len(),
374 });
375 };
376 let arg_sort = typecheck_term(arg, ctx, theory)?;
377 let expected = declared_sort.subst(&theta);
378 push_sort_expr_eqs_into(&expected, &arg_sort, op, &mut term_eqs)?;
379 theta.insert(Arc::clone(pname), arg.clone());
380 }
381 }
382 }
383
384 let mgu = unify_all(term_eqs).map_err(|e| match e {
385 GatError::SortUnificationFailure { reason } => GatError::SortUnificationFailure {
386 reason: format!("implicit inference for {op}: {reason}"),
387 },
388 other => other,
389 })?;
390
391 let mut final_subst = theta.clone();
394 for (k, v) in &mgu {
395 final_subst.insert(Arc::clone(k), v.clone());
396 }
397 let final_subst: FxHashMap<Arc<str>, Term> = final_subst
399 .into_iter()
400 .map(|(k, v)| (k, v.substitute(&mgu)))
401 .collect();
402
403 Ok(operation.output.subst(&final_subst))
404}
405
406fn push_sort_expr_eqs_into(
413 expected: &SortExpr,
414 actual: &SortExpr,
415 op: &Arc<str>,
416 term_eqs: &mut Vec<(Term, Term)>,
417) -> Result<(), GatError> {
418 if expected.head() != actual.head() || expected.args().len() != actual.args().len() {
419 return Err(GatError::ArgTypeMismatch {
420 op: op.to_string(),
421 arg_index: 0,
422 expected: expected.to_string(),
423 got: actual.to_string(),
424 });
425 }
426 for (x, y) in expected.args().iter().zip(actual.args().iter()) {
427 term_eqs.push((x.clone(), y.clone()));
428 }
429 Ok(())
430}
431
432pub fn infer_var_sorts(eq: &Equation, theory: &Theory) -> Result<VarContext, GatError> {
449 let mut ctx = VarContext::default();
450 let mut term_eqs: Vec<(Term, Term)> = Vec::new();
451 collect_constraints(&eq.lhs, theory, &mut ctx, &mut term_eqs)?;
452 collect_constraints(&eq.rhs, theory, &mut ctx, &mut term_eqs)?;
453
454 let substitution = unify_all(term_eqs)?;
455 if !substitution.is_empty() {
456 for sort in ctx.values_mut() {
457 *sort = sort.subst(&substitution);
458 }
459 }
460 Ok(ctx)
461}
462
463fn collect_constraints(
468 term: &Term,
469 theory: &Theory,
470 ctx: &mut VarContext,
471 term_eqs: &mut Vec<(Term, Term)>,
472) -> Result<(), GatError> {
473 let (op, args) = match term {
474 Term::App { op, args } => (op, args),
475 Term::Case {
476 scrutinee,
477 branches,
478 } => {
479 collect_constraints(scrutinee, theory, ctx, term_eqs)?;
480 for b in branches {
481 collect_constraints(&b.body, theory, ctx, term_eqs)?;
482 }
483 return Ok(());
484 }
485 Term::Let { bound, body, .. } => {
486 collect_constraints(bound, theory, ctx, term_eqs)?;
487 collect_constraints(body, theory, ctx, term_eqs)?;
488 return Ok(());
489 }
490 Term::Var(_) | Term::Hole { .. } => return Ok(()),
491 };
492 let operation = theory
493 .find_op(op)
494 .ok_or_else(|| GatError::OpNotFound(op.to_string()))?;
495
496 if args.len() != operation.inputs.len() {
497 return Err(GatError::TermArityMismatch {
498 op: op.to_string(),
499 expected: operation.inputs.len(),
500 got: args.len(),
501 });
502 }
503
504 let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
505 for (arg, (param_name, declared_sort, _)) in args.iter().zip(operation.inputs.iter()) {
506 let expected = declared_sort.subst(&theta);
507 match arg {
508 Term::Var(var_name) => {
509 if let Some(existing) = ctx.get(var_name).cloned() {
510 unify_sort_exprs(&existing, &expected, var_name, term_eqs)?;
511 } else {
512 ctx.insert(Arc::clone(var_name), expected);
513 }
514 }
515 Term::App { .. } | Term::Case { .. } | Term::Hole { .. } | Term::Let { .. } => {
516 collect_constraints(arg, theory, ctx, term_eqs)?;
517 }
518 }
519 theta.insert(Arc::clone(param_name), arg.clone());
520 }
521 Ok(())
522}
523
524fn unify_sort_exprs(
531 a: &SortExpr,
532 b: &SortExpr,
533 var: &Arc<str>,
534 term_eqs: &mut Vec<(Term, Term)>,
535) -> Result<(), GatError> {
536 if a.head() != b.head() {
537 return Err(GatError::ConflictingVarSort {
538 var: var.to_string(),
539 sort1: a.to_string(),
540 sort2: b.to_string(),
541 });
542 }
543 let a_args = a.args();
544 let b_args = b.args();
545 if a_args.len() != b_args.len() {
546 return Err(GatError::ConflictingVarSort {
547 var: var.to_string(),
548 sort1: a.to_string(),
549 sort2: b.to_string(),
550 });
551 }
552 for (x, y) in a_args.iter().zip(b_args.iter()) {
553 term_eqs.push((x.clone(), y.clone()));
554 }
555 Ok(())
556}
557
558fn unify_all(mut eqs: Vec<(Term, Term)>) -> Result<FxHashMap<Arc<str>, Term>, GatError> {
565 let mut subst: FxHashMap<Arc<str>, Term> = FxHashMap::default();
566
567 while let Some((a, b)) = eqs.pop() {
568 let a = apply_subst(&a, &subst);
569 let b = apply_subst(&b, &subst);
570 match (a, b) {
571 (Term::Var(x), Term::Var(y)) if x == y => {}
572 (Term::Var(x), t) | (t, Term::Var(x)) => {
573 if occurs_in(&x, &t) {
574 return Err(GatError::SortUnificationFailure {
575 reason: format!("occurs check failed: {x} in {t}"),
576 });
577 }
578 let updated: FxHashMap<Arc<str>, Term> = subst
580 .iter()
581 .map(|(k, v)| {
582 (
583 Arc::clone(k),
584 v.substitute(&std::iter::once((Arc::clone(&x), t.clone())).collect()),
585 )
586 })
587 .collect();
588 subst = updated;
589 subst.insert(x, t);
590 }
591 (
592 Term::App {
593 op: op_a,
594 args: args_a,
595 },
596 Term::App {
597 op: op_b,
598 args: args_b,
599 },
600 ) => {
601 if op_a != op_b {
602 return Err(GatError::SortUnificationFailure {
603 reason: format!("cannot unify {op_a}(...) with {op_b}(...)"),
604 });
605 }
606 if args_a.len() != args_b.len() {
607 return Err(GatError::SortUnificationFailure {
608 reason: format!(
609 "arity mismatch unifying {op_a}: {} vs {}",
610 args_a.len(),
611 args_b.len()
612 ),
613 });
614 }
615 for pair in args_a.into_iter().zip(args_b) {
616 eqs.push(pair);
617 }
618 }
619 (lhs, rhs) => {
620 return Err(GatError::SortUnificationFailure {
621 reason: format!("cannot unify {lhs} with {rhs}"),
622 });
623 }
624 }
625 }
626
627 Ok(subst)
628}
629
630fn apply_subst(term: &Term, subst: &FxHashMap<Arc<str>, Term>) -> Term {
631 if subst.is_empty() {
632 return term.clone();
633 }
634 term.substitute(subst)
635}
636
637fn occurs_in(var: &Arc<str>, term: &Term) -> bool {
638 match term {
639 Term::Var(v) => v == var,
640 Term::Hole { .. } => false,
641 Term::Let { name, bound, body } => {
642 occurs_in(var, bound) || (name != var && occurs_in(var, body))
643 }
644 Term::App { args, .. } => args.iter().any(|a| occurs_in(var, a)),
645 Term::Case {
646 scrutinee,
647 branches,
648 } => {
649 occurs_in(var, scrutinee)
650 || branches
651 .iter()
652 .any(|b| !b.binders.contains(var) && occurs_in(var, &b.body))
653 }
654 }
655}
656
657pub fn typecheck_term_with_holes(
672 term: &Term,
673 ctx: &VarContext,
674 theory: &Theory,
675) -> Result<(SortExpr, Vec<HoleReport>), GatError> {
676 let mut reports: Vec<HoleReport> = Vec::new();
677 let sort = typecheck_with_expected(term, None, ctx, theory, &mut reports)?;
678 Ok((sort, reports))
679}
680
681fn typecheck_with_expected(
682 term: &Term,
683 expected: Option<&SortExpr>,
684 ctx: &VarContext,
685 theory: &Theory,
686 reports: &mut Vec<HoleReport>,
687) -> Result<SortExpr, GatError> {
688 match term {
689 Term::Hole { name } => {
690 let sort = expected.cloned().unwrap_or_else(|| {
691 SortExpr::Name(Arc::from(format!("?{}", name.as_deref().unwrap_or("hole"))))
692 });
693 reports.push(HoleReport {
694 name: name.clone(),
695 expected: sort.clone(),
696 context: ctx.clone(),
697 position: None,
698 });
699 Ok(sort)
700 }
701 Term::Var(n) => ctx
702 .get(n)
703 .cloned()
704 .ok_or_else(|| GatError::UnboundVariable(n.to_string())),
705 Term::App { op, args } => {
706 let operation = theory
707 .find_op(op)
708 .ok_or_else(|| GatError::OpNotFound(op.to_string()))?;
709 let has_implicits = operation
710 .inputs
711 .iter()
712 .any(|(_, _, imp)| matches!(imp, Implicit::Yes));
713 if has_implicits {
714 typecheck_app_with_implicits_collecting_holes(
720 op, args, operation, ctx, theory, reports,
721 )
722 } else {
723 if args.len() != operation.inputs.len() {
724 return Err(GatError::TermArityMismatch {
725 op: op.to_string(),
726 expected: operation.inputs.len(),
727 got: args.len(),
728 });
729 }
730 let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
731 for (i, (arg, (param_name, declared_sort, _))) in
732 args.iter().zip(operation.inputs.iter()).enumerate()
733 {
734 let expected_sort = declared_sort.subst(&theta);
735 let arg_sort =
736 typecheck_with_expected(arg, Some(&expected_sort), ctx, theory, reports)?;
737 if !term_contains_hole(arg) && !arg_sort.alpha_eq(&expected_sort) {
741 return Err(GatError::ArgTypeMismatch {
742 op: op.to_string(),
743 arg_index: i,
744 expected: expected_sort.to_string(),
745 got: arg_sort.to_string(),
746 });
747 }
748 theta.insert(Arc::clone(param_name), arg.clone());
749 }
750 Ok(operation.output.subst(&theta))
751 }
752 }
753 Term::Case {
754 scrutinee,
755 branches,
756 } => typecheck_case_with_holes(scrutinee, branches, ctx, theory, reports),
757 Term::Let { name, bound, body } => {
758 let bound_sort = typecheck_with_expected(bound, None, ctx, theory, reports)?;
759 let mut extended = ctx.clone();
760 extended.insert(Arc::clone(name), bound_sort);
761 typecheck_with_expected(body, None, &extended, theory, reports)
762 }
763 }
764}
765
766fn typecheck_case_with_holes(
767 scrutinee: &Term,
768 branches: &[CaseBranch],
769 ctx: &VarContext,
770 theory: &Theory,
771 reports: &mut Vec<HoleReport>,
772) -> Result<SortExpr, GatError> {
773 let scrutinee_sort = typecheck_with_expected(scrutinee, None, ctx, theory, reports)?;
774 check_case_exhaustiveness_soft(&scrutinee_sort, branches, theory)?;
775 let mut branch_sort: Option<SortExpr> = None;
776 for b in branches {
777 let constructor_op = theory
778 .find_op(&b.constructor)
779 .ok_or_else(|| GatError::OpNotFound(b.constructor.to_string()))?;
780 if constructor_op.inputs.len() != b.binders.len() {
781 return Err(GatError::TermArityMismatch {
782 op: b.constructor.to_string(),
783 expected: constructor_op.inputs.len(),
784 got: b.binders.len(),
785 });
786 }
787 if constructor_op.output.head() != scrutinee_sort.head()
792 || constructor_op.output.args().len() != scrutinee_sort.args().len()
793 {
794 return Err(GatError::OpTypeMismatch {
795 op: b.constructor.to_string(),
796 detail: format!(
797 "constructor output sort {} does not match scrutinee sort {scrutinee_sort}",
798 constructor_op.output
799 ),
800 });
801 }
802 let unify_eqs: Vec<(Term, Term)> = constructor_op
803 .output
804 .args()
805 .iter()
806 .zip(scrutinee_sort.args().iter())
807 .map(|(a, b)| (a.clone(), b.clone()))
808 .collect();
809 let subst = unify_all(unify_eqs)?;
810 let mut extended = ctx.clone();
811 for ((_, declared_sort, _), binder) in constructor_op.inputs.iter().zip(b.binders.iter()) {
812 let binder_sort = declared_sort.subst(&subst);
813 extended.insert(Arc::clone(binder), binder_sort);
814 }
815 let body_sort = typecheck_with_expected(&b.body, None, &extended, theory, reports)?;
816 match &branch_sort {
817 None => branch_sort = Some(body_sort),
818 Some(existing) => {
819 if !existing.alpha_eq(&body_sort) {
823 return Err(GatError::EquationSortMismatch {
824 equation: "case".to_string(),
825 lhs_sort: existing.to_string(),
826 rhs_sort: body_sort.to_string(),
827 });
828 }
829 }
830 }
831 }
832 branch_sort.ok_or_else(|| GatError::NonExhaustiveCase {
833 sort: scrutinee_sort.head().to_string(),
834 missing: Vec::new(),
835 })
836}
837
838fn check_case_exhaustiveness_soft(
839 scrutinee_sort: &SortExpr,
840 branches: &[CaseBranch],
841 theory: &Theory,
842) -> Result<(), GatError> {
843 let Some(sort_decl) = theory.find_sort(scrutinee_sort.head()) else {
844 return Ok(());
845 };
846 let SortClosure::Closed(ctors) = &sort_decl.closure else {
847 return Ok(());
848 };
849 let mut seen: rustc_hash::FxHashSet<Arc<str>> = rustc_hash::FxHashSet::default();
850 for b in branches {
851 if !ctors.contains(&b.constructor) {
852 return Err(GatError::UnknownCaseConstructor {
853 sort: scrutinee_sort.head().to_string(),
854 constructor: b.constructor.to_string(),
855 });
856 }
857 if !seen.insert(Arc::clone(&b.constructor)) {
858 return Err(GatError::RedundantCaseBranch {
859 sort: scrutinee_sort.head().to_string(),
860 constructor: b.constructor.to_string(),
861 });
862 }
863 }
864 if seen.len() < ctors.len() {
865 let missing: Vec<String> = ctors
866 .iter()
867 .filter(|c| !seen.contains(*c))
868 .map(ToString::to_string)
869 .collect();
870 return Err(GatError::NonExhaustiveCase {
871 sort: scrutinee_sort.head().to_string(),
872 missing,
873 });
874 }
875 Ok(())
876}
877
878pub fn typecheck_equation(eq: &Equation, theory: &Theory) -> Result<(), GatError> {
887 let hole_count = count_holes(&eq.lhs) + count_holes(&eq.rhs);
888 if hole_count > 0 {
889 return Err(GatError::HolesInEquation {
890 equation: eq.name.to_string(),
891 count: hole_count,
892 });
893 }
894 let ctx = infer_var_sorts(eq, theory)?;
895 let lhs_sort = typecheck_term(&eq.lhs, &ctx, theory)?;
896 let rhs_sort = typecheck_term(&eq.rhs, &ctx, theory)?;
897 if !lhs_sort.alpha_eq(&rhs_sort) {
898 return Err(GatError::EquationSortMismatch {
899 equation: eq.name.to_string(),
900 lhs_sort: lhs_sort.to_string(),
901 rhs_sort: rhs_sort.to_string(),
902 });
903 }
904 Ok(())
905}
906
907pub fn typecheck_equation_modulo_rewrites(
921 eq: &Equation,
922 theory: &Theory,
923 rules: &[crate::eq::DirectedEquation],
924 step_limit: usize,
925) -> Result<(), GatError> {
926 let hole_count = count_holes(&eq.lhs) + count_holes(&eq.rhs);
927 if hole_count > 0 {
928 return Err(GatError::HolesInEquation {
929 equation: eq.name.to_string(),
930 count: hole_count,
931 });
932 }
933 let ctx = infer_var_sorts(eq, theory)?;
934 let lhs_sort = typecheck_term(&eq.lhs, &ctx, theory)?;
935 let rhs_sort = typecheck_term(&eq.rhs, &ctx, theory)?;
936 if !lhs_sort.alpha_eq_modulo_rewrites(&rhs_sort, rules, step_limit) {
937 return Err(GatError::EquationSortMismatch {
938 equation: eq.name.to_string(),
939 lhs_sort: lhs_sort.to_string(),
940 rhs_sort: rhs_sort.to_string(),
941 });
942 }
943 Ok(())
944}
945
946pub fn typecheck_theory(theory: &Theory) -> Result<(), GatError> {
958 for op in &theory.ops {
959 check_implicits_inferrable(op)?;
960 }
961 check_closed_sorts(theory)?;
962 for eq in &theory.eqs {
963 typecheck_equation(eq, theory)?;
964 }
965 Ok(())
966}
967
968fn check_closed_sorts(theory: &Theory) -> Result<(), GatError> {
976 for sort in &theory.sorts {
977 let SortClosure::Closed(ctors) = &sort.closure else {
978 continue;
979 };
980 let ctor_set: rustc_hash::FxHashSet<Arc<str>> = ctors.iter().map(Arc::clone).collect();
981 for ctor in ctors {
982 let op =
983 theory
984 .find_op(ctor)
985 .ok_or_else(|| GatError::InvalidClosedSortConstructor {
986 sort: sort.name.to_string(),
987 constructor: ctor.to_string(),
988 detail: "op does not exist in the theory".to_string(),
989 })?;
990 if op.output.head() != &sort.name {
991 return Err(GatError::InvalidClosedSortConstructor {
992 sort: sort.name.to_string(),
993 constructor: ctor.to_string(),
994 detail: format!(
995 "op output head is {}, expected {}",
996 op.output.head(),
997 sort.name
998 ),
999 });
1000 }
1001 }
1002 for op in &theory.ops {
1003 if op.output.head() == &sort.name && !ctor_set.contains(&op.name) {
1004 return Err(GatError::InvalidClosedSortConstructor {
1005 sort: sort.name.to_string(),
1006 constructor: op.name.to_string(),
1007 detail: "op produces the closed sort but is not listed in its closure"
1008 .to_string(),
1009 });
1010 }
1011 }
1012 }
1013 Ok(())
1014}
1015
1016fn check_implicits_inferrable(op: &Operation) -> Result<(), GatError> {
1019 for (pname, _, imp) in &op.inputs {
1020 if !matches!(imp, Implicit::Yes) {
1021 continue;
1022 }
1023 let mut found = false;
1024 for (_, sort_expr, other_imp) in &op.inputs {
1025 if matches!(other_imp, Implicit::No) && sort_expr_mentions_var(sort_expr, pname) {
1026 found = true;
1027 break;
1028 }
1029 }
1030 if !found && sort_expr_mentions_var(&op.output, pname) {
1031 found = true;
1032 }
1033 if !found {
1034 return Err(GatError::NonInferrableImplicit {
1035 op: op.name.to_string(),
1036 param: pname.to_string(),
1037 });
1038 }
1039 }
1040 Ok(())
1041}
1042
1043fn sort_expr_mentions_var(sort: &SortExpr, name: &Arc<str>) -> bool {
1046 sort.args().iter().any(|t| term_mentions_var(t, name))
1047}
1048
1049fn typecheck_app_with_implicits_collecting_holes(
1055 op: &Arc<str>,
1056 args: &[Term],
1057 operation: &Operation,
1058 ctx: &VarContext,
1059 theory: &Theory,
1060 reports: &mut Vec<HoleReport>,
1061) -> Result<SortExpr, GatError> {
1062 let explicit_count = operation.explicit_arity();
1063 if args.len() != explicit_count {
1064 return Err(GatError::TermArityMismatch {
1065 op: op.to_string(),
1066 expected: explicit_count,
1067 got: args.len(),
1068 });
1069 }
1070
1071 let mut fresh_rename: FxHashMap<Arc<str>, Term> = FxHashMap::default();
1072 for (idx, (pname, _, imp)) in operation.inputs.iter().enumerate() {
1073 if matches!(imp, Implicit::Yes) {
1074 let mv: Arc<str> = Arc::from(format!("?{pname}_{idx}"));
1075 fresh_rename.insert(Arc::clone(pname), Term::Var(mv));
1076 }
1077 }
1078
1079 let mut theta: FxHashMap<Arc<str>, Term> = fresh_rename.clone();
1080 let mut term_eqs: Vec<(Term, Term)> = Vec::new();
1081 let mut explicit_iter = args.iter();
1082 for (pname, declared_sort, imp) in &operation.inputs {
1083 match imp {
1084 Implicit::Yes => {}
1085 Implicit::No => {
1086 let Some(arg) = explicit_iter.next() else {
1087 return Err(GatError::TermArityMismatch {
1088 op: op.to_string(),
1089 expected: explicit_count,
1090 got: args.len(),
1091 });
1092 };
1093 let expected = declared_sort.subst(&theta);
1094 let arg_sort = typecheck_with_expected(arg, Some(&expected), ctx, theory, reports)?;
1095 push_sort_expr_eqs_into(&expected, &arg_sort, op, &mut term_eqs)?;
1096 theta.insert(Arc::clone(pname), arg.clone());
1097 }
1098 }
1099 }
1100
1101 let mgu = unify_all(term_eqs).map_err(|e| match e {
1102 GatError::SortUnificationFailure { reason } => GatError::SortUnificationFailure {
1103 reason: format!("implicit inference for {op}: {reason}"),
1104 },
1105 other => other,
1106 })?;
1107
1108 let mut final_subst = theta.clone();
1109 for (k, v) in &mgu {
1110 final_subst.insert(Arc::clone(k), v.clone());
1111 }
1112 let final_subst: FxHashMap<Arc<str>, Term> = final_subst
1113 .into_iter()
1114 .map(|(k, v)| (k, v.substitute(&mgu)))
1115 .collect();
1116
1117 Ok(operation.output.subst(&final_subst))
1118}
1119
1120fn count_holes(t: &Term) -> usize {
1121 match t {
1122 Term::Hole { .. } => 1,
1123 Term::Var(_) => 0,
1124 Term::App { args, .. } => args.iter().map(count_holes).sum(),
1125 Term::Case {
1126 scrutinee,
1127 branches,
1128 } => count_holes(scrutinee) + branches.iter().map(|b| count_holes(&b.body)).sum::<usize>(),
1129 Term::Let { bound, body, .. } => count_holes(bound) + count_holes(body),
1130 }
1131}
1132
1133fn term_contains_hole(t: &Term) -> bool {
1134 match t {
1135 Term::Hole { .. } => true,
1136 Term::Var(_) => false,
1137 Term::Let { bound, body, .. } => term_contains_hole(bound) || term_contains_hole(body),
1138 Term::App { args, .. } => args.iter().any(term_contains_hole),
1139 Term::Case {
1140 scrutinee,
1141 branches,
1142 } => term_contains_hole(scrutinee) || branches.iter().any(|b| term_contains_hole(&b.body)),
1143 }
1144}
1145
1146fn term_mentions_var(t: &Term, name: &Arc<str>) -> bool {
1147 match t {
1148 Term::Var(v) => v == name,
1149 Term::Hole { .. } => false,
1150 Term::Let {
1151 name: binder,
1152 bound,
1153 body,
1154 } => term_mentions_var(bound, name) || (binder != name && term_mentions_var(body, name)),
1155 Term::App { args, .. } => args.iter().any(|a| term_mentions_var(a, name)),
1156 Term::Case {
1157 scrutinee,
1158 branches,
1159 } => {
1160 term_mentions_var(scrutinee, name)
1161 || branches
1162 .iter()
1163 .any(|b| !b.binders.contains(name) && term_mentions_var(&b.body, name))
1164 }
1165 }
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170 use super::*;
1171 use crate::eq::Term;
1172 use crate::op::Operation;
1173 use crate::sort::{Sort, SortParam};
1174 use crate::theory::Theory;
1175
1176 fn monoid_theory() -> Theory {
1177 let carrier = Sort::simple("Carrier");
1178 let mul = Operation::new(
1179 "mul",
1180 vec![
1181 (Arc::from("a"), SortExpr::from("Carrier")),
1182 (Arc::from("b"), SortExpr::from("Carrier")),
1183 ],
1184 "Carrier",
1185 );
1186 let unit = Operation::nullary("unit", "Carrier");
1187
1188 let assoc = Equation::new(
1189 "assoc",
1190 Term::app(
1191 "mul",
1192 vec![
1193 Term::var("a"),
1194 Term::app("mul", vec![Term::var("b"), Term::var("c")]),
1195 ],
1196 ),
1197 Term::app(
1198 "mul",
1199 vec![
1200 Term::app("mul", vec![Term::var("a"), Term::var("b")]),
1201 Term::var("c"),
1202 ],
1203 ),
1204 );
1205 let left_id = Equation::new(
1206 "left_id",
1207 Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
1208 Term::var("a"),
1209 );
1210 let right_id = Equation::new(
1211 "right_id",
1212 Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
1213 Term::var("a"),
1214 );
1215
1216 Theory::new(
1217 "Monoid",
1218 vec![carrier],
1219 vec![mul, unit],
1220 vec![assoc, left_id, right_id],
1221 )
1222 }
1223
1224 fn two_sort_theory() -> Theory {
1225 Theory::new(
1226 "TwoSort",
1227 vec![Sort::simple("A"), Sort::simple("B")],
1228 vec![
1229 Operation::unary("f", "x", "A", "B"),
1230 Operation::unary("g", "x", "B", "A"),
1231 Operation::nullary("a0", "A"),
1232 ],
1233 vec![],
1234 )
1235 }
1236
1237 fn category_theory() -> Theory {
1242 let ob = Sort::simple("Ob");
1243 let hom = Sort::dependent(
1244 "Hom",
1245 vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
1246 );
1247 let hom_xx = SortExpr::App {
1248 name: Arc::from("Hom"),
1249 args: vec![Term::var("x"), Term::var("x")],
1250 };
1251 let id = Operation::unary("id", "x", "Ob", hom_xx);
1252 let hom_src_mid = SortExpr::App {
1253 name: Arc::from("Hom"),
1254 args: vec![Term::var("x"), Term::var("y")],
1255 };
1256 let hom_mid_tgt = SortExpr::App {
1257 name: Arc::from("Hom"),
1258 args: vec![Term::var("y"), Term::var("z")],
1259 };
1260 let hom_src_tgt = SortExpr::App {
1261 name: Arc::from("Hom"),
1262 args: vec![Term::var("x"), Term::var("z")],
1263 };
1264 let compose = Operation::new(
1265 "compose",
1266 vec![
1267 (Arc::from("x"), SortExpr::from("Ob")),
1268 (Arc::from("y"), SortExpr::from("Ob")),
1269 (Arc::from("z"), SortExpr::from("Ob")),
1270 (Arc::from("f"), hom_src_mid),
1271 (Arc::from("g"), hom_mid_tgt),
1272 ],
1273 hom_src_tgt,
1274 );
1275 Theory::new("Category", vec![ob, hom], vec![id, compose], Vec::new())
1276 }
1277
1278 #[test]
1279 fn typecheck_variable() -> Result<(), Box<dyn std::error::Error>> {
1280 let theory = monoid_theory();
1281 let mut ctx = VarContext::default();
1282 ctx.insert(Arc::from("x"), SortExpr::from("Carrier"));
1283 let sort = typecheck_term(&Term::var("x"), &ctx, &theory)?;
1284 assert_eq!(&**sort.head(), "Carrier");
1285 Ok(())
1286 }
1287
1288 #[test]
1289 fn typecheck_unbound_variable() {
1290 let theory = monoid_theory();
1291 let ctx = VarContext::default();
1292 let result = typecheck_term(&Term::var("z"), &ctx, &theory);
1293 assert!(matches!(result, Err(GatError::UnboundVariable(_))));
1294 }
1295
1296 #[test]
1297 fn typecheck_constant() -> Result<(), Box<dyn std::error::Error>> {
1298 let theory = monoid_theory();
1299 let ctx = VarContext::default();
1300 let sort = typecheck_term(&Term::constant("unit"), &ctx, &theory)?;
1301 assert_eq!(&**sort.head(), "Carrier");
1302 Ok(())
1303 }
1304
1305 #[test]
1306 fn typecheck_binary_op() -> Result<(), Box<dyn std::error::Error>> {
1307 let theory = monoid_theory();
1308 let mut ctx = VarContext::default();
1309 ctx.insert(Arc::from("a"), SortExpr::from("Carrier"));
1310 ctx.insert(Arc::from("b"), SortExpr::from("Carrier"));
1311 let sort = typecheck_term(
1312 &Term::app("mul", vec![Term::var("a"), Term::var("b")]),
1313 &ctx,
1314 &theory,
1315 )?;
1316 assert_eq!(&**sort.head(), "Carrier");
1317 Ok(())
1318 }
1319
1320 #[test]
1321 fn typecheck_arity_mismatch() {
1322 let theory = monoid_theory();
1323 let mut ctx = VarContext::default();
1324 ctx.insert(Arc::from("a"), SortExpr::from("Carrier"));
1325 let result = typecheck_term(&Term::app("mul", vec![Term::var("a")]), &ctx, &theory);
1326 assert!(matches!(result, Err(GatError::TermArityMismatch { .. })));
1327 }
1328
1329 #[test]
1330 fn typecheck_sort_mismatch() {
1331 let theory = two_sort_theory();
1332 let mut ctx = VarContext::default();
1333 ctx.insert(Arc::from("x"), SortExpr::from("B"));
1334 let result = typecheck_term(&Term::app("f", vec![Term::var("x")]), &ctx, &theory);
1336 assert!(matches!(result, Err(GatError::ArgTypeMismatch { .. })));
1337 }
1338
1339 #[test]
1340 fn typecheck_nested_term() -> Result<(), Box<dyn std::error::Error>> {
1341 let theory = two_sort_theory();
1342 let ctx = VarContext::default();
1343 let term = Term::app("g", vec![Term::app("f", vec![Term::constant("a0")])]);
1345 let sort = typecheck_term(&term, &ctx, &theory)?;
1346 assert_eq!(&**sort.head(), "A");
1347 Ok(())
1348 }
1349
1350 #[test]
1351 fn typecheck_nested_sort_mismatch() {
1352 let theory = two_sort_theory();
1353 let ctx = VarContext::default();
1354 let term = Term::app("f", vec![Term::app("f", vec![Term::constant("a0")])]);
1356 let result = typecheck_term(&term, &ctx, &theory);
1357 assert!(matches!(result, Err(GatError::ArgTypeMismatch { .. })));
1358 }
1359
1360 #[test]
1361 fn typecheck_unknown_op() {
1362 let theory = monoid_theory();
1363 let ctx = VarContext::default();
1364 let result = typecheck_term(&Term::constant("nonexistent"), &ctx, &theory);
1365 assert!(matches!(result, Err(GatError::OpNotFound(_))));
1366 }
1367
1368 #[test]
1369 fn infer_var_sorts_monoid() -> Result<(), Box<dyn std::error::Error>> {
1370 let theory = monoid_theory();
1371 let eq = &theory.eqs[0]; let ctx = infer_var_sorts(eq, &theory)?;
1373 assert_eq!(ctx.len(), 3);
1374 assert_eq!(&**ctx[&Arc::from("a")].head(), "Carrier");
1375 assert_eq!(&**ctx[&Arc::from("b")].head(), "Carrier");
1376 assert_eq!(&**ctx[&Arc::from("c")].head(), "Carrier");
1377 Ok(())
1378 }
1379
1380 #[test]
1381 fn infer_var_sorts_identity_law() -> Result<(), Box<dyn std::error::Error>> {
1382 let theory = monoid_theory();
1383 let eq = &theory.eqs[1]; let ctx = infer_var_sorts(eq, &theory)?;
1385 assert_eq!(ctx.len(), 1);
1386 assert_eq!(&**ctx[&Arc::from("a")].head(), "Carrier");
1387 Ok(())
1388 }
1389
1390 #[test]
1391 fn conflicting_var_sort() {
1392 let theory = two_sort_theory();
1393 let eq = Equation::new(
1394 "bogus",
1395 Term::app("f", vec![Term::var("x")]),
1396 Term::app("g", vec![Term::var("x")]),
1397 );
1398 let result = infer_var_sorts(&eq, &theory);
1399 assert!(matches!(result, Err(GatError::ConflictingVarSort { .. })));
1400 }
1401
1402 #[test]
1403 fn typecheck_monoid_equations() -> Result<(), Box<dyn std::error::Error>> {
1404 let theory = monoid_theory();
1405 typecheck_theory(&theory)?;
1406 Ok(())
1407 }
1408
1409 #[test]
1410 fn typecheck_equation_sort_mismatch() {
1411 let theory = two_sort_theory();
1412 let eq = Equation::new(
1413 "bad",
1414 Term::app("f", vec![Term::constant("a0")]),
1415 Term::constant("a0"),
1416 );
1417 let result = typecheck_equation(&eq, &theory);
1418 assert!(matches!(result, Err(GatError::EquationSortMismatch { .. })));
1419 }
1420
1421 #[test]
1422 fn typecheck_graph_theory() -> Result<(), Box<dyn std::error::Error>> {
1423 let theory = Theory::new(
1424 "Graph",
1425 vec![Sort::simple("Vertex"), Sort::simple("Edge")],
1426 vec![
1427 Operation::unary("src", "e", "Edge", "Vertex"),
1428 Operation::unary("tgt", "e", "Edge", "Vertex"),
1429 ],
1430 vec![],
1431 );
1432 typecheck_theory(&theory)?;
1433 Ok(())
1434 }
1435
1436 #[test]
1437 fn typecheck_reflexive_graph_equations() -> Result<(), Box<dyn std::error::Error>> {
1438 let theory = Theory::new(
1439 "ReflexiveGraph",
1440 vec![Sort::simple("Vertex"), Sort::simple("Edge")],
1441 vec![
1442 Operation::unary("src", "e", "Edge", "Vertex"),
1443 Operation::unary("tgt", "e", "Edge", "Vertex"),
1444 Operation::unary("id", "v", "Vertex", "Edge"),
1445 ],
1446 vec![
1447 Equation::new(
1448 "src_id",
1449 Term::app("src", vec![Term::app("id", vec![Term::var("v")])]),
1450 Term::var("v"),
1451 ),
1452 Equation::new(
1453 "tgt_id",
1454 Term::app("tgt", vec![Term::app("id", vec![Term::var("v")])]),
1455 Term::var("v"),
1456 ),
1457 ],
1458 );
1459 typecheck_theory(&theory)?;
1460 Ok(())
1461 }
1462
1463 #[test]
1464 fn typecheck_symmetric_graph_equations() -> Result<(), Box<dyn std::error::Error>> {
1465 let theory = Theory::new(
1466 "SymmetricGraph",
1467 vec![Sort::simple("Vertex"), Sort::simple("Edge")],
1468 vec![
1469 Operation::unary("src", "e", "Edge", "Vertex"),
1470 Operation::unary("tgt", "e", "Edge", "Vertex"),
1471 Operation::unary("inv", "e", "Edge", "Edge"),
1472 ],
1473 vec![
1474 Equation::new(
1475 "src_inv",
1476 Term::app("src", vec![Term::app("inv", vec![Term::var("e")])]),
1477 Term::app("tgt", vec![Term::var("e")]),
1478 ),
1479 Equation::new(
1480 "tgt_inv",
1481 Term::app("tgt", vec![Term::app("inv", vec![Term::var("e")])]),
1482 Term::app("src", vec![Term::var("e")]),
1483 ),
1484 Equation::new(
1485 "inv_inv",
1486 Term::app("inv", vec![Term::app("inv", vec![Term::var("e")])]),
1487 Term::var("e"),
1488 ),
1489 ],
1490 );
1491 typecheck_theory(&theory)?;
1492 Ok(())
1493 }
1494
1495 #[test]
1498 fn typecheck_dependent_id_ok() -> Result<(), Box<dyn std::error::Error>> {
1499 let theory = category_theory();
1500 let mut ctx = VarContext::default();
1501 ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1502 let result = typecheck_term(&Term::app("id", vec![Term::var("x")]), &ctx, &theory)?;
1503 assert_eq!(&**result.head(), "Hom");
1504 assert_eq!(result.args().len(), 2);
1505 assert_eq!(result.args()[0], Term::var("x"));
1507 assert_eq!(result.args()[1], Term::var("x"));
1508 Ok(())
1509 }
1510
1511 #[test]
1512 fn typecheck_dependent_compose_ok() -> Result<(), Box<dyn std::error::Error>> {
1513 let theory = category_theory();
1514 let mut ctx = VarContext::default();
1515 ctx.insert(Arc::from("a"), SortExpr::from("Ob"));
1516 ctx.insert(Arc::from("b"), SortExpr::from("Ob"));
1517 ctx.insert(Arc::from("c"), SortExpr::from("Ob"));
1518 ctx.insert(
1519 Arc::from("f"),
1520 SortExpr::App {
1521 name: Arc::from("Hom"),
1522 args: vec![Term::var("a"), Term::var("b")],
1523 },
1524 );
1525 ctx.insert(
1526 Arc::from("g"),
1527 SortExpr::App {
1528 name: Arc::from("Hom"),
1529 args: vec![Term::var("b"), Term::var("c")],
1530 },
1531 );
1532 let term = Term::app(
1533 "compose",
1534 vec![
1535 Term::var("a"),
1536 Term::var("b"),
1537 Term::var("c"),
1538 Term::var("f"),
1539 Term::var("g"),
1540 ],
1541 );
1542 let result = typecheck_term(&term, &ctx, &theory)?;
1543 let expected = SortExpr::App {
1544 name: Arc::from("Hom"),
1545 args: vec![Term::var("a"), Term::var("c")],
1546 };
1547 assert!(result.alpha_eq(&expected), "got {result}");
1548 Ok(())
1549 }
1550
1551 #[test]
1552 fn typecheck_dependent_compose_arg_mismatch() {
1553 let theory = category_theory();
1554 let mut ctx = VarContext::default();
1555 ctx.insert(Arc::from("a"), SortExpr::from("Ob"));
1556 ctx.insert(Arc::from("b"), SortExpr::from("Ob"));
1557 ctx.insert(Arc::from("c"), SortExpr::from("Ob"));
1558 ctx.insert(
1560 Arc::from("f"),
1561 SortExpr::App {
1562 name: Arc::from("Hom"),
1563 args: vec![Term::var("a"), Term::var("b")],
1564 },
1565 );
1566 ctx.insert(
1567 Arc::from("g"),
1568 SortExpr::App {
1569 name: Arc::from("Hom"),
1570 args: vec![Term::var("c"), Term::var("c")],
1571 },
1572 );
1573 let term = Term::app(
1574 "compose",
1575 vec![
1576 Term::var("a"),
1577 Term::var("b"),
1578 Term::var("c"),
1579 Term::var("f"),
1580 Term::var("g"),
1581 ],
1582 );
1583 let result = typecheck_term(&term, &ctx, &theory);
1584 assert!(
1585 matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1586 "expected ArgTypeMismatch, got {result:?}",
1587 );
1588 }
1589
1590 #[test]
1591 fn typecheck_dependent_equation_ok() -> Result<(), Box<dyn std::error::Error>> {
1592 let mut theory = category_theory();
1596 let assoc = Equation::new(
1597 "assoc",
1598 Term::app(
1599 "compose",
1600 vec![
1601 Term::var("a"),
1602 Term::var("b"),
1603 Term::var("d"),
1604 Term::var("f"),
1605 Term::app(
1606 "compose",
1607 vec![
1608 Term::var("b"),
1609 Term::var("c"),
1610 Term::var("d"),
1611 Term::var("g"),
1612 Term::var("h"),
1613 ],
1614 ),
1615 ],
1616 ),
1617 Term::app(
1618 "compose",
1619 vec![
1620 Term::var("a"),
1621 Term::var("c"),
1622 Term::var("d"),
1623 Term::app(
1624 "compose",
1625 vec![
1626 Term::var("a"),
1627 Term::var("b"),
1628 Term::var("c"),
1629 Term::var("f"),
1630 Term::var("g"),
1631 ],
1632 ),
1633 Term::var("h"),
1634 ],
1635 ),
1636 );
1637 theory.eqs.push(assoc);
1638 typecheck_theory(&theory)?;
1639 Ok(())
1640 }
1641
1642 #[test]
1645 fn unify_same_var_yields_empty_subst() -> Result<(), Box<dyn std::error::Error>> {
1646 let subst = unify_all(vec![(Term::var("x"), Term::var("x"))])?;
1647 assert!(subst.is_empty());
1648 Ok(())
1649 }
1650
1651 #[test]
1652 fn unify_var_to_constant_binds() -> Result<(), Box<dyn std::error::Error>> {
1653 let subst = unify_all(vec![(Term::var("x"), Term::constant("c"))])?;
1654 assert_eq!(subst.get(&Arc::from("x")), Some(&Term::constant("c")));
1655 Ok(())
1656 }
1657
1658 #[test]
1659 fn unify_occurs_check_fails() {
1660 let r = unify_all(vec![(Term::var("x"), Term::app("f", vec![Term::var("x")]))]);
1662 assert!(matches!(r, Err(GatError::SortUnificationFailure { .. })));
1663 }
1664
1665 #[test]
1666 fn unify_head_mismatch_fails() {
1667 let r = unify_all(vec![(
1668 Term::app("f", vec![Term::var("x")]),
1669 Term::app("g", vec![Term::var("x")]),
1670 )]);
1671 assert!(matches!(r, Err(GatError::SortUnificationFailure { .. })));
1672 }
1673
1674 #[test]
1675 fn unify_is_idempotent() -> Result<(), Box<dyn std::error::Error>> {
1676 let eqs = vec![(
1679 Term::app("f", vec![Term::var("x"), Term::var("y")]),
1680 Term::app(
1681 "f",
1682 vec![Term::var("a"), Term::app("g", vec![Term::var("b")])],
1683 ),
1684 )];
1685 let subst = unify_all(eqs)?;
1686 for k in subst.keys() {
1688 let once = Term::var(Arc::clone(k)).substitute(&subst);
1689 let twice = once.substitute(&subst);
1690 assert_eq!(once, twice, "substitution not idempotent on {k}");
1691 }
1692 Ok(())
1693 }
1694
1695 #[test]
1696 fn unify_soundness_mgu_instantiates_both_sides() -> Result<(), Box<dyn std::error::Error>> {
1697 let lhs = Term::app(
1699 "f",
1700 vec![Term::var("x"), Term::app("g", vec![Term::var("y")])],
1701 );
1702 let rhs = Term::app(
1703 "f",
1704 vec![
1705 Term::app("h", vec![Term::var("a")]),
1706 Term::app("g", vec![Term::var("b")]),
1707 ],
1708 );
1709 let subst = unify_all(vec![(lhs.clone(), rhs.clone())])?;
1710 let l2 = lhs.substitute(&subst);
1711 let r2 = rhs.substitute(&subst);
1712 assert_eq!(l2, r2);
1713 Ok(())
1714 }
1715
1716 #[test]
1719 fn typecheck_term_idempotent() -> Result<(), Box<dyn std::error::Error>> {
1720 let theory = category_theory();
1721 let mut ctx = VarContext::default();
1722 ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1723 let t = Term::app("id", vec![Term::var("x")]);
1724 let s1 = typecheck_term(&t, &ctx, &theory)?;
1725 let s2 = typecheck_term(&t, &ctx, &theory)?;
1726 assert_eq!(s1, s2);
1727 Ok(())
1728 }
1729
1730 #[test]
1731 fn typecheck_context_strengthening() -> Result<(), Box<dyn std::error::Error>> {
1732 let theory = category_theory();
1733 let mut ctx = VarContext::default();
1734 ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1735 let t = Term::app("id", vec![Term::var("x")]);
1736 let s1 = typecheck_term(&t, &ctx, &theory)?;
1737 ctx.insert(Arc::from("unused"), SortExpr::from("Ob"));
1739 let s2 = typecheck_term(&t, &ctx, &theory)?;
1740 assert_eq!(s1, s2);
1741 Ok(())
1742 }
1743
1744 #[test]
1745 fn typecheck_substitution_commutes() -> Result<(), Box<dyn std::error::Error>> {
1746 let theory = category_theory();
1748 let mut ctx = VarContext::default();
1749 ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1750 let t = Term::app("id", vec![Term::var("x")]);
1751 let s = typecheck_term(&t, &ctx, &theory)?;
1752
1753 let mut sigma: FxHashMap<Arc<str>, Term> = FxHashMap::default();
1755 sigma.insert(Arc::from("x"), Term::var("y"));
1756
1757 let t_prime = t.substitute(&sigma);
1758 let mut ctx_prime = VarContext::default();
1759 ctx_prime.insert(Arc::from("y"), SortExpr::from("Ob"));
1760
1761 let s_prime = typecheck_term(&t_prime, &ctx_prime, &theory)?;
1762 let s_expected = s.subst(&sigma);
1763 assert!(
1764 s_prime.alpha_eq(&s_expected),
1765 "got {s_prime}, expected {s_expected}"
1766 );
1767 Ok(())
1768 }
1769
1770 #[test]
1781 fn compose_with_disagreeing_middle_object_is_rejected() {
1782 let theory = category_theory();
1787 let mut ctx = VarContext::default();
1788 ctx.insert(Arc::from("p"), SortExpr::from("Ob"));
1789 ctx.insert(Arc::from("q"), SortExpr::from("Ob"));
1790 ctx.insert(Arc::from("r"), SortExpr::from("Ob"));
1791 ctx.insert(Arc::from("s"), SortExpr::from("Ob"));
1792 ctx.insert(
1793 Arc::from("f"),
1794 SortExpr::App {
1795 name: Arc::from("Hom"),
1796 args: vec![Term::var("p"), Term::var("q")],
1797 },
1798 );
1799 ctx.insert(
1800 Arc::from("g"),
1801 SortExpr::App {
1802 name: Arc::from("Hom"),
1803 args: vec![Term::var("r"), Term::var("s")],
1804 },
1805 );
1806 let term = Term::app(
1810 "compose",
1811 vec![
1812 Term::var("p"),
1813 Term::var("q"),
1814 Term::var("s"),
1815 Term::var("f"),
1816 Term::var("g"),
1817 ],
1818 );
1819 let result = typecheck_term(&term, &ctx, &theory);
1820 assert!(
1821 matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1822 "compose with mismatched middle object must be rejected, got {result:?}",
1823 );
1824 }
1825
1826 #[test]
1827 fn compose_of_identity_with_unrelated_arrow_is_rejected() {
1828 let theory = category_theory();
1832 let mut ctx = VarContext::default();
1833 ctx.insert(Arc::from("p"), SortExpr::from("Ob"));
1834 ctx.insert(Arc::from("q"), SortExpr::from("Ob"));
1835 ctx.insert(Arc::from("r"), SortExpr::from("Ob"));
1836 ctx.insert(
1837 Arc::from("f"),
1838 SortExpr::App {
1839 name: Arc::from("Hom"),
1840 args: vec![Term::var("q"), Term::var("r")],
1841 },
1842 );
1843 let term = Term::app(
1847 "compose",
1848 vec![
1849 Term::var("p"),
1850 Term::var("p"),
1851 Term::var("r"),
1852 Term::app("id", vec![Term::var("p")]),
1853 Term::var("f"),
1854 ],
1855 );
1856 let result = typecheck_term(&term, &ctx, &theory);
1857 assert!(
1858 matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1859 "compose(id(p), f) with src(f) != p must be rejected, got {result:?}",
1860 );
1861 }
1862
1863 #[test]
1864 fn compose_of_two_identities_at_distinct_objects_is_rejected() {
1865 let theory = category_theory();
1870 let mut ctx = VarContext::default();
1871 ctx.insert(Arc::from("p"), SortExpr::from("Ob"));
1872 ctx.insert(Arc::from("q"), SortExpr::from("Ob"));
1873 let term = Term::app(
1877 "compose",
1878 vec![
1879 Term::var("p"),
1880 Term::var("p"),
1881 Term::var("q"),
1882 Term::app("id", vec![Term::var("p")]),
1883 Term::app("id", vec![Term::var("q")]),
1884 ],
1885 );
1886 let result = typecheck_term(&term, &ctx, &theory);
1887 assert!(
1888 matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1889 "compose(id(p), id(q)) with p != q must be rejected, got {result:?}",
1890 );
1891 }
1892
1893 #[test]
1896 fn equation_with_dependent_sort_arg_mismatch_errors() {
1897 let theory = category_theory();
1901 let eq = Equation::new(
1902 "bad",
1903 Term::app("id", vec![Term::app("id", vec![Term::var("x")])]),
1904 Term::var("x"),
1905 );
1906 let result = typecheck_equation(&eq, &theory);
1909 assert!(
1910 result.is_err(),
1911 "equation with argument-sort mismatch must error, got {result:?}",
1912 );
1913 }
1914
1915 #[test]
1916 fn equation_with_unknown_op_errors() {
1917 let theory = monoid_theory();
1918 let eq = Equation::new(
1919 "bad",
1920 Term::app("mystery", vec![Term::var("a")]),
1921 Term::var("a"),
1922 );
1923 let result = typecheck_equation(&eq, &theory);
1924 assert!(
1925 matches!(result, Err(GatError::OpNotFound(_))),
1926 "equation referencing unknown op must error, got {result:?}",
1927 );
1928 }
1929
1930 #[test]
1931 fn equation_with_arity_mismatch_errors() {
1932 let theory = monoid_theory();
1933 let eq = Equation::new(
1934 "bad",
1935 Term::app("mul", vec![Term::var("a")]),
1936 Term::var("a"),
1937 );
1938 let result = typecheck_equation(&eq, &theory);
1939 assert!(
1940 matches!(result, Err(GatError::TermArityMismatch { .. })),
1941 "equation with arity mismatch must error, got {result:?}",
1942 );
1943 }
1944
1945 #[test]
1946 fn dependent_sort_with_ill_typed_arg_errors() {
1947 let theory = category_theory();
1951 let mut ctx = VarContext::default();
1952 ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
1953 ctx.insert(
1954 Arc::from("f"),
1955 SortExpr::App {
1956 name: Arc::from("Hom"),
1957 args: vec![Term::var("x"), Term::var("x")],
1958 },
1959 );
1960 let term = Term::app(
1962 "compose",
1963 vec![
1964 Term::var("f"),
1965 Term::var("x"),
1966 Term::var("x"),
1967 Term::var("f"),
1968 Term::var("f"),
1969 ],
1970 );
1971 let result = typecheck_term(&term, &ctx, &theory);
1972 assert!(
1973 matches!(result, Err(GatError::ArgTypeMismatch { .. })),
1974 "ill-typed dependent-sort argument must error, got {result:?}",
1975 );
1976 }
1977
1978 fn nat_theory() -> Theory {
1982 let nat = Sort::closed(
1983 "Nat",
1984 Vec::new(),
1985 [Arc::from("zero") as Arc<str>, Arc::from("succ")],
1986 );
1987 let zero = Operation::nullary("zero", "Nat");
1988 let succ = Operation::unary("succ", "n", "Nat", "Nat");
1989 Theory::new("NatTh", vec![nat], vec![zero, succ], Vec::new())
1990 }
1991
1992 #[test]
1993 fn closed_sort_exhaustive_case_typechecks() -> Result<(), Box<dyn std::error::Error>> {
1994 let theory = nat_theory();
1995 let mut ctx = VarContext::default();
1996 ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
1997 let term = Term::Case {
1998 scrutinee: Box::new(Term::var("n")),
1999 branches: vec![
2000 CaseBranch {
2001 constructor: Arc::from("zero"),
2002 binders: Vec::new(),
2003 body: Term::constant("zero"),
2004 },
2005 CaseBranch {
2006 constructor: Arc::from("succ"),
2007 binders: vec![Arc::from("m")],
2008 body: Term::var("m"),
2009 },
2010 ],
2011 };
2012 let sort = typecheck_term(&term, &ctx, &theory)?;
2013 assert_eq!(&**sort.head(), "Nat");
2014 Ok(())
2015 }
2016
2017 #[test]
2018 fn closed_sort_missing_branch_rejected() {
2019 let theory = nat_theory();
2020 let mut ctx = VarContext::default();
2021 ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
2022 let term = Term::Case {
2023 scrutinee: Box::new(Term::var("n")),
2024 branches: vec![CaseBranch {
2025 constructor: Arc::from("zero"),
2026 binders: Vec::new(),
2027 body: Term::constant("zero"),
2028 }],
2029 };
2030 let result = typecheck_term(&term, &ctx, &theory);
2031 assert!(
2032 matches!(result, Err(GatError::NonExhaustiveCase { .. })),
2033 "got {result:?}"
2034 );
2035 }
2036
2037 #[test]
2038 fn closed_sort_redundant_branch_rejected() {
2039 let theory = nat_theory();
2040 let mut ctx = VarContext::default();
2041 ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
2042 let term = Term::Case {
2043 scrutinee: Box::new(Term::var("n")),
2044 branches: vec![
2045 CaseBranch {
2046 constructor: Arc::from("zero"),
2047 binders: Vec::new(),
2048 body: Term::constant("zero"),
2049 },
2050 CaseBranch {
2051 constructor: Arc::from("zero"),
2052 binders: Vec::new(),
2053 body: Term::constant("zero"),
2054 },
2055 ],
2056 };
2057 let result = typecheck_term(&term, &ctx, &theory);
2058 assert!(
2059 matches!(result, Err(GatError::RedundantCaseBranch { .. })),
2060 "got {result:?}"
2061 );
2062 }
2063
2064 #[test]
2065 fn case_on_open_sort_rejected() {
2066 let v = Sort::simple("Vertex");
2068 let v0 = Operation::nullary("v0", "Vertex");
2069 let theory = Theory::new("Open", vec![v], vec![v0], Vec::new());
2070 let mut ctx = VarContext::default();
2071 ctx.insert(Arc::from("x"), SortExpr::from("Vertex"));
2072 let term = Term::Case {
2073 scrutinee: Box::new(Term::var("x")),
2074 branches: vec![CaseBranch {
2075 constructor: Arc::from("v0"),
2076 binders: Vec::new(),
2077 body: Term::constant("v0"),
2078 }],
2079 };
2080 let result = typecheck_term(&term, &ctx, &theory);
2081 assert!(
2082 matches!(result, Err(GatError::CaseOnOpenSort { .. })),
2083 "got {result:?}"
2084 );
2085 }
2086
2087 #[test]
2088 fn case_unknown_constructor_rejected() {
2089 let theory = nat_theory();
2090 let mut ctx = VarContext::default();
2091 ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
2092 let term = Term::Case {
2093 scrutinee: Box::new(Term::var("n")),
2094 branches: vec![
2095 CaseBranch {
2096 constructor: Arc::from("nope"),
2097 binders: Vec::new(),
2098 body: Term::constant("zero"),
2099 },
2100 CaseBranch {
2101 constructor: Arc::from("succ"),
2102 binders: vec![Arc::from("m")],
2103 body: Term::var("m"),
2104 },
2105 ],
2106 };
2107 let result = typecheck_term(&term, &ctx, &theory);
2108 assert!(
2109 matches!(result, Err(GatError::UnknownCaseConstructor { .. })),
2110 "got {result:?}"
2111 );
2112 }
2113
2114 #[test]
2115 fn closed_sort_rejects_external_constructor() {
2116 let nat = Sort::closed("Nat", Vec::new(), [Arc::from("zero") as Arc<str>]);
2119 let zero = Operation::nullary("zero", "Nat");
2120 let sneaky = Operation::nullary("sneaky", "Nat");
2121 let theory = Theory::new("BadClosure", vec![nat], vec![zero, sneaky], Vec::new());
2122 let result = typecheck_theory(&theory);
2123 assert!(
2124 matches!(result, Err(GatError::InvalidClosedSortConstructor { .. })),
2125 "got {result:?}"
2126 );
2127 }
2128
2129 #[test]
2130 fn morphism_preserves_closure_constructors() -> Result<(), Box<dyn std::error::Error>> {
2131 use crate::morphism::{TheoryMorphism, check_morphism};
2132 use std::collections::HashMap;
2133
2134 let nat1 = nat_theory();
2135 let nat_prime = Sort::closed(
2137 "Nat",
2138 Vec::new(),
2139 [Arc::from("zero2") as Arc<str>, Arc::from("succ2")],
2140 );
2141 let zero2 = Operation::nullary("zero2", "Nat");
2142 let succ2 = Operation::unary("succ2", "n", "Nat", "Nat");
2143 let nat2 = Theory::new("NatTh2", vec![nat_prime], vec![zero2, succ2], Vec::new());
2144
2145 let mut sort_map = HashMap::new();
2146 sort_map.insert(Arc::from("Nat"), Arc::from("Nat"));
2147 let mut op_map = HashMap::new();
2148 op_map.insert(Arc::from("zero"), Arc::from("zero2"));
2149 op_map.insert(Arc::from("succ"), Arc::from("succ2"));
2150 let m = TheoryMorphism::new("m", "NatTh", "NatTh2", sort_map, op_map);
2151 check_morphism(&m, &nat1, &nat2)?;
2152
2153 let nat_prime_bad = Sort::closed(
2156 "Nat",
2157 Vec::new(),
2158 [Arc::from("zero2") as Arc<str>, Arc::from("other")],
2159 );
2160 let other = Operation::unary("other", "n", "Nat", "Nat");
2161 let nat2_bad = Theory::new(
2162 "NatTh2",
2163 vec![nat_prime_bad],
2164 vec![Operation::nullary("zero2", "Nat"), other],
2165 Vec::new(),
2166 );
2167 let result = check_morphism(&m, &nat1, &nat2_bad);
2168 assert!(
2169 matches!(result, Err(GatError::MorphismClosureMismatch { .. })),
2170 "got {result:?}"
2171 );
2172 Ok(())
2173 }
2174
2175 #[test]
2176 fn case_term_substitution_respects_binder_shadow() {
2177 let term = Term::Case {
2182 scrutinee: Box::new(Term::var("n")),
2183 branches: vec![
2184 CaseBranch {
2185 constructor: Arc::from("zero"),
2186 binders: Vec::new(),
2187 body: Term::var("m"),
2188 },
2189 CaseBranch {
2190 constructor: Arc::from("succ"),
2191 binders: vec![Arc::from("m")],
2192 body: Term::var("m"),
2193 },
2194 ],
2195 };
2196 let mut subst = FxHashMap::default();
2197 subst.insert(
2198 Arc::from("m"),
2199 Term::app("succ", vec![Term::constant("zero")]),
2200 );
2201 let result = term.substitute(&subst);
2202 let Term::Case { branches, .. } = &result else {
2203 panic!("expected Case, got {result:?}");
2204 };
2205 assert_eq!(
2206 branches[0].body,
2207 Term::app("succ", vec![Term::constant("zero")]),
2208 "zero branch body should be substituted"
2209 );
2210 assert_eq!(
2211 branches[1].body,
2212 Term::var("m"),
2213 "succ branch body must be shadowed, body stays `m`"
2214 );
2215 }
2216
2217 fn lambda_theory() -> Theory {
2227 use crate::op::Implicit;
2228 let ty = Sort::simple("Ty");
2229 let tm = Sort::dependent("Tm", vec![SortParam::new("t", "Ty")]);
2230 let arrow = Operation::new(
2231 "arrow",
2232 vec![
2233 (Arc::from("a"), SortExpr::from("Ty")),
2234 (Arc::from("b"), SortExpr::from("Ty")),
2235 ],
2236 "Ty",
2237 );
2238 let tm_a = SortExpr::App {
2239 name: Arc::from("Tm"),
2240 args: vec![Term::var("a")],
2241 };
2242 let tm_b = SortExpr::App {
2243 name: Arc::from("Tm"),
2244 args: vec![Term::var("b")],
2245 };
2246 let tm_arrow = SortExpr::App {
2247 name: Arc::from("Tm"),
2248 args: vec![Term::app("arrow", vec![Term::var("a"), Term::var("b")])],
2249 };
2250 let app = Operation::with_implicit(
2251 "app",
2252 vec![
2253 (Arc::from("a"), SortExpr::from("Ty"), Implicit::Yes),
2254 (Arc::from("b"), SortExpr::from("Ty"), Implicit::Yes),
2255 (Arc::from("f"), tm_arrow, Implicit::No),
2256 (Arc::from("x"), tm_a, Implicit::No),
2257 ],
2258 tm_b,
2259 );
2260 Theory::new("Lambda", vec![ty, tm], vec![arrow, app], Vec::new())
2261 }
2262
2263 #[test]
2264 fn app_with_inferred_implicit_types() -> Result<(), Box<dyn std::error::Error>> {
2265 let theory = lambda_theory();
2266 let mut ctx = VarContext::default();
2267 ctx.insert(Arc::from("A"), SortExpr::from("Ty"));
2268 ctx.insert(Arc::from("B"), SortExpr::from("Ty"));
2269 ctx.insert(
2270 Arc::from("f"),
2271 SortExpr::App {
2272 name: Arc::from("Tm"),
2273 args: vec![Term::app("arrow", vec![Term::var("A"), Term::var("B")])],
2274 },
2275 );
2276 ctx.insert(
2277 Arc::from("x"),
2278 SortExpr::App {
2279 name: Arc::from("Tm"),
2280 args: vec![Term::var("A")],
2281 },
2282 );
2283 let result = typecheck_term(
2285 &Term::app("app", vec![Term::var("f"), Term::var("x")]),
2286 &ctx,
2287 &theory,
2288 )?;
2289 let expected = SortExpr::App {
2290 name: Arc::from("Tm"),
2291 args: vec![Term::var("B")],
2292 };
2293 assert!(result.alpha_eq(&expected), "got {result}");
2294 Ok(())
2295 }
2296
2297 #[test]
2298 fn implicit_inference_rejects_overconstrained_call() {
2299 use crate::op::Implicit;
2300 let type_decl = Sort::simple("Ty");
2305 let term_decl = Sort::dependent("Tm", vec![SortParam::new("t", "Ty")]);
2306 let first_ty = Operation::nullary("tyA", "Ty");
2307 let second_ty = Operation::nullary("tyB", "Ty");
2308 let arrow = Operation::new(
2309 "arrow",
2310 vec![
2311 (Arc::from("a"), SortExpr::from("Ty")),
2312 (Arc::from("b"), SortExpr::from("Ty")),
2313 ],
2314 "Ty",
2315 );
2316 let tm_of_a = SortExpr::App {
2317 name: Arc::from("Tm"),
2318 args: vec![Term::var("a")],
2319 };
2320 let tm_of_b = SortExpr::App {
2321 name: Arc::from("Tm"),
2322 args: vec![Term::var("b")],
2323 };
2324 let tm_of_arrow = SortExpr::App {
2325 name: Arc::from("Tm"),
2326 args: vec![Term::app("arrow", vec![Term::var("a"), Term::var("b")])],
2327 };
2328 let app = Operation::with_implicit(
2329 "app",
2330 vec![
2331 (Arc::from("a"), SortExpr::from("Ty"), Implicit::Yes),
2332 (Arc::from("b"), SortExpr::from("Ty"), Implicit::Yes),
2333 (Arc::from("f"), tm_of_arrow, Implicit::No),
2334 (Arc::from("x"), tm_of_a, Implicit::No),
2335 ],
2336 tm_of_b,
2337 );
2338 let theory = Theory::new(
2339 "LambdaGround",
2340 vec![type_decl, term_decl],
2341 vec![first_ty, second_ty, arrow, app],
2342 Vec::new(),
2343 );
2344
2345 let mut ctx = VarContext::default();
2346 ctx.insert(
2347 Arc::from("f"),
2348 SortExpr::App {
2349 name: Arc::from("Tm"),
2350 args: vec![Term::app(
2351 "arrow",
2352 vec![Term::constant("tyA"), Term::constant("tyB")],
2353 )],
2354 },
2355 );
2356 ctx.insert(
2357 Arc::from("x"),
2358 SortExpr::App {
2359 name: Arc::from("Tm"),
2360 args: vec![Term::constant("tyB")],
2361 },
2362 );
2363 let result = typecheck_term(
2364 &Term::app("app", vec![Term::var("f"), Term::var("x")]),
2365 &ctx,
2366 &theory,
2367 );
2368 assert!(
2369 matches!(result, Err(GatError::SortUnificationFailure { .. })),
2370 "overconstrained implicit inference must fail: got {result:?}",
2371 );
2372 }
2373
2374 #[test]
2375 fn implicit_declaration_rejected_when_not_inferrable() {
2376 use crate::op::Implicit;
2377 let foo = Operation::with_implicit(
2380 "foo",
2381 vec![
2382 (Arc::from("a"), SortExpr::from("Ty"), Implicit::No),
2383 (Arc::from("c"), SortExpr::from("Ty"), Implicit::Yes),
2384 ],
2385 SortExpr::from("Ty"),
2386 );
2387 let theory = Theory::new(
2388 "BadImplicit",
2389 vec![Sort::simple("Ty")],
2390 vec![foo],
2391 Vec::new(),
2392 );
2393 let result = typecheck_theory(&theory);
2394 assert!(
2395 matches!(result, Err(GatError::NonInferrableImplicit { .. })),
2396 "non-inferrable implicit must be rejected: got {result:?}",
2397 );
2398 }
2399
2400 #[test]
2401 fn app_without_implicits_still_typechecks() -> Result<(), Box<dyn std::error::Error>> {
2402 let theory = category_theory();
2405 let mut ctx = VarContext::default();
2406 ctx.insert(Arc::from("x"), SortExpr::from("Ob"));
2407 let result = typecheck_term(&Term::app("id", vec![Term::var("x")]), &ctx, &theory)?;
2408 assert_eq!(&**result.head(), "Hom");
2409 Ok(())
2410 }
2411
2412 #[test]
2413 fn monomorphic_let_typechecks() -> Result<(), Box<dyn std::error::Error>> {
2414 let theory = monoid_theory();
2416 let ctx = VarContext::default();
2417 let t = Term::Let {
2418 name: Arc::from("x"),
2419 bound: Box::new(Term::constant("unit")),
2420 body: Box::new(Term::app("mul", vec![Term::var("x"), Term::var("x")])),
2421 };
2422 let sort = typecheck_term(&t, &ctx, &theory)?;
2423 assert_eq!(&**sort.head(), "Carrier");
2424 Ok(())
2425 }
2426
2427 #[test]
2428 fn equation_with_hole_is_rejected() {
2429 let theory = monoid_theory();
2430 let eq = Equation::new(
2431 "bad",
2432 Term::app("mul", vec![Term::var("a"), Term::Hole { name: None }]),
2433 Term::var("a"),
2434 );
2435 let result = typecheck_equation(&eq, &theory);
2436 assert!(matches!(result, Err(GatError::HolesInEquation { .. })));
2437 }
2438
2439 mod property {
2442 use super::*;
2443 use proptest::prelude::*;
2444
2445 const SORT_POOL: &[&str] = &["S0", "S1", "S2", "S3"];
2446
2447 fn arb_well_typed_theory() -> impl Strategy<Value = Theory> {
2450 prop::sample::subsequence(SORT_POOL, 1..=4).prop_flat_map(|sort_names| {
2451 let sorts: Vec<Sort> = sort_names.iter().map(|s| Sort::simple(*s)).collect();
2452 let sn: Vec<String> = sort_names.iter().map(|s| (*s).to_owned()).collect();
2453 let sn2 = sn.clone();
2454 (
2455 Just(sorts),
2456 prop::collection::vec(
2457 (
2458 0..4usize,
2459 prop::sample::select(sn),
2460 prop::sample::select(sn2),
2461 ),
2462 0..=3,
2463 ),
2464 )
2465 .prop_map(|(sorts, op_specs)| {
2466 let mut ops = Vec::new();
2467 let mut seen = std::collections::HashSet::new();
2468 for (i, (_, input_sort, output_sort)) in op_specs.iter().enumerate() {
2469 let name = format!("op{i}");
2470 if !seen.insert(name.clone()) {
2471 continue;
2472 }
2473 ops.push(Operation::unary(
2474 &*name,
2475 "x",
2476 input_sort.as_str(),
2477 output_sort.as_str(),
2478 ));
2479 }
2480 Theory::new("TypecheckTest", sorts, ops, Vec::new())
2481 })
2482 })
2483 }
2484
2485 fn arb_case_on_nat() -> impl Strategy<Value = (Theory, Vec<Arc<str>>)> {
2489 let nat = Sort::closed(
2490 "Nat",
2491 Vec::new(),
2492 [Arc::from("zero") as Arc<str>, Arc::from("succ")],
2493 );
2494 let zero = Operation::nullary("zero", "Nat");
2495 let succ = Operation::unary("succ", "n", "Nat", "Nat");
2496 let theory = Theory::new("NatTh", vec![nat], vec![zero, succ], Vec::new());
2497 (
2498 Just(theory),
2499 prop::collection::vec(
2500 prop::sample::select(vec![
2501 Arc::from("zero"),
2502 Arc::from("succ"),
2503 Arc::from("bogus"),
2504 ] as Vec<Arc<str>>),
2505 0..=3,
2506 ),
2507 )
2508 }
2509
2510 proptest! {
2511 #![proptest_config(ProptestConfig::with_cases(256))]
2512
2513 #[test]
2514 fn case_on_closed_sort_never_panics(
2515 (theory, ctors) in arb_case_on_nat()
2516 ) {
2517 let mut ctx = VarContext::default();
2518 ctx.insert(Arc::from("n"), SortExpr::from("Nat"));
2519 let branches: Vec<CaseBranch> = ctors
2520 .into_iter()
2521 .map(|c| CaseBranch {
2522 constructor: c,
2523 binders: Vec::new(),
2524 body: Term::constant("zero"),
2525 })
2526 .collect();
2527 let term = Term::Case {
2528 scrutinee: Box::new(Term::var("n")),
2529 branches,
2530 };
2531 let r = typecheck_term(&term, &ctx, &theory);
2534 match r {
2535 Ok(_)
2536 | Err(
2537 GatError::NonExhaustiveCase { .. }
2538 | GatError::RedundantCaseBranch { .. }
2539 | GatError::UnknownCaseConstructor { .. }
2540 | GatError::OpTypeMismatch { .. }
2541 | GatError::TermArityMismatch { .. },
2542 ) => {}
2543 other => prop_assert!(false, "unexpected result: {other:?}"),
2544 }
2545 }
2546
2547 #[test]
2548 fn typecheck_is_idempotent(t in arb_well_typed_theory()) {
2549 let result1 = typecheck_theory(&t);
2550 let result2 = typecheck_theory(&t);
2551 prop_assert_eq!(result1.is_ok(), result2.is_ok());
2552 }
2553
2554 #[test]
2555 fn well_typed_theory_passes(t in arb_well_typed_theory()) {
2556 prop_assert!(
2557 typecheck_theory(&t).is_ok(),
2558 "well-typed theory should pass typecheck",
2559 );
2560 }
2561
2562 #[test]
2563 fn implicit_inference_stable_across_names(
2564 a_name in prop::sample::select(&["A", "B", "C", "P", "Q"][..]).prop_map(Arc::from),
2565 b_name in prop::sample::select(&["A", "B", "C", "P", "Q"][..]).prop_map(Arc::from),
2566 ) {
2567 use crate::op::Implicit;
2568 let ty = Sort::simple("Ty");
2574 let tm = Sort::dependent("Tm", vec![SortParam::new("t", "Ty")]);
2575 let arrow = Operation::new(
2576 "arrow",
2577 vec![
2578 (Arc::from("a"), SortExpr::from("Ty")),
2579 (Arc::from("b"), SortExpr::from("Ty")),
2580 ],
2581 "Ty",
2582 );
2583 let tm_a = SortExpr::App {
2584 name: Arc::from("Tm"),
2585 args: vec![Term::var("a")],
2586 };
2587 let tm_b = SortExpr::App {
2588 name: Arc::from("Tm"),
2589 args: vec![Term::var("b")],
2590 };
2591 let tm_arrow = SortExpr::App {
2592 name: Arc::from("Tm"),
2593 args: vec![Term::app(
2594 "arrow",
2595 vec![Term::var("a"), Term::var("b")],
2596 )],
2597 };
2598 let app = Operation::with_implicit(
2599 "app",
2600 vec![
2601 (Arc::from("a"), SortExpr::from("Ty"), Implicit::Yes),
2602 (Arc::from("b"), SortExpr::from("Ty"), Implicit::Yes),
2603 (Arc::from("f"), tm_arrow, Implicit::No),
2604 (Arc::from("x"), tm_a, Implicit::No),
2605 ],
2606 tm_b,
2607 );
2608 let theory = Theory::new("Lambda", vec![ty, tm], vec![arrow, app], Vec::new());
2609
2610 let mut ctx = VarContext::default();
2611 ctx.insert(Arc::clone(&a_name), SortExpr::from("Ty"));
2612 if a_name != b_name {
2613 ctx.insert(Arc::clone(&b_name), SortExpr::from("Ty"));
2614 }
2615 ctx.insert(
2616 Arc::from("f"),
2617 SortExpr::App {
2618 name: Arc::from("Tm"),
2619 args: vec![Term::app(
2620 "arrow",
2621 vec![Term::Var(Arc::clone(&a_name)), Term::Var(Arc::clone(&b_name))],
2622 )],
2623 },
2624 );
2625 ctx.insert(
2626 Arc::from("x"),
2627 SortExpr::App {
2628 name: Arc::from("Tm"),
2629 args: vec![Term::Var(Arc::clone(&a_name))],
2630 },
2631 );
2632 let call = Term::app("app", vec![Term::var("f"), Term::var("x")]);
2633 let s1 = typecheck_term(&call, &ctx, &theory);
2634 let s2 = typecheck_term(&call, &ctx, &theory);
2635 prop_assert_eq!(s1.is_ok(), s2.is_ok());
2636 if let (Ok(a), Ok(b)) = (&s1, &s2) {
2637 prop_assert!(a.alpha_eq(b));
2638 }
2639 }
2640
2641 #[test]
2642 fn unification_soundness_on_congruent_pairs(
2643 c1 in prop::sample::select(&["a", "b", "c"][..]),
2644 c2 in prop::sample::select(&["a", "b", "c"][..]),
2645 ) {
2646 let lhs = Term::app(
2649 "f",
2650 vec![Term::var("x"), Term::var("y")],
2651 );
2652 let rhs = Term::app(
2653 "f",
2654 vec![Term::constant(c1), Term::constant(c2)],
2655 );
2656 let subst = match unify_all(vec![(lhs.clone(), rhs.clone())]) {
2657 Ok(s) => s,
2658 Err(e) => {
2659 prop_assert!(false, "unify failed: {e}");
2660 return Ok(());
2661 }
2662 };
2663 let l2 = lhs.substitute(&subst);
2664 let r2 = rhs.substitute(&subst);
2665 prop_assert_eq!(l2, r2);
2666 }
2667 }
2668 }
2669}