use std::collections::HashSet;
use super::expr::aver_name_to_dafny;
use crate::ast::{FnDef, TypeDef};
use crate::codegen::CodegenContext;
use crate::codegen::common::parse_type_annotation;
use crate::codegen::recursion::{RecursionPlan, fuel_helper_name, rewrite_recursive_calls_body};
use crate::types::Type;
pub fn emit_mutual_fuel_group(
fns: &[&FnDef],
ctx: &CodegenContext,
plans: &std::collections::HashMap<String, RecursionPlan>,
) -> Option<String> {
for fd in fns {
dafny_default_value(&fd.return_type, ctx)?;
}
let scc_size = fns.len();
let targets: HashSet<String> = fns.iter().map(|fd| fd.name.clone()).collect();
let mut helper_lines: Vec<String> = Vec::new();
let mut wrapper_lines: Vec<String> = Vec::new();
for fd in fns {
let plan = plans
.get(&fd.name)
.cloned()
.unwrap_or(RecursionPlan::MutualSizeOfRanked { rank: 1 });
let helper_name = fuel_helper_name(&fd.name);
let fn_name = aver_name_to_dafny(&fd.name);
let params_str = emit_dafny_params(&fd.params);
let ret_type_str = super::toplevel::emit_type(&fd.return_type);
let default_val = dafny_default_value(&fd.return_type, ctx)
.expect("default value presence is checked above");
let arg_names = emit_dafny_arg_names(&fd.params);
let metric = emit_fuel_metric(fd, &plan, scc_size);
let lowered_body = crate::types::checker::effect_lifting::lower_pure_question_bang_fn(fd)
.ok()
.flatten()
.map(|lowered| lowered.body.as_ref().clone())
.unwrap_or_else(|| fd.body.as_ref().clone());
let rewritten_body = rewrite_recursive_calls_body(&lowered_body, &targets, "fuel'");
let body_str = super::toplevel::emit_fn_body(&rewritten_body, ctx);
if let Some(desc) = &fd.desc {
helper_lines.push(format!("// {}", desc));
}
helper_lines.push(format!(
"function {}(fuel: nat, {}): {}",
helper_name, params_str, ret_type_str
));
helper_lines.push(" decreases fuel".to_string());
helper_lines.push("{".to_string());
helper_lines.push(format!(" if fuel == 0 then {}", default_val));
helper_lines.push(format!(" else var fuel' := fuel - 1; {}", body_str));
helper_lines.push("}\n".to_string());
wrapper_lines.push(format!(
"function {}({}): {}",
fn_name, params_str, ret_type_str
));
wrapper_lines.push("{".to_string());
wrapper_lines.push(format!(" {}({}, {})", helper_name, metric, arg_names));
wrapper_lines.push("}\n".to_string());
}
Some(
[helper_lines, wrapper_lines]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join("\n"),
)
}
fn emit_dafny_params(params: &[(String, String)]) -> String {
params
.iter()
.map(|(pname, ptype)| {
format!(
"{}: {}",
aver_name_to_dafny(pname),
super::toplevel::emit_type(ptype)
)
})
.collect::<Vec<_>>()
.join(", ")
}
fn emit_dafny_arg_names(params: &[(String, String)]) -> String {
params
.iter()
.map(|(pname, _)| aver_name_to_dafny(pname))
.collect::<Vec<_>>()
.join(", ")
}
fn emit_fuel_metric(fd: &FnDef, plan: &RecursionPlan, scc_size: usize) -> String {
match plan {
RecursionPlan::MutualIntCountdown => {
let Some(param) = first_int_param(fd) else {
return "1".to_string();
};
let name = aver_name_to_dafny(param);
format!("(if {n} >= 0 then {n} else 0) + 1", n = name)
}
RecursionPlan::MutualStringPosAdvance { rank }
| RecursionPlan::MutualSizeOfRanked { rank } => {
let Some(name) = first_seq_or_string_param(fd) else {
return format!("{}", rank.max(&1));
};
format!(
"(|{n}| + 1) * {budget}",
n = aver_name_to_dafny(name),
budget = rank * scc_size + 1
)
}
_ => "1".to_string(),
}
}
fn first_int_param(fd: &FnDef) -> Option<&String> {
fd.params
.iter()
.find(|(_, t)| parse_type_annotation(t) == Type::Int)
.map(|(n, _)| n)
}
fn first_seq_or_string_param(fd: &FnDef) -> Option<&String> {
fd.params
.iter()
.find(|(_, t)| {
let ty = parse_type_annotation(t);
matches!(ty, Type::List(_) | Type::Vector(_) | Type::Str)
})
.map(|(n, _)| n)
}
pub fn dafny_default_value(type_str: &str, ctx: &CodegenContext) -> Option<String> {
let mut visiting = HashSet::new();
type_default(&parse_type_annotation(type_str), ctx, &mut visiting)
}
fn type_default(ty: &Type, ctx: &CodegenContext, visiting: &mut HashSet<String>) -> Option<String> {
Some(match ty {
Type::Int => "0".to_string(),
Type::Float => "0.0".to_string(),
Type::Str => "\"\"".to_string(),
Type::Bool => "false".to_string(),
Type::Unit => "()".to_string(),
Type::List(_) | Type::Vector(_) => "[]".to_string(),
Type::Map(_, _) => "map[]".to_string(),
Type::Option(_) => "Option.None".to_string(),
Type::Result(_, err) => {
format!("Result.Err({})", type_default(err, ctx, visiting)?)
}
Type::Tuple(items) => {
let parts: Vec<String> = items
.iter()
.map(|t| type_default(t, ctx, visiting))
.collect::<Option<_>>()?;
format!("({})", parts.join(", "))
}
Type::Named(name) => named_type_default(name, ctx, visiting)?,
Type::Fn(_, _, _) | Type::Unknown => return None,
})
}
fn named_type_default(
name: &str,
ctx: &CodegenContext,
visiting: &mut HashSet<String>,
) -> Option<String> {
if visiting.contains(name) {
return None;
}
visiting.insert(name.to_string());
let td = find_type_def(ctx, name)?;
let result = type_def_default(td, ctx, visiting);
visiting.remove(name);
result
}
fn find_type_def<'a>(ctx: &'a CodegenContext, target: &str) -> Option<&'a TypeDef> {
ctx.type_defs
.iter()
.chain(ctx.modules.iter().flat_map(|m| m.type_defs.iter()))
.find(|td| crate::codegen::common::type_def_name(td) == target)
}
fn type_def_default(
td: &TypeDef,
ctx: &CodegenContext,
visiting: &mut HashSet<String>,
) -> Option<String> {
match td {
TypeDef::Sum { name, variants, .. } => {
let variant = variants.first()?;
if variant.fields.is_empty() {
Some(format!("{}.{}", name, variant.name))
} else {
let args: Vec<String> = variant
.fields
.iter()
.map(|ft| type_default(&parse_type_annotation(ft), ctx, visiting))
.collect::<Option<_>>()?;
Some(format!("{}.{}({})", name, variant.name, args.join(", ")))
}
}
TypeDef::Product { name, fields, .. } => {
let args: Vec<String> = fields
.iter()
.map(|(fname, ftype)| {
type_default(&parse_type_annotation(ftype), ctx, visiting)
.map(|v| format!("{} := {}", aver_name_to_dafny(fname), v))
})
.collect::<Option<_>>()?;
Some(format!("{}({})", name, args.join(", ")))
}
}
}