mod mentioned;
mod rewrite;
mod site;
mod usage;
use std::collections::{BTreeMap, BTreeSet};
use crate::hir::common::{
HirBlock, HirCallExpr, HirExpr, HirLValue, HirProto, HirStmt, HirTableField, HirTableKey,
TempId,
};
use crate::hir::promotion::{HomeSlotKey, ProtoPromotionFacts};
use crate::readability::ReadabilityOptions;
use self::mentioned::protected_temps_for_nested_stmt;
use self::rewrite::replace_temp_in_stmt;
use self::site::{expr_touches_temp, inline_site_in_stmt};
use self::usage::{
NextStmtState, TempUseScratch, TempUseSummary, collect_stmt_temp_uses, inline_candidate,
max_temp_index_in_block,
};
const NESTED_INLINE_MAX_COMPLEXITY: usize = 5;
const CONTROL_HEAD_INLINE_MAX_COMPLEXITY: usize = 5;
pub(super) fn inline_temps_in_proto_with_facts(
proto: &mut HirProto,
readability: ReadabilityOptions,
facts: &ProtoPromotionFacts,
) -> bool {
let proto_temp_count = proto
.temps
.iter()
.map(|temp| temp.index())
.max()
.map_or(0, |max_index| max_index + 1);
let body_temp_count = max_temp_index_in_block(&proto.body).map_or(0, |max_index| max_index + 1);
let temp_count = proto_temp_count.max(body_temp_count);
let mut scratch = TempUseScratch::new(proto, temp_count);
inline_temps_in_block(
&mut proto.body,
&mut scratch,
readability,
facts,
&BTreeSet::new(),
&BTreeSet::new(),
)
}
fn inline_temps_in_block(
block: &mut HirBlock,
scratch: &mut TempUseScratch,
readability: ReadabilityOptions,
facts: &ProtoPromotionFacts,
protected_temps: &BTreeSet<TempId>,
inherited_captured_slots: &BTreeSet<HomeSlotKey>,
) -> bool {
let mut changed = false;
let mut captured_slots_before_stmt = Vec::with_capacity(block.stmts.len());
let mut active_captured_slots = inherited_captured_slots.clone();
for index in 0..block.stmts.len() {
captured_slots_before_stmt.push(active_captured_slots.clone());
let nested_protected =
protected_temps_for_nested_stmt(&block.stmts, index, protected_temps);
let mut nested_captured_slots = active_captured_slots.clone();
facts.collect_prefix_captured_home_slots_in_stmt(
&block.stmts[index],
&mut nested_captured_slots,
);
let stmt = &mut block.stmts[index];
changed |= inline_temps_in_nested_blocks(
stmt,
scratch,
readability,
facts,
&nested_protected,
&nested_captured_slots,
);
facts.collect_captured_home_slots_in_stmt(stmt, &mut active_captured_slots);
}
if inline_call_callee_across_argument_materialization(
block,
scratch,
facts,
protected_temps,
&captured_slots_before_stmt,
) {
changed = true;
captured_slots_before_stmt =
captured_slots_before_stmts(block, facts, inherited_captured_slots);
}
let total_use_totals = collect_block_temp_use_totals(&block.stmts, scratch);
let mut suffix_use_totals = vec![0; scratch.temp_count()];
let mut kept_rev = Vec::with_capacity(block.stmts.len());
let mut next_stmt_state: Option<NextStmtState> = None;
for (index, stmt) in std::mem::take(&mut block.stmts)
.into_iter()
.enumerate()
.rev()
{
let stmt_uses = collect_stmt_temp_uses(&stmt, scratch);
if let Some((temp, value)) = inline_candidate(&stmt)
&& !scratch.has_debug_local_hint(temp)
&& !protected_temps.contains(&temp)
&& !temp_rebinds_captured_slot(
temp,
facts,
captured_slots_before_stmt
.get(index)
.expect("forward scan should record every statement"),
)
&& !expr_touches_temp(value, temp)
&& prefix_use_count(temp, &total_use_totals, &suffix_use_totals, &stmt_uses) == 0
&& suffix_use_totals.get(temp.index()).copied().unwrap_or(0) == 1
&& let Some(state) = &mut next_stmt_state
&& state.temp_uses.count(temp) == 1
&& kept_rev
.last()
.and_then(|next_stmt| inline_site_in_stmt(next_stmt, temp))
.is_some_and(|site| site.allows(value, readability))
{
state.temp_uses.remove_from_totals(&mut suffix_use_totals);
let next_stmt = kept_rev
.last_mut()
.expect("next stmt metadata must track the last kept stmt");
replace_temp_in_stmt(next_stmt, temp, value);
state.temp_uses = collect_stmt_temp_uses(next_stmt, scratch);
state.temp_uses.add_to_totals(&mut suffix_use_totals);
changed = true;
continue;
}
stmt_uses.add_to_totals(&mut suffix_use_totals);
next_stmt_state = Some(NextStmtState {
temp_uses: stmt_uses,
});
kept_rev.push(stmt);
}
kept_rev.reverse();
block.stmts = kept_rev;
changed
}
fn captured_slots_before_stmts(
block: &HirBlock,
facts: &ProtoPromotionFacts,
inherited_captured_slots: &BTreeSet<HomeSlotKey>,
) -> Vec<BTreeSet<HomeSlotKey>> {
let mut captured_slots = Vec::with_capacity(block.stmts.len());
let mut active_captured_slots = inherited_captured_slots.clone();
for stmt in &block.stmts {
captured_slots.push(active_captured_slots.clone());
facts.collect_captured_home_slots_in_stmt(stmt, &mut active_captured_slots);
}
captured_slots
}
fn inline_call_callee_across_argument_materialization(
block: &mut HirBlock,
scratch: &mut TempUseScratch,
facts: &ProtoPromotionFacts,
protected_temps: &BTreeSet<TempId>,
captured_slots_before_stmt: &[BTreeSet<HomeSlotKey>],
) -> bool {
let total_use_totals = collect_block_temp_use_totals(&block.stmts, scratch);
let mut changed = false;
let mut index = 0;
while index + 2 < block.stmts.len() {
let Some((callee_temp, callee_value)) = inline_candidate(&block.stmts[index]) else {
index += 1;
continue;
};
if !cross_call_inline_candidate_is_safe(
callee_temp,
callee_value,
index,
scratch,
facts,
protected_temps,
captured_slots_before_stmt,
) || total_use_count(callee_temp, &total_use_totals) != 1
|| expr_has_open_multivalue(callee_value)
{
index += 1;
continue;
}
let mut arg_values = Vec::new();
let mut arg_temps = Vec::new();
let mut call_index = index + 1;
while call_index < block.stmts.len() {
if matches!(block.stmts[call_index], HirStmt::CallStmt(_)) {
break;
}
let Some((arg_temp, arg_value)) = inline_candidate(&block.stmts[call_index]) else {
break;
};
if !cross_call_inline_candidate_is_safe(
arg_temp,
arg_value,
call_index,
scratch,
facts,
protected_temps,
captured_slots_before_stmt,
) || total_use_count(arg_temp, &total_use_totals) != 1
|| expr_has_open_multivalue(arg_value)
{
break;
}
arg_temps.push(arg_temp);
arg_values.push(arg_value.clone());
call_index += 1;
}
if arg_temps.is_empty() || call_index >= block.stmts.len() {
index += 1;
continue;
}
let HirStmt::CallStmt(call_stmt) = &block.stmts[call_index] else {
index += 1;
continue;
};
if !matches!(&call_stmt.call.callee, HirExpr::TempRef(temp) if *temp == callee_temp)
|| !call_args_are_exact_temp_refs(&call_stmt.call.args, &arg_temps)
{
index += 1;
continue;
}
let callee_value = callee_value.clone();
let arg_replacements = arg_temps
.iter()
.copied()
.zip(arg_values)
.collect::<BTreeMap<_, _>>();
if let HirStmt::CallStmt(call_stmt) = &mut block.stmts[call_index] {
call_stmt.call.callee = callee_value;
for arg in &mut call_stmt.call.args {
let HirExpr::TempRef(temp) = arg else {
continue;
};
if let Some(value) = arg_replacements.get(temp) {
*arg = value.clone();
}
}
}
block.stmts.drain(index..call_index);
changed = true;
}
changed
}
fn cross_call_inline_candidate_is_safe(
temp: TempId,
value: &HirExpr,
stmt_index: usize,
scratch: &TempUseScratch,
facts: &ProtoPromotionFacts,
protected_temps: &BTreeSet<TempId>,
captured_slots_before_stmt: &[BTreeSet<HomeSlotKey>],
) -> bool {
!scratch.has_debug_local_hint(temp)
&& !protected_temps.contains(&temp)
&& !temp_rebinds_captured_slot(
temp,
facts,
captured_slots_before_stmt
.get(stmt_index)
.expect("captured slot scan should cover every statement"),
)
&& !expr_touches_temp(value, temp)
}
fn total_use_count(temp: TempId, total_use_totals: &[usize]) -> usize {
total_use_totals
.get(temp.index())
.copied()
.unwrap_or_default()
}
fn call_args_are_exact_temp_refs(args: &[HirExpr], expected_temps: &[TempId]) -> bool {
args.len() == expected_temps.len()
&& args
.iter()
.zip(expected_temps.iter().copied())
.all(|(arg, expected)| matches!(arg, HirExpr::TempRef(temp) if *temp == expected))
}
fn expr_has_open_multivalue(expr: &HirExpr) -> bool {
match expr {
HirExpr::VarArg => true,
HirExpr::Call(call) => {
call.multiret
|| expr_has_open_multivalue(&call.callee)
|| call.args.iter().any(expr_has_open_multivalue)
}
HirExpr::TableAccess(access) => {
expr_has_open_multivalue(&access.base) || expr_has_open_multivalue(&access.key)
}
HirExpr::Unary(unary) => expr_has_open_multivalue(&unary.expr),
HirExpr::Binary(binary) => {
expr_has_open_multivalue(&binary.lhs) || expr_has_open_multivalue(&binary.rhs)
}
HirExpr::LogicalAnd(logical) | HirExpr::LogicalOr(logical) => {
expr_has_open_multivalue(&logical.lhs) || expr_has_open_multivalue(&logical.rhs)
}
HirExpr::Decision(decision) => decision.nodes.iter().any(|node| {
expr_has_open_multivalue(&node.test)
|| decision_target_has_open_multivalue(&node.truthy)
|| decision_target_has_open_multivalue(&node.falsy)
}),
HirExpr::TableConstructor(table) => {
table.fields.iter().any(|field| match field {
HirTableField::Array(value) => expr_has_open_multivalue(value),
HirTableField::Record(field) => {
matches!(&field.key, HirTableKey::Expr(key) if expr_has_open_multivalue(key))
|| expr_has_open_multivalue(&field.value)
}
}) || table
.trailing_multivalue
.as_ref()
.is_some_and(expr_has_open_multivalue)
}
HirExpr::Closure(closure) => closure
.captures
.iter()
.any(|capture| expr_has_open_multivalue(&capture.value)),
HirExpr::Nil
| HirExpr::Boolean(_)
| HirExpr::Integer(_)
| HirExpr::Number(_)
| HirExpr::String(_)
| HirExpr::Int64(_)
| HirExpr::UInt64(_)
| HirExpr::Complex { .. }
| HirExpr::ParamRef(_)
| HirExpr::LocalRef(_)
| HirExpr::UpvalueRef(_)
| HirExpr::TempRef(_)
| HirExpr::GlobalRef(_)
| HirExpr::Unresolved(_) => false,
}
}
fn decision_target_has_open_multivalue(target: &crate::hir::common::HirDecisionTarget) -> bool {
match target {
crate::hir::common::HirDecisionTarget::Expr(expr) => expr_has_open_multivalue(expr),
crate::hir::common::HirDecisionTarget::Node(_)
| crate::hir::common::HirDecisionTarget::CurrentValue => false,
}
}
fn collect_block_temp_use_totals(stmts: &[HirStmt], scratch: &mut TempUseScratch) -> Vec<usize> {
let mut totals = vec![0; scratch.temp_count()];
for stmt in stmts {
collect_stmt_temp_uses(stmt, scratch).add_to_totals(&mut totals);
}
totals
}
fn prefix_use_count(
temp: TempId,
total_use_totals: &[usize],
suffix_use_totals: &[usize],
current_stmt_uses: &TempUseSummary,
) -> usize {
total_use_totals
.get(temp.index())
.copied()
.unwrap_or_default()
.saturating_sub(
suffix_use_totals
.get(temp.index())
.copied()
.unwrap_or_default(),
)
.saturating_sub(current_stmt_uses.count(temp))
}
fn inline_temps_in_nested_blocks(
stmt: &mut HirStmt,
scratch: &mut TempUseScratch,
readability: ReadabilityOptions,
facts: &ProtoPromotionFacts,
protected_temps: &BTreeSet<TempId>,
inherited_captured_slots: &BTreeSet<HomeSlotKey>,
) -> bool {
match stmt {
HirStmt::If(if_stmt) => {
let mut changed = inline_temps_in_block(
&mut if_stmt.then_block,
scratch,
readability,
facts,
protected_temps,
inherited_captured_slots,
);
if let Some(else_block) = &mut if_stmt.else_block {
changed |= inline_temps_in_block(
else_block,
scratch,
readability,
facts,
protected_temps,
inherited_captured_slots,
);
}
changed
}
HirStmt::While(while_stmt) => inline_temps_in_block(
&mut while_stmt.body,
scratch,
readability,
facts,
protected_temps,
inherited_captured_slots,
),
HirStmt::Repeat(repeat_stmt) => {
let mut repeat_protected = protected_temps.clone();
mentioned::collect_expr_mentioned_temps(&repeat_stmt.cond, &mut repeat_protected);
inline_temps_in_block(
&mut repeat_stmt.body,
scratch,
readability,
facts,
&repeat_protected,
inherited_captured_slots,
)
}
HirStmt::NumericFor(numeric_for) => inline_temps_in_block(
&mut numeric_for.body,
scratch,
readability,
facts,
protected_temps,
inherited_captured_slots,
),
HirStmt::GenericFor(generic_for) => inline_temps_in_block(
&mut generic_for.body,
scratch,
readability,
facts,
protected_temps,
inherited_captured_slots,
),
HirStmt::Block(block) => inline_temps_in_block(
block,
scratch,
readability,
facts,
protected_temps,
inherited_captured_slots,
),
HirStmt::Unstructured(unstructured) => inline_temps_in_block(
&mut unstructured.body,
scratch,
readability,
facts,
protected_temps,
inherited_captured_slots,
),
HirStmt::LocalDecl(_)
| HirStmt::Assign(_)
| HirStmt::TableSetList(_)
| HirStmt::ErrNil(_)
| HirStmt::ToBeClosed(_)
| HirStmt::Close(_)
| HirStmt::CallStmt(_)
| HirStmt::Return(_)
| HirStmt::Break
| HirStmt::Continue
| HirStmt::Goto(_)
| HirStmt::Label(_) => false,
}
}
fn temp_rebinds_captured_slot(
temp: TempId,
facts: &ProtoPromotionFacts,
captured_slots: &BTreeSet<HomeSlotKey>,
) -> bool {
facts
.home_slot(temp)
.is_some_and(|slot| captured_slots.contains(&slot))
}