Skip to main content

aver/codegen/recursion/
detect.rs

1//! Proof-mode recursion classifier.
2//!
3//! Backend-neutral pattern matching over the Aver AST that decides,
4//! for each recursive pure fn (or mutual-recursion SCC), which
5//! [`RecursionPlan`] variant applies — or emits a [`ProofModeIssue`]
6//! when the shape falls outside the supported set.
7//!
8//! Lean and Dafny consume the same plans through
9//! [`crate::codegen::recursion::analyze_plans`]; a couple of helpers
10//! that depend on AST queries tied to Lean's `toplevel` (pure-fn
11//! predicate, recursive-type-def predicate, type-def name) still live
12//! in `crate::codegen::lean` and are re-used here through
13//! `pub(crate)` exports. That could move to a neutral AST helper
14//! module in a later pass.
15use std::collections::{HashMap, HashSet};
16
17use crate::ast::{
18    BinOp, Expr, FnBody, FnDef, MatchArm, Pattern, Spanned, Stmt, TailCallData, TypeDef,
19};
20use crate::call_graph;
21use crate::codegen::CodegenContext;
22use crate::codegen::lean::{
23    find_type_def, pure_fns, recursive_pure_fn_names, recursive_type_names,
24    sizeof_measure_param_indices,
25};
26
27use super::{ProofModeIssue, RecursionPlan};
28
29pub(crate) fn expr_to_dotted_name(expr: &Spanned<Expr>) -> Option<String> {
30    match &expr.node {
31        Expr::Ident(name) => Some(name.clone()),
32        Expr::Attr(obj, field) => expr_to_dotted_name(obj).map(|p| format!("{}.{}", p, field)),
33        _ => None,
34    }
35}
36
37/// Local-variable name regardless of whether the expression has been
38/// resolved. Resolver rewrites `Expr::Ident(name)` for locals into
39/// `Expr::Resolved { name, .. }` (slot-numbered for the VM); pre-resolution
40/// passes still see `Expr::Ident`. Recursion-shape detectors run AFTER
41/// resolution at codegen time, so they must accept both forms — otherwise
42/// every list/string-recursive function silently falls outside the proof
43/// subset.
44pub(crate) fn local_name_of(expr: &Spanned<Expr>) -> Option<&str> {
45    match &expr.node {
46        Expr::Ident(name) => Some(name.as_str()),
47        Expr::Resolved { name, .. } => Some(name.as_str()),
48        _ => None,
49    }
50}
51
52pub(crate) fn call_matches(name: &str, target: &str) -> bool {
53    name == target || name.rsplit('.').next() == Some(target)
54}
55
56pub(crate) fn call_is_in_set(name: &str, targets: &HashSet<String>) -> bool {
57    call_matches_any(name, targets)
58}
59
60pub(crate) fn canonical_callee_name(name: &str, targets: &HashSet<String>) -> Option<String> {
61    if targets.contains(name) {
62        return Some(name.to_string());
63    }
64    name.rsplit('.')
65        .next()
66        .filter(|last| targets.contains(*last))
67        .map(ToString::to_string)
68}
69
70pub(crate) fn call_matches_any(name: &str, targets: &HashSet<String>) -> bool {
71    if targets.contains(name) {
72        return true;
73    }
74    match name.rsplit('.').next() {
75        Some(last) => targets.contains(last),
76        None => false,
77    }
78}
79
80pub(crate) fn is_int_minus_positive(expr: &Spanned<Expr>, param_name: &str) -> bool {
81    match &expr.node {
82        Expr::BinOp(BinOp::Sub, left, right) => {
83            local_name_of(left).is_some_and(|id| id == param_name)
84                && matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
85        }
86        Expr::FnCall(callee, args) => {
87            let Some(name) = expr_to_dotted_name(callee) else {
88                return false;
89            };
90            (name == "Int.sub" || name == "int.sub")
91                && args.len() == 2
92                && local_name_of(&args[0]).is_some_and(|id| id == param_name)
93                && matches!(&args[1].node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
94        }
95        _ => false,
96    }
97}
98
99pub(crate) fn collect_calls_from_expr<'a>(
100    expr: &'a Spanned<Expr>,
101    out: &mut Vec<(String, Vec<&'a Spanned<Expr>>)>,
102) {
103    match &expr.node {
104        Expr::FnCall(callee, args) => {
105            if let Some(name) = expr_to_dotted_name(callee) {
106                out.push((name, args.iter().collect()));
107            }
108            collect_calls_from_expr(callee, out);
109            for arg in args {
110                collect_calls_from_expr(arg, out);
111            }
112        }
113        Expr::TailCall(boxed) => {
114            let TailCallData {
115                target: name, args, ..
116            } = boxed.as_ref();
117            out.push((name.clone(), args.iter().collect()));
118            for arg in args {
119                collect_calls_from_expr(arg, out);
120            }
121        }
122        Expr::Attr(obj, _) => collect_calls_from_expr(obj, out),
123        Expr::BinOp(_, left, right) => {
124            collect_calls_from_expr(left, out);
125            collect_calls_from_expr(right, out);
126        }
127        Expr::Match { subject, arms, .. } => {
128            collect_calls_from_expr(subject, out);
129            for arm in arms {
130                collect_calls_from_expr(&arm.body, out);
131            }
132        }
133        Expr::Constructor(_, inner) => {
134            if let Some(inner) = inner {
135                collect_calls_from_expr(inner, out);
136            }
137        }
138        Expr::ErrorProp(inner) => collect_calls_from_expr(inner, out),
139        Expr::InterpolatedStr(parts) => {
140            for p in parts {
141                if let crate::ast::StrPart::Parsed(e) = p {
142                    collect_calls_from_expr(e, out);
143                }
144            }
145        }
146        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
147            for item in items {
148                collect_calls_from_expr(item, out);
149            }
150        }
151        Expr::MapLiteral(entries) => {
152            for (k, v) in entries {
153                collect_calls_from_expr(k, out);
154                collect_calls_from_expr(v, out);
155            }
156        }
157        Expr::RecordCreate { fields, .. } => {
158            for (_, v) in fields {
159                collect_calls_from_expr(v, out);
160            }
161        }
162        Expr::RecordUpdate { base, updates, .. } => {
163            collect_calls_from_expr(base, out);
164            for (_, v) in updates {
165                collect_calls_from_expr(v, out);
166            }
167        }
168        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
169    }
170}
171
172pub(crate) fn collect_calls_from_body(body: &FnBody) -> Vec<(String, Vec<&Spanned<Expr>>)> {
173    let mut out = Vec::new();
174    for stmt in body.stmts() {
175        match stmt {
176            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => collect_calls_from_expr(expr, &mut out),
177        }
178    }
179    out
180}
181
182pub(crate) fn collect_list_tail_binders_from_expr(
183    expr: &Spanned<Expr>,
184    list_param_name: &str,
185    tails: &mut HashSet<String>,
186) {
187    match &expr.node {
188        Expr::Match { subject, arms, .. } => {
189            if local_name_of(subject).is_some_and(|id| id == list_param_name) {
190                for MatchArm { pattern, .. } in arms {
191                    if let Pattern::Cons(_, tail) = pattern {
192                        tails.insert(tail.clone());
193                    }
194                }
195            }
196            for arm in arms {
197                collect_list_tail_binders_from_expr(&arm.body, list_param_name, tails);
198            }
199            collect_list_tail_binders_from_expr(subject, list_param_name, tails);
200        }
201        Expr::FnCall(callee, args) => {
202            collect_list_tail_binders_from_expr(callee, list_param_name, tails);
203            for arg in args {
204                collect_list_tail_binders_from_expr(arg, list_param_name, tails);
205            }
206        }
207        Expr::TailCall(boxed) => {
208            let TailCallData {
209                target: _, args, ..
210            } = boxed.as_ref();
211            for arg in args {
212                collect_list_tail_binders_from_expr(arg, list_param_name, tails);
213            }
214        }
215        Expr::Attr(obj, _) => collect_list_tail_binders_from_expr(obj, list_param_name, tails),
216        Expr::BinOp(_, left, right) => {
217            collect_list_tail_binders_from_expr(left, list_param_name, tails);
218            collect_list_tail_binders_from_expr(right, list_param_name, tails);
219        }
220        Expr::Constructor(_, inner) => {
221            if let Some(inner) = inner {
222                collect_list_tail_binders_from_expr(inner, list_param_name, tails);
223            }
224        }
225        Expr::ErrorProp(inner) => {
226            collect_list_tail_binders_from_expr(inner, list_param_name, tails)
227        }
228        Expr::InterpolatedStr(parts) => {
229            for p in parts {
230                if let crate::ast::StrPart::Parsed(e) = p {
231                    collect_list_tail_binders_from_expr(e, list_param_name, tails);
232                }
233            }
234        }
235        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
236            for item in items {
237                collect_list_tail_binders_from_expr(item, list_param_name, tails);
238            }
239        }
240        Expr::MapLiteral(entries) => {
241            for (k, v) in entries {
242                collect_list_tail_binders_from_expr(k, list_param_name, tails);
243                collect_list_tail_binders_from_expr(v, list_param_name, tails);
244            }
245        }
246        Expr::RecordCreate { fields, .. } => {
247            for (_, v) in fields {
248                collect_list_tail_binders_from_expr(v, list_param_name, tails);
249            }
250        }
251        Expr::RecordUpdate { base, updates, .. } => {
252            collect_list_tail_binders_from_expr(base, list_param_name, tails);
253            for (_, v) in updates {
254                collect_list_tail_binders_from_expr(v, list_param_name, tails);
255            }
256        }
257        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
258    }
259}
260
261pub(crate) fn collect_list_tail_binders(fd: &FnDef, list_param_name: &str) -> HashSet<String> {
262    let mut tails = HashSet::new();
263    for stmt in fd.body.stmts() {
264        match stmt {
265            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
266                collect_list_tail_binders_from_expr(expr, list_param_name, &mut tails)
267            }
268        }
269    }
270    tails
271}
272
273pub(crate) fn recursive_constructor_binders(
274    td: &TypeDef,
275    variant_name: &str,
276    binders: &[String],
277) -> Vec<String> {
278    let variant_short = variant_name.rsplit('.').next().unwrap_or(variant_name);
279    match td {
280        TypeDef::Sum { name, variants, .. } => variants
281            .iter()
282            .find(|variant| variant.name == variant_short)
283            .map(|variant| {
284                variant
285                    .fields
286                    .iter()
287                    .zip(binders.iter())
288                    .filter_map(|(field_ty, binder)| {
289                        (field_ty.trim() == name).then_some(binder.clone())
290                    })
291                    .collect()
292            })
293            .unwrap_or_default(),
294        TypeDef::Product { .. } => Vec::new(),
295    }
296}
297
298pub(crate) fn grow_recursive_subterm_binders_from_expr(
299    expr: &Spanned<Expr>,
300    tracked: &HashSet<String>,
301    td: &TypeDef,
302    out: &mut HashSet<String>,
303) {
304    match &expr.node {
305        Expr::Match { subject, arms, .. } => {
306            if let Expr::Ident(subject_name) = &subject.node
307                && tracked.contains(subject_name)
308            {
309                for arm in arms {
310                    if let Pattern::Constructor(variant_name, binders) = &arm.pattern {
311                        out.extend(recursive_constructor_binders(td, variant_name, binders));
312                    }
313                }
314            }
315            grow_recursive_subterm_binders_from_expr(subject, tracked, td, out);
316            for arm in arms {
317                grow_recursive_subterm_binders_from_expr(&arm.body, tracked, td, out);
318            }
319        }
320        Expr::FnCall(callee, args) => {
321            grow_recursive_subterm_binders_from_expr(callee, tracked, td, out);
322            for arg in args {
323                grow_recursive_subterm_binders_from_expr(arg, tracked, td, out);
324            }
325        }
326        Expr::Attr(obj, _) => grow_recursive_subterm_binders_from_expr(obj, tracked, td, out),
327        Expr::BinOp(_, left, right) => {
328            grow_recursive_subterm_binders_from_expr(left, tracked, td, out);
329            grow_recursive_subterm_binders_from_expr(right, tracked, td, out);
330        }
331        Expr::Constructor(_, Some(inner)) | Expr::ErrorProp(inner) => {
332            grow_recursive_subterm_binders_from_expr(inner, tracked, td, out)
333        }
334        Expr::InterpolatedStr(parts) => {
335            for part in parts {
336                if let crate::ast::StrPart::Parsed(inner) = part {
337                    grow_recursive_subterm_binders_from_expr(inner, tracked, td, out);
338                }
339            }
340        }
341        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
342            for item in items {
343                grow_recursive_subterm_binders_from_expr(item, tracked, td, out);
344            }
345        }
346        Expr::MapLiteral(entries) => {
347            for (k, v) in entries {
348                grow_recursive_subterm_binders_from_expr(k, tracked, td, out);
349                grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
350            }
351        }
352        Expr::RecordCreate { fields, .. } => {
353            for (_, v) in fields {
354                grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
355            }
356        }
357        Expr::RecordUpdate { base, updates, .. } => {
358            grow_recursive_subterm_binders_from_expr(base, tracked, td, out);
359            for (_, v) in updates {
360                grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
361            }
362        }
363        Expr::TailCall(boxed) => {
364            for arg in &boxed.args {
365                grow_recursive_subterm_binders_from_expr(arg, tracked, td, out);
366            }
367        }
368        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
369    }
370}
371
372pub(crate) fn collect_recursive_subterm_binders(
373    fd: &FnDef,
374    param_name: &str,
375    param_type: &str,
376    ctx: &CodegenContext,
377) -> HashSet<String> {
378    let Some(td) = find_type_def(ctx, param_type) else {
379        return HashSet::new();
380    };
381    let mut tracked: HashSet<String> = HashSet::from([param_name.to_string()]);
382    loop {
383        let mut discovered = HashSet::new();
384        for stmt in fd.body.stmts() {
385            match stmt {
386                Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
387                    grow_recursive_subterm_binders_from_expr(expr, &tracked, td, &mut discovered);
388                }
389            }
390        }
391        let before = tracked.len();
392        tracked.extend(discovered);
393        if tracked.len() == before {
394            break;
395        }
396    }
397    tracked.remove(param_name);
398    tracked
399}
400
401pub(crate) fn single_int_countdown_param_index(fd: &FnDef) -> Option<usize> {
402    let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
403        .into_iter()
404        .filter(|(name, _)| call_matches(name, &fd.name))
405        .map(|(_, args)| args)
406        .collect();
407    if recursive_calls.is_empty() {
408        return None;
409    }
410
411    fd.params
412        .iter()
413        .enumerate()
414        .find_map(|(idx, (param_name, param_ty))| {
415            if param_ty != "Int" {
416                return None;
417            }
418            let countdown_ok = recursive_calls.iter().all(|args| {
419                args.get(idx)
420                    .cloned()
421                    .is_some_and(|arg| is_int_minus_positive(arg, param_name))
422            });
423            if countdown_ok {
424                return Some(idx);
425            }
426
427            // Negative-guarded ascent (match n < 0) is handled as countdown
428            // because the fuel is natAbs(n) which works for both directions.
429            let ascent_ok = recursive_calls.iter().all(|args| {
430                args.get(idx)
431                    .copied()
432                    .is_some_and(|arg| is_int_plus_positive(arg, param_name))
433            });
434            (ascent_ok && has_negative_guarded_ascent(fd, param_name)).then_some(idx)
435        })
436}
437
438pub(crate) fn has_negative_guarded_ascent(fd: &FnDef, param_name: &str) -> bool {
439    let Some(tail) = fd.body.tail_expr() else {
440        return false;
441    };
442    let Expr::Match { subject, arms, .. } = &tail.node else {
443        return false;
444    };
445    let Expr::BinOp(BinOp::Lt, left, right) = &subject.node else {
446        return false;
447    };
448    if !is_ident(left, param_name)
449        || !matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(0)))
450    {
451        return false;
452    }
453
454    let mut true_arm = None;
455    let mut false_arm = None;
456    for arm in arms {
457        match arm.pattern {
458            Pattern::Literal(crate::ast::Literal::Bool(true)) => true_arm = Some(arm.body.as_ref()),
459            Pattern::Literal(crate::ast::Literal::Bool(false)) => {
460                false_arm = Some(arm.body.as_ref())
461            }
462            _ => return false,
463        }
464    }
465
466    let Some(true_arm) = true_arm else {
467        return false;
468    };
469    let Some(false_arm) = false_arm else {
470        return false;
471    };
472
473    let mut true_calls = Vec::new();
474    collect_calls_from_expr(true_arm, &mut true_calls);
475    let mut false_calls = Vec::new();
476    collect_calls_from_expr(false_arm, &mut false_calls);
477
478    true_calls
479        .iter()
480        .any(|(name, _)| call_matches(name, &fd.name))
481        && false_calls
482            .iter()
483            .all(|(name, _)| !call_matches(name, &fd.name))
484}
485
486/// Detect ascending-index recursion and extract the bound expression
487/// as an Aver AST (`Spanned<Expr>`). Returns `(param_index, bound)`.
488pub(crate) fn single_int_ascending_param(fd: &FnDef) -> Option<(usize, Spanned<Expr>)> {
489    let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
490        .into_iter()
491        .filter(|(name, _)| call_matches(name, &fd.name))
492        .map(|(_, args)| args)
493        .collect();
494    if recursive_calls.is_empty() {
495        return None;
496    }
497
498    for (idx, (param_name, param_ty)) in fd.params.iter().enumerate() {
499        if param_ty != "Int" {
500            continue;
501        }
502        let ascent_ok = recursive_calls.iter().all(|args| {
503            args.get(idx)
504                .cloned()
505                .is_some_and(|arg| is_int_plus_positive(arg, param_name))
506        });
507        if !ascent_ok {
508            continue;
509        }
510        if let Some(bound) = extract_equality_bound_expr(fd, param_name) {
511            return Some((idx, bound));
512        }
513    }
514    None
515}
516
517/// Extract the bound expression from `match param == BOUND` as an
518/// Aver AST node. Each backend renders this into its own idiom (Lean
519/// via `bound_expr_to_lean`, Dafny via its own `emit_expr` path).
520pub(crate) fn extract_equality_bound_expr(fd: &FnDef, param_name: &str) -> Option<Spanned<Expr>> {
521    let tail = fd.body.tail_expr()?;
522    let Expr::Match { subject, arms, .. } = &tail.node else {
523        return None;
524    };
525    let Expr::BinOp(BinOp::Eq, left, right) = &subject.node else {
526        return None;
527    };
528    if !is_ident(left, param_name) {
529        return None;
530    }
531    // Verify: true arm = base (no self-call), false arm = recursive (has self-call)
532    let mut true_has_self = false;
533    let mut false_has_self = false;
534    for arm in arms {
535        match arm.pattern {
536            Pattern::Literal(crate::ast::Literal::Bool(true)) => {
537                let mut calls = Vec::new();
538                collect_calls_from_expr(&arm.body, &mut calls);
539                true_has_self = calls.iter().any(|(n, _)| call_matches(n, &fd.name));
540            }
541            Pattern::Literal(crate::ast::Literal::Bool(false)) => {
542                let mut calls = Vec::new();
543                collect_calls_from_expr(&arm.body, &mut calls);
544                false_has_self = calls.iter().any(|(n, _)| call_matches(n, &fd.name));
545            }
546            _ => return None,
547        }
548    }
549    if true_has_self || !false_has_self {
550        return None;
551    }
552    Some((**right).clone())
553}
554
555pub(crate) fn supports_single_sizeof_structural(fd: &FnDef, ctx: &CodegenContext) -> bool {
556    let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
557        .into_iter()
558        .filter(|(name, _)| call_matches(name, &fd.name))
559        .map(|(_, args)| args)
560        .collect();
561    if recursive_calls.is_empty() {
562        return false;
563    }
564
565    let metric_indices = sizeof_measure_param_indices(fd);
566    if metric_indices.is_empty() {
567        return false;
568    }
569
570    let binder_sets: HashMap<usize, HashSet<String>> = metric_indices
571        .iter()
572        .filter_map(|idx| {
573            let (param_name, param_type) = fd.params.get(*idx)?;
574            recursive_type_names(ctx).contains(param_type).then(|| {
575                (
576                    *idx,
577                    collect_recursive_subterm_binders(fd, param_name, param_type, ctx),
578                )
579            })
580        })
581        .collect();
582
583    if binder_sets.values().all(HashSet::is_empty) {
584        return false;
585    }
586
587    recursive_calls.iter().all(|args| {
588        let mut strictly_smaller = false;
589        for idx in &metric_indices {
590            let Some((param_name, _)) = fd.params.get(*idx) else {
591                return false;
592            };
593            let Some(arg) = args.get(*idx).cloned() else {
594                return false;
595            };
596            if is_ident(arg, param_name) {
597                continue;
598            }
599            let Some(binders) = binder_sets.get(idx) else {
600                return false;
601            };
602            if local_name_of(arg).is_some_and(|id| binders.contains(id)) {
603                strictly_smaller = true;
604                continue;
605            }
606            return false;
607        }
608        strictly_smaller
609    })
610}
611
612pub(crate) fn single_list_structural_param_index(fd: &FnDef) -> Option<usize> {
613    fd.params
614        .iter()
615        .enumerate()
616        .find_map(|(param_index, (param_name, param_ty))| {
617            if !(param_ty.starts_with("List<") || param_ty == "List") {
618                return None;
619            }
620
621            let tails = collect_list_tail_binders(fd, param_name);
622            if tails.is_empty() {
623                return None;
624            }
625
626            let recursive_calls: Vec<Option<&Spanned<Expr>>> =
627                collect_calls_from_body(fd.body.as_ref())
628                    .into_iter()
629                    .filter(|(name, _)| call_matches(name, &fd.name))
630                    .map(|(_, args)| args.get(param_index).cloned())
631                    .collect();
632            if recursive_calls.is_empty() {
633                return None;
634            }
635
636            recursive_calls
637                .into_iter()
638                .all(|arg| {
639                    arg.and_then(local_name_of)
640                        .is_some_and(|id| tails.contains(id))
641                })
642                .then_some(param_index)
643        })
644}
645
646pub(crate) fn is_ident(expr: &Spanned<Expr>, name: &str) -> bool {
647    local_name_of(expr).is_some_and(|id| id == name)
648}
649
650pub(crate) fn is_int_plus_positive(expr: &Spanned<Expr>, param_name: &str) -> bool {
651    match &expr.node {
652        Expr::BinOp(BinOp::Add, left, right) => {
653            local_name_of(left).is_some_and(|id| id == param_name)
654                && matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
655        }
656        Expr::FnCall(callee, args) => {
657            let Some(name) = expr_to_dotted_name(callee) else {
658                return false;
659            };
660            (name == "Int.add" || name == "int.add")
661                && args.len() == 2
662                && local_name_of(&args[0]).is_some_and(|id| id == param_name)
663                && matches!(&args[1].node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
664        }
665        _ => false,
666    }
667}
668
669pub(crate) fn is_skip_ws_advance(
670    expr: &Spanned<Expr>,
671    string_param: &str,
672    pos_param: &str,
673) -> bool {
674    let Expr::FnCall(callee, args) = &expr.node else {
675        return false;
676    };
677    let Some(name) = expr_to_dotted_name(callee) else {
678        return false;
679    };
680    if !call_matches(&name, "skipWs") || args.len() != 2 {
681        return false;
682    }
683    is_ident(&args[0], string_param) && is_int_plus_positive(&args[1], pos_param)
684}
685
686pub(crate) fn is_skip_ws_same(expr: &Spanned<Expr>, string_param: &str, pos_param: &str) -> bool {
687    let Expr::FnCall(callee, args) = &expr.node else {
688        return false;
689    };
690    let Some(name) = expr_to_dotted_name(callee) else {
691        return false;
692    };
693    if !call_matches(&name, "skipWs") || args.len() != 2 {
694        return false;
695    }
696    is_ident(&args[0], string_param) && is_ident(&args[1], pos_param)
697}
698
699pub(crate) fn is_string_pos_advance(
700    expr: &Spanned<Expr>,
701    string_param: &str,
702    pos_param: &str,
703) -> bool {
704    is_int_plus_positive(expr, pos_param) || is_skip_ws_advance(expr, string_param, pos_param)
705}
706
707#[derive(Clone, Copy, Debug, Eq, PartialEq)]
708pub(crate) enum StringPosEdge {
709    Same,
710    Advance,
711}
712
713pub(crate) fn classify_string_pos_edge(
714    expr: &Spanned<Expr>,
715    string_param: &str,
716    pos_param: &str,
717) -> Option<StringPosEdge> {
718    if is_ident(expr, pos_param) || is_skip_ws_same(expr, string_param, pos_param) {
719        return Some(StringPosEdge::Same);
720    }
721    if is_string_pos_advance(expr, string_param, pos_param) {
722        return Some(StringPosEdge::Advance);
723    }
724    if let Expr::FnCall(callee, args) = &expr.node {
725        let name = expr_to_dotted_name(callee)?;
726        if call_matches(&name, "skipWs")
727            && args.len() == 2
728            && is_ident(&args[0], string_param)
729            && local_name_of(&args[1]).is_some_and(|id| id != pos_param)
730        {
731            return Some(StringPosEdge::Advance);
732        }
733    }
734    if local_name_of(expr).is_some_and(|id| id != pos_param) {
735        return Some(StringPosEdge::Advance);
736    }
737    None
738}
739
740pub(crate) fn ranks_from_same_edges(
741    names: &HashSet<String>,
742    same_edges: &HashMap<String, HashSet<String>>,
743) -> Option<HashMap<String, usize>> {
744    let mut indegree: HashMap<String, usize> = names.iter().map(|n| (n.clone(), 0)).collect();
745    for outs in same_edges.values() {
746        for to in outs {
747            if let Some(entry) = indegree.get_mut(to) {
748                *entry += 1;
749            } else {
750                return None;
751            }
752        }
753    }
754
755    let mut queue: Vec<String> = indegree
756        .iter()
757        .filter_map(|(name, &deg)| (deg == 0).then_some(name.clone()))
758        .collect();
759    queue.sort();
760    let mut topo = Vec::new();
761    while let Some(node) = queue.pop() {
762        topo.push(node.clone());
763        let outs = same_edges.get(&node).cloned().unwrap_or_default();
764        let mut newly_zero = Vec::new();
765        for to in outs {
766            if let Some(entry) = indegree.get_mut(&to) {
767                *entry -= 1;
768                if *entry == 0 {
769                    newly_zero.push(to);
770                }
771            } else {
772                return None;
773            }
774        }
775        newly_zero.sort();
776        queue.extend(newly_zero);
777    }
778
779    if topo.len() != names.len() {
780        return None;
781    }
782
783    let n = topo.len();
784    let mut ranks = HashMap::new();
785    for (idx, name) in topo.into_iter().enumerate() {
786        ranks.insert(name, n - idx);
787    }
788    Some(ranks)
789}
790
791pub(crate) fn supports_single_string_pos_advance(fd: &FnDef) -> bool {
792    let Some((string_param, string_ty)) = fd.params.first() else {
793        return false;
794    };
795    let Some((pos_param, pos_ty)) = fd.params.get(1) else {
796        return false;
797    };
798    if string_ty != "String" || pos_ty != "Int" {
799        return false;
800    }
801
802    type CallPair<'a> = (Option<&'a Spanned<Expr>>, Option<&'a Spanned<Expr>>);
803    let recursive_calls: Vec<CallPair<'_>> = collect_calls_from_body(fd.body.as_ref())
804        .into_iter()
805        .filter(|(name, _)| call_matches(name, &fd.name))
806        .map(|(_, args)| (args.first().cloned(), args.get(1).cloned()))
807        .collect();
808    if recursive_calls.is_empty() {
809        return false;
810    }
811
812    recursive_calls.into_iter().all(|(arg0, arg1)| {
813        arg0.is_some_and(|e| is_ident(e, string_param))
814            && arg1.is_some_and(|e| is_string_pos_advance(e, string_param, pos_param))
815    })
816}
817
818pub(crate) fn supports_mutual_int_countdown(component: &[&FnDef]) -> bool {
819    if component.len() < 2 {
820        return false;
821    }
822    if component
823        .iter()
824        .any(|fd| !matches!(fd.params.first(), Some((_, t)) if t == "Int"))
825    {
826        return false;
827    }
828    let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
829    let mut any_intra = false;
830    for fd in component {
831        let param_name = &fd.params[0].0;
832        for (callee, args) in collect_calls_from_body(fd.body.as_ref()) {
833            if !call_is_in_set(&callee, &names) {
834                continue;
835            }
836            any_intra = true;
837            let Some(arg0) = args.first().cloned() else {
838                return false;
839            };
840            if !is_int_minus_positive(arg0, param_name) {
841                return false;
842            }
843        }
844    }
845    any_intra
846}
847
848pub(crate) fn supports_mutual_string_pos_advance(
849    component: &[&FnDef],
850) -> Option<HashMap<String, usize>> {
851    if component.len() < 2 {
852        return None;
853    }
854    if component.iter().any(|fd| {
855        !matches!(fd.params.first(), Some((_, t)) if t == "String")
856            || !matches!(fd.params.get(1), Some((_, t)) if t == "Int")
857    }) {
858        return None;
859    }
860
861    let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
862    let mut same_edges: HashMap<String, HashSet<String>> =
863        names.iter().map(|n| (n.clone(), HashSet::new())).collect();
864    let mut any_intra = false;
865
866    for fd in component {
867        let string_param = &fd.params[0].0;
868        let pos_param = &fd.params[1].0;
869        for (callee_raw, args) in collect_calls_from_body(fd.body.as_ref()) {
870            let Some(callee) = canonical_callee_name(&callee_raw, &names) else {
871                continue;
872            };
873            any_intra = true;
874
875            let arg0 = args.first().cloned()?;
876            let arg1 = args.get(1).cloned()?;
877
878            if !is_ident(arg0, string_param) {
879                return None;
880            }
881
882            match classify_string_pos_edge(arg1, string_param, pos_param) {
883                Some(StringPosEdge::Same) => {
884                    if let Some(edges) = same_edges.get_mut(&fd.name) {
885                        edges.insert(callee);
886                    } else {
887                        return None;
888                    }
889                }
890                Some(StringPosEdge::Advance) => {}
891                None => return None,
892            }
893        }
894    }
895
896    if !any_intra {
897        return None;
898    }
899
900    ranks_from_same_edges(&names, &same_edges)
901}
902
903pub(crate) fn is_scalar_like_type(type_name: &str) -> bool {
904    matches!(
905        type_name,
906        "Int" | "Float" | "Bool" | "String" | "Char" | "Byte" | "Unit"
907    )
908}
909
910pub(crate) fn supports_mutual_sizeof_ranked(
911    component: &[&FnDef],
912) -> Option<HashMap<String, usize>> {
913    if component.len() < 2 {
914        return None;
915    }
916    let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
917    let metric_indices: HashMap<String, Vec<usize>> = component
918        .iter()
919        .map(|fd| (fd.name.clone(), sizeof_measure_param_indices(fd)))
920        .collect();
921    if component.iter().any(|fd| {
922        metric_indices
923            .get(&fd.name)
924            .is_none_or(|indices| indices.is_empty())
925    }) {
926        return None;
927    }
928
929    let mut same_edges: HashMap<String, HashSet<String>> =
930        names.iter().map(|n| (n.clone(), HashSet::new())).collect();
931    let mut any_intra = false;
932    for fd in component {
933        let caller_metric_indices = metric_indices.get(&fd.name)?;
934        let caller_metric_params: Vec<&str> = caller_metric_indices
935            .iter()
936            .filter_map(|idx| fd.params.get(*idx).map(|(name, _)| name.as_str()))
937            .collect();
938        for (callee_raw, args) in collect_calls_from_body(fd.body.as_ref()) {
939            let Some(callee) = canonical_callee_name(&callee_raw, &names) else {
940                continue;
941            };
942            any_intra = true;
943            let callee_metric_indices = metric_indices.get(&callee)?;
944            let is_same_edge = callee_metric_indices.len() == caller_metric_params.len()
945                && callee_metric_indices
946                    .iter()
947                    .enumerate()
948                    .all(|(pos, callee_idx)| {
949                        let Some(arg) = args.get(*callee_idx).cloned() else {
950                            return false;
951                        };
952                        is_ident(arg, caller_metric_params[pos])
953                    });
954            if is_same_edge {
955                if let Some(edges) = same_edges.get_mut(&fd.name) {
956                    edges.insert(callee);
957                } else {
958                    return None;
959                }
960            }
961        }
962    }
963    if !any_intra {
964        return None;
965    }
966
967    let ranks = ranks_from_same_edges(&names, &same_edges)?;
968    let mut out = HashMap::new();
969    for fd in component {
970        let rank = ranks.get(&fd.name).cloned()?;
971        out.insert(fd.name.clone(), rank);
972    }
973    Some(out)
974}
975
976/// Classify every recursive pure fn in `ctx`. The returned map assigns
977/// each supported function a [`RecursionPlan`]; anything that falls
978/// outside the recognised shapes becomes a [`ProofModeIssue`].
979pub fn analyze_plans(
980    ctx: &CodegenContext,
981) -> (HashMap<String, RecursionPlan>, Vec<ProofModeIssue>) {
982    let mut plans = HashMap::new();
983    let mut issues = Vec::new();
984
985    let all_pure = pure_fns(ctx);
986    let recursive_names = recursive_pure_fn_names(ctx);
987    let components = call_graph::ordered_fn_components(&all_pure, &ctx.module_prefixes);
988
989    for component in components {
990        if component.is_empty() {
991            continue;
992        }
993        let is_recursive_component =
994            component.len() > 1 || recursive_names.contains(&component[0].name);
995        if !is_recursive_component {
996            continue;
997        }
998
999        if component.len() > 1 {
1000            if supports_mutual_int_countdown(&component) {
1001                for fd in &component {
1002                    plans.insert(fd.name.clone(), RecursionPlan::MutualIntCountdown);
1003                }
1004            } else if let Some(ranks) = supports_mutual_string_pos_advance(&component) {
1005                for fd in &component {
1006                    if let Some(rank) = ranks.get(&fd.name).cloned() {
1007                        plans.insert(
1008                            fd.name.clone(),
1009                            RecursionPlan::MutualStringPosAdvance { rank },
1010                        );
1011                    }
1012                }
1013            } else if let Some(rankings) = supports_mutual_sizeof_ranked(&component) {
1014                for fd in &component {
1015                    if let Some(rank) = rankings.get(&fd.name).cloned() {
1016                        plans.insert(fd.name.clone(), RecursionPlan::MutualSizeOfRanked { rank });
1017                    }
1018                }
1019            } else {
1020                let names = component
1021                    .iter()
1022                    .map(|fd| fd.name.clone())
1023                    .collect::<Vec<_>>()
1024                    .join(", ");
1025                let line = component.iter().map(|fd| fd.line).min().unwrap_or(1);
1026                issues.push(ProofModeIssue {
1027                    line,
1028                    message: format!(
1029                        "unsupported mutual recursion group (currently supported in proof mode: Int countdown on first param): {}",
1030                        names
1031                    ),
1032                });
1033            }
1034            continue;
1035        }
1036
1037        let fd = component[0];
1038        if crate::codegen::lean::recurrence::detect_second_order_int_linear_recurrence(fd).is_some()
1039        {
1040            plans.insert(fd.name.clone(), RecursionPlan::LinearRecurrence2);
1041        } else if let Some((param_index, bound)) = single_int_ascending_param(fd) {
1042            plans.insert(
1043                fd.name.clone(),
1044                RecursionPlan::IntAscending { param_index, bound },
1045            );
1046        } else if let Some(param_index) = single_int_countdown_param_index(fd) {
1047            plans.insert(fd.name.clone(), RecursionPlan::IntCountdown { param_index });
1048        } else if supports_single_sizeof_structural(fd, ctx) {
1049            plans.insert(fd.name.clone(), RecursionPlan::SizeOfStructural);
1050        } else if let Some(param_index) = single_list_structural_param_index(fd) {
1051            plans.insert(
1052                fd.name.clone(),
1053                RecursionPlan::ListStructural { param_index },
1054            );
1055        } else if supports_single_string_pos_advance(fd) {
1056            plans.insert(fd.name.clone(), RecursionPlan::StringPosAdvance);
1057        } else {
1058            issues.push(ProofModeIssue {
1059                line: fd.line,
1060                message: format!(
1061                    "recursive function '{}' is outside proof subset (currently supported: Int countdown, second-order affine Int recurrences with pair-state worker, structural recursion on List/recursive ADTs, String+position, mutual Int countdown, mutual String+position, and ranked sizeOf recursion)",
1062                    fd.name
1063                ),
1064            });
1065        }
1066    }
1067
1068    (plans, issues)
1069}