1pub mod detect;
14
15use std::collections::HashSet;
16
17use crate::ast::{Expr, FnBody, MatchArm, Spanned, Stmt, StrPart, TailCallData};
18use crate::codegen::common::expr_to_dotted_name;
19
20pub use detect::analyze_plans;
21
22#[derive(Clone, Debug, PartialEq)]
31pub enum RecursionPlan {
32 IntCountdown { param_index: usize },
35 IntAscending {
40 param_index: usize,
41 bound: Spanned<Expr>,
42 },
43 LinearRecurrence2,
47 ListStructural { param_index: usize },
50 SizeOfStructural,
53 StringPosAdvance,
57 MutualIntCountdown,
60 MutualStringPosAdvance { rank: usize },
64 MutualSizeOfRanked { rank: usize },
67}
68
69#[derive(Clone, Debug, Eq, PartialEq)]
73pub struct ProofModeIssue {
74 pub line: usize,
75 pub message: String,
76}
77
78pub fn fuel_helper_name(name: &str) -> String {
82 format!("{}__fuel", name)
83}
84
85pub fn rewrite_recursive_calls_expr(
90 expr: &Spanned<Expr>,
91 targets: &HashSet<String>,
92 fuel_var: &str,
93) -> Spanned<Expr> {
94 let line = expr.line;
95 let new_node = match &expr.node {
96 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => return expr.clone(),
97 Expr::Attr(obj, field) => Expr::Attr(
98 Box::new(rewrite_recursive_calls_expr(obj, targets, fuel_var)),
99 field.clone(),
100 ),
101 Expr::FnCall(callee, args) => {
102 let rewritten_args: Vec<Spanned<Expr>> = args
103 .iter()
104 .map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
105 .collect();
106 if let Some(name) = expr_to_dotted_name(&callee.node)
107 && targets.contains(&name)
108 {
109 let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
110 call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
111 call_args.extend(rewritten_args);
112 Expr::FnCall(
113 Box::new(Spanned::new(Expr::Ident(fuel_helper_name(&name)), line)),
114 call_args,
115 )
116 } else {
117 Expr::FnCall(
118 Box::new(rewrite_recursive_calls_expr(callee, targets, fuel_var)),
119 rewritten_args,
120 )
121 }
122 }
123 Expr::BinOp(op, left, right) => Expr::BinOp(
124 *op,
125 Box::new(rewrite_recursive_calls_expr(left, targets, fuel_var)),
126 Box::new(rewrite_recursive_calls_expr(right, targets, fuel_var)),
127 ),
128 Expr::Match { subject, arms } => Expr::Match {
129 subject: Box::new(rewrite_recursive_calls_expr(subject, targets, fuel_var)),
130 arms: arms
131 .iter()
132 .map(|arm| MatchArm {
133 pattern: arm.pattern.clone(),
134 body: Box::new(rewrite_recursive_calls_expr(&arm.body, targets, fuel_var)),
135 })
136 .collect(),
137 },
138 Expr::Constructor(name, arg) => Expr::Constructor(
139 name.clone(),
140 arg.as_ref()
141 .map(|inner| Box::new(rewrite_recursive_calls_expr(inner, targets, fuel_var))),
142 ),
143 Expr::ErrorProp(inner) => Expr::ErrorProp(Box::new(rewrite_recursive_calls_expr(
144 inner, targets, fuel_var,
145 ))),
146 Expr::InterpolatedStr(parts) => Expr::InterpolatedStr(
147 parts
148 .iter()
149 .map(|part| match part {
150 StrPart::Literal(_) => part.clone(),
151 StrPart::Parsed(inner) => StrPart::Parsed(Box::new(
152 rewrite_recursive_calls_expr(inner, targets, fuel_var),
153 )),
154 })
155 .collect(),
156 ),
157 Expr::List(items) => Expr::List(
158 items
159 .iter()
160 .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
161 .collect(),
162 ),
163 Expr::Tuple(items) => Expr::Tuple(
164 items
165 .iter()
166 .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
167 .collect(),
168 ),
169 Expr::IndependentProduct(items, flag) => Expr::IndependentProduct(
170 items
171 .iter()
172 .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
173 .collect(),
174 *flag,
175 ),
176 Expr::MapLiteral(entries) => Expr::MapLiteral(
177 entries
178 .iter()
179 .map(|(k, v)| {
180 (
181 rewrite_recursive_calls_expr(k, targets, fuel_var),
182 rewrite_recursive_calls_expr(v, targets, fuel_var),
183 )
184 })
185 .collect(),
186 ),
187 Expr::RecordCreate { type_name, fields } => Expr::RecordCreate {
188 type_name: type_name.clone(),
189 fields: fields
190 .iter()
191 .map(|(name, value)| {
192 (
193 name.clone(),
194 rewrite_recursive_calls_expr(value, targets, fuel_var),
195 )
196 })
197 .collect(),
198 },
199 Expr::RecordUpdate {
200 type_name,
201 base,
202 updates,
203 } => Expr::RecordUpdate {
204 type_name: type_name.clone(),
205 base: Box::new(rewrite_recursive_calls_expr(base, targets, fuel_var)),
206 updates: updates
207 .iter()
208 .map(|(name, value)| {
209 (
210 name.clone(),
211 rewrite_recursive_calls_expr(value, targets, fuel_var),
212 )
213 })
214 .collect(),
215 },
216 Expr::TailCall(boxed) => {
217 let TailCallData { target, args, .. } = boxed.as_ref();
218 let rewritten_args: Vec<Spanned<Expr>> = args
219 .iter()
220 .map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
221 .collect();
222 if targets.contains(target) {
223 let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
224 call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
225 call_args.extend(rewritten_args);
226 Expr::FnCall(
227 Box::new(Spanned::new(Expr::Ident(fuel_helper_name(target)), line)),
228 call_args,
229 )
230 } else {
231 Expr::TailCall(Box::new(TailCallData::new(target.clone(), rewritten_args)))
232 }
233 }
234 };
235 Spanned::new(new_node, line)
236}
237
238pub fn rewrite_recursive_calls_body(
240 body: &FnBody,
241 targets: &HashSet<String>,
242 fuel_var: &str,
243) -> FnBody {
244 FnBody::Block(
245 body.stmts()
246 .iter()
247 .map(|stmt| match stmt {
248 Stmt::Binding(name, ty, expr) => Stmt::Binding(
249 name.clone(),
250 ty.clone(),
251 rewrite_recursive_calls_expr(expr, targets, fuel_var),
252 ),
253 Stmt::Expr(expr) => {
254 Stmt::Expr(rewrite_recursive_calls_expr(expr, targets, fuel_var))
255 }
256 })
257 .collect(),
258 )
259}