use voile_util::level::{Level, LevelType, LiftEx};
use voile_util::loc::*;
use voile_util::meta::MetaSolution;
use voile_util::tags::{PiSig::*, Plicit, VarRec, VarRec::*};
use crate::syntax::abs::{Abs, LabAbs};
use crate::syntax::core::{CaseSplit, Closure, Fields, Neutral, Val, Variants, TYPE_OMEGA};
use super::eval::compile_cons;
use super::monad::{ValTCM, TCE, TCM, TCS};
fn check(mut tcs: TCS, expr: &Abs, expected_type: &Val) -> ValTCM {
use Abs::*;
match (expr, expected_type) {
(Type(info, lower), Val::Type(upper)) => {
if upper > lower {
Ok((Val::Type(*lower).into_info(*info), tcs))
} else {
Err(TCE::LevelMismatch(expr.loc(), *lower + 1, *upper))
}
}
(RowKind(info, kind, labels), Val::Type(upper)) if *upper > From::from(0 as LevelType) => {
let labels = labels.iter().map(|l| &l.text).cloned().collect();
let expr = Val::RowKind(Default::default(), *kind, labels);
Ok((expr.into_info(*info), tcs))
}
(Meta(ident, mi), ty) => {
let axiom = Val::fresh_axiom();
let mocked = mock_for(ty, || axiom.clone());
if mocked != axiom {
tcs.meta_context.solve_meta(*mi, mocked);
}
Ok((Val::meta(*mi).into_info(ident.loc), tcs))
}
(Pair(info, fst, snd), Val::Dt(Sigma, Plicit::Ex, param_ty, closure)) => {
let (fst_term, mut tcs) = tcs.check(&**fst, &**param_ty).map_err(|e| e.wrap(*info))?;
let fst_term_ast = fst_term.ast.clone();
let snd_ty = closure.instantiate_borrow(&fst_term_ast);
let param_type = param_ty.clone().into_info(fst_term.loc());
tcs.local_gamma.push(param_type);
tcs.local_env.push(fst_term);
let (snd_term, mut tcs) = tcs.check(&**snd, &snd_ty).map_err(|e| e.wrap(*info))?;
tcs.pop_local();
let pair = Val::pair(fst_term_ast, snd_term.ast).into_info(*info);
Ok((pair, tcs))
}
(Lam(full_loc, param_loc, uid, body), Val::Dt(Pi, Plicit::Ex, param_ty, ret_ty)) => {
let param_type = param_ty.clone().into_info(param_loc.loc);
tcs.local_gamma.push(param_type);
let mocked = mock_for(&**param_ty, || Val::postulate(*uid));
let mocked_term = mocked.clone().into_info(param_loc.loc);
tcs.local_env.push(mocked_term);
let ret_ty_body = ret_ty.instantiate_cloned(mocked);
let (lam_term, mut tcs) = tcs
.check(body, &ret_ty_body)
.map_err(|e| e.wrap(*full_loc))?;
tcs.pop_local();
let lam = Val::closure_lam(lam_term.ast);
Ok((lam.into_info(*full_loc), tcs))
}
(Lam(..), Val::Dt(Pi, Plicit::Im, param_ty, ret_ty)) => {
let param_type = param_ty.clone().into_info(Default::default());
tcs.local_gamma.push(param_type);
let mocked = mock_for(&**param_ty, Val::fresh_implicit);
let mocked_term = mocked.clone().into_info(Default::default());
tcs.local_env.push(mocked_term);
let ret_ty_body = ret_ty.instantiate_cloned(mocked);
let (lam, mut tcs) = tcs.check(&expr.clone(), &ret_ty_body)?;
tcs.pop_local();
Ok((lam, tcs))
}
(Cons(info), Val::Dt(Pi, ..)) => Ok((compile_cons(info.clone()), tcs)),
(Dt(info, kind, uid, param_plicit, param, ret), Val::Type(..)) => {
let (param, mut tcs) = tcs
.check(&**param, expected_type)
.map_err(|e| e.wrap(*info))?;
tcs.local_gamma.push(param.clone());
let axiom = Val::postulate(*uid).into_info(param.loc());
tcs.local_env.push(axiom);
let (ret, mut tcs) = tcs
.check(&**ret, expected_type)
.map_err(|e| e.wrap(*info))?;
tcs.pop_local();
let dt = Val::closure_dependent_type(*kind, *param_plicit, param.ast, ret.ast)
.into_info(*info);
Ok((dt, tcs))
}
(RowPoly(info, Record, variants, ext), Val::RowKind(l, Record, labels)) => {
check_row_polymorphic_type(tcs, *info, *l, Record, variants, ext, labels)
}
(RowPoly(info, Variant, variants, ext), Val::RowKind(l, Variant, labels)) => {
check_row_polymorphic_type(tcs, *info, *l, Variant, variants, ext, labels)
}
(RowPoly(info, kind, variants, ext), Val::Type(l)) => {
check_row_polymorphic_type(tcs, *info, *l, *kind, variants, ext, &[])
}
(Rec(info, fields, more), Val::RowPoly(Record, field_types)) => {
let (nice_fields, rest_field_types, tcs) = check_fields(tcs, fields, field_types)?;
match more {
Some(more) => {
let more_type = Val::record_type(rest_field_types);
let (more, tcs) = tcs.check(&**more, &more_type)?;
let record = Val::Rec(nice_fields).rec_extend(more.ast);
Ok((record.into_info(*info), tcs))
}
None => check_fields_no_more(*info, nice_fields, rest_field_types, tcs),
}
}
(Rec(info, fields, more), Val::Neut(Neutral::Row(Record, field_types, more_types))) => {
let (nice_fields, rest_field_types, tcs) = check_fields(tcs, fields, field_types)?;
match more {
Some(more) => {
let more_type = if rest_field_types.is_empty() {
Val::Neut(*more_types.clone())
} else {
Val::neutral_record_type(rest_field_types, *more_types.clone())
};
let (more, tcs) = tcs.check(&**more, &more_type)?;
let record = Val::Rec(nice_fields).rec_extend(more.ast);
Ok((record.into_info(*info), tcs))
}
None => check_fields_no_more(*info, nice_fields, rest_field_types, tcs),
}
}
(Lift(info, levels, expr), anything) => {
let anything = anything.clone();
let (expr, tcs) = tcs
.check(&**expr, &anything.fall(*levels))
.map_err(|e| e.wrap(*info))?;
Ok((expr.map_ast(|ast| ast.lift(*levels)), tcs))
}
(Whatever(info), Val::Dt(Pi, _, param_ty, ..)) => match &**param_ty {
Val::RowPoly(Variant, variants) if variants.is_empty() => {
Ok((Val::Lam(Closure::default()).into_info(*info), tcs))
}
ty => Err(TCE::NotEmpty(*info, ty.clone())),
},
(CaseOr(label, binding, uid, body, or), Val::Dt(Pi, Plicit::Ex, param_ty, ret_ty)) => {
let lam_info = merge_info(binding, &**body);
let lam = Lam(lam_info, binding.clone(), *uid, body.clone());
let (variants, ext) = match &**param_ty {
Val::Neut(Neutral::Row(Variant, variants, ext)) => (variants, Some(&**ext)),
Val::RowPoly(Variant, variants) => (variants, None),
ty => {
let info = merge_info(label, &**or);
return Err(TCE::NotRowType(Variant, info, ty.clone()));
}
};
let mut variants = variants.clone();
let param_ty = variants
.remove(&label.text)
.ok_or_else(|| TCE::MissingVariant(Variant, label.text.clone()))?;
let input = match ext {
None => Val::variant_type(variants),
Some(ext) => Val::neutral_variant_type(variants, ext.clone()),
};
let stripped_function = Val::pi(Plicit::Ex, input, ret_ty.clone());
let dt = Val::pi(Plicit::Ex, param_ty, ret_ty.clone());
let (body, tcs) = tcs.check(&lam, &dt)?;
let mut split = CaseSplit::default();
split.insert(label.text.clone(), Closure::plain(body.ast));
let ext = Val::case_tree(split);
let (or, tcs) = tcs.check(&**or, &stripped_function)?;
Ok((or.ast.split_extend(ext).into_info(or.loc), tcs))
}
(expr, anything) => check_fallback(tcs, expr, anything),
}
}
fn mock_for(param_ty: &Val, fallback: impl FnOnce() -> Val) -> Val {
use Val::*;
fn go(param_ty: &Val) -> Option<Val> {
match param_ty {
RowPoly(Record, v) if v.is_empty() => Some(Rec(Default::default())),
RowPoly(Variant, v) if v.len() == 1 => {
let (name, ty) = v.iter().next().unwrap();
Some(Val::cons(name.clone(), go(ty)?))
}
_ => None,
}
}
go(param_ty).unwrap_or_else(fallback)
}
fn check_fallback(tcs: TCS, expr: &Abs, expected_type: &Val) -> ValTCM {
let (inferred, tcs) = tcs.infer(expr)?;
Ok(tcs
.subtype(&inferred.ast, expected_type)
.map_err(|e| e.wrap(inferred.loc))?
.evaluate(expr.clone()))
}
fn check_fields_no_more(
info: Loc,
nice_fields: Fields,
rest_field_types: Variants,
tcs: TCS,
) -> ValTCM {
match rest_field_types.keys().next() {
Some(missing_field) => Err(TCE::MissingVariant(Record, missing_field.clone())),
None => Ok((Val::Rec(nice_fields).into_info(info), tcs)),
}
}
fn check_fields(
mut tcs: TCS,
fields: &[LabAbs],
field_types: &Fields,
) -> TCM<(Fields, Variants, TCS)> {
let mut nice_fields = Fields::new();
for field in fields {
if let Some(ty) = field_types.get(&field.label.text) {
let key = field.label.text.clone();
let (field, new_tcs) = tcs.check(&field.expr, ty)?;
tcs = new_tcs;
nice_fields.insert(key, field.ast);
}
}
let rest_field_types = field_types
.iter()
.filter(|(label, _)| !nice_fields.contains_key(&**label))
.map(|(label, expr)| (label.clone(), expr.clone()))
.collect();
Ok((nice_fields, rest_field_types, tcs))
}
fn check_row_polymorphic_type(
mut tcs: TCS,
info: Loc,
level: Level,
kind: VarRec,
variants: &[LabAbs],
ext: &Option<Box<Abs>>,
labels: &[String],
) -> ValTCM {
let mut out_variants = Variants::new();
for labelled in variants {
let (val, new_tcs) = tcs
.check(&labelled.expr, &Val::Type(level))
.map_err(|e| e.wrap(info))?;
tcs = new_tcs;
let label = &labelled.label.text;
if out_variants.contains_key(label) {
return Err(TCE::OverlappingVariant(val.loc, label.clone()));
} else if labels.contains(label) {
return Err(TCE::UnexpectedVariant(val.loc, label.clone()));
}
out_variants.insert(label.clone(), val.ast);
}
match ext {
None => Ok((Val::RowPoly(kind, out_variants).into_info(info), tcs)),
Some(ext) => {
let known_labels = out_variants.keys().chain(labels.iter()).cloned().collect();
let expected_kind = Val::RowKind(Default::default(), kind, known_labels);
let (ext, new_tcs) = tcs
.check(&**ext, &expected_kind)
.map_err(|e| e.wrap(info))?;
let row_poly = Val::RowPoly(kind, out_variants)
.row_extend(ext.ast)
.into_info(info);
Ok((row_poly, new_tcs))
}
}
}
fn infer(tcs: TCS, value: &Abs) -> ValTCM {
use Abs::*;
let info = value.loc();
match value {
Type(_, level) => Ok((Val::Type(*level + 1).into_info(info), tcs)),
RowKind(..) => Ok((Val::Type(From::from(1 as LevelType)).into_info(info), tcs)),
RowPoly(_, kind, variants, more) => {
let mut labels = Vec::with_capacity(variants.len());
let mut tcs = tcs;
let mut max_level = Level::default();
for variant in variants {
let (val, new_tcs) = tcs.check(&variant.expr, &TYPE_OMEGA)?;
tcs = new_tcs;
labels.push(variant.label.text.clone());
max_level = max_level.max(val.ast.level());
}
let kind_level = max_level + 1;
match more {
None => Ok((Val::Type(kind_level).into_info(info), tcs)),
Some(more) => {
let expected = Val::RowKind(kind_level, *kind, labels);
let (_, tcs) = tcs.check(&**more, &expected)?;
Ok((Val::Type(kind_level).into_info(info), tcs))
}
}
}
Rec(_, fields, ext) => {
let (ext, tcs) = ext
.as_ref()
.map(|abs| tcs.infer(&**abs).map_err(|e| e.wrap(info)))
.transpose()?
.unwrap_or_default();
let (mut ext_fields, more) = match ext.ast {
Val::RowPoly(Record, fields) => (fields, None),
Val::Neut(Neutral::Row(Record, fields, more)) => (fields, Some(*more)),
e => return Err(TCE::NotRecVal(ext.loc, e)),
};
let mut tcs = tcs;
for field in fields {
if ext_fields.contains_key(&field.label.text) {
return Err(TCE::duplicate_field(field.label.clone()));
}
let (inferred, new_tcs) = tcs.infer(&field.expr).map_err(|e| e.wrap(info))?;
tcs = new_tcs;
ext_fields.insert(field.label.text.clone(), inferred.ast);
}
let ty = match more {
None => Val::record_type(ext_fields),
Some(more) => Val::neutral_record_type(ext_fields, more),
};
Ok((ty.into_info(info), tcs))
}
Var(_, _, dbi) => {
let local = tcs.local_type(*dbi).ast.clone().attach_dbi(*dbi);
Ok((local.into_info(info), tcs))
}
Lam(..) => {
let mut tcs = tcs;
let param_meta = tcs.fresh_meta();
let ret_meta = tcs.fresh_meta();
let pi = Val::pi(Plicit::Ex, param_meta, Closure::plain(ret_meta));
let (_, tcs) = tcs.check(value, &pi)?;
Ok((pi.into_info(info), tcs))
}
Lift(_, levels, expr) => {
let (expr, tcs) = tcs.infer(&**expr).map_err(|e| e.wrap(info))?;
Ok((expr.map_ast(|ast| ast.lift(*levels)), tcs))
}
Ref(_, dbi) => Ok((tcs.glob_type(*dbi).ast.clone().into_info(info), tcs)),
Pair(_, fst, snd) => {
let (fst_ty, tcs) = tcs.infer(&**fst).map_err(|e| e.wrap(info))?;
let (snd_ty, tcs) = tcs.infer(&**snd).map_err(|e| e.wrap(info))?;
let sigma = Val::sig(fst_ty.ast, Closure::plain(snd_ty.ast)).into_info(info);
Ok((sigma, tcs))
}
Fst(_, pair) => {
let (pair_ty, tcs) = tcs.infer(&**pair).map_err(|e| e.wrap(info))?;
match pair_ty.ast {
Val::Dt(Sigma, Plicit::Ex, param_type, ..) => Ok((param_type.into_info(info), tcs)),
ast => Err(TCE::NotSigma(pair_ty.loc, ast)),
}
}
Proj(_, record, field) => {
let (record_ty, tcs) = tcs.infer(&**record).map_err(|e| e.wrap(info))?;
match record_ty.ast {
Val::Neut(Neutral::Row(Record, mut fields, ..))
| Val::RowPoly(Record, mut fields) => fields
.remove(&field.text)
.map(|ty| (ty.into_info(info), tcs))
.ok_or_else(|| TCE::MissingVariant(Record, field.text.clone())),
ast => Err(TCE::NotRowType(Record, record_ty.loc, ast)),
}
}
Snd(_, pair) => {
let (pair_ty, tcs) = tcs.infer(&**pair).map_err(|e| e.wrap(info))?;
match pair_ty.ast {
Val::Dt(Sigma, Plicit::Ex, _, closure) => {
let (pair_compiled, tcs) = tcs.evaluate(*pair.clone());
let fst = pair_compiled.ast.first();
Ok((closure.instantiate(fst).into_info(info), tcs))
}
ast => Err(TCE::NotSigma(pair_ty.loc, ast)),
}
}
App(_, f, _app_plicit, a) => match &**f {
Cons(variant_info) => {
let (a, tcs) = tcs.infer(a).map_err(|e| e.wrap(info))?;
let mut variant = Variants::default();
variant.insert(variant_info.text[1..].to_owned(), a.ast);
Ok((Val::variant_type(variant).into_info(info), tcs))
}
Whatever(whatever_info) => {
let empty = Val::Lam(Closure::default());
let (_, mut tcs) = tcs.check(a, &empty).map_err(|e| e.wrap(info))?;
Ok((tcs.fresh_meta().into_info(*whatever_info), tcs))
}
f => {
let (f_ty, tcs) = tcs.infer(f).map_err(|e| e.wrap(info))?;
check_app_type(tcs, f, info, a, &f_ty.ast)
}
},
e => Err(TCE::CannotInfer(info, e.clone())),
}
}
fn check_app_type(tcs: TCS, f: &Abs, info: Loc, a: &Abs, pi_ty: &Val) -> ValTCM {
match pi_ty {
Val::Dt(Pi, Plicit::Ex, param_type, closure) => {
let (new_a, tcs) = tcs.check(&a, &*param_type).map_err(|e| e.wrap(info))?;
Ok((closure.instantiate_cloned(new_a.ast).into_info(info), tcs))
}
Val::Dt(Pi, Plicit::Im, _param_type, closure) => {
let mut tcs = tcs;
let inserted_meta = tcs.fresh_meta();
let new_closure = closure.instantiate_cloned(inserted_meta);
check_app_type(tcs, f, info, a, &new_closure)
}
Val::Neut(Neutral::Meta(mi)) => match tcs.meta_context.solution(*mi) {
MetaSolution::Solved(sol) => {
let sol = *sol.clone();
check_app_type(tcs, f, info, a, &sol)
}
_ => Err(TCE::MetaUnsolved(*mi)),
},
other => Err(TCE::NotPi(info, other.clone())),
}
}
fn subtype(tcs: TCS, sub: &Val, sup: &Val) -> TCM {
use Val::*;
match (sub, sup) {
(RowKind(sub_l, ..), Type(sup_l)) | (Type(sub_l), Type(sup_l)) if sub_l <= sup_l => Ok(tcs),
(RowPoly(Record, sub_vs), RowPoly(Record, sup_vs)) => {
tcs.unify_variants(Record, sup_vs, sub_vs)
}
(RowPoly(Variant, sub_vs), RowPoly(Variant, sup_vs)) => {
tcs.unify_variants(Variant, sub_vs, sup_vs)
}
(Dt(k0, plicit_a, input_a, clos_a), Dt(k1, plicit_b, input_b, clos_b))
if k0 == k1 && plicit_a == plicit_b =>
{
let tcs = tcs.unify(input_a, input_b)?;
let p = Val::fresh_axiom();
let a = clos_a.instantiate_borrow(&p);
let b = clos_b.instantiate_cloned(p);
tcs.subtype(&a, &b)
}
(e, t) => tcs.unify(e, t),
}
}
impl TCS {
#[inline]
pub fn check(self, expr: &Abs, expected_type: &Val) -> ValTCM {
check(self, expr, expected_type)
}
#[inline]
pub fn infer(self, value: &Abs) -> ValTCM {
infer(self, value)
}
#[inline]
pub fn subtype(self, sub: &Val, sup: &Val) -> TCM {
subtype(self, sub, sup)
}
}