use std::collections::{HashMap, HashSet};
use crate::ast::{FnBody, FnDef, TopLevel};
use crate::call_graph::{find_recursive_fns, tailcall_scc_components};
use crate::ir::{
AllocPolicy, BodyExprPlan, BodyPlan, CallLowerCtx, ThinKind, classify_thin_fn_def,
};
pub struct NeutralAllocPolicy;
impl AllocPolicy for NeutralAllocPolicy {
fn builtin_allocates(&self, name: &str) -> bool {
!matches!(
name,
"Int.abs"
| "Int.min"
| "Int.max"
| "Float.fromInt"
| "Float.abs"
| "Float.floor"
| "Float.ceil"
| "Float.round"
| "Float.min"
| "Float.max"
| "Float.sin"
| "Float.cos"
| "Float.sqrt"
| "Float.pow"
| "Float.atan2"
| "Float.pi"
| "Char.toCode"
| "String.len"
| "String.byteLength"
| "String.startsWith"
| "String.endsWith"
| "String.contains"
| "List.len"
| "List.contains"
| "Vector.len"
| "Map.size"
| "Map.contains"
| "Set.size"
| "Set.contains"
| "Bool.and"
| "Bool.or"
| "Bool.not"
)
}
fn constructor_allocates(&self, _name: &str, has_payload: bool) -> bool {
has_payload
}
}
#[derive(Debug, Clone, Default)]
pub struct AnalysisResult {
pub fn_analyses: HashMap<String, FnAnalysis>,
pub mutual_tco_members: HashSet<String>,
pub recursive_fns: HashSet<String>,
}
#[derive(Debug, Clone)]
pub struct FnAnalysis {
pub allocates: Option<bool>,
pub thin_kind: Option<ThinKind>,
pub body_shape: BodyShape,
pub local_count: Option<u16>,
pub mutual_tco_member: bool,
pub recursive: bool,
pub recursive_call_count: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BodyShape {
LeafExpr,
SingleExpr,
Block(usize),
Unclassified(usize),
}
pub fn analyze(
items: &[TopLevel],
alloc_policy: Option<&dyn AllocPolicy>,
ctx: &impl CallLowerCtx,
) -> AnalysisResult {
let fn_defs: Vec<&FnDef> = items
.iter()
.filter_map(|i| match i {
TopLevel::FnDef(fd) => Some(fd),
_ => None,
})
.collect();
let alloc_info = alloc_policy.map(|policy| {
struct PolicyRef<'a>(&'a dyn AllocPolicy);
impl AllocPolicy for PolicyRef<'_> {
fn builtin_allocates(&self, name: &str) -> bool {
self.0.builtin_allocates(name)
}
fn constructor_allocates(&self, name: &str, has_payload: bool) -> bool {
self.0.constructor_allocates(name, has_payload)
}
}
crate::ir::compute_alloc_info(&fn_defs, &PolicyRef(policy))
});
let entry_fns_for_scc: Vec<&FnDef> = fn_defs
.iter()
.filter(|fd| fd.name != "main")
.copied()
.collect();
let mut mutual_tco_set: HashSet<String> = HashSet::new();
for group in tailcall_scc_components(&entry_fns_for_scc) {
if group.len() < 2 {
continue; }
for fd in group {
mutual_tco_set.insert(fd.name.clone());
}
}
let recursive_set = find_recursive_fns(items);
let recursive_call_counts = crate::call_graph::recursive_callsite_counts(items);
let mut fn_analyses: HashMap<String, FnAnalysis> = HashMap::with_capacity(fn_defs.len());
for fd in &fn_defs {
let plan = classify_thin_fn_def(fd, ctx);
let body_shape = match &plan {
Some(p) => match &p.body {
BodyPlan::SingleExpr(BodyExprPlan::Leaf(_)) => BodyShape::LeafExpr,
BodyPlan::SingleExpr(_) => BodyShape::SingleExpr,
BodyPlan::Block { stmts, .. } => BodyShape::Block(stmts.len()),
},
None => {
let FnBody::Block(stmts) = fd.body.as_ref();
BodyShape::Unclassified(stmts.len())
}
};
let analysis = FnAnalysis {
allocates: alloc_info.as_ref().and_then(|m| m.get(&fd.name).copied()),
thin_kind: plan.as_ref().map(|p| p.kind),
body_shape,
local_count: fd.resolution.as_ref().map(|r| r.local_count),
mutual_tco_member: mutual_tco_set.contains(&fd.name),
recursive: recursive_set.contains(&fd.name),
recursive_call_count: recursive_call_counts.get(&fd.name).copied().unwrap_or(0),
};
fn_analyses.insert(fd.name.clone(), analysis);
}
AnalysisResult {
fn_analyses,
mutual_tco_members: mutual_tco_set,
recursive_fns: recursive_set,
}
}