use std::collections::BTreeSet;
use crate::ast::common::{
AstBlock, AstExpr, AstFunctionExpr, AstGlobalAttr, AstLValue, AstLocalAttr, AstLocalBinding,
AstLocalFunctionDecl, AstLocalOrigin, AstNameRef, AstStmt,
};
use super::super::binding_flow::stmt_references_any_binding;
use super::super::visit::{self, AstVisitor};
use super::super::walk::BlockKind;
use super::facts::MissingGlobals;
use super::insert::build_wildcard_global_decl;
pub(super) fn try_wrap_missing_collective_suffix(
block: &mut AstBlock,
kind: BlockKind,
missing: &MissingGlobals,
) -> bool {
if matches!(kind, BlockKind::ModuleBody) {
return false;
}
let Some((attr, names)) = collective_candidate(missing) else {
return false;
};
let start = block
.stmts
.iter()
.position(|stmt| stmt_mentions_any_missing_global(stmt, &names));
let Some(start) = start else {
return false;
};
let end = block
.stmts
.iter()
.rposition(|stmt| stmt_mentions_any_missing_global(stmt, &names));
let Some(mut end) = end else {
return false;
};
loop {
let bindings = collect_declared_bindings(&block.stmts[start..=end]);
let Some(next_offset) = block.stmts[(end + 1)..]
.iter()
.position(|stmt| stmt_references_any_binding(stmt, &bindings))
else {
break;
};
end += next_offset + 1;
}
if end + 1 != block.stmts.len() || !suffix_is_safe_to_wrap(&block.stmts[start..]) {
return false;
}
let suffix = block.stmts.split_off(start);
let mut inner_stmts = Vec::with_capacity(suffix.len() + 1);
inner_stmts.push(build_wildcard_global_decl(attr));
inner_stmts.extend(suffix);
block
.stmts
.push(AstStmt::DoBlock(Box::new(AstBlock { stmts: inner_stmts })));
true
}
fn collective_candidate(missing: &MissingGlobals) -> Option<(AstGlobalAttr, BTreeSet<String>)> {
match (missing.none.is_empty(), missing.const_.is_empty()) {
(true, false) => Some((
AstGlobalAttr::Const,
missing.const_.iter().cloned().collect(),
)),
(false, true) => Some((AstGlobalAttr::None, missing.none.iter().cloned().collect())),
_ => None,
}
}
fn suffix_is_safe_to_wrap(stmts: &[AstStmt]) -> bool {
stmts.iter().all(|stmt| {
!matches!(
stmt,
AstStmt::Goto(_) | AstStmt::Label(_) | AstStmt::Break | AstStmt::Continue
)
})
}
fn collect_declared_bindings(stmts: &[AstStmt]) -> Vec<AstLocalBinding> {
let mut bindings = Vec::new();
for stmt in stmts {
match stmt {
AstStmt::LocalDecl(local_decl) => bindings.extend(local_decl.bindings.iter().cloned()),
AstStmt::LocalFunctionDecl(function_decl) => {
bindings.push(synthetic_binding_for_local_function(function_decl.as_ref()));
}
AstStmt::Assign(_)
| AstStmt::CallStmt(_)
| AstStmt::Return(_)
| AstStmt::GlobalDecl(_)
| AstStmt::If(_)
| AstStmt::While(_)
| AstStmt::Repeat(_)
| AstStmt::NumericFor(_)
| AstStmt::GenericFor(_)
| AstStmt::DoBlock(_)
| AstStmt::FunctionDecl(_)
| AstStmt::Break
| AstStmt::Continue
| AstStmt::Goto(_)
| AstStmt::Label(_) | AstStmt::Error(_) => {}
}
}
bindings
}
fn synthetic_binding_for_local_function(function_decl: &AstLocalFunctionDecl) -> AstLocalBinding {
AstLocalBinding {
id: function_decl.name,
attr: AstLocalAttr::None,
origin: AstLocalOrigin::Recovered,
}
}
fn stmt_mentions_any_missing_global(stmt: &AstStmt, names: &BTreeSet<String>) -> bool {
let mut visitor = MissingGlobalStmtVisitor {
names,
found: false,
};
visit::visit_stmt(stmt, &mut visitor);
visitor.found
}
struct MissingGlobalStmtVisitor<'a> {
names: &'a BTreeSet<String>,
found: bool,
}
impl AstVisitor for MissingGlobalStmtVisitor<'_> {
fn visit_expr(&mut self, expr: &AstExpr) {
if let AstExpr::Var(AstNameRef::Global(global)) = expr
&& self.names.contains(&global.text)
{
self.found = true;
}
}
fn visit_lvalue(&mut self, lvalue: &AstLValue) {
if let AstLValue::Name(AstNameRef::Global(global)) = lvalue
&& self.names.contains(&global.text)
{
self.found = true;
}
}
fn visit_function_expr(&mut self, _function: &AstFunctionExpr) -> bool {
false
}
}