use std::collections::{HashMap, HashSet};
use super::expr::{aver_name_to_lean, emit_expr, emit_stmt};
use super::law_auto::{emit_verify_law_forall_auto_proof, emit_verify_law_support_theorems};
use super::recurrence::{
detect_second_order_int_linear_recurrence, recurrence_nat_helper_name, render_affine_pair_expr,
};
use super::shared::to_lower_first;
use super::types::type_annotation_to_lean;
use super::{RecursionPlan, VerifyEmitMode, sizeof_measure_param_indices};
use crate::ast::{self, TailCallData, *};
use crate::codegen::CodegenContext;
use crate::codegen::common::expr_to_dotted_name;
use crate::verify_law::canonical_spec_ref;
pub fn emit_type_def(td: &TypeDef) -> String {
match td {
TypeDef::Sum { name, variants, .. } => emit_sum_type(name, variants),
TypeDef::Product { name, fields, .. } => emit_product_type(name, fields),
}
}
fn is_recursive_type(name: &str, variants: &[TypeVariant]) -> bool {
for v in variants {
for field in &v.fields {
if field_type_contains(field, name) {
return true;
}
}
}
false
}
fn is_recursive_product(name: &str, fields: &[(String, String)]) -> bool {
for (_, field_type) in fields {
if field_type_contains(field_type, name) {
return true;
}
}
false
}
fn field_type_contains(field_type: &str, type_name: &str) -> bool {
field_type == type_name
|| field_type.contains(&format!("<{}", type_name))
|| field_type.contains(&format!("{}>", type_name))
|| field_type.contains(&format!(", {}", type_name))
|| field_type.contains(&format!("{},", type_name))
}
fn emit_sum_type(name: &str, variants: &[TypeVariant]) -> String {
let mut lines = Vec::new();
let is_recursive = is_recursive_type(name, variants);
lines.push(format!("inductive {} where", name));
for v in variants {
let lean_name = to_lower_first(&v.name);
if v.fields.is_empty() {
lines.push(format!(" | {}", lean_name));
} else {
let field_types: Vec<String> = v
.fields
.iter()
.map(|f| type_annotation_to_lean(f))
.collect();
let fields_str = field_types
.iter()
.map(|t| format!("({} : {})", "_", t))
.collect::<Vec<_>>()
.join(" ");
lines.push(format!(" | {} {}", lean_name, fields_str));
}
}
if is_recursive {
lines.push(" deriving Repr, BEq, Inhabited".to_string());
} else {
lines.push(" deriving Repr, BEq, Inhabited, DecidableEq".to_string());
}
lines.join("\n")
}
fn emit_product_type(name: &str, fields: &[(String, String)]) -> String {
let mut lines = Vec::new();
let is_recursive = is_recursive_product(name, fields);
lines.push(format!("structure {} where", name));
for (field_name, field_type) in fields {
lines.push(format!(
" {} : {}",
aver_name_to_lean(field_name),
type_annotation_to_lean(field_type)
));
}
if is_recursive {
lines.push(" deriving Repr, BEq, Inhabited".to_string());
} else {
lines.push(" deriving Repr, BEq, Inhabited, DecidableEq".to_string());
}
lines.join("\n")
}
fn measure_fn_name(type_name: &str) -> String {
format!("averMeasure{}", type_name)
}
fn measure_list_fn_name(type_name: &str) -> String {
format!("{}List", measure_fn_name(type_name))
}
fn measure_entries_fn_name(type_name: &str, key_type: &str) -> String {
let key_suffix: String = key_type
.chars()
.map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' })
.collect();
format!("{}Entries_{}", measure_fn_name(type_name), key_suffix)
}
fn split_top_level(s: &str, delim: char) -> Vec<String> {
crate::codegen::common::split_type_params(s, delim)
}
fn unwrap_generic<'a>(type_name: &'a str, prefix: &str) -> Option<&'a str> {
type_name
.strip_prefix(prefix)
.and_then(|rest| rest.strip_suffix('>'))
}
fn type_measure_expr(
type_name: &str,
value_expr: &str,
recursive_types: &HashSet<String>,
self_type: Option<&str>,
) -> Option<String> {
let trimmed = type_name.trim();
if recursive_types.contains(trimmed) {
return Some(format!("{} {}", measure_fn_name(trimmed), value_expr));
}
if let Some(inner) = unwrap_generic(trimmed, "List<") {
if recursive_types.contains(inner.trim()) {
return Some(format!(
"{} {}",
measure_list_fn_name(inner.trim()),
value_expr
));
}
let item_measure = type_measure_expr(inner, "item", recursive_types, self_type)
.unwrap_or_else(|| "1".to_string());
return Some(format!(
"AverMeasure.list (fun item => {}) {}",
item_measure, value_expr
));
}
if let Some(inner) = unwrap_generic(trimmed, "Option<") {
if self_type == Some(inner.trim()) {
return Some(format!(
"(match {} with | .none => 1 | .some item => {} item + 1)",
value_expr,
measure_fn_name(inner.trim())
));
}
let item_measure = type_measure_expr(inner, "item", recursive_types, self_type)
.unwrap_or_else(|| "1".to_string());
return Some(format!(
"AverMeasure.option (fun item => {}) {}",
item_measure, value_expr
));
}
if let Some(inner) = unwrap_generic(trimmed, "Map<") {
let args = split_top_level(inner, ',');
if args.len() == 2 {
let key_type = args[0].trim();
let value_type = args[1].trim();
if recursive_types.contains(value_type) {
return Some(format!(
"{} (AverMap.entries {})",
measure_entries_fn_name(value_type, key_type),
value_expr
));
}
let key_measure = type_measure_expr(key_type, "entry.1", recursive_types, self_type);
let value_measure =
type_measure_expr(value_type, "entry.2", recursive_types, self_type);
let entry_measure = match (key_measure, value_measure) {
(Some(k), Some(v)) => format!("({k}) + ({v}) + 1"),
(Some(k), None) => format!("({k}) + 1"),
(None, Some(v)) => format!("({v}) + 1"),
(None, None) => "1".to_string(),
};
return Some(format!(
"AverMeasure.list (fun entry => {}) (AverMap.entries {})",
entry_measure, value_expr
));
}
}
if let Some(inner) = unwrap_generic(trimmed, "Result<") {
let args = split_top_level(inner, ',');
if args.len() == 2 {
let ok_measure = type_measure_expr(&args[0], "okVal", recursive_types, self_type)
.unwrap_or_else(|| "1".to_string());
let err_measure = type_measure_expr(&args[1], "errVal", recursive_types, self_type)
.unwrap_or_else(|| "1".to_string());
return Some(format!(
"AverMeasure.except (fun errVal => {}) (fun okVal => {}) {}",
err_measure, ok_measure, value_expr
));
}
}
if trimmed.starts_with('(') && trimmed.ends_with(')') {
let inner = &trimmed[1..trimmed.len() - 1];
let parts = split_top_level(inner, ',');
if !parts.is_empty() {
let measures: Vec<String> = parts
.iter()
.enumerate()
.filter_map(|(idx, part)| {
type_measure_expr(
part,
&format!("{}.{}", value_expr, idx + 1),
recursive_types,
self_type,
)
})
.collect();
if !measures.is_empty() {
return Some(format!("({}) + 1", measures.join(" + ")));
}
}
}
None
}
fn recursive_map_key_types(type_refs: &[String], value_type: &str) -> Vec<String> {
let mut key_types = Vec::new();
for type_ref in type_refs {
let Some(inner) = unwrap_generic(type_ref.trim(), "Map<") else {
continue;
};
let args = split_top_level(inner, ',');
if args.len() == 2 && args[1].trim() == value_type {
let key_type = args[0].trim().to_string();
if !key_types.contains(&key_type) {
key_types.push(key_type);
}
}
}
key_types
}
fn emit_recursive_sum_measure(
name: &str,
variants: &[TypeVariant],
recursive_types: &HashSet<String>,
) -> String {
let mut lines = vec!["mutual".to_string()];
lines.push(format!(
" def {} (value : {}) : Nat :=",
measure_fn_name(name),
name
));
lines.push(" match value with".to_string());
for variant in variants {
let ctor = to_lower_first(&variant.name);
if variant.fields.is_empty() {
lines.push(format!(" | .{} => 1", ctor));
continue;
}
let binders: Vec<String> = (0..variant.fields.len())
.map(|idx| format!("x{idx}"))
.collect();
let field_measures: Vec<String> = variant
.fields
.iter()
.zip(binders.iter())
.filter_map(|(field_ty, binder)| {
type_measure_expr(field_ty, binder, recursive_types, Some(name))
})
.collect();
if field_measures.is_empty() {
lines.push(format!(" | .{} {} => 1", ctor, binders.join(" ")));
} else {
lines.push(format!(
" | .{} {} => ({}) + 1",
ctor,
binders.join(" "),
field_measures.join(" + ")
));
}
}
lines.push(format!(
" def {} (items : List {}) : Nat :=",
measure_list_fn_name(name),
name
));
lines.push(" match items with".to_string());
lines.push(" | [] => 1".to_string());
lines.push(format!(
" | head :: tail => {} head + {} tail + 1",
measure_fn_name(name),
measure_list_fn_name(name)
));
let field_types: Vec<String> = variants
.iter()
.flat_map(|variant| variant.fields.iter().cloned())
.collect();
for key_type in recursive_map_key_types(&field_types, name) {
lines.push(format!(
" def {} (items : List ({} × {})) : Nat :=",
measure_entries_fn_name(name, &key_type),
key_type,
name
));
lines.push(" match items with".to_string());
lines.push(" | [] => 1".to_string());
lines.push(format!(
" | (_, value) :: tail => {} value + {} tail + 1",
measure_fn_name(name),
measure_entries_fn_name(name, &key_type)
));
}
lines.push("end".to_string());
lines.join("\n")
}
fn emit_recursive_product_measure(
name: &str,
fields: &[(String, String)],
recursive_types: &HashSet<String>,
) -> String {
let field_measures: Vec<String> = fields
.iter()
.filter_map(|(field_name, field_ty)| {
type_measure_expr(
field_ty,
&format!("value.{}", aver_name_to_lean(field_name)),
recursive_types,
Some(name),
)
})
.collect();
let body = if field_measures.is_empty() {
"1".to_string()
} else {
format!("({}) + 1", field_measures.join(" + "))
};
let mut lines = vec![
"mutual".to_string(),
format!(
" def {} (value : {}) : Nat :=",
measure_fn_name(name),
name
),
format!(" {}", body),
format!(
" def {} (items : List {}) : Nat :=",
measure_list_fn_name(name),
name
),
" match items with".to_string(),
" | [] => 1".to_string(),
format!(
" | head :: tail => {} head + {} tail + 1",
measure_fn_name(name),
measure_list_fn_name(name)
),
];
let field_types: Vec<String> = fields.iter().map(|(_, ty)| ty.clone()).collect();
for key_type in recursive_map_key_types(&field_types, name) {
lines.push(format!(
" def {} (items : List ({} × {})) : Nat :=",
measure_entries_fn_name(name, &key_type),
key_type,
name
));
lines.push(" match items with".to_string());
lines.push(" | [] => 1".to_string());
lines.push(format!(
" | (_, value) :: tail => {} value + {} tail + 1",
measure_fn_name(name),
measure_entries_fn_name(name, &key_type)
));
}
lines.push("end".to_string());
lines.join("\n")
}
pub fn emit_recursive_measure(td: &TypeDef, recursive_types: &HashSet<String>) -> Option<String> {
match td {
TypeDef::Sum { name, variants, .. } if is_recursive_type(name, variants) => {
Some(emit_recursive_sum_measure(name, variants, recursive_types))
}
TypeDef::Product { name, fields, .. } if is_recursive_product(name, fields) => Some(
emit_recursive_product_measure(name, fields, recursive_types),
),
_ => None,
}
}
pub fn is_recursive_type_def(td: &TypeDef) -> bool {
match td {
TypeDef::Sum { name, variants, .. } => is_recursive_type(name, variants),
TypeDef::Product { name, fields, .. } => is_recursive_product(name, fields),
}
}
pub fn type_def_name(td: &TypeDef) -> &str {
match td {
TypeDef::Sum { name, .. } => name,
TypeDef::Product { name, .. } => name,
}
}
pub fn emit_recursive_decidable_eq(name: &str) -> String {
let mut lines = Vec::new();
lines.push(format!(
"private unsafe def {}.unsafeDecEq (a b : {}) : Decidable (a = b) :=",
name, name
));
lines.push(" if a == b then isTrue (unsafeCast ()) else isFalse (unsafeCast ())".to_string());
lines.push(format!("@[implemented_by {}.unsafeDecEq]", name));
lines.push(format!(
"private opaque {}.compDecEq (a b : {}) : Decidable (a = b)",
name, name
));
lines.push(format!(
"instance : DecidableEq {} := {}.compDecEq",
name, name
));
lines.join("\n")
}
const STRING_POS_FUEL_VAR: &str = "fuel'";
const PROOF_FUEL_EXHAUSTED: &str = "panic! \"Aver proof fuel exhausted\"";
fn fuel_helper_name(name: &str) -> String {
format!("{}__fuel", aver_name_to_lean(name))
}
fn emit_fn_param_names(params: &[(String, String)]) -> String {
params
.iter()
.map(|(name, _)| aver_name_to_lean(name))
.collect::<Vec<_>>()
.join(" ")
}
fn indent_lines(block: &str, prefix: &str) -> Vec<String> {
block
.lines()
.map(|line| format!("{prefix}{line}"))
.collect()
}
fn emit_doc_comment(desc: &Option<String>) -> Vec<String> {
desc.as_ref()
.map(|text| vec![format!("/-- {} -/", text)])
.unwrap_or_default()
}
fn ret_type_or_unit(fd: &FnDef) -> String {
if fd.return_type.is_empty() {
"Unit".to_string()
} else {
type_annotation_to_lean(&fd.return_type)
}
}
fn emit_fuel_helper_def(
helper_name: &str,
params: &str,
ret_type: &str,
body: &str,
outer_indent: &str,
) -> Vec<String> {
let branch_indent = format!("{outer_indent} ");
[
vec![format!(
"{outer_indent}def {} (fuel : Nat) {} : {} :=",
helper_name, params, ret_type
)],
vec![format!("{outer_indent} match fuel with")],
vec![format!("{outer_indent} | 0 => {}", PROOF_FUEL_EXHAUSTED)],
vec![format!("{outer_indent} | {} + 1 =>", STRING_POS_FUEL_VAR)],
indent_lines(body, &branch_indent),
]
.into_iter()
.flatten()
.collect()
}
fn emit_string_pos_wrapper(fd: &FnDef, helper_name: &str, rank_budget: usize) -> Vec<String> {
let fn_name = aver_name_to_lean(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let arg_names = emit_fn_param_names(&fd.params);
let (s_name, _) = &fd.params[0];
let (pos_name, _) = &fd.params[1];
vec![
format!("def {} {} : {} :=", fn_name, params, ret_type),
format!(
" {} (averStringPosFuel {} {} {}) {}",
helper_name,
aver_name_to_lean(s_name),
aver_name_to_lean(pos_name),
rank_budget,
arg_names
),
]
}
fn emit_int_countdown_wrapper(fd: &FnDef, helper_name: &str, param_index: usize) -> Vec<String> {
let fn_name = aver_name_to_lean(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let arg_names = emit_fn_param_names(&fd.params);
let metric_name = fd
.params
.get(param_index)
.map(|(name, _)| aver_name_to_lean(name))
.unwrap_or_else(|| "0".to_string());
vec![
format!("def {} {} : {} :=", fn_name, params, ret_type),
format!(
" {} ((Int.natAbs {}) + 1) {}",
helper_name, metric_name, arg_names
),
]
}
fn emit_nat_linear_recurrence_fn(
fd: &FnDef,
shape: &super::recurrence::SecondOrderIntLinearRecurrenceShape,
ctx: &CodegenContext,
) -> String {
let fn_name = aver_name_to_lean(&fd.name);
let nat_helper_name = recurrence_nat_helper_name(&fd.name);
let lean_param = aver_name_to_lean(&shape.param_name);
let ret_type = ret_type_or_unit(fd);
let nat_step = render_affine_pair_expr(
shape.recurrence,
&format!("{nat_helper_name} n"),
&format!("{nat_helper_name} (n + 1)"),
);
[
emit_doc_comment(&fd.desc),
vec![
format!("private def {} : Nat -> {}", nat_helper_name, ret_type),
format!(" | 0 => {}", emit_expr(&shape.base0, ctx)),
format!(" | 1 => {}", emit_expr(&shape.base1, ctx)),
format!(" | n + 2 => {}", nat_step),
String::new(),
format!("def {} ({} : Int) : {} :=", fn_name, lean_param, ret_type),
format!(
" if {} < 0 then {} else {} {}.toNat",
lean_param,
emit_expr(&shape.negative_branch, ctx),
nat_helper_name,
lean_param
),
],
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n")
}
fn emit_sizeof_measure_expr(fd: &FnDef, recursive_types: &HashSet<String>) -> Option<String> {
let measure_terms: Vec<String> = sizeof_measure_param_indices(fd)
.into_iter()
.filter_map(|idx| {
fd.params.get(idx).and_then(|(name, type_name)| {
type_measure_expr(type_name, &aver_name_to_lean(name), recursive_types, None)
})
})
.collect();
(!measure_terms.is_empty()).then(|| measure_terms.join(" + "))
}
fn emit_mutual_sizeof_wrapper(
fd: &FnDef,
helper_name: &str,
rank_budget: usize,
recursive_types: &HashSet<String>,
) -> Vec<String> {
let fn_name = aver_name_to_lean(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let arg_names = emit_fn_param_names(&fd.params);
let fuel_expr = emit_sizeof_measure_expr(fd, recursive_types)
.map(|measure| format!("(({}) + 1) * {}", measure, rank_budget))
.unwrap_or_else(|| rank_budget.to_string());
vec![
format!("def {} {} : {} :=", fn_name, params, ret_type),
format!(" {} ({}) {}", helper_name, fuel_expr, arg_names),
]
}
fn rewrite_recursive_calls_expr(
expr: &Spanned<Expr>,
targets: &HashSet<String>,
fuel_var: &str,
) -> Spanned<Expr> {
let line = expr.line;
let new_node = match &expr.node {
Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => return expr.clone(),
Expr::Attr(obj, field) => Expr::Attr(
Box::new(rewrite_recursive_calls_expr(obj, targets, fuel_var)),
field.clone(),
),
Expr::FnCall(callee, args) => {
let rewritten_args: Vec<Spanned<Expr>> = args
.iter()
.map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
.collect();
if let Some(name) = expr_to_dotted_name(&callee.node)
&& targets.contains(&name)
{
let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
call_args.extend(rewritten_args);
Expr::FnCall(
Box::new(Spanned::new(Expr::Ident(fuel_helper_name(&name)), line)),
call_args,
)
} else {
Expr::FnCall(
Box::new(rewrite_recursive_calls_expr(callee, targets, fuel_var)),
rewritten_args,
)
}
}
Expr::BinOp(op, left, right) => Expr::BinOp(
*op,
Box::new(rewrite_recursive_calls_expr(left, targets, fuel_var)),
Box::new(rewrite_recursive_calls_expr(right, targets, fuel_var)),
),
Expr::Match { subject, arms } => Expr::Match {
subject: Box::new(rewrite_recursive_calls_expr(subject, targets, fuel_var)),
arms: arms
.iter()
.map(|arm| MatchArm {
pattern: arm.pattern.clone(),
body: Box::new(rewrite_recursive_calls_expr(&arm.body, targets, fuel_var)),
})
.collect(),
},
Expr::Constructor(name, arg) => Expr::Constructor(
name.clone(),
arg.as_ref()
.map(|inner| Box::new(rewrite_recursive_calls_expr(inner, targets, fuel_var))),
),
Expr::ErrorProp(inner) => Expr::ErrorProp(Box::new(rewrite_recursive_calls_expr(
inner, targets, fuel_var,
))),
Expr::InterpolatedStr(parts) => Expr::InterpolatedStr(
parts
.iter()
.map(|part| match part {
StrPart::Literal(_) => part.clone(),
StrPart::Parsed(inner) => StrPart::Parsed(Box::new(
rewrite_recursive_calls_expr(inner, targets, fuel_var),
)),
})
.collect(),
),
Expr::List(items) => Expr::List(
items
.iter()
.map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
.collect(),
),
Expr::Tuple(items) => Expr::Tuple(
items
.iter()
.map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
.collect(),
),
Expr::IndependentProduct(items, flag) => Expr::IndependentProduct(
items
.iter()
.map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
.collect(),
*flag,
),
Expr::MapLiteral(entries) => Expr::MapLiteral(
entries
.iter()
.map(|(k, v)| {
(
rewrite_recursive_calls_expr(k, targets, fuel_var),
rewrite_recursive_calls_expr(v, targets, fuel_var),
)
})
.collect(),
),
Expr::RecordCreate { type_name, fields } => Expr::RecordCreate {
type_name: type_name.clone(),
fields: fields
.iter()
.map(|(name, value)| {
(
name.clone(),
rewrite_recursive_calls_expr(value, targets, fuel_var),
)
})
.collect(),
},
Expr::RecordUpdate {
type_name,
base,
updates,
} => Expr::RecordUpdate {
type_name: type_name.clone(),
base: Box::new(rewrite_recursive_calls_expr(base, targets, fuel_var)),
updates: updates
.iter()
.map(|(name, value)| {
(
name.clone(),
rewrite_recursive_calls_expr(value, targets, fuel_var),
)
})
.collect(),
},
Expr::TailCall(boxed) => {
let TailCallData { target, args, .. } = boxed.as_ref();
let rewritten_args: Vec<Spanned<Expr>> = args
.iter()
.map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
.collect();
if targets.contains(target) {
let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
call_args.extend(rewritten_args);
Expr::FnCall(
Box::new(Spanned::new(Expr::Ident(fuel_helper_name(target)), line)),
call_args,
)
} else {
Expr::TailCall(Box::new(TailCallData::new(target.clone(), rewritten_args)))
}
}
};
Spanned::new(new_node, line)
}
fn rewrite_recursive_calls_body(
body: &FnBody,
targets: &HashSet<String>,
fuel_var: &str,
) -> FnBody {
ast::FnBody::Block(
body.stmts()
.iter()
.map(|stmt| match stmt {
Stmt::Binding(name, ty, expr) => Stmt::Binding(
name.clone(),
ty.clone(),
rewrite_recursive_calls_expr(expr, targets, fuel_var),
),
Stmt::Expr(expr) => {
Stmt::Expr(rewrite_recursive_calls_expr(expr, targets, fuel_var))
}
})
.collect(),
)
}
fn emit_fuelized_string_pos_fn(fd: &FnDef, ctx: &CodegenContext) -> String {
let helper_name = fuel_helper_name(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let rewritten = rewrite_recursive_calls_body(
&fd.body,
&HashSet::from([fd.name.clone()]),
STRING_POS_FUEL_VAR,
);
let body = emit_fn_body_for(fd, &rewritten, ctx);
[
emit_doc_comment(&fd.desc),
emit_fuel_helper_def(&helper_name, ¶ms, &ret_type, &body, ""),
vec![String::new()],
emit_string_pos_wrapper(fd, &helper_name, 1),
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n")
}
fn strip_match_eq_binders(body: String) -> String {
body.lines()
.map(|line| {
let trimmed = line.trim_start();
let indent_len = line.len() - trimmed.len();
let indent = &line[..indent_len];
let Some(rest) = trimmed.strip_prefix("match h_") else {
return line.to_string();
};
let Some(colon_idx) = rest.find(" : ") else {
return line.to_string();
};
format!("{indent}match {}", &rest[colon_idx + 3..])
})
.collect::<Vec<_>>()
.join("\n")
}
fn emit_fuelized_int_countdown_fn(fd: &FnDef, ctx: &CodegenContext, param_index: usize) -> String {
let helper_name = fuel_helper_name(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let rewritten = rewrite_recursive_calls_body(
&fd.body,
&HashSet::from([fd.name.clone()]),
STRING_POS_FUEL_VAR,
);
let body = strip_match_eq_binders(emit_fn_body_for(fd, &rewritten, ctx));
[
emit_doc_comment(&fd.desc),
emit_fuel_helper_def(&helper_name, ¶ms, &ret_type, &body, ""),
vec![String::new()],
emit_int_countdown_wrapper(fd, &helper_name, param_index),
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n")
}
fn emit_fuelized_int_ascending_fn(
fd: &FnDef,
ctx: &CodegenContext,
param_index: usize,
bound_lean: &str,
) -> String {
let helper_name = fuel_helper_name(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let rewritten = rewrite_recursive_calls_body(
&fd.body,
&HashSet::from([fd.name.clone()]),
STRING_POS_FUEL_VAR,
);
let body = strip_match_eq_binders(emit_fn_body_for(fd, &rewritten, ctx));
[
emit_doc_comment(&fd.desc),
emit_fuel_helper_def(&helper_name, ¶ms, &ret_type, &body, ""),
vec![String::new()],
emit_int_ascending_wrapper(fd, &helper_name, param_index, bound_lean),
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n")
}
fn emit_int_ascending_wrapper(
fd: &FnDef,
helper_name: &str,
param_index: usize,
bound_lean: &str,
) -> Vec<String> {
let fn_name = super::expr::aver_name_to_lean(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let arg_names = emit_fn_param_names(&fd.params);
let metric_name = fd
.params
.get(param_index)
.map(|(name, _)| super::expr::aver_name_to_lean(name))
.unwrap_or_else(|| "0".to_string());
vec![
format!("def {} {} : {} :=", fn_name, params, ret_type),
format!(
" {} ((Int.natAbs ({} - {})) + 1) {}",
helper_name, bound_lean, metric_name, arg_names
),
]
}
fn emit_fuelized_sizeof_fn(fd: &FnDef, ctx: &CodegenContext) -> String {
let helper_name = fuel_helper_name(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let recursive_types: HashSet<String> = ctx
.modules
.iter()
.flat_map(|m| m.type_defs.iter())
.chain(ctx.type_defs.iter())
.filter(|td| is_recursive_type_def(td))
.map(|td| type_def_name(td).to_string())
.collect();
let rewritten = rewrite_recursive_calls_body(
&fd.body,
&HashSet::from([fd.name.clone()]),
STRING_POS_FUEL_VAR,
);
let body = emit_fn_body_for(fd, &rewritten, ctx);
[
emit_doc_comment(&fd.desc),
emit_fuel_helper_def(&helper_name, ¶ms, &ret_type, &body, ""),
vec![String::new()],
emit_mutual_sizeof_wrapper(fd, &helper_name, 1, &recursive_types),
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n")
}
fn emit_fuelized_mutual_string_pos_group(
fns: &[&FnDef],
ctx: &CodegenContext,
plans: &HashMap<String, RecursionPlan>,
) -> String {
let targets: HashSet<String> = fns.iter().map(|fd| fd.name.clone()).collect();
let max_rank = fns
.iter()
.filter_map(|fd| match plans.get(&fd.name) {
Some(RecursionPlan::MutualStringPosAdvance { rank }) => Some(*rank),
_ => None,
})
.max()
.unwrap_or(1);
let mut helper_lines = vec!["mutual".to_string()];
for fd in fns {
if !is_pure_fn(fd) {
continue;
}
let helper_name = fuel_helper_name(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let rewritten = rewrite_recursive_calls_body(&fd.body, &targets, STRING_POS_FUEL_VAR);
let body = emit_fn_body_for(fd, &rewritten, ctx);
helper_lines.extend(
emit_doc_comment(&fd.desc)
.into_iter()
.map(|line| format!(" {line}")),
);
helper_lines.extend(emit_fuel_helper_def(
&helper_name,
¶ms,
&ret_type,
&body,
" ",
));
helper_lines.push(String::new());
}
helper_lines.push("end".to_string());
let wrapper_lines: Vec<String> = fns
.iter()
.filter(|fd| is_pure_fn(fd))
.flat_map(|fd| {
let helper_name = fuel_helper_name(&fd.name);
let mut lines = emit_string_pos_wrapper(fd, &helper_name, max_rank);
lines.push(String::new());
lines
})
.collect();
[helper_lines, vec![String::new()], wrapper_lines]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n")
}
fn emit_fuelized_mutual_int_countdown_group(fns: &[&FnDef], ctx: &CodegenContext) -> String {
let targets: HashSet<String> = fns.iter().map(|fd| fd.name.clone()).collect();
let mut helper_lines = vec!["mutual".to_string()];
for fd in fns {
if !is_pure_fn(fd) {
continue;
}
let helper_name = fuel_helper_name(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let rewritten = rewrite_recursive_calls_body(&fd.body, &targets, STRING_POS_FUEL_VAR);
let body = strip_match_eq_binders(emit_fn_body_for(fd, &rewritten, ctx));
helper_lines.extend(
emit_doc_comment(&fd.desc)
.into_iter()
.map(|line| format!(" {line}")),
);
helper_lines.extend(emit_fuel_helper_def(
&helper_name,
¶ms,
&ret_type,
&body,
" ",
));
helper_lines.push(String::new());
}
helper_lines.push("end".to_string());
let wrapper_lines: Vec<String> = fns
.iter()
.filter(|fd| is_pure_fn(fd))
.flat_map(|fd| {
let helper_name = fuel_helper_name(&fd.name);
let mut lines = emit_int_countdown_wrapper(fd, &helper_name, 0);
lines.push(String::new());
lines
})
.collect();
[helper_lines, vec![String::new()], wrapper_lines]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n")
}
fn emit_fuelized_mutual_sizeof_group(
fns: &[&FnDef],
ctx: &CodegenContext,
plans: &HashMap<String, RecursionPlan>,
) -> String {
let targets: HashSet<String> = fns.iter().map(|fd| fd.name.clone()).collect();
let recursive_types: HashSet<String> = ctx
.modules
.iter()
.flat_map(|m| m.type_defs.iter())
.chain(ctx.type_defs.iter())
.filter(|td| is_recursive_type_def(td))
.map(|td| type_def_name(td).to_string())
.collect();
let rank_budget = fns
.iter()
.filter_map(|fd| match plans.get(&fd.name) {
Some(RecursionPlan::MutualSizeOfRanked { rank }) => Some(*rank),
_ => None,
})
.max()
.unwrap_or(1)
+ 1;
let mut helper_lines = vec!["mutual".to_string()];
for fd in fns {
if !is_pure_fn(fd) {
continue;
}
let helper_name = fuel_helper_name(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = ret_type_or_unit(fd);
let rewritten = rewrite_recursive_calls_body(&fd.body, &targets, STRING_POS_FUEL_VAR);
let body = emit_fn_body_for(fd, &rewritten, ctx);
helper_lines.extend(
emit_doc_comment(&fd.desc)
.into_iter()
.map(|line| format!(" {line}")),
);
helper_lines.extend(emit_fuel_helper_def(
&helper_name,
¶ms,
&ret_type,
&body,
" ",
));
helper_lines.push(String::new());
}
helper_lines.push("end".to_string());
let wrapper_lines: Vec<String> = fns
.iter()
.filter(|fd| is_pure_fn(fd))
.flat_map(|fd| {
let helper_name = fuel_helper_name(&fd.name);
let mut lines =
emit_mutual_sizeof_wrapper(fd, &helper_name, rank_budget, &recursive_types);
lines.push(String::new());
lines
})
.collect();
[helper_lines, vec![String::new()], wrapper_lines]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n")
}
pub fn is_pure_fn(fd: &FnDef) -> bool {
fd.effects.is_empty() && fd.name != "main"
}
pub fn emit_fn_def(
fd: &FnDef,
recursive_fns: &HashSet<String>,
ctx: &CodegenContext,
) -> Option<String> {
if !is_pure_fn(fd) {
return None;
}
let mut lines = Vec::new();
if let Some(desc) = &fd.desc {
lines.push(format!("/-- {} -/", desc));
}
let is_recursive = recursive_fns.contains(&fd.name);
let fn_name = aver_name_to_lean(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = if fd.return_type.is_empty() {
"Unit".to_string()
} else {
type_annotation_to_lean(&fd.return_type)
};
let prefix = if is_recursive { "partial " } else { "" };
lines.push(format!(
"{}def {} {} : {} :=",
prefix, fn_name, params, ret_type
));
lines.push(emit_fn_body_for(fd, &fd.body, ctx));
Some(lines.join("\n"))
}
pub fn emit_fn_def_proof(
fd: &FnDef,
recursion_plan: Option<RecursionPlan>,
ctx: &CodegenContext,
) -> Option<String> {
if !is_pure_fn(fd) {
return None;
}
if matches!(recursion_plan, Some(RecursionPlan::LinearRecurrence2))
&& let Some(shape) = detect_second_order_int_linear_recurrence(fd)
{
return Some(emit_nat_linear_recurrence_fn(fd, &shape, ctx));
}
if let Some(RecursionPlan::IntCountdown { param_index }) = recursion_plan {
return Some(emit_fuelized_int_countdown_fn(fd, ctx, param_index));
}
if let Some(RecursionPlan::IntAscending {
param_index,
ref bound_lean,
}) = recursion_plan
{
return Some(emit_fuelized_int_ascending_fn(
fd,
ctx,
param_index,
bound_lean,
));
}
if matches!(recursion_plan, Some(RecursionPlan::SizeOfStructural)) {
return Some(emit_fuelized_sizeof_fn(fd, ctx));
}
if matches!(recursion_plan, Some(RecursionPlan::StringPosAdvance)) {
return Some(emit_fuelized_string_pos_fn(fd, ctx));
}
let mut lines = Vec::new();
if let Some(desc) = &fd.desc {
lines.push(format!("/-- {} -/", desc));
}
let fn_name = aver_name_to_lean(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = if fd.return_type.is_empty() {
"Unit".to_string()
} else {
type_annotation_to_lean(&fd.return_type)
};
lines.push(format!("def {} {} : {} :=", fn_name, params, ret_type));
lines.push(emit_fn_body_for(fd, &fd.body, ctx));
if let Some(plan) = recursion_plan {
match plan {
RecursionPlan::LinearRecurrence2 => {}
RecursionPlan::IntCountdown { .. } => {}
RecursionPlan::IntAscending { .. } => {}
RecursionPlan::MutualIntCountdown => {
let Some((param_name, _)) = fd.params.first() else {
return Some(lines.join("\n"));
};
let lean_param = aver_name_to_lean(param_name);
lines.push(format!("termination_by Int.natAbs {}", lean_param));
lines.push("decreasing_by".to_string());
lines.push(" omega".to_string());
}
RecursionPlan::ListStructural { param_index } => {
let Some((param_name, _)) = fd.params.get(param_index) else {
return Some(lines.join("\n"));
};
let lean_param = aver_name_to_lean(param_name);
lines.push(format!("termination_by {}.length", lean_param));
lines.push("decreasing_by".to_string());
lines.push(" decreasing_tactic".to_string());
}
RecursionPlan::SizeOfStructural => {}
RecursionPlan::StringPosAdvance => {}
RecursionPlan::MutualStringPosAdvance { .. }
| RecursionPlan::MutualSizeOfRanked { .. } => {}
}
}
Some(lines.join("\n"))
}
fn emit_fn_params(params: &[(String, String)]) -> String {
params
.iter()
.map(|(name, type_ann)| {
let lean_type = type_annotation_to_lean(type_ann);
let lean_name = aver_name_to_lean(name);
format!("({} : {})", lean_name, lean_type)
})
.collect::<Vec<_>>()
.join(" ")
}
fn expr_uses_error_prop(expr: &Spanned<Expr>) -> bool {
match &expr.node {
Expr::ErrorProp(_) => true,
Expr::FnCall(callee, args) => {
expr_uses_error_prop(callee) || args.iter().any(expr_uses_error_prop)
}
Expr::Attr(obj, _) => expr_uses_error_prop(obj),
Expr::BinOp(_, left, right) => expr_uses_error_prop(left) || expr_uses_error_prop(right),
Expr::Match { subject, arms, .. } => {
expr_uses_error_prop(subject) || arms.iter().any(|arm| expr_uses_error_prop(&arm.body))
}
Expr::Constructor(_, Some(inner)) => expr_uses_error_prop(inner),
Expr::InterpolatedStr(parts) => parts.iter().any(|part| match part {
StrPart::Parsed(expr) => expr_uses_error_prop(expr),
StrPart::Literal(_) => false,
}),
Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
items.iter().any(expr_uses_error_prop)
}
Expr::MapLiteral(entries) => entries
.iter()
.any(|(key, value)| expr_uses_error_prop(key) || expr_uses_error_prop(value)),
Expr::RecordCreate { fields, .. } => {
fields.iter().any(|(_, value)| expr_uses_error_prop(value))
}
Expr::RecordUpdate { base, updates, .. } => {
expr_uses_error_prop(base)
|| updates.iter().any(|(_, value)| expr_uses_error_prop(value))
}
Expr::TailCall(boxed) => boxed.args.iter().any(expr_uses_error_prop),
Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {
false
}
}
}
fn body_uses_error_prop(body: &FnBody) -> bool {
body.stmts().iter().any(|stmt| match stmt {
Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => expr_uses_error_prop(expr),
})
}
fn fn_returns_result(fd: &FnDef) -> bool {
matches!(
crate::types::parse_type_str(&fd.return_type),
crate::types::Type::Result(_, _)
)
}
fn emit_do_stmt(stmt: &Stmt, ctx: &CodegenContext, is_last: bool) -> String {
match stmt {
Stmt::Binding(name, _, expr) if matches!(&expr.node, Expr::ErrorProp(_)) => {
let Expr::ErrorProp(inner) = &expr.node else {
unreachable!()
};
format!(
" let {} <- {}",
aver_name_to_lean(name),
emit_expr(inner, ctx)
)
}
Stmt::Binding(name, _, expr) => format!(
" let {} := {}",
aver_name_to_lean(name),
emit_expr(expr, ctx)
),
Stmt::Expr(expr) if matches!(&expr.node, Expr::ErrorProp(_)) && is_last => {
let Expr::ErrorProp(inner) = &expr.node else {
unreachable!()
};
format!(" {}", emit_expr(inner, ctx))
}
Stmt::Expr(expr) if matches!(&expr.node, Expr::ErrorProp(_)) => {
let Expr::ErrorProp(inner) = &expr.node else {
unreachable!()
};
format!(" let _ <- {}", emit_expr(inner, ctx))
}
Stmt::Expr(expr) if is_last => format!(" {}", emit_expr(expr, ctx)),
Stmt::Expr(expr) => format!(" let _ := {}", emit_expr(expr, ctx)),
}
}
fn emit_fn_body(body: &FnBody, ctx: &CodegenContext) -> String {
let stmts = body.stmts();
let mut lines = Vec::new();
for (i, stmt) in stmts.iter().enumerate() {
let is_last = i == stmts.len() - 1;
match stmt {
Stmt::Binding(_, _, _) => {
lines.push(format!(" {}", emit_stmt(stmt, ctx)));
}
Stmt::Expr(expr) => {
if is_last {
lines.push(format!(" {}", emit_expr(expr, ctx)));
} else {
lines.push(format!(" let _ := {}", emit_expr(expr, ctx)));
}
}
}
}
lines.join("\n")
}
fn emit_fn_body_result_do(body: &FnBody, ctx: &CodegenContext) -> String {
let stmts = body.stmts();
let mut lines = vec![" do".to_string()];
for (i, stmt) in stmts.iter().enumerate() {
lines.push(emit_do_stmt(stmt, ctx, i == stmts.len() - 1));
}
lines.join("\n")
}
fn emit_fn_body_for(fd: &FnDef, body: &FnBody, ctx: &CodegenContext) -> String {
if fn_returns_result(fd) && body_uses_error_prop(body) {
emit_fn_body_result_do(body, ctx)
} else {
emit_fn_body(body, ctx)
}
}
pub fn emit_verify_block(
vb: &VerifyBlock,
ctx: &CodegenContext,
verify_mode: VerifyEmitMode,
case_index_start: usize,
) -> (String, usize) {
if let VerifyKind::Law(law) = &vb.kind {
return emit_verify_law_block(vb, law, ctx, verify_mode, case_index_start);
}
let mut lines = Vec::new();
for (idx, (left, right)) in vb.cases.iter().enumerate() {
let left_str = emit_expr(left, ctx);
let right_str = emit_expr(right, ctx);
match verify_mode {
VerifyEmitMode::NativeDecide => {
lines.push(format!(
"example : {} = {} := by native_decide",
left_str, right_str
));
}
VerifyEmitMode::Sorry => {
lines.push(format!(
"example : {} = {} := by sorry",
left_str, right_str
));
}
VerifyEmitMode::TheoremSkeleton => {
let theorem_name = format!(
"{}_verify_{}",
aver_name_to_lean(&vb.fn_name),
case_index_start + idx + 1
);
lines.push(format!(
"theorem {} : {} = {} := by",
theorem_name, left_str, right_str
));
lines.push(" sorry".to_string());
}
}
}
(lines.join("\n"), case_index_start + vb.cases.len())
}
fn emit_verify_law_block(
vb: &VerifyBlock,
law: &VerifyLaw,
ctx: &CodegenContext,
verify_mode: VerifyEmitMode,
case_index_start: usize,
) -> (String, usize) {
let mut lines = Vec::new();
let fn_name = aver_name_to_lean(&vb.fn_name);
let law_name = aver_name_to_lean(&law.name);
let spec_ref = canonical_spec_ref(&vb.fn_name, law, &ctx.fn_sigs);
let theorem_base = match &spec_ref {
Some(spec_ref) => format!(
"{}_eq_{}",
fn_name,
aver_name_to_lean(&spec_ref.spec_fn_name)
),
None => format!("{}_law_{}", fn_name, law_name),
};
let lhs_template = emit_expr(&law.lhs, ctx);
let rhs_template = emit_expr(&law.rhs, ctx);
let when_template = law.when.as_ref().map(|expr| emit_expr(expr, ctx));
let quant_params = law
.givens
.iter()
.map(|given| {
format!(
"({} : {})",
aver_name_to_lean(&given.name),
type_annotation_to_lean(&given.type_name)
)
})
.collect::<Vec<_>>()
.join(" ");
match &spec_ref {
Some(spec_ref) => lines.push(format!(
"-- verify law {}.spec {} ({} cases)",
fn_name,
spec_ref.spec_fn_name,
vb.cases.len()
)),
None => lines.push(format!(
"-- verify law {}.{} ({} cases)",
fn_name,
law_name,
vb.cases.len()
)),
}
for given in &law.givens {
lines.push(format!(
"-- given {}: {} = {}",
aver_name_to_lean(&given.name),
given.type_name,
law_given_domain_to_lean(&given.domain, ctx)
));
}
if let Some(when_expr) = &law.when {
lines.push(format!("-- when {}", emit_expr(when_expr, ctx)));
}
if !quant_params.is_empty() {
lines.extend(emit_verify_law_support_theorems(
vb,
law,
ctx,
&theorem_base,
));
let theorem_prop = law_theorem_prop(
law,
ctx,
&lhs_template,
&rhs_template,
when_template.as_deref(),
);
if let Some(auto_proof) = emit_verify_law_forall_auto_proof(
vb,
law,
ctx,
verify_mode,
&theorem_base,
&quant_params,
&theorem_prop,
) {
lines.extend(auto_proof.support_lines);
if !auto_proof.replaces_theorem {
lines.push(format!(
"theorem {} : ∀ {}, {} := by",
theorem_base, quant_params, theorem_prop
));
}
lines.extend(auto_proof.proof_lines);
} else {
lines.push(format!(
"theorem {} : ∀ {}, {} := by",
theorem_base, quant_params, theorem_prop
));
lines.push(
" -- verify law is sampled; universal proof must be provided manually".to_string(),
);
lines.push(" sorry".to_string());
}
}
if !vb.cases.is_empty() {
let domain_theorem_name = format!("{}_checked_domain", theorem_base);
let domain_prop = vb
.cases
.iter()
.enumerate()
.map(|(idx, (left, right))| {
let left_str = emit_expr(left, ctx);
let right_str = emit_expr(right, ctx);
if let Some(guard) = law.sample_guards.get(idx) {
format!(
"({} = true -> {} = {})",
emit_expr(guard, ctx),
left_str,
right_str
)
} else {
format!("({} = {})", left_str, right_str)
}
})
.collect::<Vec<_>>()
.join(" ∧ ");
match verify_mode {
VerifyEmitMode::NativeDecide => {
lines.push(format!(
"theorem {} : {} := by native_decide",
domain_theorem_name, domain_prop
));
}
VerifyEmitMode::Sorry => {
lines.push(format!(
"theorem {} : {} := by sorry",
domain_theorem_name, domain_prop
));
}
VerifyEmitMode::TheoremSkeleton => {
lines.push(format!(
"theorem {} : {} := by",
domain_theorem_name, domain_prop
));
lines.push(" sorry".to_string());
}
}
}
for (idx, (left, right)) in vb.cases.iter().enumerate() {
let theorem_name = format!("{}_sample_{}", theorem_base, case_index_start + idx + 1);
let left_str = emit_expr(left, ctx);
let right_str = emit_expr(right, ctx);
let sample_prop = if let Some(guard) = law.sample_guards.get(idx) {
format!(
"{} = true -> {} = {}",
emit_expr(guard, ctx),
left_str,
right_str
)
} else {
format!("{} = {}", left_str, right_str)
};
match verify_mode {
VerifyEmitMode::NativeDecide => {
lines.push(format!(
"theorem {} : {} := by native_decide",
theorem_name, sample_prop
));
}
VerifyEmitMode::Sorry => {
lines.push(format!(
"theorem {} : {} := by sorry",
theorem_name, sample_prop
));
}
VerifyEmitMode::TheoremSkeleton => {
lines.push(format!("theorem {} : {} := by", theorem_name, sample_prop));
lines.push(" sorry".to_string());
}
}
}
(lines.join("\n"), case_index_start + vb.cases.len())
}
fn law_theorem_prop(
law: &VerifyLaw,
ctx: &CodegenContext,
lhs_template: &str,
rhs_template: &str,
when_template: Option<&str>,
) -> String {
let mut premises = Vec::new();
if law.when.is_some() {
premises.extend(
law.givens
.iter()
.map(|given| law_given_domain_prop(given, ctx)),
);
}
if let Some(when_expr) = when_template {
premises.push(format!("{when_expr} = true"));
}
let conclusion = format!("{lhs_template} = {rhs_template}");
if premises.is_empty() {
conclusion
} else {
format!("{} -> {}", premises.join(" -> "), conclusion)
}
}
fn law_given_domain_to_lean(domain: &VerifyGivenDomain, ctx: &CodegenContext) -> String {
match domain {
VerifyGivenDomain::IntRange { start, end } => format!("{}..{}", start, end),
VerifyGivenDomain::Explicit(values) => format!(
"[{}]",
values
.iter()
.map(|v| emit_expr(v, ctx))
.collect::<Vec<_>>()
.join(", ")
),
}
}
fn law_given_domain_prop(given: &VerifyGiven, ctx: &CodegenContext) -> String {
let given_name = aver_name_to_lean(&given.name);
let values = law_given_domain_values(&given.domain);
match values.as_slice() {
[] => "False".to_string(),
[value] => format!("{given_name} = {}", emit_expr(value, ctx)),
_ => values
.iter()
.map(|value| format!("{given_name} = {}", emit_expr(value, ctx)))
.collect::<Vec<_>>()
.join(" ∨ "),
}
}
pub(super) fn law_given_domain_values(domain: &VerifyGivenDomain) -> Vec<Spanned<Expr>> {
match domain {
VerifyGivenDomain::IntRange { start, end } => (*start..=*end)
.map(|n| Spanned::bare(Expr::Literal(Literal::Int(n))))
.collect(),
VerifyGivenDomain::Explicit(values) => values.clone(),
}
}
pub fn emit_decision(db: &DecisionBlock) -> String {
let mut lines = Vec::new();
lines.push(format!("/- Decision: {}", db.name));
lines.push(format!(" Date: {}", db.date));
lines.push(format!(" Reason: {}", db.reason));
lines.push(format!(" Chosen: {}", db.chosen.node.as_context_string()));
if !db.rejected.is_empty() {
lines.push(format!(
" Rejected: {}",
db.rejected
.iter()
.map(|r| r.node.as_context_string())
.collect::<Vec<_>>()
.join(", ")
));
}
if !db.impacts.is_empty() {
let impacts = db
.impacts
.iter()
.map(|impact| impact.node.as_context_string())
.collect::<Vec<_>>()
.join(", ");
lines.push(format!(" Impacts: {}", impacts));
}
if let Some(author) = &db.author {
lines.push(format!(" Author: {}", author));
}
lines.push("-/".to_string());
lines.join("\n")
}
pub fn emit_mutual_group(fns: &[&FnDef], ctx: &CodegenContext) -> String {
let mut lines = Vec::new();
lines.push("mutual".to_string());
for fd in fns {
if !is_pure_fn(fd) {
continue;
}
if let Some(desc) = &fd.desc {
lines.push(format!(" /-- {} -/", desc));
}
let fn_name = aver_name_to_lean(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = if fd.return_type.is_empty() {
"Unit".to_string()
} else {
type_annotation_to_lean(&fd.return_type)
};
lines.push(format!(
" partial def {} {} : {} :=",
fn_name, params, ret_type
));
let body = emit_fn_body_for(fd, &fd.body, ctx);
for line in body.lines() {
lines.push(format!(" {}", line));
}
lines.push(String::new());
}
lines.push("end".to_string());
lines.join("\n")
}
pub fn emit_mutual_group_proof(
fns: &[&FnDef],
ctx: &CodegenContext,
plans: &std::collections::HashMap<String, RecursionPlan>,
) -> String {
if fns
.iter()
.all(|fd| matches!(plans.get(&fd.name), Some(RecursionPlan::MutualIntCountdown)))
{
return emit_fuelized_mutual_int_countdown_group(fns, ctx);
}
if fns.iter().all(|fd| {
matches!(
plans.get(&fd.name),
Some(RecursionPlan::MutualStringPosAdvance { .. })
)
}) {
return emit_fuelized_mutual_string_pos_group(fns, ctx, plans);
}
if fns.iter().all(|fd| {
matches!(
plans.get(&fd.name),
Some(RecursionPlan::MutualSizeOfRanked { .. })
)
}) {
return emit_fuelized_mutual_sizeof_group(fns, ctx, plans);
}
let mut lines = Vec::new();
lines.push("mutual".to_string());
for fd in fns {
if !is_pure_fn(fd) {
continue;
}
if let Some(desc) = &fd.desc {
lines.push(format!(" /-- {} -/", desc));
}
let fn_name = aver_name_to_lean(&fd.name);
let params = emit_fn_params(&fd.params);
let ret_type = if fd.return_type.is_empty() {
"Unit".to_string()
} else {
type_annotation_to_lean(&fd.return_type)
};
lines.push(format!(" def {} {} : {} :=", fn_name, params, ret_type));
let body = emit_fn_body_for(fd, &fd.body, ctx);
for line in body.lines() {
lines.push(format!(" {}", line));
}
match plans.get(&fd.name).cloned() {
Some(RecursionPlan::MutualIntCountdown) => {
if let Some((first_name, _)) = fd.params.first() {
let lean_first = aver_name_to_lean(first_name);
lines.push(format!(" termination_by Int.natAbs {}", lean_first));
lines.push(" decreasing_by".to_string());
lines.push(" omega".to_string());
}
}
Some(RecursionPlan::MutualStringPosAdvance { rank }) => {
if let Some((s_name, _)) = fd.params.first()
&& let Some((pos_name, _)) = fd.params.get(1)
{
let lean_s = aver_name_to_lean(s_name);
let lean_pos = aver_name_to_lean(pos_name);
lines.push(format!(
" termination_by (({}.data.length) - ({}.toNat), {})",
lean_s, lean_pos, rank
));
lines.push(" decreasing_by".to_string());
lines.push(" simp_wf".to_string());
}
}
Some(RecursionPlan::MutualSizeOfRanked { .. }) => {}
_ => {}
}
lines.push(String::new());
}
lines.push("end".to_string());
lines.join("\n")
}