use crate::context::Context;
use crate::error::{KernelError, KernelResult};
use crate::reduction::normalize;
use crate::term::{Literal, Term, Universe};
pub fn infer_type(ctx: &Context, term: &Term) -> KernelResult<Term> {
match term {
Term::Sort(u) => Ok(Term::Sort(u.succ())),
Term::Var(name) => ctx
.get(name)
.cloned()
.ok_or_else(|| KernelError::UnboundVariable(name.clone())),
Term::Global(name) => ctx
.get_global(name)
.cloned()
.ok_or_else(|| KernelError::UnboundVariable(name.clone())),
Term::Pi {
param,
param_type,
body_type,
} => {
let a_sort = infer_sort(ctx, param_type)?;
let extended_ctx = ctx.extend(param, (**param_type).clone());
let b_sort = infer_sort(&extended_ctx, body_type)?;
Ok(Term::Sort(a_sort.max(&b_sort)))
}
Term::Lambda {
param,
param_type,
body,
} => {
let _ = infer_sort(ctx, param_type)?;
let extended_ctx = ctx.extend(param, (**param_type).clone());
let body_type = infer_type(&extended_ctx, body)?;
Ok(Term::Pi {
param: param.clone(),
param_type: param_type.clone(),
body_type: Box::new(body_type),
})
}
Term::App(func, arg) => {
let func_type = infer_type(ctx, func)?;
match func_type {
Term::Pi {
param,
param_type,
body_type,
} => {
check_type(ctx, arg, ¶m_type)?;
Ok(substitute(&body_type, ¶m, arg))
}
_ => Err(KernelError::NotAFunction(format!("{}", func)))
}
}
Term::Match {
discriminant,
motive,
cases,
} => {
let disc_type = infer_type(ctx, discriminant)?;
let inductive_name = extract_inductive_name(ctx, &disc_type)
.ok_or_else(|| KernelError::NotAnInductive(format!("{}", disc_type)))?;
let motive_type = infer_type(ctx, motive)?;
let effective_motive = match &motive_type {
Term::Pi {
param_type,
body_type,
..
} => {
if !types_equal(param_type, &disc_type) {
return Err(KernelError::InvalidMotive(format!(
"motive parameter {} doesn't match discriminant type {}",
param_type, disc_type
)));
}
match infer_type(ctx, body_type) {
Ok(Term::Sort(_)) => {}
_ => {
return Err(KernelError::InvalidMotive(format!(
"motive body {} is not a type",
body_type
)));
}
}
(**motive).clone()
}
Term::Sort(_) => {
Term::Lambda {
param: "_".to_string(),
param_type: Box::new(disc_type.clone()),
body: motive.clone(),
}
}
_ => {
return Err(KernelError::InvalidMotive(format!(
"motive {} is not a function or type",
motive
)));
}
};
let constructors = ctx.get_constructors(&inductive_name);
if cases.len() != constructors.len() {
return Err(KernelError::WrongNumberOfCases {
expected: constructors.len(),
found: cases.len(),
});
}
for (case, (ctor_name, ctor_type)) in cases.iter().zip(constructors.iter()) {
let expected_case_type = compute_case_type(&effective_motive, ctor_name, ctor_type, &disc_type);
check_type(ctx, case, &expected_case_type)?;
}
let return_type = Term::App(
Box::new(effective_motive),
discriminant.clone(),
);
Ok(beta_reduce(&return_type))
}
Term::Lit(lit) => {
match lit {
Literal::Int(_) => Ok(Term::Global("Int".to_string())),
Literal::Float(_) => Ok(Term::Global("Float".to_string())),
Literal::Text(_) => Ok(Term::Global("Text".to_string())),
Literal::Duration(_) => Ok(Term::Global("Duration".to_string())),
Literal::Date(_) => Ok(Term::Global("Date".to_string())),
Literal::Moment(_) => Ok(Term::Global("Moment".to_string())),
}
}
Term::Hole => Err(KernelError::CannotInferHole),
Term::Fix { name, body } => {
let structural_type = infer_fix_type_structurally(ctx, body)?;
crate::termination::check_termination(ctx, name, body)?;
let extended = ctx.extend(name, structural_type.clone());
let _ = infer_type(&extended, body)?;
Ok(structural_type)
}
}
}
fn infer_fix_type_structurally(ctx: &Context, term: &Term) -> KernelResult<Term> {
match term {
Term::Lambda {
param,
param_type,
body,
} => {
let _ = infer_sort(ctx, param_type)?;
let extended = ctx.extend(param, (**param_type).clone());
let body_type = infer_fix_type_structurally(&extended, body)?;
Ok(Term::Pi {
param: param.clone(),
param_type: param_type.clone(),
body_type: Box::new(body_type),
})
}
Term::Match { motive, .. } => {
if let Term::Lambda { body, .. } = motive.as_ref() {
Ok((**body).clone())
} else {
Ok((**motive).clone())
}
}
_ => infer_type(ctx, term),
}
}
fn check_type(ctx: &Context, term: &Term, expected: &Term) -> KernelResult<()> {
if matches!(term, Term::Hole) {
if matches!(expected, Term::Sort(_)) {
return Ok(());
}
return Err(KernelError::TypeMismatch {
expected: format!("{}", expected),
found: "_".to_string(),
});
}
if matches!(expected, Term::Hole) {
let _ = infer_type(ctx, term)?; return Ok(());
}
if let Term::Lambda {
param,
param_type,
body,
} = term
{
if let Term::Global(name) = param_type.as_ref() {
if name == "_" {
if let Term::Pi {
param_type: expected_param_type,
body_type: expected_body_type,
param: expected_param,
} = expected
{
let extended_ctx = ctx.extend(param, (**expected_param_type).clone());
let body_expected = if param != expected_param {
substitute(expected_body_type, expected_param, &Term::Var(param.clone()))
} else {
(**expected_body_type).clone()
};
return check_type(&extended_ctx, body, &body_expected);
}
}
}
}
let inferred = infer_type(ctx, term)?;
if is_subtype(ctx, &inferred, expected) {
Ok(())
} else {
Err(KernelError::TypeMismatch {
expected: format!("{}", expected),
found: format!("{}", inferred),
})
}
}
fn infer_sort(ctx: &Context, term: &Term) -> KernelResult<Universe> {
let ty = infer_type(ctx, term)?;
match ty {
Term::Sort(u) => Ok(u),
_ => Err(KernelError::NotAType(format!("{}", term))),
}
}
fn beta_reduce(term: &Term) -> Term {
match term {
Term::App(func, arg) => {
match func.as_ref() {
Term::Lambda { param, body, .. } => {
substitute(body, param, arg)
}
_ => term.clone(),
}
}
_ => term.clone(),
}
}
fn compute_case_type(motive: &Term, ctor_name: &str, ctor_type: &Term, disc_type: &Term) -> Term {
let type_args = extract_type_args(disc_type);
let num_type_args = type_args.len();
let mut all_params: Vec<(String, Term)> = Vec::new();
let mut current = ctor_type;
while let Term::Pi {
param,
param_type,
body_type,
} = current
{
all_params.push((param.clone(), (**param_type).clone()));
current = body_type;
}
let (type_params, value_params): (Vec<_>, Vec<_>) = all_params
.into_iter()
.enumerate()
.partition(|(i, _)| *i < num_type_args);
let named_value_params: Vec<(usize, (String, Term))> = value_params
.into_iter()
.enumerate()
.map(|(i, (idx, (_, param_type)))| {
(idx, (format!("__arg{}", i), param_type))
})
.collect();
let mut ctor_applied = Term::Global(ctor_name.to_string());
for type_arg in &type_args {
ctor_applied = Term::App(Box::new(ctor_applied), Box::new(type_arg.clone()));
}
for (_, (param_name, _)) in &named_value_params {
ctor_applied = Term::App(
Box::new(ctor_applied),
Box::new(Term::Var(param_name.clone())),
);
}
let motive_applied = Term::App(Box::new(motive.clone()), Box::new(ctor_applied));
let result_type = beta_reduce(&motive_applied);
let mut case_type = result_type;
for (_, (param_name, param_type)) in named_value_params.into_iter().rev() {
let mut subst_param_type = param_type;
for ((_, (type_param_name, _)), type_arg) in type_params.iter().zip(type_args.iter()) {
subst_param_type = substitute(&subst_param_type, type_param_name, type_arg);
}
case_type = Term::Pi {
param: param_name,
param_type: Box::new(subst_param_type),
body_type: Box::new(case_type),
};
}
case_type
}
fn extract_type_args(ty: &Term) -> Vec<Term> {
let mut args = Vec::new();
let mut current = ty;
while let Term::App(func, arg) = current {
args.push((**arg).clone());
current = func;
}
args.reverse();
args
}
pub fn substitute(body: &Term, var: &str, replacement: &Term) -> Term {
match body {
Term::Sort(u) => Term::Sort(u.clone()),
Term::Lit(lit) => Term::Lit(lit.clone()),
Term::Hole => Term::Hole,
Term::Var(name) if name == var => replacement.clone(),
Term::Var(name) => Term::Var(name.clone()),
Term::Global(name) => Term::Global(name.clone()),
Term::Pi {
param,
param_type,
body_type,
} => {
let new_param_type = substitute(param_type, var, replacement);
let new_body_type = if param == var {
(**body_type).clone()
} else {
substitute(body_type, var, replacement)
};
Term::Pi {
param: param.clone(),
param_type: Box::new(new_param_type),
body_type: Box::new(new_body_type),
}
}
Term::Lambda {
param,
param_type,
body,
} => {
let new_param_type = substitute(param_type, var, replacement);
let new_body = if param == var {
(**body).clone()
} else {
substitute(body, var, replacement)
};
Term::Lambda {
param: param.clone(),
param_type: Box::new(new_param_type),
body: Box::new(new_body),
}
}
Term::App(func, arg) => Term::App(
Box::new(substitute(func, var, replacement)),
Box::new(substitute(arg, var, replacement)),
),
Term::Match {
discriminant,
motive,
cases,
} => Term::Match {
discriminant: Box::new(substitute(discriminant, var, replacement)),
motive: Box::new(substitute(motive, var, replacement)),
cases: cases
.iter()
.map(|c| substitute(c, var, replacement))
.collect(),
},
Term::Fix { name, body } => {
if name == var {
Term::Fix {
name: name.clone(),
body: body.clone(),
}
} else {
Term::Fix {
name: name.clone(),
body: Box::new(substitute(body, var, replacement)),
}
}
}
}
}
pub fn is_subtype(ctx: &Context, a: &Term, b: &Term) -> bool {
let a_norm = normalize(ctx, a);
let b_norm = normalize(ctx, b);
is_subtype_normalized(ctx, &a_norm, &b_norm)
}
fn is_subtype_normalized(ctx: &Context, a: &Term, b: &Term) -> bool {
match (a, b) {
(Term::Sort(u1), Term::Sort(u2)) => u1.is_subtype_of(u2),
(
Term::Pi {
param: p1,
param_type: t1,
body_type: b1,
},
Term::Pi {
param: p2,
param_type: t2,
body_type: b2,
},
) => {
is_subtype_normalized(ctx, t2, t1) && {
let b2_renamed = substitute(b2, p2, &Term::Var(p1.clone()));
is_subtype_normalized(ctx, b1, &b2_renamed)
}
}
_ => types_equal(a, b),
}
}
fn extract_inductive_name(ctx: &Context, ty: &Term) -> Option<String> {
match ty {
Term::Global(name) if ctx.is_inductive(name) => Some(name.clone()),
Term::App(func, _) => extract_inductive_name(ctx, func),
_ => None,
}
}
fn types_equal(a: &Term, b: &Term) -> bool {
if matches!(a, Term::Hole) || matches!(b, Term::Hole) {
return true;
}
match (a, b) {
(Term::Sort(u1), Term::Sort(u2)) => u1 == u2,
(Term::Lit(l1), Term::Lit(l2)) => l1 == l2,
(Term::Var(n1), Term::Var(n2)) => n1 == n2,
(Term::Global(n1), Term::Global(n2)) => n1 == n2,
(
Term::Pi {
param: p1,
param_type: t1,
body_type: b1,
},
Term::Pi {
param: p2,
param_type: t2,
body_type: b2,
},
) => {
types_equal(t1, t2) && {
let b2_renamed = substitute(b2, p2, &Term::Var(p1.clone()));
types_equal(b1, &b2_renamed)
}
}
(
Term::Lambda {
param: p1,
param_type: t1,
body: b1,
},
Term::Lambda {
param: p2,
param_type: t2,
body: b2,
},
) => {
types_equal(t1, t2) && {
let b2_renamed = substitute(b2, p2, &Term::Var(p1.clone()));
types_equal(b1, &b2_renamed)
}
}
(Term::App(f1, a1), Term::App(f2, a2)) => types_equal(f1, f2) && types_equal(a1, a2),
(
Term::Match {
discriminant: d1,
motive: m1,
cases: c1,
},
Term::Match {
discriminant: d2,
motive: m2,
cases: c2,
},
) => {
types_equal(d1, d2)
&& types_equal(m1, m2)
&& c1.len() == c2.len()
&& c1.iter().zip(c2.iter()).all(|(a, b)| types_equal(a, b))
}
(
Term::Fix {
name: n1,
body: b1,
},
Term::Fix {
name: n2,
body: b2,
},
) => {
let b2_renamed = substitute(b2, n2, &Term::Var(n1.clone()));
types_equal(b1, &b2_renamed)
}
_ => false,
}
}