1use std::collections::{HashMap, HashSet};
16
17use crate::ast::{
18 BinOp, Expr, FnBody, FnDef, MatchArm, Pattern, Spanned, Stmt, TailCallData, TypeDef,
19};
20use crate::call_graph;
21use crate::codegen::CodegenContext;
22use crate::codegen::lean::{
23 find_type_def, pure_fns, recursive_pure_fn_names, recursive_type_names,
24 sizeof_measure_param_indices,
25};
26
27use super::{ProofModeIssue, RecursionPlan};
28
29pub(crate) fn expr_to_dotted_name(expr: &Spanned<Expr>) -> Option<String> {
30 match &expr.node {
31 Expr::Ident(name) => Some(name.clone()),
32 Expr::Attr(obj, field) => expr_to_dotted_name(obj).map(|p| format!("{}.{}", p, field)),
33 _ => None,
34 }
35}
36
37pub(crate) fn local_name_of(expr: &Spanned<Expr>) -> Option<&str> {
45 match &expr.node {
46 Expr::Ident(name) => Some(name.as_str()),
47 Expr::Resolved { name, .. } => Some(name.as_str()),
48 _ => None,
49 }
50}
51
52pub(crate) fn call_matches(name: &str, target: &str) -> bool {
53 name == target || name.rsplit('.').next() == Some(target)
54}
55
56pub(crate) fn call_is_in_set(name: &str, targets: &HashSet<String>) -> bool {
57 call_matches_any(name, targets)
58}
59
60pub(crate) fn canonical_callee_name(name: &str, targets: &HashSet<String>) -> Option<String> {
61 if targets.contains(name) {
62 return Some(name.to_string());
63 }
64 name.rsplit('.')
65 .next()
66 .filter(|last| targets.contains(*last))
67 .map(ToString::to_string)
68}
69
70pub(crate) fn call_matches_any(name: &str, targets: &HashSet<String>) -> bool {
71 if targets.contains(name) {
72 return true;
73 }
74 match name.rsplit('.').next() {
75 Some(last) => targets.contains(last),
76 None => false,
77 }
78}
79
80pub(crate) fn is_int_minus_positive(expr: &Spanned<Expr>, param_name: &str) -> bool {
81 match &expr.node {
82 Expr::BinOp(BinOp::Sub, left, right) => {
83 local_name_of(left).is_some_and(|id| id == param_name)
84 && matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
85 }
86 Expr::FnCall(callee, args) => {
87 let Some(name) = expr_to_dotted_name(callee) else {
88 return false;
89 };
90 (name == "Int.sub" || name == "int.sub")
91 && args.len() == 2
92 && local_name_of(&args[0]).is_some_and(|id| id == param_name)
93 && matches!(&args[1].node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
94 }
95 _ => false,
96 }
97}
98
99pub(crate) fn collect_calls_from_expr<'a>(
100 expr: &'a Spanned<Expr>,
101 out: &mut Vec<(String, Vec<&'a Spanned<Expr>>)>,
102) {
103 match &expr.node {
104 Expr::FnCall(callee, args) => {
105 if let Some(name) = expr_to_dotted_name(callee) {
106 out.push((name, args.iter().collect()));
107 }
108 collect_calls_from_expr(callee, out);
109 for arg in args {
110 collect_calls_from_expr(arg, out);
111 }
112 }
113 Expr::TailCall(boxed) => {
114 let TailCallData {
115 target: name, args, ..
116 } = boxed.as_ref();
117 out.push((name.clone(), args.iter().collect()));
118 for arg in args {
119 collect_calls_from_expr(arg, out);
120 }
121 }
122 Expr::Attr(obj, _) => collect_calls_from_expr(obj, out),
123 Expr::BinOp(_, left, right) => {
124 collect_calls_from_expr(left, out);
125 collect_calls_from_expr(right, out);
126 }
127 Expr::Match { subject, arms, .. } => {
128 collect_calls_from_expr(subject, out);
129 for arm in arms {
130 collect_calls_from_expr(&arm.body, out);
131 }
132 }
133 Expr::Constructor(_, inner) => {
134 if let Some(inner) = inner {
135 collect_calls_from_expr(inner, out);
136 }
137 }
138 Expr::ErrorProp(inner) => collect_calls_from_expr(inner, out),
139 Expr::InterpolatedStr(parts) => {
140 for p in parts {
141 if let crate::ast::StrPart::Parsed(e) = p {
142 collect_calls_from_expr(e, out);
143 }
144 }
145 }
146 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
147 for item in items {
148 collect_calls_from_expr(item, out);
149 }
150 }
151 Expr::MapLiteral(entries) => {
152 for (k, v) in entries {
153 collect_calls_from_expr(k, out);
154 collect_calls_from_expr(v, out);
155 }
156 }
157 Expr::RecordCreate { fields, .. } => {
158 for (_, v) in fields {
159 collect_calls_from_expr(v, out);
160 }
161 }
162 Expr::RecordUpdate { base, updates, .. } => {
163 collect_calls_from_expr(base, out);
164 for (_, v) in updates {
165 collect_calls_from_expr(v, out);
166 }
167 }
168 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
169 }
170}
171
172pub(crate) fn collect_calls_from_body(body: &FnBody) -> Vec<(String, Vec<&Spanned<Expr>>)> {
173 let mut out = Vec::new();
174 for stmt in body.stmts() {
175 match stmt {
176 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => collect_calls_from_expr(expr, &mut out),
177 }
178 }
179 out
180}
181
182pub(crate) fn collect_list_tail_binders_from_expr(
183 expr: &Spanned<Expr>,
184 list_param_name: &str,
185 tails: &mut HashSet<String>,
186) {
187 match &expr.node {
188 Expr::Match { subject, arms, .. } => {
189 if local_name_of(subject).is_some_and(|id| id == list_param_name) {
190 for MatchArm { pattern, .. } in arms {
191 if let Pattern::Cons(_, tail) = pattern {
192 tails.insert(tail.clone());
193 }
194 }
195 }
196 for arm in arms {
197 collect_list_tail_binders_from_expr(&arm.body, list_param_name, tails);
198 }
199 collect_list_tail_binders_from_expr(subject, list_param_name, tails);
200 }
201 Expr::FnCall(callee, args) => {
202 collect_list_tail_binders_from_expr(callee, list_param_name, tails);
203 for arg in args {
204 collect_list_tail_binders_from_expr(arg, list_param_name, tails);
205 }
206 }
207 Expr::TailCall(boxed) => {
208 let TailCallData {
209 target: _, args, ..
210 } = boxed.as_ref();
211 for arg in args {
212 collect_list_tail_binders_from_expr(arg, list_param_name, tails);
213 }
214 }
215 Expr::Attr(obj, _) => collect_list_tail_binders_from_expr(obj, list_param_name, tails),
216 Expr::BinOp(_, left, right) => {
217 collect_list_tail_binders_from_expr(left, list_param_name, tails);
218 collect_list_tail_binders_from_expr(right, list_param_name, tails);
219 }
220 Expr::Constructor(_, inner) => {
221 if let Some(inner) = inner {
222 collect_list_tail_binders_from_expr(inner, list_param_name, tails);
223 }
224 }
225 Expr::ErrorProp(inner) => {
226 collect_list_tail_binders_from_expr(inner, list_param_name, tails)
227 }
228 Expr::InterpolatedStr(parts) => {
229 for p in parts {
230 if let crate::ast::StrPart::Parsed(e) = p {
231 collect_list_tail_binders_from_expr(e, list_param_name, tails);
232 }
233 }
234 }
235 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
236 for item in items {
237 collect_list_tail_binders_from_expr(item, list_param_name, tails);
238 }
239 }
240 Expr::MapLiteral(entries) => {
241 for (k, v) in entries {
242 collect_list_tail_binders_from_expr(k, list_param_name, tails);
243 collect_list_tail_binders_from_expr(v, list_param_name, tails);
244 }
245 }
246 Expr::RecordCreate { fields, .. } => {
247 for (_, v) in fields {
248 collect_list_tail_binders_from_expr(v, list_param_name, tails);
249 }
250 }
251 Expr::RecordUpdate { base, updates, .. } => {
252 collect_list_tail_binders_from_expr(base, list_param_name, tails);
253 for (_, v) in updates {
254 collect_list_tail_binders_from_expr(v, list_param_name, tails);
255 }
256 }
257 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
258 }
259}
260
261pub(crate) fn collect_list_tail_binders(fd: &FnDef, list_param_name: &str) -> HashSet<String> {
262 let mut tails = HashSet::new();
263 for stmt in fd.body.stmts() {
264 match stmt {
265 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
266 collect_list_tail_binders_from_expr(expr, list_param_name, &mut tails)
267 }
268 }
269 }
270 tails
271}
272
273pub(crate) fn recursive_constructor_binders(
274 td: &TypeDef,
275 variant_name: &str,
276 binders: &[String],
277) -> Vec<String> {
278 let variant_short = variant_name.rsplit('.').next().unwrap_or(variant_name);
279 match td {
280 TypeDef::Sum { name, variants, .. } => variants
281 .iter()
282 .find(|variant| variant.name == variant_short)
283 .map(|variant| {
284 variant
285 .fields
286 .iter()
287 .zip(binders.iter())
288 .filter_map(|(field_ty, binder)| {
289 (field_ty.trim() == name).then_some(binder.clone())
290 })
291 .collect()
292 })
293 .unwrap_or_default(),
294 TypeDef::Product { .. } => Vec::new(),
295 }
296}
297
298pub(crate) fn grow_recursive_subterm_binders_from_expr(
299 expr: &Spanned<Expr>,
300 tracked: &HashSet<String>,
301 td: &TypeDef,
302 out: &mut HashSet<String>,
303) {
304 match &expr.node {
305 Expr::Match { subject, arms, .. } => {
306 if let Expr::Ident(subject_name) = &subject.node
307 && tracked.contains(subject_name)
308 {
309 for arm in arms {
310 if let Pattern::Constructor(variant_name, binders) = &arm.pattern {
311 out.extend(recursive_constructor_binders(td, variant_name, binders));
312 }
313 }
314 }
315 grow_recursive_subterm_binders_from_expr(subject, tracked, td, out);
316 for arm in arms {
317 grow_recursive_subterm_binders_from_expr(&arm.body, tracked, td, out);
318 }
319 }
320 Expr::FnCall(callee, args) => {
321 grow_recursive_subterm_binders_from_expr(callee, tracked, td, out);
322 for arg in args {
323 grow_recursive_subterm_binders_from_expr(arg, tracked, td, out);
324 }
325 }
326 Expr::Attr(obj, _) => grow_recursive_subterm_binders_from_expr(obj, tracked, td, out),
327 Expr::BinOp(_, left, right) => {
328 grow_recursive_subterm_binders_from_expr(left, tracked, td, out);
329 grow_recursive_subterm_binders_from_expr(right, tracked, td, out);
330 }
331 Expr::Constructor(_, Some(inner)) | Expr::ErrorProp(inner) => {
332 grow_recursive_subterm_binders_from_expr(inner, tracked, td, out)
333 }
334 Expr::InterpolatedStr(parts) => {
335 for part in parts {
336 if let crate::ast::StrPart::Parsed(inner) = part {
337 grow_recursive_subterm_binders_from_expr(inner, tracked, td, out);
338 }
339 }
340 }
341 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
342 for item in items {
343 grow_recursive_subterm_binders_from_expr(item, tracked, td, out);
344 }
345 }
346 Expr::MapLiteral(entries) => {
347 for (k, v) in entries {
348 grow_recursive_subterm_binders_from_expr(k, tracked, td, out);
349 grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
350 }
351 }
352 Expr::RecordCreate { fields, .. } => {
353 for (_, v) in fields {
354 grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
355 }
356 }
357 Expr::RecordUpdate { base, updates, .. } => {
358 grow_recursive_subterm_binders_from_expr(base, tracked, td, out);
359 for (_, v) in updates {
360 grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
361 }
362 }
363 Expr::TailCall(boxed) => {
364 for arg in &boxed.args {
365 grow_recursive_subterm_binders_from_expr(arg, tracked, td, out);
366 }
367 }
368 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
369 }
370}
371
372pub(crate) fn collect_recursive_subterm_binders(
373 fd: &FnDef,
374 param_name: &str,
375 param_type: &str,
376 ctx: &CodegenContext,
377) -> HashSet<String> {
378 let Some(td) = find_type_def(ctx, param_type) else {
379 return HashSet::new();
380 };
381 let mut tracked: HashSet<String> = HashSet::from([param_name.to_string()]);
382 loop {
383 let mut discovered = HashSet::new();
384 for stmt in fd.body.stmts() {
385 match stmt {
386 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
387 grow_recursive_subterm_binders_from_expr(expr, &tracked, td, &mut discovered);
388 }
389 }
390 }
391 let before = tracked.len();
392 tracked.extend(discovered);
393 if tracked.len() == before {
394 break;
395 }
396 }
397 tracked.remove(param_name);
398 tracked
399}
400
401pub(crate) fn single_int_countdown_param_index(fd: &FnDef) -> Option<usize> {
402 let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
403 .into_iter()
404 .filter(|(name, _)| call_matches(name, &fd.name))
405 .map(|(_, args)| args)
406 .collect();
407 if recursive_calls.is_empty() {
408 return None;
409 }
410
411 fd.params
412 .iter()
413 .enumerate()
414 .find_map(|(idx, (param_name, param_ty))| {
415 if param_ty != "Int" {
416 return None;
417 }
418 let countdown_ok = recursive_calls.iter().all(|args| {
419 args.get(idx)
420 .cloned()
421 .is_some_and(|arg| is_int_minus_positive(arg, param_name))
422 });
423 if countdown_ok {
424 return Some(idx);
425 }
426
427 let ascent_ok = recursive_calls.iter().all(|args| {
430 args.get(idx)
431 .copied()
432 .is_some_and(|arg| is_int_plus_positive(arg, param_name))
433 });
434 (ascent_ok && has_negative_guarded_ascent(fd, param_name)).then_some(idx)
435 })
436}
437
438pub(crate) fn has_negative_guarded_ascent(fd: &FnDef, param_name: &str) -> bool {
439 let Some(tail) = fd.body.tail_expr() else {
440 return false;
441 };
442 let Expr::Match { subject, arms, .. } = &tail.node else {
443 return false;
444 };
445 let Expr::BinOp(BinOp::Lt, left, right) = &subject.node else {
446 return false;
447 };
448 if !is_ident(left, param_name)
449 || !matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(0)))
450 {
451 return false;
452 }
453
454 let mut true_arm = None;
455 let mut false_arm = None;
456 for arm in arms {
457 match arm.pattern {
458 Pattern::Literal(crate::ast::Literal::Bool(true)) => true_arm = Some(arm.body.as_ref()),
459 Pattern::Literal(crate::ast::Literal::Bool(false)) => {
460 false_arm = Some(arm.body.as_ref())
461 }
462 _ => return false,
463 }
464 }
465
466 let Some(true_arm) = true_arm else {
467 return false;
468 };
469 let Some(false_arm) = false_arm else {
470 return false;
471 };
472
473 let mut true_calls = Vec::new();
474 collect_calls_from_expr(true_arm, &mut true_calls);
475 let mut false_calls = Vec::new();
476 collect_calls_from_expr(false_arm, &mut false_calls);
477
478 true_calls
479 .iter()
480 .any(|(name, _)| call_matches(name, &fd.name))
481 && false_calls
482 .iter()
483 .all(|(name, _)| !call_matches(name, &fd.name))
484}
485
486pub(crate) fn single_int_ascending_param(fd: &FnDef) -> Option<(usize, Spanned<Expr>)> {
489 let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
490 .into_iter()
491 .filter(|(name, _)| call_matches(name, &fd.name))
492 .map(|(_, args)| args)
493 .collect();
494 if recursive_calls.is_empty() {
495 return None;
496 }
497
498 for (idx, (param_name, param_ty)) in fd.params.iter().enumerate() {
499 if param_ty != "Int" {
500 continue;
501 }
502 let ascent_ok = recursive_calls.iter().all(|args| {
503 args.get(idx)
504 .cloned()
505 .is_some_and(|arg| is_int_plus_positive(arg, param_name))
506 });
507 if !ascent_ok {
508 continue;
509 }
510 if let Some(bound) = extract_equality_bound_expr(fd, param_name) {
511 return Some((idx, bound));
512 }
513 }
514 None
515}
516
517pub(crate) fn extract_equality_bound_expr(fd: &FnDef, param_name: &str) -> Option<Spanned<Expr>> {
521 let tail = fd.body.tail_expr()?;
522 let Expr::Match { subject, arms, .. } = &tail.node else {
523 return None;
524 };
525 let Expr::BinOp(BinOp::Eq, left, right) = &subject.node else {
526 return None;
527 };
528 if !is_ident(left, param_name) {
529 return None;
530 }
531 let mut true_has_self = false;
533 let mut false_has_self = false;
534 for arm in arms {
535 match arm.pattern {
536 Pattern::Literal(crate::ast::Literal::Bool(true)) => {
537 let mut calls = Vec::new();
538 collect_calls_from_expr(&arm.body, &mut calls);
539 true_has_self = calls.iter().any(|(n, _)| call_matches(n, &fd.name));
540 }
541 Pattern::Literal(crate::ast::Literal::Bool(false)) => {
542 let mut calls = Vec::new();
543 collect_calls_from_expr(&arm.body, &mut calls);
544 false_has_self = calls.iter().any(|(n, _)| call_matches(n, &fd.name));
545 }
546 _ => return None,
547 }
548 }
549 if true_has_self || !false_has_self {
550 return None;
551 }
552 Some((**right).clone())
553}
554
555pub(crate) fn supports_single_sizeof_structural(fd: &FnDef, ctx: &CodegenContext) -> bool {
556 let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
557 .into_iter()
558 .filter(|(name, _)| call_matches(name, &fd.name))
559 .map(|(_, args)| args)
560 .collect();
561 if recursive_calls.is_empty() {
562 return false;
563 }
564
565 let metric_indices = sizeof_measure_param_indices(fd);
566 if metric_indices.is_empty() {
567 return false;
568 }
569
570 let binder_sets: HashMap<usize, HashSet<String>> = metric_indices
571 .iter()
572 .filter_map(|idx| {
573 let (param_name, param_type) = fd.params.get(*idx)?;
574 recursive_type_names(ctx).contains(param_type).then(|| {
575 (
576 *idx,
577 collect_recursive_subterm_binders(fd, param_name, param_type, ctx),
578 )
579 })
580 })
581 .collect();
582
583 if binder_sets.values().all(HashSet::is_empty) {
584 return false;
585 }
586
587 recursive_calls.iter().all(|args| {
588 let mut strictly_smaller = false;
589 for idx in &metric_indices {
590 let Some((param_name, _)) = fd.params.get(*idx) else {
591 return false;
592 };
593 let Some(arg) = args.get(*idx).cloned() else {
594 return false;
595 };
596 if is_ident(arg, param_name) {
597 continue;
598 }
599 let Some(binders) = binder_sets.get(idx) else {
600 return false;
601 };
602 if local_name_of(arg).is_some_and(|id| binders.contains(id)) {
603 strictly_smaller = true;
604 continue;
605 }
606 return false;
607 }
608 strictly_smaller
609 })
610}
611
612pub(crate) fn single_list_structural_param_index(fd: &FnDef) -> Option<usize> {
613 fd.params
614 .iter()
615 .enumerate()
616 .find_map(|(param_index, (param_name, param_ty))| {
617 if !(param_ty.starts_with("List<") || param_ty == "List") {
618 return None;
619 }
620
621 let tails = collect_list_tail_binders(fd, param_name);
622 if tails.is_empty() {
623 return None;
624 }
625
626 let recursive_calls: Vec<Option<&Spanned<Expr>>> =
627 collect_calls_from_body(fd.body.as_ref())
628 .into_iter()
629 .filter(|(name, _)| call_matches(name, &fd.name))
630 .map(|(_, args)| args.get(param_index).cloned())
631 .collect();
632 if recursive_calls.is_empty() {
633 return None;
634 }
635
636 recursive_calls
637 .into_iter()
638 .all(|arg| {
639 arg.and_then(local_name_of)
640 .is_some_and(|id| tails.contains(id))
641 })
642 .then_some(param_index)
643 })
644}
645
646pub(crate) fn is_ident(expr: &Spanned<Expr>, name: &str) -> bool {
647 local_name_of(expr).is_some_and(|id| id == name)
648}
649
650pub(crate) fn is_int_plus_positive(expr: &Spanned<Expr>, param_name: &str) -> bool {
651 match &expr.node {
652 Expr::BinOp(BinOp::Add, left, right) => {
653 local_name_of(left).is_some_and(|id| id == param_name)
654 && matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
655 }
656 Expr::FnCall(callee, args) => {
657 let Some(name) = expr_to_dotted_name(callee) else {
658 return false;
659 };
660 (name == "Int.add" || name == "int.add")
661 && args.len() == 2
662 && local_name_of(&args[0]).is_some_and(|id| id == param_name)
663 && matches!(&args[1].node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
664 }
665 _ => false,
666 }
667}
668
669pub(crate) fn is_skip_ws_advance(
670 expr: &Spanned<Expr>,
671 string_param: &str,
672 pos_param: &str,
673) -> bool {
674 let Expr::FnCall(callee, args) = &expr.node else {
675 return false;
676 };
677 let Some(name) = expr_to_dotted_name(callee) else {
678 return false;
679 };
680 if !call_matches(&name, "skipWs") || args.len() != 2 {
681 return false;
682 }
683 is_ident(&args[0], string_param) && is_int_plus_positive(&args[1], pos_param)
684}
685
686pub(crate) fn is_skip_ws_same(expr: &Spanned<Expr>, string_param: &str, pos_param: &str) -> bool {
687 let Expr::FnCall(callee, args) = &expr.node else {
688 return false;
689 };
690 let Some(name) = expr_to_dotted_name(callee) else {
691 return false;
692 };
693 if !call_matches(&name, "skipWs") || args.len() != 2 {
694 return false;
695 }
696 is_ident(&args[0], string_param) && is_ident(&args[1], pos_param)
697}
698
699pub(crate) fn is_string_pos_advance(
700 expr: &Spanned<Expr>,
701 string_param: &str,
702 pos_param: &str,
703) -> bool {
704 is_int_plus_positive(expr, pos_param) || is_skip_ws_advance(expr, string_param, pos_param)
705}
706
707#[derive(Clone, Copy, Debug, Eq, PartialEq)]
708pub(crate) enum StringPosEdge {
709 Same,
710 Advance,
711}
712
713pub(crate) fn classify_string_pos_edge(
714 expr: &Spanned<Expr>,
715 string_param: &str,
716 pos_param: &str,
717) -> Option<StringPosEdge> {
718 if is_ident(expr, pos_param) || is_skip_ws_same(expr, string_param, pos_param) {
719 return Some(StringPosEdge::Same);
720 }
721 if is_string_pos_advance(expr, string_param, pos_param) {
722 return Some(StringPosEdge::Advance);
723 }
724 if let Expr::FnCall(callee, args) = &expr.node {
725 let name = expr_to_dotted_name(callee)?;
726 if call_matches(&name, "skipWs")
727 && args.len() == 2
728 && is_ident(&args[0], string_param)
729 && local_name_of(&args[1]).is_some_and(|id| id != pos_param)
730 {
731 return Some(StringPosEdge::Advance);
732 }
733 }
734 if local_name_of(expr).is_some_and(|id| id != pos_param) {
735 return Some(StringPosEdge::Advance);
736 }
737 None
738}
739
740pub(crate) fn ranks_from_same_edges(
741 names: &HashSet<String>,
742 same_edges: &HashMap<String, HashSet<String>>,
743) -> Option<HashMap<String, usize>> {
744 let mut indegree: HashMap<String, usize> = names.iter().map(|n| (n.clone(), 0)).collect();
745 for outs in same_edges.values() {
746 for to in outs {
747 if let Some(entry) = indegree.get_mut(to) {
748 *entry += 1;
749 } else {
750 return None;
751 }
752 }
753 }
754
755 let mut queue: Vec<String> = indegree
756 .iter()
757 .filter_map(|(name, °)| (deg == 0).then_some(name.clone()))
758 .collect();
759 queue.sort();
760 let mut topo = Vec::new();
761 while let Some(node) = queue.pop() {
762 topo.push(node.clone());
763 let outs = same_edges.get(&node).cloned().unwrap_or_default();
764 let mut newly_zero = Vec::new();
765 for to in outs {
766 if let Some(entry) = indegree.get_mut(&to) {
767 *entry -= 1;
768 if *entry == 0 {
769 newly_zero.push(to);
770 }
771 } else {
772 return None;
773 }
774 }
775 newly_zero.sort();
776 queue.extend(newly_zero);
777 }
778
779 if topo.len() != names.len() {
780 return None;
781 }
782
783 let n = topo.len();
784 let mut ranks = HashMap::new();
785 for (idx, name) in topo.into_iter().enumerate() {
786 ranks.insert(name, n - idx);
787 }
788 Some(ranks)
789}
790
791pub(crate) fn supports_single_string_pos_advance(fd: &FnDef) -> bool {
792 let Some((string_param, string_ty)) = fd.params.first() else {
793 return false;
794 };
795 let Some((pos_param, pos_ty)) = fd.params.get(1) else {
796 return false;
797 };
798 if string_ty != "String" || pos_ty != "Int" {
799 return false;
800 }
801
802 type CallPair<'a> = (Option<&'a Spanned<Expr>>, Option<&'a Spanned<Expr>>);
803 let recursive_calls: Vec<CallPair<'_>> = collect_calls_from_body(fd.body.as_ref())
804 .into_iter()
805 .filter(|(name, _)| call_matches(name, &fd.name))
806 .map(|(_, args)| (args.first().cloned(), args.get(1).cloned()))
807 .collect();
808 if recursive_calls.is_empty() {
809 return false;
810 }
811
812 recursive_calls.into_iter().all(|(arg0, arg1)| {
813 arg0.is_some_and(|e| is_ident(e, string_param))
814 && arg1.is_some_and(|e| is_string_pos_advance(e, string_param, pos_param))
815 })
816}
817
818pub(crate) fn supports_mutual_int_countdown(component: &[&FnDef]) -> bool {
819 if component.len() < 2 {
820 return false;
821 }
822 if component
823 .iter()
824 .any(|fd| !matches!(fd.params.first(), Some((_, t)) if t == "Int"))
825 {
826 return false;
827 }
828 let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
829 let mut any_intra = false;
830 for fd in component {
831 let param_name = &fd.params[0].0;
832 for (callee, args) in collect_calls_from_body(fd.body.as_ref()) {
833 if !call_is_in_set(&callee, &names) {
834 continue;
835 }
836 any_intra = true;
837 let Some(arg0) = args.first().cloned() else {
838 return false;
839 };
840 if !is_int_minus_positive(arg0, param_name) {
841 return false;
842 }
843 }
844 }
845 any_intra
846}
847
848pub(crate) fn supports_mutual_string_pos_advance(
849 component: &[&FnDef],
850) -> Option<HashMap<String, usize>> {
851 if component.len() < 2 {
852 return None;
853 }
854 if component.iter().any(|fd| {
855 !matches!(fd.params.first(), Some((_, t)) if t == "String")
856 || !matches!(fd.params.get(1), Some((_, t)) if t == "Int")
857 }) {
858 return None;
859 }
860
861 let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
862 let mut same_edges: HashMap<String, HashSet<String>> =
863 names.iter().map(|n| (n.clone(), HashSet::new())).collect();
864 let mut any_intra = false;
865
866 for fd in component {
867 let string_param = &fd.params[0].0;
868 let pos_param = &fd.params[1].0;
869 for (callee_raw, args) in collect_calls_from_body(fd.body.as_ref()) {
870 let Some(callee) = canonical_callee_name(&callee_raw, &names) else {
871 continue;
872 };
873 any_intra = true;
874
875 let arg0 = args.first().cloned()?;
876 let arg1 = args.get(1).cloned()?;
877
878 if !is_ident(arg0, string_param) {
879 return None;
880 }
881
882 match classify_string_pos_edge(arg1, string_param, pos_param) {
883 Some(StringPosEdge::Same) => {
884 if let Some(edges) = same_edges.get_mut(&fd.name) {
885 edges.insert(callee);
886 } else {
887 return None;
888 }
889 }
890 Some(StringPosEdge::Advance) => {}
891 None => return None,
892 }
893 }
894 }
895
896 if !any_intra {
897 return None;
898 }
899
900 ranks_from_same_edges(&names, &same_edges)
901}
902
903pub(crate) fn is_scalar_like_type(type_name: &str) -> bool {
904 matches!(
905 type_name,
906 "Int" | "Float" | "Bool" | "String" | "Char" | "Byte" | "Unit"
907 )
908}
909
910pub(crate) fn supports_mutual_sizeof_ranked(
911 component: &[&FnDef],
912) -> Option<HashMap<String, usize>> {
913 if component.len() < 2 {
914 return None;
915 }
916 let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
917 let metric_indices: HashMap<String, Vec<usize>> = component
918 .iter()
919 .map(|fd| (fd.name.clone(), sizeof_measure_param_indices(fd)))
920 .collect();
921 if component.iter().any(|fd| {
922 metric_indices
923 .get(&fd.name)
924 .is_none_or(|indices| indices.is_empty())
925 }) {
926 return None;
927 }
928
929 let mut same_edges: HashMap<String, HashSet<String>> =
930 names.iter().map(|n| (n.clone(), HashSet::new())).collect();
931 let mut any_intra = false;
932 for fd in component {
933 let caller_metric_indices = metric_indices.get(&fd.name)?;
934 let caller_metric_params: Vec<&str> = caller_metric_indices
935 .iter()
936 .filter_map(|idx| fd.params.get(*idx).map(|(name, _)| name.as_str()))
937 .collect();
938 for (callee_raw, args) in collect_calls_from_body(fd.body.as_ref()) {
939 let Some(callee) = canonical_callee_name(&callee_raw, &names) else {
940 continue;
941 };
942 any_intra = true;
943 let callee_metric_indices = metric_indices.get(&callee)?;
944 let is_same_edge = callee_metric_indices.len() == caller_metric_params.len()
945 && callee_metric_indices
946 .iter()
947 .enumerate()
948 .all(|(pos, callee_idx)| {
949 let Some(arg) = args.get(*callee_idx).cloned() else {
950 return false;
951 };
952 is_ident(arg, caller_metric_params[pos])
953 });
954 if is_same_edge {
955 if let Some(edges) = same_edges.get_mut(&fd.name) {
956 edges.insert(callee);
957 } else {
958 return None;
959 }
960 }
961 }
962 }
963 if !any_intra {
964 return None;
965 }
966
967 let ranks = ranks_from_same_edges(&names, &same_edges)?;
968 let mut out = HashMap::new();
969 for fd in component {
970 let rank = ranks.get(&fd.name).cloned()?;
971 out.insert(fd.name.clone(), rank);
972 }
973 Some(out)
974}
975
976pub fn analyze_plans(
980 ctx: &CodegenContext,
981) -> (HashMap<String, RecursionPlan>, Vec<ProofModeIssue>) {
982 let mut plans = HashMap::new();
983 let mut issues = Vec::new();
984
985 let all_pure = pure_fns(ctx);
986 let recursive_names = recursive_pure_fn_names(ctx);
987 let components = call_graph::ordered_fn_components(&all_pure, &ctx.module_prefixes);
988
989 for component in components {
990 if component.is_empty() {
991 continue;
992 }
993 let is_recursive_component =
994 component.len() > 1 || recursive_names.contains(&component[0].name);
995 if !is_recursive_component {
996 continue;
997 }
998
999 if component.len() > 1 {
1000 if supports_mutual_int_countdown(&component) {
1001 for fd in &component {
1002 plans.insert(fd.name.clone(), RecursionPlan::MutualIntCountdown);
1003 }
1004 } else if let Some(ranks) = supports_mutual_string_pos_advance(&component) {
1005 for fd in &component {
1006 if let Some(rank) = ranks.get(&fd.name).cloned() {
1007 plans.insert(
1008 fd.name.clone(),
1009 RecursionPlan::MutualStringPosAdvance { rank },
1010 );
1011 }
1012 }
1013 } else if let Some(rankings) = supports_mutual_sizeof_ranked(&component) {
1014 for fd in &component {
1015 if let Some(rank) = rankings.get(&fd.name).cloned() {
1016 plans.insert(fd.name.clone(), RecursionPlan::MutualSizeOfRanked { rank });
1017 }
1018 }
1019 } else {
1020 let names = component
1021 .iter()
1022 .map(|fd| fd.name.clone())
1023 .collect::<Vec<_>>()
1024 .join(", ");
1025 let line = component.iter().map(|fd| fd.line).min().unwrap_or(1);
1026 issues.push(ProofModeIssue {
1027 line,
1028 message: format!(
1029 "unsupported mutual recursion group (currently supported in proof mode: Int countdown on first param): {}",
1030 names
1031 ),
1032 });
1033 }
1034 continue;
1035 }
1036
1037 let fd = component[0];
1038 if crate::codegen::lean::recurrence::detect_second_order_int_linear_recurrence(fd).is_some()
1039 {
1040 plans.insert(fd.name.clone(), RecursionPlan::LinearRecurrence2);
1041 } else if let Some((param_index, bound)) = single_int_ascending_param(fd) {
1042 plans.insert(
1043 fd.name.clone(),
1044 RecursionPlan::IntAscending { param_index, bound },
1045 );
1046 } else if let Some(param_index) = single_int_countdown_param_index(fd) {
1047 plans.insert(fd.name.clone(), RecursionPlan::IntCountdown { param_index });
1048 } else if supports_single_sizeof_structural(fd, ctx) {
1049 plans.insert(fd.name.clone(), RecursionPlan::SizeOfStructural);
1050 } else if let Some(param_index) = single_list_structural_param_index(fd) {
1051 plans.insert(
1052 fd.name.clone(),
1053 RecursionPlan::ListStructural { param_index },
1054 );
1055 } else if supports_single_string_pos_advance(fd) {
1056 plans.insert(fd.name.clone(), RecursionPlan::StringPosAdvance);
1057 } else {
1058 issues.push(ProofModeIssue {
1059 line: fd.line,
1060 message: format!(
1061 "recursive function '{}' is outside proof subset (currently supported: Int countdown, second-order affine Int recurrences with pair-state worker, structural recursion on List/recursive ADTs, String+position, mutual Int countdown, mutual String+position, and ranked sizeOf recursion)",
1062 fd.name
1063 ),
1064 });
1065 }
1066 }
1067
1068 (plans, issues)
1069}