mod candidate;
mod use_sites;
use crate::readability::ReadabilityOptions;
use self::candidate::{
InlinePolicy, inline_candidate, stmt_is_adjacent_call_result_sink,
stmt_is_alias_initializer_sink, stmt_is_direct_return_value_sink,
};
use self::use_sites::rewrite_stmt_use_sites_with_policy;
use super::super::common::{AstBindingRef, AstBlock, AstModule, AstStmt};
use super::ReadabilityContext;
use super::binding_flow::BindingUseIndex;
use super::binding_tree::{
expr_references_binding, stmt_has_access_base_binding_use, stmt_has_call_callee_binding_use,
stmt_has_direct_call_arg_binding_use, stmt_has_index_binding_use, stmt_has_nested_binding_use,
stmt_has_nested_binding_value_use,
};
use super::walk::{self, AstRewritePass, BlockKind};
pub(super) fn apply(module: &mut AstModule, context: ReadabilityContext) -> bool {
let _ = context.target;
walk::rewrite_module(
module,
&mut InlineExprsPass {
options: context.options,
},
)
}
struct InlineExprsPass {
options: ReadabilityOptions,
}
impl AstRewritePass for InlineExprsPass {
fn rewrite_block(&mut self, block: &mut AstBlock, _kind: BlockKind) -> bool {
rewrite_current_block(block, self.options)
}
}
fn rewrite_current_block(block: &mut AstBlock, options: ReadabilityOptions) -> bool {
let mut changed = false;
let old_stmts = std::mem::take(&mut block.stmts);
let use_index = BindingUseIndex::for_stmts(&old_stmts);
let mut new_stmts = Vec::with_capacity(old_stmts.len());
let mut index = 0;
while index < old_stmts.len() {
let Some(next_stmt) = old_stmts.get(index + 1) else {
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
};
let Some((candidate, value)) = inline_candidate(&old_stmts[index]) else {
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
};
let policy = if matches!(candidate, candidate::InlineCandidate::LocalAlias { .. })
&& stmt_is_alias_initializer_sink(next_stmt)
{
InlinePolicy::AliasInitializerChain
} else if matches!(candidate, candidate::InlineCandidate::LocalAlias { .. })
&& stmt_is_adjacent_call_result_sink(next_stmt)
{
InlinePolicy::AdjacentCallResultCallee
} else if matches!(candidate, candidate::InlineCandidate::LocalAlias { .. })
&& stmt_is_direct_return_value_sink(next_stmt)
{
InlinePolicy::DirectReturnConstructor
} else {
InlinePolicy::Conservative
};
if matches!(policy, InlinePolicy::AliasInitializerChain)
&& candidate::is_lookup_inline_expr(value)
&& stmt_starts_lookup_mechanical_run(&old_stmts, index, candidate.binding())
{
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
}
let allows_special_lookup_access_base = matches!(
candidate,
candidate::InlineCandidate::LocalAlias {
origin: super::super::common::AstLocalOrigin::Recovered,
..
}
) && matches!(policy, InlinePolicy::Conservative)
&& matches!(next_stmt, AstStmt::Assign(_))
&& candidate::is_lookup_inline_expr(value)
&& stmt_has_access_base_binding_use(next_stmt, candidate.binding());
let allows_special_index_sink = matches!(
candidate,
candidate::InlineCandidate::LocalAlias {
origin: super::super::common::AstLocalOrigin::Recovered,
..
}
) && matches!(policy, InlinePolicy::Conservative)
&& matches!(next_stmt, AstStmt::Assign(_))
&& super::expr_analysis::is_mechanical_run_inline_expr(value)
&& stmt_has_index_binding_use(next_stmt, candidate.binding());
let allows_special_adjacent_value_sink =
matches!(
candidate,
candidate::InlineCandidate::LocalAlias {
origin: super::super::common::AstLocalOrigin::Recovered,
..
}
) && matches!(
policy,
InlinePolicy::Conservative | InlinePolicy::AliasInitializerChain
) && matches!(next_stmt, AstStmt::Assign(_) | AstStmt::LocalDecl(_))
&& stmt_sink_binding_allows_adjacent_value_inline(&old_stmts, index + 1)
&& ((candidate::is_raw_global_alias_expr(value)
&& stmt_has_direct_call_arg_binding_use(next_stmt, candidate.binding()))
|| (stmt_has_nested_binding_value_use(next_stmt, candidate.binding())
&& (candidate::is_recallable_inline_expr(value)
|| (candidate::is_lookup_inline_expr(value)
&& assign_targets_same_lookup_expr(next_stmt, value)))));
let effective_policy = if allows_special_index_sink {
InlinePolicy::MechanicalRun
} else if allows_special_adjacent_value_sink {
InlinePolicy::AdjacentValueSink
} else {
policy
};
if !candidate.allows_expr_with_policy(value, effective_policy)
&& !allows_special_lookup_access_base
{
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
}
if use_index.count_uses_in_suffix(index + 1, candidate.binding()) != 1 {
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
}
let mut rewritten_next = next_stmt.clone();
let mut rewrite_policy = effective_policy;
if !rewrite_stmt_use_sites_with_policy(
&mut rewritten_next,
candidate,
value,
options,
rewrite_policy,
) {
if matches!(policy, InlinePolicy::AliasInitializerChain)
&& candidate::is_recallable_inline_expr(value)
&& stmt_has_direct_call_arg_binding_use(next_stmt, candidate.binding())
{
rewritten_next = next_stmt.clone();
rewrite_policy = InlinePolicy::ExtendedCallChain;
if !rewrite_stmt_use_sites_with_policy(
&mut rewritten_next,
candidate,
value,
options,
rewrite_policy,
) {
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
}
} else {
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
}
}
new_stmts.push(rewritten_next);
changed = true;
index += 2;
}
block.stmts = new_stmts;
changed |= collapse_adjacent_call_alias_runs(block, options);
changed |= collapse_terminal_call_result_alias_runs(block, options);
changed |= collapse_terminal_local_mechanical_runs(block, options);
changed |= collapse_adjacent_mechanical_alias_runs(block, options);
changed
}
fn collapse_adjacent_call_alias_runs(block: &mut AstBlock, options: ReadabilityOptions) -> bool {
let old_stmts = std::mem::take(&mut block.stmts);
let use_index = BindingUseIndex::for_stmts(&old_stmts);
let mut new_stmts = Vec::with_capacity(old_stmts.len());
let mut changed = false;
let mut index = 0;
while index < old_stmts.len() {
let mut run_end = index;
while run_end < old_stmts.len() && inline_candidate(&old_stmts[run_end]).is_some() {
run_end += 1;
}
if run_end == index
|| run_end >= old_stmts.len()
|| !matches!(old_stmts[run_end], AstStmt::CallStmt(_))
{
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
};
let mut rewritten_sink = old_stmts[run_end].clone();
let mut removed = vec![false; run_end - index];
let mut collapsed_count = 0usize;
for candidate_index in (index..run_end).rev() {
let Some((candidate, value)) = inline_candidate(&old_stmts[candidate_index]) else {
continue;
};
if !matches!(candidate, candidate::InlineCandidate::LocalAlias { .. }) {
continue;
}
if use_index.count_uses_in_range(candidate_index + 1, run_end + 1, candidate.binding())
!= 1
{
continue;
}
let intermediate_uses = if candidate::is_lookup_inline_expr(value) {
count_binding_uses_in_remaining_run(
&use_index,
candidate_index + 1,
&removed[(candidate_index + 1 - index)..],
candidate.binding(),
)
} else {
use_index.count_uses_in_range(candidate_index + 1, run_end, candidate.binding())
};
if intermediate_uses != 0 {
continue;
}
let mut trial_sink = rewritten_sink.clone();
if rewrite_stmt_use_sites_with_policy(
&mut trial_sink,
candidate,
value,
options,
InlinePolicy::ExtendedCallChain,
) {
rewritten_sink = trial_sink;
removed[candidate_index - index] = true;
collapsed_count += 1;
}
}
if collapsed_count >= 2 {
changed = true;
for (offset, stmt) in old_stmts[index..run_end].iter().enumerate() {
if !removed[offset] {
new_stmts.push(stmt.clone());
}
}
new_stmts.push(rewritten_sink);
index = run_end + 1;
continue;
}
new_stmts.push(old_stmts[index].clone());
index += 1;
}
block.stmts = new_stmts;
changed
}
fn collapse_terminal_call_result_alias_runs(
block: &mut AstBlock,
options: ReadabilityOptions,
) -> bool {
let old_stmts = std::mem::take(&mut block.stmts);
let use_index = BindingUseIndex::for_stmts(&old_stmts);
let mut new_stmts = Vec::with_capacity(old_stmts.len());
let mut changed = false;
let mut index = 0;
while index < old_stmts.len() {
let Some(sink_index) = find_terminal_call_result_sink(&old_stmts, index) else {
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
};
let mut rewritten_sink = old_stmts[sink_index].clone();
let mut removed = vec![false; sink_index - index];
let mut collapsed_count = 0usize;
for candidate_index in (index..sink_index).rev() {
let Some((candidate, value)) = inline_candidate(&old_stmts[candidate_index]) else {
continue;
};
if !matches!(candidate, candidate::InlineCandidate::LocalAlias { .. }) {
continue;
}
if use_index.count_uses_in_suffix(candidate_index + 1, candidate.binding()) != 1 {
continue;
}
let intermediate_uses = if candidate::is_lookup_inline_expr(value) {
count_binding_uses_in_remaining_run(
&use_index,
candidate_index + 1,
&removed[(candidate_index + 1 - index)..],
candidate.binding(),
)
} else {
use_index.count_uses_in_range(candidate_index + 1, sink_index, candidate.binding())
};
if intermediate_uses != 0
|| !stmt_has_nested_binding_use(&rewritten_sink, candidate.binding())
{
continue;
}
let mut trial_sink = rewritten_sink.clone();
if rewrite_stmt_use_sites_with_policy(
&mut trial_sink,
candidate,
value,
options,
InlinePolicy::ExtendedCallChain,
) {
rewritten_sink = trial_sink;
removed[candidate_index - index] = true;
collapsed_count += 1;
}
}
if collapsed_count >= 2 {
changed = true;
for (offset, stmt) in old_stmts[index..sink_index].iter().enumerate() {
if !removed[offset] {
new_stmts.push(stmt.clone());
}
}
new_stmts.push(rewritten_sink);
index = sink_index + 1;
continue;
}
new_stmts.push(old_stmts[index].clone());
index += 1;
}
block.stmts = new_stmts;
changed
}
fn find_terminal_call_result_sink(stmts: &[AstStmt], index: usize) -> Option<usize> {
inline_candidate(stmts.get(index)?)?;
let mut sink_index = index + 1;
while sink_index < stmts.len() && inline_candidate(&stmts[sink_index]).is_some() {
if stmt_is_adjacent_call_result_sink(&stmts[sink_index]) {
return Some(sink_index);
}
sink_index += 1;
}
None
}
fn stmt_sink_binding_allows_adjacent_value_inline(stmts: &[AstStmt], sink_index: usize) -> bool {
let Some(stmt) = stmts.get(sink_index) else {
return false;
};
if matches!(stmt, AstStmt::Assign(_)) {
return true;
}
let Some((sink_candidate, _)) = inline_candidate(stmt) else {
return false;
};
!stmts[(sink_index + 1)..]
.iter()
.any(|stmt| stmt_has_call_callee_binding_use(stmt, sink_candidate.binding()))
}
fn collapse_adjacent_mechanical_alias_runs(
block: &mut AstBlock,
options: ReadabilityOptions,
) -> bool {
let old_stmts = std::mem::take(&mut block.stmts);
let use_index = BindingUseIndex::for_stmts(&old_stmts);
let mut new_stmts = Vec::with_capacity(old_stmts.len());
let mut changed = false;
let mut index = 0;
while index < old_stmts.len() {
let mut run_end = index;
while run_end < old_stmts.len() && inline_candidate(&old_stmts[run_end]).is_some() {
run_end += 1;
}
if run_end == index
|| run_end >= old_stmts.len()
|| !stmt_can_absorb_mechanical_run(&old_stmts[run_end])
{
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
}
let mut rewritten_sink = old_stmts[run_end].clone();
let mut removed = vec![false; run_end - index];
let mut collapsed_count = 0usize;
let mut has_non_lookup_piece = false;
let mut has_dependent_lookup_piece = false;
for candidate_index in (index..run_end).rev() {
let Some((candidate, value)) = inline_candidate(&old_stmts[candidate_index]) else {
continue;
};
if !candidate.allows_expr_with_policy(value, InlinePolicy::MechanicalRun) {
continue;
}
if use_index.count_uses_in_range(candidate_index + 1, run_end + 1, candidate.binding())
!= 1
{
continue;
}
if use_index.count_uses_in_suffix(run_end + 1, candidate.binding()) != 0 {
continue;
}
if count_binding_uses_in_remaining_run(
&use_index,
candidate_index + 1,
&removed[(candidate_index + 1 - index)..],
candidate.binding(),
) != 0
{
continue;
}
if !stmt_has_nested_binding_use(&rewritten_sink, candidate.binding()) {
continue;
}
let mut trial_sink = rewritten_sink.clone();
if rewrite_stmt_use_sites_with_policy(
&mut trial_sink,
candidate,
value,
options,
InlinePolicy::MechanicalRun,
) {
rewritten_sink = trial_sink;
removed[candidate_index - index] = true;
collapsed_count += 1;
has_non_lookup_piece |= !candidate::is_lookup_inline_expr(value);
has_dependent_lookup_piece |= candidate::is_lookup_inline_expr(value)
&& expr_references_any_run_binding(
value,
&old_stmts[index..run_end],
candidate.binding(),
);
}
}
if collapsed_count >= 2
&& (has_non_lookup_piece
|| stmt_prefers_pure_lookup_run_collapse(&rewritten_sink)
|| (has_dependent_lookup_piece
&& stmt_prefers_dependent_lookup_run_collapse(&rewritten_sink)))
{
changed = true;
for (offset, stmt) in old_stmts[index..run_end].iter().enumerate() {
if !removed[offset] {
new_stmts.push(stmt.clone());
}
}
new_stmts.push(rewritten_sink);
index = run_end + 1;
continue;
}
new_stmts.push(old_stmts[index].clone());
index += 1;
}
block.stmts = new_stmts;
changed
}
fn collapse_terminal_local_mechanical_runs(
block: &mut AstBlock,
options: ReadabilityOptions,
) -> bool {
let old_stmts = std::mem::take(&mut block.stmts);
let use_index = BindingUseIndex::for_stmts(&old_stmts);
let mut new_stmts = Vec::with_capacity(old_stmts.len());
let mut changed = false;
let mut index = 0;
while index < old_stmts.len() {
let mut run_end = index;
while run_end < old_stmts.len() && inline_candidate(&old_stmts[run_end]).is_some() {
run_end += 1;
}
if run_end <= index + 1 || run_end >= old_stmts.len() {
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
}
let Some((sink_candidate, _)) = inline_candidate(&old_stmts[run_end - 1]) else {
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
};
if !matches!(
sink_candidate,
candidate::InlineCandidate::LocalAlias { .. }
) || use_index.count_uses_in_suffix(run_end, sink_candidate.binding()) == 0
{
new_stmts.push(old_stmts[index].clone());
index += 1;
continue;
}
let mut rewritten_sink = old_stmts[run_end - 1].clone();
let mut removed = vec![false; run_end - index - 1];
let mut collapsed_count = 0usize;
for candidate_index in (index..(run_end - 1)).rev() {
let Some((candidate, value)) = inline_candidate(&old_stmts[candidate_index]) else {
continue;
};
if !candidate.allows_expr_with_policy(value, InlinePolicy::MechanicalRun) {
continue;
}
if use_index.count_uses_in_suffix(candidate_index + 1, candidate.binding()) != 1 {
continue;
}
if use_index.count_uses_in_suffix(run_end, candidate.binding()) != 0 {
continue;
}
if count_binding_uses_in_remaining_run(
&use_index,
candidate_index + 1,
&removed[(candidate_index + 1 - index)..],
candidate.binding(),
) != 0
{
continue;
}
if !stmt_has_nested_binding_use(&rewritten_sink, candidate.binding()) {
continue;
}
let mut trial_sink = rewritten_sink.clone();
if rewrite_stmt_use_sites_with_policy(
&mut trial_sink,
candidate,
value,
options,
InlinePolicy::MechanicalRun,
) {
rewritten_sink = trial_sink;
removed[candidate_index - index] = true;
collapsed_count += 1;
}
}
if collapsed_count >= 2 {
changed = true;
for (offset, stmt) in old_stmts[index..(run_end - 1)].iter().enumerate() {
if !removed[offset] {
new_stmts.push(stmt.clone());
}
}
new_stmts.push(rewritten_sink);
index = run_end;
continue;
}
new_stmts.push(old_stmts[index].clone());
index += 1;
}
block.stmts = new_stmts;
changed
}
fn stmt_can_absorb_mechanical_run(stmt: &AstStmt) -> bool {
matches!(
stmt,
AstStmt::Assign(_)
| AstStmt::CallStmt(_)
| AstStmt::Return(_)
| AstStmt::If(_)
| AstStmt::While(_)
| AstStmt::Repeat(_)
| AstStmt::NumericFor(_)
| AstStmt::GenericFor(_)
)
}
fn stmt_prefers_pure_lookup_run_collapse(stmt: &AstStmt) -> bool {
matches!(
stmt,
AstStmt::Assign(assign)
if assign
.targets
.iter()
.any(|target| !matches!(target, super::super::common::AstLValue::Name(_)))
)
}
fn stmt_prefers_dependent_lookup_run_collapse(stmt: &AstStmt) -> bool {
matches!(stmt, AstStmt::Assign(_))
}
fn stmt_starts_lookup_mechanical_run(
stmts: &[AstStmt],
index: usize,
binding: AstBindingRef,
) -> bool {
let mut run_end = index;
while run_end < stmts.len() && inline_candidate(&stmts[run_end]).is_some() {
run_end += 1;
}
run_end > index + 1
&& run_end < stmts.len()
&& stmt_can_absorb_mechanical_run(&stmts[run_end])
&& stmts
.get(index + 1)
.and_then(inline_candidate)
.is_some_and(|(_, next_value)| {
candidate::is_lookup_inline_expr(next_value)
&& expr_references_binding(next_value, binding)
})
}
fn expr_references_any_run_binding(
expr: &super::super::common::AstExpr,
run: &[AstStmt],
except: AstBindingRef,
) -> bool {
run.iter().any(|stmt| {
inline_candidate(stmt).is_some_and(|(candidate, _)| {
let binding = candidate.binding();
binding != except && expr_references_binding(expr, binding)
})
})
}
fn assign_targets_same_lookup_expr(stmt: &AstStmt, expr: &super::super::common::AstExpr) -> bool {
let AstStmt::Assign(assign) = stmt else {
return false;
};
assign
.targets
.iter()
.any(|target| lvalue_matches_lookup_expr(target, expr))
}
fn lvalue_matches_lookup_expr(
target: &super::super::common::AstLValue,
expr: &super::super::common::AstExpr,
) -> bool {
match (target, expr) {
(
super::super::common::AstLValue::FieldAccess(lhs),
super::super::common::AstExpr::FieldAccess(rhs),
) => lhs.field == rhs.field && lhs.base == rhs.base,
(
super::super::common::AstLValue::IndexAccess(lhs),
super::super::common::AstExpr::IndexAccess(rhs),
) => lhs.base == rhs.base && lhs.index == rhs.index,
_ => false,
}
}
fn count_binding_uses_in_remaining_run(
use_index: &BindingUseIndex,
start_index: usize,
removed: &[bool],
binding: AstBindingRef,
) -> usize {
removed
.iter()
.enumerate()
.filter(|(_, removed)| !**removed)
.map(|(offset, _)| use_index.count_uses_in_stmt_index(start_index + offset, binding))
.sum()
}
#[cfg(test)]
mod tests;