Skip to main content

logicaffeine_compile/analysis/
readonly.rs

1use std::collections::{HashMap, HashSet};
2
3use logicaffeine_base::Symbol;
4use logicaffeine_language::ast::{Expr, Stmt};
5use logicaffeine_language::ast::stmt::ClosureBody;
6
7use super::callgraph::CallGraph;
8use super::types::{LogosType, TypeEnv};
9
10/// Readonly parameter analysis result.
11///
12/// Maps each function to the set of its `Seq<T>` parameters that are never
13/// structurally mutated (no Push, Pop, Add, Remove, SetIndex, or reassignment)
14/// either directly or transitively through callees.
15///
16/// Parameters in this set are eligible for `&[T]` borrow in codegen instead
17/// of requiring ownership or cloning.
18pub struct ReadonlyParams {
19    /// fn_sym → set of param symbols that are readonly within that function.
20    pub readonly: HashMap<Symbol, HashSet<Symbol>>,
21}
22
23impl ReadonlyParams {
24    /// Analyze the program and compute readonly parameters.
25    ///
26    /// Uses fixed-point iteration: starts optimistically with all `Seq<T>`
27    /// params as readonly candidates, then eliminates those that are directly
28    /// mutated or transitively mutated via callee propagation.
29    ///
30    /// Native functions are trusted: their params remain readonly unless
31    /// the LOGOS body explicitly mutates them (which is impossible since they
32    /// have no body).
33    pub fn analyze(stmts: &[Stmt<'_>], callgraph: &CallGraph, type_env: &TypeEnv) -> Self {
34        // Build fn_params map: fn_sym → ordered list of param symbols
35        let mut fn_params: HashMap<Symbol, Vec<Symbol>> = HashMap::new();
36        for stmt in stmts {
37            if let Stmt::FunctionDef { name, params, .. } = stmt {
38                let syms: Vec<Symbol> = params.iter().map(|(s, _)| *s).collect();
39                fn_params.insert(*name, syms);
40            }
41        }
42
43        // Initialize: all Seq<T> params are readonly candidates
44        let mut readonly: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
45        for stmt in stmts {
46            if let Stmt::FunctionDef { name, params, .. } = stmt {
47                let mut candidates = HashSet::new();
48                for (sym, _) in params {
49                    if is_seq_type(type_env.lookup(*sym)) {
50                        candidates.insert(*sym);
51                    }
52                }
53                readonly.insert(*name, candidates);
54            }
55        }
56
57        // Remove directly mutated params from non-native functions
58        for stmt in stmts {
59            if let Stmt::FunctionDef { name, params, body, is_native, .. } = stmt {
60                if *is_native {
61                    continue;
62                }
63                let param_set: HashSet<Symbol> = params.iter().map(|(s, _)| *s).collect();
64                let mutated = collect_direct_mutations(body, &param_set);
65                if let Some(candidates) = readonly.get_mut(name) {
66                    for sym in &mutated {
67                        candidates.remove(sym);
68                    }
69                }
70            }
71        }
72
73        // Fixed-point: propagate non-readonly through call sites
74        loop {
75            let mut changed = false;
76
77            for stmt in stmts {
78                if let Stmt::FunctionDef { name: caller, body, is_native, .. } = stmt {
79                    if *is_native {
80                        continue;
81                    }
82
83                    // Collect all call sites in this function's body (including closures)
84                    let call_sites = collect_call_sites(body);
85
86                    for (callee, arg_syms) in &call_sites {
87                        let callee_params = match fn_params.get(callee) {
88                            Some(p) => p,
89                            None => continue, // unknown function, skip
90                        };
91
92                        for (i, maybe_arg_sym) in arg_syms.iter().enumerate() {
93                            let arg_sym = match maybe_arg_sym {
94                                Some(s) => s,
95                                None => continue, // arg is not a plain identifier
96                            };
97
98                            let callee_param = match callee_params.get(i) {
99                                Some(p) => p,
100                                None => continue,
101                            };
102
103                            // Is callee's param at position i NOT readonly?
104                            let callee_param_readonly = readonly
105                                .get(callee)
106                                .map(|s| s.contains(callee_param))
107                                .unwrap_or(true); // unknown callees are trusted
108
109                            if !callee_param_readonly {
110                                // The caller's arg is passed to a mutating position
111                                if let Some(caller_readonly) = readonly.get_mut(caller) {
112                                    if caller_readonly.remove(arg_sym) {
113                                        changed = true;
114                                    }
115                                }
116                            }
117                        }
118                    }
119                }
120            }
121
122            if !changed {
123                break;
124            }
125        }
126
127        Self { readonly }
128    }
129
130    /// Returns `true` if `param_sym` is readonly within `fn_sym`.
131    pub fn is_readonly(&self, fn_sym: Symbol, param_sym: Symbol) -> bool {
132        self.readonly
133            .get(&fn_sym)
134            .map(|s| s.contains(&param_sym))
135            .unwrap_or(false)
136    }
137}
138
139fn is_seq_type(ty: &LogosType) -> bool {
140    matches!(ty, LogosType::Seq(_))
141}
142
143// =============================================================================
144// Direct mutation detection
145// =============================================================================
146
147/// Collects param symbols that are directly mutated in the body.
148///
149/// Looks for Push, Pop, Add, Remove, SetIndex, SetField, and Set reassignment
150/// on identifiers that appear in `param_set`. Also detects "consumed"
151/// parameters: those assigned into a mutable local via `Let mutable X be param`.
152/// A consumed Seq parameter should be taken by value (not borrowed) so the
153/// copy becomes a move instead of a `.to_vec()` clone.
154///
155/// Does NOT recurse into closure bodies (closures in LOGOS capture by clone,
156/// so they don't mutate the original param directly).
157fn collect_direct_mutations(stmts: &[Stmt<'_>], param_set: &HashSet<Symbol>) -> HashSet<Symbol> {
158    let mut mutated = HashSet::new();
159    for stmt in stmts {
160        collect_mutations_from_stmt(stmt, param_set, &mut mutated);
161    }
162    // Detect consumed parameters: `Let mutable X be param` where X is
163    // subsequently mutated. Taking param by value allows a move instead
164    // of a clone. We conservatively mark any param that appears as the
165    // value of a `Let mutable` as consumed.
166    collect_consumed_params(stmts, param_set, &mut mutated);
167    mutated
168}
169
170fn collect_mutations_from_stmt(stmt: &Stmt<'_>, param_set: &HashSet<Symbol>, mutated: &mut HashSet<Symbol>) {
171    match stmt {
172        Stmt::Push { collection, .. } => {
173            if let Expr::Identifier(sym) = **collection {
174                if param_set.contains(&sym) {
175                    mutated.insert(sym);
176                }
177            }
178        }
179        Stmt::Pop { collection, .. } => {
180            if let Expr::Identifier(sym) = **collection {
181                if param_set.contains(&sym) {
182                    mutated.insert(sym);
183                }
184            }
185        }
186        Stmt::Add { collection, .. } => {
187            if let Expr::Identifier(sym) = **collection {
188                if param_set.contains(&sym) {
189                    mutated.insert(sym);
190                }
191            }
192        }
193        Stmt::Remove { collection, .. } => {
194            if let Expr::Identifier(sym) = **collection {
195                if param_set.contains(&sym) {
196                    mutated.insert(sym);
197                }
198            }
199        }
200        Stmt::SetIndex { collection, .. } => {
201            if let Expr::Identifier(sym) = **collection {
202                if param_set.contains(&sym) {
203                    mutated.insert(sym);
204                }
205            }
206        }
207        Stmt::SetField { object, .. } => {
208            if let Expr::Identifier(sym) = **object {
209                if param_set.contains(&sym) {
210                    mutated.insert(sym);
211                }
212            }
213        }
214        Stmt::Set { target, .. } => {
215            if param_set.contains(target) {
216                mutated.insert(*target);
217            }
218        }
219        // Recurse into control-flow blocks (not closures)
220        Stmt::If { then_block, else_block, .. } => {
221            for s in *then_block {
222                collect_mutations_from_stmt(s, param_set, mutated);
223            }
224            if let Some(else_b) = else_block {
225                for s in *else_b {
226                    collect_mutations_from_stmt(s, param_set, mutated);
227                }
228            }
229        }
230        Stmt::While { body, .. } => {
231            for s in *body {
232                collect_mutations_from_stmt(s, param_set, mutated);
233            }
234        }
235        Stmt::Repeat { body, .. } => {
236            for s in *body {
237                collect_mutations_from_stmt(s, param_set, mutated);
238            }
239        }
240        Stmt::Inspect { arms, .. } => {
241            for arm in arms {
242                for s in arm.body {
243                    collect_mutations_from_stmt(s, param_set, mutated);
244                }
245            }
246        }
247        _ => {}
248    }
249}
250
251/// Detects consumed parameters: those copied into a mutable local via
252/// `Let mutable X be param`. Recurses into control-flow blocks.
253fn collect_consumed_params(stmts: &[Stmt<'_>], param_set: &HashSet<Symbol>, consumed: &mut HashSet<Symbol>) {
254    for stmt in stmts {
255        match stmt {
256            Stmt::Let { mutable: true, value, .. } => {
257                if let Expr::Identifier(sym) = value {
258                    if param_set.contains(sym) {
259                        consumed.insert(*sym);
260                    }
261                }
262            }
263            Stmt::If { then_block, else_block, .. } => {
264                collect_consumed_params(then_block, param_set, consumed);
265                if let Some(else_b) = else_block {
266                    collect_consumed_params(else_b, param_set, consumed);
267                }
268            }
269            Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
270                collect_consumed_params(body, param_set, consumed);
271            }
272            Stmt::Inspect { arms, .. } => {
273                for arm in arms {
274                    collect_consumed_params(arm.body, param_set, consumed);
275                }
276            }
277            _ => {}
278        }
279    }
280}
281
282// =============================================================================
283// Call site collection (for fixed-point propagation)
284// =============================================================================
285
286/// Collects all call sites in a function body, including those inside closures.
287///
288/// Returns `Vec<(callee, [arg0?, arg1?, ...])>` where each arg is `Some(sym)`
289/// if the argument is a plain `Expr::Identifier`, otherwise `None`.
290fn collect_call_sites(stmts: &[Stmt<'_>]) -> Vec<(Symbol, Vec<Option<Symbol>>)> {
291    let mut sites = Vec::new();
292    collect_call_sites_from_stmts(stmts, &mut sites);
293    sites
294}
295
296fn collect_call_sites_from_stmts(stmts: &[Stmt<'_>], sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
297    for stmt in stmts {
298        collect_call_sites_from_stmt(stmt, sites);
299    }
300}
301
302fn collect_call_sites_from_stmt(stmt: &Stmt<'_>, sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
303    match stmt {
304        Stmt::Call { function, args } => {
305            let arg_syms = args.iter().map(|arg| {
306                if let Expr::Identifier(sym) = *arg { Some(*sym) } else { None }
307            }).collect();
308            sites.push((*function, arg_syms));
309            for arg in args {
310                collect_call_sites_from_expr(arg, sites);
311            }
312        }
313        Stmt::Let { value, .. } => collect_call_sites_from_expr(value, sites),
314        Stmt::Set { value, .. } => collect_call_sites_from_expr(value, sites),
315        Stmt::Return { value: Some(v) } => collect_call_sites_from_expr(v, sites),
316        Stmt::If { cond, then_block, else_block } => {
317            collect_call_sites_from_expr(cond, sites);
318            collect_call_sites_from_stmts(then_block, sites);
319            if let Some(else_b) = else_block {
320                collect_call_sites_from_stmts(else_b, sites);
321            }
322        }
323        Stmt::While { cond, body, .. } => {
324            collect_call_sites_from_expr(cond, sites);
325            collect_call_sites_from_stmts(body, sites);
326        }
327        Stmt::Repeat { iterable, body, .. } => {
328            collect_call_sites_from_expr(iterable, sites);
329            collect_call_sites_from_stmts(body, sites);
330        }
331        Stmt::Push { value, collection } => {
332            collect_call_sites_from_expr(value, sites);
333            collect_call_sites_from_expr(collection, sites);
334        }
335        Stmt::Inspect { arms, .. } => {
336            for arm in arms {
337                collect_call_sites_from_stmts(arm.body, sites);
338            }
339        }
340        Stmt::Concurrent { tasks } | Stmt::Parallel { tasks } => {
341            collect_call_sites_from_stmts(tasks, sites);
342        }
343        _ => {}
344    }
345}
346
347fn collect_call_sites_from_expr(expr: &Expr<'_>, sites: &mut Vec<(Symbol, Vec<Option<Symbol>>)>) {
348    match expr {
349        Expr::Call { function, args } => {
350            let arg_syms = args.iter().map(|arg| {
351                if let Expr::Identifier(sym) = *arg { Some(*sym) } else { None }
352            }).collect();
353            sites.push((*function, arg_syms));
354            for arg in args {
355                collect_call_sites_from_expr(arg, sites);
356            }
357        }
358        Expr::Closure { body, .. } => match body {
359            ClosureBody::Expression(e) => collect_call_sites_from_expr(e, sites),
360            ClosureBody::Block(stmts) => collect_call_sites_from_stmts(stmts, sites),
361        },
362        Expr::BinaryOp { left, right, .. } => {
363            collect_call_sites_from_expr(left, sites);
364            collect_call_sites_from_expr(right, sites);
365        }
366        Expr::Index { collection, index } => {
367            collect_call_sites_from_expr(collection, sites);
368            collect_call_sites_from_expr(index, sites);
369        }
370        Expr::Length { collection } => collect_call_sites_from_expr(collection, sites),
371        Expr::Contains { collection, value } => {
372            collect_call_sites_from_expr(collection, sites);
373            collect_call_sites_from_expr(value, sites);
374        }
375        Expr::FieldAccess { object, .. } => collect_call_sites_from_expr(object, sites),
376        Expr::Copy { expr } | Expr::Give { value: expr } => {
377            collect_call_sites_from_expr(expr, sites);
378        }
379        Expr::OptionSome { value } => collect_call_sites_from_expr(value, sites),
380        Expr::WithCapacity { value, capacity } => {
381            collect_call_sites_from_expr(value, sites);
382            collect_call_sites_from_expr(capacity, sites);
383        }
384        Expr::CallExpr { callee, args } => {
385            collect_call_sites_from_expr(callee, sites);
386            for arg in args {
387                collect_call_sites_from_expr(arg, sites);
388            }
389        }
390        _ => {}
391    }
392}
393
394// =============================================================================
395// Mutable Borrow Parameter Analysis
396// =============================================================================
397
398/// Mutable borrow parameter analysis result.
399///
400/// Identifies `Seq<T>` parameters that are only mutated via element access
401/// (SetIndex) but never structurally modified (no Push, Pop, Add, Remove,
402/// or reassignment). These parameters can be passed as `&mut [T]` instead
403/// of by value, eliminating the move-in/move-out ownership pattern.
404///
405/// Additional requirement: the function must return the parameter as its
406/// sole return value, so the call site can drop the assignment.
407pub struct MutableBorrowParams {
408    /// fn_sym → set of param symbols eligible for &mut [T] borrow.
409    pub mutable_borrow: HashMap<Symbol, HashSet<Symbol>>,
410}
411
412impl MutableBorrowParams {
413    /// Analyze the program and compute mutable borrow parameters.
414    pub fn analyze(stmts: &[Stmt<'_>], callgraph: &CallGraph, type_env: &TypeEnv) -> Self {
415        let mut fn_params: HashMap<Symbol, Vec<Symbol>> = HashMap::new();
416        for stmt in stmts {
417            if let Stmt::FunctionDef { name, params, .. } = stmt {
418                let syms: Vec<Symbol> = params.iter().map(|(s, _)| *s).collect();
419                fn_params.insert(*name, syms);
420            }
421        }
422
423        let mut mutable_borrow: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
424
425        for stmt in stmts {
426            if let Stmt::FunctionDef { name, params, body, is_native, is_exported, .. } = stmt {
427                if *is_native || *is_exported {
428                    continue;
429                }
430
431                let mut candidates = HashSet::new();
432
433                for (sym, _) in params {
434                    if !is_seq_type(type_env.lookup(*sym)) {
435                        continue;
436                    }
437
438                    let has_set_index = has_set_index_on(body, *sym);
439                    let has_structural = has_structural_mutation_on(body, *sym);
440                    let has_reassign = has_reassignment_on(body, *sym);
441                    let consumed = is_consumed_param(body, *sym);
442                    let returned = is_sole_return_param(body, *sym);
443
444                    if has_set_index && !has_structural && !has_reassign && !consumed && returned {
445                        candidates.insert(*sym);
446                    } else if consumed {
447                        // Consume-alias detection: `Let mutable result be arr`
448                        // where `arr` is never used after the consume and `result`
449                        // satisfies all &mut [T] criteria.
450                        let param_idx = params.iter().position(|(s, _)| *s == *sym).unwrap_or(usize::MAX);
451                        if let Some(alias) = detect_consume_alias(body, *sym) {
452                            let alias_has_set_index = has_set_index_on(body, alias);
453                            let alias_has_structural = has_structural_mutation_on(body, alias);
454                            let alias_returned = is_sole_return_param_or_alias(body, *sym, alias);
455                            let alias_reassign_ok = reassignment_only_self_calls(body, alias, *name, param_idx);
456                            let param_dead = is_param_dead_after_consume(body, *sym, alias);
457
458                            if alias_has_set_index && !alias_has_structural && alias_returned && alias_reassign_ok && param_dead {
459                                candidates.insert(*sym);
460                            }
461                        }
462                    }
463                }
464
465                if !candidates.is_empty() {
466                    mutable_borrow.insert(*name, candidates);
467                }
468            }
469        }
470
471        // Fixed-point: propagate through call sites.
472        loop {
473            let mut changed = false;
474            for stmt in stmts {
475                if let Stmt::FunctionDef { name: caller, body, is_native, .. } = stmt {
476                    if *is_native {
477                        continue;
478                    }
479                    let call_sites = collect_call_sites(body);
480                    for (callee, arg_syms) in &call_sites {
481                        let callee_params = match fn_params.get(callee) {
482                            Some(p) => p,
483                            None => continue,
484                        };
485                        for (i, maybe_arg_sym) in arg_syms.iter().enumerate() {
486                            let arg_sym = match maybe_arg_sym {
487                                Some(s) => s,
488                                None => continue,
489                            };
490                            let callee_param = match callee_params.get(i) {
491                                Some(p) => p,
492                                None => continue,
493                            };
494                            let callee_is_mut_borrow = mutable_borrow
495                                .get(callee)
496                                .map(|s| s.contains(callee_param))
497                                .unwrap_or(false);
498                            if !callee_is_mut_borrow {
499                                if let Some(caller_set) = mutable_borrow.get_mut(caller) {
500                                    if caller_set.remove(arg_sym) {
501                                        changed = true;
502                                    }
503                                }
504                            }
505                        }
506                    }
507                }
508            }
509            if !changed {
510                break;
511            }
512        }
513
514        // Call-site compatibility: &mut [T] suppresses the return type, so the
515        // function can only be called in void context (Stmt::Call) or in
516        // `Set x to f(x, ...)` where x is at a mut_borrow position.
517        // Remove functions from mutable_borrow if any call site uses the return
518        // value in a way that requires it (Let, Show, Return, expression context).
519        let incompatible = collect_incompatible_mut_borrow_callsites(
520            stmts, &mutable_borrow, &fn_params,
521        );
522        for fn_sym in incompatible {
523            mutable_borrow.remove(&fn_sym);
524        }
525
526        Self { mutable_borrow }
527    }
528
529    pub fn is_mutable_borrow(&self, fn_sym: Symbol, param_sym: Symbol) -> bool {
530        self.mutable_borrow
531            .get(&fn_sym)
532            .map(|s| s.contains(&param_sym))
533            .unwrap_or(false)
534    }
535}
536
537fn has_set_index_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
538    stmts.iter().any(|s| check_set_index_stmt(s, sym))
539}
540
541fn check_set_index_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
542    match stmt {
543        Stmt::SetIndex { collection, .. } => {
544            matches!(**collection, Expr::Identifier(s) if s == sym)
545        }
546        Stmt::If { then_block, else_block, .. } => {
547            has_set_index_on(then_block, sym)
548                || else_block.as_ref().map_or(false, |eb| has_set_index_on(eb, sym))
549        }
550        Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
551            has_set_index_on(body, sym)
552        }
553        Stmt::Inspect { arms, .. } => {
554            arms.iter().any(|arm| has_set_index_on(arm.body, sym))
555        }
556        _ => false,
557    }
558}
559
560fn has_structural_mutation_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
561    stmts.iter().any(|s| check_structural_stmt(s, sym))
562}
563
564fn check_structural_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
565    match stmt {
566        Stmt::Push { collection, .. } | Stmt::Pop { collection, .. }
567        | Stmt::Add { collection, .. } | Stmt::Remove { collection, .. } => {
568            matches!(**collection, Expr::Identifier(s) if s == sym)
569        }
570        Stmt::If { then_block, else_block, .. } => {
571            has_structural_mutation_on(then_block, sym)
572                || else_block.as_ref().map_or(false, |eb| has_structural_mutation_on(eb, sym))
573        }
574        Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
575            has_structural_mutation_on(body, sym)
576        }
577        Stmt::Inspect { arms, .. } => {
578            arms.iter().any(|arm| has_structural_mutation_on(arm.body, sym))
579        }
580        _ => false,
581    }
582}
583
584fn has_reassignment_on(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
585    stmts.iter().any(|s| check_reassignment_stmt(s, sym))
586}
587
588fn check_reassignment_stmt(stmt: &Stmt<'_>, sym: Symbol) -> bool {
589    match stmt {
590        Stmt::Set { target, .. } => *target == sym,
591        Stmt::If { then_block, else_block, .. } => {
592            has_reassignment_on(then_block, sym)
593                || else_block.as_ref().map_or(false, |eb| has_reassignment_on(eb, sym))
594        }
595        Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
596            has_reassignment_on(body, sym)
597        }
598        Stmt::Inspect { arms, .. } => {
599            arms.iter().any(|arm| has_reassignment_on(arm.body, sym))
600        }
601        _ => false,
602    }
603}
604
605fn is_consumed_param(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
606    for stmt in stmts {
607        match stmt {
608            Stmt::Let { mutable: true, value, .. } => {
609                if matches!(value, Expr::Identifier(s) if *s == sym) {
610                    return true;
611                }
612            }
613            Stmt::If { then_block, else_block, .. } => {
614                if is_consumed_param(then_block, sym) { return true; }
615                if let Some(else_b) = else_block {
616                    if is_consumed_param(else_b, sym) { return true; }
617                }
618            }
619            Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
620                if is_consumed_param(body, sym) { return true; }
621            }
622            _ => {}
623        }
624    }
625    false
626}
627
628fn is_sole_return_param(stmts: &[Stmt<'_>], sym: Symbol) -> bool {
629    let mut returns = Vec::new();
630    collect_returns(stmts, &mut returns);
631    !returns.is_empty() && returns.iter().all(|r| *r == sym)
632}
633
634fn collect_returns(stmts: &[Stmt<'_>], returns: &mut Vec<Symbol>) {
635    for stmt in stmts {
636        match stmt {
637            Stmt::Return { value: Some(expr) } => {
638                if let Expr::Identifier(sym) = expr {
639                    returns.push(*sym);
640                } else {
641                    // Non-identifier return — sentinel that won't match
642                    returns.push(Symbol::EMPTY);
643                }
644            }
645            Stmt::If { then_block, else_block, .. } => {
646                collect_returns(then_block, returns);
647                if let Some(else_b) = else_block {
648                    collect_returns(else_b, returns);
649                }
650            }
651            Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
652                collect_returns(body, returns);
653            }
654            Stmt::Inspect { arms, .. } => {
655                for arm in arms {
656                    collect_returns(arm.body, returns);
657                }
658            }
659            _ => {}
660        }
661    }
662}
663
664// =============================================================================
665// Call-site compatibility for &mut [T]
666// =============================================================================
667
668/// Collect functions in `mutable_borrow` that have incompatible call sites.
669/// An incompatible call site is one where the function's return value is used
670/// (e.g., in Let, Show, Return, or expression context) because &mut [T]
671/// functions have void return.
672fn collect_incompatible_mut_borrow_callsites(
673    stmts: &[Stmt<'_>],
674    mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
675    fn_params: &HashMap<Symbol, Vec<Symbol>>,
676) -> HashSet<Symbol> {
677    let mut incompatible = HashSet::new();
678    for stmt in stmts {
679        if let Stmt::FunctionDef { body, .. } = stmt {
680            check_callsite_compat_stmts(body, mutable_borrow, fn_params, &mut incompatible);
681        }
682    }
683    // Also check main-level statements (not inside function defs)
684    check_callsite_compat_stmts(stmts, mutable_borrow, fn_params, &mut incompatible);
685    incompatible
686}
687
688fn check_callsite_compat_stmts(
689    stmts: &[Stmt<'_>],
690    mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
691    fn_params: &HashMap<Symbol, Vec<Symbol>>,
692    incompatible: &mut HashSet<Symbol>,
693) {
694    for stmt in stmts {
695        check_callsite_compat_stmt(stmt, mutable_borrow, fn_params, incompatible);
696    }
697}
698
699fn check_callsite_compat_stmt(
700    stmt: &Stmt<'_>,
701    mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
702    fn_params: &HashMap<Symbol, Vec<Symbol>>,
703    incompatible: &mut HashSet<Symbol>,
704) {
705    match stmt {
706        // Stmt::Call → void context, always OK. But check args for nested calls.
707        Stmt::Call { args, .. } => {
708            for arg in args {
709                check_callsite_compat_expr(arg, mutable_borrow, incompatible);
710            }
711        }
712        // Stmt::Set → OK if target == args[mut_borrow_pos], otherwise check expr
713        Stmt::Set { target, value } => {
714            if let Expr::Call { function, args } = value {
715                if mutable_borrow.contains_key(function) {
716                    // Check that target is at a mut_borrow position
717                    let mut_positions: HashSet<usize> = fn_params.get(function)
718                        .map(|params| {
719                            params.iter().enumerate()
720                                .filter(|(_, sym)| {
721                                    mutable_borrow.get(function)
722                                        .map(|s| s.contains(sym))
723                                        .unwrap_or(false)
724                                })
725                                .map(|(i, _)| i)
726                                .collect()
727                        })
728                        .unwrap_or_default();
729
730                    let target_at_mut_pos = args.iter().enumerate()
731                        .any(|(i, a)| {
732                            mut_positions.contains(&i)
733                                && matches!(a, Expr::Identifier(sym) if *sym == *target)
734                        });
735
736                    if !target_at_mut_pos {
737                        incompatible.insert(*function);
738                    }
739                }
740                // Check args for nested calls
741                for arg in args {
742                    check_callsite_compat_expr(arg, mutable_borrow, incompatible);
743                }
744            } else {
745                check_callsite_compat_expr(value, mutable_borrow, incompatible);
746            }
747        }
748        // Stmt::Let → if value is a call to a mut_borrow function, it's incompatible
749        Stmt::Let { value, .. } => {
750            check_callsite_compat_expr(value, mutable_borrow, incompatible);
751        }
752        Stmt::Return { value: Some(v) } => {
753            check_callsite_compat_expr(v, mutable_borrow, incompatible);
754        }
755        Stmt::Show { object, .. } => {
756            check_callsite_compat_expr(object, mutable_borrow, incompatible);
757        }
758        Stmt::Push { value, collection } => {
759            check_callsite_compat_expr(value, mutable_borrow, incompatible);
760            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
761        }
762        Stmt::SetIndex { collection, index, value } => {
763            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
764            check_callsite_compat_expr(index, mutable_borrow, incompatible);
765            check_callsite_compat_expr(value, mutable_borrow, incompatible);
766        }
767        Stmt::If { cond, then_block, else_block } => {
768            check_callsite_compat_expr(cond, mutable_borrow, incompatible);
769            check_callsite_compat_stmts(then_block, mutable_borrow, fn_params, incompatible);
770            if let Some(else_b) = else_block {
771                check_callsite_compat_stmts(else_b, mutable_borrow, fn_params, incompatible);
772            }
773        }
774        Stmt::While { cond, body, .. } => {
775            check_callsite_compat_expr(cond, mutable_borrow, incompatible);
776            check_callsite_compat_stmts(body, mutable_borrow, fn_params, incompatible);
777        }
778        Stmt::Repeat { iterable, body, .. } => {
779            check_callsite_compat_expr(iterable, mutable_borrow, incompatible);
780            check_callsite_compat_stmts(body, mutable_borrow, fn_params, incompatible);
781        }
782        Stmt::Inspect { arms, .. } => {
783            for arm in arms {
784                check_callsite_compat_stmts(arm.body, mutable_borrow, fn_params, incompatible);
785            }
786        }
787        // Skip FunctionDef — handled at top level
788        _ => {}
789    }
790}
791
792/// Check if an expression uses a mut_borrow function in value context (incompatible).
793fn check_callsite_compat_expr(
794    expr: &Expr<'_>,
795    mutable_borrow: &HashMap<Symbol, HashSet<Symbol>>,
796    incompatible: &mut HashSet<Symbol>,
797) {
798    match expr {
799        Expr::Call { function, args } => {
800            if mutable_borrow.contains_key(function) {
801                incompatible.insert(*function);
802            }
803            for arg in args {
804                check_callsite_compat_expr(arg, mutable_borrow, incompatible);
805            }
806        }
807        Expr::BinaryOp { left, right, .. } => {
808            check_callsite_compat_expr(left, mutable_borrow, incompatible);
809            check_callsite_compat_expr(right, mutable_borrow, incompatible);
810        }
811        Expr::Index { collection, index } => {
812            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
813            check_callsite_compat_expr(index, mutable_borrow, incompatible);
814        }
815        Expr::Length { collection } => {
816            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
817        }
818        Expr::Contains { collection, value } => {
819            check_callsite_compat_expr(collection, mutable_borrow, incompatible);
820            check_callsite_compat_expr(value, mutable_borrow, incompatible);
821        }
822        Expr::FieldAccess { object, .. } => {
823            check_callsite_compat_expr(object, mutable_borrow, incompatible);
824        }
825        Expr::Copy { expr: inner } | Expr::Give { value: inner } | Expr::OptionSome { value: inner } => {
826            check_callsite_compat_expr(inner, mutable_borrow, incompatible);
827        }
828        _ => {}
829    }
830}
831
832// =============================================================================
833// Consume-Alias Detection for &mut [T]
834// =============================================================================
835
836/// Detect consume-alias pattern: finds exactly one `Let mutable <alias> be <param>`
837/// at the top level of the function body. Returns the alias symbol if found.
838fn detect_consume_alias(body: &[Stmt<'_>], param_sym: Symbol) -> Option<Symbol> {
839    let mut alias = None;
840    for stmt in body {
841        if let Stmt::Let { var, mutable: true, value, .. } = stmt {
842            if matches!(value, Expr::Identifier(s) if *s == param_sym) {
843                if alias.is_some() {
844                    return None; // Multiple consumes — reject
845                }
846                alias = Some(*var);
847            }
848        }
849    }
850    alias
851}
852
853/// Check that every return in the body returns either `param_sym` or `alias_sym`.
854fn is_sole_return_param_or_alias(stmts: &[Stmt<'_>], param_sym: Symbol, alias_sym: Symbol) -> bool {
855    let mut returns = Vec::new();
856    collect_returns(stmts, &mut returns);
857    !returns.is_empty() && returns.iter().all(|r| *r == param_sym || *r == alias_sym)
858}
859
860/// Check that every `Set <alias> to <expr>` in the body is a call to `func_name`
861/// with `alias` at position `param_position`. No other reassignment patterns allowed.
862fn reassignment_only_self_calls(
863    body: &[Stmt<'_>],
864    alias: Symbol,
865    func_name: Symbol,
866    param_position: usize,
867) -> bool {
868    check_reassignment_self_calls(body, alias, func_name, param_position)
869}
870
871fn check_reassignment_self_calls(
872    stmts: &[Stmt<'_>],
873    alias: Symbol,
874    func_name: Symbol,
875    param_position: usize,
876) -> bool {
877    for stmt in stmts {
878        match stmt {
879            Stmt::Set { target, value } if *target == alias => {
880                // Must be a call to func_name with alias at param_position
881                match value {
882                    Expr::Call { function, args } if *function == func_name => {
883                        let arg_at_pos = args.get(param_position);
884                        let is_alias_at_pos = arg_at_pos
885                            .map(|a| matches!(a, Expr::Identifier(s) if *s == alias))
886                            .unwrap_or(false);
887                        if !is_alias_at_pos {
888                            return false;
889                        }
890                    }
891                    _ => return false, // Non-self-call reassignment
892                }
893            }
894            Stmt::If { then_block, else_block, .. } => {
895                if !check_reassignment_self_calls(then_block, alias, func_name, param_position) {
896                    return false;
897                }
898                if let Some(else_b) = else_block {
899                    if !check_reassignment_self_calls(else_b, alias, func_name, param_position) {
900                        return false;
901                    }
902                }
903            }
904            Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
905                if !check_reassignment_self_calls(body, alias, func_name, param_position) {
906                    return false;
907                }
908            }
909            Stmt::Inspect { arms, .. } => {
910                for arm in arms {
911                    if !check_reassignment_self_calls(arm.body, alias, func_name, param_position) {
912                        return false;
913                    }
914                }
915            }
916            _ => {}
917        }
918    }
919    true
920}
921
922/// Check that `param_sym` is dead after the consume statement
923/// (`Let mutable <alias> be <param>`). Scans top-level statements:
924/// before the consume, param can be used freely. After the consume,
925/// param must never appear in any expression.
926fn is_param_dead_after_consume(body: &[Stmt<'_>], param_sym: Symbol, alias: Symbol) -> bool {
927    let mut found_consume = false;
928    for stmt in body {
929        if !found_consume {
930            // Check if this is the consume statement
931            if let Stmt::Let { var, mutable: true, value, .. } = stmt {
932                if *var == alias && matches!(value, Expr::Identifier(s) if *s == param_sym) {
933                    found_consume = true;
934                    continue;
935                }
936            }
937        } else {
938            // After the consume: param must not appear
939            if stmt_references_symbol(stmt, param_sym) {
940                return false;
941            }
942        }
943    }
944    found_consume // Must have actually found the consume
945}
946
947/// Check if a statement references a given symbol anywhere (expressions, collections, etc.).
948fn stmt_references_symbol(stmt: &Stmt<'_>, sym: Symbol) -> bool {
949    match stmt {
950        Stmt::Let { value, .. } => expr_references_symbol(value, sym),
951        Stmt::Set { target, value } => *target == sym || expr_references_symbol(value, sym),
952        Stmt::Call { function, args } => {
953            *function == sym || args.iter().any(|a| expr_references_symbol(a, sym))
954        }
955        Stmt::Push { value, collection } => {
956            expr_references_symbol(value, sym) || expr_references_symbol(collection, sym)
957        }
958        Stmt::Pop { collection, into } => {
959            expr_references_symbol(collection, sym)
960                || into.map_or(false, |s| s == sym)
961        }
962        Stmt::Add { value, collection } | Stmt::Remove { value, collection } => {
963            expr_references_symbol(value, sym) || expr_references_symbol(collection, sym)
964        }
965        Stmt::SetIndex { collection, index, value } => {
966            expr_references_symbol(collection, sym)
967                || expr_references_symbol(index, sym)
968                || expr_references_symbol(value, sym)
969        }
970        Stmt::SetField { object, value, .. } => {
971            expr_references_symbol(object, sym) || expr_references_symbol(value, sym)
972        }
973        Stmt::Return { value: Some(v) } => expr_references_symbol(v, sym),
974        Stmt::Return { value: None } => false,
975        Stmt::If { cond, then_block, else_block } => {
976            expr_references_symbol(cond, sym)
977                || then_block.iter().any(|s| stmt_references_symbol(s, sym))
978                || else_block.as_ref().map_or(false, |eb| eb.iter().any(|s| stmt_references_symbol(s, sym)))
979        }
980        Stmt::While { cond, body, .. } => {
981            expr_references_symbol(cond, sym)
982                || body.iter().any(|s| stmt_references_symbol(s, sym))
983        }
984        Stmt::Repeat { iterable, body, .. } => {
985            expr_references_symbol(iterable, sym)
986                || body.iter().any(|s| stmt_references_symbol(s, sym))
987        }
988        Stmt::Inspect { arms, .. } => {
989            arms.iter().any(|arm| arm.body.iter().any(|s| stmt_references_symbol(s, sym)))
990        }
991        Stmt::Show { object, .. } => expr_references_symbol(object, sym),
992        _ => false,
993    }
994}
995
996fn expr_references_symbol(expr: &Expr<'_>, sym: Symbol) -> bool {
997    match expr {
998        Expr::Identifier(s) => *s == sym,
999        Expr::BinaryOp { left, right, .. } => {
1000            expr_references_symbol(left, sym) || expr_references_symbol(right, sym)
1001        }
1002        Expr::Not { operand } => expr_references_symbol(operand, sym),
1003        Expr::Call { function, args } => {
1004            *function == sym || args.iter().any(|a| expr_references_symbol(a, sym))
1005        }
1006        Expr::Index { collection, index } => {
1007            expr_references_symbol(collection, sym) || expr_references_symbol(index, sym)
1008        }
1009        Expr::Length { collection } => expr_references_symbol(collection, sym),
1010        Expr::Contains { collection, value } => {
1011            expr_references_symbol(collection, sym) || expr_references_symbol(value, sym)
1012        }
1013        Expr::FieldAccess { object, .. } => expr_references_symbol(object, sym),
1014        Expr::Slice { collection, start, end } => {
1015            expr_references_symbol(collection, sym)
1016                || expr_references_symbol(start, sym)
1017                || expr_references_symbol(end, sym)
1018        }
1019        Expr::Copy { expr: inner } | Expr::Give { value: inner } | Expr::OptionSome { value: inner } => {
1020            expr_references_symbol(inner, sym)
1021        }
1022        Expr::CallExpr { callee, args } => {
1023            expr_references_symbol(callee, sym)
1024                || args.iter().any(|a| expr_references_symbol(a, sym))
1025        }
1026        _ => false,
1027    }
1028}