Skip to main content

aver/codegen/recursion/
mod.rs

1//! Shared proof-mode recursion analysis.
2//!
3//! Classifies each recursive pure fn into a [`RecursionPlan`] that tells
4//! the proof backends (Lean, Dafny) how to emit a fuel-guarded helper
5//! plus a wrapper with an appropriate fuel metric. The same classifier
6//! feeds both backends so supported shapes stay consistent.
7//!
8//! Emission is backend-specific (syntax, termination-proof mechanism,
9//! default-value for fuel exhaustion), but the recognition pass and the
10//! AST transform that rewrites recursive calls into helper calls are
11//! shared.
12
13pub mod detect;
14
15use std::collections::HashSet;
16
17use crate::ast::{Expr, FnBody, MatchArm, Spanned, Stmt, StrPart, TailCallData};
18use crate::codegen::common::expr_to_dotted_name;
19
20pub use detect::analyze_plans;
21
22/// Classification for a single recursive fn (or a whole mutual-recursion
23/// SCC, in which case every fn in the SCC gets its own plan from the
24/// same family).
25///
26/// `Eq` is deliberately omitted — `IntAscending` holds an AST expression
27/// which only implements `PartialEq` (float literals inside it are
28/// partially ordered). `PartialEq` still works for the uses this enum
29/// sees (pattern matching, `matches!`, equality via `.eq`).
30#[derive(Clone, Debug, PartialEq)]
31pub enum RecursionPlan {
32    /// Single-fn recursion where an `Int` parameter decreases by 1.
33    /// The wrapper supplies `n.natAbs + 1` fuel so the helper terminates.
34    IntCountdown { param_index: usize },
35    /// Single-fn recursion where an `Int` param increases by 1 up to a
36    /// bound. The bound is kept as an Aver AST expression so each
37    /// backend renders it in its own idiom; the wrapper supplies
38    /// `(bound - n).natAbs + 1` fuel.
39    IntAscending {
40        param_index: usize,
41        bound: Spanned<Expr>,
42    },
43    /// Affine second-order recurrence like `fib(n) = fib(n-1) + fib(n-2)`
44    /// with `0 / 1` bases and an `n < 0` guard. Emitted through a
45    /// private Nat helper (pair-state), not a fuel helper.
46    LinearRecurrence2,
47    /// Single-fn structural recursion on a `List<_>` parameter; proof
48    /// backends emit as structural recursion directly (no fuel).
49    ListStructural { param_index: usize },
50    /// Single-fn structural recursion on a recursive user ADT; proof
51    /// backends emit through a sizeOf-guarded fuel helper.
52    SizeOfStructural,
53    /// Single-fn recursion where the first `String` is preserved and
54    /// the second `Int` position parameter strictly advances (`pos +
55    /// k`, k ≥ 1). Wrapper fuel is derived from `s.length - pos`.
56    StringPosAdvance,
57    /// Mutual recursion SCC where the first `Int` parameter decreases
58    /// by 1 across every inter-fn call.
59    MutualIntCountdown,
60    /// Mutual recursion SCC where the first `String` is preserved and
61    /// the second `Int` either advances or stays the same across
62    /// rank-decreasing edges.
63    MutualStringPosAdvance { rank: usize },
64    /// Generic mutual recursion SCC using `sizeOf` on structural
65    /// parameters plus rank for same-measure edges.
66    MutualSizeOfRanked { rank: usize },
67}
68
69/// A diagnostic surfaced when a recursive fn falls outside the supported
70/// patterns. Proof backends translate this into a warning and emit the
71/// fn through a partial/axiom fallback (or skip it entirely).
72#[derive(Clone, Debug, Eq, PartialEq)]
73pub struct ProofModeIssue {
74    pub line: usize,
75    pub message: String,
76}
77
78/// Canonical suffix for a fuel-guarded helper fn. Deliberately contains
79/// only lowercase ASCII + underscores so both Lean and Dafny accept it
80/// as an identifier without renaming.
81pub fn fuel_helper_name(name: &str) -> String {
82    format!("{}__fuel", name)
83}
84
85/// AST transform: walk `expr` and replace every recursive call to a fn
86/// in `targets` with `fn__fuel(fuel_var, …args)`. Inter-fn mutual calls
87/// in the same SCC are rewritten identically (the fuel parameter is
88/// threaded through the whole group).
89pub fn rewrite_recursive_calls_expr(
90    expr: &Spanned<Expr>,
91    targets: &HashSet<String>,
92    fuel_var: &str,
93) -> Spanned<Expr> {
94    let line = expr.line;
95    let new_node = match &expr.node {
96        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => return expr.clone(),
97        Expr::Attr(obj, field) => Expr::Attr(
98            Box::new(rewrite_recursive_calls_expr(obj, targets, fuel_var)),
99            field.clone(),
100        ),
101        Expr::FnCall(callee, args) => {
102            let rewritten_args: Vec<Spanned<Expr>> = args
103                .iter()
104                .map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
105                .collect();
106            if let Some(name) = expr_to_dotted_name(&callee.node)
107                && targets.contains(&name)
108            {
109                let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
110                call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
111                call_args.extend(rewritten_args);
112                Expr::FnCall(
113                    Box::new(Spanned::new(Expr::Ident(fuel_helper_name(&name)), line)),
114                    call_args,
115                )
116            } else {
117                Expr::FnCall(
118                    Box::new(rewrite_recursive_calls_expr(callee, targets, fuel_var)),
119                    rewritten_args,
120                )
121            }
122        }
123        Expr::BinOp(op, left, right) => Expr::BinOp(
124            *op,
125            Box::new(rewrite_recursive_calls_expr(left, targets, fuel_var)),
126            Box::new(rewrite_recursive_calls_expr(right, targets, fuel_var)),
127        ),
128        Expr::Match { subject, arms } => Expr::Match {
129            subject: Box::new(rewrite_recursive_calls_expr(subject, targets, fuel_var)),
130            arms: arms
131                .iter()
132                .map(|arm| MatchArm {
133                    pattern: arm.pattern.clone(),
134                    body: Box::new(rewrite_recursive_calls_expr(&arm.body, targets, fuel_var)),
135                    binding_slots: std::sync::OnceLock::new(),
136                })
137                .collect(),
138        },
139        Expr::Constructor(name, arg) => Expr::Constructor(
140            name.clone(),
141            arg.as_ref()
142                .map(|inner| Box::new(rewrite_recursive_calls_expr(inner, targets, fuel_var))),
143        ),
144        Expr::ErrorProp(inner) => Expr::ErrorProp(Box::new(rewrite_recursive_calls_expr(
145            inner, targets, fuel_var,
146        ))),
147        Expr::InterpolatedStr(parts) => Expr::InterpolatedStr(
148            parts
149                .iter()
150                .map(|part| match part {
151                    StrPart::Literal(_) => part.clone(),
152                    StrPart::Parsed(inner) => StrPart::Parsed(Box::new(
153                        rewrite_recursive_calls_expr(inner, targets, fuel_var),
154                    )),
155                })
156                .collect(),
157        ),
158        Expr::List(items) => Expr::List(
159            items
160                .iter()
161                .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
162                .collect(),
163        ),
164        Expr::Tuple(items) => Expr::Tuple(
165            items
166                .iter()
167                .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
168                .collect(),
169        ),
170        Expr::IndependentProduct(items, flag) => Expr::IndependentProduct(
171            items
172                .iter()
173                .map(|item| rewrite_recursive_calls_expr(item, targets, fuel_var))
174                .collect(),
175            *flag,
176        ),
177        Expr::MapLiteral(entries) => Expr::MapLiteral(
178            entries
179                .iter()
180                .map(|(k, v)| {
181                    (
182                        rewrite_recursive_calls_expr(k, targets, fuel_var),
183                        rewrite_recursive_calls_expr(v, targets, fuel_var),
184                    )
185                })
186                .collect(),
187        ),
188        Expr::RecordCreate { type_name, fields } => Expr::RecordCreate {
189            type_name: type_name.clone(),
190            fields: fields
191                .iter()
192                .map(|(name, value)| {
193                    (
194                        name.clone(),
195                        rewrite_recursive_calls_expr(value, targets, fuel_var),
196                    )
197                })
198                .collect(),
199        },
200        Expr::RecordUpdate {
201            type_name,
202            base,
203            updates,
204        } => Expr::RecordUpdate {
205            type_name: type_name.clone(),
206            base: Box::new(rewrite_recursive_calls_expr(base, targets, fuel_var)),
207            updates: updates
208                .iter()
209                .map(|(name, value)| {
210                    (
211                        name.clone(),
212                        rewrite_recursive_calls_expr(value, targets, fuel_var),
213                    )
214                })
215                .collect(),
216        },
217        Expr::TailCall(boxed) => {
218            let TailCallData { target, args, .. } = boxed.as_ref();
219            let rewritten_args: Vec<Spanned<Expr>> = args
220                .iter()
221                .map(|arg| rewrite_recursive_calls_expr(arg, targets, fuel_var))
222                .collect();
223            if targets.contains(target) {
224                let mut call_args = Vec::with_capacity(rewritten_args.len() + 1);
225                call_args.push(Spanned::new(Expr::Ident(fuel_var.to_string()), line));
226                call_args.extend(rewritten_args);
227                Expr::FnCall(
228                    Box::new(Spanned::new(Expr::Ident(fuel_helper_name(target)), line)),
229                    call_args,
230                )
231            } else {
232                Expr::TailCall(Box::new(TailCallData::new(target.clone(), rewritten_args)))
233            }
234        }
235    };
236    Spanned::new(new_node, line)
237}
238
239/// Body-level wrapper around [`rewrite_recursive_calls_expr`].
240pub fn rewrite_recursive_calls_body(
241    body: &FnBody,
242    targets: &HashSet<String>,
243    fuel_var: &str,
244) -> FnBody {
245    FnBody::Block(
246        body.stmts()
247            .iter()
248            .map(|stmt| match stmt {
249                Stmt::Binding(name, ty, expr) => Stmt::Binding(
250                    name.clone(),
251                    ty.clone(),
252                    rewrite_recursive_calls_expr(expr, targets, fuel_var),
253                ),
254                Stmt::Expr(expr) => {
255                    Stmt::Expr(rewrite_recursive_calls_expr(expr, targets, fuel_var))
256                }
257            })
258            .collect(),
259    )
260}