use std::collections::HashSet;
use crate::context::Context;
use crate::error::{KernelError, KernelResult};
use crate::term::Term;
struct GuardContext {
fix_name: String,
struct_param: String,
struct_type: String,
smaller_than: HashSet<String>,
}
pub fn check_termination(ctx: &Context, fix_name: &str, body: &Term) -> KernelResult<()> {
let (struct_param, struct_type, inner) = extract_structural_param(ctx, body)?;
let guard_ctx = GuardContext {
fix_name: fix_name.to_string(),
struct_param,
struct_type,
smaller_than: HashSet::new(),
};
check_guarded(ctx, &guard_ctx, inner)
}
fn extract_structural_param<'a>(
ctx: &Context,
body: &'a Term,
) -> KernelResult<(String, String, &'a Term)> {
match body {
Term::Lambda {
param,
param_type,
body,
} => {
if let Some(type_name) = extract_inductive_name(ctx, param_type) {
return Ok((param.clone(), type_name, body));
}
extract_structural_param(ctx, body)
}
_ => Err(KernelError::TerminationViolation {
fix_name: String::new(),
reason: "No inductive parameter found for structural recursion".to_string(),
}),
}
}
fn extract_inductive_name(ctx: &Context, ty: &Term) -> Option<String> {
match ty {
Term::Global(name) if ctx.is_inductive(name) => Some(name.clone()),
Term::App(func, _) => extract_inductive_name(ctx, func),
_ => None,
}
}
fn check_guarded(ctx: &Context, guard_ctx: &GuardContext, term: &Term) -> KernelResult<()> {
match term {
Term::App(func, arg) => {
check_recursive_call(ctx, guard_ctx, func, arg)?;
check_guarded(ctx, guard_ctx, func)?;
check_guarded(ctx, guard_ctx, arg)
}
Term::Match {
discriminant,
cases,
..
} => {
if let Term::Var(disc_name) = discriminant.as_ref() {
if disc_name == &guard_ctx.struct_param {
return check_match_cases_guarded(ctx, guard_ctx, cases);
}
}
check_guarded(ctx, guard_ctx, discriminant)?;
for case in cases {
check_guarded(ctx, guard_ctx, case)?;
}
Ok(())
}
Term::Lambda { body, .. } => check_guarded(ctx, guard_ctx, body),
Term::Pi { body_type, .. } => check_guarded(ctx, guard_ctx, body_type),
Term::Fix { body, .. } => {
check_guarded(ctx, guard_ctx, body)
}
Term::Sort(_) | Term::Var(_) | Term::Global(_) | Term::Lit(_) | Term::Hole => Ok(()),
}
}
fn check_recursive_call(
_ctx: &Context,
guard_ctx: &GuardContext,
func: &Term,
arg: &Term,
) -> KernelResult<()> {
let (head, first_arg) = extract_head_and_first_arg(func, arg);
if let Term::Var(name) = head {
if name == &guard_ctx.fix_name {
match first_arg {
Term::Var(arg_name) => {
if !guard_ctx.smaller_than.contains(arg_name) {
return Err(KernelError::TerminationViolation {
fix_name: guard_ctx.fix_name.clone(),
reason: format!(
"Recursive call with '{}' which is not structurally smaller than '{}'",
arg_name, guard_ctx.struct_param
),
});
}
Ok(())
}
_ => {
Err(KernelError::TerminationViolation {
fix_name: guard_ctx.fix_name.clone(),
reason: format!(
"Recursive call with complex argument - cannot verify termination"
),
})
}
}
} else {
Ok(()) }
} else {
Ok(()) }
}
fn extract_head_and_first_arg<'a>(func: &'a Term, arg: &'a Term) -> (&'a Term, &'a Term) {
let mut current = func;
let mut first_arg = arg;
while let Term::App(inner_func, inner_arg) = current {
first_arg = inner_arg;
current = inner_func;
}
(current, first_arg)
}
fn check_match_cases_guarded(
ctx: &Context,
guard_ctx: &GuardContext,
cases: &[Term],
) -> KernelResult<()> {
let constructors = ctx.get_constructors(&guard_ctx.struct_type);
for (case, (ctor_name, ctor_type)) in cases.iter().zip(constructors.iter()) {
let param_count = count_pi_params(ctor_type);
let (smaller_vars, case_body) = extract_lambda_params(case, param_count);
let mut extended_ctx = GuardContext {
fix_name: guard_ctx.fix_name.clone(),
struct_param: guard_ctx.struct_param.clone(),
struct_type: guard_ctx.struct_type.clone(),
smaller_than: guard_ctx.smaller_than.clone(),
};
let recursive_params = get_recursive_params(ctx, &guard_ctx.struct_type, ctor_type);
for (idx, _) in recursive_params {
if idx < smaller_vars.len() {
extended_ctx.smaller_than.insert(smaller_vars[idx].clone());
}
}
for var in &smaller_vars {
extended_ctx.smaller_than.insert(var.clone());
}
check_guarded(ctx, &extended_ctx, case_body)?;
}
Ok(())
}
fn count_pi_params(ty: &Term) -> usize {
match ty {
Term::Pi { body_type, .. } => 1 + count_pi_params(body_type),
_ => 0,
}
}
fn extract_lambda_params(term: &Term, count: usize) -> (Vec<String>, &Term) {
if count == 0 {
return (Vec::new(), term);
}
match term {
Term::Lambda { param, body, .. } => {
let (mut params, final_body) = extract_lambda_params(body, count - 1);
params.insert(0, param.clone());
(params, final_body)
}
_ => (Vec::new(), term),
}
}
fn get_recursive_params(ctx: &Context, inductive: &str, ctor_type: &Term) -> Vec<(usize, String)> {
let mut result = Vec::new();
let mut current = ctor_type;
let mut idx = 0;
while let Term::Pi {
param,
param_type,
body_type,
} = current
{
if let Some(type_name) = extract_inductive_name(ctx, param_type) {
if type_name == inductive {
result.push((idx, param.clone()));
}
}
idx += 1;
current = body_type;
}
result
}