use std::collections::VecDeque;
use std::sync::Arc;
use rustc_hash::{FxHashMap, FxHashSet};
use crate::eq::Term;
use crate::error::GatError;
use crate::model::{Model, ModelValue};
use crate::sort::SortExpr;
use crate::theory::Theory;
#[derive(Debug, Clone)]
pub struct FreeModelConfig {
pub max_depth: usize,
pub max_terms_per_sort: usize,
}
impl Default for FreeModelConfig {
fn default() -> Self {
Self {
max_depth: 3,
max_terms_per_sort: 1000,
}
}
}
#[derive(Debug)]
pub struct FreeModelResult {
pub model: Model,
pub is_complete: bool,
}
pub fn free_model(theory: &Theory, config: &FreeModelConfig) -> Result<FreeModelResult, GatError> {
let (terms_by_fiber, is_complete) = generate_terms(theory, config)?;
let mut terms_by_sort = collapse_fibers(&terms_by_fiber);
for sort in &theory.sorts {
terms_by_sort.entry(Arc::clone(&sort.name)).or_default();
}
let (term_to_global, total_terms) = assign_global_indices(&terms_by_sort);
let mut uf = quotient_by_equations(theory, &terms_by_sort, &term_to_global, total_terms);
let model = build_model(theory, &terms_by_sort, &term_to_global, &mut uf);
Ok(FreeModelResult { model, is_complete })
}
fn collapse_fibers(
terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
) -> FxHashMap<Arc<str>, Vec<Term>> {
let mut out: FxHashMap<Arc<str>, Vec<Term>> = FxHashMap::default();
for (fiber, terms) in terms_by_fiber {
let head = Arc::clone(fiber.head());
let bucket = out.entry(head).or_default();
for t in terms {
if !bucket.contains(t) {
bucket.push(t.clone());
}
}
}
out
}
fn topological_sort_sorts(theory: &Theory) -> Result<Vec<Arc<str>>, GatError> {
let sort_names: FxHashSet<Arc<str>> =
theory.sorts.iter().map(|s| Arc::clone(&s.name)).collect();
let mut in_degree: FxHashMap<Arc<str>, usize> = FxHashMap::default();
let mut dependents: FxHashMap<Arc<str>, Vec<Arc<str>>> = FxHashMap::default();
for sort in &theory.sorts {
in_degree.entry(Arc::clone(&sort.name)).or_insert(0);
for param in &sort.params {
let param_head = param.sort.head();
if sort_names.contains(param_head) {
*in_degree.entry(Arc::clone(&sort.name)).or_insert(0) += 1;
dependents
.entry(Arc::clone(param_head))
.or_default()
.push(Arc::clone(&sort.name));
}
}
}
let mut initial: Vec<Arc<str>> = in_degree
.iter()
.filter(|(_, deg)| **deg == 0)
.map(|(name, _)| Arc::clone(name))
.collect();
initial.sort(); let mut queue: VecDeque<Arc<str>> = initial.into_iter().collect();
let mut result = Vec::new();
while let Some(name) = queue.pop_front() {
result.push(Arc::clone(&name));
if let Some(deps) = dependents.get(&name) {
for dep in deps {
if let Some(deg) = in_degree.get_mut(dep) {
*deg = deg.saturating_sub(1);
if *deg == 0 {
queue.push_back(Arc::clone(dep));
}
}
}
}
}
if result.len() < theory.sorts.len() {
let cyclic: Vec<String> = theory
.sorts
.iter()
.filter(|s| !result.contains(&s.name))
.map(|s| s.name.to_string())
.collect();
return Err(GatError::CyclicSortDependency(cyclic));
}
Ok(result)
}
fn generate_terms(
theory: &Theory,
config: &FreeModelConfig,
) -> Result<(FxHashMap<SortExpr, Vec<Term>>, bool), GatError> {
#![allow(clippy::type_complexity)]
let mut terms_by_fiber: FxHashMap<SortExpr, Vec<Term>> = FxHashMap::default();
let _ = topological_sort_sorts(theory)?;
for op in &theory.ops {
if op.inputs.is_empty() {
let term = Term::constant(Arc::clone(&op.name));
let fiber = op.output.clone();
let bucket = terms_by_fiber.entry(fiber).or_default();
if !bucket.contains(&term) {
bucket.push(term);
}
}
}
let mut last_depth_added = false;
for _depth in 1..=config.max_depth {
let new_terms = generate_depth(theory, &terms_by_fiber);
let mut added_any = false;
for (fiber, new) in new_terms {
let bucket = terms_by_fiber.entry(fiber.clone()).or_default();
for t in new {
if bucket.len() >= config.max_terms_per_sort {
let head = fiber.head();
return Err(GatError::ModelError(format!(
"term count for sort '{head}' exceeds limit {}",
config.max_terms_per_sort
)));
}
if !bucket.contains(&t) {
bucket.push(t);
added_any = true;
}
}
}
last_depth_added = added_any;
}
let is_complete = !last_depth_added;
Ok((terms_by_fiber, is_complete))
}
fn generate_depth(
theory: &Theory,
terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
) -> FxHashMap<SortExpr, Vec<Term>> {
let mut new_terms: FxHashMap<SortExpr, Vec<Term>> = FxHashMap::default();
for op in &theory.ops {
if op.inputs.is_empty() {
continue;
}
let mut chosen: Vec<Term> = Vec::with_capacity(op.inputs.len());
let mut theta: FxHashMap<Arc<str>, Term> = FxHashMap::default();
extend_op_tuples(
op,
0,
&mut chosen,
&mut theta,
terms_by_fiber,
&mut new_terms,
);
}
new_terms
}
fn extend_op_tuples(
op: &crate::op::Operation,
slot: usize,
chosen: &mut Vec<Term>,
theta: &mut FxHashMap<Arc<str>, Term>,
terms_by_fiber: &FxHashMap<SortExpr, Vec<Term>>,
new_terms: &mut FxHashMap<SortExpr, Vec<Term>>,
) {
if slot == op.inputs.len() {
let output_fiber = op.output.subst(theta);
let term = Term::app(Arc::clone(&op.name), chosen.clone());
new_terms.entry(output_fiber).or_default().push(term);
return;
}
let (param_name, declared_sort, _implicit) = &op.inputs[slot];
let expected_fiber = declared_sort.subst(theta);
let Some(candidates) = terms_by_fiber.get(&expected_fiber) else {
return;
};
for cand in candidates {
chosen.push(cand.clone());
theta.insert(Arc::clone(param_name), cand.clone());
extend_op_tuples(op, slot + 1, chosen, theta, terms_by_fiber, new_terms);
theta.remove(param_name);
chosen.pop();
}
}
fn assign_global_indices(
terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
) -> (FxHashMap<Arc<str>, Vec<usize>>, usize) {
let mut global_idx = 0usize;
let mut term_to_global: FxHashMap<Arc<str>, Vec<usize>> = FxHashMap::default();
let mut sorted_keys: Vec<&Arc<str>> = terms_by_sort.keys().collect();
sorted_keys.sort();
for sort in sorted_keys {
let terms = &terms_by_sort[sort];
let indices: Vec<usize> = (global_idx..global_idx + terms.len()).collect();
global_idx += terms.len();
term_to_global.insert(Arc::clone(sort), indices);
}
(term_to_global, global_idx)
}
fn quotient_by_equations(
theory: &Theory,
terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
total_terms: usize,
) -> UnionFind {
let mut uf = UnionFind::new(total_terms);
let eq_info: Vec<_> = theory
.eqs
.iter()
.map(|eq| {
let vars: Vec<Arc<str>> = {
let mut all = eq.lhs.free_vars();
all.extend(eq.rhs.free_vars());
all.into_iter().collect()
};
let var_sorts = crate::typecheck::infer_var_sorts(eq, theory).ok();
(eq, vars, var_sorts)
})
.collect();
let congruence_entries = build_congruence_index(terms_by_sort, term_to_global);
loop {
let merges_before = uf.merge_count;
for (eq, vars, var_sorts) in &eq_info {
if vars.is_empty() {
merge_constant_eq(eq, terms_by_sort, term_to_global, &mut uf);
continue;
}
let Some(vs) = var_sorts else {
continue;
};
merge_by_equation(eq, vars, vs, terms_by_sort, term_to_global, &mut uf);
}
congruence_closure_pass(&congruence_entries, &mut uf);
if uf.merge_count == merges_before {
break;
}
}
uf
}
struct CongruenceEntry {
term_idx: usize,
arg_indices: Vec<usize>,
}
fn build_congruence_index(
terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
) -> FxHashMap<Arc<str>, Vec<CongruenceEntry>> {
let mut index: FxHashMap<Arc<str>, Vec<CongruenceEntry>> = FxHashMap::default();
let mut term_lookup: FxHashMap<&Term, usize> = FxHashMap::default();
for (sort, terms) in terms_by_sort {
let indices = &term_to_global[sort];
for (i, term) in terms.iter().enumerate() {
term_lookup.insert(term, indices[i]);
}
}
for (sort, terms) in terms_by_sort {
let indices = &term_to_global[sort];
for (i, term) in terms.iter().enumerate() {
if let Term::App { op, args } = term {
if args.is_empty() {
continue;
}
let arg_indices: Vec<usize> = args
.iter()
.filter_map(|arg| term_lookup.get(arg).copied())
.collect();
if arg_indices.len() == args.len() {
index
.entry(Arc::clone(op))
.or_default()
.push(CongruenceEntry {
term_idx: indices[i],
arg_indices,
});
}
}
}
}
index
}
fn congruence_closure_pass(
entries: &FxHashMap<Arc<str>, Vec<CongruenceEntry>>,
uf: &mut UnionFind,
) {
for group in entries.values() {
if group.len() < 2 {
continue;
}
let mut canonical_groups: FxHashMap<Vec<usize>, usize> = FxHashMap::default();
for entry in group {
let canonical_args: Vec<usize> =
entry.arg_indices.iter().map(|&i| uf.find(i)).collect();
if let Some(&representative) = canonical_groups.get(&canonical_args) {
uf.union(representative, entry.term_idx);
} else {
canonical_groups.insert(canonical_args, uf.find(entry.term_idx));
}
}
}
}
fn is_app_only(term: &Term) -> bool {
match term {
Term::Var(_) => true,
Term::App { args, .. } => args.iter().all(is_app_only),
Term::Case { .. } | Term::Hole { .. } | Term::Let { .. } => false,
}
}
fn term_to_string(term: &Term) -> String {
match term {
Term::Var(name) => name.to_string(),
Term::App { op, args } if args.is_empty() => format!("{op}()"),
Term::App { op, args } => {
let arg_strs: Vec<String> = args.iter().map(term_to_string).collect();
format!("{op}({})", arg_strs.join(", "))
}
Term::Case {
scrutinee,
branches,
} => {
let branch_strs: Vec<String> = branches
.iter()
.map(|b| {
let binders = b
.binders
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>();
format!(
"{}({}) => {}",
b.constructor,
binders.join(", "),
term_to_string(&b.body)
)
})
.collect();
format!(
"case {} of {} end",
term_to_string(scrutinee),
branch_strs.join(" | ")
)
}
Term::Hole { name } => name
.as_ref()
.map_or_else(|| "?".to_string(), |n| format!("?{n}")),
Term::Let { name, bound, body } => format!(
"let {name} = {} in {}",
term_to_string(bound),
term_to_string(body)
),
}
}
fn build_model(
theory: &Theory,
terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
uf: &mut UnionFind,
) -> Model {
let mut model = Model::new(&*theory.name);
let mut class_rep_string: FxHashMap<usize, String> = FxHashMap::default();
let mut string_to_rep: FxHashMap<String, String> = FxHashMap::default();
for (sort, terms) in terms_by_sort {
let indices = &term_to_global[sort];
let mut seen_classes: FxHashSet<usize> = FxHashSet::default();
for (i, term) in terms.iter().enumerate() {
debug_assert!(
is_app_only(term),
"free-model generator emitted a non-App term: {term:?}",
);
let rep = uf.find(indices[i]);
if seen_classes.insert(rep) {
class_rep_string.insert(rep, term_to_string(term));
}
let rep_str = class_rep_string[&rep].clone();
string_to_rep.insert(term_to_string(term), rep_str);
}
}
for (sort, terms) in terms_by_sort {
let indices = &term_to_global[sort];
let mut seen_classes: FxHashSet<usize> = FxHashSet::default();
let mut carrier = Vec::new();
for (i, term) in terms.iter().enumerate() {
let rep = uf.find(indices[i]);
if seen_classes.insert(rep) {
carrier.push(ModelValue::Str(term_to_string(term)));
}
}
model.add_sort(sort.to_string(), carrier);
}
let lookup = Arc::new(string_to_rep);
for op in &theory.ops {
let op_name = op.name.to_string();
let arity = op.arity();
let table = Arc::clone(&lookup);
model.add_op(op_name.clone(), move |args: &[ModelValue]| {
if args.len() != arity {
return Err(GatError::ModelError(format!(
"operation '{op_name}' expects {arity} args, got {}",
args.len()
)));
}
let mut arg_strs: Vec<String> = Vec::with_capacity(args.len());
for (i, a) in args.iter().enumerate() {
match a {
ModelValue::Str(s) => arg_strs.push(s.clone()),
other => {
return Err(GatError::ModelError(format!(
"operation '{op_name}' received non-string argument at index {i}: {other:?}"
)));
}
}
}
let result_str = format!("{op_name}({})", arg_strs.join(", "));
Ok(ModelValue::Str(
table.get(&result_str).map_or(result_str, String::clone),
))
});
}
model
}
fn merge_constant_eq(
eq: &crate::eq::Equation,
terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
uf: &mut UnionFind,
) {
let lhs_idx = find_term_index(&eq.lhs, terms_by_sort, term_to_global);
let rhs_idx = find_term_index(&eq.rhs, terms_by_sort, term_to_global);
if let (Some(l), Some(r)) = (lhs_idx, rhs_idx) {
uf.union(l, r);
}
}
fn find_term_index(
term: &Term,
terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
) -> Option<usize> {
for (sort, terms) in terms_by_sort {
for (i, t) in terms.iter().enumerate() {
if t == term {
return Some(term_to_global[sort][i]);
}
}
}
None
}
fn merge_by_equation(
eq: &crate::eq::Equation,
vars: &[Arc<str>],
var_sorts: &FxHashMap<Arc<str>, SortExpr>,
terms_by_sort: &FxHashMap<Arc<str>, Vec<Term>>,
term_to_global: &FxHashMap<Arc<str>, Vec<usize>>,
uf: &mut UnionFind,
) {
let var_terms: Vec<(&Arc<str>, &Vec<Term>)> = vars
.iter()
.filter_map(|v| {
let sort = var_sorts.get(v)?;
let terms = terms_by_sort.get(sort.head())?;
Some((v, terms))
})
.collect();
if var_terms.len() != vars.len() || var_terms.iter().any(|(_, terms)| terms.is_empty()) {
return;
}
let mut indices = vec![0usize; var_terms.len()];
loop {
let mut subst = rustc_hash::FxHashMap::default();
for (i, (var, terms)) in var_terms.iter().enumerate() {
subst.insert(Arc::clone(var), terms[indices[i]].clone());
}
let lhs = eq.lhs.substitute(&subst);
let rhs = eq.rhs.substitute(&subst);
let lhs_idx = find_term_index(&lhs, terms_by_sort, term_to_global);
let rhs_idx = find_term_index(&rhs, terms_by_sort, term_to_global);
if let (Some(l), Some(r)) = (lhs_idx, rhs_idx) {
uf.union(l, r);
}
let mut carry = true;
for i in (0..indices.len()).rev() {
if carry {
indices[i] += 1;
if indices[i] < var_terms[i].1.len() {
carry = false;
} else {
indices[i] = 0;
}
}
}
if carry {
break;
}
}
}
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
merge_count: usize,
}
impl UnionFind {
fn new(size: usize) -> Self {
Self {
parent: (0..size).collect(),
rank: vec![0; size],
merge_count: 0,
}
}
fn find(&mut self, mut x: usize) -> usize {
while self.parent[x] != x {
self.parent[x] = self.parent[self.parent[x]]; x = self.parent[x];
}
x
}
fn union(&mut self, x: usize, y: usize) {
let rx = self.find(x);
let ry = self.find(y);
if rx == ry {
return;
}
self.merge_count += 1;
match self.rank[rx].cmp(&self.rank[ry]) {
std::cmp::Ordering::Less => self.parent[rx] = ry,
std::cmp::Ordering::Greater => self.parent[ry] = rx,
std::cmp::Ordering::Equal => {
self.parent[ry] = rx;
self.rank[rx] += 1;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eq::Equation;
use crate::op::Operation;
use crate::sort::Sort;
use crate::theory::Theory;
#[test]
fn free_model_of_pointed_set() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"PointedSet",
vec![Sort::simple("Carrier")],
vec![Operation::nullary("unit", "Carrier")],
vec![],
);
let result = free_model(&theory, &FreeModelConfig::default())?;
assert_eq!(result.model.sort_interp["Carrier"].len(), 1);
Ok(())
}
#[test]
fn free_model_empty_theory() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new("Empty", vec![Sort::simple("S")], vec![], vec![]);
let model = free_model(&theory, &FreeModelConfig::default())?.model;
assert!(model.sort_interp["S"].is_empty());
Ok(())
}
#[test]
fn free_model_two_constants() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"TwoPoints",
vec![Sort::simple("S")],
vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
vec![],
);
let model = free_model(&theory, &FreeModelConfig::default())?.model;
assert_eq!(model.sort_interp["S"].len(), 2);
Ok(())
}
#[test]
fn free_model_equation_collapses_constants() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"CollapsedPoints",
vec![Sort::simple("S")],
vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
vec![Equation::new(
"a_eq_b",
Term::constant("a"),
Term::constant("b"),
)],
);
let model = free_model(&theory, &FreeModelConfig::default())?.model;
assert_eq!(model.sort_interp["S"].len(), 1);
Ok(())
}
#[test]
fn free_model_monoid_identity_collapses() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"Monoid",
vec![Sort::simple("Carrier")],
vec![
Operation::new(
"mul",
vec![
("a".into(), "Carrier".into()),
("b".into(), "Carrier".into()),
],
"Carrier",
),
Operation::nullary("unit", "Carrier"),
],
vec![
Equation::new(
"left_id",
Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
Term::var("a"),
),
Equation::new(
"right_id",
Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
Term::var("a"),
),
],
);
let config = FreeModelConfig {
max_depth: 1,
max_terms_per_sort: 100,
};
let model = free_model(&theory, &config)?.model;
assert_eq!(model.sort_interp["Carrier"].len(), 1);
Ok(())
}
#[test]
fn free_model_graph_theory() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"Graph",
vec![Sort::simple("Vertex"), Sort::simple("Edge")],
vec![
Operation::unary("src", "e", "Edge", "Vertex"),
Operation::unary("tgt", "e", "Edge", "Vertex"),
],
vec![],
);
let model = free_model(&theory, &FreeModelConfig::default())?.model;
assert!(model.sort_interp["Vertex"].is_empty());
assert!(model.sort_interp["Edge"].is_empty());
Ok(())
}
#[test]
fn free_model_term_count_bounded() {
let theory = Theory::new(
"Chain",
vec![Sort::simple("S")],
vec![
Operation::nullary("zero", "S"),
Operation::unary("succ", "x", "S", "S"),
],
vec![],
);
let config = FreeModelConfig {
max_depth: 10,
max_terms_per_sort: 5,
};
let result = free_model(&theory, &config);
assert!(matches!(result, Err(GatError::ModelError(_))));
}
#[test]
fn free_model_category_theory() -> Result<(), Box<dyn std::error::Error>> {
use crate::sort::SortParam;
let theory = Theory::new(
"Category",
vec![
Sort::simple("Ob"),
Sort::dependent(
"Hom",
vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
),
],
vec![
Operation::nullary("star", "Ob"),
Operation::unary("id", "x", "Ob", "Hom"),
],
Vec::new(),
);
let config = FreeModelConfig {
max_depth: 2,
max_terms_per_sort: 100,
};
let model = free_model(&theory, &config)?.model;
assert_eq!(model.sort_interp["Ob"].len(), 1);
assert!(
!model.sort_interp["Hom"].is_empty(),
"Hom should have at least the identity morphism"
);
Ok(())
}
#[test]
fn free_model_dependent_sort_no_ops() -> Result<(), Box<dyn std::error::Error>> {
use crate::sort::SortParam;
let theory = Theory::new(
"T",
vec![
Sort::simple("A"),
Sort::dependent("B", vec![SortParam::new("x", "A")]),
],
vec![Operation::nullary("a", "A")],
Vec::new(),
);
let model = free_model(&theory, &FreeModelConfig::default())?.model;
assert_eq!(model.sort_interp["A"].len(), 1);
assert!(
model.sort_interp["B"].is_empty(),
"B has no operations targeting it, so carrier should be empty"
);
Ok(())
}
#[test]
fn free_model_sort_ordering() -> Result<(), Box<dyn std::error::Error>> {
use crate::sort::SortParam;
let theory = Theory::new(
"T",
vec![
Sort::dependent("B", vec![SortParam::new("x", "A")]),
Sort::simple("A"),
],
vec![
Operation::nullary("a", "A"),
Operation::unary("f", "x", "A", "B"),
],
Vec::new(),
);
let config = FreeModelConfig {
max_depth: 1,
max_terms_per_sort: 100,
};
let model = free_model(&theory, &config)?.model;
assert_eq!(model.sort_interp["A"].len(), 1);
assert_eq!(model.sort_interp["B"].len(), 1);
Ok(())
}
#[test]
fn free_model_operations_work() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"PointedSet",
vec![Sort::simple("Carrier")],
vec![Operation::nullary("unit", "Carrier")],
vec![],
);
let model = free_model(&theory, &FreeModelConfig::default())?.model;
let result = model.eval("unit", &[])?;
assert!(matches!(result, ModelValue::Str(_)));
Ok(())
}
#[test]
fn free_model_congruence_closure() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"Congruence",
vec![Sort::simple("S")],
vec![
Operation::nullary("a", "S"),
Operation::nullary("b", "S"),
Operation::unary("f", "x", "S", "S"),
],
vec![Equation::new(
"a_eq_b",
Term::constant("a"),
Term::constant("b"),
)],
);
let config = FreeModelConfig {
max_depth: 1,
max_terms_per_sort: 100,
};
let model = free_model(&theory, &config)?.model;
assert_eq!(
model.sort_interp["S"].len(),
2,
"a ~ b and f(a) ~ f(b) by congruence: expect 2 classes"
);
Ok(())
}
#[test]
fn free_model_dependent_category() -> Result<(), Box<dyn std::error::Error>> {
use crate::sort::{SortExpr, SortParam};
let hom_xx = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("x")],
};
let theory = Theory::new(
"EndoCategory",
vec![
Sort::simple("Ob"),
Sort::dependent(
"Hom",
vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
),
],
vec![
Operation::nullary("star", "Ob"),
Operation::unary("id", "x", "Ob", hom_xx.clone()),
Operation::unary("f", "x", "Ob", hom_xx),
],
Vec::new(),
);
let config = FreeModelConfig {
max_depth: 2,
max_terms_per_sort: 100,
};
let model = free_model(&theory, &config)?.model;
assert_eq!(model.sort_interp["Ob"].len(), 1);
assert_eq!(
model.sort_interp["Hom"].len(),
2,
"expected id(star) and f(star) in Hom fiber"
);
Ok(())
}
#[test]
fn free_model_parallel_arrows_no_spurious_composites() -> Result<(), Box<dyn std::error::Error>>
{
use crate::sort::{SortExpr, SortParam};
let hom_ab = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::constant("a"), Term::constant("b")],
};
let hom_xy = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("y")],
};
let theory = Theory::new(
"ParallelArrows",
vec![
Sort::simple("Ob"),
Sort::dependent(
"Hom",
vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
),
],
vec![
Operation::nullary("a", "Ob"),
Operation::nullary("b", "Ob"),
Operation::nullary("f", hom_ab.clone()),
Operation::nullary("g", hom_ab),
Operation::unary(
"id",
"x",
"Ob",
SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("x")],
},
),
Operation::new(
"compose",
vec![
(Arc::from("x"), SortExpr::from("Ob")),
(Arc::from("y"), SortExpr::from("Ob")),
(Arc::from("z"), SortExpr::from("Ob")),
(
Arc::from("h1"),
SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("y")],
},
),
(
Arc::from("h2"),
SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("y"), Term::var("z")],
},
),
],
hom_xy,
),
],
Vec::new(),
);
let config = FreeModelConfig {
max_depth: 1,
max_terms_per_sort: 100,
};
let model = free_model(&theory, &config)?.model;
assert_eq!(model.sort_interp["Ob"].len(), 2);
assert_eq!(
model.sort_interp["Hom"].len(),
4,
"Hom fiber should contain {{id(a), id(b), f, g}}, got {:?}",
model.sort_interp["Hom"],
);
Ok(())
}
#[test]
fn free_model_every_term_well_typed() -> Result<(), Box<dyn std::error::Error>> {
use crate::sort::{SortExpr, SortParam};
use crate::typecheck::{VarContext, typecheck_term};
let hom_xx = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("x")],
};
let theory = Theory::new(
"EndoCat",
vec![
Sort::simple("Ob"),
Sort::dependent(
"Hom",
vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
),
],
vec![
Operation::nullary("star", "Ob"),
Operation::unary("id", "x", "Ob", hom_xx.clone()),
Operation::unary("f", "x", "Ob", hom_xx),
],
Vec::new(),
);
let config = FreeModelConfig {
max_depth: 2,
max_terms_per_sort: 100,
};
let (fibers, _) = generate_terms(&theory, &config)?;
let ctx = VarContext::default();
for (fiber, terms) in &fibers {
for term in terms {
let inferred = typecheck_term(term, &ctx, &theory)?;
assert!(
inferred.alpha_eq(fiber),
"term {term} has fiber {fiber} but typecheck inferred {inferred}",
);
}
}
Ok(())
}
#[test]
fn free_model_simple_sorts_backward_compat() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"Graph",
vec![Sort::simple("Vertex"), Sort::simple("Edge")],
vec![
Operation::nullary("v0", "Vertex"),
Operation::nullary("v1", "Vertex"),
Operation::unary("src", "e", "Edge", "Vertex"),
Operation::unary("tgt", "e", "Edge", "Vertex"),
],
Vec::new(),
);
let config = FreeModelConfig {
max_depth: 1,
max_terms_per_sort: 100,
};
let model = free_model(&theory, &config)?.model;
assert_eq!(model.sort_interp["Vertex"].len(), 2);
assert!(model.sort_interp["Edge"].is_empty());
Ok(())
}
#[test]
fn free_model_cyclic_sort_dependency_rejected() {
use crate::sort::SortParam;
let theory = Theory::new(
"Cyclic",
vec![
Sort::dependent("A", vec![SortParam::new("x", "B")]),
Sort::dependent("B", vec![SortParam::new("y", "A")]),
],
vec![],
vec![],
);
let result = free_model(&theory, &FreeModelConfig::default());
assert!(
matches!(result, Err(GatError::CyclicSortDependency(_))),
"cyclic sort dependencies should be rejected"
);
}
}