Skip to main content

panproto_gat/
free_model.rs

1//! Bounded approximation of the free (initial) model.
2//!
3//! Generates an approximation of the initial model of a theory by
4//! enumerating closed terms up to [`FreeModelConfig::max_depth`] and
5//! quotienting by the theory's equations. The result is exact (truly
6//! initial) when no new terms appear at the final depth level; otherwise
7//! it is a finite truncation. Check [`FreeModelResult::is_complete`] to
8//! determine whether the bound was sufficient.
9
10use std::collections::VecDeque;
11use std::sync::Arc;
12
13use rustc_hash::{FxHashMap, FxHashSet};
14
15use crate::eq::Term;
16use crate::error::GatError;
17use crate::model::{Model, ModelValue};
18use crate::sort::SortExpr;
19use crate::theory::Theory;
20
21/// Configuration for free model construction.
22#[derive(Debug, Clone)]
23pub struct FreeModelConfig {
24    /// Maximum depth of term generation. Default: 3.
25    pub max_depth: usize,
26    /// Maximum number of terms per sort (safety bound). Default: 1000.
27    pub max_terms_per_sort: usize,
28}
29
30impl Default for FreeModelConfig {
31    fn default() -> Self {
32        Self {
33            max_depth: 3,
34            max_terms_per_sort: 1000,
35        }
36    }
37}
38
39/// Result of free model construction, including completeness status.
40#[derive(Debug)]
41pub struct FreeModelResult {
42    /// The constructed model.
43    pub model: Model,
44    /// Whether the model is provably complete (initial). `true` when no
45    /// new closed terms were generated at the final depth level, meaning
46    /// increasing `max_depth` would not change the model.
47    pub is_complete: bool,
48}
49
50/// Construct a bounded approximation of the free (initial) model.
51///
52/// The carrier set of each sort is the set of closed terms of that sort
53/// (up to `max_depth`), quotiented by the theory's equations using
54/// union-find. Operations are defined by term application.
55///
56/// When [`FreeModelResult::is_complete`] is `true`, the result is the
57/// exact initial model (no new terms would appear at deeper levels).
58///
59/// # Errors
60///
61/// Returns [`GatError::ModelError`] if the term count exceeds bounds.
62pub fn free_model(theory: &Theory, config: &FreeModelConfig) -> Result<FreeModelResult, GatError> {
63    let (terms_by_fiber, is_complete) = generate_terms(theory, config)?;
64    // Collapse fiber-indexed terms to head-indexed terms for the
65    // downstream model interface, which exposes carriers by sort head
66    // name. Seed empty entries for every declared sort so callers can
67    // always look up a carrier by name.
68    let mut terms_by_sort = collapse_fibers(&terms_by_fiber);
69    for sort in &theory.sorts {
70        terms_by_sort.entry(Arc::clone(&sort.name)).or_default();
71    }
72    let (term_to_global, total_terms) = assign_global_indices(&terms_by_sort);
73    let mut uf = quotient_by_equations(theory, &terms_by_sort, &term_to_global, total_terms);
74    let model = build_model(theory, &terms_by_sort, &term_to_global, &mut uf);
75    Ok(FreeModelResult { model, is_complete })
76}
77
78/// Collapse a fiber-indexed term map down to a head-indexed term map.
79/// All terms with the same head sort are unioned into a single carrier.
80///
81/// Soundness of the collapse under the downstream quotient: GAT
82/// equations are sort-preserving (both sides of every equation have
83/// the same output sort, enforced by `typecheck_equation`), and
84/// congruence closure over a set of sort-preserving equations only
85/// ever relates terms that are already in the same fiber. The
86/// head-indexed carrier therefore exposes the fibered free model's
87/// underlying set without identifying terms across fibers; consumers
88/// that need fiber information should recover it from each term's
89/// inferred output sort via `typecheck_term`.
90fn collapse_fibers(
91    terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
92) -> FxHashMap<Arc<str>, Vec<Term>> {
93    let mut out: FxHashMap<Arc<str>, Vec<Term>> = FxHashMap::default();
94    for (fiber, terms) in terms_by_fiber {
95        let head = Arc::clone(fiber.head());
96        let bucket = out.entry(head).or_default();
97        for t in terms {
98            if !bucket.contains(t) {
99                bucket.push(t.clone());
100            }
101        }
102    }
103    out
104}
105
106/// Topologically sort the theory's sorts so that parameter sorts are
107/// ordered before the dependent sorts that reference them. Returns sort
108/// names in dependency order.
109///
110/// # Errors
111///
112/// Returns [`GatError::CyclicSortDependency`] if cyclic dependencies exist.
113fn topological_sort_sorts(theory: &Theory) -> Result<Vec<Arc<str>>, GatError> {
114    let sort_names: FxHashSet<Arc<str>> =
115        theory.sorts.iter().map(|s| Arc::clone(&s.name)).collect();
116    let mut in_degree: FxHashMap<Arc<str>, usize> = FxHashMap::default();
117    let mut dependents: FxHashMap<Arc<str>, Vec<Arc<str>>> = FxHashMap::default();
118
119    for sort in &theory.sorts {
120        in_degree.entry(Arc::clone(&sort.name)).or_insert(0);
121        for param in &sort.params {
122            let param_head = param.sort.head();
123            if sort_names.contains(param_head) {
124                *in_degree.entry(Arc::clone(&sort.name)).or_insert(0) += 1;
125                dependents
126                    .entry(Arc::clone(param_head))
127                    .or_default()
128                    .push(Arc::clone(&sort.name));
129            }
130        }
131    }
132
133    let mut initial: Vec<Arc<str>> = in_degree
134        .iter()
135        .filter(|(_, deg)| **deg == 0)
136        .map(|(name, _)| Arc::clone(name))
137        .collect();
138    initial.sort(); // Deterministic ordering.
139    let mut queue: VecDeque<Arc<str>> = initial.into_iter().collect();
140
141    let mut result = Vec::new();
142    while let Some(name) = queue.pop_front() {
143        result.push(Arc::clone(&name));
144        if let Some(deps) = dependents.get(&name) {
145            for dep in deps {
146                if let Some(deg) = in_degree.get_mut(dep) {
147                    *deg = deg.saturating_sub(1);
148                    if *deg == 0 {
149                        queue.push_back(Arc::clone(dep));
150                    }
151                }
152            }
153        }
154    }
155
156    // Reject cyclic sort dependencies instead of silently appending.
157    if result.len() < theory.sorts.len() {
158        let cyclic: Vec<String> = theory
159            .sorts
160            .iter()
161            .filter(|s| !result.contains(&s.name))
162            .map(|s| s.name.to_string())
163            .collect();
164        return Err(GatError::CyclicSortDependency(cyclic));
165    }
166
167    Ok(result)
168}
169
170/// Phase 1: Generate all closed terms up to `max_depth`, indexed by sort.
171///
172/// For dependent sorts `S(x1: A1, ..., xn: An)`, terms are generated
173/// fiber-by-fiber: for each tuple of parameter values drawn from the
174/// carrier sets of A1...An, we find operations whose output sort is S
175/// and whose parameter inputs match the fiber. All fiber terms are
176/// collected under the base sort name S.
177/// Returns `(terms_by_sort, is_complete)` where `is_complete` is `true`
178/// when no new terms were generated at the final depth level.
179fn generate_terms(
180    theory: &Theory,
181    config: &FreeModelConfig,
182) -> Result<(FxHashMap<SortExpr, Vec<Term>>, bool), GatError> {
183    #![allow(clippy::type_complexity)]
184    let mut terms_by_fiber: FxHashMap<SortExpr, Vec<Term>> = FxHashMap::default();
185
186    // Run topological sort for its cycle check; the head ordering is no
187    // longer consulted directly because we file terms under their
188    // instantiated output sort.
189    let _ = topological_sort_sorts(theory)?;
190
191    // Seed: nullary operations. A nullary op's output sort cannot
192    // reference any input (there are none), so `op.output` is already a
193    // closed sort expression.
194    for op in &theory.ops {
195        if op.inputs.is_empty() {
196            let term = Term::constant(Arc::clone(&op.name));
197            let fiber = op.output.clone();
198            let bucket = terms_by_fiber.entry(fiber).or_default();
199            if !bucket.contains(&term) {
200                bucket.push(term);
201            }
202        }
203    }
204
205    let mut last_depth_added = false;
206    for _depth in 1..=config.max_depth {
207        let new_terms = generate_depth(theory, &terms_by_fiber);
208
209        let mut added_any = false;
210        for (fiber, new) in new_terms {
211            let bucket = terms_by_fiber.entry(fiber.clone()).or_default();
212            for t in new {
213                if bucket.len() >= config.max_terms_per_sort {
214                    let head = fiber.head();
215                    return Err(GatError::ModelError(format!(
216                        "term count for sort '{head}' exceeds limit {}",
217                        config.max_terms_per_sort
218                    )));
219                }
220                if !bucket.contains(&t) {
221                    bucket.push(t);
222                    added_any = true;
223                }
224            }
225        }
226        last_depth_added = added_any;
227    }
228
229    let is_complete = !last_depth_added;
230    Ok((terms_by_fiber, is_complete))
231}
232
233/// Generate one depth level of terms by applying non-nullary ops to
234/// existing terms, matching argument fibers against the declared input
235/// sort expressions under a running substitution.
236fn generate_depth(
237    theory: &Theory,
238    terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
239) -> FxHashMap<SortExpr, Vec<Term>> {
240    let mut new_terms: FxHashMap<SortExpr, Vec<Term>> = FxHashMap::default();
241
242    for op in &theory.ops {
243        if op.inputs.is_empty() {
244            continue;
245        }
246        let mut chosen: Vec<Term> = Vec::with_capacity(op.inputs.len());
247        let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
248        extend_op_tuples(
249            op,
250            0,
251            &mut chosen,
252            &mut theta,
253            terms_by_fiber,
254            &mut new_terms,
255        );
256    }
257
258    new_terms
259}
260
261/// Recursive helper: at slot `i`, try every candidate term whose fiber
262/// matches `op.inputs[i].1.subst(&theta)` and extend θ with the chosen
263/// term, recursing into slot `i + 1`. When `i == op.inputs.len()`,
264/// materialise the application and file it under `op.output.subst(&θ)`.
265fn extend_op_tuples(
266    op: &crate::op::Operation,
267    slot: usize,
268    chosen: &mut Vec<Term>,
269    theta: &mut FxHashMap<Arc<str>, Term>,
270    terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
271    new_terms: &mut FxHashMap<SortExpr, Vec<Term>>,
272) {
273    if slot == op.inputs.len() {
274        let output_fiber = op.output.subst(theta);
275        let term = Term::app(Arc::clone(&op.name), chosen.clone());
276        new_terms.entry(output_fiber).or_default().push(term);
277        return;
278    }
279    let (param_name, declared_sort, _implicit) = &op.inputs[slot];
280    let expected_fiber = declared_sort.subst(theta);
281    let Some(candidates) = terms_by_fiber.get(&expected_fiber) else {
282        return;
283    };
284    for cand in candidates {
285        chosen.push(cand.clone());
286        theta.insert(Arc::clone(param_name), cand.clone());
287        extend_op_tuples(op, slot + 1, chosen, theta, terms_by_fiber, new_terms);
288        theta.remove(param_name);
289        chosen.pop();
290    }
291}
292
293/// Assign consecutive global indices to all generated terms.
294///
295/// Iterates sorts in sort-name order so that the resulting indices are
296/// deterministic across runs, regardless of hash-table insertion order
297/// upstream. Any downstream consumer that hashes or compares free-model
298/// indices (the VCS layer in particular) depends on this determinism.
299fn assign_global_indices(
300    terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
301) -> (FxHashMap<Arc<str>, Vec<usize>>, usize) {
302    let mut global_idx = 0usize;
303    let mut term_to_global: FxHashMap<Arc<str>, Vec<usize>> = FxHashMap::default();
304
305    let mut sorted_keys: Vec<&Arc<str>> = terms_by_sort.keys().collect();
306    sorted_keys.sort();
307    for sort in sorted_keys {
308        let terms = &terms_by_sort[sort];
309        let indices: Vec<usize> = (global_idx..global_idx + terms.len()).collect();
310        global_idx += terms.len();
311        term_to_global.insert(Arc::clone(sort), indices);
312    }
313
314    (term_to_global, global_idx)
315}
316
317/// Phase 2: Quotient terms by equations using union-find with congruence closure.
318///
319/// Runs equation merging and congruence propagation in a fixpoint loop.
320/// Congruence closure ensures that if `t1 ~ t2`, then for every operation
321/// `f`, we also get `f(... t1 ...) ~ f(... t2 ...)` when both terms exist
322/// in the generated set. This is necessary for the free model to be truly
323/// initial (the quotient must be closed under all operation congruences).
324fn quotient_by_equations(
325    theory: &Theory,
326    terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
327    term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
328    total_terms: usize,
329) -> UnionFind {
330    let mut uf = UnionFind::new(total_terms);
331
332    // Precompute variable sorts for each equation.
333    let eq_info: Vec<_> = theory
334        .eqs
335        .iter()
336        .map(|eq| {
337            let vars: Vec<Arc<str>> = {
338                let mut all = eq.lhs.free_vars();
339                all.extend(eq.rhs.free_vars());
340                all.into_iter().collect()
341            };
342            let var_sorts = crate::typecheck::infer_var_sorts(eq, theory).ok();
343            (eq, vars, var_sorts)
344        })
345        .collect();
346
347    // Build a congruence index: for each compound term f(a1, ..., an),
348    // record (op_name, [global_idx_of_a1, ..., global_idx_of_an]) -> global_idx.
349    // This allows efficient congruence closure propagation.
350    let congruence_entries = build_congruence_index(terms_by_sort, term_to_global);
351
352    // Fixpoint loop: keep merging until no new merges occur.
353    loop {
354        let merges_before = uf.merge_count;
355
356        // Pass 1: equation substitution instances.
357        for (eq, vars, var_sorts) in &eq_info {
358            if vars.is_empty() {
359                merge_constant_eq(eq, terms_by_sort, term_to_global, &mut uf);
360                continue;
361            }
362
363            let Some(vs) = var_sorts else {
364                continue;
365            };
366
367            merge_by_equation(eq, vars, vs, terms_by_sort, term_to_global, &mut uf);
368        }
369
370        // Pass 2: congruence closure. If t1 ~ t2, then f(..., t1, ...) ~ f(..., t2, ...)
371        // for all operations f where both compound terms exist.
372        congruence_closure_pass(&congruence_entries, &mut uf);
373
374        if uf.merge_count == merges_before {
375            break;
376        }
377    }
378
379    uf
380}
381
382/// Entry in the congruence index: a compound term with its operation name,
383/// subterm global indices, and its own global index.
384struct CongruenceEntry {
385    /// Global index of this term.
386    term_idx: usize,
387    /// Global indices of each subterm (argument).
388    arg_indices: Vec<usize>,
389}
390
391/// Build an index of all compound (non-nullary) generated terms, grouped by
392/// operation name and arity. This enables efficient congruence closure.
393fn build_congruence_index(
394    terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
395    term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
396) -> FxHashMap<Arc<str>, Vec<CongruenceEntry>> {
397    let mut index: FxHashMap<Arc<str>, Vec<CongruenceEntry>> = FxHashMap::default();
398
399    // Build a flat lookup: term -> global_idx for subterm resolution.
400    let mut term_lookup: FxHashMap<&Term, usize> = FxHashMap::default();
401    for (sort, terms) in terms_by_sort {
402        let indices = &term_to_global[sort];
403        for (i, term) in terms.iter().enumerate() {
404            term_lookup.insert(term, indices[i]);
405        }
406    }
407
408    for (sort, terms) in terms_by_sort {
409        let indices = &term_to_global[sort];
410        for (i, term) in terms.iter().enumerate() {
411            if let Term::App { op, args } = term {
412                if args.is_empty() {
413                    continue;
414                }
415                let arg_indices: Vec<usize> = args
416                    .iter()
417                    .filter_map(|arg| term_lookup.get(arg).copied())
418                    .collect();
419                // Only include if all subterms were found.
420                if arg_indices.len() == args.len() {
421                    index
422                        .entry(Arc::clone(op))
423                        .or_default()
424                        .push(CongruenceEntry {
425                            term_idx: indices[i],
426                            arg_indices,
427                        });
428                }
429            }
430        }
431    }
432
433    index
434}
435
436/// Propagate congruence: for terms sharing the same operation, if their
437/// argument tuples are pointwise equivalent under the union-find, merge them.
438fn congruence_closure_pass(
439    entries: &FxHashMap<Arc<str>, Vec<CongruenceEntry>>,
440    uf: &mut UnionFind,
441) {
442    for group in entries.values() {
443        if group.len() < 2 {
444            continue;
445        }
446        // Group entries by their canonical argument tuple.
447        let mut canonical_groups: FxHashMap<Vec<usize>, usize> = FxHashMap::default();
448        for entry in group {
449            let canonical_args: Vec<usize> =
450                entry.arg_indices.iter().map(|&i| uf.find(i)).collect();
451            if let Some(&representative) = canonical_groups.get(&canonical_args) {
452                uf.union(representative, entry.term_idx);
453            } else {
454                canonical_groups.insert(canonical_args, uf.find(entry.term_idx));
455            }
456        }
457    }
458}
459
460/// Phase 3: Build the Model from equivalence class representatives.
461/// Format a term as a human-readable string (e.g., `mul(unit(), x)`).
462///
463/// This must be used consistently for both carrier set values and
464/// operation results to ensure that `check_model` can match them.
465/// Check whether a term is built entirely from `Var` and `App` nodes.
466/// The free-model generator only produces App-only terms, so this
467/// invariant holds on every term that reaches `build_model`'s
468/// stringification path and makes `term_to_string` injective.
469fn is_app_only(term: &Term) -> bool {
470    match term {
471        Term::Var(_) => true,
472        Term::App { args, .. } => args.iter().all(is_app_only),
473        Term::Case { .. } | Term::Hole { .. } | Term::Let { .. } => false,
474    }
475}
476
477fn term_to_string(term: &Term) -> String {
478    match term {
479        Term::Var(name) => name.to_string(),
480        Term::App { op, args } if args.is_empty() => format!("{op}()"),
481        Term::App { op, args } => {
482            let arg_strs: Vec<String> = args.iter().map(term_to_string).collect();
483            format!("{op}({})", arg_strs.join(", "))
484        }
485        Term::Case {
486            scrutinee,
487            branches,
488        } => {
489            let branch_strs: Vec<String> = branches
490                .iter()
491                .map(|b| {
492                    let binders = b
493                        .binders
494                        .iter()
495                        .map(ToString::to_string)
496                        .collect::<Vec<_>>();
497                    format!(
498                        "{}({}) => {}",
499                        b.constructor,
500                        binders.join(", "),
501                        term_to_string(&b.body)
502                    )
503                })
504                .collect();
505            format!(
506                "case {} of {} end",
507                term_to_string(scrutinee),
508                branch_strs.join(" | ")
509            )
510        }
511        Term::Hole { name } => name
512            .as_ref()
513            .map_or_else(|| "?".to_string(), |n| format!("?{n}")),
514        Term::Let { name, bound, body } => format!(
515            "let {name} = {} in {}",
516            term_to_string(bound),
517            term_to_string(body)
518        ),
519    }
520}
521
522fn build_model(
523    theory: &Theory,
524    terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
525    term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
526    uf: &mut UnionFind,
527) -> Model {
528    let mut model = Model::new(&*theory.name);
529
530    // String-keyed representative lookup. Safe because the free-model
531    // generator emits only `Term::App` nodes via `extend_op_tuples`;
532    // `term_to_string` is injective on App-only terms with App-only
533    // arguments, so stringification does not collide across terms.
534    // The debug assertion guards the invariant: every term seen here
535    // must be App-only (holes, case terms, and let bindings are
536    // produced by user input to the typechecker, never by the free
537    // model enumerator).
538    let mut class_rep_string: FxHashMap<usize, String> = FxHashMap::default();
539    let mut string_to_rep: FxHashMap<String, String> = FxHashMap::default();
540    for (sort, terms) in terms_by_sort {
541        let indices = &term_to_global[sort];
542        let mut seen_classes: FxHashSet<usize> = FxHashSet::default();
543
544        for (i, term) in terms.iter().enumerate() {
545            debug_assert!(
546                is_app_only(term),
547                "free-model generator emitted a non-App term: {term:?}",
548            );
549            let rep = uf.find(indices[i]);
550            if seen_classes.insert(rep) {
551                // First term in this class becomes the representative string.
552                class_rep_string.insert(rep, term_to_string(term));
553            }
554            let rep_str = class_rep_string[&rep].clone();
555            string_to_rep.insert(term_to_string(term), rep_str);
556        }
557    }
558
559    // Build carrier sets using class representatives.
560    for (sort, terms) in terms_by_sort {
561        let indices = &term_to_global[sort];
562        let mut seen_classes: FxHashSet<usize> = FxHashSet::default();
563        let mut carrier = Vec::new();
564
565        for (i, term) in terms.iter().enumerate() {
566            let rep = uf.find(indices[i]);
567            if seen_classes.insert(rep) {
568                carrier.push(ModelValue::Str(term_to_string(term)));
569            }
570        }
571        model.add_sort(sort.to_string(), carrier);
572    }
573
574    // Build operation interpretations that map carrier → carrier.
575    // The lookup table is shared via Arc for the closures.
576    let lookup = Arc::new(string_to_rep);
577
578    for op in &theory.ops {
579        let op_name = op.name.to_string();
580        let arity = op.arity();
581        let table = Arc::clone(&lookup);
582        model.add_op(op_name.clone(), move |args: &[ModelValue]| {
583            if args.len() != arity {
584                return Err(GatError::ModelError(format!(
585                    "operation '{op_name}' expects {arity} args, got {}",
586                    args.len()
587                )));
588            }
589            // Carrier values are always ModelValue::Str here because
590            // free_model emits string carriers via term_to_string. A
591            // non-string argument indicates a caller bug; surface it
592            // rather than silently rendering as "?".
593            let mut arg_strs: Vec<String> = Vec::with_capacity(args.len());
594            for (i, a) in args.iter().enumerate() {
595                match a {
596                    ModelValue::Str(s) => arg_strs.push(s.clone()),
597                    other => {
598                        return Err(GatError::ModelError(format!(
599                            "operation '{op_name}' received non-string argument at index {i}: {other:?}"
600                        )));
601                    }
602                }
603            }
604            let result_str = format!("{op_name}({})", arg_strs.join(", "));
605
606            // Look up the result in the term table. If found, return the
607            // equivalence class representative. If not found (term exceeds
608            // depth bound), return the formatted string as-is.
609            Ok(ModelValue::Str(
610                table.get(&result_str).map_or(result_str, String::clone),
611            ))
612        });
613    }
614
615    model
616}
617
618/// Merge terms identified by a constants-only equation.
619fn merge_constant_eq(
620    eq: &crate::eq::Equation,
621    terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
622    term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
623    uf: &mut UnionFind,
624) {
625    let lhs_idx = find_term_index(&eq.lhs, terms_by_sort, term_to_global);
626    let rhs_idx = find_term_index(&eq.rhs, terms_by_sort, term_to_global);
627    if let (Some(l), Some(r)) = (lhs_idx, rhs_idx) {
628        uf.union(l, r);
629    }
630}
631
632/// Find the global index of a closed term in the generated term set.
633fn find_term_index(
634    term: &Term,
635    terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
636    term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
637) -> Option<usize> {
638    for (sort, terms) in terms_by_sort {
639        for (i, t) in terms.iter().enumerate() {
640            if t == term {
641                return Some(term_to_global[sort][i]);
642            }
643        }
644    }
645    None
646}
647
648/// Enumerate substitutions and merge LHS/RHS when both match generated terms.
649fn merge_by_equation(
650    eq: &crate::eq::Equation,
651    vars: &[Arc<str>],
652    var_sorts: &FxHashMap<Arc<str>, SortExpr>,
653    terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
654    term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
655    uf: &mut UnionFind,
656) {
657    let var_terms: Vec<(&Arc<str>, &Vec<Term>)> = vars
658        .iter()
659        .filter_map(|v| {
660            let sort = var_sorts.get(v)?;
661            let terms = terms_by_sort.get(sort.head())?;
662            Some((v, terms))
663        })
664        .collect();
665
666    if var_terms.len() != vars.len() || var_terms.iter().any(|(_, terms)| terms.is_empty()) {
667        return;
668    }
669
670    let mut indices = vec![0usize; var_terms.len()];
671
672    loop {
673        let mut subst = rustc_hash::FxHashMap::default();
674        for (i, (var, terms)) in var_terms.iter().enumerate() {
675            subst.insert(Arc::clone(var), terms[indices[i]].clone());
676        }
677
678        let lhs = eq.lhs.substitute(&subst);
679        let rhs = eq.rhs.substitute(&subst);
680
681        let lhs_idx = find_term_index(&lhs, terms_by_sort, term_to_global);
682        let rhs_idx = find_term_index(&rhs, terms_by_sort, term_to_global);
683        if let (Some(l), Some(r)) = (lhs_idx, rhs_idx) {
684            uf.union(l, r);
685        }
686
687        let mut carry = true;
688        for i in (0..indices.len()).rev() {
689            if carry {
690                indices[i] += 1;
691                if indices[i] < var_terms[i].1.len() {
692                    carry = false;
693                } else {
694                    indices[i] = 0;
695                }
696            }
697        }
698        if carry {
699            break;
700        }
701    }
702}
703
704/// Simple union-find with path compression and union by rank.
705struct UnionFind {
706    parent: Vec<usize>,
707    rank: Vec<usize>,
708    /// Total number of union operations that actually merged distinct classes.
709    merge_count: usize,
710}
711
712impl UnionFind {
713    fn new(size: usize) -> Self {
714        Self {
715            parent: (0..size).collect(),
716            rank: vec![0; size],
717            merge_count: 0,
718        }
719    }
720
721    fn find(&mut self, mut x: usize) -> usize {
722        while self.parent[x] != x {
723            self.parent[x] = self.parent[self.parent[x]]; // Path splitting.
724            x = self.parent[x];
725        }
726        x
727    }
728
729    fn union(&mut self, x: usize, y: usize) {
730        let rx = self.find(x);
731        let ry = self.find(y);
732        if rx == ry {
733            return;
734        }
735        self.merge_count += 1;
736        match self.rank[rx].cmp(&self.rank[ry]) {
737            std::cmp::Ordering::Less => self.parent[rx] = ry,
738            std::cmp::Ordering::Greater => self.parent[ry] = rx,
739            std::cmp::Ordering::Equal => {
740                self.parent[ry] = rx;
741                self.rank[rx] += 1;
742            }
743        }
744    }
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750    use crate::eq::Equation;
751    use crate::op::Operation;
752    use crate::sort::Sort;
753    use crate::theory::Theory;
754
755    #[test]
756    fn free_model_of_pointed_set() -> Result<(), Box<dyn std::error::Error>> {
757        let theory = Theory::new(
758            "PointedSet",
759            vec![Sort::simple("Carrier")],
760            vec![Operation::nullary("unit", "Carrier")],
761            vec![],
762        );
763        let result = free_model(&theory, &FreeModelConfig::default())?;
764        assert_eq!(result.model.sort_interp["Carrier"].len(), 1);
765        Ok(())
766    }
767
768    #[test]
769    fn free_model_empty_theory() -> Result<(), Box<dyn std::error::Error>> {
770        let theory = Theory::new("Empty", vec![Sort::simple("S")], vec![], vec![]);
771        let model = free_model(&theory, &FreeModelConfig::default())?.model;
772        assert!(model.sort_interp["S"].is_empty());
773        Ok(())
774    }
775
776    #[test]
777    fn free_model_two_constants() -> Result<(), Box<dyn std::error::Error>> {
778        let theory = Theory::new(
779            "TwoPoints",
780            vec![Sort::simple("S")],
781            vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
782            vec![],
783        );
784        let model = free_model(&theory, &FreeModelConfig::default())?.model;
785        assert_eq!(model.sort_interp["S"].len(), 2);
786        Ok(())
787    }
788
789    #[test]
790    fn free_model_equation_collapses_constants() -> Result<(), Box<dyn std::error::Error>> {
791        let theory = Theory::new(
792            "CollapsedPoints",
793            vec![Sort::simple("S")],
794            vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
795            vec![Equation::new(
796                "a_eq_b",
797                Term::constant("a"),
798                Term::constant("b"),
799            )],
800        );
801        let model = free_model(&theory, &FreeModelConfig::default())?.model;
802        assert_eq!(model.sort_interp["S"].len(), 1);
803        Ok(())
804    }
805
806    #[test]
807    fn free_model_monoid_identity_collapses() -> Result<(), Box<dyn std::error::Error>> {
808        let theory = Theory::new(
809            "Monoid",
810            vec![Sort::simple("Carrier")],
811            vec![
812                Operation::new(
813                    "mul",
814                    vec![
815                        ("a".into(), "Carrier".into()),
816                        ("b".into(), "Carrier".into()),
817                    ],
818                    "Carrier",
819                ),
820                Operation::nullary("unit", "Carrier"),
821            ],
822            vec![
823                Equation::new(
824                    "left_id",
825                    Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
826                    Term::var("a"),
827                ),
828                Equation::new(
829                    "right_id",
830                    Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
831                    Term::var("a"),
832                ),
833            ],
834        );
835        let config = FreeModelConfig {
836            max_depth: 1,
837            max_terms_per_sort: 100,
838        };
839        let model = free_model(&theory, &config)?.model;
840        assert_eq!(model.sort_interp["Carrier"].len(), 1);
841        Ok(())
842    }
843
844    #[test]
845    fn free_model_graph_theory() -> Result<(), Box<dyn std::error::Error>> {
846        let theory = Theory::new(
847            "Graph",
848            vec![Sort::simple("Vertex"), Sort::simple("Edge")],
849            vec![
850                Operation::unary("src", "e", "Edge", "Vertex"),
851                Operation::unary("tgt", "e", "Edge", "Vertex"),
852            ],
853            vec![],
854        );
855        let model = free_model(&theory, &FreeModelConfig::default())?.model;
856        assert!(model.sort_interp["Vertex"].is_empty());
857        assert!(model.sort_interp["Edge"].is_empty());
858        Ok(())
859    }
860
861    #[test]
862    fn free_model_term_count_bounded() {
863        let theory = Theory::new(
864            "Chain",
865            vec![Sort::simple("S")],
866            vec![
867                Operation::nullary("zero", "S"),
868                Operation::unary("succ", "x", "S", "S"),
869            ],
870            vec![],
871        );
872        let config = FreeModelConfig {
873            max_depth: 10,
874            max_terms_per_sort: 5,
875        };
876        let result = free_model(&theory, &config);
877        assert!(matches!(result, Err(GatError::ModelError(_))));
878    }
879
880    /// Free model of a category theory with dependent sorts.
881    /// Ob (objects), Hom(a: Ob, b: Ob) (morphisms), id: Ob -> Hom(a, a).
882    /// With one object constant, should generate the identity morphism.
883    #[test]
884    fn free_model_category_theory() -> Result<(), Box<dyn std::error::Error>> {
885        use crate::sort::SortParam;
886
887        let theory = Theory::new(
888            "Category",
889            vec![
890                Sort::simple("Ob"),
891                Sort::dependent(
892                    "Hom",
893                    vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
894                ),
895            ],
896            vec![
897                Operation::nullary("star", "Ob"),
898                // id: Ob -> Hom (in practice produces Hom(x, x))
899                Operation::unary("id", "x", "Ob", "Hom"),
900            ],
901            Vec::new(),
902        );
903
904        let config = FreeModelConfig {
905            max_depth: 2,
906            max_terms_per_sort: 100,
907        };
908        let model = free_model(&theory, &config)?.model;
909
910        // Ob should have one element: star().
911        assert_eq!(model.sort_interp["Ob"].len(), 1);
912
913        // Hom should have at least id(star()).
914        assert!(
915            !model.sort_interp["Hom"].is_empty(),
916            "Hom should have at least the identity morphism"
917        );
918        Ok(())
919    }
920
921    /// Dependent sort with no operations targeting it produces empty carrier.
922    #[test]
923    fn free_model_dependent_sort_no_ops() -> Result<(), Box<dyn std::error::Error>> {
924        use crate::sort::SortParam;
925
926        let theory = Theory::new(
927            "T",
928            vec![
929                Sort::simple("A"),
930                Sort::dependent("B", vec![SortParam::new("x", "A")]),
931            ],
932            vec![Operation::nullary("a", "A")],
933            Vec::new(),
934        );
935
936        let model = free_model(&theory, &FreeModelConfig::default())?.model;
937        assert_eq!(model.sort_interp["A"].len(), 1);
938        assert!(
939            model.sort_interp["B"].is_empty(),
940            "B has no operations targeting it, so carrier should be empty"
941        );
942        Ok(())
943    }
944
945    /// Topological ordering ensures parameter sorts are populated first.
946    #[test]
947    fn free_model_sort_ordering() -> Result<(), Box<dyn std::error::Error>> {
948        use crate::sort::SortParam;
949
950        // Deliberately put the dependent sort first in the list.
951        let theory = Theory::new(
952            "T",
953            vec![
954                Sort::dependent("B", vec![SortParam::new("x", "A")]),
955                Sort::simple("A"),
956            ],
957            vec![
958                Operation::nullary("a", "A"),
959                Operation::unary("f", "x", "A", "B"),
960            ],
961            Vec::new(),
962        );
963
964        let config = FreeModelConfig {
965            max_depth: 1,
966            max_terms_per_sort: 100,
967        };
968        let model = free_model(&theory, &config)?.model;
969
970        // A should have a().
971        assert_eq!(model.sort_interp["A"].len(), 1);
972        // B should have f(a()).
973        assert_eq!(model.sort_interp["B"].len(), 1);
974        Ok(())
975    }
976
977    #[test]
978    fn free_model_operations_work() -> Result<(), Box<dyn std::error::Error>> {
979        let theory = Theory::new(
980            "PointedSet",
981            vec![Sort::simple("Carrier")],
982            vec![Operation::nullary("unit", "Carrier")],
983            vec![],
984        );
985        let model = free_model(&theory, &FreeModelConfig::default())?.model;
986        let result = model.eval("unit", &[])?;
987        assert!(matches!(result, ModelValue::Str(_)));
988        Ok(())
989    }
990
991    #[test]
992    fn free_model_congruence_closure() -> Result<(), Box<dyn std::error::Error>> {
993        // Theory with a = b and f: S -> S.
994        // Congruence closure requires f(a) ~ f(b), even though no equation
995        // directly equates them. The equation a = b combined with the
996        // congruence rule for f must produce this.
997        let theory = Theory::new(
998            "Congruence",
999            vec![Sort::simple("S")],
1000            vec![
1001                Operation::nullary("a", "S"),
1002                Operation::nullary("b", "S"),
1003                Operation::unary("f", "x", "S", "S"),
1004            ],
1005            vec![Equation::new(
1006                "a_eq_b",
1007                Term::constant("a"),
1008                Term::constant("b"),
1009            )],
1010        );
1011        let config = FreeModelConfig {
1012            max_depth: 1,
1013            max_terms_per_sort: 100,
1014        };
1015        let model = free_model(&theory, &config)?.model;
1016        // a ~ b, so f(a) ~ f(b). The carrier should have at most 2 elements:
1017        // one equivalence class for {a, b} and one for {f(a), f(b)}.
1018        assert_eq!(
1019            model.sort_interp["S"].len(),
1020            2,
1021            "a ~ b and f(a) ~ f(b) by congruence: expect 2 classes"
1022        );
1023        Ok(())
1024    }
1025
1026    /// Free category on two generating morphisms. With one object and
1027    /// one endo-generator, the expected terms at depth 2 are:
1028    /// `id(star)`, `f(star)`, `f(f(star))`. Exactly three morphisms in
1029    /// the `Hom` fiber, demonstrating that fiber matching prevents the
1030    /// combinatorial blow-up of a cartesian-product model.
1031    #[test]
1032    fn free_model_dependent_category() -> Result<(), Box<dyn std::error::Error>> {
1033        use crate::sort::{SortExpr, SortParam};
1034
1035        let hom_xx = SortExpr::App {
1036            name: Arc::from("Hom"),
1037            args: vec![Term::var("x"), Term::var("x")],
1038        };
1039        let theory = Theory::new(
1040            "EndoCategory",
1041            vec![
1042                Sort::simple("Ob"),
1043                Sort::dependent(
1044                    "Hom",
1045                    vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
1046                ),
1047            ],
1048            vec![
1049                Operation::nullary("star", "Ob"),
1050                Operation::unary("id", "x", "Ob", hom_xx.clone()),
1051                Operation::unary("f", "x", "Ob", hom_xx),
1052            ],
1053            Vec::new(),
1054        );
1055
1056        let config = FreeModelConfig {
1057            max_depth: 2,
1058            max_terms_per_sort: 100,
1059        };
1060        let model = free_model(&theory, &config)?.model;
1061
1062        // Ob: exactly one element (star).
1063        assert_eq!(model.sort_interp["Ob"].len(), 1);
1064        // Hom: id(star), f(star); f(f(star)) exists only if we stack
1065        // unary ops with matching fibers, which holds here because
1066        // f: (x: Ob) -> Hom(x, x) maps an Ob to a Hom(x, x), but the
1067        // argument of f must itself be an Ob. So at depth 2, the Hom
1068        // carrier holds id(star) and f(star) only.
1069        assert_eq!(
1070            model.sort_interp["Hom"].len(),
1071            2,
1072            "expected id(star) and f(star) in Hom fiber"
1073        );
1074        Ok(())
1075    }
1076
1077    /// Free category on two parallel generating arrows `f, g : Hom(a, b)`
1078    /// at depth 1 has `{id(a), id(b), f(a, b), g(a, b)}` (4 morphisms), not
1079    /// 6 or more. `compose(f, g)` cannot form because `tgt(f) = b` but
1080    /// `src(g) = a`, so the middle-object constraint rules out composites
1081    /// in either order. This is the fiber-matching test that distinguishes
1082    /// a dependent-sort-aware generator from a cartesian-product
1083    /// generator.
1084    #[test]
1085    fn free_model_parallel_arrows_no_spurious_composites() -> Result<(), Box<dyn std::error::Error>>
1086    {
1087        use crate::sort::{SortExpr, SortParam};
1088
1089        let hom_ab = SortExpr::App {
1090            name: Arc::from("Hom"),
1091            args: vec![Term::constant("a"), Term::constant("b")],
1092        };
1093        let hom_xy = SortExpr::App {
1094            name: Arc::from("Hom"),
1095            args: vec![Term::var("x"), Term::var("y")],
1096        };
1097        let theory = Theory::new(
1098            "ParallelArrows",
1099            vec![
1100                Sort::simple("Ob"),
1101                Sort::dependent(
1102                    "Hom",
1103                    vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
1104                ),
1105            ],
1106            vec![
1107                Operation::nullary("a", "Ob"),
1108                Operation::nullary("b", "Ob"),
1109                Operation::nullary("f", hom_ab.clone()),
1110                Operation::nullary("g", hom_ab),
1111                Operation::unary(
1112                    "id",
1113                    "x",
1114                    "Ob",
1115                    SortExpr::App {
1116                        name: Arc::from("Hom"),
1117                        args: vec![Term::var("x"), Term::var("x")],
1118                    },
1119                ),
1120                Operation::new(
1121                    "compose",
1122                    vec![
1123                        (Arc::from("x"), SortExpr::from("Ob")),
1124                        (Arc::from("y"), SortExpr::from("Ob")),
1125                        (Arc::from("z"), SortExpr::from("Ob")),
1126                        (
1127                            Arc::from("h1"),
1128                            SortExpr::App {
1129                                name: Arc::from("Hom"),
1130                                args: vec![Term::var("x"), Term::var("y")],
1131                            },
1132                        ),
1133                        (
1134                            Arc::from("h2"),
1135                            SortExpr::App {
1136                                name: Arc::from("Hom"),
1137                                args: vec![Term::var("y"), Term::var("z")],
1138                            },
1139                        ),
1140                    ],
1141                    hom_xy,
1142                ),
1143            ],
1144            Vec::new(),
1145        );
1146
1147        let config = FreeModelConfig {
1148            max_depth: 1,
1149            max_terms_per_sort: 100,
1150        };
1151        let model = free_model(&theory, &config)?.model;
1152
1153        // Ob: {a, b}.
1154        assert_eq!(model.sort_interp["Ob"].len(), 2);
1155        // Hom at depth 1: id(a), id(b), f, g. No composites because f and
1156        // g share source/target (a, b), so compose(f, g) would require
1157        // tgt(f) = b = src(g) = a, which fails.
1158        assert_eq!(
1159            model.sort_interp["Hom"].len(),
1160            4,
1161            "Hom fiber should contain {{id(a), id(b), f, g}}, got {:?}",
1162            model.sort_interp["Hom"],
1163        );
1164        Ok(())
1165    }
1166
1167    /// Every term in the free model has a well-typed output sort.
1168    #[test]
1169    fn free_model_every_term_well_typed() -> Result<(), Box<dyn std::error::Error>> {
1170        use crate::sort::{SortExpr, SortParam};
1171        use crate::typecheck::{VarContext, typecheck_term};
1172
1173        let hom_xx = SortExpr::App {
1174            name: Arc::from("Hom"),
1175            args: vec![Term::var("x"), Term::var("x")],
1176        };
1177        let theory = Theory::new(
1178            "EndoCat",
1179            vec![
1180                Sort::simple("Ob"),
1181                Sort::dependent(
1182                    "Hom",
1183                    vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
1184                ),
1185            ],
1186            vec![
1187                Operation::nullary("star", "Ob"),
1188                Operation::unary("id", "x", "Ob", hom_xx.clone()),
1189                Operation::unary("f", "x", "Ob", hom_xx),
1190            ],
1191            Vec::new(),
1192        );
1193
1194        let config = FreeModelConfig {
1195            max_depth: 2,
1196            max_terms_per_sort: 100,
1197        };
1198        let (fibers, _) = generate_terms(&theory, &config)?;
1199        let ctx = VarContext::default();
1200        for (fiber, terms) in &fibers {
1201            for term in terms {
1202                let inferred = typecheck_term(term, &ctx, &theory)?;
1203                assert!(
1204                    inferred.alpha_eq(fiber),
1205                    "term {term} has fiber {fiber} but typecheck inferred {inferred}",
1206                );
1207            }
1208        }
1209        Ok(())
1210    }
1211
1212    /// For theories with only simple sorts, the fiber-matching generator
1213    /// reduces to the old head-indexed cartesian-product behavior. Verify
1214    /// that a simple-sort graph theory produces the expected carrier
1215    /// counts.
1216    #[test]
1217    fn free_model_simple_sorts_backward_compat() -> Result<(), Box<dyn std::error::Error>> {
1218        let theory = Theory::new(
1219            "Graph",
1220            vec![Sort::simple("Vertex"), Sort::simple("Edge")],
1221            vec![
1222                Operation::nullary("v0", "Vertex"),
1223                Operation::nullary("v1", "Vertex"),
1224                Operation::unary("src", "e", "Edge", "Vertex"),
1225                Operation::unary("tgt", "e", "Edge", "Vertex"),
1226            ],
1227            Vec::new(),
1228        );
1229        let config = FreeModelConfig {
1230            max_depth: 1,
1231            max_terms_per_sort: 100,
1232        };
1233        let model = free_model(&theory, &config)?.model;
1234        // Two vertices at depth 0; src/tgt need an Edge which is empty.
1235        // So at any depth: Vertex = {v0, v1}, Edge = {}.
1236        assert_eq!(model.sort_interp["Vertex"].len(), 2);
1237        assert!(model.sort_interp["Edge"].is_empty());
1238        Ok(())
1239    }
1240
1241    #[test]
1242    fn free_model_cyclic_sort_dependency_rejected() {
1243        use crate::sort::SortParam;
1244
1245        // Sort A depends on B and B depends on A: cyclic.
1246        let theory = Theory::new(
1247            "Cyclic",
1248            vec![
1249                Sort::dependent("A", vec![SortParam::new("x", "B")]),
1250                Sort::dependent("B", vec![SortParam::new("y", "A")]),
1251            ],
1252            vec![],
1253            vec![],
1254        );
1255        let result = free_model(&theory, &FreeModelConfig::default());
1256        assert!(
1257            matches!(result, Err(GatError::CyclicSortDependency(_))),
1258            "cyclic sort dependencies should be rejected"
1259        );
1260    }
1261}