use voile_util::meta::{MetaSolution, MI};
use voile_util::tags::VarRec;
use crate::syntax::core::{CaseSplit, Closure, Neutral, TraverseNeutral, Val, Variants};
use super::monad::{TCE, TCM, TCS};
use std::cmp::Ordering;
fn check_solution(meta: MI, rhs: Val) -> TCM<()> {
rhs.try_fold_neutral((), |(), neut| match neut {
Neutral::Meta(mi) if mi == meta => Err(TCE::MetaRecursion(mi)),
_ => Ok(()),
})
}
fn solve_with(mut tcs: TCS, meta: MI, solution: Val) -> TCM {
let anticipated_solution = solution.clone().unimplemented_to_glob();
check_solution(meta, solution)?;
tcs.meta_context.solve_meta(meta, anticipated_solution);
Ok(tcs)
}
fn unify_variants(tcs: TCS, kind: VarRec, subset: &Variants, superset: &Variants) -> TCM {
subset.iter().try_fold(tcs, |tcs, (name, ty)| {
let counterpart = superset
.get(name)
.ok_or_else(|| TCE::MissingVariant(kind, name.clone()))?;
tcs.unify(ty, counterpart)
})
}
fn unify(tcs: TCS, a: &Val, b: &Val) -> TCM {
use Neutral::{Axi, Meta, Ref};
use Val::*;
match (a, b) {
(Type(sub_level), Type(super_level)) if sub_level == super_level => Ok(tcs),
(Neut(Axi(sub)), Neut(Axi(sup))) if sub.unique_id() == sup.unique_id() => Ok(tcs),
(Neut(Ref(x)), Neut(Ref(y))) if x == y => Ok(tcs),
(Dt(k0, a_plicit, input_a, clos_a), Dt(k1, b_plicit, input_b, clos_b))
if k0 == k1 && a_plicit == b_plicit =>
{
tcs.unify(input_a, input_b)?.unify_closure(clos_a, clos_b)
}
(Lam(a), Lam(b)) => unify_closure(tcs, a, b),
(Cons(_, a), Cons(_, b)) => tcs.unify(&**a, &**b),
(Pair(a0, a1), Pair(b0, b1)) => tcs.unify(&**a0, &**b0)?.unify(&**a1, &**b1),
(RowPoly(a_kind, a_variants), RowPoly(b_kind, b_variants))
if a_kind == b_kind && a_variants.len() == b_variants.len() =>
{
tcs.unify_variants(*a_kind, a_variants, b_variants)
}
(RowKind(a_level, a_kind, a_labels), RowKind(b_level, b_kind, b_labels))
if a_level == b_level && a_kind == b_kind && a_labels.len() == b_labels.len() =>
{
if a_labels.iter().all(|n| b_labels.contains(n)) {
Ok(tcs)
} else {
Err(TCE::CannotUnify(a.clone(), b.clone()))
}
}
(Rec(a_fields), Rec(b_fields)) if a_fields.len() == b_fields.len() => {
tcs.unify_variants(VarRec::Record, a_fields, b_fields)
}
(Rec(more), Neut(Neutral::Rec(less, ext))) | (Neut(Neutral::Rec(less, ext)), Rec(more)) => {
let (more, tcs) = unify_partial_variants(tcs, more.clone(), less, VarRec::Record)?;
tcs.unify(&Rec(more), &Neut(*ext.clone()))
}
(RowPoly(kind0, more), Neut(Neutral::Row(kind1, less, ext)))
| (Neut(Neutral::Row(kind1, less, ext)), RowPoly(kind0, more))
if kind0 == kind1 =>
{
let (more, tcs) = unify_partial_variants(tcs, more.clone(), less, *kind0)?;
tcs.unify(&RowPoly(*kind0, more), &Neut(*ext.clone()))
}
(term, Neut(Meta(mi))) | (Neut(Meta(mi)), term) => unify_meta_with(tcs, term, *mi),
(Neut(a), Neut(b)) => tcs.unify_neutral(a, b),
(e, t) => Err(TCE::CannotUnify(e.clone(), t.clone())),
}
}
fn unify_meta_with(tcs: TCS, term: &Val, mi: MI) -> TCM {
match &tcs.meta_context.solution(mi) {
MetaSolution::Unsolved => solve_with(tcs, mi, term.clone()),
MetaSolution::Solved(solution) => {
let val = *solution.clone();
tcs.unify(&val, term)
}
MetaSolution::Inlined => unreachable!(),
}
}
fn unify_neutral_variants(
tcs: TCS,
a_fields: &Variants,
b_fields: &Variants,
a_more: &Neutral,
b_more: &Neutral,
kind: VarRec,
) -> TCM {
let (more, less, more_r, less_r) = match a_fields.len().cmp(&b_fields.len()) {
Ordering::Equal => {
return tcs
.unify_variants(kind, a_fields, b_fields)?
.unify_neutral(a_more, b_more);
}
Ordering::Greater => (a_fields.clone(), b_fields, a_more, b_more),
Ordering::Less => (b_fields.clone(), a_fields, b_more, a_more),
};
let (rest, tcs) = unify_partial_variants(tcs, more, less, kind)?;
tcs.unify_neutral(&Neutral::Row(kind, rest, Box::new(more_r.clone())), less_r)
}
fn unify_partial_variants(
mut tcs: TCS,
mut more: Variants,
less: &Variants,
kind: VarRec,
) -> TCM<(Variants, TCS)> {
for (lab, term) in less {
tcs = (more.remove(lab))
.ok_or_else(|| TCE::MissingVariant(kind, lab.to_owned()))
.and_then(|t| tcs.unify(term, &t))?;
}
Ok((more, tcs))
}
fn unify_closure(tcs: TCS, a: &Closure, b: &Closure) -> TCM {
use Closure::*;
match (a, b) {
(Plain(..), Plain(..)) => {
let p = Val::fresh_axiom();
let a = a.instantiate_borrow(&p);
let b = b.instantiate_cloned(p);
tcs.unify(&a, &b)
}
(Tree(split_a), Tree(split_b)) => tcs.unify_case_split(split_a, split_b),
(Tree(split), _) | (_, Tree(split)) => {
let mut tcs = tcs;
for (label, branch) in split {
let p = Val::fresh_axiom();
let cons = Val::cons(label.clone(), p.clone());
let a = branch.instantiate_cloned(p);
let b = b.instantiate_cloned(cons);
tcs = tcs.unify(&a, &b)?;
}
Ok(tcs)
}
}
}
fn unify_case_split(tcs: TCS, split_a: &CaseSplit, split_b: &CaseSplit) -> TCM {
if split_a.len() != split_b.len() {
return Err(TCE::CannotUnify(
Val::case_tree(split_a.clone()),
Val::case_tree(split_b.clone()),
));
}
let mut tcs = tcs;
for (label, closure_a) in split_a {
let case_b = split_b
.get(label)
.ok_or_else(|| TCE::MissingVariant(VarRec::Variant, label.clone()))?;
tcs = tcs.unify_closure(closure_a, case_b)?;
}
Ok(tcs)
}
fn unify_neutral(tcs: TCS, a: &Neutral, b: &Neutral) -> TCM {
use Neutral::*;
match (a, b) {
(Ref(x), Ref(y)) if x == y => Ok(tcs),
(Lift(x, a), Lift(y, b)) if x == y => tcs.unify_neutral(&**a, &**b),
(App(f, a), App(g, b)) if a.len() == b.len() => (a.iter().zip(b.iter()))
.try_fold(tcs.unify_neutral(&*f, &*g)?, |tcs, (x, y)| tcs.unify(x, y)),
(Rec(a_f, a_more), Rec(b_f, b_more)) => {
unify_neutral_variants(tcs, a_f, b_f, &**a_more, &**b_more, VarRec::Record)
}
(Row(a_kind, a_fields, a_more), Row(b_kind, b_fields, b_more)) if a_kind == b_kind => {
unify_neutral_variants(tcs, a_fields, b_fields, &**a_more, &**b_more, *a_kind)
}
(Snd(a), Snd(b)) | (Fst(a), Fst(b)) => tcs.unify_neutral(&**a, &**b),
(Proj(a, lab_a), Proj(b, lab_b)) if lab_a == lab_b => tcs.unify_neutral(&**a, &**b),
(SplitOn(split_a, a), SplitOn(split_b, b)) | (OrSplit(split_a, a), OrSplit(split_b, b)) => {
tcs.unify_case_split(split_a, split_b)?
.unify_neutral(&**a, &**b)
}
(Axi(a), Axi(b)) if a.unique_id() == b.unique_id() => Ok(tcs),
(Meta(mi), sol) | (sol, Meta(mi)) => unify_meta_with(tcs, &Val::Neut(sol.clone()), *mi),
(e, t) => Err(TCE::CannotUnify(Val::Neut(e.clone()), Val::Neut(t.clone()))),
}
}
impl TCS {
#[inline]
pub fn unify(self, a: &Val, b: &Val) -> TCM {
unify(self, a, b)
}
#[inline]
fn unify_neutral(self, a: &Neutral, b: &Neutral) -> TCM {
unify_neutral(self, a, b)
}
#[inline]
fn unify_closure(self, a: &Closure, b: &Closure) -> TCM {
unify_closure(self, a, b)
}
#[inline]
fn unify_case_split(self, a: &CaseSplit, b: &CaseSplit) -> TCM {
unify_case_split(self, a, b)
}
#[inline]
pub(crate) fn unify_variants(
self,
kind: VarRec,
subset: &Variants,
superset: &Variants,
) -> TCM {
unify_variants(self, kind, subset, superset)
}
}