1use std::collections::HashMap;
26
27use crate::ast::{Expr, FnBody, FnDef, Stmt, StrPart};
28
29use super::calls::{expr_to_dotted_name, is_builtin_namespace};
30
31pub trait AllocPolicy {
37 fn builtin_allocates(&self, name: &str) -> bool;
40
41 fn constructor_allocates(&self, name: &str, has_payload: bool) -> bool;
46}
47
48fn expr_allocates<P: AllocPolicy>(
52 expr: &Expr,
53 user_allocates: &HashMap<String, bool>,
54 policy: &P,
55) -> bool {
56 match expr {
57 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => false,
58 Expr::Constructor(_, None) => false,
59
60 Expr::List(_)
63 | Expr::Tuple(_)
64 | Expr::MapLiteral(_)
65 | Expr::RecordCreate { .. }
66 | Expr::RecordUpdate { .. }
67 | Expr::IndependentProduct(_, _) => true,
68 Expr::InterpolatedStr(parts) => {
69 parts.iter().any(|p| matches!(p, StrPart::Parsed(_)))
73 || expr_children_allocate(expr, user_allocates, policy)
74 }
75 Expr::Constructor(name, Some(payload)) => {
76 policy.constructor_allocates(name, true)
77 || expr_allocates(&payload.node, user_allocates, policy)
78 }
79
80 Expr::FnCall(callee, args) => {
82 if let Some(name) = expr_to_dotted_name(&callee.node) {
83 let ns = name.split('.').next().unwrap_or("");
84 if is_builtin_namespace(ns) {
85 if policy.builtin_allocates(&name) {
86 return true;
87 }
88 } else if let Some(&true) = user_allocates.get(&name) {
89 return true;
90 }
91 }
92 args.iter()
93 .any(|a| expr_allocates(&a.node, user_allocates, policy))
94 }
95 Expr::TailCall(data) => {
96 if let Some(&true) = user_allocates.get(&data.target) {
97 return true;
98 }
99 data.args
100 .iter()
101 .any(|a| expr_allocates(&a.node, user_allocates, policy))
102 }
103
104 Expr::Attr(base, _) | Expr::ErrorProp(base) => {
106 expr_allocates(&base.node, user_allocates, policy)
107 }
108 Expr::BinOp(_, l, r) => {
109 expr_allocates(&l.node, user_allocates, policy)
110 || expr_allocates(&r.node, user_allocates, policy)
111 }
112 Expr::Match { subject, arms } => {
113 expr_allocates(&subject.node, user_allocates, policy)
114 || arms
115 .iter()
116 .any(|a| expr_allocates(&a.body.node, user_allocates, policy))
117 }
118 }
119}
120
121fn expr_children_allocate<P: AllocPolicy>(
123 expr: &Expr,
124 user_allocates: &HashMap<String, bool>,
125 policy: &P,
126) -> bool {
127 if let Expr::InterpolatedStr(parts) = expr {
128 return parts.iter().any(|p| match p {
129 StrPart::Literal(_) => false,
130 StrPart::Parsed(e) => expr_allocates(&e.node, user_allocates, policy),
131 });
132 }
133 false
134}
135
136fn body_allocates<P: AllocPolicy>(
138 body: &FnBody,
139 user_allocates: &HashMap<String, bool>,
140 policy: &P,
141) -> bool {
142 body.stmts().iter().any(|s| match s {
143 Stmt::Binding(_, _, e) | Stmt::Expr(e) => expr_allocates(&e.node, user_allocates, policy),
144 })
145}
146
147pub fn count_alloc_sites_in_fn<P: AllocPolicy>(fd: &FnDef, policy: &P) -> usize {
158 let FnBody::Block(stmts) = fd.body.as_ref();
159 let mut acc = 0;
160 for stmt in stmts {
161 match stmt {
162 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
163 count_expr_alloc_sites(&e.node, policy, &mut acc)
164 }
165 }
166 }
167 acc
168}
169
170fn count_expr_alloc_sites<P: AllocPolicy>(expr: &Expr, policy: &P, acc: &mut usize) {
171 match expr {
172 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
173 Expr::Constructor(_, None) => {}
174
175 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
176 *acc += 1;
177 for item in items {
178 count_expr_alloc_sites(&item.node, policy, acc);
179 }
180 }
181 Expr::MapLiteral(entries) => {
182 *acc += 1;
183 for (k, v) in entries {
184 count_expr_alloc_sites(&k.node, policy, acc);
185 count_expr_alloc_sites(&v.node, policy, acc);
186 }
187 }
188 Expr::RecordCreate { fields, .. } => {
189 *acc += 1;
190 for (_, v) in fields {
191 count_expr_alloc_sites(&v.node, policy, acc);
192 }
193 }
194 Expr::RecordUpdate { base, updates, .. } => {
195 *acc += 1;
196 count_expr_alloc_sites(&base.node, policy, acc);
197 for (_, v) in updates {
198 count_expr_alloc_sites(&v.node, policy, acc);
199 }
200 }
201 Expr::InterpolatedStr(parts) => {
202 if parts.iter().any(|p| matches!(p, StrPart::Parsed(_))) {
206 *acc += 1;
207 }
208 for part in parts {
209 if let StrPart::Parsed(e) = part {
210 count_expr_alloc_sites(&e.node, policy, acc);
211 }
212 }
213 }
214 Expr::Constructor(name, Some(payload)) => {
215 if policy.constructor_allocates(name, true) {
216 *acc += 1;
217 }
218 count_expr_alloc_sites(&payload.node, policy, acc);
219 }
220 Expr::FnCall(callee, args) => {
221 if let Some(name) = expr_to_dotted_name(&callee.node) {
222 let ns = name.split('.').next().unwrap_or("");
223 if is_builtin_namespace(ns) && policy.builtin_allocates(&name) {
224 *acc += 1;
225 }
226 }
227 count_expr_alloc_sites(&callee.node, policy, acc);
228 for a in args {
229 count_expr_alloc_sites(&a.node, policy, acc);
230 }
231 }
232 Expr::TailCall(data) => {
233 for a in &data.args {
234 count_expr_alloc_sites(&a.node, policy, acc);
235 }
236 }
237 Expr::Attr(base, _) | Expr::ErrorProp(base) => {
238 count_expr_alloc_sites(&base.node, policy, acc);
239 }
240 Expr::BinOp(_, l, r) => {
241 count_expr_alloc_sites(&l.node, policy, acc);
242 count_expr_alloc_sites(&r.node, policy, acc);
243 }
244 Expr::Match { subject, arms } => {
245 count_expr_alloc_sites(&subject.node, policy, acc);
246 for arm in arms {
247 count_expr_alloc_sites(&arm.body.node, policy, acc);
248 }
249 }
250 }
251}
252
253pub fn count_alloc_sites_in_program<P: AllocPolicy>(
256 items: &[crate::ast::TopLevel],
257 policy: &P,
258) -> usize {
259 items
260 .iter()
261 .filter_map(|it| match it {
262 crate::ast::TopLevel::FnDef(fd) => Some(count_alloc_sites_in_fn(fd, policy)),
263 _ => None,
264 })
265 .sum()
266}
267
268pub fn compute_alloc_info<P: AllocPolicy>(fns: &[&FnDef], policy: &P) -> HashMap<String, bool> {
279 let mut info: HashMap<String, bool> = fns
280 .iter()
281 .map(|fd| {
282 (fd.name.clone(), !fd.effects.is_empty())
284 })
285 .collect();
286
287 loop {
288 let mut changed = false;
289 for fd in fns {
290 if *info.get(&fd.name).unwrap_or(&false) {
291 continue;
292 }
293 if body_allocates(&fd.body, &info, policy) {
294 info.insert(fd.name.clone(), true);
295 changed = true;
296 }
297 }
298 if !changed {
299 break;
300 }
301 }
302
303 info
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::ast::{BinOp, FnDef, Literal, Spanned};
310 use std::sync::Arc;
311
312 struct TestPolicy;
316
317 impl AllocPolicy for TestPolicy {
318 fn builtin_allocates(&self, name: &str) -> bool {
319 name.starts_with("Map.") || name == "String.fromInt"
320 }
321 fn constructor_allocates(&self, _name: &str, _has_payload: bool) -> bool {
322 false
323 }
324 }
325
326 fn sp<T>(value: T) -> Spanned<T> {
327 Spanned::new(value, 1)
328 }
329
330 fn lit_int(n: i64) -> Spanned<Expr> {
331 sp(Expr::Literal(Literal::Int(n)))
332 }
333
334 fn fn_def_pure(name: &str, body: Expr) -> FnDef {
335 FnDef {
336 name: name.to_string(),
337 line: 1,
338 params: vec![],
339 return_type: "Int".into(),
340 effects: vec![],
341 desc: None,
342 body: Arc::new(FnBody::from_expr(sp(body))),
343 resolution: None,
344 }
345 }
346
347 #[test]
348 fn pure_arithmetic_does_not_allocate() {
349 let fd = fn_def_pure(
350 "addOne",
351 Expr::BinOp(BinOp::Add, Box::new(lit_int(1)), Box::new(lit_int(2))),
352 );
353 let info = compute_alloc_info(&[&fd], &TestPolicy);
354 assert_eq!(info.get("addOne"), Some(&false));
355 }
356
357 #[test]
358 fn list_literal_allocates() {
359 let fd = fn_def_pure("makeList", Expr::List(vec![lit_int(1), lit_int(2)]));
360 let info = compute_alloc_info(&[&fd], &TestPolicy);
361 assert_eq!(info.get("makeList"), Some(&true));
362 }
363
364 #[test]
365 fn allocating_builtin_call_allocates() {
366 let call = Expr::FnCall(
368 Box::new(sp(Expr::Attr(
369 Box::new(sp(Expr::Ident("String".into()))),
370 "fromInt".into(),
371 ))),
372 vec![lit_int(42)],
373 );
374 let fd = fn_def_pure("stringify", call);
375 let info = compute_alloc_info(&[&fd], &TestPolicy);
376 assert_eq!(info.get("stringify"), Some(&true));
377 }
378
379 #[test]
380 fn pure_builtin_call_does_not_allocate() {
381 let call = Expr::FnCall(
383 Box::new(sp(Expr::Attr(
384 Box::new(sp(Expr::Ident("Int".into()))),
385 "abs".into(),
386 ))),
387 vec![lit_int(-5)],
388 );
389 let fd = fn_def_pure("absVal", call);
390 let info = compute_alloc_info(&[&fd], &TestPolicy);
391 assert_eq!(info.get("absVal"), Some(&false));
392 }
393
394 #[test]
395 fn effects_force_allocating() {
396 let mut fd = fn_def_pure("logIt", Expr::Literal(Literal::Int(0)));
397 fd.effects = vec![sp("Console.print".into())];
398 let info = compute_alloc_info(&[&fd], &TestPolicy);
399 assert_eq!(info.get("logIt"), Some(&true));
400 }
401
402 #[test]
403 fn transitive_user_call_propagates() {
404 let inner = fn_def_pure("makeListInner", Expr::List(vec![lit_int(1)]));
407
408 let call = Expr::FnCall(Box::new(sp(Expr::Ident("makeListInner".into()))), vec![]);
409 let wrapper = fn_def_pure("wrapperFn", call);
410
411 let info = compute_alloc_info(&[&inner, &wrapper], &TestPolicy);
412 assert_eq!(info.get("makeListInner"), Some(&true));
413 assert_eq!(info.get("wrapperFn"), Some(&true));
414 }
415
416 #[test]
417 fn mutual_recursion_pure_stays_pure() {
418 let f = fn_def_pure(
420 "f",
421 Expr::FnCall(Box::new(sp(Expr::Ident("g".into()))), vec![lit_int(1)]),
422 );
423 let g = fn_def_pure(
424 "g",
425 Expr::FnCall(Box::new(sp(Expr::Ident("f".into()))), vec![lit_int(2)]),
426 );
427 let info = compute_alloc_info(&[&f, &g], &TestPolicy);
428 assert_eq!(info.get("f"), Some(&false));
429 assert_eq!(info.get("g"), Some(&false));
430 }
431
432 #[test]
433 fn mutual_recursion_one_allocates_taints_the_group() {
434 let f = fn_def_pure(
436 "f",
437 Expr::FnCall(Box::new(sp(Expr::Ident("g".into()))), vec![lit_int(1)]),
438 );
439 let g = fn_def_pure("g", Expr::List(vec![lit_int(0)]));
440 let info = compute_alloc_info(&[&f, &g], &TestPolicy);
441 assert_eq!(info.get("f"), Some(&true));
442 assert_eq!(info.get("g"), Some(&true));
443 }
444}