use std::sync::Arc;
use rustc_hash::FxHashMap;
use crate::eq::{Equation, Term};
use crate::error::GatError;
use crate::theory::Theory;
pub type VarContext = FxHashMap<Arc<str>, Arc<str>>;
pub fn typecheck_term(
term: &Term,
ctx: &VarContext,
theory: &Theory,
) -> Result<Arc<str>, GatError> {
match term {
Term::Var(name) => ctx
.get(name)
.cloned()
.ok_or_else(|| GatError::UnboundVariable(name.to_string())),
Term::App { op, args } => {
let operation = theory
.find_op(op)
.ok_or_else(|| GatError::OpNotFound(op.to_string()))?;
if args.len() != operation.inputs.len() {
return Err(GatError::TermArityMismatch {
op: op.to_string(),
expected: operation.inputs.len(),
got: args.len(),
});
}
for (i, (arg, (_, expected_sort))) in
args.iter().zip(operation.inputs.iter()).enumerate()
{
let arg_sort = typecheck_term(arg, ctx, theory)?;
if arg_sort != *expected_sort {
return Err(GatError::ArgTypeMismatch {
op: op.to_string(),
arg_index: i,
expected: expected_sort.to_string(),
got: arg_sort.to_string(),
});
}
}
Ok(Arc::clone(&operation.output))
}
}
}
pub fn infer_var_sorts(eq: &Equation, theory: &Theory) -> Result<VarContext, GatError> {
let mut ctx = VarContext::default();
collect_constraints(&eq.lhs, theory, &mut ctx)?;
collect_constraints(&eq.rhs, theory, &mut ctx)?;
Ok(ctx)
}
fn collect_constraints(term: &Term, theory: &Theory, ctx: &mut VarContext) -> Result<(), GatError> {
if let Term::App { op, args } = term {
let operation = theory
.find_op(op)
.ok_or_else(|| GatError::OpNotFound(op.to_string()))?;
for (arg, (_, expected_sort)) in args.iter().zip(operation.inputs.iter()) {
match arg {
Term::Var(var_name) => {
if let Some(existing) = ctx.get(var_name) {
if existing != expected_sort {
return Err(GatError::ConflictingVarSort {
var: var_name.to_string(),
sort1: existing.to_string(),
sort2: expected_sort.to_string(),
});
}
} else {
ctx.insert(Arc::clone(var_name), Arc::clone(expected_sort));
}
}
Term::App { .. } => {
collect_constraints(arg, theory, ctx)?;
}
}
}
}
Ok(())
}
pub fn typecheck_equation(eq: &Equation, theory: &Theory) -> Result<(), GatError> {
let ctx = infer_var_sorts(eq, theory)?;
let lhs_sort = typecheck_term(&eq.lhs, &ctx, theory)?;
let rhs_sort = typecheck_term(&eq.rhs, &ctx, theory)?;
if lhs_sort != rhs_sort {
return Err(GatError::EquationSortMismatch {
equation: eq.name.to_string(),
lhs_sort: lhs_sort.to_string(),
rhs_sort: rhs_sort.to_string(),
});
}
Ok(())
}
pub fn typecheck_theory(theory: &Theory) -> Result<(), GatError> {
for eq in &theory.eqs {
typecheck_equation(eq, theory)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eq::Term;
use crate::op::Operation;
use crate::sort::Sort;
use crate::theory::Theory;
fn monoid_theory() -> Theory {
let carrier = Sort::simple("Carrier");
let mul = Operation::new(
"mul",
vec![
("a".into(), "Carrier".into()),
("b".into(), "Carrier".into()),
],
"Carrier",
);
let unit = Operation::nullary("unit", "Carrier");
let assoc = Equation::new(
"assoc",
Term::app(
"mul",
vec![
Term::var("a"),
Term::app("mul", vec![Term::var("b"), Term::var("c")]),
],
),
Term::app(
"mul",
vec![
Term::app("mul", vec![Term::var("a"), Term::var("b")]),
Term::var("c"),
],
),
);
let left_id = Equation::new(
"left_id",
Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
Term::var("a"),
);
let right_id = Equation::new(
"right_id",
Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
Term::var("a"),
);
Theory::new(
"Monoid",
vec![carrier],
vec![mul, unit],
vec![assoc, left_id, right_id],
)
}
fn two_sort_theory() -> Theory {
Theory::new(
"TwoSort",
vec![Sort::simple("A"), Sort::simple("B")],
vec![
Operation::unary("f", "x", "A", "B"),
Operation::unary("g", "x", "B", "A"),
Operation::nullary("a0", "A"),
],
vec![],
)
}
#[test]
fn typecheck_variable() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let mut ctx = VarContext::default();
ctx.insert(Arc::from("x"), Arc::from("Carrier"));
let sort = typecheck_term(&Term::var("x"), &ctx, &theory)?;
assert_eq!(&*sort, "Carrier");
Ok(())
}
#[test]
fn typecheck_unbound_variable() {
let theory = monoid_theory();
let ctx = VarContext::default();
let result = typecheck_term(&Term::var("z"), &ctx, &theory);
assert!(matches!(result, Err(GatError::UnboundVariable(_))));
}
#[test]
fn typecheck_constant() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let ctx = VarContext::default();
let sort = typecheck_term(&Term::constant("unit"), &ctx, &theory)?;
assert_eq!(&*sort, "Carrier");
Ok(())
}
#[test]
fn typecheck_binary_op() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let mut ctx = VarContext::default();
ctx.insert(Arc::from("a"), Arc::from("Carrier"));
ctx.insert(Arc::from("b"), Arc::from("Carrier"));
let sort = typecheck_term(
&Term::app("mul", vec![Term::var("a"), Term::var("b")]),
&ctx,
&theory,
)?;
assert_eq!(&*sort, "Carrier");
Ok(())
}
#[test]
fn typecheck_arity_mismatch() {
let theory = monoid_theory();
let mut ctx = VarContext::default();
ctx.insert(Arc::from("a"), Arc::from("Carrier"));
let result = typecheck_term(&Term::app("mul", vec![Term::var("a")]), &ctx, &theory);
assert!(matches!(result, Err(GatError::TermArityMismatch { .. })));
}
#[test]
fn typecheck_sort_mismatch() {
let theory = two_sort_theory();
let mut ctx = VarContext::default();
ctx.insert(Arc::from("x"), Arc::from("B"));
let result = typecheck_term(&Term::app("f", vec![Term::var("x")]), &ctx, &theory);
assert!(matches!(result, Err(GatError::ArgTypeMismatch { .. })));
}
#[test]
fn typecheck_nested_term() -> Result<(), Box<dyn std::error::Error>> {
let theory = two_sort_theory();
let ctx = VarContext::default();
let term = Term::app("g", vec![Term::app("f", vec![Term::constant("a0")])]);
let sort = typecheck_term(&term, &ctx, &theory)?;
assert_eq!(&*sort, "A");
Ok(())
}
#[test]
fn typecheck_nested_sort_mismatch() {
let theory = two_sort_theory();
let ctx = VarContext::default();
let term = Term::app("f", vec![Term::app("f", vec![Term::constant("a0")])]);
let result = typecheck_term(&term, &ctx, &theory);
assert!(matches!(result, Err(GatError::ArgTypeMismatch { .. })));
}
#[test]
fn typecheck_unknown_op() {
let theory = monoid_theory();
let ctx = VarContext::default();
let result = typecheck_term(&Term::constant("nonexistent"), &ctx, &theory);
assert!(matches!(result, Err(GatError::OpNotFound(_))));
}
#[test]
fn infer_var_sorts_monoid() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let eq = &theory.eqs[0]; let ctx = infer_var_sorts(eq, &theory)?;
assert_eq!(ctx.len(), 3);
assert_eq!(&*ctx[&Arc::from("a")], "Carrier");
assert_eq!(&*ctx[&Arc::from("b")], "Carrier");
assert_eq!(&*ctx[&Arc::from("c")], "Carrier");
Ok(())
}
#[test]
fn infer_var_sorts_identity_law() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let eq = &theory.eqs[1]; let ctx = infer_var_sorts(eq, &theory)?;
assert_eq!(ctx.len(), 1);
assert_eq!(&*ctx[&Arc::from("a")], "Carrier");
Ok(())
}
#[test]
fn conflicting_var_sort() {
let theory = two_sort_theory();
let eq = Equation::new(
"bogus",
Term::app("f", vec![Term::var("x")]),
Term::app("g", vec![Term::var("x")]),
);
let result = infer_var_sorts(&eq, &theory);
assert!(matches!(result, Err(GatError::ConflictingVarSort { .. })));
}
#[test]
fn typecheck_monoid_equations() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
typecheck_theory(&theory)?;
Ok(())
}
#[test]
fn typecheck_equation_sort_mismatch() {
let theory = two_sort_theory();
let eq = Equation::new(
"bad",
Term::app("f", vec![Term::constant("a0")]), Term::constant("a0"), );
let result = typecheck_equation(&eq, &theory);
assert!(matches!(result, Err(GatError::EquationSortMismatch { .. })));
}
#[test]
fn typecheck_graph_theory() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"Graph",
vec![Sort::simple("Vertex"), Sort::simple("Edge")],
vec![
Operation::unary("src", "e", "Edge", "Vertex"),
Operation::unary("tgt", "e", "Edge", "Vertex"),
],
vec![],
);
typecheck_theory(&theory)?;
Ok(())
}
#[test]
fn typecheck_reflexive_graph_equations() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"ReflexiveGraph",
vec![Sort::simple("Vertex"), Sort::simple("Edge")],
vec![
Operation::unary("src", "e", "Edge", "Vertex"),
Operation::unary("tgt", "e", "Edge", "Vertex"),
Operation::unary("id", "v", "Vertex", "Edge"),
],
vec![
Equation::new(
"src_id",
Term::app("src", vec![Term::app("id", vec![Term::var("v")])]),
Term::var("v"),
),
Equation::new(
"tgt_id",
Term::app("tgt", vec![Term::app("id", vec![Term::var("v")])]),
Term::var("v"),
),
],
);
typecheck_theory(&theory)?;
Ok(())
}
#[test]
fn typecheck_symmetric_graph_equations() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"SymmetricGraph",
vec![Sort::simple("Vertex"), Sort::simple("Edge")],
vec![
Operation::unary("src", "e", "Edge", "Vertex"),
Operation::unary("tgt", "e", "Edge", "Vertex"),
Operation::unary("inv", "e", "Edge", "Edge"),
],
vec![
Equation::new(
"src_inv",
Term::app("src", vec![Term::app("inv", vec![Term::var("e")])]),
Term::app("tgt", vec![Term::var("e")]),
),
Equation::new(
"tgt_inv",
Term::app("tgt", vec![Term::app("inv", vec![Term::var("e")])]),
Term::app("src", vec![Term::var("e")]),
),
Equation::new(
"inv_inv",
Term::app("inv", vec![Term::app("inv", vec![Term::var("e")])]),
Term::var("e"),
),
],
);
typecheck_theory(&theory)?;
Ok(())
}
mod property {
use super::*;
use proptest::prelude::*;
const SORT_POOL: &[&str] = &["S0", "S1", "S2", "S3"];
fn arb_well_typed_theory() -> impl Strategy<Value = Theory> {
prop::sample::subsequence(SORT_POOL, 1..=4).prop_flat_map(|sort_names| {
let sorts: Vec<Sort> = sort_names.iter().map(|s| Sort::simple(*s)).collect();
let sn: Vec<String> = sort_names.iter().map(|s| (*s).to_owned()).collect();
let sn2 = sn.clone();
(
Just(sorts),
prop::collection::vec(
(
0..4usize,
prop::sample::select(sn),
prop::sample::select(sn2),
),
0..=3,
),
)
.prop_map(|(sorts, op_specs)| {
let mut ops = Vec::new();
let mut seen = std::collections::HashSet::new();
for (i, (_, input_sort, output_sort)) in op_specs.iter().enumerate() {
let name = format!("op{i}");
if !seen.insert(name.clone()) {
continue;
}
ops.push(Operation::unary(
&*name,
"x",
input_sort.as_str(),
output_sort.as_str(),
));
}
Theory::new("TypecheckTest", sorts, ops, Vec::new())
})
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn typecheck_is_idempotent(t in arb_well_typed_theory()) {
let result1 = typecheck_theory(&t);
let result2 = typecheck_theory(&t);
prop_assert_eq!(result1.is_ok(), result2.is_ok());
}
#[test]
fn well_typed_theory_passes(t in arb_well_typed_theory()) {
prop_assert!(
typecheck_theory(&t).is_ok(),
"well-typed theory should pass typecheck",
);
}
}
}
}