1pub mod detect;
14
15use std::collections::HashSet;
16
17use crate::ast::{Expr, FnBody, MatchArm, Spanned, Stmt, StrPart, TailCallData};
18use crate::codegen::common::expr_to_dotted_name;
19
20pub use detect::analyze_plans;
21
22#[derive(Clone, Debug, PartialEq)]
31pub enum RecursionPlan {
32 IntCountdown { param_index: usize },
35 IntAscending {
40 param_index: usize,
41 bound: Spanned<Expr>,
42 },
43 LinearRecurrence2,
47 ListStructural { param_index: usize },
50 SizeOfStructural,
53 StringPosAdvance,
57 MutualIntCountdown,
60 MutualStringPosAdvance { rank: usize },
64 MutualSizeOfRanked { rank: usize },
67}
68
69#[derive(Clone, Debug, Eq, PartialEq)]
73pub struct ProofModeIssue {
74 pub line: usize,
75 pub message: String,
76}
77
78pub fn fuel_helper_name(name: &str) -> String {
82 format!("{}__fuel", name)
83}
84
85pub fn rewrite_recursive_calls_expr(
90 expr: &Spanned<Expr>,
91 targets: &HashSet<String>,
92 fuel_var: &str,
93) -> Spanned<Expr> {
94 let line = expr.line;
95 let new_node = match &expr.node {
96 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => return expr.clone(),
97 Expr::Attr(obj, field) => Expr::Attr(
98 Box::new(rewrite_recursive_calls_expr(obj, targets, fuel_var)),
99 field.clone(),
100 ),
101 Expr::FnCall(callee, args) => {
102 let rewritten_args: Vec<Spanned<Expr>> = args
103 .iter()
104 .map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
105 .collect();
106 if let Some(name) = expr_to_dotted_name(&callee.node)
107 && targets.contains(&name)
108 {
109 let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
110 call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
111 call_args.extend(rewritten_args);
112 Expr::FnCall(
113 Box::new(Spanned::new(Expr::Ident(fuel_helper_name(&name)), line)),
114 call_args,
115 )
116 } else {
117 Expr::FnCall(
118 Box::new(rewrite_recursive_calls_expr(callee, targets, fuel_var)),
119 rewritten_args,
120 )
121 }
122 }
123 Expr::BinOp(op, left, right) => Expr::BinOp(
124 *op,
125 Box::new(rewrite_recursive_calls_expr(left, targets, fuel_var)),
126 Box::new(rewrite_recursive_calls_expr(right, targets, fuel_var)),
127 ),
128 Expr::Match { subject, arms } => Expr::Match {
129 subject: Box::new(rewrite_recursive_calls_expr(subject, targets, fuel_var)),
130 arms: arms
131 .iter()
132 .map(|arm| MatchArm {
133 pattern: arm.pattern.clone(),
134 body: Box::new(rewrite_recursive_calls_expr(&arm.body, targets, fuel_var)),
135 binding_slots: std::sync::OnceLock::new(),
136 })
137 .collect(),
138 },
139 Expr::Constructor(name, arg) => Expr::Constructor(
140 name.clone(),
141 arg.as_ref()
142 .map(|inner| Box::new(rewrite_recursive_calls_expr(inner, targets, fuel_var))),
143 ),
144 Expr::ErrorProp(inner) => Expr::ErrorProp(Box::new(rewrite_recursive_calls_expr(
145 inner, targets, fuel_var,
146 ))),
147 Expr::InterpolatedStr(parts) => Expr::InterpolatedStr(
148 parts
149 .iter()
150 .map(|part| match part {
151 StrPart::Literal(_) => part.clone(),
152 StrPart::Parsed(inner) => StrPart::Parsed(Box::new(
153 rewrite_recursive_calls_expr(inner, targets, fuel_var),
154 )),
155 })
156 .collect(),
157 ),
158 Expr::List(items) => Expr::List(
159 items
160 .iter()
161 .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
162 .collect(),
163 ),
164 Expr::Tuple(items) => Expr::Tuple(
165 items
166 .iter()
167 .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
168 .collect(),
169 ),
170 Expr::IndependentProduct(items, flag) => Expr::IndependentProduct(
171 items
172 .iter()
173 .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
174 .collect(),
175 *flag,
176 ),
177 Expr::MapLiteral(entries) => Expr::MapLiteral(
178 entries
179 .iter()
180 .map(|(k, v)| {
181 (
182 rewrite_recursive_calls_expr(k, targets, fuel_var),
183 rewrite_recursive_calls_expr(v, targets, fuel_var),
184 )
185 })
186 .collect(),
187 ),
188 Expr::RecordCreate { type_name, fields } => Expr::RecordCreate {
189 type_name: type_name.clone(),
190 fields: fields
191 .iter()
192 .map(|(name, value)| {
193 (
194 name.clone(),
195 rewrite_recursive_calls_expr(value, targets, fuel_var),
196 )
197 })
198 .collect(),
199 },
200 Expr::RecordUpdate {
201 type_name,
202 base,
203 updates,
204 } => Expr::RecordUpdate {
205 type_name: type_name.clone(),
206 base: Box::new(rewrite_recursive_calls_expr(base, targets, fuel_var)),
207 updates: updates
208 .iter()
209 .map(|(name, value)| {
210 (
211 name.clone(),
212 rewrite_recursive_calls_expr(value, targets, fuel_var),
213 )
214 })
215 .collect(),
216 },
217 Expr::TailCall(boxed) => {
218 let TailCallData { target, args, .. } = boxed.as_ref();
219 let rewritten_args: Vec<Spanned<Expr>> = args
220 .iter()
221 .map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
222 .collect();
223 if targets.contains(target) {
224 let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
225 call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
226 call_args.extend(rewritten_args);
227 Expr::FnCall(
228 Box::new(Spanned::new(Expr::Ident(fuel_helper_name(target)), line)),
229 call_args,
230 )
231 } else {
232 Expr::TailCall(Box::new(TailCallData::new(target.clone(), rewritten_args)))
233 }
234 }
235 };
236 Spanned::new(new_node, line)
237}
238
239pub fn rewrite_recursive_calls_body(
241 body: &FnBody,
242 targets: &HashSet<String>,
243 fuel_var: &str,
244) -> FnBody {
245 FnBody::Block(
246 body.stmts()
247 .iter()
248 .map(|stmt| match stmt {
249 Stmt::Binding(name, ty, expr) => Stmt::Binding(
250 name.clone(),
251 ty.clone(),
252 rewrite_recursive_calls_expr(expr, targets, fuel_var),
253 ),
254 Stmt::Expr(expr) => {
255 Stmt::Expr(rewrite_recursive_calls_expr(expr, targets, fuel_var))
256 }
257 })
258 .collect(),
259 )
260}