Skip to main content

logicaffeine_compile/analysis/
readonly.rs

1use std::collections::{HashMap, HashSet};
2
3use logicaffeine_base::Symbol;
4use logicaffeine_language::ast::{Expr, Stmt};
5use logicaffeine_language::ast::stmt::ClosureBody;
6
7use super::callgraph::CallGraph;
8use super::types::{LogosType, TypeEnv};
9
10/// Readonly parameter analysis result.
11///
12/// Maps each function to the set of its `Seq<T>` parameters that are never
13/// structurally mutated (no Push, Pop, Add, Remove, SetIndex, or reassignment)
14/// either directly or transitively through callees.
15///
16/// Parameters in this set are eligible for `&[T]` borrow in codegen instead
17/// of requiring ownership or cloning.
18///
19/// # Safety under Rc<RefCell> (reference) semantics
20///
21/// LogosSeq<T> is Rc<RefCell<Vec<T>>>. The analysis is safe because:
22/// - `Let mutable X be param` is caught by `collect_consumed_params`, which
23///   removes the param from the readonly set. This prevents aliasing where
24///   a mutable local could mutate through the shared Rc.
25/// - Interior mutation via Rc clone requires the `mutable` keyword on the
26///   alias, which triggers the consumed-param check.
27/// - Fixed-point propagation through the callgraph catches transitive
28///   mutations: if a callee mutates a param position, all callers passing
29///   a readonly candidate to that position lose their readonly status.
30pub struct ReadonlyParams {
31    /// fn_sym → set of param symbols that are readonly within that function.
32    pub readonly: HashMap<Symbol, HashSet<Symbol>>,
33}
34
35impl ReadonlyParams {
36    /// Analyze the program and compute readonly parameters.
37    ///
38    /// Uses fixed-point iteration: starts optimistically with all `Seq<T>`
39    /// params as readonly candidates, then eliminates those that are directly
40    /// mutated or transitively mutated via callee propagation.
41    ///
42    /// Native functions are trusted: their params remain readonly unless
43    /// the LOGOS body explicitly mutates them (which is impossible since they
44    /// have no body).
45    pub fn analyze(stmts: &[Stmt<'_>], callgraph: &CallGraph, type_env: &TypeEnv) -> Self {
46        // Build fn_params map: fn_sym → ordered list of param symbols
47        let mut fn_params: HashMap<Symbol, Vec<Symbol>> = HashMap::new();
48        for stmt in stmts {
49            if let Stmt::FunctionDef { name, params, .. } = stmt {
50                let syms: Vec<Symbol> = params.iter().map(|(s, _)| *s).collect();
51                fn_params.insert(*name, syms);
52            }
53        }
54
55        // Initialize: all Seq<T> params are readonly candidates
56        let mut readonly: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
57        for stmt in stmts {
58            if let Stmt::FunctionDef { name, params, .. } = stmt {
59                let mut candidates = HashSet::new();
60                for (sym, _) in params {
61                    if is_seq_type(type_env.lookup(*sym)) {
62                        candidates.insert(*sym);
63                    }
64                }
65                readonly.insert(*name, candidates);
66            }
67        }
68
69        // Remove directly mutated params from non-native functions
70        for stmt in stmts {
71            if let Stmt::FunctionDef { name, params, body, is_native, .. } = stmt {
72                if *is_native {
73                    continue;
74                }
75                let param_set: HashSet<Symbol> = params.iter().map(|(s, _)| *s).collect();
76                let mutated = collect_direct_mutations(body, &param_set);
77                if let Some(candidates) = readonly.get_mut(name) {
78                    for sym in &mutated {
79                        candidates.remove(sym);
80                    }
81                }
82            }
83        }
84
85        // Fixed-point: propagate non-readonly through call sites
86        loop {
87            let mut changed = false;
88
89            for stmt in stmts {
90                if let Stmt::FunctionDef { name: caller, body, is_native, .. } = stmt {
91                    if *is_native {
92                        continue;
93                    }
94
95                    // Collect all call sites in this function's body (including closures)
96                    let call_sites = collect_call_sites(body);
97
98                    for (callee, arg_syms) in &call_sites {
99                        let callee_params = match fn_params.get(callee) {
100                            Some(p) => p,
101                            None => continue, // unknown function, skip
102                        };
103
104                        for (i, maybe_arg_sym) in arg_syms.iter().enumerate() {
105                            let arg_sym = match maybe_arg_sym {
106                                Some(s) => s,
107                                None => continue, // arg is not a plain identifier
108                            };
109
110                            let callee_param = match callee_params.get(i) {
111                                Some(p) => p,
112                                None => continue,
113                            };
114
115                            // Is callee's param at position i NOT readonly?
116                            let callee_param_readonly = readonly
117                                .get(callee)
118                                .map(|s| s.contains(callee_param))
119                                .unwrap_or(true); // unknown callees are trusted
120
121                            if !callee_param_readonly {
122                                // The caller's arg is passed to a mutating position
123                                if let Some(caller_readonly) = readonly.get_mut(caller) {
124                                    if caller_readonly.remove(arg_sym) {
125                                        changed = true;
126                                    }
127                                }
128                            }
129                        }
130                    }
131                }
132            }
133
134            if !changed {
135                break;
136            }
137        }
138
139        Self { readonly }
140    }
141
142    /// Returns `true` if `param_sym` is readonly within `fn_sym`.
143    pub fn is_readonly(&self, fn_sym: Symbol, param_sym: Symbol) -> bool {
144        self.readonly
145            .get(&fn_sym)
146            .map(|s| s.contains(&param_sym))
147            .unwrap_or(false)
148    }
149}
150
151fn is_seq_type(ty: &LogosType) -> bool {
152    matches!(ty, LogosType::Seq(_))
153}
154
155// =============================================================================
156// Direct mutation detection
157// =============================================================================
158
159/// Collects param symbols that are directly mutated in the body.
160///
161/// Looks for Push, Pop, Add, Remove, SetIndex, SetField, and Set reassignment
162/// on identifiers that appear in `param_set`. Also detects "consumed"
163/// parameters: those assigned into a mutable local via `Let mutable X be param`.
164/// A consumed Seq parameter should be taken by value (not borrowed) so the
165/// copy becomes a move instead of a `.to_vec()` clone.
166///
167/// Does NOT recurse into closure bodies (closures in LOGOS capture by clone,
168/// so they don't mutate the original param directly).
169fn collect_direct_mutations(stmts: &[Stmt<'_>], param_set: &HashSet<Symbol>) -> HashSet<Symbol> {
170    let mut mutated = HashSet::new();
171    for stmt in stmts {
172        collect_mutations_from_stmt(stmt, param_set, &mut mutated);
173    }
174    // Detect consumed parameters: `Let mutable X be param` where X is
175    // subsequently mutated. Taking param by value allows a move instead
176    // of a clone. We conservatively mark any param that appears as the
177    // value of a `Let mutable` as consumed.
178    collect_consumed_params(stmts, param_set, &mut mutated);
179    mutated
180}
181
182fn collect_mutations_from_stmt(stmt: &Stmt<'_>, param_set: &HashSet<Symbol>, mutated: &mut HashSet<Symbol>) {
183    match stmt {
184        Stmt::Push { collection, .. } => {
185            if let Expr::Identifier(sym) = **collection {
186                if param_set.contains(&sym) {
187                    mutated.insert(sym);
188                }
189            }
190        }
191        Stmt::Pop { collection, .. } => {
192            if let Expr::Identifier(sym) = **collection {
193                if param_set.contains(&sym) {
194                    mutated.insert(sym);
195                }
196            }
197        }
198        Stmt::Add { collection, .. } => {
199            if let Expr::Identifier(sym) = **collection {
200                if param_set.contains(&sym) {
201                    mutated.insert(sym);
202                }
203            }
204        }
205        Stmt::Remove { collection, .. } => {
206            if let Expr::Identifier(sym) = **collection {
207                if param_set.contains(&sym) {
208                    mutated.insert(sym);
209                }
210            }
211        }
212        Stmt::SetIndex { collection, .. } => {
213            if let Expr::Identifier(sym) = **collection {
214                if param_set.contains(&sym) {
215                    mutated.insert(sym);
216                }
217            }
218        }
219        Stmt::SetField { object, .. } => {
220            if let Expr::Identifier(sym) = **object {
221                if param_set.contains(&sym) {
222                    mutated.insert(sym);
223                }
224            }
225        }
226        Stmt::Set { target, .. } => {
227            if param_set.contains(target) {
228                mutated.insert(*target);
229            }
230        }
231        // Recurse into control-flow blocks (not closures)
232        Stmt::If { then_block, else_block, .. } => {
233            for s in *then_block {
234                collect_mutations_from_stmt(s, param_set, mutated);
235            }
236            if let Some(else_b) = else_block {
237                for s in *else_b {
238                    collect_mutations_from_stmt(s, param_set, mutated);
239                }
240            }
241        }
242        Stmt::While { body, .. } => {
243            for s in *body {
244                collect_mutations_from_stmt(s, param_set, mutated);
245            }
246        }
247        Stmt::Repeat { body, .. } => {
248            for s in *body {
249                collect_mutations_from_stmt(s, param_set, mutated);
250            }
251        }
252        Stmt::Inspect { arms, .. } => {
253            for arm in arms {
254                for s in arm.body {
255                    collect_mutations_from_stmt(s, param_set, mutated);
256                }
257            }
258        }
259        _ => {}
260    }
261}
262
263/// Detects consumed parameters: those copied into a mutable local via
264/// `Let mutable X be param`. Recurses into control-flow blocks.
265fn collect_consumed_params(stmts: &[Stmt<'_>], param_set: &HashSet<Symbol>, consumed: &mut HashSet<Symbol>) {
266    for stmt in stmts {
267        match stmt {
268            Stmt::Let { mutable: true, value, .. } => {
269                if let Expr::Identifier(sym) = value {
270                    if param_set.contains(sym) {
271                        consumed.insert(*sym);
272                    }
273                }
274            }
275            Stmt::If { then_block, else_block, .. } => {
276                collect_consumed_params(then_block, param_set, consumed);
277                if let Some(else_b) = else_block {
278                    collect_consumed_params(else_b, param_set, consumed);
279                }
280            }
281            Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
282                collect_consumed_params(body, param_set, consumed);
283            }
284            Stmt::Inspect { arms, .. } => {
285                for arm in arms {
286                    collect_consumed_params(arm.body, param_set, consumed);
287                }
288            }
289            _ => {}
290        }
291    }
292}
293
294// =============================================================================
295// Call site collection (for fixed-point propagation)
296// =============================================================================
297
298/// Collects all call sites in a function body, including those inside closures.
299///
300/// Returns `Vec<(callee, [arg0?, arg1?, ...])>` where each arg is `Some(sym)`
301/// if the argument is a plain `Expr::Identifier`, otherwise `None`.
302fn collect_call_sites(stmts: &[Stmt<'_>]) -> Vec<(Symbol, Vec<Option<Symbol>>)> {
303    let mut sites = Vec::new();
304    collect_call_sites_from_stmts(stmts, &mut sites);
305    sites
306}
307
308fn collect_call_sites_from_stmts(stmts: &[Stmt<'_>], sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
309    for stmt in stmts {
310        collect_call_sites_from_stmt(stmt, sites);
311    }
312}
313
314fn collect_call_sites_from_stmt(stmt: &Stmt<'_>, sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
315    match stmt {
316        Stmt::Call { function, args } => {
317            let arg_syms = args.iter().map(|arg| {
318                if let Expr::Identifier(sym) = *arg { Some(*sym) } else { None }
319            }).collect();
320            sites.push((*function, arg_syms));
321            for arg in args {
322                collect_call_sites_from_expr(arg, sites);
323            }
324        }
325        Stmt::Let { value, .. } => collect_call_sites_from_expr(value, sites),
326        Stmt::Set { value, .. } => collect_call_sites_from_expr(value, sites),
327        Stmt::Return { value: Some(v) } => collect_call_sites_from_expr(v, sites),
328        Stmt::If { cond, then_block, else_block } => {
329            collect_call_sites_from_expr(cond, sites);
330            collect_call_sites_from_stmts(then_block, sites);
331            if let Some(else_b) = else_block {
332                collect_call_sites_from_stmts(else_b, sites);
333            }
334        }
335        Stmt::While { cond, body, .. } => {
336            collect_call_sites_from_expr(cond, sites);
337            collect_call_sites_from_stmts(body, sites);
338        }
339        Stmt::Repeat { iterable, body, .. } => {
340            collect_call_sites_from_expr(iterable, sites);
341            collect_call_sites_from_stmts(body, sites);
342        }
343        Stmt::Push { value, collection } => {
344            collect_call_sites_from_expr(value, sites);
345            collect_call_sites_from_expr(collection, sites);
346        }
347        Stmt::Inspect { arms, .. } => {
348            for arm in arms {
349                collect_call_sites_from_stmts(arm.body, sites);
350            }
351        }
352        Stmt::Concurrent { tasks } | Stmt::Parallel { tasks } => {
353            collect_call_sites_from_stmts(tasks, sites);
354        }
355        _ => {}
356    }
357}
358
359fn collect_call_sites_from_expr(expr: &Expr<'_>, sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
360    match expr {
361        Expr::Call { function, args } => {
362            let arg_syms = args.iter().map(|arg| {
363                if let Expr::Identifier(sym) = *arg { Some(*sym) } else { None }
364            }).collect();
365            sites.push((*function, arg_syms));
366            for arg in args {
367                collect_call_sites_from_expr(arg, sites);
368            }
369        }
370        Expr::Closure { body, .. } => match body {
371            ClosureBody::Expression(e) => collect_call_sites_from_expr(e, sites),
372            ClosureBody::Block(stmts) => collect_call_sites_from_stmts(stmts, sites),
373        },
374        Expr::BinaryOp { left, right, .. } => {
375            collect_call_sites_from_expr(left, sites);
376            collect_call_sites_from_expr(right, sites);
377        }
378        Expr::Index { collection, index } => {
379            collect_call_sites_from_expr(collection, sites);
380            collect_call_sites_from_expr(index, sites);
381        }
382        Expr::Length { collection } => collect_call_sites_from_expr(collection, sites),
383        Expr::Contains { collection, value } => {
384            collect_call_sites_from_expr(collection, sites);
385            collect_call_sites_from_expr(value, sites);
386        }
387        Expr::FieldAccess { object, .. } => collect_call_sites_from_expr(object, sites),
388        Expr::Copy { expr } | Expr::Give { value: expr } => {
389            collect_call_sites_from_expr(expr, sites);
390        }
391        Expr::OptionSome { value } => collect_call_sites_from_expr(value, sites),
392        Expr::WithCapacity { value, capacity } => {
393            collect_call_sites_from_expr(value, sites);
394            collect_call_sites_from_expr(capacity, sites);
395        }
396        Expr::CallExpr { callee, args } => {
397            collect_call_sites_from_expr(callee, sites);
398            for arg in args {
399                collect_call_sites_from_expr(arg, sites);
400            }
401        }
402        _ => {}
403    }
404}
405
406// =============================================================================
407// Mutable Borrow Parameter Analysis
408// =============================================================================
409
410/// Mutable borrow parameter analysis result.
411///
412/// Identifies `Seq<T>` parameters that are only mutated via element access
413/// (SetIndex) but never structurally modified (no Push, Pop, Add, Remove,
414/// or reassignment). These parameters can be passed as `&mut [T]` instead
415/// of by value, eliminating the move-in/move-out ownership pattern.
416///
417/// Additional requirement: the function must return the parameter as its
418/// sole return value, so the call site can drop the assignment.
419pub struct MutableBorrowParams {
420    /// fn_sym → set of param symbols eligible for &mut [T] borrow.
421    pub mutable_borrow: HashMap<Symbol, HashSet<Symbol>>,
422}
423
424impl MutableBorrowParams {
425    /// Analyze the program and compute mutable borrow parameters.
426    pub fn analyze(stmts: &[Stmt<'_>], callgraph: &CallGraph, type_env: &TypeEnv) -> Self {
427        let mut fn_params: HashMap<Symbol, Vec<Symbol>> = HashMap::new();
428        for stmt in stmts {
429            if let Stmt::FunctionDef { name, params, .. } = stmt {
430                let syms: Vec<Symbol> = params.iter().map(|(s, _)| *s).collect();
431                fn_params.insert(*name, syms);
432            }
433        }
434
435        let mut mutable_borrow: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
436
437        for stmt in stmts {
438            if let Stmt::FunctionDef { name, params, body, is_native, is_exported, .. } = stmt {
439                if *is_native || *is_exported {
440                    continue;
441                }
442
443                let mut candidates = HashSet::new();
444
445                for (sym, _) in params {
446                    if !is_seq_type(type_env.lookup(*sym)) {
447                        continue;
448                    }
449
450                    let has_set_index = has_set_index_on(body, *sym);
451                    let has_structural = has_structural_mutation_on(body, *sym);
452                    let has_reassign = has_reassignment_on(body, *sym);
453                    let consumed = is_consumed_param(body, *sym);
454                    let returned = is_sole_return_param(body, *sym);
455
456                    if has_set_index && !has_structural && !has_reassign && !consumed && returned {
457                        candidates.insert(*sym);
458                    } else if consumed {
459                        // Consume-alias detection: `Let mutable result be arr`
460                        // where `arr` is never used after the consume and `result`
461                        // satisfies all &mut [T] criteria.
462                        let param_idx = params.iter().position(|(s, _)| *s == *sym).unwrap_or(usize::MAX);
463                        if let Some(alias) = detect_consume_alias(body, *sym) {
464                            let alias_has_set_index = has_set_index_on(body, alias);
465                            let alias_has_structural = has_structural_mutation_on(body, alias);
466                            let alias_returned = is_sole_return_param_or_alias(body, *sym, alias);
467                            let alias_reassign_ok = reassignment_only_self_calls(body, alias, *name, param_idx);
468                            let param_dead = is_param_dead_after_consume(body, *sym, alias);
469
470                            if alias_has_set_index && !alias_has_structural && alias_returned && alias_reassign_ok && param_dead {
471                                candidates.insert(*sym);
472                            }
473                        }
474                    }
475                }
476
477                if !candidates.is_empty() {
478                    mutable_borrow.insert(*name, candidates);
479                }
480            }
481        }
482
483        // Fixed-point: propagate through call sites.
484        loop {
485            let mut changed = false;
486            for stmt in stmts {
487                if let Stmt::FunctionDef { name: caller, body, is_native, .. } = stmt {
488                    if *is_native {
489                        continue;
490                    }
491                    let call_sites = collect_call_sites(body);
492                    for (callee, arg_syms) in &call_sites {
493                        let callee_params = match fn_params.get(callee) {
494                            Some(p) => p,
495                            None => continue,
496                        };
497                        for (i, maybe_arg_sym) in arg_syms.iter().enumerate() {
498                            let arg_sym = match maybe_arg_sym {
499                                Some(s) => s,
500                                None => continue,
501                            };
502                            let callee_param = match callee_params.get(i) {
503                                Some(p) => p,
504                                None => continue,
505                            };
506                            let callee_is_mut_borrow = mutable_borrow
507                                .get(callee)
508                                .map(|s| s.contains(callee_param))
509                                .unwrap_or(false);
510                            if !callee_is_mut_borrow {
511                                if let Some(caller_set) = mutable_borrow.get_mut(caller) {
512                                    if caller_set.remove(arg_sym) {
513                                        changed = true;
514                                    }
515                                }
516                            }
517                        }
518                    }
519                }
520            }
521            if !changed {
522                break;
523            }
524        }
525
526        // Call-site compatibility: &mut [T] suppresses the return type, so the
527        // function can only be called in void context (Stmt::Call) or in
528        // `Set x to f(x, ...)` where x is at a mut_borrow position.
529        // Remove functions from mutable_borrow if any call site uses the return
530        // value in a way that requires it (Let, Show, Return, expression context).
531        let incompatible = collect_incompatible_mut_borrow_callsites(
532            stmts, &mutable_borrow, &fn_params,
533        );
534        for fn_sym in incompatible {
535            mutable_borrow.remove(&fn_sym);
536        }
537
538        Self { mutable_borrow }
539    }
540
541    pub fn is_mutable_borrow(&self, fn_sym: Symbol, param_sym: Symbol) -> bool {
542        self.mutable_borrow
543            .get(&fn_sym)
544            .map(|s| s.contains(&param_sym))
545            .unwrap_or(false)
546    }
547}
548
549fn has_set_index_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
550    stmts.iter().any(|s| check_set_index_stmt(s, sym))
551}
552
553fn check_set_index_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
554    match stmt {
555        Stmt::SetIndex { collection, .. } => {
556            matches!(**collection, Expr::Identifier(s) if s == sym)
557        }
558        Stmt::If { then_block, else_block, .. } => {
559            has_set_index_on(then_block, sym)
560                || else_block.as_ref().map_or(false, |eb| has_set_index_on(eb, sym))
561        }
562        Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
563            has_set_index_on(body, sym)
564        }
565        Stmt::Inspect { arms, .. } => {
566            arms.iter().any(|arm| has_set_index_on(arm.body, sym))
567        }
568        _ => false,
569    }
570}
571
572fn has_structural_mutation_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
573    stmts.iter().any(|s| check_structural_stmt(s, sym))
574}
575
576fn check_structural_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
577    match stmt {
578        Stmt::Push { collection, .. } | Stmt::Pop { collection, .. }
579        | Stmt::Add { collection, .. } | Stmt::Remove { collection, .. } => {
580            matches!(**collection, Expr::Identifier(s) if s == sym)
581        }
582        Stmt::If { then_block, else_block, .. } => {
583            has_structural_mutation_on(then_block, sym)
584                || else_block.as_ref().map_or(false, |eb| has_structural_mutation_on(eb, sym))
585        }
586        Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
587            has_structural_mutation_on(body, sym)
588        }
589        Stmt::Inspect { arms, .. } => {
590            arms.iter().any(|arm| has_structural_mutation_on(arm.body, sym))
591        }
592        _ => false,
593    }
594}
595
596fn has_reassignment_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
597    stmts.iter().any(|s| check_reassignment_stmt(s, sym))
598}
599
600fn check_reassignment_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
601    match stmt {
602        Stmt::Set { target, .. } => *target == sym,
603        Stmt::If { then_block, else_block, .. } => {
604            has_reassignment_on(then_block, sym)
605                || else_block.as_ref().map_or(false, |eb| has_reassignment_on(eb, sym))
606        }
607        Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
608            has_reassignment_on(body, sym)
609        }
610        Stmt::Inspect { arms, .. } => {
611            arms.iter().any(|arm| has_reassignment_on(arm.body, sym))
612        }
613        _ => false,
614    }
615}
616
617fn is_consumed_param(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
618    for stmt in stmts {
619        match stmt {
620            Stmt::Let { mutable: true, value, .. } => {
621                if matches!(value, Expr::Identifier(s) if *s == sym) {
622                    return true;
623                }
624            }
625            Stmt::If { then_block, else_block, .. } => {
626                if is_consumed_param(then_block, sym) { return true; }
627                if let Some(else_b) = else_block {
628                    if is_consumed_param(else_b, sym) { return true; }
629                }
630            }
631            Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
632                if is_consumed_param(body, sym) { return true; }
633            }
634            _ => {}
635        }
636    }
637    false
638}
639
640fn is_sole_return_param(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
641    let mut returns = Vec::new();
642    collect_returns(stmts, &mut returns);
643    !returns.is_empty() && returns.iter().all(|r| *r == sym)
644}
645
646fn collect_returns(stmts: &[Stmt<'_>], returns: &mut Vec<Symbol>) {
647    for stmt in stmts {
648        match stmt {
649            Stmt::Return { value: Some(expr) } => {
650                if let Expr::Identifier(sym) = expr {
651                    returns.push(*sym);
652                } else {
653                    // Non-identifier return — sentinel that won't match
654                    returns.push(Symbol::EMPTY);
655                }
656            }
657            Stmt::If { then_block, else_block, .. } => {
658                collect_returns(then_block, returns);
659                if let Some(else_b) = else_block {
660                    collect_returns(else_b, returns);
661                }
662            }
663            Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
664                collect_returns(body, returns);
665            }
666            Stmt::Inspect { arms, .. } => {
667                for arm in arms {
668                    collect_returns(arm.body, returns);
669                }
670            }
671            _ => {}
672        }
673    }
674}
675
676// =============================================================================
677// Call-site compatibility for &mut [T]
678// =============================================================================
679
680/// Collect functions in `mutable_borrow` that have incompatible call sites.
681/// An incompatible call site is one where the function's return value is used
682/// (e.g., in Let, Show, Return, or expression context) because &mut [T]
683/// functions have void return.
684fn collect_incompatible_mut_borrow_callsites(
685    stmts: &[Stmt<'_>],
686    mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
687    fn_params: &HashMap<Symbol, Vec<Symbol>>,
688) -> HashSet<Symbol> {
689    let mut incompatible = HashSet::new();
690    for stmt in stmts {
691        if let Stmt::FunctionDef { body, .. } = stmt {
692            check_callsite_compat_stmts(body, mutable_borrow, fn_params, &mut incompatible);
693        }
694    }
695    // Also check main-level statements (not inside function defs)
696    check_callsite_compat_stmts(stmts, mutable_borrow, fn_params, &mut incompatible);
697    incompatible
698}
699
700fn check_callsite_compat_stmts(
701    stmts: &[Stmt<'_>],
702    mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
703    fn_params: &HashMap<Symbol, Vec<Symbol>>,
704    incompatible: &mut HashSet<Symbol>,
705) {
706    for stmt in stmts {
707        check_callsite_compat_stmt(stmt, mutable_borrow, fn_params, incompatible);
708    }
709}
710
711fn check_callsite_compat_stmt(
712    stmt: &Stmt<'_>,
713    mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
714    fn_params: &HashMap<Symbol, Vec<Symbol>>,
715    incompatible: &mut HashSet<Symbol>,
716) {
717    match stmt {
718        // Stmt::Call → void context, always OK. But check args for nested calls.
719        Stmt::Call { args, .. } => {
720            for arg in args {
721                check_callsite_compat_expr(arg, mutable_borrow, incompatible);
722            }
723        }
724        // Stmt::Set → OK if target == args[mut_borrow_pos], otherwise check expr
725        Stmt::Set { target, value } => {
726            if let Expr::Call { function, args } = value {
727                if mutable_borrow.contains_key(function) {
728                    // Check that target is at a mut_borrow position
729                    let mut_positions: HashSet<usize> = fn_params.get(function)
730                        .map(|params| {
731                            params.iter().enumerate()
732                                .filter(|(_, sym)| {
733                                    mutable_borrow.get(function)
734                                        .map(|s| s.contains(sym))
735                                        .unwrap_or(false)
736                                })
737                                .map(|(i, _)| i)
738                                .collect()
739                        })
740                        .unwrap_or_default();
741
742                    let target_at_mut_pos = args.iter().enumerate()
743                        .any(|(i, a)| {
744                            mut_positions.contains(&i)
745                                && matches!(a, Expr::Identifier(sym) if *sym == *target)
746                        });
747
748                    if !target_at_mut_pos {
749                        incompatible.insert(*function);
750                    }
751                }
752                // Check args for nested calls
753                for arg in args {
754                    check_callsite_compat_expr(arg, mutable_borrow, incompatible);
755                }
756            } else {
757                check_callsite_compat_expr(value, mutable_borrow, incompatible);
758            }
759        }
760        // Stmt::Let → if value is a call to a mut_borrow function, it's incompatible
761        Stmt::Let { value, .. } => {
762            check_callsite_compat_expr(value, mutable_borrow, incompatible);
763        }
764        Stmt::Return { value: Some(v) } => {
765            check_callsite_compat_expr(v, mutable_borrow, incompatible);
766        }
767        Stmt::Show { object, .. } => {
768            check_callsite_compat_expr(object, mutable_borrow, incompatible);
769        }
770        Stmt::Push { value, collection } => {
771            check_callsite_compat_expr(value, mutable_borrow, incompatible);
772            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
773        }
774        Stmt::SetIndex { collection, index, value } => {
775            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
776            check_callsite_compat_expr(index, mutable_borrow, incompatible);
777            check_callsite_compat_expr(value, mutable_borrow, incompatible);
778        }
779        Stmt::If { cond, then_block, else_block } => {
780            check_callsite_compat_expr(cond, mutable_borrow, incompatible);
781            check_callsite_compat_stmts(then_block, mutable_borrow, fn_params, incompatible);
782            if let Some(else_b) = else_block {
783                check_callsite_compat_stmts(else_b, mutable_borrow, fn_params, incompatible);
784            }
785        }
786        Stmt::While { cond, body, .. } => {
787            check_callsite_compat_expr(cond, mutable_borrow, incompatible);
788            check_callsite_compat_stmts(body, mutable_borrow, fn_params, incompatible);
789        }
790        Stmt::Repeat { iterable, body, .. } => {
791            check_callsite_compat_expr(iterable, mutable_borrow, incompatible);
792            check_callsite_compat_stmts(body, mutable_borrow, fn_params, incompatible);
793        }
794        Stmt::Inspect { arms, .. } => {
795            for arm in arms {
796                check_callsite_compat_stmts(arm.body, mutable_borrow, fn_params, incompatible);
797            }
798        }
799        // Skip FunctionDef — handled at top level
800        _ => {}
801    }
802}
803
804/// Check if an expression uses a mut_borrow function in value context (incompatible).
805fn check_callsite_compat_expr(
806    expr: &Expr<'_>,
807    mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
808    incompatible: &mut HashSet<Symbol>,
809) {
810    match expr {
811        Expr::Call { function, args } => {
812            if mutable_borrow.contains_key(function) {
813                incompatible.insert(*function);
814            }
815            for arg in args {
816                check_callsite_compat_expr(arg, mutable_borrow, incompatible);
817            }
818        }
819        Expr::BinaryOp { left, right, .. } => {
820            check_callsite_compat_expr(left, mutable_borrow, incompatible);
821            check_callsite_compat_expr(right, mutable_borrow, incompatible);
822        }
823        Expr::Index { collection, index } => {
824            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
825            check_callsite_compat_expr(index, mutable_borrow, incompatible);
826        }
827        Expr::Length { collection } => {
828            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
829        }
830        Expr::Contains { collection, value } => {
831            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
832            check_callsite_compat_expr(value, mutable_borrow, incompatible);
833        }
834        Expr::FieldAccess { object, .. } => {
835            check_callsite_compat_expr(object, mutable_borrow, incompatible);
836        }
837        Expr::Copy { expr: inner } | Expr::Give { value: inner } | Expr::OptionSome { value: inner } => {
838            check_callsite_compat_expr(inner, mutable_borrow, incompatible);
839        }
840        _ => {}
841    }
842}
843
844// =============================================================================
845// Consume-Alias Detection for &mut [T]
846// =============================================================================
847
848/// Detect consume-alias pattern: finds exactly one `Let mutable <alias> be <param>`
849/// at the top level of the function body. Returns the alias symbol if found.
850fn detect_consume_alias(body: &[Stmt<'_>], param_sym: Symbol) -> Option<Symbol> {
851    let mut alias = None;
852    for stmt in body {
853        if let Stmt::Let { var, mutable: true, value, .. } = stmt {
854            if matches!(value, Expr::Identifier(s) if *s == param_sym) {
855                if alias.is_some() {
856                    return None; // Multiple consumes — reject
857                }
858                alias = Some(*var);
859            }
860        }
861    }
862    alias
863}
864
865/// Check that every return in the body returns either `param_sym` or `alias_sym`.
866fn is_sole_return_param_or_alias(stmts: &[Stmt<'_>], param_sym: Symbol, alias_sym: Symbol) -> bool {
867    let mut returns = Vec::new();
868    collect_returns(stmts, &mut returns);
869    !returns.is_empty() && returns.iter().all(|r| *r == param_sym || *r == alias_sym)
870}
871
872/// Check that every `Set <alias> to <expr>` in the body is a call to `func_name`
873/// with `alias` at position `param_position`. No other reassignment patterns allowed.
874fn reassignment_only_self_calls(
875    body: &[Stmt<'_>],
876    alias: Symbol,
877    func_name: Symbol,
878    param_position: usize,
879) -> bool {
880    check_reassignment_self_calls(body, alias, func_name, param_position)
881}
882
883fn check_reassignment_self_calls(
884    stmts: &[Stmt<'_>],
885    alias: Symbol,
886    func_name: Symbol,
887    param_position: usize,
888) -> bool {
889    for stmt in stmts {
890        match stmt {
891            Stmt::Set { target, value } if *target == alias => {
892                // Must be a call to func_name with alias at param_position
893                match value {
894                    Expr::Call { function, args } if *function == func_name => {
895                        let arg_at_pos = args.get(param_position);
896                        let is_alias_at_pos = arg_at_pos
897                            .map(|a| matches!(a, Expr::Identifier(s) if *s == alias))
898                            .unwrap_or(false);
899                        if !is_alias_at_pos {
900                            return false;
901                        }
902                    }
903                    _ => return false, // Non-self-call reassignment
904                }
905            }
906            Stmt::If { then_block, else_block, .. } => {
907                if !check_reassignment_self_calls(then_block, alias, func_name, param_position) {
908                    return false;
909                }
910                if let Some(else_b) = else_block {
911                    if !check_reassignment_self_calls(else_b, alias, func_name, param_position) {
912                        return false;
913                    }
914                }
915            }
916            Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
917                if !check_reassignment_self_calls(body, alias, func_name, param_position) {
918                    return false;
919                }
920            }
921            Stmt::Inspect { arms, .. } => {
922                for arm in arms {
923                    if !check_reassignment_self_calls(arm.body, alias, func_name, param_position) {
924                        return false;
925                    }
926                }
927            }
928            _ => {}
929        }
930    }
931    true
932}
933
934/// Check that `param_sym` is dead after the consume statement
935/// (`Let mutable <alias> be <param>`). Scans top-level statements:
936/// before the consume, param can be used freely. After the consume,
937/// param must never appear in any expression.
938fn is_param_dead_after_consume(body: &[Stmt<'_>], param_sym: Symbol, alias: Symbol) -> bool {
939    let mut found_consume = false;
940    for stmt in body {
941        if !found_consume {
942            // Check if this is the consume statement
943            if let Stmt::Let { var, mutable: true, value, .. } = stmt {
944                if *var == alias && matches!(value, Expr::Identifier(s) if *s == param_sym) {
945                    found_consume = true;
946                    continue;
947                }
948            }
949        } else {
950            // After the consume: param must not appear
951            if stmt_references_symbol(stmt, param_sym) {
952                return false;
953            }
954        }
955    }
956    found_consume // Must have actually found the consume
957}
958
959/// Check if a statement references a given symbol anywhere (expressions, collections, etc.).
960fn stmt_references_symbol(stmt: &Stmt<'_>, sym: Symbol) -> bool {
961    match stmt {
962        Stmt::Let { value, .. } => expr_references_symbol(value, sym),
963        Stmt::Set { target, value } => *target == sym || expr_references_symbol(value, sym),
964        Stmt::Call { function, args } => {
965            *function == sym || args.iter().any(|a| expr_references_symbol(a, sym))
966        }
967        Stmt::Push { value, collection } => {
968            expr_references_symbol(value, sym) || expr_references_symbol(collection, sym)
969        }
970        Stmt::Pop { collection, into } => {
971            expr_references_symbol(collection, sym)
972                || into.map_or(false, |s| s == sym)
973        }
974        Stmt::Add { value, collection } | Stmt::Remove { value, collection } => {
975            expr_references_symbol(value, sym) || expr_references_symbol(collection, sym)
976        }
977        Stmt::SetIndex { collection, index, value } => {
978            expr_references_symbol(collection, sym)
979                || expr_references_symbol(index, sym)
980                || expr_references_symbol(value, sym)
981        }
982        Stmt::SetField { object, value, .. } => {
983            expr_references_symbol(object, sym) || expr_references_symbol(value, sym)
984        }
985        Stmt::Return { value: Some(v) } => expr_references_symbol(v, sym),
986        Stmt::Return { value: None } => false,
987        Stmt::If { cond, then_block, else_block } => {
988            expr_references_symbol(cond, sym)
989                || then_block.iter().any(|s| stmt_references_symbol(s, sym))
990                || else_block.as_ref().map_or(false, |eb| eb.iter().any(|s| stmt_references_symbol(s, sym)))
991        }
992        Stmt::While { cond, body, .. } => {
993            expr_references_symbol(cond, sym)
994                || body.iter().any(|s| stmt_references_symbol(s, sym))
995        }
996        Stmt::Repeat { iterable, body, .. } => {
997            expr_references_symbol(iterable, sym)
998                || body.iter().any(|s| stmt_references_symbol(s, sym))
999        }
1000        Stmt::Inspect { arms, .. } => {
1001            arms.iter().any(|arm| arm.body.iter().any(|s| stmt_references_symbol(s, sym)))
1002        }
1003        Stmt::Show { object, .. } => expr_references_symbol(object, sym),
1004        _ => false,
1005    }
1006}
1007
1008fn expr_references_symbol(expr: &Expr<'_>, sym: Symbol) -> bool {
1009    match expr {
1010        Expr::Identifier(s) => *s == sym,
1011        Expr::BinaryOp { left, right, .. } => {
1012            expr_references_symbol(left, sym) || expr_references_symbol(right, sym)
1013        }
1014        Expr::Not { operand } => expr_references_symbol(operand, sym),
1015        Expr::Call { function, args } => {
1016            *function == sym || args.iter().any(|a| expr_references_symbol(a, sym))
1017        }
1018        Expr::Index { collection, index } => {
1019            expr_references_symbol(collection, sym) || expr_references_symbol(index, sym)
1020        }
1021        Expr::Length { collection } => expr_references_symbol(collection, sym),
1022        Expr::Contains { collection, value } => {
1023            expr_references_symbol(collection, sym) || expr_references_symbol(value, sym)
1024        }
1025        Expr::FieldAccess { object, .. } => expr_references_symbol(object, sym),
1026        Expr::Slice { collection, start, end } => {
1027            expr_references_symbol(collection, sym)
1028                || expr_references_symbol(start, sym)
1029                || expr_references_symbol(end, sym)
1030        }
1031        Expr::Copy { expr: inner } | Expr::Give { value: inner } | Expr::OptionSome { value: inner } => {
1032            expr_references_symbol(inner, sym)
1033        }
1034        Expr::CallExpr { callee, args } => {
1035            expr_references_symbol(callee, sym)
1036                || args.iter().any(|a| expr_references_symbol(a, sym))
1037        }
1038        _ => false,
1039    }
1040}