use std::collections::{BTreeMap, BTreeSet};
use crate::ast::{
AstBindingRef, AstBlock, AstCallKind, AstExpr, AstFunctionExpr, AstFunctionName, AstLValue,
AstModule, AstNameRef, AstStmt, AstSyntheticLocalId, AstTableField, AstTableKey,
};
use crate::hir::{HirModule, HirProtoRef};
#[derive(Debug, Clone, Default)]
pub(super) struct AstNamingFacts {
pub(super) functions: Vec<FunctionAstNamingFacts>,
}
#[derive(Debug, Clone, Default)]
pub(super) struct FunctionAstNamingFacts {
pub(super) debug_like_binding_order: BTreeMap<AstBindingRef, usize>,
pub(super) unused_synthetic_locals: BTreeSet<AstSyntheticLocalId>,
}
pub(super) fn collect_ast_naming_facts(module: &AstModule, hir: &HirModule) -> AstNamingFacts {
let mut facts = AstNamingFacts {
functions: vec![FunctionAstNamingFacts::default(); hir.protos.len()],
};
collect_function_facts(module.entry_function, &module.body, hir, &mut facts);
facts
}
#[derive(Debug, Default)]
struct FunctionAstCollector {
binding_order: Vec<AstBindingRef>,
seen_bindings: BTreeSet<AstBindingRef>,
declared_synthetic_locals: BTreeSet<AstSyntheticLocalId>,
mentioned_synthetic_locals: BTreeSet<AstSyntheticLocalId>,
}
impl FunctionAstCollector {
fn note_binding(&mut self, binding: AstBindingRef) {
if self.seen_bindings.insert(binding) {
self.binding_order.push(binding);
}
if let AstBindingRef::SyntheticLocal(local) = binding {
self.declared_synthetic_locals.insert(local);
}
}
fn note_name_ref(&mut self, name: &AstNameRef) {
match name {
AstNameRef::Local(local) => self.note_binding(AstBindingRef::Local(*local)),
AstNameRef::SyntheticLocal(local) => {
self.note_binding(AstBindingRef::SyntheticLocal(*local));
self.mentioned_synthetic_locals.insert(*local);
}
AstNameRef::Param(_)
| AstNameRef::Temp(_)
| AstNameRef::Upvalue(_)
| AstNameRef::Global(_) => {}
}
}
fn finish(self) -> FunctionAstNamingFacts {
let debug_like_binding_order = self
.binding_order
.into_iter()
.enumerate()
.map(|(index, binding)| (binding, index))
.collect();
let unused_synthetic_locals = self
.declared_synthetic_locals
.difference(&self.mentioned_synthetic_locals)
.copied()
.collect();
FunctionAstNamingFacts {
debug_like_binding_order,
unused_synthetic_locals,
}
}
}
fn collect_function_facts(
function: HirProtoRef,
body: &AstBlock,
hir: &HirModule,
facts: &mut AstNamingFacts,
) {
let mut collector = FunctionAstCollector::default();
note_named_vararg_binding(function, hir, &mut collector);
collect_block_facts(body, &mut collector, hir, facts);
facts.functions[function.index()] = collector.finish();
}
fn note_named_vararg_binding(
function: HirProtoRef,
hir: &HirModule,
collector: &mut FunctionAstCollector,
) {
let Some(proto) = hir.protos.get(function.index()) else {
return;
};
if proto.signature.has_vararg_param_reg
&& let Some(&local) = proto.locals.first()
{
collector.note_binding(AstBindingRef::Local(local));
}
}
fn collect_block_facts(
block: &AstBlock,
collector: &mut FunctionAstCollector,
hir: &HirModule,
facts: &mut AstNamingFacts,
) {
for stmt in &block.stmts {
collect_stmt_facts(stmt, collector, hir, facts);
}
}
fn collect_stmt_facts(
stmt: &AstStmt,
collector: &mut FunctionAstCollector,
hir: &HirModule,
facts: &mut AstNamingFacts,
) {
match stmt {
AstStmt::LocalDecl(local_decl) => {
for binding in &local_decl.bindings {
collector.note_binding(binding.id);
}
for value in &local_decl.values {
collect_expr_facts(value, collector, hir, facts);
}
}
AstStmt::GlobalDecl(global_decl) => {
for value in &global_decl.values {
collect_expr_facts(value, collector, hir, facts);
}
}
AstStmt::Assign(assign) => {
for target in &assign.targets {
collect_lvalue_facts(target, collector, hir, facts);
}
for value in &assign.values {
collect_expr_facts(value, collector, hir, facts);
}
}
AstStmt::CallStmt(call_stmt) => collect_call_facts(&call_stmt.call, collector, hir, facts),
AstStmt::Return(ret) => {
for value in &ret.values {
collect_expr_facts(value, collector, hir, facts);
}
}
AstStmt::If(if_stmt) => {
collect_expr_facts(&if_stmt.cond, collector, hir, facts);
collect_block_facts(&if_stmt.then_block, collector, hir, facts);
if let Some(else_block) = &if_stmt.else_block {
collect_block_facts(else_block, collector, hir, facts);
}
}
AstStmt::While(while_stmt) => {
collect_expr_facts(&while_stmt.cond, collector, hir, facts);
collect_block_facts(&while_stmt.body, collector, hir, facts);
}
AstStmt::Repeat(repeat_stmt) => {
collect_block_facts(&repeat_stmt.body, collector, hir, facts);
collect_expr_facts(&repeat_stmt.cond, collector, hir, facts);
}
AstStmt::NumericFor(numeric_for) => {
collector.note_binding(numeric_for.binding);
collect_expr_facts(&numeric_for.start, collector, hir, facts);
collect_expr_facts(&numeric_for.limit, collector, hir, facts);
collect_expr_facts(&numeric_for.step, collector, hir, facts);
collect_block_facts(&numeric_for.body, collector, hir, facts);
}
AstStmt::GenericFor(generic_for) => {
for &binding in &generic_for.bindings {
collector.note_binding(binding);
}
for expr in &generic_for.iterator {
collect_expr_facts(expr, collector, hir, facts);
}
collect_block_facts(&generic_for.body, collector, hir, facts);
}
AstStmt::DoBlock(block) => collect_block_facts(block, collector, hir, facts),
AstStmt::FunctionDecl(function_decl) => {
collect_function_name_facts(&function_decl.target, collector);
collect_nested_function_facts(&function_decl.func, hir, facts);
}
AstStmt::LocalFunctionDecl(local_function_decl) => {
collector.note_binding(local_function_decl.name);
collect_nested_function_facts(&local_function_decl.func, hir, facts);
}
AstStmt::Break | AstStmt::Continue | AstStmt::Goto(_) | AstStmt::Label(_) => {}
}
}
fn collect_nested_function_facts(
function_expr: &AstFunctionExpr,
hir: &HirModule,
facts: &mut AstNamingFacts,
) {
collect_function_facts(function_expr.function, &function_expr.body, hir, facts);
}
fn collect_function_name_facts(target: &AstFunctionName, collector: &mut FunctionAstCollector) {
let path = match target {
AstFunctionName::Plain(path) => path,
AstFunctionName::Method(path, _) => path,
};
collector.note_name_ref(&path.root);
}
fn collect_call_facts(
call: &AstCallKind,
collector: &mut FunctionAstCollector,
hir: &HirModule,
facts: &mut AstNamingFacts,
) {
match call {
AstCallKind::Call(call) => {
collect_expr_facts(&call.callee, collector, hir, facts);
for arg in &call.args {
collect_expr_facts(arg, collector, hir, facts);
}
}
AstCallKind::MethodCall(call) => {
collect_expr_facts(&call.receiver, collector, hir, facts);
for arg in &call.args {
collect_expr_facts(arg, collector, hir, facts);
}
}
}
}
fn collect_lvalue_facts(
target: &AstLValue,
collector: &mut FunctionAstCollector,
hir: &HirModule,
facts: &mut AstNamingFacts,
) {
match target {
AstLValue::Name(name) => collector.note_name_ref(name),
AstLValue::FieldAccess(access) => collect_expr_facts(&access.base, collector, hir, facts),
AstLValue::IndexAccess(access) => {
collect_expr_facts(&access.base, collector, hir, facts);
collect_expr_facts(&access.index, collector, hir, facts);
}
}
}
fn collect_expr_facts(
expr: &AstExpr,
collector: &mut FunctionAstCollector,
hir: &HirModule,
facts: &mut AstNamingFacts,
) {
match expr {
AstExpr::Var(name) => collector.note_name_ref(name),
AstExpr::FieldAccess(access) => collect_expr_facts(&access.base, collector, hir, facts),
AstExpr::IndexAccess(access) => {
collect_expr_facts(&access.base, collector, hir, facts);
collect_expr_facts(&access.index, collector, hir, facts);
}
AstExpr::Unary(unary) => collect_expr_facts(&unary.expr, collector, hir, facts),
AstExpr::Binary(binary) => {
collect_expr_facts(&binary.lhs, collector, hir, facts);
collect_expr_facts(&binary.rhs, collector, hir, facts);
}
AstExpr::LogicalAnd(logical) | AstExpr::LogicalOr(logical) => {
collect_expr_facts(&logical.lhs, collector, hir, facts);
collect_expr_facts(&logical.rhs, collector, hir, facts);
}
AstExpr::Call(call) => {
collect_expr_facts(&call.callee, collector, hir, facts);
for arg in &call.args {
collect_expr_facts(arg, collector, hir, facts);
}
}
AstExpr::MethodCall(call) => {
collect_expr_facts(&call.receiver, collector, hir, facts);
for arg in &call.args {
collect_expr_facts(arg, collector, hir, facts);
}
}
AstExpr::SingleValue(expr) => collect_expr_facts(expr, collector, hir, facts),
AstExpr::TableConstructor(table) => {
for field in &table.fields {
match field {
AstTableField::Array(value) => collect_expr_facts(value, collector, hir, facts),
AstTableField::Record(record) => {
if let AstTableKey::Expr(key) = &record.key {
collect_expr_facts(key, collector, hir, facts);
}
collect_expr_facts(&record.value, collector, hir, facts);
}
}
}
}
AstExpr::FunctionExpr(function_expr) => {
collect_nested_function_facts(function_expr, hir, facts)
}
AstExpr::Nil
| AstExpr::Boolean(_)
| AstExpr::Integer(_)
| AstExpr::Number(_)
| AstExpr::String(_)
| AstExpr::Int64(_)
| AstExpr::UInt64(_)
| AstExpr::Complex { .. }
| AstExpr::VarArg => {}
}
}