use std::{collections::HashMap, sync::Arc};
use crate::{
analysis::{
ConstValue, DefSite, MethodRef, ReturnInfo, SsaFunction, SsaInstruction, SsaOp, SsaVarId,
SsaVariable, VariableOrigin,
},
compiler::{pass::SsaPass, CompilerContext, EventKind, EventLog},
metadata::{tables::MemberRefSignature, token::Token, typesystem::CilTypeReference},
CilObject, Result,
};
#[derive(Debug, Clone)]
enum InlineAction {
FullInline,
ProxyDevirtualize {
target_method: MethodRef,
arg_mapping: Vec<usize>,
is_virtual: bool,
},
NoOpEliminate,
ConstantFold(ConstValue),
}
#[derive(Debug, Clone)]
struct InlineCandidate {
block_idx: usize,
instr_idx: usize,
callee_token: Token,
action: InlineAction,
}
struct InliningContext<'a> {
pass: &'a InliningPass,
caller_ssa: &'a mut SsaFunction,
caller_token: Token,
analysis_ctx: &'a CompilerContext,
assembly: &'a Arc<CilObject>,
changes: EventLog,
}
impl<'a> InliningContext<'a> {
fn new(
pass: &'a InliningPass,
caller_ssa: &'a mut SsaFunction,
caller_token: Token,
analysis_ctx: &'a CompilerContext,
assembly: &'a Arc<CilObject>,
) -> Self {
Self {
pass,
caller_ssa,
caller_token,
analysis_ctx,
assembly,
changes: EventLog::new(),
}
}
fn has_changes(&self) -> bool {
!self.changes.is_empty()
}
fn into_changes(self) -> EventLog {
self.changes
}
fn is_valid_target(&self, callee_token: Token) -> bool {
if callee_token == self.caller_token {
return false;
}
if self.analysis_ctx.no_inline.contains(&callee_token) {
return false;
}
true
}
fn should_inline(&self, callee_token: Token) -> bool {
if !self.is_valid_target(callee_token) {
return false;
}
self.analysis_ctx
.with_ssa(callee_token, |callee_ssa| {
let instr_count = callee_ssa.instruction_count();
if instr_count > self.pass.inline_threshold {
return false;
}
if !callee_ssa.purity().can_inline() {
return false;
}
if self.has_recursive_call(callee_token, callee_ssa) {
return false;
}
true
})
.unwrap_or(false)
}
fn has_recursive_call(&self, method_token: Token, ssa: &SsaFunction) -> bool {
for block in ssa.blocks() {
for instr in block.instructions() {
if let SsaOp::Call { method, .. } | SsaOp::CallVirt { method, .. } = instr.op() {
if method.token() == method_token {
return true;
}
let callees = self.analysis_ctx.call_graph.callees(method.token());
if callees.contains(&method_token) {
return true;
}
}
}
}
false
}
fn should_devirtualize_proxy(
&self,
callee_token: Token,
) -> Option<(MethodRef, Vec<usize>, bool)> {
if !self.is_valid_target(callee_token) {
return None;
}
self.analysis_ctx
.with_ssa(callee_token, |callee_ssa| {
InliningPass::detect_proxy_pattern(callee_ssa)
})
.flatten()
}
#[allow(clippy::option_option)]
fn detect_noop_method(&self, callee_token: Token) -> Option<Option<ConstValue>> {
if !self.is_valid_target(callee_token) {
return None;
}
self.analysis_ctx
.with_ssa(callee_token, |callee_ssa| {
if !callee_ssa.purity().can_eliminate_if_unused() {
return None;
}
match callee_ssa.return_info() {
ReturnInfo::Void => Some(None),
ReturnInfo::Constant(val) => Some(Some(val.clone())),
_ => None,
}
})
.flatten()
}
fn resolve_to_method_def(&self, token: Token) -> Token {
let table_id = token.table();
if table_id == 0x06 {
return token;
}
if table_id == 0x0A {
let refs = self.assembly.refs_members();
if let Some(member_ref_entry) = refs.get(&token) {
let member_ref = member_ref_entry.value();
let MemberRefSignature::Method(ref _method_sig) = member_ref.signature else {
return token;
};
if let CilTypeReference::TypeDef(type_ref) = &member_ref.declaredby {
if let Some(type_info) = type_ref.upgrade() {
if let Some(method) = type_info
.query_methods()
.name(&member_ref.name)
.find_first()
{
return method.token;
}
}
}
}
}
token
}
fn find_candidates(&self) -> Vec<InlineCandidate> {
let mut candidates = Vec::new();
for (block_idx, instr_idx, instr) in self.caller_ssa.iter_instructions() {
let raw_callee_token = match instr.op() {
SsaOp::Call { method, .. } => method.token(),
SsaOp::CallVirt { method, .. } => {
let token = method.token();
if self
.analysis_ctx
.call_graph
.resolver()
.is_polymorphic(token)
{
continue;
}
token
}
_ => continue,
};
let callee_token = self.resolve_to_method_def(raw_callee_token);
if !self.pass.proxy_only {
if self.should_inline(callee_token) {
candidates.push(InlineCandidate {
block_idx,
instr_idx,
callee_token,
action: InlineAction::FullInline,
});
continue;
}
if let Some(noop_result) = self.detect_noop_method(callee_token) {
let action = match noop_result {
None => InlineAction::NoOpEliminate,
Some(val) => InlineAction::ConstantFold(val),
};
candidates.push(InlineCandidate {
block_idx,
instr_idx,
callee_token,
action,
});
continue;
}
}
if let Some((target_method, arg_mapping, is_virtual)) =
self.should_devirtualize_proxy(callee_token)
{
candidates.push(InlineCandidate {
block_idx,
instr_idx,
callee_token,
action: InlineAction::ProxyDevirtualize {
target_method,
arg_mapping,
is_virtual,
},
});
}
}
candidates
}
fn process_candidate(&mut self, candidate: &InlineCandidate) -> bool {
let call_op = match self.caller_ssa.block(candidate.block_idx) {
Some(block) => match block.instructions().get(candidate.instr_idx) {
Some(instr) => instr.op().clone(),
None => return false,
},
None => return false,
};
let success = match &candidate.action {
InlineAction::FullInline => {
let Some(callee_ssa) = self
.analysis_ctx
.with_ssa(candidate.callee_token, Clone::clone)
else {
return false;
};
self.inline_call(
&callee_ssa,
candidate.block_idx,
candidate.instr_idx,
&call_op,
candidate.callee_token,
)
}
InlineAction::ProxyDevirtualize {
target_method,
arg_mapping,
is_virtual,
} => self.devirtualize_proxy(
candidate.block_idx,
candidate.instr_idx,
&call_op,
*target_method,
arg_mapping,
*is_virtual,
candidate.callee_token,
),
InlineAction::NoOpEliminate => self.eliminate_noop_call(
candidate.block_idx,
candidate.instr_idx,
&call_op,
candidate.callee_token,
),
InlineAction::ConstantFold(const_val) => self.fold_constant_call(
candidate.block_idx,
candidate.instr_idx,
&call_op,
const_val,
candidate.callee_token,
),
};
if success {
self.analysis_ctx.mark_inlined(candidate.callee_token);
}
success
}
fn inline_call(
&mut self,
callee_ssa: &SsaFunction,
call_block_idx: usize,
call_instr_idx: usize,
call_op: &SsaOp,
callee_token: Token,
) -> bool {
let (dest, args) = match call_op {
SsaOp::Call { dest, args, .. } | SsaOp::CallVirt { dest, args, .. } => {
(*dest, args.clone())
}
_ => return false,
};
match callee_ssa.return_info() {
ReturnInfo::Constant(value) => {
if let Some(dest_var) = dest {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
let instr = &mut block.instructions_mut()[call_instr_idx];
instr.set_op(SsaOp::Const {
dest: dest_var,
value: value.clone(),
});
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("inlined constant {callee_token:?}"));
return true;
}
} else {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
let instr = &mut block.instructions_mut()[call_instr_idx];
instr.set_op(SsaOp::Nop);
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("eliminated pure call {callee_token:?}"));
return true;
}
}
}
ReturnInfo::PassThrough(param_idx) => {
if let Some(dest_var) = dest {
if let Some(&src_var) = args.get(param_idx) {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
let instr = &mut block.instructions_mut()[call_instr_idx];
instr.set_op(SsaOp::Copy {
dest: dest_var,
src: src_var,
});
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("inlined passthrough {callee_token:?}"));
return true;
}
}
}
}
ReturnInfo::Void => {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
let instr = &mut block.instructions_mut()[call_instr_idx];
instr.set_op(SsaOp::Nop);
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("eliminated void call {callee_token:?}"));
return true;
}
}
ReturnInfo::PureComputation | ReturnInfo::Dynamic | ReturnInfo::Unknown => {
return self.inline_full(
callee_ssa,
call_block_idx,
call_instr_idx,
dest,
&args,
callee_token,
);
}
}
false
}
fn inline_full(
&mut self,
callee_ssa: &SsaFunction,
call_block_idx: usize,
call_instr_idx: usize,
dest: Option<SsaVarId>,
args: &[SsaVarId],
callee_token: Token,
) -> bool {
if callee_ssa.blocks().len() != 1 {
return false;
}
let Some(callee_block) = callee_ssa.blocks().first() else {
return false;
};
let mut var_remap: HashMap<SsaVarId, SsaVarId> = HashMap::new();
for (param_idx, &arg_var) in args.iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
if let Some(param_var) = callee_ssa
.variables_from_argument(param_idx as u16)
.find(|v| v.version() == 0)
{
var_remap.insert(param_var.id(), arg_var);
}
}
let mut inlined_ops: Vec<SsaOp> = Vec::new();
let mut return_value: Option<SsaVarId> = None;
for instr in callee_block.instructions() {
let op = instr.op();
if let SsaOp::Return { value } = op {
return_value = *value;
} else {
let remapped_op = Self::remap_op(op, &mut var_remap, callee_ssa, self.caller_ssa);
inlined_ops.push(remapped_op);
}
}
let Some(block) = self.caller_ssa.block_mut(call_block_idx) else {
return false;
};
if let Some(first_op) = inlined_ops.first().cloned() {
block.instructions_mut()[call_instr_idx].set_op(first_op);
} else {
block.instructions_mut()[call_instr_idx].set_op(SsaOp::Nop);
}
let instructions = block.instructions_mut();
for (i, op) in inlined_ops.into_iter().skip(1).enumerate() {
instructions.insert(call_instr_idx + 1 + i, SsaInstruction::synthetic(op));
}
if let (Some(dest_var), Some(ret_var)) = (dest, return_value) {
let remapped_ret = var_remap.get(&ret_var).copied().unwrap_or(ret_var);
if dest_var != remapped_ret {
let Some(block) = self.caller_ssa.block_mut(call_block_idx) else {
return false;
};
let insert_pos = call_instr_idx + 1;
block.instructions_mut().insert(
insert_pos,
SsaInstruction::synthetic(SsaOp::Copy {
dest: dest_var,
src: remapped_ret,
}),
);
}
}
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!("fully inlined {callee_token:?}"));
true
}
#[allow(clippy::too_many_arguments)]
fn devirtualize_proxy(
&mut self,
call_block_idx: usize,
call_instr_idx: usize,
call_op: &SsaOp,
target_method: MethodRef,
arg_mapping: &[usize],
is_virtual: bool,
proxy_token: Token,
) -> bool {
let (dest, original_args) = match call_op {
SsaOp::Call { dest, args, .. } | SsaOp::CallVirt { dest, args, .. } => {
(*dest, args.clone())
}
_ => return false,
};
let mut remapped_args = Vec::with_capacity(arg_mapping.len());
for ¶m_idx in arg_mapping {
if let Some(&arg) = original_args.get(param_idx) {
remapped_args.push(arg);
} else {
return false;
}
}
let new_op = if is_virtual {
SsaOp::CallVirt {
dest,
method: target_method,
args: remapped_args,
}
} else {
SsaOp::Call {
dest,
method: target_method,
args: remapped_args,
}
};
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
block.instructions_mut()[call_instr_idx].set_op(new_op);
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!(
"devirtualized proxy {:?} -> {:?}",
proxy_token,
target_method.token()
));
return true;
}
false
}
fn eliminate_noop_call(
&mut self,
call_block_idx: usize,
call_instr_idx: usize,
call_op: &SsaOp,
callee_token: Token,
) -> bool {
if !matches!(call_op, SsaOp::Call { .. } | SsaOp::CallVirt { .. }) {
return false;
}
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) {
instr.set_op(SsaOp::Nop);
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!(
"eliminated no-op call to 0x{:08x}",
callee_token.value()
));
return true;
}
}
false
}
fn fold_constant_call(
&mut self,
call_block_idx: usize,
call_instr_idx: usize,
call_op: &SsaOp,
const_val: &ConstValue,
callee_token: Token,
) -> bool {
let dest = match call_op {
SsaOp::Call { dest, .. } | SsaOp::CallVirt { dest, .. } => *dest,
_ => return false,
};
if let Some(dest_var) = dest {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) {
instr.set_op(SsaOp::Const {
dest: dest_var,
value: const_val.clone(),
});
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!(
"folded constant call to 0x{:08x} -> {:?}",
callee_token.value(),
const_val
));
return true;
}
}
} else {
if let Some(block) = self.caller_ssa.block_mut(call_block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) {
instr.set_op(SsaOp::Nop);
self.changes
.record(EventKind::MethodInlined)
.at(self.caller_token, call_instr_idx)
.message(format!(
"eliminated unused constant call to 0x{:08x}",
callee_token.value()
));
return true;
}
}
}
false
}
fn remap_op(
op: &SsaOp,
var_remap: &mut HashMap<SsaVarId, SsaVarId>,
callee_ssa: &SsaFunction,
caller_ssa: &mut SsaFunction,
) -> SsaOp {
let mut cloned = op.clone();
if let Some(dest) = cloned.dest() {
let new_dest = Self::get_or_create_var(dest, var_remap, callee_ssa, caller_ssa);
cloned.set_dest(new_dest);
}
for used in op.uses() {
let new_var = var_remap.get(&used).copied().unwrap_or(used);
cloned.replace_uses(used, new_var);
}
cloned
}
fn get_or_create_var(
var: SsaVarId,
var_remap: &mut HashMap<SsaVarId, SsaVarId>,
callee_ssa: &SsaFunction,
caller_ssa: &mut SsaFunction,
) -> SsaVarId {
if let Some(&remapped) = var_remap.get(&var) {
return remapped;
}
#[allow(clippy::cast_possible_truncation)]
let new_id = if let Some(callee_var) = callee_ssa.variable(var) {
let new_var = SsaVariable::new_typed(
VariableOrigin::Stack(caller_ssa.variable_count() as u32),
0,
DefSite::instruction(0, 0),
callee_var.var_type().clone(),
);
caller_ssa.add_variable(new_var)
} else {
let new_var = SsaVariable::new(
VariableOrigin::Stack(caller_ssa.variable_count() as u32),
0,
DefSite::instruction(0, 0),
);
caller_ssa.add_variable(new_var)
};
var_remap.insert(var, new_id);
new_id
}
}
#[derive(Debug)]
pub struct InliningPass {
inline_threshold: usize,
proxy_only: bool,
}
impl InliningPass {
#[must_use]
pub fn new(threshold: usize, proxy_only: bool) -> Self {
Self {
inline_threshold: threshold,
proxy_only,
}
}
fn find_argument_origin(
ssa: &SsaFunction,
var: SsaVarId,
instructions: &[SsaInstruction],
) -> Option<usize> {
if let Some(var_info) = ssa.variable(var) {
if let VariableOrigin::Argument(idx) = var_info.origin() {
return Some(idx as usize);
}
}
for instr in instructions {
match instr.op() {
SsaOp::LoadArg { dest, arg_index } if *dest == var => {
return Some(*arg_index as usize);
}
SsaOp::Copy { dest, src } if *dest == var => {
return Self::find_argument_origin(ssa, *src, instructions);
}
_ => {}
}
}
None
}
fn detect_proxy_pattern(ssa: &SsaFunction) -> Option<(MethodRef, Vec<usize>, bool)> {
if ssa.blocks().len() != 1 {
return None;
}
let block = ssa.blocks().first()?;
let instructions = block.instructions();
let mut call_info: Option<(&MethodRef, &[SsaVarId], Option<SsaVarId>, bool)> = None;
let mut call_count = 0;
for instr in instructions {
match instr.op() {
SsaOp::Call { method, args, dest } => {
call_count += 1;
call_info = Some((method, args, *dest, false));
}
SsaOp::CallVirt { method, args, dest } => {
call_count += 1;
call_info = Some((method, args, *dest, true));
}
SsaOp::Return { .. }
| SsaOp::Nop
| SsaOp::Phi { .. }
| SsaOp::LoadArg { .. }
| SsaOp::LoadLocal { .. }
| SsaOp::Copy { .. } => {}
_ => return None,
}
}
if call_count != 1 {
return None;
}
let (target_method, call_args, call_dest, is_virtual) = call_info?;
let mut arg_mapping = Vec::with_capacity(call_args.len());
let num_params = ssa.num_args();
for &arg_var in call_args {
let param_idx = Self::find_argument_origin(ssa, arg_var, instructions);
match param_idx {
Some(idx) if idx < num_params => {
arg_mapping.push(idx);
}
_ => {
return None;
}
}
}
for instr in instructions {
if let SsaOp::Return {
value: Some(ret_var),
} = instr.op()
{
if Some(*ret_var) != call_dest {
return None;
}
}
}
Some((*target_method, arg_mapping, is_virtual))
}
}
impl SsaPass for InliningPass {
fn name(&self) -> &'static str {
"InliningPass"
}
fn description(&self) -> &'static str {
"Inlines small, pure methods and devirtualizes proxy calls"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
assembly: &Arc<CilObject>,
) -> Result<bool> {
let mut inline_ctx = InliningContext::new(self, ssa, method_token, ctx, assembly);
let candidates = inline_ctx.find_candidates();
if candidates.is_empty() {
return Ok(false);
}
for candidate in candidates.into_iter().rev() {
inline_ctx.process_candidate(&candidate);
}
let changed = inline_ctx.has_changes();
if changed {
ctx.events.merge(&inline_ctx.into_changes());
}
Ok(changed)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
analysis::{CallGraph, ConstValue, MethodRef, SsaFunctionBuilder, SsaOp, SsaVarId},
compiler::CompilerContext,
compiler::{
passes::inlining::{InliningContext, InliningPass},
SsaPass,
},
metadata::token::Token,
test::helpers::test_assembly_arc,
CilObject,
};
fn test_context() -> CompilerContext {
CompilerContext::new(Arc::new(CallGraph::new()))
}
fn test_assembly() -> Arc<CilObject> {
test_assembly_arc()
}
#[test]
fn test_pass_creation() {
let pass = InliningPass::new(20, false);
assert_eq!(pass.name(), "InliningPass");
assert_eq!(pass.inline_threshold, 20);
assert!(!pass.proxy_only);
let pass_custom = InliningPass::new(50, false);
assert_eq!(pass_custom.inline_threshold, 50);
assert!(!pass_custom.proxy_only);
let pass_proxy = InliningPass::new(0, true);
assert_eq!(pass_proxy.inline_threshold, 0);
assert!(pass_proxy.proxy_only);
}
#[test]
fn test_inline_constant_return() {
let callee_token = Token::new(0x06000002);
let (callee_ssa, callee_v0) = {
let mut v0_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
v0_out = v0;
b.ret_val(v0);
});
});
(ssa, v0_out)
};
let caller_token = Token::new(0x06000001);
let (mut caller_ssa, call_dest) = {
let mut dest_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let dest = b.call(MethodRef::new(callee_token), &[]);
dest_out = dest;
b.ret_val(dest);
});
});
(ssa, dest_out)
};
let ctx = test_context();
ctx.set_ssa(callee_token, callee_ssa.clone());
let call_op = caller_ssa.block(0).unwrap().instructions()[0].op().clone();
let pass = InliningPass::new(20, false);
let assembly = test_assembly();
let mut inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
let result = inline_ctx.inline_call(&callee_ssa, 0, 0, &call_op, callee_token);
assert!(result, "Inlining should succeed");
let block = inline_ctx.caller_ssa.block(0).unwrap();
let first_instr = &block.instructions()[0];
match first_instr.op() {
SsaOp::Const { dest, value } => {
assert_eq!(*dest, call_dest);
assert_eq!(*value, ConstValue::I32(42));
}
other => panic!("Expected Const, got {:?}", other),
}
let _ = callee_v0;
}
#[test]
fn test_inline_void_pure() {
let callee_token = Token::new(0x06000002);
let callee_ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.ret());
});
let caller_token = Token::new(0x06000001);
let mut caller_ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
b.call_void(MethodRef::new(callee_token), &[]);
b.ret();
});
});
let ctx = test_context();
ctx.set_ssa(callee_token, callee_ssa.clone());
let call_op = caller_ssa.block(0).unwrap().instructions()[0].op().clone();
let pass = InliningPass::new(20, false);
let assembly = test_assembly();
let mut inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
let result = inline_ctx.inline_call(&callee_ssa, 0, 0, &call_op, callee_token);
assert!(result, "Inlining should succeed");
let block = inline_ctx.caller_ssa.block(0).unwrap();
let first_instr = &block.instructions()[0];
assert!(
matches!(first_instr.op(), SsaOp::Nop),
"Expected Nop, got {:?}",
first_instr.op()
);
}
#[test]
fn test_no_inline_self_recursion() {
let pass = InliningPass::new(20, false);
let token = Token::new(0x06000001);
let ctx = test_context();
let assembly = test_assembly();
let mut dummy_ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.ret());
});
let inline_ctx = InliningContext::new(&pass, &mut dummy_ssa, token, &ctx, &assembly);
assert!(!inline_ctx.is_valid_target(token));
}
#[test]
fn test_no_inline_large_method() {
let callee_token = Token::new(0x06000002);
let caller_token = Token::new(0x06000001);
let callee_ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
for _ in 0..30 {
let _ = b.const_i32(0);
}
b.ret();
});
});
let ctx = test_context();
ctx.set_ssa(callee_token, callee_ssa);
let pass = InliningPass::new(20, false);
let assembly = test_assembly();
let mut caller_ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.ret());
});
let inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
assert!(!inline_ctx.should_inline(callee_token));
}
#[test]
fn test_inline_full_computation() {
let callee_token = Token::new(0x06000002);
let callee_ssa = SsaFunctionBuilder::new(1, 0).build_with(|f| {
let param0 = f.arg(0);
f.block(0, |b| {
let v1 = b.const_i32(10);
let v2 = b.add(param0, v1);
b.ret_val(v2);
});
});
let caller_token = Token::new(0x06000001);
let mut caller_ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(5);
let v1 = b.call(MethodRef::new(callee_token), &[v0]);
b.ret_val(v1);
});
});
let ctx = test_context();
ctx.set_ssa(callee_token, callee_ssa.clone());
let call_op = caller_ssa.block(0).unwrap().instructions()[1].op().clone();
let pass = InliningPass::new(20, false);
let assembly = test_assembly();
let mut inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
let result = inline_ctx.inline_call(&callee_ssa, 0, 1, &call_op, callee_token);
assert!(result, "Full inlining should succeed");
let block = inline_ctx.caller_ssa.block(0).unwrap();
assert!(
block.instructions().len() > 3,
"Expected inlined instructions, got {} instructions",
block.instructions().len()
);
let second_instr = &block.instructions()[1];
match second_instr.op() {
SsaOp::Const { value, .. } => {
assert_eq!(*value, ConstValue::I32(10), "Expected inlined constant 10");
}
other => panic!("Expected Const for inlined instruction, got {:?}", other),
}
let has_add = block
.instructions()
.iter()
.any(|i| matches!(i.op(), SsaOp::Add { .. }));
assert!(has_add, "Expected Add instruction after inlining");
}
#[test]
fn test_detect_proxy_void() {
let target_token = Token::new(0x0A000001);
let proxy_ssa = SsaFunctionBuilder::new(1, 0).build_with(|f| {
let param0 = f.arg(0);
f.block(0, |b| {
b.call_void(MethodRef::new(target_token), &[param0]);
b.ret();
});
});
let result = InliningPass::detect_proxy_pattern(&proxy_ssa);
assert!(result.is_some(), "Should detect void proxy");
let (target, arg_mapping, is_virtual) = result.unwrap();
assert_eq!(target.token(), target_token);
assert_eq!(arg_mapping, vec![0]); assert!(!is_virtual);
}
#[test]
fn test_detect_proxy_with_return() {
let target_token = Token::new(0x0A000002);
let proxy_ssa = SsaFunctionBuilder::new(2, 0).build_with(|f| {
let param0 = f.arg(0);
let param1 = f.arg(1);
f.block(0, |b| {
let result = b.call(MethodRef::new(target_token), &[param0, param1]);
b.ret_val(result);
});
});
let result = InliningPass::detect_proxy_pattern(&proxy_ssa);
assert!(result.is_some(), "Should detect proxy with return");
let (target, arg_mapping, is_virtual) = result.unwrap();
assert_eq!(target.token(), target_token);
assert_eq!(arg_mapping, vec![0, 1]);
assert!(!is_virtual);
}
#[test]
fn test_detect_proxy_reordered_args() {
let target_token = Token::new(0x0A000003);
let proxy_ssa = SsaFunctionBuilder::new(2, 0).build_with(|f| {
let param0 = f.arg(0);
let param1 = f.arg(1);
f.block(0, |b| {
let result = b.call(MethodRef::new(target_token), &[param1, param0]);
b.ret_val(result);
});
});
let result = InliningPass::detect_proxy_pattern(&proxy_ssa);
assert!(result.is_some(), "Should detect proxy with reordered args");
let (target, arg_mapping, is_virtual) = result.unwrap();
assert_eq!(target.token(), target_token);
assert_eq!(arg_mapping, vec![1, 0]); assert!(!is_virtual);
}
#[test]
fn test_not_proxy_with_computation() {
let target_token = Token::new(0x0A000004);
let not_proxy_ssa = SsaFunctionBuilder::new(1, 0).build_with(|f| {
let param0 = f.arg(0);
f.block(0, |b| {
let one = b.const_i32(1);
let sum = b.add(param0, one);
let result = b.call(MethodRef::new(target_token), &[sum]);
b.ret_val(result);
});
});
let result = InliningPass::detect_proxy_pattern(¬_proxy_ssa);
assert!(
result.is_none(),
"Should NOT detect as proxy - has computation"
);
}
#[test]
fn test_devirtualize_proxy() {
let proxy_token = Token::new(0x06000002);
let target_token = Token::new(0x0A000001);
let proxy_ssa = SsaFunctionBuilder::new(1, 0).build_with(|f| {
let param0 = f.arg(0);
f.block(0, |b| {
b.call_void(MethodRef::new(target_token), &[param0]);
b.ret();
});
});
let caller_token = Token::new(0x06000001);
let mut caller_ssa = SsaFunctionBuilder::new(1, 0).build_with(|f| {
let arg0 = f.arg(0);
f.block(0, |b| {
b.call_void(MethodRef::new(proxy_token), &[arg0]);
b.ret();
});
});
let ctx = test_context();
ctx.set_ssa(proxy_token, proxy_ssa);
let proxy_pattern = ctx
.with_ssa(proxy_token, InliningPass::detect_proxy_pattern)
.flatten();
assert!(proxy_pattern.is_some(), "Proxy should be detected");
let (target_method, arg_mapping, is_virtual) = proxy_pattern.unwrap();
let call_op = caller_ssa.block(0).unwrap().instructions()[0].op().clone();
let pass = InliningPass::new(20, false);
let assembly = test_assembly();
let mut inline_ctx =
InliningContext::new(&pass, &mut caller_ssa, caller_token, &ctx, &assembly);
let result = inline_ctx.devirtualize_proxy(
0,
0,
&call_op,
target_method,
&arg_mapping,
is_virtual,
proxy_token,
);
assert!(result, "Devirtualization should succeed");
let block = inline_ctx.caller_ssa.block(0).unwrap();
let first_instr = &block.instructions()[0];
match first_instr.op() {
SsaOp::Call { method, .. } => {
assert_eq!(
method.token(),
target_token,
"Call should now target {:?}",
target_token
);
}
other => panic!("Expected Call, got {:?}", other),
}
}
}