use super::{Transform, TransformCategory, TransformLevel};
use crate::mir::{Block, Function, Instruction, Module, Operand, Register};
use std::cell::Cell;
use std::collections::HashMap;
#[derive(Default)]
pub struct FunctionInlining;
impl Transform for FunctionInlining {
fn name(&self) -> &'static str {
"function_inlining"
}
fn description(&self) -> &'static str {
"Module-level function inlining (see ModuleInlining for full implementation)"
}
fn category(&self) -> TransformCategory {
TransformCategory::Inlining
}
fn level(&self) -> TransformLevel {
TransformLevel::Experimental
}
fn apply(&self, func: &mut Function) -> Result<bool, String> {
let func_name = func.sig.name.clone();
let mut changed = false;
for block in &mut func.blocks {
let mut to_remove = Vec::new();
for (i, instr) in block.instructions.iter().enumerate() {
if let Instruction::Call { name, ret, .. } = instr
&& *name == func_name
&& ret.is_none()
{
if let Some(next) = block.instructions.get(i + 1)
&& matches!(next, Instruction::Ret { value: None })
{
to_remove.push(i);
}
}
}
for idx in to_remove.iter().rev() {
block.instructions.remove(*idx);
changed = true;
}
}
Ok(changed)
}
}
pub struct ModuleInlining {
inline_counter: Cell<usize>,
}
impl Default for ModuleInlining {
fn default() -> Self {
Self::new()
}
}
impl ModuleInlining {
pub fn new() -> Self {
Self {
inline_counter: Cell::new(0),
}
}
fn next_inline_id(&self) -> usize {
let id = self.inline_counter.get();
self.inline_counter.set(id + 1);
id
}
pub fn inline_functions(&self, module: &mut Module) -> Result<usize, String> {
let mut inlined_count = 0;
const MAX_INLINE_ITERATIONS: usize = 20;
const MAX_TOTAL_INSTRUCTIONS: usize = 50_000;
let total_instructions: usize = module
.functions
.values()
.map(|f| f.blocks.iter().map(|b| b.instructions.len()).sum::<usize>())
.sum();
if total_instructions > MAX_TOTAL_INSTRUCTIONS {
return Err(format!(
"Module too large for inlining ({} instructions, max {})",
total_instructions, MAX_TOTAL_INSTRUCTIONS
));
}
let mut call_sites = Vec::new();
for (func_name, func) in &module.functions {
for block in &func.blocks {
for (instr_idx, instr) in block.instructions.iter().enumerate() {
if let Instruction::Call { name, .. } = instr {
call_sites.push(CallSite {
caller: func_name.clone(),
callee: name.clone(),
block_label: block.label.clone(),
instr_idx,
});
}
}
}
}
let mut iterations = 0;
while iterations < MAX_INLINE_ITERATIONS {
let mut made_progress = false;
let mut new_call_sites = Vec::new();
for (func_name, func) in &module.functions {
for block in &func.blocks {
for (instr_idx, instr) in block.instructions.iter().enumerate() {
if let Instruction::Call { name, .. } = instr {
new_call_sites.push(CallSite {
caller: func_name.clone(),
callee: name.clone(),
block_label: block.label.clone(),
instr_idx,
});
}
}
}
}
for call_site in new_call_sites {
if self.should_inline(&call_site, module) {
match self.perform_inline(&call_site, module) {
Ok(()) => {
inlined_count += 1;
made_progress = true;
}
Err(_e) => {
continue;
}
}
}
}
if !made_progress {
break; }
iterations += 1;
}
if iterations >= MAX_INLINE_ITERATIONS {
return Err(format!(
"Inlining did not converge after {} iterations",
MAX_INLINE_ITERATIONS
));
}
Ok(inlined_count)
}
fn should_inline(&self, call_site: &CallSite, module: &Module) -> bool {
if call_site.callee == call_site.caller {
return false;
}
if let Some(callee_func) = module.functions.get(&call_site.callee) {
let is_recursive = callee_func.blocks.iter().any(|block| {
block.instructions.iter().any(|instr| {
matches!(instr, Instruction::Call { name, .. } if name == &call_site.callee)
})
});
if is_recursive {
return false;
}
let total_instructions = callee_func.instruction_count();
if total_instructions > 50 {
return false;
}
let has_calls = callee_func.blocks.iter().any(|block| {
block
.instructions
.iter()
.any(|instr| matches!(instr, Instruction::Call { .. }))
});
if has_calls && total_instructions > 30 {
return false; }
let has_complex_cf = callee_func.blocks.iter().any(|block| {
block
.instructions
.iter()
.any(|instr| matches!(instr, Instruction::Switch { .. }))
});
if has_complex_cf {
return false; }
if callee_func.blocks.len() > 20 || total_instructions > 100 {
return false;
}
if let Some(_caller_func) = module.functions.get(&call_site.caller) {
}
total_instructions <= 30
} else {
false }
}
fn perform_inline(&self, call_site: &CallSite, module: &mut Module) -> Result<(), String> {
let callee_func = module
.functions
.get(&call_site.callee)
.ok_or_else(|| format!("Callee function '{}' not found", call_site.callee))?
.clone();
let call_args = {
let caller_func = module
.functions
.get(&call_site.caller)
.ok_or_else(|| format!("Caller function '{}' not found", call_site.caller))?;
let call_block = caller_func
.blocks
.iter()
.find(|b| b.label == call_site.block_label)
.ok_or_else(|| format!("Call block '{}' not found", call_site.block_label))?;
if call_site.instr_idx >= call_block.instructions.len() {
return Err("Call instruction index out of bounds".to_string());
}
let call_instr = &call_block.instructions[call_site.instr_idx];
if let Instruction::Call { args, .. } = call_instr {
args.clone()
} else {
return Err("Expected call instruction".to_string());
}
};
let param_mapping = self.create_param_mapping(&callee_func, &call_args)?;
if callee_func.blocks.len() == 1 {
self.inline_single_block_function(call_site, module, &callee_func, ¶m_mapping)?;
} else {
self.inline_multi_block_function(call_site, module, &callee_func, ¶m_mapping)?;
}
Ok(())
}
fn inline_single_block_function(
&self,
call_site: &CallSite,
module: &mut Module,
callee_func: &Function,
param_mapping: &HashMap<Register, Operand>,
) -> Result<(), String> {
let (call_result_reg, _call_instr) = {
let caller_func = module
.functions
.get(&call_site.caller)
.ok_or_else(|| format!("Caller function '{}' not found", call_site.caller))?;
let call_block = caller_func
.blocks
.iter()
.find(|b| b.label == call_site.block_label)
.ok_or_else(|| {
format!(
"Block '{}' not found in caller function '{}'",
call_site.block_label, call_site.caller
)
})?;
let call_instr = call_block.instructions[call_site.instr_idx].clone();
let call_result_reg = if let Instruction::Call { ret, .. } = &call_instr {
ret.clone()
} else {
return Err("Expected call instruction".to_string());
};
(call_result_reg, call_instr)
};
let callee_block = callee_func
.blocks
.first()
.ok_or_else(|| "Callee function has no blocks".to_string())?;
let mut inlined_instructions = Vec::new();
let caller_func = module
.functions
.get(&call_site.caller)
.ok_or_else(|| format!("Caller function '{}' not found", call_site.caller))?;
for instr in &callee_block.instructions {
let mut new_instr = instr.clone();
if let Instruction::Ret { value } = &new_instr
&& let Some(ret_val) = value
{
if let Some(ref result_reg) = call_result_reg {
let return_type = *callee_func.sig.ret_ty.as_ref().ok_or_else(|| {
"Function has return value but no return type in signature".to_string()
})?;
let mut assign_instr = Instruction::IntBinary {
op: crate::mir::IntBinOp::Add,
dst: result_reg.clone(),
ty: return_type,
lhs: ret_val.clone(),
rhs: Operand::Immediate(crate::mir::Immediate::I64(0)),
};
self.substitute_parameters_and_rename(
&mut assign_instr,
param_mapping,
caller_func,
)?;
if let Instruction::IntBinary { dst, .. } = &mut assign_instr {
*dst = result_reg.clone();
}
inlined_instructions.push(assign_instr);
}
continue;
}
self.substitute_parameters_and_rename(&mut new_instr, param_mapping, caller_func)?;
inlined_instructions.push(new_instr);
}
let caller_func = module
.functions
.get_mut(&call_site.caller)
.ok_or_else(|| format!("Caller function '{}' not found", call_site.caller))?;
let call_block = caller_func
.blocks
.iter_mut()
.find(|b| b.label == call_site.block_label)
.ok_or_else(|| {
format!(
"Block '{}' not found in caller function '{}'",
call_site.block_label, call_site.caller
)
})?;
call_block.instructions.splice(
call_site.instr_idx..=call_site.instr_idx,
inlined_instructions,
);
Ok(())
}
fn substitute_parameters_and_rename(
&self,
instr: &mut Instruction,
param_mapping: &HashMap<Register, Operand>,
caller_func: &Function,
) -> Result<(), String> {
let mut register_map: HashMap<Register, Register> = HashMap::new();
let mut next_reg_id = self.find_max_register_id(caller_func) + 1;
let mut map_register = |reg: &Register| -> Register {
if let Some(mapped) = register_map.get(reg) {
return mapped.clone();
}
if let Some(param_operand) = param_mapping.get(reg)
&& let Operand::Register(param_reg) = param_operand
{
return param_reg.clone();
}
let new_reg = Register::Virtual(crate::mir::VirtualReg::gpr(next_reg_id));
register_map.insert(reg.clone(), new_reg.clone());
next_reg_id += 1;
new_reg
};
self.map_instruction_registers(instr, &mut map_register);
Ok(())
}
fn find_max_register_id(&self, func: &Function) -> u32 {
let mut max_id = 0;
for block in &func.blocks {
for instr in &block.instructions {
if let Some(reg) = instr.def_reg()
&& let Register::Virtual(vreg) = reg
&& vreg.class == crate::mir::RegisterClass::Gpr
{
max_id = max_id.max(vreg.id);
}
for use_reg in instr.use_regs() {
if let Register::Virtual(vreg) = use_reg
&& vreg.class == crate::mir::RegisterClass::Gpr
{
max_id = max_id.max(vreg.id);
}
}
}
}
max_id
}
fn map_instruction_registers<F>(&self, instr: &mut Instruction, map_reg: &mut F)
where
F: FnMut(&Register) -> Register,
{
match instr {
Instruction::IntBinary { dst, lhs, rhs, .. } => {
*dst = map_reg(dst);
self.map_operand_register(lhs, map_reg);
self.map_operand_register(rhs, map_reg);
}
Instruction::FloatBinary { dst, lhs, rhs, .. } => {
*dst = map_reg(dst);
self.map_operand_register(lhs, map_reg);
self.map_operand_register(rhs, map_reg);
}
Instruction::FloatUnary { dst, src, .. } => {
*dst = map_reg(dst);
self.map_operand_register(src, map_reg);
}
Instruction::IntCmp { dst, lhs, rhs, .. } => {
*dst = map_reg(dst);
self.map_operand_register(lhs, map_reg);
self.map_operand_register(rhs, map_reg);
}
Instruction::FloatCmp { dst, lhs, rhs, .. } => {
*dst = map_reg(dst);
self.map_operand_register(lhs, map_reg);
self.map_operand_register(rhs, map_reg);
}
Instruction::Select {
dst,
cond,
true_val,
false_val,
..
} => {
*dst = map_reg(dst);
*cond = map_reg(cond);
self.map_operand_register(true_val, map_reg);
self.map_operand_register(false_val, map_reg);
}
Instruction::Load { dst, addr, .. } => {
*dst = map_reg(dst);
if let crate::mir::AddressMode::BaseOffset { base, .. } = addr {
*base = map_reg(base);
}
if let crate::mir::AddressMode::BaseIndexScale { base, index, .. } = addr {
*base = map_reg(base);
*index = map_reg(index);
}
}
Instruction::Store { src, addr, .. } => {
self.map_operand_register(src, map_reg);
if let crate::mir::AddressMode::BaseOffset { base, .. } = addr {
*base = map_reg(base);
}
if let crate::mir::AddressMode::BaseIndexScale { base, index, .. } = addr {
*base = map_reg(base);
*index = map_reg(index);
}
}
Instruction::Lea { dst, base, .. } => {
*dst = map_reg(dst);
*base = map_reg(base);
}
Instruction::VectorOp { dst, operands, .. } => {
*dst = map_reg(dst);
for operand in operands {
self.map_operand_register(operand, map_reg);
}
}
_ => {} }
}
fn map_operand_register<F>(&self, operand: &mut Operand, map_reg: &mut F)
where
F: FnMut(&Register) -> Register,
{
if let Operand::Register(reg) = operand {
*reg = map_reg(reg);
}
}
fn create_param_mapping(
&self,
callee_func: &Function,
call_args: &[Operand],
) -> Result<HashMap<Register, Operand>, String> {
if callee_func.sig.params.len() != call_args.len() {
return Err(format!(
"Parameter count mismatch: expected {}, got {}",
callee_func.sig.params.len(),
call_args.len()
));
}
let mut mapping = HashMap::new();
for (param, arg) in callee_func.sig.params.iter().zip(call_args.iter()) {
mapping.insert(param.reg.clone(), arg.clone());
}
Ok(mapping)
}
fn clone_and_rename_blocks(
&self,
blocks: &[Block],
param_mapping: &HashMap<Register, Operand>,
suffix: &str,
caller_max_reg: u32,
) -> Result<Vec<Block>, String> {
let mut renamed_blocks = Vec::new();
let mut register_mapping = HashMap::new();
let base_reg_offset = caller_max_reg as usize + 1;
for block in blocks {
for instr in &block.instructions {
if let Some(dst) = instr.def_reg()
&& !register_mapping.contains_key(dst)
{
let new_reg = Register::Virtual(crate::mir::VirtualReg::gpr(
(base_reg_offset + register_mapping.len()) as u32,
));
register_mapping.insert(dst.clone(), new_reg);
}
for use_reg in instr.use_regs() {
if !register_mapping.contains_key(use_reg)
&& !param_mapping.contains_key(use_reg)
{
let new_reg = Register::Virtual(crate::mir::VirtualReg::gpr(
(base_reg_offset + register_mapping.len()) as u32,
));
register_mapping.insert(use_reg.clone(), new_reg);
}
}
}
}
for block in blocks {
let mut new_block = Block::new(format!("{}{}", block.label, suffix));
for instr in &block.instructions {
let mut new_instr = instr.clone();
if let Some(dst) = new_instr.def_reg()
&& let Some(new_dst) = register_mapping.get(dst)
{
self.rename_instruction_dst(&mut new_instr, new_dst.clone());
}
self.rename_instruction_uses(&mut new_instr, ®ister_mapping, param_mapping)?;
match &mut new_instr {
Instruction::Jmp { target } => {
*target = format!("{}{}", target, suffix);
}
Instruction::Br {
true_target,
false_target,
..
} => {
*true_target = format!("{}{}", true_target, suffix);
*false_target = format!("{}{}", false_target, suffix);
}
_ => {}
}
new_block.push(new_instr);
}
renamed_blocks.push(new_block);
}
Ok(renamed_blocks)
}
fn inline_multi_block_function(
&self,
call_site: &CallSite,
module: &mut Module,
callee_func: &Function,
param_mapping: &HashMap<Register, Operand>,
) -> Result<(), String> {
let inline_id = self.next_inline_id();
let suffix = format!("_inline_{}_{}", call_site.callee, inline_id);
let caller_max_reg = module
.functions
.get(&call_site.caller)
.map(|f| self.find_max_register_id(f))
.unwrap_or(0);
let mut inlined_blocks = self.clone_and_rename_blocks(
&callee_func.blocks,
param_mapping,
&suffix,
caller_max_reg,
)?;
if inlined_blocks.is_empty() {
return Err("Callee has no blocks".to_string());
}
let caller_func = module
.functions
.get_mut(&call_site.caller)
.ok_or_else(|| format!("Caller function '{}' not found", call_site.caller))?;
let call_block_idx = caller_func
.blocks
.iter()
.position(|b| b.label == call_site.block_label)
.ok_or_else(|| "Call block not found".to_string())?;
let call_block = &mut caller_func.blocks[call_block_idx];
let ret_reg =
if let Instruction::Call { ret, .. } = &call_block.instructions[call_site.instr_idx] {
ret.clone()
} else {
return Err("Expected Call instruction".to_string());
};
let mut post_call_instrs = call_block.instructions.split_off(call_site.instr_idx + 1);
call_block.instructions.pop();
let split_label = format!("{}_split_{}", call_site.block_label, inline_id);
let mut split_block = Block::new(split_label.clone());
split_block.instructions.append(&mut post_call_instrs);
let expected_entry = format!("entry{}", suffix);
let callee_entry_target = inlined_blocks
.iter()
.find(|b| b.label == expected_entry)
.map(|b| b.label.clone())
.unwrap_or_else(|| inlined_blocks[0].label.clone()); call_block.instructions.push(Instruction::Jmp {
target: callee_entry_target,
});
for block in &mut inlined_blocks {
if let Some(last_instr) = block.instructions.pop() {
if let Instruction::Ret { value } = last_instr {
if let Some(val) = value
&& let Some(dst) = &ret_reg
{
block.instructions.push(Instruction::IntBinary {
op: crate::mir::IntBinOp::Add,
ty: crate::mir::MirType::Scalar(crate::mir::ScalarType::I64),
dst: dst.clone(),
lhs: val,
rhs: Operand::Immediate(crate::mir::Immediate::I64(0)),
});
}
block.instructions.push(Instruction::Jmp {
target: split_label.clone(),
});
} else {
block.instructions.push(last_instr);
}
}
}
caller_func.blocks.insert(call_block_idx + 1, split_block);
let mut insert_pos = call_block_idx + 1;
for block in inlined_blocks {
caller_func.blocks.insert(insert_pos, block);
insert_pos += 1;
}
Ok(())
}
fn rename_instruction_dst(&self, instr: &mut Instruction, new_dst: Register) {
match instr {
Instruction::IntBinary { dst, .. }
| Instruction::FloatBinary { dst, .. }
| Instruction::FloatUnary { dst, .. }
| Instruction::IntCmp { dst, .. }
| Instruction::FloatCmp { dst, .. }
| Instruction::Select { dst, .. }
| Instruction::Load { dst, .. }
| Instruction::Lea { dst, .. }
| Instruction::VectorOp { dst, .. } => {
*dst = new_dst;
}
Instruction::Call { ret, .. } => {
*ret = Some(new_dst);
}
_ => {} }
}
fn rename_instruction_uses(
&self,
instr: &mut Instruction,
register_mapping: &HashMap<Register, Register>,
param_mapping: &HashMap<Register, Operand>,
) -> Result<(), String> {
match instr {
Instruction::IntBinary { lhs, rhs, .. }
| Instruction::FloatBinary { lhs, rhs, .. }
| Instruction::IntCmp { lhs, rhs, .. }
| Instruction::FloatCmp { lhs, rhs, .. } => {
*lhs = self.map_operand(lhs, register_mapping, param_mapping)?;
*rhs = self.map_operand(rhs, register_mapping, param_mapping)?;
}
Instruction::FloatUnary { src, .. } => {
*src = self.map_operand(src, register_mapping, param_mapping)?;
}
Instruction::Select {
cond,
true_val,
false_val,
..
} => {
*cond = self.map_register(cond, register_mapping)?;
*true_val = self.map_operand(true_val, register_mapping, param_mapping)?;
*false_val = self.map_operand(false_val, register_mapping, param_mapping)?;
}
Instruction::Load { addr, .. } => {
*addr = self.map_address_mode(addr, register_mapping)?;
}
Instruction::Store { src, addr, .. } => {
*src = self.map_operand(src, register_mapping, param_mapping)?;
*addr = self.map_address_mode(addr, register_mapping)?;
}
Instruction::Lea { base, .. } => {
*base = self.map_register(base, register_mapping)?;
}
Instruction::VectorOp { operands, .. } => {
for operand in operands {
*operand = self.map_operand(operand, register_mapping, param_mapping)?;
}
}
Instruction::Call { args, .. } => {
for arg in args {
*arg = self.map_operand(arg, register_mapping, param_mapping)?;
}
}
Instruction::Br { cond, .. } => {
*cond = self.map_register(cond, register_mapping)?;
}
Instruction::Switch { value, .. } => {
*value = self.map_register(value, register_mapping)?;
}
Instruction::Ret { value: Some(val) } => {
*val = self.map_operand(val, register_mapping, param_mapping)?;
}
_ => {} }
Ok(())
}
fn map_operand(
&self,
operand: &Operand,
register_mapping: &HashMap<Register, Register>,
param_mapping: &HashMap<Register, Operand>,
) -> Result<Operand, String> {
match operand {
Operand::Register(reg) => {
if let Some(param_operand) = param_mapping.get(reg) {
Ok(param_operand.clone())
} else if let Some(mapped_reg) = register_mapping.get(reg) {
Ok(Operand::Register(mapped_reg.clone()))
} else {
Ok(operand.clone())
}
}
_ => Ok(operand.clone()),
}
}
fn map_register(
&self,
reg: &Register,
register_mapping: &HashMap<Register, Register>,
) -> Result<Register, String> {
if let Some(mapped) = register_mapping.get(reg) {
Ok(mapped.clone())
} else {
Ok(reg.clone())
}
}
fn map_address_mode(
&self,
addr: &crate::mir::AddressMode,
register_mapping: &HashMap<Register, Register>,
) -> Result<crate::mir::AddressMode, String> {
match addr {
crate::mir::AddressMode::BaseOffset { base, offset } => {
Ok(crate::mir::AddressMode::BaseOffset {
base: self.map_register(base, register_mapping)?,
offset: *offset,
})
}
crate::mir::AddressMode::BaseIndexScale {
base,
index,
scale,
offset,
} => Ok(crate::mir::AddressMode::BaseIndexScale {
base: self.map_register(base, register_mapping)?,
index: self.map_register(index, register_mapping)?,
scale: *scale,
offset: *offset,
}),
}
}
}
#[derive(Debug)]
struct CallSite {
caller: String,
callee: String,
block_label: String,
instr_idx: usize,
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use crate::mir::{
FunctionBuilder, Immediate, IntBinOp, MirType, Operand, ScalarType, VirtualReg,
};
#[test]
fn test_inline_multi_block() {
let mut module = Module::new("test_module");
let mut callee = FunctionBuilder::new("callee")
.param(VirtualReg::gpr(0).into(), MirType::Scalar(ScalarType::I64))
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()), rhs: Operand::Immediate(Immediate::I64(1)),
})
.instr(Instruction::Jmp {
target: "exit".to_string(),
})
.block("exit")
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(0).into())),
})
.build();
callee.sig.params[0].reg = VirtualReg::gpr(0).into();
module.add_function(callee);
let caller = FunctionBuilder::new("caller")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::Call {
name: "callee".to_string(),
args: vec![Operand::Immediate(Immediate::I64(10))],
ret: Some(VirtualReg::gpr(1).into()),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(1).into())),
})
.build();
module.add_function(caller);
let inline_pass = ModuleInlining::new();
let count = inline_pass
.inline_functions(&mut module)
.expect("Inlining failed");
assert!(count > 0, "Should have inlined 1 function");
let caller = module.functions.get("caller").unwrap();
assert!(
caller.blocks.len() >= 3,
"Expected at least 3 blocks after inlining, got {}",
caller.blocks.len()
);
let has_call = caller.blocks.iter().any(|b| {
b.instructions
.iter()
.any(|i| matches!(i, Instruction::Call { .. }))
});
assert!(!has_call, "Call instruction should be removed");
}
}