1use std::collections::{HashMap, HashSet};
2
3use logicaffeine_base::Symbol;
4use logicaffeine_language::ast::{Expr, Stmt};
5use logicaffeine_language::ast::stmt::ClosureBody;
6
7use super::callgraph::CallGraph;
8use super::types::{LogosType, TypeEnv};
9
10pub struct ReadonlyParams {
19 pub readonly: HashMap<Symbol, HashSet<Symbol>>,
21}
22
23impl ReadonlyParams {
24 pub fn analyze(stmts: &[Stmt<'_>], callgraph: &CallGraph, type_env: &TypeEnv) -> Self {
34 let mut fn_params: HashMap<Symbol, Vec<Symbol>> = HashMap::new();
36 for stmt in stmts {
37 if let Stmt::FunctionDef { name, params, .. } = stmt {
38 let syms: Vec<Symbol> = params.iter().map(|(s, _)| *s).collect();
39 fn_params.insert(*name, syms);
40 }
41 }
42
43 let mut readonly: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
45 for stmt in stmts {
46 if let Stmt::FunctionDef { name, params, .. } = stmt {
47 let mut candidates = HashSet::new();
48 for (sym, _) in params {
49 if is_seq_type(type_env.lookup(*sym)) {
50 candidates.insert(*sym);
51 }
52 }
53 readonly.insert(*name, candidates);
54 }
55 }
56
57 for stmt in stmts {
59 if let Stmt::FunctionDef { name, params, body, is_native, .. } = stmt {
60 if *is_native {
61 continue;
62 }
63 let param_set: HashSet<Symbol> = params.iter().map(|(s, _)| *s).collect();
64 let mutated = collect_direct_mutations(body, ¶m_set);
65 if let Some(candidates) = readonly.get_mut(name) {
66 for sym in &mutated {
67 candidates.remove(sym);
68 }
69 }
70 }
71 }
72
73 loop {
75 let mut changed = false;
76
77 for stmt in stmts {
78 if let Stmt::FunctionDef { name: caller, body, is_native, .. } = stmt {
79 if *is_native {
80 continue;
81 }
82
83 let call_sites = collect_call_sites(body);
85
86 for (callee, arg_syms) in &call_sites {
87 let callee_params = match fn_params.get(callee) {
88 Some(p) => p,
89 None => continue, };
91
92 for (i, maybe_arg_sym) in arg_syms.iter().enumerate() {
93 let arg_sym = match maybe_arg_sym {
94 Some(s) => s,
95 None => continue, };
97
98 let callee_param = match callee_params.get(i) {
99 Some(p) => p,
100 None => continue,
101 };
102
103 let callee_param_readonly = readonly
105 .get(callee)
106 .map(|s| s.contains(callee_param))
107 .unwrap_or(true); if !callee_param_readonly {
110 if let Some(caller_readonly) = readonly.get_mut(caller) {
112 if caller_readonly.remove(arg_sym) {
113 changed = true;
114 }
115 }
116 }
117 }
118 }
119 }
120 }
121
122 if !changed {
123 break;
124 }
125 }
126
127 Self { readonly }
128 }
129
130 pub fn is_readonly(&self, fn_sym: Symbol, param_sym: Symbol) -> bool {
132 self.readonly
133 .get(&fn_sym)
134 .map(|s| s.contains(¶m_sym))
135 .unwrap_or(false)
136 }
137}
138
139fn is_seq_type(ty: &LogosType) -> bool {
140 matches!(ty, LogosType::Seq(_))
141}
142
143fn collect_direct_mutations(stmts: &[Stmt<'_>], param_set: &HashSet<Symbol>) -> HashSet<Symbol> {
158 let mut mutated = HashSet::new();
159 for stmt in stmts {
160 collect_mutations_from_stmt(stmt, param_set, &mut mutated);
161 }
162 collect_consumed_params(stmts, param_set, &mut mutated);
167 mutated
168}
169
170fn collect_mutations_from_stmt(stmt: &Stmt<'_>, param_set: &HashSet<Symbol>, mutated: &mut HashSet<Symbol>) {
171 match stmt {
172 Stmt::Push { collection, .. } => {
173 if let Expr::Identifier(sym) = **collection {
174 if param_set.contains(&sym) {
175 mutated.insert(sym);
176 }
177 }
178 }
179 Stmt::Pop { collection, .. } => {
180 if let Expr::Identifier(sym) = **collection {
181 if param_set.contains(&sym) {
182 mutated.insert(sym);
183 }
184 }
185 }
186 Stmt::Add { collection, .. } => {
187 if let Expr::Identifier(sym) = **collection {
188 if param_set.contains(&sym) {
189 mutated.insert(sym);
190 }
191 }
192 }
193 Stmt::Remove { collection, .. } => {
194 if let Expr::Identifier(sym) = **collection {
195 if param_set.contains(&sym) {
196 mutated.insert(sym);
197 }
198 }
199 }
200 Stmt::SetIndex { collection, .. } => {
201 if let Expr::Identifier(sym) = **collection {
202 if param_set.contains(&sym) {
203 mutated.insert(sym);
204 }
205 }
206 }
207 Stmt::SetField { object, .. } => {
208 if let Expr::Identifier(sym) = **object {
209 if param_set.contains(&sym) {
210 mutated.insert(sym);
211 }
212 }
213 }
214 Stmt::Set { target, .. } => {
215 if param_set.contains(target) {
216 mutated.insert(*target);
217 }
218 }
219 Stmt::If { then_block, else_block, .. } => {
221 for s in *then_block {
222 collect_mutations_from_stmt(s, param_set, mutated);
223 }
224 if let Some(else_b) = else_block {
225 for s in *else_b {
226 collect_mutations_from_stmt(s, param_set, mutated);
227 }
228 }
229 }
230 Stmt::While { body, .. } => {
231 for s in *body {
232 collect_mutations_from_stmt(s, param_set, mutated);
233 }
234 }
235 Stmt::Repeat { body, .. } => {
236 for s in *body {
237 collect_mutations_from_stmt(s, param_set, mutated);
238 }
239 }
240 Stmt::Inspect { arms, .. } => {
241 for arm in arms {
242 for s in arm.body {
243 collect_mutations_from_stmt(s, param_set, mutated);
244 }
245 }
246 }
247 _ => {}
248 }
249}
250
251fn collect_consumed_params(stmts: &[Stmt<'_>], param_set: &HashSet<Symbol>, consumed: &mut HashSet<Symbol>) {
254 for stmt in stmts {
255 match stmt {
256 Stmt::Let { mutable: true, value, .. } => {
257 if let Expr::Identifier(sym) = value {
258 if param_set.contains(sym) {
259 consumed.insert(*sym);
260 }
261 }
262 }
263 Stmt::If { then_block, else_block, .. } => {
264 collect_consumed_params(then_block, param_set, consumed);
265 if let Some(else_b) = else_block {
266 collect_consumed_params(else_b, param_set, consumed);
267 }
268 }
269 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
270 collect_consumed_params(body, param_set, consumed);
271 }
272 Stmt::Inspect { arms, .. } => {
273 for arm in arms {
274 collect_consumed_params(arm.body, param_set, consumed);
275 }
276 }
277 _ => {}
278 }
279 }
280}
281
282fn collect_call_sites(stmts: &[Stmt<'_>]) -> Vec<(Symbol, Vec<Option<Symbol>>)> {
291 let mut sites = Vec::new();
292 collect_call_sites_from_stmts(stmts, &mut sites);
293 sites
294}
295
296fn collect_call_sites_from_stmts(stmts: &[Stmt<'_>], sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
297 for stmt in stmts {
298 collect_call_sites_from_stmt(stmt, sites);
299 }
300}
301
302fn collect_call_sites_from_stmt(stmt: &Stmt<'_>, sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
303 match stmt {
304 Stmt::Call { function, args } => {
305 let arg_syms = args.iter().map(|arg| {
306 if let Expr::Identifier(sym) = *arg { Some(*sym) } else { None }
307 }).collect();
308 sites.push((*function, arg_syms));
309 for arg in args {
310 collect_call_sites_from_expr(arg, sites);
311 }
312 }
313 Stmt::Let { value, .. } => collect_call_sites_from_expr(value, sites),
314 Stmt::Set { value, .. } => collect_call_sites_from_expr(value, sites),
315 Stmt::Return { value: Some(v) } => collect_call_sites_from_expr(v, sites),
316 Stmt::If { cond, then_block, else_block } => {
317 collect_call_sites_from_expr(cond, sites);
318 collect_call_sites_from_stmts(then_block, sites);
319 if let Some(else_b) = else_block {
320 collect_call_sites_from_stmts(else_b, sites);
321 }
322 }
323 Stmt::While { cond, body, .. } => {
324 collect_call_sites_from_expr(cond, sites);
325 collect_call_sites_from_stmts(body, sites);
326 }
327 Stmt::Repeat { iterable, body, .. } => {
328 collect_call_sites_from_expr(iterable, sites);
329 collect_call_sites_from_stmts(body, sites);
330 }
331 Stmt::Push { value, collection } => {
332 collect_call_sites_from_expr(value, sites);
333 collect_call_sites_from_expr(collection, sites);
334 }
335 Stmt::Inspect { arms, .. } => {
336 for arm in arms {
337 collect_call_sites_from_stmts(arm.body, sites);
338 }
339 }
340 Stmt::Concurrent { tasks } | Stmt::Parallel { tasks } => {
341 collect_call_sites_from_stmts(tasks, sites);
342 }
343 _ => {}
344 }
345}
346
347fn collect_call_sites_from_expr(expr: &Expr<'_>, sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
348 match expr {
349 Expr::Call { function, args } => {
350 let arg_syms = args.iter().map(|arg| {
351 if let Expr::Identifier(sym) = *arg { Some(*sym) } else { None }
352 }).collect();
353 sites.push((*function, arg_syms));
354 for arg in args {
355 collect_call_sites_from_expr(arg, sites);
356 }
357 }
358 Expr::Closure { body, .. } => match body {
359 ClosureBody::Expression(e) => collect_call_sites_from_expr(e, sites),
360 ClosureBody::Block(stmts) => collect_call_sites_from_stmts(stmts, sites),
361 },
362 Expr::BinaryOp { left, right, .. } => {
363 collect_call_sites_from_expr(left, sites);
364 collect_call_sites_from_expr(right, sites);
365 }
366 Expr::Index { collection, index } => {
367 collect_call_sites_from_expr(collection, sites);
368 collect_call_sites_from_expr(index, sites);
369 }
370 Expr::Length { collection } => collect_call_sites_from_expr(collection, sites),
371 Expr::Contains { collection, value } => {
372 collect_call_sites_from_expr(collection, sites);
373 collect_call_sites_from_expr(value, sites);
374 }
375 Expr::FieldAccess { object, .. } => collect_call_sites_from_expr(object, sites),
376 Expr::Copy { expr } | Expr::Give { value: expr } => {
377 collect_call_sites_from_expr(expr, sites);
378 }
379 Expr::OptionSome { value } => collect_call_sites_from_expr(value, sites),
380 Expr::WithCapacity { value, capacity } => {
381 collect_call_sites_from_expr(value, sites);
382 collect_call_sites_from_expr(capacity, sites);
383 }
384 Expr::CallExpr { callee, args } => {
385 collect_call_sites_from_expr(callee, sites);
386 for arg in args {
387 collect_call_sites_from_expr(arg, sites);
388 }
389 }
390 _ => {}
391 }
392}
393
394pub struct MutableBorrowParams {
408 pub mutable_borrow: HashMap<Symbol, HashSet<Symbol>>,
410}
411
412impl MutableBorrowParams {
413 pub fn analyze(stmts: &[Stmt<'_>], callgraph: &CallGraph, type_env: &TypeEnv) -> Self {
415 let mut fn_params: HashMap<Symbol, Vec<Symbol>> = HashMap::new();
416 for stmt in stmts {
417 if let Stmt::FunctionDef { name, params, .. } = stmt {
418 let syms: Vec<Symbol> = params.iter().map(|(s, _)| *s).collect();
419 fn_params.insert(*name, syms);
420 }
421 }
422
423 let mut mutable_borrow: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
424
425 for stmt in stmts {
426 if let Stmt::FunctionDef { name, params, body, is_native, is_exported, .. } = stmt {
427 if *is_native || *is_exported {
428 continue;
429 }
430
431 let mut candidates = HashSet::new();
432
433 for (sym, _) in params {
434 if !is_seq_type(type_env.lookup(*sym)) {
435 continue;
436 }
437
438 let has_set_index = has_set_index_on(body, *sym);
439 let has_structural = has_structural_mutation_on(body, *sym);
440 let has_reassign = has_reassignment_on(body, *sym);
441 let consumed = is_consumed_param(body, *sym);
442 let returned = is_sole_return_param(body, *sym);
443
444 if has_set_index && !has_structural && !has_reassign && !consumed && returned {
445 candidates.insert(*sym);
446 } else if consumed {
447 let param_idx = params.iter().position(|(s, _)| *s == *sym).unwrap_or(usize::MAX);
451 if let Some(alias) = detect_consume_alias(body, *sym) {
452 let alias_has_set_index = has_set_index_on(body, alias);
453 let alias_has_structural = has_structural_mutation_on(body, alias);
454 let alias_returned = is_sole_return_param_or_alias(body, *sym, alias);
455 let alias_reassign_ok = reassignment_only_self_calls(body, alias, *name, param_idx);
456 let param_dead = is_param_dead_after_consume(body, *sym, alias);
457
458 if alias_has_set_index && !alias_has_structural && alias_returned && alias_reassign_ok && param_dead {
459 candidates.insert(*sym);
460 }
461 }
462 }
463 }
464
465 if !candidates.is_empty() {
466 mutable_borrow.insert(*name, candidates);
467 }
468 }
469 }
470
471 loop {
473 let mut changed = false;
474 for stmt in stmts {
475 if let Stmt::FunctionDef { name: caller, body, is_native, .. } = stmt {
476 if *is_native {
477 continue;
478 }
479 let call_sites = collect_call_sites(body);
480 for (callee, arg_syms) in &call_sites {
481 let callee_params = match fn_params.get(callee) {
482 Some(p) => p,
483 None => continue,
484 };
485 for (i, maybe_arg_sym) in arg_syms.iter().enumerate() {
486 let arg_sym = match maybe_arg_sym {
487 Some(s) => s,
488 None => continue,
489 };
490 let callee_param = match callee_params.get(i) {
491 Some(p) => p,
492 None => continue,
493 };
494 let callee_is_mut_borrow = mutable_borrow
495 .get(callee)
496 .map(|s| s.contains(callee_param))
497 .unwrap_or(false);
498 if !callee_is_mut_borrow {
499 if let Some(caller_set) = mutable_borrow.get_mut(caller) {
500 if caller_set.remove(arg_sym) {
501 changed = true;
502 }
503 }
504 }
505 }
506 }
507 }
508 }
509 if !changed {
510 break;
511 }
512 }
513
514 let incompatible = collect_incompatible_mut_borrow_callsites(
520 stmts, &mutable_borrow, &fn_params,
521 );
522 for fn_sym in incompatible {
523 mutable_borrow.remove(&fn_sym);
524 }
525
526 Self { mutable_borrow }
527 }
528
529 pub fn is_mutable_borrow(&self, fn_sym: Symbol, param_sym: Symbol) -> bool {
530 self.mutable_borrow
531 .get(&fn_sym)
532 .map(|s| s.contains(¶m_sym))
533 .unwrap_or(false)
534 }
535}
536
537fn has_set_index_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
538 stmts.iter().any(|s| check_set_index_stmt(s, sym))
539}
540
541fn check_set_index_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
542 match stmt {
543 Stmt::SetIndex { collection, .. } => {
544 matches!(**collection, Expr::Identifier(s) if s == sym)
545 }
546 Stmt::If { then_block, else_block, .. } => {
547 has_set_index_on(then_block, sym)
548 || else_block.as_ref().map_or(false, |eb| has_set_index_on(eb, sym))
549 }
550 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
551 has_set_index_on(body, sym)
552 }
553 Stmt::Inspect { arms, .. } => {
554 arms.iter().any(|arm| has_set_index_on(arm.body, sym))
555 }
556 _ => false,
557 }
558}
559
560fn has_structural_mutation_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
561 stmts.iter().any(|s| check_structural_stmt(s, sym))
562}
563
564fn check_structural_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
565 match stmt {
566 Stmt::Push { collection, .. } | Stmt::Pop { collection, .. }
567 | Stmt::Add { collection, .. } | Stmt::Remove { collection, .. } => {
568 matches!(**collection, Expr::Identifier(s) if s == sym)
569 }
570 Stmt::If { then_block, else_block, .. } => {
571 has_structural_mutation_on(then_block, sym)
572 || else_block.as_ref().map_or(false, |eb| has_structural_mutation_on(eb, sym))
573 }
574 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
575 has_structural_mutation_on(body, sym)
576 }
577 Stmt::Inspect { arms, .. } => {
578 arms.iter().any(|arm| has_structural_mutation_on(arm.body, sym))
579 }
580 _ => false,
581 }
582}
583
584fn has_reassignment_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
585 stmts.iter().any(|s| check_reassignment_stmt(s, sym))
586}
587
588fn check_reassignment_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
589 match stmt {
590 Stmt::Set { target, .. } => *target == sym,
591 Stmt::If { then_block, else_block, .. } => {
592 has_reassignment_on(then_block, sym)
593 || else_block.as_ref().map_or(false, |eb| has_reassignment_on(eb, sym))
594 }
595 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
596 has_reassignment_on(body, sym)
597 }
598 Stmt::Inspect { arms, .. } => {
599 arms.iter().any(|arm| has_reassignment_on(arm.body, sym))
600 }
601 _ => false,
602 }
603}
604
605fn is_consumed_param(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
606 for stmt in stmts {
607 match stmt {
608 Stmt::Let { mutable: true, value, .. } => {
609 if matches!(value, Expr::Identifier(s) if *s == sym) {
610 return true;
611 }
612 }
613 Stmt::If { then_block, else_block, .. } => {
614 if is_consumed_param(then_block, sym) { return true; }
615 if let Some(else_b) = else_block {
616 if is_consumed_param(else_b, sym) { return true; }
617 }
618 }
619 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
620 if is_consumed_param(body, sym) { return true; }
621 }
622 _ => {}
623 }
624 }
625 false
626}
627
628fn is_sole_return_param(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
629 let mut returns = Vec::new();
630 collect_returns(stmts, &mut returns);
631 !returns.is_empty() && returns.iter().all(|r| *r == sym)
632}
633
634fn collect_returns(stmts: &[Stmt<'_>], returns: &mut Vec<Symbol>) {
635 for stmt in stmts {
636 match stmt {
637 Stmt::Return { value: Some(expr) } => {
638 if let Expr::Identifier(sym) = expr {
639 returns.push(*sym);
640 } else {
641 returns.push(Symbol::EMPTY);
643 }
644 }
645 Stmt::If { then_block, else_block, .. } => {
646 collect_returns(then_block, returns);
647 if let Some(else_b) = else_block {
648 collect_returns(else_b, returns);
649 }
650 }
651 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
652 collect_returns(body, returns);
653 }
654 Stmt::Inspect { arms, .. } => {
655 for arm in arms {
656 collect_returns(arm.body, returns);
657 }
658 }
659 _ => {}
660 }
661 }
662}
663
664fn collect_incompatible_mut_borrow_callsites(
673 stmts: &[Stmt<'_>],
674 mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
675 fn_params: &HashMap<Symbol, Vec<Symbol>>,
676) -> HashSet<Symbol> {
677 let mut incompatible = HashSet::new();
678 for stmt in stmts {
679 if let Stmt::FunctionDef { body, .. } = stmt {
680 check_callsite_compat_stmts(body, mutable_borrow, fn_params, &mut incompatible);
681 }
682 }
683 check_callsite_compat_stmts(stmts, mutable_borrow, fn_params, &mut incompatible);
685 incompatible
686}
687
688fn check_callsite_compat_stmts(
689 stmts: &[Stmt<'_>],
690 mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
691 fn_params: &HashMap<Symbol, Vec<Symbol>>,
692 incompatible: &mut HashSet<Symbol>,
693) {
694 for stmt in stmts {
695 check_callsite_compat_stmt(stmt, mutable_borrow, fn_params, incompatible);
696 }
697}
698
699fn check_callsite_compat_stmt(
700 stmt: &Stmt<'_>,
701 mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
702 fn_params: &HashMap<Symbol, Vec<Symbol>>,
703 incompatible: &mut HashSet<Symbol>,
704) {
705 match stmt {
706 Stmt::Call { args, .. } => {
708 for arg in args {
709 check_callsite_compat_expr(arg, mutable_borrow, incompatible);
710 }
711 }
712 Stmt::Set { target, value } => {
714 if let Expr::Call { function, args } = value {
715 if mutable_borrow.contains_key(function) {
716 let mut_positions: HashSet<usize> = fn_params.get(function)
718 .map(|params| {
719 params.iter().enumerate()
720 .filter(|(_, sym)| {
721 mutable_borrow.get(function)
722 .map(|s| s.contains(sym))
723 .unwrap_or(false)
724 })
725 .map(|(i, _)| i)
726 .collect()
727 })
728 .unwrap_or_default();
729
730 let target_at_mut_pos = args.iter().enumerate()
731 .any(|(i, a)| {
732 mut_positions.contains(&i)
733 && matches!(a, Expr::Identifier(sym) if *sym == *target)
734 });
735
736 if !target_at_mut_pos {
737 incompatible.insert(*function);
738 }
739 }
740 for arg in args {
742 check_callsite_compat_expr(arg, mutable_borrow, incompatible);
743 }
744 } else {
745 check_callsite_compat_expr(value, mutable_borrow, incompatible);
746 }
747 }
748 Stmt::Let { value, .. } => {
750 check_callsite_compat_expr(value, mutable_borrow, incompatible);
751 }
752 Stmt::Return { value: Some(v) } => {
753 check_callsite_compat_expr(v, mutable_borrow, incompatible);
754 }
755 Stmt::Show { object, .. } => {
756 check_callsite_compat_expr(object, mutable_borrow, incompatible);
757 }
758 Stmt::Push { value, collection } => {
759 check_callsite_compat_expr(value, mutable_borrow, incompatible);
760 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
761 }
762 Stmt::SetIndex { collection, index, value } => {
763 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
764 check_callsite_compat_expr(index, mutable_borrow, incompatible);
765 check_callsite_compat_expr(value, mutable_borrow, incompatible);
766 }
767 Stmt::If { cond, then_block, else_block } => {
768 check_callsite_compat_expr(cond, mutable_borrow, incompatible);
769 check_callsite_compat_stmts(then_block, mutable_borrow, fn_params, incompatible);
770 if let Some(else_b) = else_block {
771 check_callsite_compat_stmts(else_b, mutable_borrow, fn_params, incompatible);
772 }
773 }
774 Stmt::While { cond, body, .. } => {
775 check_callsite_compat_expr(cond, mutable_borrow, incompatible);
776 check_callsite_compat_stmts(body, mutable_borrow, fn_params, incompatible);
777 }
778 Stmt::Repeat { iterable, body, .. } => {
779 check_callsite_compat_expr(iterable, mutable_borrow, incompatible);
780 check_callsite_compat_stmts(body, mutable_borrow, fn_params, incompatible);
781 }
782 Stmt::Inspect { arms, .. } => {
783 for arm in arms {
784 check_callsite_compat_stmts(arm.body, mutable_borrow, fn_params, incompatible);
785 }
786 }
787 _ => {}
789 }
790}
791
792fn check_callsite_compat_expr(
794 expr: &Expr<'_>,
795 mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
796 incompatible: &mut HashSet<Symbol>,
797) {
798 match expr {
799 Expr::Call { function, args } => {
800 if mutable_borrow.contains_key(function) {
801 incompatible.insert(*function);
802 }
803 for arg in args {
804 check_callsite_compat_expr(arg, mutable_borrow, incompatible);
805 }
806 }
807 Expr::BinaryOp { left, right, .. } => {
808 check_callsite_compat_expr(left, mutable_borrow, incompatible);
809 check_callsite_compat_expr(right, mutable_borrow, incompatible);
810 }
811 Expr::Index { collection, index } => {
812 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
813 check_callsite_compat_expr(index, mutable_borrow, incompatible);
814 }
815 Expr::Length { collection } => {
816 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
817 }
818 Expr::Contains { collection, value } => {
819 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
820 check_callsite_compat_expr(value, mutable_borrow, incompatible);
821 }
822 Expr::FieldAccess { object, .. } => {
823 check_callsite_compat_expr(object, mutable_borrow, incompatible);
824 }
825 Expr::Copy { expr: inner } | Expr::Give { value: inner } | Expr::OptionSome { value: inner } => {
826 check_callsite_compat_expr(inner, mutable_borrow, incompatible);
827 }
828 _ => {}
829 }
830}
831
832fn detect_consume_alias(body: &[Stmt<'_>], param_sym: Symbol) -> Option<Symbol> {
839 let mut alias = None;
840 for stmt in body {
841 if let Stmt::Let { var, mutable: true, value, .. } = stmt {
842 if matches!(value, Expr::Identifier(s) if *s == param_sym) {
843 if alias.is_some() {
844 return None; }
846 alias = Some(*var);
847 }
848 }
849 }
850 alias
851}
852
853fn is_sole_return_param_or_alias(stmts: &[Stmt<'_>], param_sym: Symbol, alias_sym: Symbol) -> bool {
855 let mut returns = Vec::new();
856 collect_returns(stmts, &mut returns);
857 !returns.is_empty() && returns.iter().all(|r| *r == param_sym || *r == alias_sym)
858}
859
860fn reassignment_only_self_calls(
863 body: &[Stmt<'_>],
864 alias: Symbol,
865 func_name: Symbol,
866 param_position: usize,
867) -> bool {
868 check_reassignment_self_calls(body, alias, func_name, param_position)
869}
870
871fn check_reassignment_self_calls(
872 stmts: &[Stmt<'_>],
873 alias: Symbol,
874 func_name: Symbol,
875 param_position: usize,
876) -> bool {
877 for stmt in stmts {
878 match stmt {
879 Stmt::Set { target, value } if *target == alias => {
880 match value {
882 Expr::Call { function, args } if *function == func_name => {
883 let arg_at_pos = args.get(param_position);
884 let is_alias_at_pos = arg_at_pos
885 .map(|a| matches!(a, Expr::Identifier(s) if *s == alias))
886 .unwrap_or(false);
887 if !is_alias_at_pos {
888 return false;
889 }
890 }
891 _ => return false, }
893 }
894 Stmt::If { then_block, else_block, .. } => {
895 if !check_reassignment_self_calls(then_block, alias, func_name, param_position) {
896 return false;
897 }
898 if let Some(else_b) = else_block {
899 if !check_reassignment_self_calls(else_b, alias, func_name, param_position) {
900 return false;
901 }
902 }
903 }
904 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
905 if !check_reassignment_self_calls(body, alias, func_name, param_position) {
906 return false;
907 }
908 }
909 Stmt::Inspect { arms, .. } => {
910 for arm in arms {
911 if !check_reassignment_self_calls(arm.body, alias, func_name, param_position) {
912 return false;
913 }
914 }
915 }
916 _ => {}
917 }
918 }
919 true
920}
921
922fn is_param_dead_after_consume(body: &[Stmt<'_>], param_sym: Symbol, alias: Symbol) -> bool {
927 let mut found_consume = false;
928 for stmt in body {
929 if !found_consume {
930 if let Stmt::Let { var, mutable: true, value, .. } = stmt {
932 if *var == alias && matches!(value, Expr::Identifier(s) if *s == param_sym) {
933 found_consume = true;
934 continue;
935 }
936 }
937 } else {
938 if stmt_references_symbol(stmt, param_sym) {
940 return false;
941 }
942 }
943 }
944 found_consume }
946
947fn stmt_references_symbol(stmt: &Stmt<'_>, sym: Symbol) -> bool {
949 match stmt {
950 Stmt::Let { value, .. } => expr_references_symbol(value, sym),
951 Stmt::Set { target, value } => *target == sym || expr_references_symbol(value, sym),
952 Stmt::Call { function, args } => {
953 *function == sym || args.iter().any(|a| expr_references_symbol(a, sym))
954 }
955 Stmt::Push { value, collection } => {
956 expr_references_symbol(value, sym) || expr_references_symbol(collection, sym)
957 }
958 Stmt::Pop { collection, into } => {
959 expr_references_symbol(collection, sym)
960 || into.map_or(false, |s| s == sym)
961 }
962 Stmt::Add { value, collection } | Stmt::Remove { value, collection } => {
963 expr_references_symbol(value, sym) || expr_references_symbol(collection, sym)
964 }
965 Stmt::SetIndex { collection, index, value } => {
966 expr_references_symbol(collection, sym)
967 || expr_references_symbol(index, sym)
968 || expr_references_symbol(value, sym)
969 }
970 Stmt::SetField { object, value, .. } => {
971 expr_references_symbol(object, sym) || expr_references_symbol(value, sym)
972 }
973 Stmt::Return { value: Some(v) } => expr_references_symbol(v, sym),
974 Stmt::Return { value: None } => false,
975 Stmt::If { cond, then_block, else_block } => {
976 expr_references_symbol(cond, sym)
977 || then_block.iter().any(|s| stmt_references_symbol(s, sym))
978 || else_block.as_ref().map_or(false, |eb| eb.iter().any(|s| stmt_references_symbol(s, sym)))
979 }
980 Stmt::While { cond, body, .. } => {
981 expr_references_symbol(cond, sym)
982 || body.iter().any(|s| stmt_references_symbol(s, sym))
983 }
984 Stmt::Repeat { iterable, body, .. } => {
985 expr_references_symbol(iterable, sym)
986 || body.iter().any(|s| stmt_references_symbol(s, sym))
987 }
988 Stmt::Inspect { arms, .. } => {
989 arms.iter().any(|arm| arm.body.iter().any(|s| stmt_references_symbol(s, sym)))
990 }
991 Stmt::Show { object, .. } => expr_references_symbol(object, sym),
992 _ => false,
993 }
994}
995
996fn expr_references_symbol(expr: &Expr<'_>, sym: Symbol) -> bool {
997 match expr {
998 Expr::Identifier(s) => *s == sym,
999 Expr::BinaryOp { left, right, .. } => {
1000 expr_references_symbol(left, sym) || expr_references_symbol(right, sym)
1001 }
1002 Expr::Not { operand } => expr_references_symbol(operand, sym),
1003 Expr::Call { function, args } => {
1004 *function == sym || args.iter().any(|a| expr_references_symbol(a, sym))
1005 }
1006 Expr::Index { collection, index } => {
1007 expr_references_symbol(collection, sym) || expr_references_symbol(index, sym)
1008 }
1009 Expr::Length { collection } => expr_references_symbol(collection, sym),
1010 Expr::Contains { collection, value } => {
1011 expr_references_symbol(collection, sym) || expr_references_symbol(value, sym)
1012 }
1013 Expr::FieldAccess { object, .. } => expr_references_symbol(object, sym),
1014 Expr::Slice { collection, start, end } => {
1015 expr_references_symbol(collection, sym)
1016 || expr_references_symbol(start, sym)
1017 || expr_references_symbol(end, sym)
1018 }
1019 Expr::Copy { expr: inner } | Expr::Give { value: inner } | Expr::OptionSome { value: inner } => {
1020 expr_references_symbol(inner, sym)
1021 }
1022 Expr::CallExpr { callee, args } => {
1023 expr_references_symbol(callee, sym)
1024 || args.iter().any(|a| expr_references_symbol(a, sym))
1025 }
1026 _ => false,
1027 }
1028}