use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::{Inst, IrFunction, RegionAllocKind, VarId};
use super::escape::{
ClosureInfo, EscapeContext, EscapeMode, EscapeState, build_use_chains, build_var_defs,
classify_escape_with_ctx, collect_allocs,
};
use super::optimize::{
blocks_on_path, collect_use_blocks, dominators, has_back_edge, lca_of_many, post_dominators,
region_contains_throw,
};
struct Candidate {
dst: VarId,
callee_fn_name: Arc<str>,
returns_allocs: HashSet<VarId>,
capture_count: usize,
}
fn resolve_callee_name(
callee_var: VarId,
arg_count: usize,
var_defs: &HashMap<VarId, &Inst>,
defn_map: &HashMap<(Arc<str>, Arc<str>), ClosureInfo>,
) -> Option<Arc<str>> {
let def_inst = var_defs.get(&callee_var)?;
let info: &ClosureInfo = match def_inst {
Inst::LoadGlobal(_, ns, name) => defn_map.get(&(ns.clone(), name.clone()))?,
Inst::Deref(_, src) => match var_defs.get(src)? {
Inst::LoadGlobal(_, ns, name) | Inst::LoadVar(_, ns, name) => {
defn_map.get(&(ns.clone(), name.clone()))?
}
_ => return None,
},
_ => return None,
};
for (i, &count) in info.param_counts.iter().enumerate() {
if count == arg_count && !info.is_variadic[i] {
return Some(info.arity_fn_names[i].clone());
}
}
None
}
fn returns_allocs_of(callee: &IrFunction, ctx: &EscapeContext) -> HashSet<VarId> {
let alloc_blocks = collect_allocs(callee);
let uses = build_use_chains(callee);
let var_defs = build_var_defs(callee);
alloc_blocks
.keys()
.filter_map(|&alloc| {
let state = classify_escape_with_ctx(
alloc,
&uses,
callee,
Some(ctx),
Some(&var_defs),
EscapeMode::Alloc,
);
if state == EscapeState::Returns {
Some(alloc)
} else {
None
}
})
.collect()
}
fn alloc_to_region_kind(inst: &Inst) -> Option<RegionAllocKind> {
match inst {
Inst::AllocVector(..) => Some(RegionAllocKind::Vector),
Inst::AllocMap(..) => Some(RegionAllocKind::Map),
Inst::AllocSet(..) => Some(RegionAllocKind::Set),
Inst::AllocList(..) => Some(RegionAllocKind::List),
Inst::AllocCons(..) => Some(RegionAllocKind::Cons),
_ => None,
}
}
fn alloc_operands(inst: &Inst) -> Vec<VarId> {
match inst {
Inst::AllocVector(_, elems) | Inst::AllocSet(_, elems) | Inst::AllocList(_, elems) => {
elems.clone()
}
Inst::AllocMap(_, pairs) => pairs.iter().flat_map(|&(k, v)| [k, v]).collect(),
Inst::AllocCons(_, head, tail) => vec![*head, *tail],
_ => vec![],
}
}
fn rename_inner_subfunctions(func: &mut IrFunction, suffix: &str) {
let mut name_map: HashMap<Arc<str>, Arc<str>> = HashMap::new();
fn collect(f: &IrFunction, suffix: &str, map: &mut HashMap<Arc<str>, Arc<str>>) {
for sub in &f.subfunctions {
if let Some(name) = &sub.name {
let new_name: Arc<str> = Arc::from(format!("{name}{suffix}").as_str());
map.insert(name.clone(), new_name);
}
collect(sub, suffix, map);
}
}
collect(func, suffix, &mut name_map);
if name_map.is_empty() {
return;
}
fn rewrite(f: &mut IrFunction, map: &HashMap<Arc<str>, Arc<str>>) {
for block in &mut f.blocks {
for inst in block.phis.iter_mut().chain(block.insts.iter_mut()) {
if let Inst::AllocClosure(_, tmpl, _) = inst {
for n in &mut tmpl.arity_fn_names {
if let Some(new_name) = map.get(n) {
*n = new_name.clone();
}
}
}
}
}
for sub in &mut f.subfunctions {
if let Some(name) = &sub.name
&& let Some(new_name) = map.get(name)
{
sub.name = Some(new_name.clone());
}
rewrite(sub, map);
}
}
rewrite(func, &name_map);
}
fn specialize(original: &IrFunction, targets: &HashSet<VarId>, suffix: &str) -> Option<IrFunction> {
let mut clone = original.clone();
let region_var = VarId(clone.next_var);
clone.next_var += 1;
let mut promoted_any = false;
for block in &mut clone.blocks {
for inst in &mut block.insts {
let Some(dst) = inst.dst() else {
continue;
};
if !targets.contains(&dst) {
continue;
}
let Some(kind) = alloc_to_region_kind(inst) else {
continue;
};
let ops = alloc_operands(inst);
*inst = Inst::RegionAlloc(dst, region_var, kind, ops);
promoted_any = true;
}
}
if !promoted_any {
return None;
}
if let Some(entry) = clone.blocks.first_mut() {
entry.insts.insert(0, Inst::RegionParam(region_var));
} else {
return None;
}
rename_inner_subfunctions(&mut clone, suffix);
let new_name: Arc<str> = match &original.name {
Some(n) => Arc::from(format!("{n}{suffix}").as_str()),
None => Arc::from(format!("__cljrs_anon{suffix}").as_str()),
};
clone.name = Some(new_name);
Some(clone)
}
fn find_call_by_dst(func: &IrFunction, dst: VarId) -> Option<(usize, usize)> {
for (b_idx, block) in func.blocks.iter().enumerate() {
for (i_idx, inst) in block.insts.iter().enumerate() {
if let Inst::Call(d, _, _) = inst
&& *d == dst
{
return Some((b_idx, i_idx));
}
}
}
None
}
fn rewrite_call_with_region_scope(
func: &mut IrFunction,
dst: VarId,
target_name: Arc<str>,
capture_count: usize,
) -> bool {
let Some((block_idx, inst_idx)) = find_call_by_dst(func, dst) else {
return false;
};
let alloc_block = func.blocks[block_idx].id;
let uses = build_use_chains(func);
let mut use_blocks = collect_use_blocks(dst, &uses, func);
use_blocks.insert(alloc_block);
let doms = dominators(func);
let postdoms = post_dominators(func);
let start_block = match lca_of_many(&doms, use_blocks.iter().copied()) {
Some(b) => b,
None => return false,
};
let end_block = match lca_of_many(&postdoms, use_blocks.iter().copied()) {
Some(b) => b,
None => return false,
};
if !doms
.get(&alloc_block)
.map(|d| d.contains(&start_block))
.unwrap_or(false)
{
return false;
}
let region_blocks = blocks_on_path(func, start_block, end_block);
if has_back_edge(func, ®ion_blocks, &doms) {
return false;
}
if region_contains_throw(func, ®ion_blocks) {
return false;
}
let region_var = VarId(func.next_var);
func.next_var += 1;
let Inst::Call(call_dst, callee, args) = func.blocks[block_idx].insts[inst_idx].clone() else {
return false;
};
debug_assert_eq!(call_dst, dst);
let full_args: Vec<VarId> = if capture_count == 1 {
let mut v = Vec::with_capacity(args.len() + 1);
v.push(callee);
v.extend(args);
v
} else {
args
};
func.blocks[block_idx].insts[inst_idx] = Inst::CallWithRegion(dst, target_name, full_args);
if let Some(b) = func.blocks.iter_mut().find(|b| b.id == start_block) {
b.insts.insert(0, Inst::RegionStart(region_var));
}
if let Some(b) = func.blocks.iter_mut().find(|b| b.id == end_block) {
b.insts.push(Inst::RegionEnd(region_var));
}
true
}
struct CandidateLoc {
path: Vec<usize>,
candidate: Candidate,
}
fn collect_candidates_in(
func: &IrFunction,
path: Vec<usize>,
ctx: &EscapeContext,
out: &mut Vec<CandidateLoc>,
) {
let uses = build_use_chains(func);
let var_defs = build_var_defs(func);
for block in func.blocks.iter() {
for inst in block.insts.iter() {
let Inst::Call(dst, callee, args) = inst else {
continue;
};
let dst_state = classify_escape_with_ctx(
*dst,
&uses,
func,
Some(ctx),
Some(&var_defs),
EscapeMode::Alloc,
);
if dst_state != EscapeState::NoEscape {
continue;
}
let Some(callee_name) =
resolve_callee_name(*callee, args.len(), &var_defs, &ctx.defn_map)
else {
continue;
};
let Some(callee_fn) = ctx.registry.get(&callee_name) else {
continue;
};
let total_params = callee_fn.params.len();
if total_params < args.len() {
continue;
}
let capture_count = total_params - args.len();
if capture_count > 1 {
continue;
}
let returns_allocs = returns_allocs_of(callee_fn, ctx);
if returns_allocs.is_empty() {
continue;
}
let any_promotable = returns_allocs.iter().any(|alloc_var| {
callee_fn.blocks.iter().any(|b| {
b.insts
.iter()
.any(|i| i.dst() == Some(*alloc_var) && alloc_to_region_kind(i).is_some())
})
});
if !any_promotable {
continue;
}
out.push(CandidateLoc {
path: path.clone(),
candidate: Candidate {
dst: *dst,
callee_fn_name: callee_name,
returns_allocs,
capture_count,
},
});
}
}
for (i, sub) in func.subfunctions.iter().enumerate() {
let mut sub_path = path.clone();
sub_path.push(i);
collect_candidates_in(sub, sub_path, ctx, out);
}
}
fn fn_at_path_mut<'a>(root: &'a mut IrFunction, path: &[usize]) -> &'a mut IrFunction {
let mut cur = root;
for &i in path {
cur = &mut cur.subfunctions[i];
}
cur
}
pub fn promote_cross_fn_allocs(mut root: IrFunction, ctx: &EscapeContext) -> IrFunction {
let mut candidates: Vec<CandidateLoc> = Vec::new();
collect_candidates_in(&root, Vec::new(), ctx, &mut candidates);
if candidates.is_empty() {
return root;
}
type SpecialiseKey = (Vec<usize>, Arc<str>, Vec<u32>);
let mut specialised: HashMap<SpecialiseKey, Arc<str>> = HashMap::new();
let mut counter: usize = 0;
for loc in candidates {
let CandidateLoc { path, candidate } = loc;
let mut alloc_key: Vec<u32> = candidate.returns_allocs.iter().map(|v| v.0).collect();
alloc_key.sort_unstable();
let key = (path.clone(), candidate.callee_fn_name.clone(), alloc_key);
let target_name = if let Some(n) = specialised.get(&key) {
n.clone()
} else {
let original = match ctx.registry.get(&candidate.callee_fn_name) {
Some(f) => f.clone(),
None => continue,
};
counter += 1;
let suffix = format!("__rg{counter}");
let Some(clone) = specialize(&original, &candidate.returns_allocs, &suffix) else {
continue;
};
let new_name = clone.name.clone().expect("specialised has name");
let caller = fn_at_path_mut(&mut root, &path);
caller.subfunctions.push(clone);
specialised.insert(key, new_name.clone());
new_name
};
let caller = fn_at_path_mut(&mut root, &path);
let _ok = rewrite_call_with_region_scope(
caller,
candidate.dst,
target_name,
candidate.capture_count,
);
}
root
}