use std::collections::BTreeMap;
use crate::{
analysis::{
CilTarget, DefSite, MethodRef, ReturnInfo, SsaFunction, SsaInstruction, SsaOp, SsaType,
SsaVarId, VariableOrigin,
},
compiler::{
pass::{PassCapability, SsaPass},
CompilerContext, EventKind, EventLog,
},
metadata::{tables::MemberRefSignature, token::Token, typesystem::CilTypeReference},
CilObject,
};
#[derive(Debug, Clone)]
struct InlineCandidate {
block_idx: usize,
instr_idx: usize,
callee_token: Token,
}
struct InliningContext<'a> {
pass: &'a InliningPass,
caller_ssa: &'a mut SsaFunction,
caller_token: Token,
analysis_ctx: &'a CompilerContext,
assembly: &'a CilObject,
changes: EventLog,
}
impl<'a> InliningContext<'a> {
fn new(
pass: &'a InliningPass,
caller_ssa: &'a mut SsaFunction,
caller_token: Token,
analysis_ctx: &'a CompilerContext,
assembly: &'a CilObject,
) -> Self {
Self {
pass,
caller_ssa,
caller_token,
analysis_ctx,
assembly,
changes: EventLog::new(),
}
}
fn has_changes(&self) -> bool {
!self.changes.is_empty()
}
fn into_changes(self) -> EventLog {
self.changes
}
fn is_valid_target(&self, callee_token: Token) -> bool {
if callee_token == self.caller_token {
return false;
}
if self.analysis_ctx.no_inline.contains(&callee_token) {
return false;
}
true
}
fn should_inline(&self, callee_token: Token) -> bool {
if !self.is_valid_target(callee_token) {
return false;
}
self.analysis_ctx
.with_ssa(callee_token, |callee_ssa| {
let instr_count = callee_ssa.instruction_count();
if instr_count > self.pass.inline_threshold {
return false;
}
if !callee_ssa.purity().can_inline() {
return false;
}
if self.has_recursive_call(callee_token, callee_ssa) {
return false;
}
true
})
.unwrap_or(false)
}
fn has_recursive_call(&self, method_token: Token, ssa: &SsaFunction) -> bool {
for block in ssa.blocks() {
for instr in block.instructions() {
if let SsaOp::Call { method, .. } | SsaOp::CallVirt { method, .. } = instr.op() {
if method.token() == method_token {
return true;
}
let callees = self.analysis_ctx.call_graph.callees(method.token());
if callees.contains(&method_token) {
return true;
}
}
}
}
false
}
fn resolve_to_method_def(&self, token: Token) -> Token {
let table_id = token.table();
if table_id == 0x06 {
return token;
}
if table_id == 0x0A {
let refs = self.assembly.refs_members();
if let Some(member_ref_entry) = refs.get(&token) {
let member_ref = member_ref_entry.value();
let MemberRefSignature::Method(ref _method_sig) = member_ref.signature else {
return token;
};
if let CilTypeReference::TypeDef(type_ref) = &member_ref.declaredby {
if let Some(type_info) = type_ref.upgrade() {
if let Some(method) = type_info
.query_methods()
.name(&member_ref.name)
.find_first()
{
return method.token;
}
}
}
}
}
token
}
fn find_candidates(&self) -> Vec<InlineCandidate> {
let mut candidates = Vec::new();
for (block_idx, instr_idx, instr) in self.caller_ssa.iter_instructions() {
let raw_callee_token = match instr.op() {
SsaOp::Call { method, .. } => method.token(),
SsaOp::CallVirt { method, .. } => {
let token = method.token();
if self
.analysis_ctx
.call_graph
.resolver()
.is_polymorphic(token)
{
continue;
}
token
}
_ => continue,
};
let callee_token = self.resolve_to_method_def(raw_callee_token);
if self.should_inline(callee_token) {
candidates.push(InlineCandidate {
block_idx,
instr_idx,
callee_token,
});
}
}
candidates
}
fn process_candidate(&mut self, candidate: &InlineCandidate) -> bool {
let call_op = match self.caller_ssa.block(candidate.block_idx) {
Some(block) => match block.instructions().get(candidate.instr_idx) {
Some(instr) => instr.op().clone(),
None => return false,
},
None => return false,
};
let Some(callee_ssa) = self
.analysis_ctx
.with_ssa(candidate.callee_token, Clone::clone)
else {
return false;
};
let success = self.inline_call(
&callee_ssa,
candidate.block_idx,
candidate.instr_idx,
&call_op,
candidate.callee_token,
);
if success {
self.analysis_ctx.mark_inlined(candidate.callee_token);
}
success
}
fn inline_call(
&mut self,
callee_ssa: &SsaFunction,
call_block_idx: usize,
call_instr_idx: usize,
call_op: &SsaOp,
callee_token: Token,
) -> bool {
let (dest, args) = match call_op {
SsaOp::Call { dest, args, .. } | SsaOp::CallVirt { dest, args, .. } => {
(*dest, args.clone())
}
_ => return false,
};
match callee_ssa.return_info() {
ReturnInfo::Constant(value) => {
if let Some(dest_var) = dest {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) {
instr.set_op(SsaOp::Const {
dest: dest_var,
value: value.clone(),
});
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("inlined constant {callee_token:?}"));
return true;
}
}
} else {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) {
instr.set_op(SsaOp::Nop);
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("eliminated pure call {callee_token:?}"));
return true;
}
}
}
}
ReturnInfo::PassThrough(param_idx) => {
if let Some(dest_var) = dest {
if let Some(&src_var) = args.get(param_idx) {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) {
instr.set_op(SsaOp::Copy {
dest: dest_var,
src: src_var,
});
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("inlined passthrough {callee_token:?}"));
return true;
}
}
}
}
}
ReturnInfo::Void => {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) {
instr.set_op(SsaOp::Nop);
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("eliminated void call {callee_token:?}"));
return true;
}
}
}
ReturnInfo::PureComputation | ReturnInfo::Dynamic | ReturnInfo::Unknown => {
return self.inline_full(
callee_ssa,
call_block_idx,
call_instr_idx,
dest,
&args,
callee_token,
);
}
}
false
}
fn inline_full(
&mut self,
callee_ssa: &SsaFunction,
call_block_idx: usize,
call_instr_idx: usize,
dest: Option<SsaVarId>,
args: &[SsaVarId],
callee_token: Token,
) -> bool {
if callee_ssa.blocks().len() != 1 {
return false;
}
let Some(callee_block) = callee_ssa.blocks().first() else {
return false;
};
let mut var_remap: BTreeMap<SsaVarId, SsaVarId> = BTreeMap::new();
for (param_idx, &arg_var) in args.iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
if let Some(param_var) = callee_ssa
.variables_from_argument(param_idx as u16)
.find(|v| v.version() == 0)
{
var_remap.insert(param_var.id(), arg_var);
}
}
let mut inlined_ops: Vec<SsaOp> = Vec::new();
let mut return_value: Option<SsaVarId> = None;
for instr in callee_block.instructions() {
let op = instr.op();
if let SsaOp::Return { value } = op {
return_value = *value;
} else {
let remapped_op = Self::remap_op(op, &mut var_remap, callee_ssa, self.caller_ssa);
inlined_ops.push(remapped_op);
}
}
let Some(block) = self.caller_ssa.block_mut(call_block_idx) else {
return false;
};
let first_op = inlined_ops.first().cloned();
if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) {
if let Some(op) = first_op {
instr.set_op(op);
} else {
instr.set_op(SsaOp::Nop);
}
} else {
return false;
}
let instructions = block.instructions_mut();
let base = call_instr_idx.saturating_add(1);
for (i, op) in inlined_ops.into_iter().skip(1).enumerate() {
instructions.insert(base.saturating_add(i), SsaInstruction::synthetic(op));
}
if let (Some(dest_var), Some(ret_var)) = (dest, return_value) {
let remapped_ret = var_remap.get(&ret_var).copied().unwrap_or(ret_var);
if dest_var != remapped_ret {
let Some(block) = self.caller_ssa.block_mut(call_block_idx) else {
return false;
};
let insert_pos = call_instr_idx.saturating_add(1);
block.instructions_mut().insert(
insert_pos,
SsaInstruction::synthetic(SsaOp::Copy {
dest: dest_var,
src: remapped_ret,
}),
);
}
}
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("fully inlined {callee_token:?}"));
true
}
fn remap_op(
op: &SsaOp,
var_remap: &mut BTreeMap<SsaVarId, SsaVarId>,
callee_ssa: &SsaFunction,
caller_ssa: &mut SsaFunction,
) -> SsaOp {
let mut cloned = op.clone();
if let Some(dest) = cloned.dest() {
let new_dest = Self::get_or_create_var(dest, var_remap, callee_ssa, caller_ssa);
cloned.set_dest(new_dest);
}
for used in op.uses() {
let new_var = var_remap.get(&used).copied().unwrap_or(used);
cloned.replace_uses(used, new_var);
}
cloned
}
fn get_or_create_var(
var: SsaVarId,
var_remap: &mut BTreeMap<SsaVarId, SsaVarId>,
callee_ssa: &SsaFunction,
caller_ssa: &mut SsaFunction,
) -> SsaVarId {
if let Some(&remapped) = var_remap.get(&var) {
return remapped;
}
let var_type = callee_ssa
.variable(var)
.map(|v| v.var_type().clone())
.unwrap_or(SsaType::Unknown);
let new_id = caller_ssa.create_variable(
VariableOrigin::Phi,
0,
DefSite::instruction(0, 0),
var_type,
);
var_remap.insert(var, new_id);
new_id
}
}
#[derive(Debug)]
pub struct InliningPass {
inline_threshold: usize,
}
impl InliningPass {
#[must_use]
pub fn new(threshold: usize) -> Self {
Self {
inline_threshold: threshold,
}
}
}
impl SsaPass<CilTarget, CompilerContext> for InliningPass {
fn name(&self) -> &'static str {
"InliningPass"
}
fn description(&self) -> &'static str {
"Inlines small, pure methods at their call sites"
}
fn reads_peer_ssa(&self) -> bool {
true
}
fn provides(&self) -> &[PassCapability] {
&[PassCapability::InlinedMethods]
}
fn requires(&self) -> &[PassCapability] {
&[PassCapability::DevirtualizedCalls]
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method: &MethodRef,
host: &CompilerContext,
) -> analyssa::Result<bool> {
let assembly = host
.assembly()
.ok_or_else(|| analyssa::Error::new("InliningPass requires an assembly"))?;
let mut inline_ctx = InliningContext::new(self, ssa, method.0, host, &assembly);
let candidates = inline_ctx.find_candidates();
if candidates.is_empty() {
return Ok(false);
}
for candidate in candidates.into_iter().rev() {
inline_ctx.process_candidate(&candidate);
}
let changed = inline_ctx.has_changes();
if changed {
host.events.merge(&inline_ctx.into_changes());
}
Ok(changed)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
analysis::{
CallGraph, ConstValue, MethodRef, SsaFunctionBuilder, SsaOp, SsaType, SsaVarId,
},
compiler::{
passes::inlining::{InliningContext, InliningPass},
CompilerContext, SsaPass,
},
metadata::token::Token,
test::helpers::test_assembly_arc,
CilObject,
};
fn test_context() -> CompilerContext {
CompilerContext::new(Arc::new(CallGraph::new()))
}
fn test_assembly() -> Arc<CilObject> {
test_assembly_arc()
}
#[test]
fn test_pass_creation() {
let pass = InliningPass::new(20);
assert_eq!(pass.name(), "InliningPass");
assert_eq!(pass.inline_threshold, 20);
let pass_custom = InliningPass::new(50);
assert_eq!(pass_custom.inline_threshold, 50);
}
#[test]
fn test_inline_constant_return() {
let callee_token = Token::new(0x06000002);
let (callee_ssa, callee_v0) = {
let mut v0_out = SsaVarId::from_index(0);
let ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
v0_out = v0;
b.ret_val(v0);
});
})
.unwrap();
(ssa, v0_out)
};
let caller_token = Token::new(0x06000001);
let (mut caller_ssa, call_dest) = {
let mut dest_out = SsaVarId::from_index(1);
let ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| {
let dest = b.call(MethodRef::new(callee_token), &[], SsaType::I32);
dest_out = dest;
b.ret_val(dest);
});
})
.unwrap();
(ssa, dest_out)
};
let ctx = test_context();
ctx.set_ssa(callee_token, callee_ssa.clone());
let call_op = caller_ssa.block(0).unwrap().instructions()[0].op().clone();
let pass = InliningPass::new(20);
let assembly = test_assembly();
let mut inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
let result = inline_ctx.inline_call(&callee_ssa, 0, 0, &call_op, callee_token);
assert!(result, "Inlining should succeed");
let block = inline_ctx.caller_ssa.block(0).unwrap();
let first_instr = &block.instructions()[0];
match first_instr.op() {
SsaOp::Const { dest, value } => {
assert_eq!(*dest, call_dest);
assert_eq!(*value, ConstValue::I32(42));
}
other => panic!("Expected Const, got {:?}", other),
}
let _ = callee_v0;
}
#[test]
fn test_inline_void_pure() {
let callee_token = Token::new(0x06000002);
let callee_ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| b.ret());
})
.unwrap();
let caller_token = Token::new(0x06000001);
let mut caller_ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| {
b.call_void(MethodRef::new(callee_token), &[]);
b.ret();
});
})
.unwrap();
let ctx = test_context();
ctx.set_ssa(callee_token, callee_ssa.clone());
let call_op = caller_ssa.block(0).unwrap().instructions()[0].op().clone();
let pass = InliningPass::new(20);
let assembly = test_assembly();
let mut inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
let result = inline_ctx.inline_call(&callee_ssa, 0, 0, &call_op, callee_token);
assert!(result, "Inlining should succeed");
let block = inline_ctx.caller_ssa.block(0).unwrap();
let first_instr = &block.instructions()[0];
assert!(
matches!(first_instr.op(), SsaOp::Nop),
"Expected Nop, got {:?}",
first_instr.op()
);
}
#[test]
fn test_no_inline_self_recursion() {
let pass = InliningPass::new(20);
let token = Token::new(0x06000001);
let ctx = test_context();
let assembly = test_assembly();
let mut dummy_ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| b.ret());
})
.unwrap();
let inline_ctx = InliningContext::new(&pass, &mut dummy_ssa, token, &ctx, &assembly);
assert!(!inline_ctx.is_valid_target(token));
}
#[test]
fn test_no_inline_large_method() {
let callee_token = Token::new(0x06000002);
let caller_token = Token::new(0x06000001);
let callee_ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| {
for _ in 0..30 {
let _ = b.const_i32(0);
}
b.ret();
});
})
.unwrap();
let ctx = test_context();
ctx.set_ssa(callee_token, callee_ssa);
let pass = InliningPass::new(20);
let assembly = test_assembly();
let mut caller_ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| b.ret());
})
.unwrap();
let inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
assert!(!inline_ctx.should_inline(callee_token));
}
#[test]
fn test_inline_full_computation() {
let callee_token = Token::new(0x06000002);
let callee_ssa = SsaFunctionBuilder::new(1, 0)
.build_with(|f| {
let param0 = f.arg(0, SsaType::I32);
f.block(0, |b| {
let v1 = b.const_i32(10);
let v2 = b.add(param0, v1);
b.ret_val(v2);
});
})
.unwrap();
let caller_token = Token::new(0x06000001);
let mut caller_ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(5);
let v1 = b.call(MethodRef::new(callee_token), &[v0], SsaType::I32);
b.ret_val(v1);
});
})
.unwrap();
let ctx = test_context();
ctx.set_ssa(callee_token, callee_ssa.clone());
let call_op = caller_ssa.block(0).unwrap().instructions()[1].op().clone();
let pass = InliningPass::new(20);
let assembly = test_assembly();
let mut inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
let result = inline_ctx.inline_call(&callee_ssa, 0, 1, &call_op, callee_token);
assert!(result, "Full inlining should succeed");
let block = inline_ctx.caller_ssa.block(0).unwrap();
assert!(
block.instructions().len() > 3,
"Expected inlined instructions, got {} instructions",
block.instructions().len()
);
let second_instr = &block.instructions()[1];
match second_instr.op() {
SsaOp::Const { value, .. } => {
assert_eq!(*value, ConstValue::I32(10), "Expected inlined constant 10");
}
other => panic!("Expected Const for inlined instruction, got {:?}", other),
}
let has_add = block
.instructions()
.iter()
.any(|i| matches!(i.op(), SsaOp::Add { .. }));
assert!(has_add, "Expected Add instruction after inlining");
}
}