use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::escape::{ClosureInfo, build_defn_map, build_fn_registry};
use crate::{Block, BlockId, Inst, IrFunction, Terminator, VarId};
const INLINE_THRESHOLD: usize = 20;
const MAX_ROUNDS: usize = 8;
fn instruction_count(ir: &IrFunction) -> usize {
ir.blocks.iter().map(|b| b.insts.len() + b.phis.len()).sum()
}
fn has_load_local(ir: &IrFunction) -> bool {
ir.blocks.iter().any(|b| {
b.insts
.iter()
.chain(b.phis.iter())
.any(|i| matches!(i, Inst::LoadLocal(..)))
})
}
fn is_eligible(callee: &IrFunction, forbidden: &HashSet<Arc<str>>) -> bool {
let name_ok = callee
.name
.as_ref()
.map(|n| !forbidden.contains(n))
.unwrap_or(false); name_ok
&& callee.subfunctions.is_empty()
&& !has_load_local(callee)
&& !callee.blocks.is_empty()
&& instruction_count(callee) <= INLINE_THRESHOLD
}
fn rv(map: &HashMap<VarId, VarId>, v: VarId) -> VarId {
map.get(&v).copied().unwrap_or(v)
}
fn rb(map: &HashMap<BlockId, BlockId>, b: BlockId) -> BlockId {
map.get(&b).copied().unwrap_or(b)
}
fn clone_inst(
inst: &Inst,
var_map: &HashMap<VarId, VarId>,
block_map: &HashMap<BlockId, BlockId>,
) -> Inst {
match inst {
Inst::Const(dst, c) => Inst::Const(rv(var_map, *dst), c.clone()),
Inst::LoadLocal(dst, name) => Inst::LoadLocal(rv(var_map, *dst), name.clone()),
Inst::LoadGlobal(dst, ns, name) => {
Inst::LoadGlobal(rv(var_map, *dst), ns.clone(), name.clone())
}
Inst::LoadVar(dst, ns, name) => Inst::LoadVar(rv(var_map, *dst), ns.clone(), name.clone()),
Inst::AllocVector(dst, elems) => Inst::AllocVector(
rv(var_map, *dst),
elems.iter().map(|&v| rv(var_map, v)).collect(),
),
Inst::AllocMap(dst, pairs) => Inst::AllocMap(
rv(var_map, *dst),
pairs
.iter()
.map(|&(k, v)| (rv(var_map, k), rv(var_map, v)))
.collect(),
),
Inst::AllocSet(dst, elems) => Inst::AllocSet(
rv(var_map, *dst),
elems.iter().map(|&v| rv(var_map, v)).collect(),
),
Inst::AllocList(dst, elems) => Inst::AllocList(
rv(var_map, *dst),
elems.iter().map(|&v| rv(var_map, v)).collect(),
),
Inst::AllocCons(dst, h, t) => {
Inst::AllocCons(rv(var_map, *dst), rv(var_map, *h), rv(var_map, *t))
}
Inst::AllocClosure(dst, tmpl, captures) => Inst::AllocClosure(
rv(var_map, *dst),
tmpl.clone(),
captures.iter().map(|&v| rv(var_map, v)).collect(),
),
Inst::CallKnown(dst, func, args) => Inst::CallKnown(
rv(var_map, *dst),
func.clone(),
args.iter().map(|&v| rv(var_map, v)).collect(),
),
Inst::Call(dst, callee, args) => Inst::Call(
rv(var_map, *dst),
rv(var_map, *callee),
args.iter().map(|&v| rv(var_map, v)).collect(),
),
Inst::CallDirect(dst, name, args) => Inst::CallDirect(
rv(var_map, *dst),
name.clone(),
args.iter().map(|&v| rv(var_map, v)).collect(),
),
Inst::Deref(dst, src) => Inst::Deref(rv(var_map, *dst), rv(var_map, *src)),
Inst::DefVar(dst, ns, name, val) => Inst::DefVar(
rv(var_map, *dst),
ns.clone(),
name.clone(),
rv(var_map, *val),
),
Inst::SetBang(var, val) => Inst::SetBang(rv(var_map, *var), rv(var_map, *val)),
Inst::Throw(val) => Inst::Throw(rv(var_map, *val)),
Inst::Phi(dst, entries) => Inst::Phi(
rv(var_map, *dst),
entries
.iter()
.map(|&(bid, v)| (rb(block_map, bid), rv(var_map, v)))
.collect(),
),
Inst::Recur(args) => Inst::Recur(args.iter().map(|&v| rv(var_map, v)).collect()),
Inst::SourceLoc(span) => Inst::SourceLoc(span.clone()),
Inst::RegionStart(dst) => Inst::RegionStart(rv(var_map, *dst)),
Inst::RegionAlloc(dst, region, kind, ops) => Inst::RegionAlloc(
rv(var_map, *dst),
rv(var_map, *region),
*kind,
ops.iter().map(|&v| rv(var_map, v)).collect(),
),
Inst::RegionEnd(region) => Inst::RegionEnd(rv(var_map, *region)),
Inst::RegionParam(dst) => Inst::RegionParam(rv(var_map, *dst)),
Inst::CallWithRegion(dst, name, args) => Inst::CallWithRegion(
rv(var_map, *dst),
name.clone(),
args.iter().map(|&v| rv(var_map, v)).collect(),
),
}
}
fn clone_terminator(
term: &Terminator,
var_map: &HashMap<VarId, VarId>,
block_map: &HashMap<BlockId, BlockId>,
cont_block: BlockId,
) -> Terminator {
match term {
Terminator::Return(_) => Terminator::Jump(cont_block),
Terminator::Jump(b) => Terminator::Jump(rb(block_map, *b)),
Terminator::Branch {
cond,
then_block,
else_block,
} => Terminator::Branch {
cond: rv(var_map, *cond),
then_block: rb(block_map, *then_block),
else_block: rb(block_map, *else_block),
},
Terminator::RecurJump { target, args } => Terminator::RecurJump {
target: rb(block_map, *target),
args: args.iter().map(|&v| rv(var_map, v)).collect(),
},
Terminator::Unreachable => Terminator::Unreachable,
}
}
fn resolve<'r>(
callee_var: VarId,
arg_count: usize,
var_defs: &HashMap<VarId, &Inst>,
defn_map: &HashMap<(Arc<str>, Arc<str>), ClosureInfo>,
registry: &'r HashMap<Arc<str>, IrFunction>,
) -> Option<(Arc<str>, &'r IrFunction)> {
let fn_name = match var_defs.get(&callee_var)? {
Inst::LoadGlobal(_, ns, name) => {
let info = defn_map.get(&(ns.clone(), name.clone()))?;
info.arity_fn_names
.iter()
.zip(&info.param_counts)
.zip(&info.is_variadic)
.find(|&((_, &pc), &var)| pc == arg_count && !var)
.map(|((name, _), _)| name.clone())?
}
_ => return None,
};
let callee = registry.get(&fn_name)?;
Some((fn_name, callee))
}
fn do_inline(
mut caller: IrFunction,
block_idx: usize,
inst_idx: usize,
callee: &IrFunction,
callee_self: VarId,
args: Vec<VarId>,
call_dst: VarId,
) -> IrFunction {
let mut var_map: HashMap<VarId, VarId> = HashMap::new();
let mut block_map: HashMap<BlockId, BlockId> = HashMap::new();
if callee.params.len() == args.len() + 1 {
let (self_param, user_params) = callee.params.split_first().expect("non-empty");
var_map.insert(self_param.1, callee_self);
for (i, (_, param_var)) in user_params.iter().enumerate() {
var_map.insert(*param_var, args[i]);
}
} else {
for (i, (_, param_var)) in callee.params.iter().enumerate() {
var_map.insert(*param_var, args[i]);
}
}
for block in &callee.blocks {
for inst in block.phis.iter().chain(block.insts.iter()) {
if let Some(dst) = inst.dst() {
var_map.entry(dst).or_insert_with(|| {
let fresh = VarId(caller.next_var);
caller.next_var += 1;
fresh
});
}
}
}
for block in &callee.blocks {
let fresh = BlockId(caller.next_block);
caller.next_block += 1;
block_map.insert(block.id, fresh);
}
let cont_id = BlockId(caller.next_block);
caller.next_block += 1;
let return_sites: Vec<(BlockId, VarId)> = callee
.blocks
.iter()
.filter_map(|b| {
if let Terminator::Return(ret_var) = &b.terminator {
Some((block_map[&b.id], rv(&var_map, *ret_var)))
} else {
None
}
})
.collect();
let cloned_blocks: Vec<Block> = callee
.blocks
.iter()
.map(|block| Block {
id: block_map[&block.id],
phis: block
.phis
.iter()
.map(|i| clone_inst(i, &var_map, &block_map))
.collect(),
insts: block
.insts
.iter()
.map(|i| clone_inst(i, &var_map, &block_map))
.collect(),
terminator: clone_terminator(&block.terminator, &var_map, &block_map, cont_id),
})
.collect();
let callee_entry = block_map[&callee.blocks[0].id];
let orig_block = &mut caller.blocks[block_idx];
let orig_phis = orig_block.phis.clone();
let pre_insts: Vec<Inst> = orig_block.insts[..inst_idx].to_vec();
let post_insts: Vec<Inst> = orig_block.insts[inst_idx + 1..].to_vec();
let post_term = orig_block.terminator.clone();
orig_block.phis = orig_phis;
orig_block.insts = pre_insts;
orig_block.terminator = Terminator::Jump(callee_entry);
let cont_block = Block {
id: cont_id,
phis: if return_sites.is_empty() {
vec![]
} else {
vec![Inst::Phi(call_dst, return_sites)]
},
insts: post_insts,
terminator: post_term,
};
caller.blocks.extend(cloned_blocks);
caller.blocks.push(cont_block);
caller
}
fn inline_one_round(
mut func: IrFunction,
registry: &HashMap<Arc<str>, IrFunction>,
defn_map: &HashMap<(Arc<str>, Arc<str>), ClosureInfo>,
forbidden: &HashSet<Arc<str>>,
) -> (IrFunction, bool) {
let var_defs_owned: HashMap<VarId, Inst> = func
.blocks
.iter()
.flat_map(|b| b.phis.iter().chain(b.insts.iter()))
.filter_map(|i| i.dst().map(|dst| (dst, i.clone())))
.collect();
let var_defs: HashMap<VarId, &Inst> = var_defs_owned.iter().map(|(k, v)| (*k, v)).collect();
let mut changed = false;
let mut block_idx = 0;
while block_idx < func.blocks.len() {
let found = func.blocks[block_idx]
.insts
.iter()
.enumerate()
.find_map(|(inst_idx, inst)| {
let Inst::Call(dst, callee_var, args) = inst else {
return None;
};
let (fn_name, callee) =
resolve(*callee_var, args.len(), &var_defs, defn_map, registry)?;
if !is_eligible(callee, forbidden) {
return None;
}
Some((inst_idx, fn_name, *callee_var, args.clone(), *dst))
});
if let Some((inst_idx, fn_name, callee_self, args, call_dst)) = found {
let callee = registry[&fn_name].clone();
func = do_inline(
func,
block_idx,
inst_idx,
&callee,
callee_self,
args,
call_dst,
);
changed = true;
}
block_idx += 1;
}
(func, changed)
}
fn inline_fn(
func: IrFunction,
registry: &HashMap<Arc<str>, IrFunction>,
defn_map: &HashMap<(Arc<str>, Arc<str>), ClosureInfo>,
forbidden: &HashSet<Arc<str>>,
) -> IrFunction {
let mut func = func;
for _ in 0..MAX_ROUNDS {
let (new_func, changed) = inline_one_round(func, registry, defn_map, forbidden);
func = new_func;
if !changed {
break;
}
}
func
}
fn inline_tree(
mut func: IrFunction,
registry: &HashMap<Arc<str>, IrFunction>,
defn_map: &HashMap<(Arc<str>, Arc<str>), ClosureInfo>,
) -> IrFunction {
let subs = std::mem::take(&mut func.subfunctions);
func.subfunctions = subs
.into_iter()
.map(|sub| inline_tree(sub, registry, defn_map))
.collect();
let mut forbidden: HashSet<Arc<str>> = HashSet::new();
if let Some(name) = &func.name {
forbidden.insert(name.clone());
}
inline_fn(func, registry, defn_map, &forbidden)
}
pub fn inline(ir_func: IrFunction) -> IrFunction {
let registry = build_fn_registry(&ir_func);
let defn_map = build_defn_map(&ir_func);
inline_tree(ir_func, ®istry, &defn_map)
}
#[cfg(test)]
pub fn count_insts(func: &IrFunction) -> usize {
instruction_count(func)
}
#[cfg(test)]
pub fn check_has_load_local(func: &IrFunction) -> bool {
has_load_local(func)
}