pub mod detect;
use std::collections::HashSet;
use crate::ast::{Expr, FnBody, MatchArm, Spanned, Stmt, StrPart, TailCallData};
use crate::codegen::common::expr_to_dotted_name;
pub use detect::analyze_plans;
#[derive(Clone, Debug, PartialEq)]
pub enum RecursionPlan {
IntCountdown { param_index: usize },
IntAscending {
param_index: usize,
bound: Spanned<Expr>,
},
LinearRecurrence2,
ListStructural { param_index: usize },
SizeOfStructural,
StringPosAdvance,
MutualIntCountdown,
MutualStringPosAdvance { rank: usize },
MutualSizeOfRanked { rank: usize },
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ProofModeIssue {
pub line: usize,
pub message: String,
}
pub fn fuel_helper_name(name: &str) -> String {
format!("{}__fuel", name)
}
pub fn rewrite_recursive_calls_expr(
expr: &Spanned<Expr>,
targets: &HashSet<String>,
fuel_var: &str,
) -> Spanned<Expr> {
let line = expr.line;
let new_node = match &expr.node {
Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => return expr.clone(),
Expr::Attr(obj, field) => Expr::Attr(
Box::new(rewrite_recursive_calls_expr(obj, targets, fuel_var)),
field.clone(),
),
Expr::FnCall(callee, args) => {
let rewritten_args: Vec<Spanned<Expr>> = args
.iter()
.map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
.collect();
if let Some(name) = expr_to_dotted_name(&callee.node)
&& targets.contains(&name)
{
let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
call_args.extend(rewritten_args);
Expr::FnCall(
Box::new(Spanned::new(Expr::Ident(fuel_helper_name(&name)), line)),
call_args,
)
} else {
Expr::FnCall(
Box::new(rewrite_recursive_calls_expr(callee, targets, fuel_var)),
rewritten_args,
)
}
}
Expr::BinOp(op, left, right) => Expr::BinOp(
*op,
Box::new(rewrite_recursive_calls_expr(left, targets, fuel_var)),
Box::new(rewrite_recursive_calls_expr(right, targets, fuel_var)),
),
Expr::Match { subject, arms } => Expr::Match {
subject: Box::new(rewrite_recursive_calls_expr(subject, targets, fuel_var)),
arms: arms
.iter()
.map(|arm| MatchArm {
pattern: arm.pattern.clone(),
body: Box::new(rewrite_recursive_calls_expr(&arm.body, targets, fuel_var)),
binding_slots: std::sync::OnceLock::new(),
})
.collect(),
},
Expr::Constructor(name, arg) => Expr::Constructor(
name.clone(),
arg.as_ref()
.map(|inner| Box::new(rewrite_recursive_calls_expr(inner, targets, fuel_var))),
),
Expr::ErrorProp(inner) => Expr::ErrorProp(Box::new(rewrite_recursive_calls_expr(
inner, targets, fuel_var,
))),
Expr::InterpolatedStr(parts) => Expr::InterpolatedStr(
parts
.iter()
.map(|part| match part {
StrPart::Literal(_) => part.clone(),
StrPart::Parsed(inner) => StrPart::Parsed(Box::new(
rewrite_recursive_calls_expr(inner, targets, fuel_var),
)),
})
.collect(),
),
Expr::List(items) => Expr::List(
items
.iter()
.map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
.collect(),
),
Expr::Tuple(items) => Expr::Tuple(
items
.iter()
.map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
.collect(),
),
Expr::IndependentProduct(items, flag) => Expr::IndependentProduct(
items
.iter()
.map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
.collect(),
*flag,
),
Expr::MapLiteral(entries) => Expr::MapLiteral(
entries
.iter()
.map(|(k, v)| {
(
rewrite_recursive_calls_expr(k, targets, fuel_var),
rewrite_recursive_calls_expr(v, targets, fuel_var),
)
})
.collect(),
),
Expr::RecordCreate { type_name, fields } => Expr::RecordCreate {
type_name: type_name.clone(),
fields: fields
.iter()
.map(|(name, value)| {
(
name.clone(),
rewrite_recursive_calls_expr(value, targets, fuel_var),
)
})
.collect(),
},
Expr::RecordUpdate {
type_name,
base,
updates,
} => Expr::RecordUpdate {
type_name: type_name.clone(),
base: Box::new(rewrite_recursive_calls_expr(base, targets, fuel_var)),
updates: updates
.iter()
.map(|(name, value)| {
(
name.clone(),
rewrite_recursive_calls_expr(value, targets, fuel_var),
)
})
.collect(),
},
Expr::TailCall(boxed) => {
let TailCallData { target, args, .. } = boxed.as_ref();
let rewritten_args: Vec<Spanned<Expr>> = args
.iter()
.map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
.collect();
if targets.contains(target) {
let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
call_args.extend(rewritten_args);
Expr::FnCall(
Box::new(Spanned::new(Expr::Ident(fuel_helper_name(target)), line)),
call_args,
)
} else {
Expr::TailCall(Box::new(TailCallData::new(target.clone(), rewritten_args)))
}
}
};
Spanned::new(new_node, line)
}
pub fn rewrite_recursive_calls_body(
body: &FnBody,
targets: &HashSet<String>,
fuel_var: &str,
) -> FnBody {
FnBody::Block(
body.stmts()
.iter()
.map(|stmt| match stmt {
Stmt::Binding(name, ty, expr) => Stmt::Binding(
name.clone(),
ty.clone(),
rewrite_recursive_calls_expr(expr, targets, fuel_var),
),
Stmt::Expr(expr) => {
Stmt::Expr(rewrite_recursive_calls_expr(expr, targets, fuel_var))
}
})
.collect(),
)
}