1use std::collections::{HashMap, HashSet};
16
17use crate::ast::{
18 BinOp, Expr, FnBody, FnDef, MatchArm, Pattern, Spanned, Stmt, TailCallData, TypeDef,
19};
20use crate::call_graph;
21use crate::codegen::CodegenContext;
22use crate::codegen::lean::{
23 find_type_def, pure_fns, recursive_pure_fn_names, recursive_type_names,
24 sizeof_measure_param_indices,
25};
26
27use super::{ProofModeIssue, RecursionPlan};
28
29pub(crate) fn expr_to_dotted_name(expr: &Spanned<Expr>) -> Option<String> {
30 match &expr.node {
31 Expr::Ident(name) => Some(name.clone()),
32 Expr::Attr(obj, field) => expr_to_dotted_name(obj).map(|p| format!("{}.{}", p, field)),
33 _ => None,
34 }
35}
36
37pub(crate) fn call_matches(name: &str, target: &str) -> bool {
38 name == target || name.rsplit('.').next() == Some(target)
39}
40
41pub(crate) fn call_is_in_set(name: &str, targets: &HashSet<String>) -> bool {
42 call_matches_any(name, targets)
43}
44
45pub(crate) fn canonical_callee_name(name: &str, targets: &HashSet<String>) -> Option<String> {
46 if targets.contains(name) {
47 return Some(name.to_string());
48 }
49 name.rsplit('.')
50 .next()
51 .filter(|last| targets.contains(*last))
52 .map(ToString::to_string)
53}
54
55pub(crate) fn call_matches_any(name: &str, targets: &HashSet<String>) -> bool {
56 if targets.contains(name) {
57 return true;
58 }
59 match name.rsplit('.').next() {
60 Some(last) => targets.contains(last),
61 None => false,
62 }
63}
64
65pub(crate) fn is_int_minus_positive(expr: &Spanned<Expr>, param_name: &str) -> bool {
66 match &expr.node {
67 Expr::BinOp(BinOp::Sub, left, right) => {
68 matches!(&left.node, Expr::Ident(id) if id == param_name)
69 && matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
70 }
71 Expr::FnCall(callee, args) => {
72 let Some(name) = expr_to_dotted_name(callee) else {
73 return false;
74 };
75 (name == "Int.sub" || name == "int.sub")
76 && args.len() == 2
77 && matches!(&args[0].node, Expr::Ident(id) if id == param_name)
78 && matches!(&args[1].node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
79 }
80 _ => false,
81 }
82}
83
84pub(crate) fn collect_calls_from_expr<'a>(
85 expr: &'a Spanned<Expr>,
86 out: &mut Vec<(String, Vec<&'a Spanned<Expr>>)>,
87) {
88 match &expr.node {
89 Expr::FnCall(callee, args) => {
90 if let Some(name) = expr_to_dotted_name(callee) {
91 out.push((name, args.iter().collect()));
92 }
93 collect_calls_from_expr(callee, out);
94 for arg in args {
95 collect_calls_from_expr(arg, out);
96 }
97 }
98 Expr::TailCall(boxed) => {
99 let TailCallData {
100 target: name, args, ..
101 } = boxed.as_ref();
102 out.push((name.clone(), args.iter().collect()));
103 for arg in args {
104 collect_calls_from_expr(arg, out);
105 }
106 }
107 Expr::Attr(obj, _) => collect_calls_from_expr(obj, out),
108 Expr::BinOp(_, left, right) => {
109 collect_calls_from_expr(left, out);
110 collect_calls_from_expr(right, out);
111 }
112 Expr::Match { subject, arms, .. } => {
113 collect_calls_from_expr(subject, out);
114 for arm in arms {
115 collect_calls_from_expr(&arm.body, out);
116 }
117 }
118 Expr::Constructor(_, inner) => {
119 if let Some(inner) = inner {
120 collect_calls_from_expr(inner, out);
121 }
122 }
123 Expr::ErrorProp(inner) => collect_calls_from_expr(inner, out),
124 Expr::InterpolatedStr(parts) => {
125 for p in parts {
126 if let crate::ast::StrPart::Parsed(e) = p {
127 collect_calls_from_expr(e, out);
128 }
129 }
130 }
131 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
132 for item in items {
133 collect_calls_from_expr(item, out);
134 }
135 }
136 Expr::MapLiteral(entries) => {
137 for (k, v) in entries {
138 collect_calls_from_expr(k, out);
139 collect_calls_from_expr(v, out);
140 }
141 }
142 Expr::RecordCreate { fields, .. } => {
143 for (_, v) in fields {
144 collect_calls_from_expr(v, out);
145 }
146 }
147 Expr::RecordUpdate { base, updates, .. } => {
148 collect_calls_from_expr(base, out);
149 for (_, v) in updates {
150 collect_calls_from_expr(v, out);
151 }
152 }
153 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
154 }
155}
156
157pub(crate) fn collect_calls_from_body(body: &FnBody) -> Vec<(String, Vec<&Spanned<Expr>>)> {
158 let mut out = Vec::new();
159 for stmt in body.stmts() {
160 match stmt {
161 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => collect_calls_from_expr(expr, &mut out),
162 }
163 }
164 out
165}
166
167pub(crate) fn collect_list_tail_binders_from_expr(
168 expr: &Spanned<Expr>,
169 list_param_name: &str,
170 tails: &mut HashSet<String>,
171) {
172 match &expr.node {
173 Expr::Match { subject, arms, .. } => {
174 if matches!(&subject.node, Expr::Ident(id) if id == list_param_name) {
175 for MatchArm { pattern, .. } in arms {
176 if let Pattern::Cons(_, tail) = pattern {
177 tails.insert(tail.clone());
178 }
179 }
180 }
181 for arm in arms {
182 collect_list_tail_binders_from_expr(&arm.body, list_param_name, tails);
183 }
184 collect_list_tail_binders_from_expr(subject, list_param_name, tails);
185 }
186 Expr::FnCall(callee, args) => {
187 collect_list_tail_binders_from_expr(callee, list_param_name, tails);
188 for arg in args {
189 collect_list_tail_binders_from_expr(arg, list_param_name, tails);
190 }
191 }
192 Expr::TailCall(boxed) => {
193 let TailCallData {
194 target: _, args, ..
195 } = boxed.as_ref();
196 for arg in args {
197 collect_list_tail_binders_from_expr(arg, list_param_name, tails);
198 }
199 }
200 Expr::Attr(obj, _) => collect_list_tail_binders_from_expr(obj, list_param_name, tails),
201 Expr::BinOp(_, left, right) => {
202 collect_list_tail_binders_from_expr(left, list_param_name, tails);
203 collect_list_tail_binders_from_expr(right, list_param_name, tails);
204 }
205 Expr::Constructor(_, inner) => {
206 if let Some(inner) = inner {
207 collect_list_tail_binders_from_expr(inner, list_param_name, tails);
208 }
209 }
210 Expr::ErrorProp(inner) => {
211 collect_list_tail_binders_from_expr(inner, list_param_name, tails)
212 }
213 Expr::InterpolatedStr(parts) => {
214 for p in parts {
215 if let crate::ast::StrPart::Parsed(e) = p {
216 collect_list_tail_binders_from_expr(e, list_param_name, tails);
217 }
218 }
219 }
220 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
221 for item in items {
222 collect_list_tail_binders_from_expr(item, list_param_name, tails);
223 }
224 }
225 Expr::MapLiteral(entries) => {
226 for (k, v) in entries {
227 collect_list_tail_binders_from_expr(k, list_param_name, tails);
228 collect_list_tail_binders_from_expr(v, list_param_name, tails);
229 }
230 }
231 Expr::RecordCreate { fields, .. } => {
232 for (_, v) in fields {
233 collect_list_tail_binders_from_expr(v, list_param_name, tails);
234 }
235 }
236 Expr::RecordUpdate { base, updates, .. } => {
237 collect_list_tail_binders_from_expr(base, list_param_name, tails);
238 for (_, v) in updates {
239 collect_list_tail_binders_from_expr(v, list_param_name, tails);
240 }
241 }
242 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
243 }
244}
245
246pub(crate) fn collect_list_tail_binders(fd: &FnDef, list_param_name: &str) -> HashSet<String> {
247 let mut tails = HashSet::new();
248 for stmt in fd.body.stmts() {
249 match stmt {
250 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
251 collect_list_tail_binders_from_expr(expr, list_param_name, &mut tails)
252 }
253 }
254 }
255 tails
256}
257
258pub(crate) fn recursive_constructor_binders(
259 td: &TypeDef,
260 variant_name: &str,
261 binders: &[String],
262) -> Vec<String> {
263 let variant_short = variant_name.rsplit('.').next().unwrap_or(variant_name);
264 match td {
265 TypeDef::Sum { name, variants, .. } => variants
266 .iter()
267 .find(|variant| variant.name == variant_short)
268 .map(|variant| {
269 variant
270 .fields
271 .iter()
272 .zip(binders.iter())
273 .filter_map(|(field_ty, binder)| {
274 (field_ty.trim() == name).then_some(binder.clone())
275 })
276 .collect()
277 })
278 .unwrap_or_default(),
279 TypeDef::Product { .. } => Vec::new(),
280 }
281}
282
283pub(crate) fn grow_recursive_subterm_binders_from_expr(
284 expr: &Spanned<Expr>,
285 tracked: &HashSet<String>,
286 td: &TypeDef,
287 out: &mut HashSet<String>,
288) {
289 match &expr.node {
290 Expr::Match { subject, arms, .. } => {
291 if let Expr::Ident(subject_name) = &subject.node
292 && tracked.contains(subject_name)
293 {
294 for arm in arms {
295 if let Pattern::Constructor(variant_name, binders) = &arm.pattern {
296 out.extend(recursive_constructor_binders(td, variant_name, binders));
297 }
298 }
299 }
300 grow_recursive_subterm_binders_from_expr(subject, tracked, td, out);
301 for arm in arms {
302 grow_recursive_subterm_binders_from_expr(&arm.body, tracked, td, out);
303 }
304 }
305 Expr::FnCall(callee, args) => {
306 grow_recursive_subterm_binders_from_expr(callee, tracked, td, out);
307 for arg in args {
308 grow_recursive_subterm_binders_from_expr(arg, tracked, td, out);
309 }
310 }
311 Expr::Attr(obj, _) => grow_recursive_subterm_binders_from_expr(obj, tracked, td, out),
312 Expr::BinOp(_, left, right) => {
313 grow_recursive_subterm_binders_from_expr(left, tracked, td, out);
314 grow_recursive_subterm_binders_from_expr(right, tracked, td, out);
315 }
316 Expr::Constructor(_, Some(inner)) | Expr::ErrorProp(inner) => {
317 grow_recursive_subterm_binders_from_expr(inner, tracked, td, out)
318 }
319 Expr::InterpolatedStr(parts) => {
320 for part in parts {
321 if let crate::ast::StrPart::Parsed(inner) = part {
322 grow_recursive_subterm_binders_from_expr(inner, tracked, td, out);
323 }
324 }
325 }
326 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
327 for item in items {
328 grow_recursive_subterm_binders_from_expr(item, tracked, td, out);
329 }
330 }
331 Expr::MapLiteral(entries) => {
332 for (k, v) in entries {
333 grow_recursive_subterm_binders_from_expr(k, tracked, td, out);
334 grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
335 }
336 }
337 Expr::RecordCreate { fields, .. } => {
338 for (_, v) in fields {
339 grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
340 }
341 }
342 Expr::RecordUpdate { base, updates, .. } => {
343 grow_recursive_subterm_binders_from_expr(base, tracked, td, out);
344 for (_, v) in updates {
345 grow_recursive_subterm_binders_from_expr(v, tracked, td, out);
346 }
347 }
348 Expr::TailCall(boxed) => {
349 for arg in &boxed.args {
350 grow_recursive_subterm_binders_from_expr(arg, tracked, td, out);
351 }
352 }
353 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
354 }
355}
356
357pub(crate) fn collect_recursive_subterm_binders(
358 fd: &FnDef,
359 param_name: &str,
360 param_type: &str,
361 ctx: &CodegenContext,
362) -> HashSet<String> {
363 let Some(td) = find_type_def(ctx, param_type) else {
364 return HashSet::new();
365 };
366 let mut tracked: HashSet<String> = HashSet::from([param_name.to_string()]);
367 loop {
368 let mut discovered = HashSet::new();
369 for stmt in fd.body.stmts() {
370 match stmt {
371 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
372 grow_recursive_subterm_binders_from_expr(expr, &tracked, td, &mut discovered);
373 }
374 }
375 }
376 let before = tracked.len();
377 tracked.extend(discovered);
378 if tracked.len() == before {
379 break;
380 }
381 }
382 tracked.remove(param_name);
383 tracked
384}
385
386pub(crate) fn single_int_countdown_param_index(fd: &FnDef) -> Option<usize> {
387 let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
388 .into_iter()
389 .filter(|(name, _)| call_matches(name, &fd.name))
390 .map(|(_, args)| args)
391 .collect();
392 if recursive_calls.is_empty() {
393 return None;
394 }
395
396 fd.params
397 .iter()
398 .enumerate()
399 .find_map(|(idx, (param_name, param_ty))| {
400 if param_ty != "Int" {
401 return None;
402 }
403 let countdown_ok = recursive_calls.iter().all(|args| {
404 args.get(idx)
405 .cloned()
406 .is_some_and(|arg| is_int_minus_positive(arg, param_name))
407 });
408 if countdown_ok {
409 return Some(idx);
410 }
411
412 let ascent_ok = recursive_calls.iter().all(|args| {
415 args.get(idx)
416 .copied()
417 .is_some_and(|arg| is_int_plus_positive(arg, param_name))
418 });
419 (ascent_ok && has_negative_guarded_ascent(fd, param_name)).then_some(idx)
420 })
421}
422
423pub(crate) fn has_negative_guarded_ascent(fd: &FnDef, param_name: &str) -> bool {
424 let Some(tail) = fd.body.tail_expr() else {
425 return false;
426 };
427 let Expr::Match { subject, arms, .. } = &tail.node else {
428 return false;
429 };
430 let Expr::BinOp(BinOp::Lt, left, right) = &subject.node else {
431 return false;
432 };
433 if !is_ident(left, param_name)
434 || !matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(0)))
435 {
436 return false;
437 }
438
439 let mut true_arm = None;
440 let mut false_arm = None;
441 for arm in arms {
442 match arm.pattern {
443 Pattern::Literal(crate::ast::Literal::Bool(true)) => true_arm = Some(arm.body.as_ref()),
444 Pattern::Literal(crate::ast::Literal::Bool(false)) => {
445 false_arm = Some(arm.body.as_ref())
446 }
447 _ => return false,
448 }
449 }
450
451 let Some(true_arm) = true_arm else {
452 return false;
453 };
454 let Some(false_arm) = false_arm else {
455 return false;
456 };
457
458 let mut true_calls = Vec::new();
459 collect_calls_from_expr(true_arm, &mut true_calls);
460 let mut false_calls = Vec::new();
461 collect_calls_from_expr(false_arm, &mut false_calls);
462
463 true_calls
464 .iter()
465 .any(|(name, _)| call_matches(name, &fd.name))
466 && false_calls
467 .iter()
468 .all(|(name, _)| !call_matches(name, &fd.name))
469}
470
471pub(crate) fn single_int_ascending_param(fd: &FnDef) -> Option<(usize, Spanned<Expr>)> {
474 let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
475 .into_iter()
476 .filter(|(name, _)| call_matches(name, &fd.name))
477 .map(|(_, args)| args)
478 .collect();
479 if recursive_calls.is_empty() {
480 return None;
481 }
482
483 for (idx, (param_name, param_ty)) in fd.params.iter().enumerate() {
484 if param_ty != "Int" {
485 continue;
486 }
487 let ascent_ok = recursive_calls.iter().all(|args| {
488 args.get(idx)
489 .cloned()
490 .is_some_and(|arg| is_int_plus_positive(arg, param_name))
491 });
492 if !ascent_ok {
493 continue;
494 }
495 if let Some(bound) = extract_equality_bound_expr(fd, param_name) {
496 return Some((idx, bound));
497 }
498 }
499 None
500}
501
502pub(crate) fn extract_equality_bound_expr(fd: &FnDef, param_name: &str) -> Option<Spanned<Expr>> {
506 let tail = fd.body.tail_expr()?;
507 let Expr::Match { subject, arms, .. } = &tail.node else {
508 return None;
509 };
510 let Expr::BinOp(BinOp::Eq, left, right) = &subject.node else {
511 return None;
512 };
513 if !is_ident(left, param_name) {
514 return None;
515 }
516 let mut true_has_self = false;
518 let mut false_has_self = false;
519 for arm in arms {
520 match arm.pattern {
521 Pattern::Literal(crate::ast::Literal::Bool(true)) => {
522 let mut calls = Vec::new();
523 collect_calls_from_expr(&arm.body, &mut calls);
524 true_has_self = calls.iter().any(|(n, _)| call_matches(n, &fd.name));
525 }
526 Pattern::Literal(crate::ast::Literal::Bool(false)) => {
527 let mut calls = Vec::new();
528 collect_calls_from_expr(&arm.body, &mut calls);
529 false_has_self = calls.iter().any(|(n, _)| call_matches(n, &fd.name));
530 }
531 _ => return None,
532 }
533 }
534 if true_has_self || !false_has_self {
535 return None;
536 }
537 Some((**right).clone())
538}
539
540pub(crate) fn supports_single_sizeof_structural(fd: &FnDef, ctx: &CodegenContext) -> bool {
541 let recursive_calls: Vec<Vec<&Spanned<Expr>>> = collect_calls_from_body(fd.body.as_ref())
542 .into_iter()
543 .filter(|(name, _)| call_matches(name, &fd.name))
544 .map(|(_, args)| args)
545 .collect();
546 if recursive_calls.is_empty() {
547 return false;
548 }
549
550 let metric_indices = sizeof_measure_param_indices(fd);
551 if metric_indices.is_empty() {
552 return false;
553 }
554
555 let binder_sets: HashMap<usize, HashSet<String>> = metric_indices
556 .iter()
557 .filter_map(|idx| {
558 let (param_name, param_type) = fd.params.get(*idx)?;
559 recursive_type_names(ctx).contains(param_type).then(|| {
560 (
561 *idx,
562 collect_recursive_subterm_binders(fd, param_name, param_type, ctx),
563 )
564 })
565 })
566 .collect();
567
568 if binder_sets.values().all(HashSet::is_empty) {
569 return false;
570 }
571
572 recursive_calls.iter().all(|args| {
573 let mut strictly_smaller = false;
574 for idx in &metric_indices {
575 let Some((param_name, _)) = fd.params.get(*idx) else {
576 return false;
577 };
578 let Some(arg) = args.get(*idx).cloned() else {
579 return false;
580 };
581 if is_ident(arg, param_name) {
582 continue;
583 }
584 let Some(binders) = binder_sets.get(idx) else {
585 return false;
586 };
587 if matches!(&arg.node, Expr::Ident(id) if binders.contains(id)) {
588 strictly_smaller = true;
589 continue;
590 }
591 return false;
592 }
593 strictly_smaller
594 })
595}
596
597pub(crate) fn single_list_structural_param_index(fd: &FnDef) -> Option<usize> {
598 fd.params
599 .iter()
600 .enumerate()
601 .find_map(|(param_index, (param_name, param_ty))| {
602 if !(param_ty.starts_with("List<") || param_ty == "List") {
603 return None;
604 }
605
606 let tails = collect_list_tail_binders(fd, param_name);
607 if tails.is_empty() {
608 return None;
609 }
610
611 let recursive_calls: Vec<Option<&Spanned<Expr>>> =
612 collect_calls_from_body(fd.body.as_ref())
613 .into_iter()
614 .filter(|(name, _)| call_matches(name, &fd.name))
615 .map(|(_, args)| args.get(param_index).cloned())
616 .collect();
617 if recursive_calls.is_empty() {
618 return None;
619 }
620
621 recursive_calls
622 .into_iter()
623 .all(|arg| {
624 arg.is_some_and(|a| matches!(&a.node, Expr::Ident(id) if tails.contains(id)))
625 })
626 .then_some(param_index)
627 })
628}
629
630pub(crate) fn is_ident(expr: &Spanned<Expr>, name: &str) -> bool {
631 matches!(&expr.node, Expr::Ident(id) if id == name)
632}
633
634pub(crate) fn is_int_plus_positive(expr: &Spanned<Expr>, param_name: &str) -> bool {
635 match &expr.node {
636 Expr::BinOp(BinOp::Add, left, right) => {
637 matches!(&left.node, Expr::Ident(id) if id == param_name)
638 && matches!(&right.node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
639 }
640 Expr::FnCall(callee, args) => {
641 let Some(name) = expr_to_dotted_name(callee) else {
642 return false;
643 };
644 (name == "Int.add" || name == "int.add")
645 && args.len() == 2
646 && matches!(&args[0].node, Expr::Ident(id) if id == param_name)
647 && matches!(&args[1].node, Expr::Literal(crate::ast::Literal::Int(n)) if *n >= 1)
648 }
649 _ => false,
650 }
651}
652
653pub(crate) fn is_skip_ws_advance(
654 expr: &Spanned<Expr>,
655 string_param: &str,
656 pos_param: &str,
657) -> bool {
658 let Expr::FnCall(callee, args) = &expr.node else {
659 return false;
660 };
661 let Some(name) = expr_to_dotted_name(callee) else {
662 return false;
663 };
664 if !call_matches(&name, "skipWs") || args.len() != 2 {
665 return false;
666 }
667 is_ident(&args[0], string_param) && is_int_plus_positive(&args[1], pos_param)
668}
669
670pub(crate) fn is_skip_ws_same(expr: &Spanned<Expr>, string_param: &str, pos_param: &str) -> bool {
671 let Expr::FnCall(callee, args) = &expr.node else {
672 return false;
673 };
674 let Some(name) = expr_to_dotted_name(callee) else {
675 return false;
676 };
677 if !call_matches(&name, "skipWs") || args.len() != 2 {
678 return false;
679 }
680 is_ident(&args[0], string_param) && is_ident(&args[1], pos_param)
681}
682
683pub(crate) fn is_string_pos_advance(
684 expr: &Spanned<Expr>,
685 string_param: &str,
686 pos_param: &str,
687) -> bool {
688 is_int_plus_positive(expr, pos_param) || is_skip_ws_advance(expr, string_param, pos_param)
689}
690
691#[derive(Clone, Copy, Debug, Eq, PartialEq)]
692pub(crate) enum StringPosEdge {
693 Same,
694 Advance,
695}
696
697pub(crate) fn classify_string_pos_edge(
698 expr: &Spanned<Expr>,
699 string_param: &str,
700 pos_param: &str,
701) -> Option<StringPosEdge> {
702 if is_ident(expr, pos_param) || is_skip_ws_same(expr, string_param, pos_param) {
703 return Some(StringPosEdge::Same);
704 }
705 if is_string_pos_advance(expr, string_param, pos_param) {
706 return Some(StringPosEdge::Advance);
707 }
708 if let Expr::FnCall(callee, args) = &expr.node {
709 let name = expr_to_dotted_name(callee)?;
710 if call_matches(&name, "skipWs")
711 && args.len() == 2
712 && is_ident(&args[0], string_param)
713 && matches!(&args[1].node, Expr::Ident(id) if id != pos_param)
714 {
715 return Some(StringPosEdge::Advance);
716 }
717 }
718 if matches!(&expr.node, Expr::Ident(id) if id != pos_param) {
719 return Some(StringPosEdge::Advance);
720 }
721 None
722}
723
724pub(crate) fn ranks_from_same_edges(
725 names: &HashSet<String>,
726 same_edges: &HashMap<String, HashSet<String>>,
727) -> Option<HashMap<String, usize>> {
728 let mut indegree: HashMap<String, usize> = names.iter().map(|n| (n.clone(), 0)).collect();
729 for outs in same_edges.values() {
730 for to in outs {
731 if let Some(entry) = indegree.get_mut(to) {
732 *entry += 1;
733 } else {
734 return None;
735 }
736 }
737 }
738
739 let mut queue: Vec<String> = indegree
740 .iter()
741 .filter_map(|(name, °)| (deg == 0).then_some(name.clone()))
742 .collect();
743 queue.sort();
744 let mut topo = Vec::new();
745 while let Some(node) = queue.pop() {
746 topo.push(node.clone());
747 let outs = same_edges.get(&node).cloned().unwrap_or_default();
748 let mut newly_zero = Vec::new();
749 for to in outs {
750 if let Some(entry) = indegree.get_mut(&to) {
751 *entry -= 1;
752 if *entry == 0 {
753 newly_zero.push(to);
754 }
755 } else {
756 return None;
757 }
758 }
759 newly_zero.sort();
760 queue.extend(newly_zero);
761 }
762
763 if topo.len() != names.len() {
764 return None;
765 }
766
767 let n = topo.len();
768 let mut ranks = HashMap::new();
769 for (idx, name) in topo.into_iter().enumerate() {
770 ranks.insert(name, n - idx);
771 }
772 Some(ranks)
773}
774
775pub(crate) fn supports_single_string_pos_advance(fd: &FnDef) -> bool {
776 let Some((string_param, string_ty)) = fd.params.first() else {
777 return false;
778 };
779 let Some((pos_param, pos_ty)) = fd.params.get(1) else {
780 return false;
781 };
782 if string_ty != "String" || pos_ty != "Int" {
783 return false;
784 }
785
786 type CallPair<'a> = (Option<&'a Spanned<Expr>>, Option<&'a Spanned<Expr>>);
787 let recursive_calls: Vec<CallPair<'_>> = collect_calls_from_body(fd.body.as_ref())
788 .into_iter()
789 .filter(|(name, _)| call_matches(name, &fd.name))
790 .map(|(_, args)| (args.first().cloned(), args.get(1).cloned()))
791 .collect();
792 if recursive_calls.is_empty() {
793 return false;
794 }
795
796 recursive_calls.into_iter().all(|(arg0, arg1)| {
797 arg0.is_some_and(|e| is_ident(e, string_param))
798 && arg1.is_some_and(|e| is_string_pos_advance(e, string_param, pos_param))
799 })
800}
801
802pub(crate) fn supports_mutual_int_countdown(component: &[&FnDef]) -> bool {
803 if component.len() < 2 {
804 return false;
805 }
806 if component
807 .iter()
808 .any(|fd| !matches!(fd.params.first(), Some((_, t)) if t == "Int"))
809 {
810 return false;
811 }
812 let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
813 let mut any_intra = false;
814 for fd in component {
815 let param_name = &fd.params[0].0;
816 for (callee, args) in collect_calls_from_body(fd.body.as_ref()) {
817 if !call_is_in_set(&callee, &names) {
818 continue;
819 }
820 any_intra = true;
821 let Some(arg0) = args.first().cloned() else {
822 return false;
823 };
824 if !is_int_minus_positive(arg0, param_name) {
825 return false;
826 }
827 }
828 }
829 any_intra
830}
831
832pub(crate) fn supports_mutual_string_pos_advance(
833 component: &[&FnDef],
834) -> Option<HashMap<String, usize>> {
835 if component.len() < 2 {
836 return None;
837 }
838 if component.iter().any(|fd| {
839 !matches!(fd.params.first(), Some((_, t)) if t == "String")
840 || !matches!(fd.params.get(1), Some((_, t)) if t == "Int")
841 }) {
842 return None;
843 }
844
845 let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
846 let mut same_edges: HashMap<String, HashSet<String>> =
847 names.iter().map(|n| (n.clone(), HashSet::new())).collect();
848 let mut any_intra = false;
849
850 for fd in component {
851 let string_param = &fd.params[0].0;
852 let pos_param = &fd.params[1].0;
853 for (callee_raw, args) in collect_calls_from_body(fd.body.as_ref()) {
854 let Some(callee) = canonical_callee_name(&callee_raw, &names) else {
855 continue;
856 };
857 any_intra = true;
858
859 let arg0 = args.first().cloned()?;
860 let arg1 = args.get(1).cloned()?;
861
862 if !is_ident(arg0, string_param) {
863 return None;
864 }
865
866 match classify_string_pos_edge(arg1, string_param, pos_param) {
867 Some(StringPosEdge::Same) => {
868 if let Some(edges) = same_edges.get_mut(&fd.name) {
869 edges.insert(callee);
870 } else {
871 return None;
872 }
873 }
874 Some(StringPosEdge::Advance) => {}
875 None => return None,
876 }
877 }
878 }
879
880 if !any_intra {
881 return None;
882 }
883
884 ranks_from_same_edges(&names, &same_edges)
885}
886
887pub(crate) fn is_scalar_like_type(type_name: &str) -> bool {
888 matches!(
889 type_name,
890 "Int" | "Float" | "Bool" | "String" | "Char" | "Byte" | "Unit"
891 )
892}
893
894pub(crate) fn supports_mutual_sizeof_ranked(
895 component: &[&FnDef],
896) -> Option<HashMap<String, usize>> {
897 if component.len() < 2 {
898 return None;
899 }
900 let names: HashSet<String> = component.iter().map(|fd| fd.name.clone()).collect();
901 let metric_indices: HashMap<String, Vec<usize>> = component
902 .iter()
903 .map(|fd| (fd.name.clone(), sizeof_measure_param_indices(fd)))
904 .collect();
905 if component.iter().any(|fd| {
906 metric_indices
907 .get(&fd.name)
908 .is_none_or(|indices| indices.is_empty())
909 }) {
910 return None;
911 }
912
913 let mut same_edges: HashMap<String, HashSet<String>> =
914 names.iter().map(|n| (n.clone(), HashSet::new())).collect();
915 let mut any_intra = false;
916 for fd in component {
917 let caller_metric_indices = metric_indices.get(&fd.name)?;
918 let caller_metric_params: Vec<&str> = caller_metric_indices
919 .iter()
920 .filter_map(|idx| fd.params.get(*idx).map(|(name, _)| name.as_str()))
921 .collect();
922 for (callee_raw, args) in collect_calls_from_body(fd.body.as_ref()) {
923 let Some(callee) = canonical_callee_name(&callee_raw, &names) else {
924 continue;
925 };
926 any_intra = true;
927 let callee_metric_indices = metric_indices.get(&callee)?;
928 let is_same_edge = callee_metric_indices.len() == caller_metric_params.len()
929 && callee_metric_indices
930 .iter()
931 .enumerate()
932 .all(|(pos, callee_idx)| {
933 let Some(arg) = args.get(*callee_idx).cloned() else {
934 return false;
935 };
936 is_ident(arg, caller_metric_params[pos])
937 });
938 if is_same_edge {
939 if let Some(edges) = same_edges.get_mut(&fd.name) {
940 edges.insert(callee);
941 } else {
942 return None;
943 }
944 }
945 }
946 }
947 if !any_intra {
948 return None;
949 }
950
951 let ranks = ranks_from_same_edges(&names, &same_edges)?;
952 let mut out = HashMap::new();
953 for fd in component {
954 let rank = ranks.get(&fd.name).cloned()?;
955 out.insert(fd.name.clone(), rank);
956 }
957 Some(out)
958}
959
960pub fn analyze_plans(
964 ctx: &CodegenContext,
965) -> (HashMap<String, RecursionPlan>, Vec<ProofModeIssue>) {
966 let mut plans = HashMap::new();
967 let mut issues = Vec::new();
968
969 let all_pure = pure_fns(ctx);
970 let recursive_names = recursive_pure_fn_names(ctx);
971 let components = call_graph::ordered_fn_components(&all_pure, &ctx.module_prefixes);
972
973 for component in components {
974 if component.is_empty() {
975 continue;
976 }
977 let is_recursive_component =
978 component.len() > 1 || recursive_names.contains(&component[0].name);
979 if !is_recursive_component {
980 continue;
981 }
982
983 if component.len() > 1 {
984 if supports_mutual_int_countdown(&component) {
985 for fd in &component {
986 plans.insert(fd.name.clone(), RecursionPlan::MutualIntCountdown);
987 }
988 } else if let Some(ranks) = supports_mutual_string_pos_advance(&component) {
989 for fd in &component {
990 if let Some(rank) = ranks.get(&fd.name).cloned() {
991 plans.insert(
992 fd.name.clone(),
993 RecursionPlan::MutualStringPosAdvance { rank },
994 );
995 }
996 }
997 } else if let Some(rankings) = supports_mutual_sizeof_ranked(&component) {
998 for fd in &component {
999 if let Some(rank) = rankings.get(&fd.name).cloned() {
1000 plans.insert(fd.name.clone(), RecursionPlan::MutualSizeOfRanked { rank });
1001 }
1002 }
1003 } else {
1004 let names = component
1005 .iter()
1006 .map(|fd| fd.name.clone())
1007 .collect::<Vec<_>>()
1008 .join(", ");
1009 let line = component.iter().map(|fd| fd.line).min().unwrap_or(1);
1010 issues.push(ProofModeIssue {
1011 line,
1012 message: format!(
1013 "unsupported mutual recursion group (currently supported in proof mode: Int countdown on first param): {}",
1014 names
1015 ),
1016 });
1017 }
1018 continue;
1019 }
1020
1021 let fd = component[0];
1022 if crate::codegen::lean::recurrence::detect_second_order_int_linear_recurrence(fd).is_some()
1023 {
1024 plans.insert(fd.name.clone(), RecursionPlan::LinearRecurrence2);
1025 } else if let Some((param_index, bound)) = single_int_ascending_param(fd) {
1026 plans.insert(
1027 fd.name.clone(),
1028 RecursionPlan::IntAscending { param_index, bound },
1029 );
1030 } else if let Some(param_index) = single_int_countdown_param_index(fd) {
1031 plans.insert(fd.name.clone(), RecursionPlan::IntCountdown { param_index });
1032 } else if supports_single_sizeof_structural(fd, ctx) {
1033 plans.insert(fd.name.clone(), RecursionPlan::SizeOfStructural);
1034 } else if let Some(param_index) = single_list_structural_param_index(fd) {
1035 plans.insert(
1036 fd.name.clone(),
1037 RecursionPlan::ListStructural { param_index },
1038 );
1039 } else if supports_single_string_pos_advance(fd) {
1040 plans.insert(fd.name.clone(), RecursionPlan::StringPosAdvance);
1041 } else {
1042 issues.push(ProofModeIssue {
1043 line: fd.line,
1044 message: format!(
1045 "recursive function '{}' is outside proof subset (currently supported: Int countdown, second-order affine Int recurrences with pair-state worker, structural recursion on List/recursive ADTs, String+position, mutual Int countdown, mutual String+position, and ranked sizeOf recursion)",
1046 fd.name
1047 ),
1048 });
1049 }
1050 }
1051
1052 (plans, issues)
1053}