1use std::collections::VecDeque;
11use std::sync::Arc;
12
13use rustc_hash::{FxHashMap, FxHashSet};
14
15use crate::eq::Term;
16use crate::error::GatError;
17use crate::model::{Model, ModelValue};
18use crate::sort::SortExpr;
19use crate::theory::Theory;
20
21#[derive(Debug, Clone)]
23pub struct FreeModelConfig {
24 pub max_depth: usize,
26 pub max_terms_per_sort: usize,
28}
29
30impl Default for FreeModelConfig {
31 fn default() -> Self {
32 Self {
33 max_depth: 3,
34 max_terms_per_sort: 1000,
35 }
36 }
37}
38
39#[derive(Debug)]
41pub struct FreeModelResult {
42 pub model: Model,
44 pub is_complete: bool,
48}
49
50pub fn free_model(theory: &Theory, config: &FreeModelConfig) -> Result<FreeModelResult, GatError> {
63 let (terms_by_fiber, is_complete) = generate_terms(theory, config)?;
64 let mut terms_by_sort = collapse_fibers(&terms_by_fiber);
69 for sort in &theory.sorts {
70 terms_by_sort.entry(Arc::clone(&sort.name)).or_default();
71 }
72 let (term_to_global, total_terms) = assign_global_indices(&terms_by_sort);
73 let mut uf = quotient_by_equations(theory, &terms_by_sort, &term_to_global, total_terms);
74 let model = build_model(theory, &terms_by_sort, &term_to_global, &mut uf);
75 Ok(FreeModelResult { model, is_complete })
76}
77
78fn collapse_fibers(
91 terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
92) -> FxHashMap<Arc<str>, Vec<Term>> {
93 let mut out: FxHashMap<Arc<str>, Vec<Term>> = FxHashMap::default();
94 for (fiber, terms) in terms_by_fiber {
95 let head = Arc::clone(fiber.head());
96 let bucket = out.entry(head).or_default();
97 for t in terms {
98 if !bucket.contains(t) {
99 bucket.push(t.clone());
100 }
101 }
102 }
103 out
104}
105
106fn topological_sort_sorts(theory: &Theory) -> Result<Vec<Arc<str>>, GatError> {
114 let sort_names: FxHashSet<Arc<str>> =
115 theory.sorts.iter().map(|s| Arc::clone(&s.name)).collect();
116 let mut in_degree: FxHashMap<Arc<str>, usize> = FxHashMap::default();
117 let mut dependents: FxHashMap<Arc<str>, Vec<Arc<str>>> = FxHashMap::default();
118
119 for sort in &theory.sorts {
120 in_degree.entry(Arc::clone(&sort.name)).or_insert(0);
121 for param in &sort.params {
122 let param_head = param.sort.head();
123 if sort_names.contains(param_head) {
124 *in_degree.entry(Arc::clone(&sort.name)).or_insert(0) += 1;
125 dependents
126 .entry(Arc::clone(param_head))
127 .or_default()
128 .push(Arc::clone(&sort.name));
129 }
130 }
131 }
132
133 let mut initial: Vec<Arc<str>> = in_degree
134 .iter()
135 .filter(|(_, deg)| **deg == 0)
136 .map(|(name, _)| Arc::clone(name))
137 .collect();
138 initial.sort(); let mut queue: VecDeque<Arc<str>> = initial.into_iter().collect();
140
141 let mut result = Vec::new();
142 while let Some(name) = queue.pop_front() {
143 result.push(Arc::clone(&name));
144 if let Some(deps) = dependents.get(&name) {
145 for dep in deps {
146 if let Some(deg) = in_degree.get_mut(dep) {
147 *deg = deg.saturating_sub(1);
148 if *deg == 0 {
149 queue.push_back(Arc::clone(dep));
150 }
151 }
152 }
153 }
154 }
155
156 if result.len() < theory.sorts.len() {
158 let cyclic: Vec<String> = theory
159 .sorts
160 .iter()
161 .filter(|s| !result.contains(&s.name))
162 .map(|s| s.name.to_string())
163 .collect();
164 return Err(GatError::CyclicSortDependency(cyclic));
165 }
166
167 Ok(result)
168}
169
170fn generate_terms(
180 theory: &Theory,
181 config: &FreeModelConfig,
182) -> Result<(FxHashMap<SortExpr, Vec<Term>>, bool), GatError> {
183 #![allow(clippy::type_complexity)]
184 let mut terms_by_fiber: FxHashMap<SortExpr, Vec<Term>> = FxHashMap::default();
185
186 let _ = topological_sort_sorts(theory)?;
190
191 for op in &theory.ops {
195 if op.inputs.is_empty() {
196 let term = Term::constant(Arc::clone(&op.name));
197 let fiber = op.output.clone();
198 let bucket = terms_by_fiber.entry(fiber).or_default();
199 if !bucket.contains(&term) {
200 bucket.push(term);
201 }
202 }
203 }
204
205 let mut last_depth_added = false;
206 for _depth in 1..=config.max_depth {
207 let new_terms = generate_depth(theory, &terms_by_fiber);
208
209 let mut added_any = false;
210 for (fiber, new) in new_terms {
211 let bucket = terms_by_fiber.entry(fiber.clone()).or_default();
212 for t in new {
213 if bucket.len() >= config.max_terms_per_sort {
214 let head = fiber.head();
215 return Err(GatError::ModelError(format!(
216 "term count for sort '{head}' exceeds limit {}",
217 config.max_terms_per_sort
218 )));
219 }
220 if !bucket.contains(&t) {
221 bucket.push(t);
222 added_any = true;
223 }
224 }
225 }
226 last_depth_added = added_any;
227 }
228
229 let is_complete = !last_depth_added;
230 Ok((terms_by_fiber, is_complete))
231}
232
233fn generate_depth(
237 theory: &Theory,
238 terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
239) -> FxHashMap<SortExpr, Vec<Term>> {
240 let mut new_terms: FxHashMap<SortExpr, Vec<Term>> = FxHashMap::default();
241
242 for op in &theory.ops {
243 if op.inputs.is_empty() {
244 continue;
245 }
246 let mut chosen: Vec<Term> = Vec::with_capacity(op.inputs.len());
247 let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
248 extend_op_tuples(
249 op,
250 0,
251 &mut chosen,
252 &mut theta,
253 terms_by_fiber,
254 &mut new_terms,
255 );
256 }
257
258 new_terms
259}
260
261fn extend_op_tuples(
266 op: &crate::op::Operation,
267 slot: usize,
268 chosen: &mut Vec<Term>,
269 theta: &mut FxHashMap<Arc<str>, Term>,
270 terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
271 new_terms: &mut FxHashMap<SortExpr, Vec<Term>>,
272) {
273 if slot == op.inputs.len() {
274 let output_fiber = op.output.subst(theta);
275 let term = Term::app(Arc::clone(&op.name), chosen.clone());
276 new_terms.entry(output_fiber).or_default().push(term);
277 return;
278 }
279 let (param_name, declared_sort, _implicit) = &op.inputs[slot];
280 let expected_fiber = declared_sort.subst(theta);
281 let Some(candidates) = terms_by_fiber.get(&expected_fiber) else {
282 return;
283 };
284 for cand in candidates {
285 chosen.push(cand.clone());
286 theta.insert(Arc::clone(param_name), cand.clone());
287 extend_op_tuples(op, slot + 1, chosen, theta, terms_by_fiber, new_terms);
288 theta.remove(param_name);
289 chosen.pop();
290 }
291}
292
293fn assign_global_indices(
300 terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
301) -> (FxHashMap<Arc<str>, Vec<usize>>, usize) {
302 let mut global_idx = 0usize;
303 let mut term_to_global: FxHashMap<Arc<str>, Vec<usize>> = FxHashMap::default();
304
305 let mut sorted_keys: Vec<&Arc<str>> = terms_by_sort.keys().collect();
306 sorted_keys.sort();
307 for sort in sorted_keys {
308 let terms = &terms_by_sort[sort];
309 let indices: Vec<usize> = (global_idx..global_idx + terms.len()).collect();
310 global_idx += terms.len();
311 term_to_global.insert(Arc::clone(sort), indices);
312 }
313
314 (term_to_global, global_idx)
315}
316
317fn quotient_by_equations(
325 theory: &Theory,
326 terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
327 term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
328 total_terms: usize,
329) -> UnionFind {
330 let mut uf = UnionFind::new(total_terms);
331
332 let eq_info: Vec<_> = theory
334 .eqs
335 .iter()
336 .map(|eq| {
337 let vars: Vec<Arc<str>> = {
338 let mut all = eq.lhs.free_vars();
339 all.extend(eq.rhs.free_vars());
340 all.into_iter().collect()
341 };
342 let var_sorts = crate::typecheck::infer_var_sorts(eq, theory).ok();
343 (eq, vars, var_sorts)
344 })
345 .collect();
346
347 let congruence_entries = build_congruence_index(terms_by_sort, term_to_global);
351
352 loop {
354 let merges_before = uf.merge_count;
355
356 for (eq, vars, var_sorts) in &eq_info {
358 if vars.is_empty() {
359 merge_constant_eq(eq, terms_by_sort, term_to_global, &mut uf);
360 continue;
361 }
362
363 let Some(vs) = var_sorts else {
364 continue;
365 };
366
367 merge_by_equation(eq, vars, vs, terms_by_sort, term_to_global, &mut uf);
368 }
369
370 congruence_closure_pass(&congruence_entries, &mut uf);
373
374 if uf.merge_count == merges_before {
375 break;
376 }
377 }
378
379 uf
380}
381
382struct CongruenceEntry {
385 term_idx: usize,
387 arg_indices: Vec<usize>,
389}
390
391fn build_congruence_index(
394 terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
395 term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
396) -> FxHashMap<Arc<str>, Vec<CongruenceEntry>> {
397 let mut index: FxHashMap<Arc<str>, Vec<CongruenceEntry>> = FxHashMap::default();
398
399 let mut term_lookup: FxHashMap<&Term, usize> = FxHashMap::default();
401 for (sort, terms) in terms_by_sort {
402 let indices = &term_to_global[sort];
403 for (i, term) in terms.iter().enumerate() {
404 term_lookup.insert(term, indices[i]);
405 }
406 }
407
408 for (sort, terms) in terms_by_sort {
409 let indices = &term_to_global[sort];
410 for (i, term) in terms.iter().enumerate() {
411 if let Term::App { op, args } = term {
412 if args.is_empty() {
413 continue;
414 }
415 let arg_indices: Vec<usize> = args
416 .iter()
417 .filter_map(|arg| term_lookup.get(arg).copied())
418 .collect();
419 if arg_indices.len() == args.len() {
421 index
422 .entry(Arc::clone(op))
423 .or_default()
424 .push(CongruenceEntry {
425 term_idx: indices[i],
426 arg_indices,
427 });
428 }
429 }
430 }
431 }
432
433 index
434}
435
436fn congruence_closure_pass(
439 entries: &FxHashMap<Arc<str>, Vec<CongruenceEntry>>,
440 uf: &mut UnionFind,
441) {
442 for group in entries.values() {
443 if group.len() < 2 {
444 continue;
445 }
446 let mut canonical_groups: FxHashMap<Vec<usize>, usize> = FxHashMap::default();
448 for entry in group {
449 let canonical_args: Vec<usize> =
450 entry.arg_indices.iter().map(|&i| uf.find(i)).collect();
451 if let Some(&representative) = canonical_groups.get(&canonical_args) {
452 uf.union(representative, entry.term_idx);
453 } else {
454 canonical_groups.insert(canonical_args, uf.find(entry.term_idx));
455 }
456 }
457 }
458}
459
460fn is_app_only(term: &Term) -> bool {
470 match term {
471 Term::Var(_) => true,
472 Term::App { args, .. } => args.iter().all(is_app_only),
473 Term::Case { .. } | Term::Hole { .. } | Term::Let { .. } => false,
474 }
475}
476
477fn term_to_string(term: &Term) -> String {
478 match term {
479 Term::Var(name) => name.to_string(),
480 Term::App { op, args } if args.is_empty() => format!("{op}()"),
481 Term::App { op, args } => {
482 let arg_strs: Vec<String> = args.iter().map(term_to_string).collect();
483 format!("{op}({})", arg_strs.join(", "))
484 }
485 Term::Case {
486 scrutinee,
487 branches,
488 } => {
489 let branch_strs: Vec<String> = branches
490 .iter()
491 .map(|b| {
492 let binders = b
493 .binders
494 .iter()
495 .map(ToString::to_string)
496 .collect::<Vec<_>>();
497 format!(
498 "{}({}) => {}",
499 b.constructor,
500 binders.join(", "),
501 term_to_string(&b.body)
502 )
503 })
504 .collect();
505 format!(
506 "case {} of {} end",
507 term_to_string(scrutinee),
508 branch_strs.join(" | ")
509 )
510 }
511 Term::Hole { name } => name
512 .as_ref()
513 .map_or_else(|| "?".to_string(), |n| format!("?{n}")),
514 Term::Let { name, bound, body } => format!(
515 "let {name} = {} in {}",
516 term_to_string(bound),
517 term_to_string(body)
518 ),
519 }
520}
521
522fn build_model(
523 theory: &Theory,
524 terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
525 term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
526 uf: &mut UnionFind,
527) -> Model {
528 let mut model = Model::new(&*theory.name);
529
530 let mut class_rep_string: FxHashMap<usize, String> = FxHashMap::default();
539 let mut string_to_rep: FxHashMap<String, String> = FxHashMap::default();
540 for (sort, terms) in terms_by_sort {
541 let indices = &term_to_global[sort];
542 let mut seen_classes: FxHashSet<usize> = FxHashSet::default();
543
544 for (i, term) in terms.iter().enumerate() {
545 debug_assert!(
546 is_app_only(term),
547 "free-model generator emitted a non-App term: {term:?}",
548 );
549 let rep = uf.find(indices[i]);
550 if seen_classes.insert(rep) {
551 class_rep_string.insert(rep, term_to_string(term));
553 }
554 let rep_str = class_rep_string[&rep].clone();
555 string_to_rep.insert(term_to_string(term), rep_str);
556 }
557 }
558
559 for (sort, terms) in terms_by_sort {
561 let indices = &term_to_global[sort];
562 let mut seen_classes: FxHashSet<usize> = FxHashSet::default();
563 let mut carrier = Vec::new();
564
565 for (i, term) in terms.iter().enumerate() {
566 let rep = uf.find(indices[i]);
567 if seen_classes.insert(rep) {
568 carrier.push(ModelValue::Str(term_to_string(term)));
569 }
570 }
571 model.add_sort(sort.to_string(), carrier);
572 }
573
574 let lookup = Arc::new(string_to_rep);
577
578 for op in &theory.ops {
579 let op_name = op.name.to_string();
580 let arity = op.arity();
581 let table = Arc::clone(&lookup);
582 model.add_op(op_name.clone(), move |args: &[ModelValue]| {
583 if args.len() != arity {
584 return Err(GatError::ModelError(format!(
585 "operation '{op_name}' expects {arity} args, got {}",
586 args.len()
587 )));
588 }
589 let mut arg_strs: Vec<String> = Vec::with_capacity(args.len());
594 for (i, a) in args.iter().enumerate() {
595 match a {
596 ModelValue::Str(s) => arg_strs.push(s.clone()),
597 other => {
598 return Err(GatError::ModelError(format!(
599 "operation '{op_name}' received non-string argument at index {i}: {other:?}"
600 )));
601 }
602 }
603 }
604 let result_str = format!("{op_name}({})", arg_strs.join(", "));
605
606 Ok(ModelValue::Str(
610 table.get(&result_str).map_or(result_str, String::clone),
611 ))
612 });
613 }
614
615 model
616}
617
618fn merge_constant_eq(
620 eq: &crate::eq::Equation,
621 terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
622 term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
623 uf: &mut UnionFind,
624) {
625 let lhs_idx = find_term_index(&eq.lhs, terms_by_sort, term_to_global);
626 let rhs_idx = find_term_index(&eq.rhs, terms_by_sort, term_to_global);
627 if let (Some(l), Some(r)) = (lhs_idx, rhs_idx) {
628 uf.union(l, r);
629 }
630}
631
632fn find_term_index(
634 term: &Term,
635 terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
636 term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
637) -> Option<usize> {
638 for (sort, terms) in terms_by_sort {
639 for (i, t) in terms.iter().enumerate() {
640 if t == term {
641 return Some(term_to_global[sort][i]);
642 }
643 }
644 }
645 None
646}
647
648fn merge_by_equation(
650 eq: &crate::eq::Equation,
651 vars: &[Arc<str>],
652 var_sorts: &FxHashMap<Arc<str>, SortExpr>,
653 terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
654 term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
655 uf: &mut UnionFind,
656) {
657 let var_terms: Vec<(&Arc<str>, &Vec<Term>)> = vars
658 .iter()
659 .filter_map(|v| {
660 let sort = var_sorts.get(v)?;
661 let terms = terms_by_sort.get(sort.head())?;
662 Some((v, terms))
663 })
664 .collect();
665
666 if var_terms.len() != vars.len() || var_terms.iter().any(|(_, terms)| terms.is_empty()) {
667 return;
668 }
669
670 let mut indices = vec![0usize; var_terms.len()];
671
672 loop {
673 let mut subst = rustc_hash::FxHashMap::default();
674 for (i, (var, terms)) in var_terms.iter().enumerate() {
675 subst.insert(Arc::clone(var), terms[indices[i]].clone());
676 }
677
678 let lhs = eq.lhs.substitute(&subst);
679 let rhs = eq.rhs.substitute(&subst);
680
681 let lhs_idx = find_term_index(&lhs, terms_by_sort, term_to_global);
682 let rhs_idx = find_term_index(&rhs, terms_by_sort, term_to_global);
683 if let (Some(l), Some(r)) = (lhs_idx, rhs_idx) {
684 uf.union(l, r);
685 }
686
687 let mut carry = true;
688 for i in (0..indices.len()).rev() {
689 if carry {
690 indices[i] += 1;
691 if indices[i] < var_terms[i].1.len() {
692 carry = false;
693 } else {
694 indices[i] = 0;
695 }
696 }
697 }
698 if carry {
699 break;
700 }
701 }
702}
703
704struct UnionFind {
706 parent: Vec<usize>,
707 rank: Vec<usize>,
708 merge_count: usize,
710}
711
712impl UnionFind {
713 fn new(size: usize) -> Self {
714 Self {
715 parent: (0..size).collect(),
716 rank: vec![0; size],
717 merge_count: 0,
718 }
719 }
720
721 fn find(&mut self, mut x: usize) -> usize {
722 while self.parent[x] != x {
723 self.parent[x] = self.parent[self.parent[x]]; x = self.parent[x];
725 }
726 x
727 }
728
729 fn union(&mut self, x: usize, y: usize) {
730 let rx = self.find(x);
731 let ry = self.find(y);
732 if rx == ry {
733 return;
734 }
735 self.merge_count += 1;
736 match self.rank[rx].cmp(&self.rank[ry]) {
737 std::cmp::Ordering::Less => self.parent[rx] = ry,
738 std::cmp::Ordering::Greater => self.parent[ry] = rx,
739 std::cmp::Ordering::Equal => {
740 self.parent[ry] = rx;
741 self.rank[rx] += 1;
742 }
743 }
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750 use crate::eq::Equation;
751 use crate::op::Operation;
752 use crate::sort::Sort;
753 use crate::theory::Theory;
754
755 #[test]
756 fn free_model_of_pointed_set() -> Result<(), Box<dyn std::error::Error>> {
757 let theory = Theory::new(
758 "PointedSet",
759 vec![Sort::simple("Carrier")],
760 vec![Operation::nullary("unit", "Carrier")],
761 vec![],
762 );
763 let result = free_model(&theory, &FreeModelConfig::default())?;
764 assert_eq!(result.model.sort_interp["Carrier"].len(), 1);
765 Ok(())
766 }
767
768 #[test]
769 fn free_model_empty_theory() -> Result<(), Box<dyn std::error::Error>> {
770 let theory = Theory::new("Empty", vec![Sort::simple("S")], vec![], vec![]);
771 let model = free_model(&theory, &FreeModelConfig::default())?.model;
772 assert!(model.sort_interp["S"].is_empty());
773 Ok(())
774 }
775
776 #[test]
777 fn free_model_two_constants() -> Result<(), Box<dyn std::error::Error>> {
778 let theory = Theory::new(
779 "TwoPoints",
780 vec![Sort::simple("S")],
781 vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
782 vec![],
783 );
784 let model = free_model(&theory, &FreeModelConfig::default())?.model;
785 assert_eq!(model.sort_interp["S"].len(), 2);
786 Ok(())
787 }
788
789 #[test]
790 fn free_model_equation_collapses_constants() -> Result<(), Box<dyn std::error::Error>> {
791 let theory = Theory::new(
792 "CollapsedPoints",
793 vec![Sort::simple("S")],
794 vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
795 vec![Equation::new(
796 "a_eq_b",
797 Term::constant("a"),
798 Term::constant("b"),
799 )],
800 );
801 let model = free_model(&theory, &FreeModelConfig::default())?.model;
802 assert_eq!(model.sort_interp["S"].len(), 1);
803 Ok(())
804 }
805
806 #[test]
807 fn free_model_monoid_identity_collapses() -> Result<(), Box<dyn std::error::Error>> {
808 let theory = Theory::new(
809 "Monoid",
810 vec![Sort::simple("Carrier")],
811 vec![
812 Operation::new(
813 "mul",
814 vec![
815 ("a".into(), "Carrier".into()),
816 ("b".into(), "Carrier".into()),
817 ],
818 "Carrier",
819 ),
820 Operation::nullary("unit", "Carrier"),
821 ],
822 vec![
823 Equation::new(
824 "left_id",
825 Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
826 Term::var("a"),
827 ),
828 Equation::new(
829 "right_id",
830 Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
831 Term::var("a"),
832 ),
833 ],
834 );
835 let config = FreeModelConfig {
836 max_depth: 1,
837 max_terms_per_sort: 100,
838 };
839 let model = free_model(&theory, &config)?.model;
840 assert_eq!(model.sort_interp["Carrier"].len(), 1);
841 Ok(())
842 }
843
844 #[test]
845 fn free_model_graph_theory() -> Result<(), Box<dyn std::error::Error>> {
846 let theory = Theory::new(
847 "Graph",
848 vec![Sort::simple("Vertex"), Sort::simple("Edge")],
849 vec![
850 Operation::unary("src", "e", "Edge", "Vertex"),
851 Operation::unary("tgt", "e", "Edge", "Vertex"),
852 ],
853 vec![],
854 );
855 let model = free_model(&theory, &FreeModelConfig::default())?.model;
856 assert!(model.sort_interp["Vertex"].is_empty());
857 assert!(model.sort_interp["Edge"].is_empty());
858 Ok(())
859 }
860
861 #[test]
862 fn free_model_term_count_bounded() {
863 let theory = Theory::new(
864 "Chain",
865 vec![Sort::simple("S")],
866 vec![
867 Operation::nullary("zero", "S"),
868 Operation::unary("succ", "x", "S", "S"),
869 ],
870 vec![],
871 );
872 let config = FreeModelConfig {
873 max_depth: 10,
874 max_terms_per_sort: 5,
875 };
876 let result = free_model(&theory, &config);
877 assert!(matches!(result, Err(GatError::ModelError(_))));
878 }
879
880 #[test]
884 fn free_model_category_theory() -> Result<(), Box<dyn std::error::Error>> {
885 use crate::sort::SortParam;
886
887 let theory = Theory::new(
888 "Category",
889 vec![
890 Sort::simple("Ob"),
891 Sort::dependent(
892 "Hom",
893 vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
894 ),
895 ],
896 vec![
897 Operation::nullary("star", "Ob"),
898 Operation::unary("id", "x", "Ob", "Hom"),
900 ],
901 Vec::new(),
902 );
903
904 let config = FreeModelConfig {
905 max_depth: 2,
906 max_terms_per_sort: 100,
907 };
908 let model = free_model(&theory, &config)?.model;
909
910 assert_eq!(model.sort_interp["Ob"].len(), 1);
912
913 assert!(
915 !model.sort_interp["Hom"].is_empty(),
916 "Hom should have at least the identity morphism"
917 );
918 Ok(())
919 }
920
921 #[test]
923 fn free_model_dependent_sort_no_ops() -> Result<(), Box<dyn std::error::Error>> {
924 use crate::sort::SortParam;
925
926 let theory = Theory::new(
927 "T",
928 vec![
929 Sort::simple("A"),
930 Sort::dependent("B", vec![SortParam::new("x", "A")]),
931 ],
932 vec![Operation::nullary("a", "A")],
933 Vec::new(),
934 );
935
936 let model = free_model(&theory, &FreeModelConfig::default())?.model;
937 assert_eq!(model.sort_interp["A"].len(), 1);
938 assert!(
939 model.sort_interp["B"].is_empty(),
940 "B has no operations targeting it, so carrier should be empty"
941 );
942 Ok(())
943 }
944
945 #[test]
947 fn free_model_sort_ordering() -> Result<(), Box<dyn std::error::Error>> {
948 use crate::sort::SortParam;
949
950 let theory = Theory::new(
952 "T",
953 vec![
954 Sort::dependent("B", vec![SortParam::new("x", "A")]),
955 Sort::simple("A"),
956 ],
957 vec![
958 Operation::nullary("a", "A"),
959 Operation::unary("f", "x", "A", "B"),
960 ],
961 Vec::new(),
962 );
963
964 let config = FreeModelConfig {
965 max_depth: 1,
966 max_terms_per_sort: 100,
967 };
968 let model = free_model(&theory, &config)?.model;
969
970 assert_eq!(model.sort_interp["A"].len(), 1);
972 assert_eq!(model.sort_interp["B"].len(), 1);
974 Ok(())
975 }
976
977 #[test]
978 fn free_model_operations_work() -> Result<(), Box<dyn std::error::Error>> {
979 let theory = Theory::new(
980 "PointedSet",
981 vec![Sort::simple("Carrier")],
982 vec![Operation::nullary("unit", "Carrier")],
983 vec![],
984 );
985 let model = free_model(&theory, &FreeModelConfig::default())?.model;
986 let result = model.eval("unit", &[])?;
987 assert!(matches!(result, ModelValue::Str(_)));
988 Ok(())
989 }
990
991 #[test]
992 fn free_model_congruence_closure() -> Result<(), Box<dyn std::error::Error>> {
993 let theory = Theory::new(
998 "Congruence",
999 vec![Sort::simple("S")],
1000 vec![
1001 Operation::nullary("a", "S"),
1002 Operation::nullary("b", "S"),
1003 Operation::unary("f", "x", "S", "S"),
1004 ],
1005 vec![Equation::new(
1006 "a_eq_b",
1007 Term::constant("a"),
1008 Term::constant("b"),
1009 )],
1010 );
1011 let config = FreeModelConfig {
1012 max_depth: 1,
1013 max_terms_per_sort: 100,
1014 };
1015 let model = free_model(&theory, &config)?.model;
1016 assert_eq!(
1019 model.sort_interp["S"].len(),
1020 2,
1021 "a ~ b and f(a) ~ f(b) by congruence: expect 2 classes"
1022 );
1023 Ok(())
1024 }
1025
1026 #[test]
1032 fn free_model_dependent_category() -> Result<(), Box<dyn std::error::Error>> {
1033 use crate::sort::{SortExpr, SortParam};
1034
1035 let hom_xx = SortExpr::App {
1036 name: Arc::from("Hom"),
1037 args: vec![Term::var("x"), Term::var("x")],
1038 };
1039 let theory = Theory::new(
1040 "EndoCategory",
1041 vec![
1042 Sort::simple("Ob"),
1043 Sort::dependent(
1044 "Hom",
1045 vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
1046 ),
1047 ],
1048 vec![
1049 Operation::nullary("star", "Ob"),
1050 Operation::unary("id", "x", "Ob", hom_xx.clone()),
1051 Operation::unary("f", "x", "Ob", hom_xx),
1052 ],
1053 Vec::new(),
1054 );
1055
1056 let config = FreeModelConfig {
1057 max_depth: 2,
1058 max_terms_per_sort: 100,
1059 };
1060 let model = free_model(&theory, &config)?.model;
1061
1062 assert_eq!(model.sort_interp["Ob"].len(), 1);
1064 assert_eq!(
1070 model.sort_interp["Hom"].len(),
1071 2,
1072 "expected id(star) and f(star) in Hom fiber"
1073 );
1074 Ok(())
1075 }
1076
1077 #[test]
1085 fn free_model_parallel_arrows_no_spurious_composites() -> Result<(), Box<dyn std::error::Error>>
1086 {
1087 use crate::sort::{SortExpr, SortParam};
1088
1089 let hom_ab = SortExpr::App {
1090 name: Arc::from("Hom"),
1091 args: vec![Term::constant("a"), Term::constant("b")],
1092 };
1093 let hom_xy = SortExpr::App {
1094 name: Arc::from("Hom"),
1095 args: vec![Term::var("x"), Term::var("y")],
1096 };
1097 let theory = Theory::new(
1098 "ParallelArrows",
1099 vec![
1100 Sort::simple("Ob"),
1101 Sort::dependent(
1102 "Hom",
1103 vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
1104 ),
1105 ],
1106 vec![
1107 Operation::nullary("a", "Ob"),
1108 Operation::nullary("b", "Ob"),
1109 Operation::nullary("f", hom_ab.clone()),
1110 Operation::nullary("g", hom_ab),
1111 Operation::unary(
1112 "id",
1113 "x",
1114 "Ob",
1115 SortExpr::App {
1116 name: Arc::from("Hom"),
1117 args: vec![Term::var("x"), Term::var("x")],
1118 },
1119 ),
1120 Operation::new(
1121 "compose",
1122 vec![
1123 (Arc::from("x"), SortExpr::from("Ob")),
1124 (Arc::from("y"), SortExpr::from("Ob")),
1125 (Arc::from("z"), SortExpr::from("Ob")),
1126 (
1127 Arc::from("h1"),
1128 SortExpr::App {
1129 name: Arc::from("Hom"),
1130 args: vec![Term::var("x"), Term::var("y")],
1131 },
1132 ),
1133 (
1134 Arc::from("h2"),
1135 SortExpr::App {
1136 name: Arc::from("Hom"),
1137 args: vec![Term::var("y"), Term::var("z")],
1138 },
1139 ),
1140 ],
1141 hom_xy,
1142 ),
1143 ],
1144 Vec::new(),
1145 );
1146
1147 let config = FreeModelConfig {
1148 max_depth: 1,
1149 max_terms_per_sort: 100,
1150 };
1151 let model = free_model(&theory, &config)?.model;
1152
1153 assert_eq!(model.sort_interp["Ob"].len(), 2);
1155 assert_eq!(
1159 model.sort_interp["Hom"].len(),
1160 4,
1161 "Hom fiber should contain {{id(a), id(b), f, g}}, got {:?}",
1162 model.sort_interp["Hom"],
1163 );
1164 Ok(())
1165 }
1166
1167 #[test]
1169 fn free_model_every_term_well_typed() -> Result<(), Box<dyn std::error::Error>> {
1170 use crate::sort::{SortExpr, SortParam};
1171 use crate::typecheck::{VarContext, typecheck_term};
1172
1173 let hom_xx = SortExpr::App {
1174 name: Arc::from("Hom"),
1175 args: vec![Term::var("x"), Term::var("x")],
1176 };
1177 let theory = Theory::new(
1178 "EndoCat",
1179 vec![
1180 Sort::simple("Ob"),
1181 Sort::dependent(
1182 "Hom",
1183 vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
1184 ),
1185 ],
1186 vec![
1187 Operation::nullary("star", "Ob"),
1188 Operation::unary("id", "x", "Ob", hom_xx.clone()),
1189 Operation::unary("f", "x", "Ob", hom_xx),
1190 ],
1191 Vec::new(),
1192 );
1193
1194 let config = FreeModelConfig {
1195 max_depth: 2,
1196 max_terms_per_sort: 100,
1197 };
1198 let (fibers, _) = generate_terms(&theory, &config)?;
1199 let ctx = VarContext::default();
1200 for (fiber, terms) in &fibers {
1201 for term in terms {
1202 let inferred = typecheck_term(term, &ctx, &theory)?;
1203 assert!(
1204 inferred.alpha_eq(fiber),
1205 "term {term} has fiber {fiber} but typecheck inferred {inferred}",
1206 );
1207 }
1208 }
1209 Ok(())
1210 }
1211
1212 #[test]
1217 fn free_model_simple_sorts_backward_compat() -> Result<(), Box<dyn std::error::Error>> {
1218 let theory = Theory::new(
1219 "Graph",
1220 vec![Sort::simple("Vertex"), Sort::simple("Edge")],
1221 vec![
1222 Operation::nullary("v0", "Vertex"),
1223 Operation::nullary("v1", "Vertex"),
1224 Operation::unary("src", "e", "Edge", "Vertex"),
1225 Operation::unary("tgt", "e", "Edge", "Vertex"),
1226 ],
1227 Vec::new(),
1228 );
1229 let config = FreeModelConfig {
1230 max_depth: 1,
1231 max_terms_per_sort: 100,
1232 };
1233 let model = free_model(&theory, &config)?.model;
1234 assert_eq!(model.sort_interp["Vertex"].len(), 2);
1237 assert!(model.sort_interp["Edge"].is_empty());
1238 Ok(())
1239 }
1240
1241 #[test]
1242 fn free_model_cyclic_sort_dependency_rejected() {
1243 use crate::sort::SortParam;
1244
1245 let theory = Theory::new(
1247 "Cyclic",
1248 vec![
1249 Sort::dependent("A", vec![SortParam::new("x", "B")]),
1250 Sort::dependent("B", vec![SortParam::new("y", "A")]),
1251 ],
1252 vec![],
1253 vec![],
1254 );
1255 let result = free_model(&theory, &FreeModelConfig::default());
1256 assert!(
1257 matches!(result, Err(GatError::CyclicSortDependency(_))),
1258 "cyclic sort dependencies should be rejected"
1259 );
1260 }
1261}