Skip to main content

aver/codegen/recursion/
detect.rs

1//! Proof-mode recursion classifier.
2//!
3//! Backend-neutral pattern matching over the Aver AST that decides,
4//! for each recursive pure fn (or mutual-recursion SCC), which
5//! [`RecursionPlan`] variant applies — or emits a [`ProofModeIssue`]
6//! when the shape falls outside the supported set.
7//!
8//! Lean and Dafny consume the same plans through
9//! [`crate::codegen::recursion::analyze_plans`]; a couple of helpers
10//! that depend on AST queries tied to Lean's `toplevel` (pure-fn
11//! predicate, recursive-type-def predicate, type-def name) still live
12//! in `crate::codegen::lean` and are re-used here through
13//! `pub(crate)` exports. That could move to a neutral AST helper
14//! module in a later pass.
15use std::collections::{HashMap, HashSet};
16
17use crate::ast::{
18    BinOp, Expr, FnBody, FnDef, MatchArm, Pattern, Spanned, Stmt, TailCallData, TypeDef,
19};
20use crate::call_graph;
21use crate::codegen::CodegenContext;
22use crate::codegen::lean::{
23    find_type_def, pure_fns, recursive_pure_fn_names, recursive_type_names,
24    sizeof_measure_param_indices,
25};
26
27use super::{ProofModeIssue, RecursionPlan};
28
29pub(crate) fn expr_to_dotted_name(expr: &Spanned<Expr>) -> Option<String> {
30    match &expr.node {
31        Expr::Ident(name) => Some(name.clone()),
32        Expr::Attr(obj, field) => expr_to_dotted_name(obj).map(|p| format!("{}.{}", p, field)),
33        _ => None,
34    }
35}
36
37pub(crate) fn call_matches(name: &str, target: &str) -> bool {
38    name == target || name.rsplit('.').next() == Some(target)
39}
40
41pub(crate) fn call_is_in_set(name: &str, targets: &HashSet<String>) -> bool {
42    call_matches_any(name, targets)
43}
44
45pub(crate) fn canonical_callee_name(name: &str, targets: &HashSet<String>) -> Option<String> {
46    if targets.contains(name) {
47        return Some(name.to_string());
48    }
49    name.rsplit('.')
50        .next()
51        .filter(|last| targets.contains(*last))
52        .map(ToString::to_string)
53}
54
55pub(crate) fn call_matches_any(name: &str, targets: &HashSet<String>) -> bool {
56    if targets.contains(name) {
57        return true;
58    }
59    match name.rsplit('.').next() {
60        Some(last) => targets.contains(last),
61        None => false,
62    }
63}
64
65pub(crate) fn is_int_minus_positive(expr: &Spanned<Expr>, param_name: &str) -> bool {
66    match &expr.node {
67        Expr::BinOp(BinOp::Sub, left, right) => {
68            matches!(&left.node, Expr::Ident(id) if id == param_name)
69                && matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
70        }
71        Expr::FnCall(callee, args) => {
72            let Some(name) = expr_to_dotted_name(callee) else {
73                return false;
74            };
75            (name == "Int.sub" || name == "int.sub")
76                && args.len() == 2
77                && matches!(&args[0].node, Expr::Ident(id) if id == param_name)
78                && matches!(&args[1].node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
79        }
80        _ => false,
81    }
82}
83
84pub(crate) fn collect_calls_from_expr<'a>(
85    expr: &'a Spanned<Expr>,
86    out: &mut Vec<(String, Vec<&'a Spanned<Expr>>)>,
87) {
88    match &expr.node {
89        Expr::FnCall(callee, args) => {
90            if let Some(name) = expr_to_dotted_name(callee) {
91                out.push((name, args.iter().collect()));
92            }
93            collect_calls_from_expr(callee, out);
94            for arg in args {
95                collect_calls_from_expr(arg, out);
96            }
97        }
98        Expr::TailCall(boxed) => {
99            let TailCallData {
100                target: name, args, ..
101            } = boxed.as_ref();
102            out.push((name.clone(), args.iter().collect()));
103            for arg in args {
104                collect_calls_from_expr(arg, out);
105            }
106        }
107        Expr::Attr(obj, _) => collect_calls_from_expr(obj, out),
108        Expr::BinOp(_, left, right) => {
109            collect_calls_from_expr(left, out);
110            collect_calls_from_expr(right, out);
111        }
112        Expr::Match { subject, arms, .. } => {
113            collect_calls_from_expr(subject, out);
114            for arm in arms {
115                collect_calls_from_expr(&arm.body, out);
116            }
117        }
118        Expr::Constructor(_, inner) => {
119            if let Some(inner) = inner {
120                collect_calls_from_expr(inner, out);
121            }
122        }
123        Expr::ErrorProp(inner) => collect_calls_from_expr(inner, out),
124        Expr::InterpolatedStr(parts) => {
125            for p in parts {
126                if let crate::ast::StrPart::Parsed(e) = p {
127                    collect_calls_from_expr(e, out);
128                }
129            }
130        }
131        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
132            for item in items {
133                collect_calls_from_expr(item, out);
134            }
135        }
136        Expr::MapLiteral(entries) => {
137            for (k, v) in entries {
138                collect_calls_from_expr(k, out);
139                collect_calls_from_expr(v, out);
140            }
141        }
142        Expr::RecordCreate { fields, .. } => {
143            for (_, v) in fields {
144                collect_calls_from_expr(v, out);
145            }
146        }
147        Expr::RecordUpdate { base, updates, .. } => {
148            collect_calls_from_expr(base, out);
149            for (_, v) in updates {
150                collect_calls_from_expr(v, out);
151            }
152        }
153        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
154    }
155}
156
157pub(crate) fn collect_calls_from_body(body: &FnBody) -> Vec<(String, Vec<&Spanned<Expr>>)> {
158    let mut out = Vec::new();
159    for stmt in body.stmts() {
160        match stmt {
161            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => collect_calls_from_expr(expr, &mut out),
162        }
163    }
164    out
165}
166
167pub(crate) fn collect_list_tail_binders_from_expr(
168    expr: &Spanned<Expr>,
169    list_param_name: &str,
170    tails: &mut HashSet<String>,
171) {
172    match &expr.node {
173        Expr::Match { subject, arms, .. } => {
174            if matches!(&subject.node, Expr::Ident(id) if id == list_param_name) {
175                for MatchArm { pattern, .. } in arms {
176                    if let Pattern::Cons(_, tail) = pattern {
177                        tails.insert(tail.clone());
178                    }
179                }
180            }
181            for arm in arms {
182                collect_list_tail_binders_from_expr(&arm.body, list_param_name, tails);
183            }
184            collect_list_tail_binders_from_expr(subject, list_param_name, tails);
185        }
186        Expr::FnCall(callee, args) => {
187            collect_list_tail_binders_from_expr(callee, list_param_name, tails);
188            for arg in args {
189                collect_list_tail_binders_from_expr(arg, list_param_name, tails);
190            }
191        }
192        Expr::TailCall(boxed) => {
193            let TailCallData {
194                target: _, args, ..
195            } = boxed.as_ref();
196            for arg in args {
197                collect_list_tail_binders_from_expr(arg, list_param_name, tails);
198            }
199        }
200        Expr::Attr(obj, _) => collect_list_tail_binders_from_expr(obj, list_param_name, tails),
201        Expr::BinOp(_, left, right) => {
202            collect_list_tail_binders_from_expr(left, list_param_name, tails);
203            collect_list_tail_binders_from_expr(right, list_param_name, tails);
204        }
205        Expr::Constructor(_, inner) => {
206            if let Some(inner) = inner {
207                collect_list_tail_binders_from_expr(inner, list_param_name, tails);
208            }
209        }
210        Expr::ErrorProp(inner) => {
211            collect_list_tail_binders_from_expr(inner, list_param_name, tails)
212        }
213        Expr::InterpolatedStr(parts) => {
214            for p in parts {
215                if let crate::ast::StrPart::Parsed(e) = p {
216                    collect_list_tail_binders_from_expr(e, list_param_name, tails);
217                }
218            }
219        }
220        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
221            for item in items {
222                collect_list_tail_binders_from_expr(item, list_param_name, tails);
223            }
224        }
225        Expr::MapLiteral(entries) => {
226            for (k, v) in entries {
227                collect_list_tail_binders_from_expr(k, list_param_name, tails);
228                collect_list_tail_binders_from_expr(v, list_param_name, tails);
229            }
230        }
231        Expr::RecordCreate { fields, .. } => {
232            for (_, v) in fields {
233                collect_list_tail_binders_from_expr(v, list_param_name, tails);
234            }
235        }
236        Expr::RecordUpdate { base, updates, .. } => {
237            collect_list_tail_binders_from_expr(base, list_param_name, tails);
238            for (_, v) in updates {
239                collect_list_tail_binders_from_expr(v, list_param_name, tails);
240            }
241        }
242        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
243    }
244}
245
246pub(crate) fn collect_list_tail_binders(fd: &FnDef, list_param_name: &str) -> HashSet<String> {
247    let mut tails = HashSet::new();
248    for stmt in fd.body.stmts() {
249        match stmt {
250            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
251                collect_list_tail_binders_from_expr(expr, list_param_name, &mut tails)
252            }
253        }
254    }
255    tails
256}
257
258pub(crate) fn recursive_constructor_binders(
259    td: &TypeDef,
260    variant_name: &str,
261    binders: &[String],
262) -> Vec<String> {
263    let variant_short = variant_name.rsplit('.').next().unwrap_or(variant_name);
264    match td {
265        TypeDef::Sum { name, variants, .. } => variants
266            .iter()
267            .find(|variant| variant.name == variant_short)
268            .map(|variant| {
269                variant
270                    .fields
271                    .iter()
272                    .zip(binders.iter())
273                    .filter_map(|(field_ty, binder)| {
274                        (field_ty.trim() == name).then_some(binder.clone())
275                    })
276                    .collect()
277            })
278            .unwrap_or_default(),
279        TypeDef::Product { .. } => Vec::new(),
280    }
281}
282
283pub(crate) fn grow_recursive_subterm_binders_from_expr(
284    expr: &Spanned<Expr>,
285    tracked: &HashSet<String>,
286    td: &TypeDef,
287    out: &mut HashSet<String>,
288) {
289    match &expr.node {
290        Expr::Match { subject, arms, .. } => {
291            if let Expr::Ident(subject_name) = &subject.node
292                && tracked.contains(subject_name)
293            {
294                for arm in arms {
295                    if let Pattern::Constructor(variant_name, binders) = &arm.pattern {
296                        out.extend(recursive_constructor_binders(td, variant_name, binders));
297                    }
298                }
299            }
300            grow_recursive_subterm_binders_from_expr(subject, tracked, td, out);
301            for arm in arms {
302                grow_recursive_subterm_binders_from_expr(&arm.body, tracked, td, out);
303            }
304        }
305        Expr::FnCall(callee, args) => {
306            grow_recursive_subterm_binders_from_expr(callee, tracked, td, out);
307            for arg in args {
308                grow_recursive_subterm_binders_from_expr(arg, tracked, td, out);
309            }
310        }
311        Expr::Attr(obj, _) => grow_recursive_subterm_binders_from_expr(obj, tracked, td, out),
312        Expr::BinOp(_, left, right) => {
313            grow_recursive_subterm_binders_from_expr(left, tracked, td, out);
314            grow_recursive_subterm_binders_from_expr(right, tracked, td, out);
315        }
316        Expr::Constructor(_, Some(inner)) | Expr::ErrorProp(inner) => {
317            grow_recursive_subterm_binders_from_expr(inner, tracked, td, out)
318        }
319        Expr::InterpolatedStr(parts) => {
320            for part in parts {
321                if let crate::ast::StrPart::Parsed(inner) = part {
322                    grow_recursive_subterm_binders_from_expr(inner, tracked, td, out);
323                }
324            }
325        }
326        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
327            for item in items {
328                grow_recursive_subterm_binders_from_expr(item, tracked, td, out);
329            }
330        }
331        Expr::MapLiteral(entries) => {
332            for (k, v) in entries {
333                grow_recursive_subterm_binders_from_expr(k, tracked, td, out);
334                grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
335            }
336        }
337        Expr::RecordCreate { fields, .. } => {
338            for (_, v) in fields {
339                grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
340            }
341        }
342        Expr::RecordUpdate { base, updates, .. } => {
343            grow_recursive_subterm_binders_from_expr(base, tracked, td, out);
344            for (_, v) in updates {
345                grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
346            }
347        }
348        Expr::TailCall(boxed) => {
349            for arg in &boxed.args {
350                grow_recursive_subterm_binders_from_expr(arg, tracked, td, out);
351            }
352        }
353        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
354    }
355}
356
357pub(crate) fn collect_recursive_subterm_binders(
358    fd: &FnDef,
359    param_name: &str,
360    param_type: &str,
361    ctx: &CodegenContext,
362) -> HashSet<String> {
363    let Some(td) = find_type_def(ctx, param_type) else {
364        return HashSet::new();
365    };
366    let mut tracked: HashSet<String> = HashSet::from([param_name.to_string()]);
367    loop {
368        let mut discovered = HashSet::new();
369        for stmt in fd.body.stmts() {
370            match stmt {
371                Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
372                    grow_recursive_subterm_binders_from_expr(expr, &tracked, td, &mut discovered);
373                }
374            }
375        }
376        let before = tracked.len();
377        tracked.extend(discovered);
378        if tracked.len() == before {
379            break;
380        }
381    }
382    tracked.remove(param_name);
383    tracked
384}
385
386pub(crate) fn single_int_countdown_param_index(fd: &FnDef) -> Option<usize> {
387    let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
388        .into_iter()
389        .filter(|(name, _)| call_matches(name, &fd.name))
390        .map(|(_, args)| args)
391        .collect();
392    if recursive_calls.is_empty() {
393        return None;
394    }
395
396    fd.params
397        .iter()
398        .enumerate()
399        .find_map(|(idx, (param_name, param_ty))| {
400            if param_ty != "Int" {
401                return None;
402            }
403            let countdown_ok = recursive_calls.iter().all(|args| {
404                args.get(idx)
405                    .cloned()
406                    .is_some_and(|arg| is_int_minus_positive(arg, param_name))
407            });
408            if countdown_ok {
409                return Some(idx);
410            }
411
412            // Negative-guarded ascent (match n < 0) is handled as countdown
413            // because the fuel is natAbs(n) which works for both directions.
414            let ascent_ok = recursive_calls.iter().all(|args| {
415                args.get(idx)
416                    .copied()
417                    .is_some_and(|arg| is_int_plus_positive(arg, param_name))
418            });
419            (ascent_ok && has_negative_guarded_ascent(fd, param_name)).then_some(idx)
420        })
421}
422
423pub(crate) fn has_negative_guarded_ascent(fd: &FnDef, param_name: &str) -> bool {
424    let Some(tail) = fd.body.tail_expr() else {
425        return false;
426    };
427    let Expr::Match { subject, arms, .. } = &tail.node else {
428        return false;
429    };
430    let Expr::BinOp(BinOp::Lt, left, right) = &subject.node else {
431        return false;
432    };
433    if !is_ident(left, param_name)
434        || !matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(0)))
435    {
436        return false;
437    }
438
439    let mut true_arm = None;
440    let mut false_arm = None;
441    for arm in arms {
442        match arm.pattern {
443            Pattern::Literal(crate::ast::Literal::Bool(true)) => true_arm = Some(arm.body.as_ref()),
444            Pattern::Literal(crate::ast::Literal::Bool(false)) => {
445                false_arm = Some(arm.body.as_ref())
446            }
447            _ => return false,
448        }
449    }
450
451    let Some(true_arm) = true_arm else {
452        return false;
453    };
454    let Some(false_arm) = false_arm else {
455        return false;
456    };
457
458    let mut true_calls = Vec::new();
459    collect_calls_from_expr(true_arm, &mut true_calls);
460    let mut false_calls = Vec::new();
461    collect_calls_from_expr(false_arm, &mut false_calls);
462
463    true_calls
464        .iter()
465        .any(|(name, _)| call_matches(name, &fd.name))
466        && false_calls
467            .iter()
468            .all(|(name, _)| !call_matches(name, &fd.name))
469}
470
471/// Detect ascending-index recursion and extract the bound expression
472/// as an Aver AST (`Spanned<Expr>`). Returns `(param_index, bound)`.
473pub(crate) fn single_int_ascending_param(fd: &FnDef) -> Option<(usize, Spanned<Expr>)> {
474    let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
475        .into_iter()
476        .filter(|(name, _)| call_matches(name, &fd.name))
477        .map(|(_, args)| args)
478        .collect();
479    if recursive_calls.is_empty() {
480        return None;
481    }
482
483    for (idx, (param_name, param_ty)) in fd.params.iter().enumerate() {
484        if param_ty != "Int" {
485            continue;
486        }
487        let ascent_ok = recursive_calls.iter().all(|args| {
488            args.get(idx)
489                .cloned()
490                .is_some_and(|arg| is_int_plus_positive(arg, param_name))
491        });
492        if !ascent_ok {
493            continue;
494        }
495        if let Some(bound) = extract_equality_bound_expr(fd, param_name) {
496            return Some((idx, bound));
497        }
498    }
499    None
500}
501
502/// Extract the bound expression from `match param == BOUND` as an
503/// Aver AST node. Each backend renders this into its own idiom (Lean
504/// via `bound_expr_to_lean`, Dafny via its own `emit_expr` path).
505pub(crate) fn extract_equality_bound_expr(fd: &FnDef, param_name: &str) -> Option<Spanned<Expr>> {
506    let tail = fd.body.tail_expr()?;
507    let Expr::Match { subject, arms, .. } = &tail.node else {
508        return None;
509    };
510    let Expr::BinOp(BinOp::Eq, left, right) = &subject.node else {
511        return None;
512    };
513    if !is_ident(left, param_name) {
514        return None;
515    }
516    // Verify: true arm = base (no self-call), false arm = recursive (has self-call)
517    let mut true_has_self = false;
518    let mut false_has_self = false;
519    for arm in arms {
520        match arm.pattern {
521            Pattern::Literal(crate::ast::Literal::Bool(true)) => {
522                let mut calls = Vec::new();
523                collect_calls_from_expr(&arm.body, &mut calls);
524                true_has_self = calls.iter().any(|(n, _)| call_matches(n, &fd.name));
525            }
526            Pattern::Literal(crate::ast::Literal::Bool(false)) => {
527                let mut calls = Vec::new();
528                collect_calls_from_expr(&arm.body, &mut calls);
529                false_has_self = calls.iter().any(|(n, _)| call_matches(n, &fd.name));
530            }
531            _ => return None,
532        }
533    }
534    if true_has_self || !false_has_self {
535        return None;
536    }
537    Some((**right).clone())
538}
539
540pub(crate) fn supports_single_sizeof_structural(fd: &FnDef, ctx: &CodegenContext) -> bool {
541    let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
542        .into_iter()
543        .filter(|(name, _)| call_matches(name, &fd.name))
544        .map(|(_, args)| args)
545        .collect();
546    if recursive_calls.is_empty() {
547        return false;
548    }
549
550    let metric_indices = sizeof_measure_param_indices(fd);
551    if metric_indices.is_empty() {
552        return false;
553    }
554
555    let binder_sets: HashMap<usize, HashSet<String>> = metric_indices
556        .iter()
557        .filter_map(|idx| {
558            let (param_name, param_type) = fd.params.get(*idx)?;
559            recursive_type_names(ctx).contains(param_type).then(|| {
560                (
561                    *idx,
562                    collect_recursive_subterm_binders(fd, param_name, param_type, ctx),
563                )
564            })
565        })
566        .collect();
567
568    if binder_sets.values().all(HashSet::is_empty) {
569        return false;
570    }
571
572    recursive_calls.iter().all(|args| {
573        let mut strictly_smaller = false;
574        for idx in &metric_indices {
575            let Some((param_name, _)) = fd.params.get(*idx) else {
576                return false;
577            };
578            let Some(arg) = args.get(*idx).cloned() else {
579                return false;
580            };
581            if is_ident(arg, param_name) {
582                continue;
583            }
584            let Some(binders) = binder_sets.get(idx) else {
585                return false;
586            };
587            if matches!(&arg.node, Expr::Ident(id) if binders.contains(id)) {
588                strictly_smaller = true;
589                continue;
590            }
591            return false;
592        }
593        strictly_smaller
594    })
595}
596
597pub(crate) fn single_list_structural_param_index(fd: &FnDef) -> Option<usize> {
598    fd.params
599        .iter()
600        .enumerate()
601        .find_map(|(param_index, (param_name, param_ty))| {
602            if !(param_ty.starts_with("List<") || param_ty == "List") {
603                return None;
604            }
605
606            let tails = collect_list_tail_binders(fd, param_name);
607            if tails.is_empty() {
608                return None;
609            }
610
611            let recursive_calls: Vec<Option<&Spanned<Expr>>> =
612                collect_calls_from_body(fd.body.as_ref())
613                    .into_iter()
614                    .filter(|(name, _)| call_matches(name, &fd.name))
615                    .map(|(_, args)| args.get(param_index).cloned())
616                    .collect();
617            if recursive_calls.is_empty() {
618                return None;
619            }
620
621            recursive_calls
622                .into_iter()
623                .all(|arg| {
624                    arg.is_some_and(|a| matches!(&a.node, Expr::Ident(id) if tails.contains(id)))
625                })
626                .then_some(param_index)
627        })
628}
629
630pub(crate) fn is_ident(expr: &Spanned<Expr>, name: &str) -> bool {
631    matches!(&expr.node, Expr::Ident(id) if id == name)
632}
633
634pub(crate) fn is_int_plus_positive(expr: &Spanned<Expr>, param_name: &str) -> bool {
635    match &expr.node {
636        Expr::BinOp(BinOp::Add, left, right) => {
637            matches!(&left.node, Expr::Ident(id) if id == param_name)
638                && matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
639        }
640        Expr::FnCall(callee, args) => {
641            let Some(name) = expr_to_dotted_name(callee) else {
642                return false;
643            };
644            (name == "Int.add" || name == "int.add")
645                && args.len() == 2
646                && matches!(&args[0].node, Expr::Ident(id) if id == param_name)
647                && matches!(&args[1].node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
648        }
649        _ => false,
650    }
651}
652
653pub(crate) fn is_skip_ws_advance(
654    expr: &Spanned<Expr>,
655    string_param: &str,
656    pos_param: &str,
657) -> bool {
658    let Expr::FnCall(callee, args) = &expr.node else {
659        return false;
660    };
661    let Some(name) = expr_to_dotted_name(callee) else {
662        return false;
663    };
664    if !call_matches(&name, "skipWs") || args.len() != 2 {
665        return false;
666    }
667    is_ident(&args[0], string_param) && is_int_plus_positive(&args[1], pos_param)
668}
669
670pub(crate) fn is_skip_ws_same(expr: &Spanned<Expr>, string_param: &str, pos_param: &str) -> bool {
671    let Expr::FnCall(callee, args) = &expr.node else {
672        return false;
673    };
674    let Some(name) = expr_to_dotted_name(callee) else {
675        return false;
676    };
677    if !call_matches(&name, "skipWs") || args.len() != 2 {
678        return false;
679    }
680    is_ident(&args[0], string_param) && is_ident(&args[1], pos_param)
681}
682
683pub(crate) fn is_string_pos_advance(
684    expr: &Spanned<Expr>,
685    string_param: &str,
686    pos_param: &str,
687) -> bool {
688    is_int_plus_positive(expr, pos_param) || is_skip_ws_advance(expr, string_param, pos_param)
689}
690
691#[derive(Clone, Copy, Debug, Eq, PartialEq)]
692pub(crate) enum StringPosEdge {
693    Same,
694    Advance,
695}
696
697pub(crate) fn classify_string_pos_edge(
698    expr: &Spanned<Expr>,
699    string_param: &str,
700    pos_param: &str,
701) -> Option<StringPosEdge> {
702    if is_ident(expr, pos_param) || is_skip_ws_same(expr, string_param, pos_param) {
703        return Some(StringPosEdge::Same);
704    }
705    if is_string_pos_advance(expr, string_param, pos_param) {
706        return Some(StringPosEdge::Advance);
707    }
708    if let Expr::FnCall(callee, args) = &expr.node {
709        let name = expr_to_dotted_name(callee)?;
710        if call_matches(&name, "skipWs")
711            && args.len() == 2
712            && is_ident(&args[0], string_param)
713            && matches!(&args[1].node, Expr::Ident(id) if id != pos_param)
714        {
715            return Some(StringPosEdge::Advance);
716        }
717    }
718    if matches!(&expr.node, Expr::Ident(id) if id != pos_param) {
719        return Some(StringPosEdge::Advance);
720    }
721    None
722}
723
724pub(crate) fn ranks_from_same_edges(
725    names: &HashSet<String>,
726    same_edges: &HashMap<String, HashSet<String>>,
727) -> Option<HashMap<String, usize>> {
728    let mut indegree: HashMap<String, usize> = names.iter().map(|n| (n.clone(), 0)).collect();
729    for outs in same_edges.values() {
730        for to in outs {
731            if let Some(entry) = indegree.get_mut(to) {
732                *entry += 1;
733            } else {
734                return None;
735            }
736        }
737    }
738
739    let mut queue: Vec<String> = indegree
740        .iter()
741        .filter_map(|(name, &deg)| (deg == 0).then_some(name.clone()))
742        .collect();
743    queue.sort();
744    let mut topo = Vec::new();
745    while let Some(node) = queue.pop() {
746        topo.push(node.clone());
747        let outs = same_edges.get(&node).cloned().unwrap_or_default();
748        let mut newly_zero = Vec::new();
749        for to in outs {
750            if let Some(entry) = indegree.get_mut(&to) {
751                *entry -= 1;
752                if *entry == 0 {
753                    newly_zero.push(to);
754                }
755            } else {
756                return None;
757            }
758        }
759        newly_zero.sort();
760        queue.extend(newly_zero);
761    }
762
763    if topo.len() != names.len() {
764        return None;
765    }
766
767    let n = topo.len();
768    let mut ranks = HashMap::new();
769    for (idx, name) in topo.into_iter().enumerate() {
770        ranks.insert(name, n - idx);
771    }
772    Some(ranks)
773}
774
775pub(crate) fn supports_single_string_pos_advance(fd: &FnDef) -> bool {
776    let Some((string_param, string_ty)) = fd.params.first() else {
777        return false;
778    };
779    let Some((pos_param, pos_ty)) = fd.params.get(1) else {
780        return false;
781    };
782    if string_ty != "String" || pos_ty != "Int" {
783        return false;
784    }
785
786    type CallPair<'a> = (Option<&'a Spanned<Expr>>, Option<&'a Spanned<Expr>>);
787    let recursive_calls: Vec<CallPair<'_>> = collect_calls_from_body(fd.body.as_ref())
788        .into_iter()
789        .filter(|(name, _)| call_matches(name, &fd.name))
790        .map(|(_, args)| (args.first().cloned(), args.get(1).cloned()))
791        .collect();
792    if recursive_calls.is_empty() {
793        return false;
794    }
795
796    recursive_calls.into_iter().all(|(arg0, arg1)| {
797        arg0.is_some_and(|e| is_ident(e, string_param))
798            && arg1.is_some_and(|e| is_string_pos_advance(e, string_param, pos_param))
799    })
800}
801
802pub(crate) fn supports_mutual_int_countdown(component: &[&FnDef]) -> bool {
803    if component.len() < 2 {
804        return false;
805    }
806    if component
807        .iter()
808        .any(|fd| !matches!(fd.params.first(), Some((_, t)) if t == "Int"))
809    {
810        return false;
811    }
812    let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
813    let mut any_intra = false;
814    for fd in component {
815        let param_name = &fd.params[0].0;
816        for (callee, args) in collect_calls_from_body(fd.body.as_ref()) {
817            if !call_is_in_set(&callee, &names) {
818                continue;
819            }
820            any_intra = true;
821            let Some(arg0) = args.first().cloned() else {
822                return false;
823            };
824            if !is_int_minus_positive(arg0, param_name) {
825                return false;
826            }
827        }
828    }
829    any_intra
830}
831
832pub(crate) fn supports_mutual_string_pos_advance(
833    component: &[&FnDef],
834) -> Option<HashMap<String, usize>> {
835    if component.len() < 2 {
836        return None;
837    }
838    if component.iter().any(|fd| {
839        !matches!(fd.params.first(), Some((_, t)) if t == "String")
840            || !matches!(fd.params.get(1), Some((_, t)) if t == "Int")
841    }) {
842        return None;
843    }
844
845    let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
846    let mut same_edges: HashMap<String, HashSet<String>> =
847        names.iter().map(|n| (n.clone(), HashSet::new())).collect();
848    let mut any_intra = false;
849
850    for fd in component {
851        let string_param = &fd.params[0].0;
852        let pos_param = &fd.params[1].0;
853        for (callee_raw, args) in collect_calls_from_body(fd.body.as_ref()) {
854            let Some(callee) = canonical_callee_name(&callee_raw, &names) else {
855                continue;
856            };
857            any_intra = true;
858
859            let arg0 = args.first().cloned()?;
860            let arg1 = args.get(1).cloned()?;
861
862            if !is_ident(arg0, string_param) {
863                return None;
864            }
865
866            match classify_string_pos_edge(arg1, string_param, pos_param) {
867                Some(StringPosEdge::Same) => {
868                    if let Some(edges) = same_edges.get_mut(&fd.name) {
869                        edges.insert(callee);
870                    } else {
871                        return None;
872                    }
873                }
874                Some(StringPosEdge::Advance) => {}
875                None => return None,
876            }
877        }
878    }
879
880    if !any_intra {
881        return None;
882    }
883
884    ranks_from_same_edges(&names, &same_edges)
885}
886
887pub(crate) fn is_scalar_like_type(type_name: &str) -> bool {
888    matches!(
889        type_name,
890        "Int" | "Float" | "Bool" | "String" | "Char" | "Byte" | "Unit"
891    )
892}
893
894pub(crate) fn supports_mutual_sizeof_ranked(
895    component: &[&FnDef],
896) -> Option<HashMap<String, usize>> {
897    if component.len() < 2 {
898        return None;
899    }
900    let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
901    let metric_indices: HashMap<String, Vec<usize>> = component
902        .iter()
903        .map(|fd| (fd.name.clone(), sizeof_measure_param_indices(fd)))
904        .collect();
905    if component.iter().any(|fd| {
906        metric_indices
907            .get(&fd.name)
908            .is_none_or(|indices| indices.is_empty())
909    }) {
910        return None;
911    }
912
913    let mut same_edges: HashMap<String, HashSet<String>> =
914        names.iter().map(|n| (n.clone(), HashSet::new())).collect();
915    let mut any_intra = false;
916    for fd in component {
917        let caller_metric_indices = metric_indices.get(&fd.name)?;
918        let caller_metric_params: Vec<&str> = caller_metric_indices
919            .iter()
920            .filter_map(|idx| fd.params.get(*idx).map(|(name, _)| name.as_str()))
921            .collect();
922        for (callee_raw, args) in collect_calls_from_body(fd.body.as_ref()) {
923            let Some(callee) = canonical_callee_name(&callee_raw, &names) else {
924                continue;
925            };
926            any_intra = true;
927            let callee_metric_indices = metric_indices.get(&callee)?;
928            let is_same_edge = callee_metric_indices.len() == caller_metric_params.len()
929                && callee_metric_indices
930                    .iter()
931                    .enumerate()
932                    .all(|(pos, callee_idx)| {
933                        let Some(arg) = args.get(*callee_idx).cloned() else {
934                            return false;
935                        };
936                        is_ident(arg, caller_metric_params[pos])
937                    });
938            if is_same_edge {
939                if let Some(edges) = same_edges.get_mut(&fd.name) {
940                    edges.insert(callee);
941                } else {
942                    return None;
943                }
944            }
945        }
946    }
947    if !any_intra {
948        return None;
949    }
950
951    let ranks = ranks_from_same_edges(&names, &same_edges)?;
952    let mut out = HashMap::new();
953    for fd in component {
954        let rank = ranks.get(&fd.name).cloned()?;
955        out.insert(fd.name.clone(), rank);
956    }
957    Some(out)
958}
959
960/// Classify every recursive pure fn in `ctx`. The returned map assigns
961/// each supported function a [`RecursionPlan`]; anything that falls
962/// outside the recognised shapes becomes a [`ProofModeIssue`].
963pub fn analyze_plans(
964    ctx: &CodegenContext,
965) -> (HashMap<String, RecursionPlan>, Vec<ProofModeIssue>) {
966    let mut plans = HashMap::new();
967    let mut issues = Vec::new();
968
969    let all_pure = pure_fns(ctx);
970    let recursive_names = recursive_pure_fn_names(ctx);
971    let components = call_graph::ordered_fn_components(&all_pure, &ctx.module_prefixes);
972
973    for component in components {
974        if component.is_empty() {
975            continue;
976        }
977        let is_recursive_component =
978            component.len() > 1 || recursive_names.contains(&component[0].name);
979        if !is_recursive_component {
980            continue;
981        }
982
983        if component.len() > 1 {
984            if supports_mutual_int_countdown(&component) {
985                for fd in &component {
986                    plans.insert(fd.name.clone(), RecursionPlan::MutualIntCountdown);
987                }
988            } else if let Some(ranks) = supports_mutual_string_pos_advance(&component) {
989                for fd in &component {
990                    if let Some(rank) = ranks.get(&fd.name).cloned() {
991                        plans.insert(
992                            fd.name.clone(),
993                            RecursionPlan::MutualStringPosAdvance { rank },
994                        );
995                    }
996                }
997            } else if let Some(rankings) = supports_mutual_sizeof_ranked(&component) {
998                for fd in &component {
999                    if let Some(rank) = rankings.get(&fd.name).cloned() {
1000                        plans.insert(fd.name.clone(), RecursionPlan::MutualSizeOfRanked { rank });
1001                    }
1002                }
1003            } else {
1004                let names = component
1005                    .iter()
1006                    .map(|fd| fd.name.clone())
1007                    .collect::<Vec<_>>()
1008                    .join(", ");
1009                let line = component.iter().map(|fd| fd.line).min().unwrap_or(1);
1010                issues.push(ProofModeIssue {
1011                    line,
1012                    message: format!(
1013                        "unsupported mutual recursion group (currently supported in proof mode: Int countdown on first param): {}",
1014                        names
1015                    ),
1016                });
1017            }
1018            continue;
1019        }
1020
1021        let fd = component[0];
1022        if crate::codegen::lean::recurrence::detect_second_order_int_linear_recurrence(fd).is_some()
1023        {
1024            plans.insert(fd.name.clone(), RecursionPlan::LinearRecurrence2);
1025        } else if let Some((param_index, bound)) = single_int_ascending_param(fd) {
1026            plans.insert(
1027                fd.name.clone(),
1028                RecursionPlan::IntAscending { param_index, bound },
1029            );
1030        } else if let Some(param_index) = single_int_countdown_param_index(fd) {
1031            plans.insert(fd.name.clone(), RecursionPlan::IntCountdown { param_index });
1032        } else if supports_single_sizeof_structural(fd, ctx) {
1033            plans.insert(fd.name.clone(), RecursionPlan::SizeOfStructural);
1034        } else if let Some(param_index) = single_list_structural_param_index(fd) {
1035            plans.insert(
1036                fd.name.clone(),
1037                RecursionPlan::ListStructural { param_index },
1038            );
1039        } else if supports_single_string_pos_advance(fd) {
1040            plans.insert(fd.name.clone(), RecursionPlan::StringPosAdvance);
1041        } else {
1042            issues.push(ProofModeIssue {
1043                line: fd.line,
1044                message: format!(
1045                    "recursive function '{}' is outside proof subset (currently supported: Int countdown, second-order affine Int recurrences with pair-state worker, structural recursion on List/recursive ADTs, String+position, mutual Int countdown, mutual String+position, and ranked sizeOf recursion)",
1046                    fd.name
1047                ),
1048            });
1049        }
1050    }
1051
1052    (plans, issues)
1053}