1use std::collections::{HashMap, HashSet};
2
3use logicaffeine_base::Symbol;
4use logicaffeine_language::ast::{Expr, Stmt};
5use logicaffeine_language::ast::stmt::ClosureBody;
6
7use super::callgraph::CallGraph;
8use super::types::{LogosType, TypeEnv};
9
10pub struct ReadonlyParams {
31 pub readonly: HashMap<Symbol, HashSet<Symbol>>,
33}
34
35impl ReadonlyParams {
36 pub fn analyze(stmts: &[Stmt<'_>], callgraph: &CallGraph, type_env: &TypeEnv) -> Self {
46 let mut fn_params: HashMap<Symbol, Vec<Symbol>> = HashMap::new();
48 for stmt in stmts {
49 if let Stmt::FunctionDef { name, params, .. } = stmt {
50 let syms: Vec<Symbol> = params.iter().map(|(s, _)| *s).collect();
51 fn_params.insert(*name, syms);
52 }
53 }
54
55 let mut readonly: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
57 for stmt in stmts {
58 if let Stmt::FunctionDef { name, params, .. } = stmt {
59 let mut candidates = HashSet::new();
60 for (sym, _) in params {
61 if is_seq_type(type_env.lookup(*sym)) {
62 candidates.insert(*sym);
63 }
64 }
65 readonly.insert(*name, candidates);
66 }
67 }
68
69 for stmt in stmts {
71 if let Stmt::FunctionDef { name, params, body, is_native, .. } = stmt {
72 if *is_native {
73 continue;
74 }
75 let param_set: HashSet<Symbol> = params.iter().map(|(s, _)| *s).collect();
76 let mutated = collect_direct_mutations(body, ¶m_set);
77 if let Some(candidates) = readonly.get_mut(name) {
78 for sym in &mutated {
79 candidates.remove(sym);
80 }
81 }
82 }
83 }
84
85 loop {
87 let mut changed = false;
88
89 for stmt in stmts {
90 if let Stmt::FunctionDef { name: caller, body, is_native, .. } = stmt {
91 if *is_native {
92 continue;
93 }
94
95 let call_sites = collect_call_sites(body);
97
98 for (callee, arg_syms) in &call_sites {
99 let callee_params = match fn_params.get(callee) {
100 Some(p) => p,
101 None => continue, };
103
104 for (i, maybe_arg_sym) in arg_syms.iter().enumerate() {
105 let arg_sym = match maybe_arg_sym {
106 Some(s) => s,
107 None => continue, };
109
110 let callee_param = match callee_params.get(i) {
111 Some(p) => p,
112 None => continue,
113 };
114
115 let callee_param_readonly = readonly
117 .get(callee)
118 .map(|s| s.contains(callee_param))
119 .unwrap_or(true); if !callee_param_readonly {
122 if let Some(caller_readonly) = readonly.get_mut(caller) {
124 if caller_readonly.remove(arg_sym) {
125 changed = true;
126 }
127 }
128 }
129 }
130 }
131 }
132 }
133
134 if !changed {
135 break;
136 }
137 }
138
139 Self { readonly }
140 }
141
142 pub fn is_readonly(&self, fn_sym: Symbol, param_sym: Symbol) -> bool {
144 self.readonly
145 .get(&fn_sym)
146 .map(|s| s.contains(¶m_sym))
147 .unwrap_or(false)
148 }
149}
150
151fn is_seq_type(ty: &LogosType) -> bool {
152 matches!(ty, LogosType::Seq(_))
153}
154
155fn collect_direct_mutations(stmts: &[Stmt<'_>], param_set: &HashSet<Symbol>) -> HashSet<Symbol> {
170 let mut mutated = HashSet::new();
171 for stmt in stmts {
172 collect_mutations_from_stmt(stmt, param_set, &mut mutated);
173 }
174 collect_consumed_params(stmts, param_set, &mut mutated);
179 mutated
180}
181
182fn collect_mutations_from_stmt(stmt: &Stmt<'_>, param_set: &HashSet<Symbol>, mutated: &mut HashSet<Symbol>) {
183 match stmt {
184 Stmt::Push { collection, .. } => {
185 if let Expr::Identifier(sym) = **collection {
186 if param_set.contains(&sym) {
187 mutated.insert(sym);
188 }
189 }
190 }
191 Stmt::Pop { collection, .. } => {
192 if let Expr::Identifier(sym) = **collection {
193 if param_set.contains(&sym) {
194 mutated.insert(sym);
195 }
196 }
197 }
198 Stmt::Add { collection, .. } => {
199 if let Expr::Identifier(sym) = **collection {
200 if param_set.contains(&sym) {
201 mutated.insert(sym);
202 }
203 }
204 }
205 Stmt::Remove { collection, .. } => {
206 if let Expr::Identifier(sym) = **collection {
207 if param_set.contains(&sym) {
208 mutated.insert(sym);
209 }
210 }
211 }
212 Stmt::SetIndex { collection, .. } => {
213 if let Expr::Identifier(sym) = **collection {
214 if param_set.contains(&sym) {
215 mutated.insert(sym);
216 }
217 }
218 }
219 Stmt::SetField { object, .. } => {
220 if let Expr::Identifier(sym) = **object {
221 if param_set.contains(&sym) {
222 mutated.insert(sym);
223 }
224 }
225 }
226 Stmt::Set { target, .. } => {
227 if param_set.contains(target) {
228 mutated.insert(*target);
229 }
230 }
231 Stmt::If { then_block, else_block, .. } => {
233 for s in *then_block {
234 collect_mutations_from_stmt(s, param_set, mutated);
235 }
236 if let Some(else_b) = else_block {
237 for s in *else_b {
238 collect_mutations_from_stmt(s, param_set, mutated);
239 }
240 }
241 }
242 Stmt::While { body, .. } => {
243 for s in *body {
244 collect_mutations_from_stmt(s, param_set, mutated);
245 }
246 }
247 Stmt::Repeat { body, .. } => {
248 for s in *body {
249 collect_mutations_from_stmt(s, param_set, mutated);
250 }
251 }
252 Stmt::Inspect { arms, .. } => {
253 for arm in arms {
254 for s in arm.body {
255 collect_mutations_from_stmt(s, param_set, mutated);
256 }
257 }
258 }
259 _ => {}
260 }
261}
262
263fn collect_consumed_params(stmts: &[Stmt<'_>], param_set: &HashSet<Symbol>, consumed: &mut HashSet<Symbol>) {
266 for stmt in stmts {
267 match stmt {
268 Stmt::Let { mutable: true, value, .. } => {
269 if let Expr::Identifier(sym) = value {
270 if param_set.contains(sym) {
271 consumed.insert(*sym);
272 }
273 }
274 }
275 Stmt::If { then_block, else_block, .. } => {
276 collect_consumed_params(then_block, param_set, consumed);
277 if let Some(else_b) = else_block {
278 collect_consumed_params(else_b, param_set, consumed);
279 }
280 }
281 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
282 collect_consumed_params(body, param_set, consumed);
283 }
284 Stmt::Inspect { arms, .. } => {
285 for arm in arms {
286 collect_consumed_params(arm.body, param_set, consumed);
287 }
288 }
289 _ => {}
290 }
291 }
292}
293
294fn collect_call_sites(stmts: &[Stmt<'_>]) -> Vec<(Symbol, Vec<Option<Symbol>>)> {
303 let mut sites = Vec::new();
304 collect_call_sites_from_stmts(stmts, &mut sites);
305 sites
306}
307
308fn collect_call_sites_from_stmts(stmts: &[Stmt<'_>], sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
309 for stmt in stmts {
310 collect_call_sites_from_stmt(stmt, sites);
311 }
312}
313
314fn collect_call_sites_from_stmt(stmt: &Stmt<'_>, sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
315 match stmt {
316 Stmt::Call { function, args } => {
317 let arg_syms = args.iter().map(|arg| {
318 if let Expr::Identifier(sym) = *arg { Some(*sym) } else { None }
319 }).collect();
320 sites.push((*function, arg_syms));
321 for arg in args {
322 collect_call_sites_from_expr(arg, sites);
323 }
324 }
325 Stmt::Let { value, .. } => collect_call_sites_from_expr(value, sites),
326 Stmt::Set { value, .. } => collect_call_sites_from_expr(value, sites),
327 Stmt::Return { value: Some(v) } => collect_call_sites_from_expr(v, sites),
328 Stmt::If { cond, then_block, else_block } => {
329 collect_call_sites_from_expr(cond, sites);
330 collect_call_sites_from_stmts(then_block, sites);
331 if let Some(else_b) = else_block {
332 collect_call_sites_from_stmts(else_b, sites);
333 }
334 }
335 Stmt::While { cond, body, .. } => {
336 collect_call_sites_from_expr(cond, sites);
337 collect_call_sites_from_stmts(body, sites);
338 }
339 Stmt::Repeat { iterable, body, .. } => {
340 collect_call_sites_from_expr(iterable, sites);
341 collect_call_sites_from_stmts(body, sites);
342 }
343 Stmt::Push { value, collection } => {
344 collect_call_sites_from_expr(value, sites);
345 collect_call_sites_from_expr(collection, sites);
346 }
347 Stmt::Inspect { arms, .. } => {
348 for arm in arms {
349 collect_call_sites_from_stmts(arm.body, sites);
350 }
351 }
352 Stmt::Concurrent { tasks } | Stmt::Parallel { tasks } => {
353 collect_call_sites_from_stmts(tasks, sites);
354 }
355 _ => {}
356 }
357}
358
359fn collect_call_sites_from_expr(expr: &Expr<'_>, sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
360 match expr {
361 Expr::Call { function, args } => {
362 let arg_syms = args.iter().map(|arg| {
363 if let Expr::Identifier(sym) = *arg { Some(*sym) } else { None }
364 }).collect();
365 sites.push((*function, arg_syms));
366 for arg in args {
367 collect_call_sites_from_expr(arg, sites);
368 }
369 }
370 Expr::Closure { body, .. } => match body {
371 ClosureBody::Expression(e) => collect_call_sites_from_expr(e, sites),
372 ClosureBody::Block(stmts) => collect_call_sites_from_stmts(stmts, sites),
373 },
374 Expr::BinaryOp { left, right, .. } => {
375 collect_call_sites_from_expr(left, sites);
376 collect_call_sites_from_expr(right, sites);
377 }
378 Expr::Index { collection, index } => {
379 collect_call_sites_from_expr(collection, sites);
380 collect_call_sites_from_expr(index, sites);
381 }
382 Expr::Length { collection } => collect_call_sites_from_expr(collection, sites),
383 Expr::Contains { collection, value } => {
384 collect_call_sites_from_expr(collection, sites);
385 collect_call_sites_from_expr(value, sites);
386 }
387 Expr::FieldAccess { object, .. } => collect_call_sites_from_expr(object, sites),
388 Expr::Copy { expr } | Expr::Give { value: expr } => {
389 collect_call_sites_from_expr(expr, sites);
390 }
391 Expr::OptionSome { value } => collect_call_sites_from_expr(value, sites),
392 Expr::WithCapacity { value, capacity } => {
393 collect_call_sites_from_expr(value, sites);
394 collect_call_sites_from_expr(capacity, sites);
395 }
396 Expr::CallExpr { callee, args } => {
397 collect_call_sites_from_expr(callee, sites);
398 for arg in args {
399 collect_call_sites_from_expr(arg, sites);
400 }
401 }
402 _ => {}
403 }
404}
405
406pub struct MutableBorrowParams {
420 pub mutable_borrow: HashMap<Symbol, HashSet<Symbol>>,
422}
423
424impl MutableBorrowParams {
425 pub fn analyze(stmts: &[Stmt<'_>], callgraph: &CallGraph, type_env: &TypeEnv) -> Self {
427 let mut fn_params: HashMap<Symbol, Vec<Symbol>> = HashMap::new();
428 for stmt in stmts {
429 if let Stmt::FunctionDef { name, params, .. } = stmt {
430 let syms: Vec<Symbol> = params.iter().map(|(s, _)| *s).collect();
431 fn_params.insert(*name, syms);
432 }
433 }
434
435 let mut mutable_borrow: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
436
437 for stmt in stmts {
438 if let Stmt::FunctionDef { name, params, body, is_native, is_exported, .. } = stmt {
439 if *is_native || *is_exported {
440 continue;
441 }
442
443 let mut candidates = HashSet::new();
444
445 for (sym, _) in params {
446 if !is_seq_type(type_env.lookup(*sym)) {
447 continue;
448 }
449
450 let has_set_index = has_set_index_on(body, *sym);
451 let has_structural = has_structural_mutation_on(body, *sym);
452 let has_reassign = has_reassignment_on(body, *sym);
453 let consumed = is_consumed_param(body, *sym);
454 let returned = is_sole_return_param(body, *sym);
455
456 if has_set_index && !has_structural && !has_reassign && !consumed && returned {
457 candidates.insert(*sym);
458 } else if consumed {
459 let param_idx = params.iter().position(|(s, _)| *s == *sym).unwrap_or(usize::MAX);
463 if let Some(alias) = detect_consume_alias(body, *sym) {
464 let alias_has_set_index = has_set_index_on(body, alias);
465 let alias_has_structural = has_structural_mutation_on(body, alias);
466 let alias_returned = is_sole_return_param_or_alias(body, *sym, alias);
467 let alias_reassign_ok = reassignment_only_self_calls(body, alias, *name, param_idx);
468 let param_dead = is_param_dead_after_consume(body, *sym, alias);
469
470 if alias_has_set_index && !alias_has_structural && alias_returned && alias_reassign_ok && param_dead {
471 candidates.insert(*sym);
472 }
473 }
474 }
475 }
476
477 if !candidates.is_empty() {
478 mutable_borrow.insert(*name, candidates);
479 }
480 }
481 }
482
483 loop {
485 let mut changed = false;
486 for stmt in stmts {
487 if let Stmt::FunctionDef { name: caller, body, is_native, .. } = stmt {
488 if *is_native {
489 continue;
490 }
491 let call_sites = collect_call_sites(body);
492 for (callee, arg_syms) in &call_sites {
493 let callee_params = match fn_params.get(callee) {
494 Some(p) => p,
495 None => continue,
496 };
497 for (i, maybe_arg_sym) in arg_syms.iter().enumerate() {
498 let arg_sym = match maybe_arg_sym {
499 Some(s) => s,
500 None => continue,
501 };
502 let callee_param = match callee_params.get(i) {
503 Some(p) => p,
504 None => continue,
505 };
506 let callee_is_mut_borrow = mutable_borrow
507 .get(callee)
508 .map(|s| s.contains(callee_param))
509 .unwrap_or(false);
510 if !callee_is_mut_borrow {
511 if let Some(caller_set) = mutable_borrow.get_mut(caller) {
512 if caller_set.remove(arg_sym) {
513 changed = true;
514 }
515 }
516 }
517 }
518 }
519 }
520 }
521 if !changed {
522 break;
523 }
524 }
525
526 let incompatible = collect_incompatible_mut_borrow_callsites(
532 stmts, &mutable_borrow, &fn_params,
533 );
534 for fn_sym in incompatible {
535 mutable_borrow.remove(&fn_sym);
536 }
537
538 Self { mutable_borrow }
539 }
540
541 pub fn is_mutable_borrow(&self, fn_sym: Symbol, param_sym: Symbol) -> bool {
542 self.mutable_borrow
543 .get(&fn_sym)
544 .map(|s| s.contains(¶m_sym))
545 .unwrap_or(false)
546 }
547}
548
549fn has_set_index_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
550 stmts.iter().any(|s| check_set_index_stmt(s, sym))
551}
552
553fn check_set_index_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
554 match stmt {
555 Stmt::SetIndex { collection, .. } => {
556 matches!(**collection, Expr::Identifier(s) if s == sym)
557 }
558 Stmt::If { then_block, else_block, .. } => {
559 has_set_index_on(then_block, sym)
560 || else_block.as_ref().map_or(false, |eb| has_set_index_on(eb, sym))
561 }
562 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
563 has_set_index_on(body, sym)
564 }
565 Stmt::Inspect { arms, .. } => {
566 arms.iter().any(|arm| has_set_index_on(arm.body, sym))
567 }
568 _ => false,
569 }
570}
571
572fn has_structural_mutation_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
573 stmts.iter().any(|s| check_structural_stmt(s, sym))
574}
575
576fn check_structural_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
577 match stmt {
578 Stmt::Push { collection, .. } | Stmt::Pop { collection, .. }
579 | Stmt::Add { collection, .. } | Stmt::Remove { collection, .. } => {
580 matches!(**collection, Expr::Identifier(s) if s == sym)
581 }
582 Stmt::If { then_block, else_block, .. } => {
583 has_structural_mutation_on(then_block, sym)
584 || else_block.as_ref().map_or(false, |eb| has_structural_mutation_on(eb, sym))
585 }
586 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
587 has_structural_mutation_on(body, sym)
588 }
589 Stmt::Inspect { arms, .. } => {
590 arms.iter().any(|arm| has_structural_mutation_on(arm.body, sym))
591 }
592 _ => false,
593 }
594}
595
596fn has_reassignment_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
597 stmts.iter().any(|s| check_reassignment_stmt(s, sym))
598}
599
600fn check_reassignment_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
601 match stmt {
602 Stmt::Set { target, .. } => *target == sym,
603 Stmt::If { then_block, else_block, .. } => {
604 has_reassignment_on(then_block, sym)
605 || else_block.as_ref().map_or(false, |eb| has_reassignment_on(eb, sym))
606 }
607 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
608 has_reassignment_on(body, sym)
609 }
610 Stmt::Inspect { arms, .. } => {
611 arms.iter().any(|arm| has_reassignment_on(arm.body, sym))
612 }
613 _ => false,
614 }
615}
616
617fn is_consumed_param(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
618 for stmt in stmts {
619 match stmt {
620 Stmt::Let { mutable: true, value, .. } => {
621 if matches!(value, Expr::Identifier(s) if *s == sym) {
622 return true;
623 }
624 }
625 Stmt::If { then_block, else_block, .. } => {
626 if is_consumed_param(then_block, sym) { return true; }
627 if let Some(else_b) = else_block {
628 if is_consumed_param(else_b, sym) { return true; }
629 }
630 }
631 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
632 if is_consumed_param(body, sym) { return true; }
633 }
634 _ => {}
635 }
636 }
637 false
638}
639
640fn is_sole_return_param(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
641 let mut returns = Vec::new();
642 collect_returns(stmts, &mut returns);
643 !returns.is_empty() && returns.iter().all(|r| *r == sym)
644}
645
646fn collect_returns(stmts: &[Stmt<'_>], returns: &mut Vec<Symbol>) {
647 for stmt in stmts {
648 match stmt {
649 Stmt::Return { value: Some(expr) } => {
650 if let Expr::Identifier(sym) = expr {
651 returns.push(*sym);
652 } else {
653 returns.push(Symbol::EMPTY);
655 }
656 }
657 Stmt::If { then_block, else_block, .. } => {
658 collect_returns(then_block, returns);
659 if let Some(else_b) = else_block {
660 collect_returns(else_b, returns);
661 }
662 }
663 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
664 collect_returns(body, returns);
665 }
666 Stmt::Inspect { arms, .. } => {
667 for arm in arms {
668 collect_returns(arm.body, returns);
669 }
670 }
671 _ => {}
672 }
673 }
674}
675
676fn collect_incompatible_mut_borrow_callsites(
685 stmts: &[Stmt<'_>],
686 mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
687 fn_params: &HashMap<Symbol, Vec<Symbol>>,
688) -> HashSet<Symbol> {
689 let mut incompatible = HashSet::new();
690 for stmt in stmts {
691 if let Stmt::FunctionDef { body, .. } = stmt {
692 check_callsite_compat_stmts(body, mutable_borrow, fn_params, &mut incompatible);
693 }
694 }
695 check_callsite_compat_stmts(stmts, mutable_borrow, fn_params, &mut incompatible);
697 incompatible
698}
699
700fn check_callsite_compat_stmts(
701 stmts: &[Stmt<'_>],
702 mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
703 fn_params: &HashMap<Symbol, Vec<Symbol>>,
704 incompatible: &mut HashSet<Symbol>,
705) {
706 for stmt in stmts {
707 check_callsite_compat_stmt(stmt, mutable_borrow, fn_params, incompatible);
708 }
709}
710
711fn check_callsite_compat_stmt(
712 stmt: &Stmt<'_>,
713 mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
714 fn_params: &HashMap<Symbol, Vec<Symbol>>,
715 incompatible: &mut HashSet<Symbol>,
716) {
717 match stmt {
718 Stmt::Call { args, .. } => {
720 for arg in args {
721 check_callsite_compat_expr(arg, mutable_borrow, incompatible);
722 }
723 }
724 Stmt::Set { target, value } => {
726 if let Expr::Call { function, args } = value {
727 if mutable_borrow.contains_key(function) {
728 let mut_positions: HashSet<usize> = fn_params.get(function)
730 .map(|params| {
731 params.iter().enumerate()
732 .filter(|(_, sym)| {
733 mutable_borrow.get(function)
734 .map(|s| s.contains(sym))
735 .unwrap_or(false)
736 })
737 .map(|(i, _)| i)
738 .collect()
739 })
740 .unwrap_or_default();
741
742 let target_at_mut_pos = args.iter().enumerate()
743 .any(|(i, a)| {
744 mut_positions.contains(&i)
745 && matches!(a, Expr::Identifier(sym) if *sym == *target)
746 });
747
748 if !target_at_mut_pos {
749 incompatible.insert(*function);
750 }
751 }
752 for arg in args {
754 check_callsite_compat_expr(arg, mutable_borrow, incompatible);
755 }
756 } else {
757 check_callsite_compat_expr(value, mutable_borrow, incompatible);
758 }
759 }
760 Stmt::Let { value, .. } => {
762 check_callsite_compat_expr(value, mutable_borrow, incompatible);
763 }
764 Stmt::Return { value: Some(v) } => {
765 check_callsite_compat_expr(v, mutable_borrow, incompatible);
766 }
767 Stmt::Show { object, .. } => {
768 check_callsite_compat_expr(object, mutable_borrow, incompatible);
769 }
770 Stmt::Push { value, collection } => {
771 check_callsite_compat_expr(value, mutable_borrow, incompatible);
772 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
773 }
774 Stmt::SetIndex { collection, index, value } => {
775 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
776 check_callsite_compat_expr(index, mutable_borrow, incompatible);
777 check_callsite_compat_expr(value, mutable_borrow, incompatible);
778 }
779 Stmt::If { cond, then_block, else_block } => {
780 check_callsite_compat_expr(cond, mutable_borrow, incompatible);
781 check_callsite_compat_stmts(then_block, mutable_borrow, fn_params, incompatible);
782 if let Some(else_b) = else_block {
783 check_callsite_compat_stmts(else_b, mutable_borrow, fn_params, incompatible);
784 }
785 }
786 Stmt::While { cond, body, .. } => {
787 check_callsite_compat_expr(cond, mutable_borrow, incompatible);
788 check_callsite_compat_stmts(body, mutable_borrow, fn_params, incompatible);
789 }
790 Stmt::Repeat { iterable, body, .. } => {
791 check_callsite_compat_expr(iterable, mutable_borrow, incompatible);
792 check_callsite_compat_stmts(body, mutable_borrow, fn_params, incompatible);
793 }
794 Stmt::Inspect { arms, .. } => {
795 for arm in arms {
796 check_callsite_compat_stmts(arm.body, mutable_borrow, fn_params, incompatible);
797 }
798 }
799 _ => {}
801 }
802}
803
804fn check_callsite_compat_expr(
806 expr: &Expr<'_>,
807 mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
808 incompatible: &mut HashSet<Symbol>,
809) {
810 match expr {
811 Expr::Call { function, args } => {
812 if mutable_borrow.contains_key(function) {
813 incompatible.insert(*function);
814 }
815 for arg in args {
816 check_callsite_compat_expr(arg, mutable_borrow, incompatible);
817 }
818 }
819 Expr::BinaryOp { left, right, .. } => {
820 check_callsite_compat_expr(left, mutable_borrow, incompatible);
821 check_callsite_compat_expr(right, mutable_borrow, incompatible);
822 }
823 Expr::Index { collection, index } => {
824 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
825 check_callsite_compat_expr(index, mutable_borrow, incompatible);
826 }
827 Expr::Length { collection } => {
828 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
829 }
830 Expr::Contains { collection, value } => {
831 check_callsite_compat_expr(collection, mutable_borrow, incompatible);
832 check_callsite_compat_expr(value, mutable_borrow, incompatible);
833 }
834 Expr::FieldAccess { object, .. } => {
835 check_callsite_compat_expr(object, mutable_borrow, incompatible);
836 }
837 Expr::Copy { expr: inner } | Expr::Give { value: inner } | Expr::OptionSome { value: inner } => {
838 check_callsite_compat_expr(inner, mutable_borrow, incompatible);
839 }
840 _ => {}
841 }
842}
843
844fn detect_consume_alias(body: &[Stmt<'_>], param_sym: Symbol) -> Option<Symbol> {
851 let mut alias = None;
852 for stmt in body {
853 if let Stmt::Let { var, mutable: true, value, .. } = stmt {
854 if matches!(value, Expr::Identifier(s) if *s == param_sym) {
855 if alias.is_some() {
856 return None; }
858 alias = Some(*var);
859 }
860 }
861 }
862 alias
863}
864
865fn is_sole_return_param_or_alias(stmts: &[Stmt<'_>], param_sym: Symbol, alias_sym: Symbol) -> bool {
867 let mut returns = Vec::new();
868 collect_returns(stmts, &mut returns);
869 !returns.is_empty() && returns.iter().all(|r| *r == param_sym || *r == alias_sym)
870}
871
872fn reassignment_only_self_calls(
875 body: &[Stmt<'_>],
876 alias: Symbol,
877 func_name: Symbol,
878 param_position: usize,
879) -> bool {
880 check_reassignment_self_calls(body, alias, func_name, param_position)
881}
882
883fn check_reassignment_self_calls(
884 stmts: &[Stmt<'_>],
885 alias: Symbol,
886 func_name: Symbol,
887 param_position: usize,
888) -> bool {
889 for stmt in stmts {
890 match stmt {
891 Stmt::Set { target, value } if *target == alias => {
892 match value {
894 Expr::Call { function, args } if *function == func_name => {
895 let arg_at_pos = args.get(param_position);
896 let is_alias_at_pos = arg_at_pos
897 .map(|a| matches!(a, Expr::Identifier(s) if *s == alias))
898 .unwrap_or(false);
899 if !is_alias_at_pos {
900 return false;
901 }
902 }
903 _ => return false, }
905 }
906 Stmt::If { then_block, else_block, .. } => {
907 if !check_reassignment_self_calls(then_block, alias, func_name, param_position) {
908 return false;
909 }
910 if let Some(else_b) = else_block {
911 if !check_reassignment_self_calls(else_b, alias, func_name, param_position) {
912 return false;
913 }
914 }
915 }
916 Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
917 if !check_reassignment_self_calls(body, alias, func_name, param_position) {
918 return false;
919 }
920 }
921 Stmt::Inspect { arms, .. } => {
922 for arm in arms {
923 if !check_reassignment_self_calls(arm.body, alias, func_name, param_position) {
924 return false;
925 }
926 }
927 }
928 _ => {}
929 }
930 }
931 true
932}
933
934fn is_param_dead_after_consume(body: &[Stmt<'_>], param_sym: Symbol, alias: Symbol) -> bool {
939 let mut found_consume = false;
940 for stmt in body {
941 if !found_consume {
942 if let Stmt::Let { var, mutable: true, value, .. } = stmt {
944 if *var == alias && matches!(value, Expr::Identifier(s) if *s == param_sym) {
945 found_consume = true;
946 continue;
947 }
948 }
949 } else {
950 if stmt_references_symbol(stmt, param_sym) {
952 return false;
953 }
954 }
955 }
956 found_consume }
958
959fn stmt_references_symbol(stmt: &Stmt<'_>, sym: Symbol) -> bool {
961 match stmt {
962 Stmt::Let { value, .. } => expr_references_symbol(value, sym),
963 Stmt::Set { target, value } => *target == sym || expr_references_symbol(value, sym),
964 Stmt::Call { function, args } => {
965 *function == sym || args.iter().any(|a| expr_references_symbol(a, sym))
966 }
967 Stmt::Push { value, collection } => {
968 expr_references_symbol(value, sym) || expr_references_symbol(collection, sym)
969 }
970 Stmt::Pop { collection, into } => {
971 expr_references_symbol(collection, sym)
972 || into.map_or(false, |s| s == sym)
973 }
974 Stmt::Add { value, collection } | Stmt::Remove { value, collection } => {
975 expr_references_symbol(value, sym) || expr_references_symbol(collection, sym)
976 }
977 Stmt::SetIndex { collection, index, value } => {
978 expr_references_symbol(collection, sym)
979 || expr_references_symbol(index, sym)
980 || expr_references_symbol(value, sym)
981 }
982 Stmt::SetField { object, value, .. } => {
983 expr_references_symbol(object, sym) || expr_references_symbol(value, sym)
984 }
985 Stmt::Return { value: Some(v) } => expr_references_symbol(v, sym),
986 Stmt::Return { value: None } => false,
987 Stmt::If { cond, then_block, else_block } => {
988 expr_references_symbol(cond, sym)
989 || then_block.iter().any(|s| stmt_references_symbol(s, sym))
990 || else_block.as_ref().map_or(false, |eb| eb.iter().any(|s| stmt_references_symbol(s, sym)))
991 }
992 Stmt::While { cond, body, .. } => {
993 expr_references_symbol(cond, sym)
994 || body.iter().any(|s| stmt_references_symbol(s, sym))
995 }
996 Stmt::Repeat { iterable, body, .. } => {
997 expr_references_symbol(iterable, sym)
998 || body.iter().any(|s| stmt_references_symbol(s, sym))
999 }
1000 Stmt::Inspect { arms, .. } => {
1001 arms.iter().any(|arm| arm.body.iter().any(|s| stmt_references_symbol(s, sym)))
1002 }
1003 Stmt::Show { object, .. } => expr_references_symbol(object, sym),
1004 _ => false,
1005 }
1006}
1007
1008fn expr_references_symbol(expr: &Expr<'_>, sym: Symbol) -> bool {
1009 match expr {
1010 Expr::Identifier(s) => *s == sym,
1011 Expr::BinaryOp { left, right, .. } => {
1012 expr_references_symbol(left, sym) || expr_references_symbol(right, sym)
1013 }
1014 Expr::Not { operand } => expr_references_symbol(operand, sym),
1015 Expr::Call { function, args } => {
1016 *function == sym || args.iter().any(|a| expr_references_symbol(a, sym))
1017 }
1018 Expr::Index { collection, index } => {
1019 expr_references_symbol(collection, sym) || expr_references_symbol(index, sym)
1020 }
1021 Expr::Length { collection } => expr_references_symbol(collection, sym),
1022 Expr::Contains { collection, value } => {
1023 expr_references_symbol(collection, sym) || expr_references_symbol(value, sym)
1024 }
1025 Expr::FieldAccess { object, .. } => expr_references_symbol(object, sym),
1026 Expr::Slice { collection, start, end } => {
1027 expr_references_symbol(collection, sym)
1028 || expr_references_symbol(start, sym)
1029 || expr_references_symbol(end, sym)
1030 }
1031 Expr::Copy { expr: inner } | Expr::Give { value: inner } | Expr::OptionSome { value: inner } => {
1032 expr_references_symbol(inner, sym)
1033 }
1034 Expr::CallExpr { callee, args } => {
1035 expr_references_symbol(callee, sym)
1036 || args.iter().any(|a| expr_references_symbol(a, sym))
1037 }
1038 _ => false,
1039 }
1040}