Skip to main content

aver/codegen/recursion/
mod.rs

1//! Shared proof-mode recursion analysis.
2//!
3//! Classifies each recursive pure fn into a [`RecursionPlan`] that tells
4//! the proof backends (Lean, Dafny) how to emit a fuel-guarded helper
5//! plus a wrapper with an appropriate fuel metric. The same classifier
6//! feeds both backends so supported shapes stay consistent.
7//!
8//! Emission is backend-specific (syntax, termination-proof mechanism,
9//! default-value for fuel exhaustion), but the recognition pass and the
10//! AST transform that rewrites recursive calls into helper calls are
11//! shared.
12
13pub mod detect;
14
15use std::collections::HashSet;
16
17use crate::ast::{Expr, FnBody, MatchArm, Spanned, Stmt, StrPart, TailCallData};
18use crate::codegen::common::expr_to_dotted_name;
19
20pub use detect::analyze_plans;
21
22/// Classification for a single recursive fn (or a whole mutual-recursion
23/// SCC, in which case every fn in the SCC gets its own plan from the
24/// same family).
25///
26/// `Eq` is deliberately omitted — `IntAscending` holds an AST expression
27/// which only implements `PartialEq` (float literals inside it are
28/// partially ordered). `PartialEq` still works for the uses this enum
29/// sees (pattern matching, `matches!`, equality via `.eq`).
30#[derive(Clone, Debug, PartialEq)]
31pub enum RecursionPlan {
32    /// Single-fn recursion where an `Int` parameter decreases by 1.
33    /// The wrapper supplies `n.natAbs + 1` fuel so the helper terminates.
34    IntCountdown { param_index: usize },
35    /// Single-fn recursion where an `Int` param increases by 1 up to a
36    /// bound. The bound is kept as an Aver AST expression so each
37    /// backend renders it in its own idiom; the wrapper supplies
38    /// `(bound - n).natAbs + 1` fuel.
39    IntAscending {
40        param_index: usize,
41        bound: Spanned<Expr>,
42    },
43    /// Affine second-order recurrence like `fib(n) = fib(n-1) + fib(n-2)`
44    /// with `0 / 1` bases and an `n < 0` guard. Emitted through a
45    /// private Nat helper (pair-state), not a fuel helper.
46    LinearRecurrence2,
47    /// Single-fn structural recursion on a `List<_>` parameter; proof
48    /// backends emit as structural recursion directly (no fuel).
49    ListStructural { param_index: usize },
50    /// Single-fn structural recursion on a recursive user ADT; proof
51    /// backends emit through a sizeOf-guarded fuel helper.
52    SizeOfStructural,
53    /// Single-fn recursion where the first `String` is preserved and
54    /// the second `Int` position parameter strictly advances (`pos +
55    /// k`, k ≥ 1). Wrapper fuel is derived from `s.length - pos`.
56    StringPosAdvance,
57    /// Mutual recursion SCC where the first `Int` parameter decreases
58    /// by 1 across every inter-fn call.
59    MutualIntCountdown,
60    /// Mutual recursion SCC where the first `String` is preserved and
61    /// the second `Int` either advances or stays the same across
62    /// rank-decreasing edges.
63    MutualStringPosAdvance { rank: usize },
64    /// Generic mutual recursion SCC using `sizeOf` on structural
65    /// parameters plus rank for same-measure edges.
66    MutualSizeOfRanked { rank: usize },
67}
68
69/// A diagnostic surfaced when a recursive fn falls outside the supported
70/// patterns. Proof backends translate this into a warning and emit the
71/// fn through a partial/axiom fallback (or skip it entirely).
72#[derive(Clone, Debug, Eq, PartialEq)]
73pub struct ProofModeIssue {
74    pub line: usize,
75    pub message: String,
76}
77
78/// Canonical suffix for a fuel-guarded helper fn. Deliberately contains
79/// only lowercase ASCII + underscores so both Lean and Dafny accept it
80/// as an identifier without renaming.
81pub fn fuel_helper_name(name: &str) -> String {
82    format!("{}__fuel", name)
83}
84
85/// AST transform: walk `expr` and replace every recursive call to a fn
86/// in `targets` with `fn__fuel(fuel_var, …args)`. Inter-fn mutual calls
87/// in the same SCC are rewritten identically (the fuel parameter is
88/// threaded through the whole group).
89pub fn rewrite_recursive_calls_expr(
90    expr: &Spanned<Expr>,
91    targets: &HashSet<String>,
92    fuel_var: &str,
93) -> Spanned<Expr> {
94    let line = expr.line;
95    let new_node = match &expr.node {
96        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => return expr.clone(),
97        Expr::Attr(obj, field) => Expr::Attr(
98            Box::new(rewrite_recursive_calls_expr(obj, targets, fuel_var)),
99            field.clone(),
100        ),
101        Expr::FnCall(callee, args) => {
102            let rewritten_args: Vec<Spanned<Expr>> = args
103                .iter()
104                .map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
105                .collect();
106            if let Some(name) = expr_to_dotted_name(&callee.node)
107                && targets.contains(&name)
108            {
109                let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
110                call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
111                call_args.extend(rewritten_args);
112                Expr::FnCall(
113                    Box::new(Spanned::new(Expr::Ident(fuel_helper_name(&name)), line)),
114                    call_args,
115                )
116            } else {
117                Expr::FnCall(
118                    Box::new(rewrite_recursive_calls_expr(callee, targets, fuel_var)),
119                    rewritten_args,
120                )
121            }
122        }
123        Expr::BinOp(op, left, right) => Expr::BinOp(
124            *op,
125            Box::new(rewrite_recursive_calls_expr(left, targets, fuel_var)),
126            Box::new(rewrite_recursive_calls_expr(right, targets, fuel_var)),
127        ),
128        Expr::Match { subject, arms } => Expr::Match {
129            subject: Box::new(rewrite_recursive_calls_expr(subject, targets, fuel_var)),
130            arms: arms
131                .iter()
132                .map(|arm| MatchArm {
133                    pattern: arm.pattern.clone(),
134                    body: Box::new(rewrite_recursive_calls_expr(&arm.body, targets, fuel_var)),
135                })
136                .collect(),
137        },
138        Expr::Constructor(name, arg) => Expr::Constructor(
139            name.clone(),
140            arg.as_ref()
141                .map(|inner| Box::new(rewrite_recursive_calls_expr(inner, targets, fuel_var))),
142        ),
143        Expr::ErrorProp(inner) => Expr::ErrorProp(Box::new(rewrite_recursive_calls_expr(
144            inner, targets, fuel_var,
145        ))),
146        Expr::InterpolatedStr(parts) => Expr::InterpolatedStr(
147            parts
148                .iter()
149                .map(|part| match part {
150                    StrPart::Literal(_) => part.clone(),
151                    StrPart::Parsed(inner) => StrPart::Parsed(Box::new(
152                        rewrite_recursive_calls_expr(inner, targets, fuel_var),
153                    )),
154                })
155                .collect(),
156        ),
157        Expr::List(items) => Expr::List(
158            items
159                .iter()
160                .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
161                .collect(),
162        ),
163        Expr::Tuple(items) => Expr::Tuple(
164            items
165                .iter()
166                .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
167                .collect(),
168        ),
169        Expr::IndependentProduct(items, flag) => Expr::IndependentProduct(
170            items
171                .iter()
172                .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
173                .collect(),
174            *flag,
175        ),
176        Expr::MapLiteral(entries) => Expr::MapLiteral(
177            entries
178                .iter()
179                .map(|(k, v)| {
180                    (
181                        rewrite_recursive_calls_expr(k, targets, fuel_var),
182                        rewrite_recursive_calls_expr(v, targets, fuel_var),
183                    )
184                })
185                .collect(),
186        ),
187        Expr::RecordCreate { type_name, fields } => Expr::RecordCreate {
188            type_name: type_name.clone(),
189            fields: fields
190                .iter()
191                .map(|(name, value)| {
192                    (
193                        name.clone(),
194                        rewrite_recursive_calls_expr(value, targets, fuel_var),
195                    )
196                })
197                .collect(),
198        },
199        Expr::RecordUpdate {
200            type_name,
201            base,
202            updates,
203        } => Expr::RecordUpdate {
204            type_name: type_name.clone(),
205            base: Box::new(rewrite_recursive_calls_expr(base, targets, fuel_var)),
206            updates: updates
207                .iter()
208                .map(|(name, value)| {
209                    (
210                        name.clone(),
211                        rewrite_recursive_calls_expr(value, targets, fuel_var),
212                    )
213                })
214                .collect(),
215        },
216        Expr::TailCall(boxed) => {
217            let TailCallData { target, args, .. } = boxed.as_ref();
218            let rewritten_args: Vec<Spanned<Expr>> = args
219                .iter()
220                .map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
221                .collect();
222            if targets.contains(target) {
223                let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
224                call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
225                call_args.extend(rewritten_args);
226                Expr::FnCall(
227                    Box::new(Spanned::new(Expr::Ident(fuel_helper_name(target)), line)),
228                    call_args,
229                )
230            } else {
231                Expr::TailCall(Box::new(TailCallData::new(target.clone(), rewritten_args)))
232            }
233        }
234    };
235    Spanned::new(new_node, line)
236}
237
238/// Body-level wrapper around [`rewrite_recursive_calls_expr`].
239pub fn rewrite_recursive_calls_body(
240    body: &FnBody,
241    targets: &HashSet<String>,
242    fuel_var: &str,
243) -> FnBody {
244    FnBody::Block(
245        body.stmts()
246            .iter()
247            .map(|stmt| match stmt {
248                Stmt::Binding(name, ty, expr) => Stmt::Binding(
249                    name.clone(),
250                    ty.clone(),
251                    rewrite_recursive_calls_expr(expr, targets, fuel_var),
252                ),
253                Stmt::Expr(expr) => {
254                    Stmt::Expr(rewrite_recursive_calls_expr(expr, targets, fuel_var))
255                }
256            })
257            .collect(),
258    )
259}