1use std::sync::Arc;
40
41use crate::ast::{Expr, FnBody, FnDef, Stmt, StrPart, TopLevel};
42
43pub fn annotate_program_alias_slots(items: &mut [TopLevel]) {
44 for item in items {
45 if let TopLevel::FnDef(fd) = item {
46 annotate_fn(fd);
47 }
48 }
49}
50
51fn annotate_fn(fd: &mut FnDef) {
52 let Some(res) = fd.resolution.clone() else {
53 return;
54 };
55 let local_count = res.local_count as usize;
56 let mut aliased = vec![false; local_count];
57
58 for (i, (_, ty)) in fd.params.iter().enumerate() {
60 if param_type_is_alias_prone(ty)
61 && let Some(slot) = aliased.get_mut(i)
62 {
63 *slot = true;
64 }
65 }
66
67 let body = fd.body.clone();
71 let FnBody::Block(stmts) = body.as_ref();
72 for _ in 0..2 {
73 for stmt in stmts {
74 if let Stmt::Binding(name, _, expr) = stmt {
75 let Some(&slot) = res.local_slots.get(name) else {
76 continue;
77 };
78 if expr_is_alias_source(&expr.node, &aliased)
79 && let Some(s) = aliased.get_mut(slot as usize)
80 {
81 *s = true;
82 }
83 }
84 }
85 }
86
87 let new_res = crate::ast::FnResolution {
90 local_count: res.local_count,
91 local_slots: res.local_slots.clone(),
92 local_slot_types: res.local_slot_types.clone(),
93 aliased_slots: Arc::new(aliased),
94 };
95 fd.resolution = Some(new_res);
96}
97
98fn param_type_is_alias_prone(ty: &str) -> bool {
99 let trimmed = ty.trim();
100 trimmed.starts_with("Vector<") || trimmed.starts_with("Map<")
101}
102
103fn expr_is_alias_source(expr: &Expr, aliased: &[bool]) -> bool {
104 if let Expr::Resolved { slot, .. } = expr
105 && aliased.get(*slot as usize).copied().unwrap_or(false)
106 {
107 return true;
108 }
109 contains_alias_source_call(expr)
110}
111
112fn contains_alias_source_call(expr: &Expr) -> bool {
113 match expr {
114 Expr::FnCall(callee, args) => {
115 if let Expr::Attr(parent, member) = &callee.node
116 && let Expr::Ident(p) = &parent.node
117 {
118 if (p == "Vector" || p == "Map") && member == "get" {
119 return true;
120 }
121 if p == "Vector"
122 && member == "new"
123 && args.len() == 2
124 && let Some(t) = args[1].ty()
125 && type_is_compound(&t.display())
126 {
127 return true;
128 }
129 }
130 if contains_alias_source_call(&callee.node) {
131 return true;
132 }
133 args.iter().any(|a| contains_alias_source_call(&a.node))
134 }
135 Expr::Attr(inner, _) => contains_alias_source_call(&inner.node),
136 Expr::BinOp(_, lhs, rhs) => {
137 contains_alias_source_call(&lhs.node) || contains_alias_source_call(&rhs.node)
138 }
139 Expr::Match { subject, arms } => {
140 contains_alias_source_call(&subject.node)
141 || arms
142 .iter()
143 .any(|a| contains_alias_source_call(&a.body.node))
144 }
145 Expr::Constructor(_, payload) => payload
146 .as_ref()
147 .is_some_and(|p| contains_alias_source_call(&p.node)),
148 Expr::ErrorProp(inner) => contains_alias_source_call(&inner.node),
149 Expr::Tuple(items) | Expr::List(items) | Expr::IndependentProduct(items, _) => {
150 items.iter().any(|i| contains_alias_source_call(&i.node))
151 }
152 Expr::MapLiteral(pairs) => pairs.iter().any(|(k, v)| {
153 contains_alias_source_call(&k.node) || contains_alias_source_call(&v.node)
154 }),
155 Expr::RecordCreate { fields, .. } => fields
156 .iter()
157 .any(|(_, e)| contains_alias_source_call(&e.node)),
158 Expr::RecordUpdate { base, updates, .. } => {
159 contains_alias_source_call(&base.node)
160 || updates
161 .iter()
162 .any(|(_, e)| contains_alias_source_call(&e.node))
163 }
164 Expr::InterpolatedStr(parts) => parts.iter().any(|p| match p {
165 StrPart::Parsed(e) => contains_alias_source_call(&e.node),
166 StrPart::Literal(_) => false,
167 }),
168 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::TailCall(_) => false,
169 }
170}
171
172fn type_is_compound(ty: &str) -> bool {
173 let trimmed = ty.trim();
174 trimmed.starts_with("Vector<")
175 || trimmed.starts_with("Map<")
176 || trimmed.starts_with("List<")
177 || trimmed.starts_with("Tuple<")
178 || trimmed.starts_with("Result<")
179 || trimmed.starts_with("Option<")
180 || (trimmed
181 .chars()
182 .next()
183 .is_some_and(|c| c.is_ascii_uppercase())
184 && !matches!(trimmed, "Int" | "Float" | "Bool" | "String" | "Unit"))
185}