use std::collections::HashMap;
use crate::ast::{Expr, FnBody, FnDef, Stmt, StrPart};
use super::calls::{expr_to_dotted_name, is_builtin_namespace};
pub trait AllocPolicy {
fn builtin_allocates(&self, name: &str) -> bool;
fn constructor_allocates(&self, name: &str, has_payload: bool) -> bool;
}
fn expr_allocates<P: AllocPolicy>(
expr: &Expr,
user_allocates: &HashMap<String, bool>,
policy: &P,
) -> bool {
match expr {
Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => false,
Expr::Constructor(_, None) => false,
Expr::List(_)
| Expr::Tuple(_)
| Expr::MapLiteral(_)
| Expr::RecordCreate { .. }
| Expr::RecordUpdate { .. }
| Expr::IndependentProduct(_, _) => true,
Expr::InterpolatedStr(parts) => {
parts.iter().any(|p| matches!(p, StrPart::Parsed(_)))
|| expr_children_allocate(expr, user_allocates, policy)
}
Expr::Constructor(name, Some(payload)) => {
policy.constructor_allocates(name, true)
|| expr_allocates(&payload.node, user_allocates, policy)
}
Expr::FnCall(callee, args) => {
if let Some(name) = expr_to_dotted_name(&callee.node) {
let ns = name.split('.').next().unwrap_or("");
if is_builtin_namespace(ns) {
if policy.builtin_allocates(&name) {
return true;
}
} else if let Some(&true) = user_allocates.get(&name) {
return true;
}
}
args.iter()
.any(|a| expr_allocates(&a.node, user_allocates, policy))
}
Expr::TailCall(data) => {
if let Some(&true) = user_allocates.get(&data.target) {
return true;
}
data.args
.iter()
.any(|a| expr_allocates(&a.node, user_allocates, policy))
}
Expr::Attr(base, _) | Expr::ErrorProp(base) => {
expr_allocates(&base.node, user_allocates, policy)
}
Expr::BinOp(_, l, r) => {
expr_allocates(&l.node, user_allocates, policy)
|| expr_allocates(&r.node, user_allocates, policy)
}
Expr::Match { subject, arms } => {
expr_allocates(&subject.node, user_allocates, policy)
|| arms
.iter()
.any(|a| expr_allocates(&a.body.node, user_allocates, policy))
}
}
}
fn expr_children_allocate<P: AllocPolicy>(
expr: &Expr,
user_allocates: &HashMap<String, bool>,
policy: &P,
) -> bool {
if let Expr::InterpolatedStr(parts) = expr {
return parts.iter().any(|p| match p {
StrPart::Literal(_) => false,
StrPart::Parsed(e) => expr_allocates(&e.node, user_allocates, policy),
});
}
false
}
fn body_allocates<P: AllocPolicy>(
body: &FnBody,
user_allocates: &HashMap<String, bool>,
policy: &P,
) -> bool {
body.stmts().iter().any(|s| match s {
Stmt::Binding(_, _, e) | Stmt::Expr(e) => expr_allocates(&e.node, user_allocates, policy),
})
}
pub fn count_alloc_sites_in_fn<P: AllocPolicy>(fd: &FnDef, policy: &P) -> usize {
let FnBody::Block(stmts) = fd.body.as_ref();
let mut acc = 0;
for stmt in stmts {
match stmt {
Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
count_expr_alloc_sites(&e.node, policy, &mut acc)
}
}
}
acc
}
fn count_expr_alloc_sites<P: AllocPolicy>(expr: &Expr, policy: &P, acc: &mut usize) {
match expr {
Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
Expr::Constructor(_, None) => {}
Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
*acc += 1;
for item in items {
count_expr_alloc_sites(&item.node, policy, acc);
}
}
Expr::MapLiteral(entries) => {
*acc += 1;
for (k, v) in entries {
count_expr_alloc_sites(&k.node, policy, acc);
count_expr_alloc_sites(&v.node, policy, acc);
}
}
Expr::RecordCreate { fields, .. } => {
*acc += 1;
for (_, v) in fields {
count_expr_alloc_sites(&v.node, policy, acc);
}
}
Expr::RecordUpdate { base, updates, .. } => {
*acc += 1;
count_expr_alloc_sites(&base.node, policy, acc);
for (_, v) in updates {
count_expr_alloc_sites(&v.node, policy, acc);
}
}
Expr::InterpolatedStr(parts) => {
if parts.iter().any(|p| matches!(p, StrPart::Parsed(_))) {
*acc += 1;
}
for part in parts {
if let StrPart::Parsed(e) = part {
count_expr_alloc_sites(&e.node, policy, acc);
}
}
}
Expr::Constructor(name, Some(payload)) => {
if policy.constructor_allocates(name, true) {
*acc += 1;
}
count_expr_alloc_sites(&payload.node, policy, acc);
}
Expr::FnCall(callee, args) => {
if let Some(name) = expr_to_dotted_name(&callee.node) {
let ns = name.split('.').next().unwrap_or("");
if is_builtin_namespace(ns) && policy.builtin_allocates(&name) {
*acc += 1;
}
}
count_expr_alloc_sites(&callee.node, policy, acc);
for a in args {
count_expr_alloc_sites(&a.node, policy, acc);
}
}
Expr::TailCall(data) => {
for a in &data.args {
count_expr_alloc_sites(&a.node, policy, acc);
}
}
Expr::Attr(base, _) | Expr::ErrorProp(base) => {
count_expr_alloc_sites(&base.node, policy, acc);
}
Expr::BinOp(_, l, r) => {
count_expr_alloc_sites(&l.node, policy, acc);
count_expr_alloc_sites(&r.node, policy, acc);
}
Expr::Match { subject, arms } => {
count_expr_alloc_sites(&subject.node, policy, acc);
for arm in arms {
count_expr_alloc_sites(&arm.body.node, policy, acc);
}
}
}
}
pub fn count_alloc_sites_in_program<P: AllocPolicy>(
items: &[crate::ast::TopLevel],
policy: &P,
) -> usize {
items
.iter()
.filter_map(|it| match it {
crate::ast::TopLevel::FnDef(fd) => Some(count_alloc_sites_in_fn(fd, policy)),
_ => None,
})
.sum()
}
pub fn compute_alloc_info<P: AllocPolicy>(fns: &[&FnDef], policy: &P) -> HashMap<String, bool> {
let mut info: HashMap<String, bool> = fns
.iter()
.map(|fd| {
(fd.name.clone(), !fd.effects.is_empty())
})
.collect();
loop {
let mut changed = false;
for fd in fns {
if *info.get(&fd.name).unwrap_or(&false) {
continue;
}
if body_allocates(&fd.body, &info, policy) {
info.insert(fd.name.clone(), true);
changed = true;
}
}
if !changed {
break;
}
}
info
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{BinOp, FnDef, Literal, Spanned};
use std::sync::Arc;
struct TestPolicy;
impl AllocPolicy for TestPolicy {
fn builtin_allocates(&self, name: &str) -> bool {
name.starts_with("Map.") || name == "String.fromInt"
}
fn constructor_allocates(&self, _name: &str, _has_payload: bool) -> bool {
false
}
}
fn sp<T>(value: T) -> Spanned<T> {
Spanned::new(value, 1)
}
fn lit_int(n: i64) -> Spanned<Expr> {
sp(Expr::Literal(Literal::Int(n)))
}
fn fn_def_pure(name: &str, body: Expr) -> FnDef {
FnDef {
name: name.to_string(),
line: 1,
params: vec![],
return_type: "Int".into(),
effects: vec![],
desc: None,
body: Arc::new(FnBody::from_expr(sp(body))),
resolution: None,
}
}
#[test]
fn pure_arithmetic_does_not_allocate() {
let fd = fn_def_pure(
"addOne",
Expr::BinOp(BinOp::Add, Box::new(lit_int(1)), Box::new(lit_int(2))),
);
let info = compute_alloc_info(&[&fd], &TestPolicy);
assert_eq!(info.get("addOne"), Some(&false));
}
#[test]
fn list_literal_allocates() {
let fd = fn_def_pure("makeList", Expr::List(vec![lit_int(1), lit_int(2)]));
let info = compute_alloc_info(&[&fd], &TestPolicy);
assert_eq!(info.get("makeList"), Some(&true));
}
#[test]
fn allocating_builtin_call_allocates() {
let call = Expr::FnCall(
Box::new(sp(Expr::Attr(
Box::new(sp(Expr::Ident("String".into()))),
"fromInt".into(),
))),
vec![lit_int(42)],
);
let fd = fn_def_pure("stringify", call);
let info = compute_alloc_info(&[&fd], &TestPolicy);
assert_eq!(info.get("stringify"), Some(&true));
}
#[test]
fn pure_builtin_call_does_not_allocate() {
let call = Expr::FnCall(
Box::new(sp(Expr::Attr(
Box::new(sp(Expr::Ident("Int".into()))),
"abs".into(),
))),
vec![lit_int(-5)],
);
let fd = fn_def_pure("absVal", call);
let info = compute_alloc_info(&[&fd], &TestPolicy);
assert_eq!(info.get("absVal"), Some(&false));
}
#[test]
fn effects_force_allocating() {
let mut fd = fn_def_pure("logIt", Expr::Literal(Literal::Int(0)));
fd.effects = vec![sp("Console.print".into())];
let info = compute_alloc_info(&[&fd], &TestPolicy);
assert_eq!(info.get("logIt"), Some(&true));
}
#[test]
fn transitive_user_call_propagates() {
let inner = fn_def_pure("makeListInner", Expr::List(vec![lit_int(1)]));
let call = Expr::FnCall(Box::new(sp(Expr::Ident("makeListInner".into()))), vec![]);
let wrapper = fn_def_pure("wrapperFn", call);
let info = compute_alloc_info(&[&inner, &wrapper], &TestPolicy);
assert_eq!(info.get("makeListInner"), Some(&true));
assert_eq!(info.get("wrapperFn"), Some(&true));
}
#[test]
fn mutual_recursion_pure_stays_pure() {
let f = fn_def_pure(
"f",
Expr::FnCall(Box::new(sp(Expr::Ident("g".into()))), vec![lit_int(1)]),
);
let g = fn_def_pure(
"g",
Expr::FnCall(Box::new(sp(Expr::Ident("f".into()))), vec![lit_int(2)]),
);
let info = compute_alloc_info(&[&f, &g], &TestPolicy);
assert_eq!(info.get("f"), Some(&false));
assert_eq!(info.get("g"), Some(&false));
}
#[test]
fn mutual_recursion_one_allocates_taints_the_group() {
let f = fn_def_pure(
"f",
Expr::FnCall(Box::new(sp(Expr::Ident("g".into()))), vec![lit_int(1)]),
);
let g = fn_def_pure("g", Expr::List(vec![lit_int(0)]));
let info = compute_alloc_info(&[&f, &g], &TestPolicy);
assert_eq!(info.get("f"), Some(&true));
assert_eq!(info.get("g"), Some(&true));
}
}