#include <memory>
#include "asm_v_wasm.h"
#include "ir/module-utils.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm.h"
namespace wasm {
enum class ModuleElementKind { Function, Global, Event };
typedef std::pair<ModuleElementKind, Name> ModuleElement;
struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> {
Module* module;
std::vector<ModuleElement> queue;
std::set<ModuleElement> reachable;
bool usesMemory = false;
bool usesTable = false;
ReachabilityAnalyzer(Module* module, const std::vector<ModuleElement>& roots)
: module(module) {
queue = roots;
for (auto& segment : module->memory.segments) {
if (!segment.isPassive) {
walk(segment.offset);
}
}
for (auto& segment : module->table.segments) {
walk(segment.offset);
}
while (queue.size()) {
auto& curr = queue.back();
queue.pop_back();
if (reachable.count(curr) == 0) {
reachable.insert(curr);
if (curr.first == ModuleElementKind::Function) {
auto* func = module->getFunction(curr.second);
if (!func->imported()) {
walk(func->body);
}
} else if (curr.first == ModuleElementKind::Global) {
auto* global = module->getGlobal(curr.second);
if (!global->imported()) {
walk(global->init);
}
}
}
}
}
void visitCall(Call* curr) {
if (reachable.count(
ModuleElement(ModuleElementKind::Function, curr->target)) == 0) {
queue.emplace_back(ModuleElementKind::Function, curr->target);
}
}
void visitCallIndirect(CallIndirect* curr) { usesTable = true; }
void visitGlobalGet(GlobalGet* curr) {
if (reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) ==
0) {
queue.emplace_back(ModuleElementKind::Global, curr->name);
}
}
void visitGlobalSet(GlobalSet* curr) {
if (reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) ==
0) {
queue.emplace_back(ModuleElementKind::Global, curr->name);
}
}
void visitLoad(Load* curr) { usesMemory = true; }
void visitStore(Store* curr) { usesMemory = true; }
void visitAtomicCmpxchg(AtomicCmpxchg* curr) { usesMemory = true; }
void visitAtomicRMW(AtomicRMW* curr) { usesMemory = true; }
void visitAtomicWait(AtomicWait* curr) { usesMemory = true; }
void visitAtomicNotify(AtomicNotify* curr) { usesMemory = true; }
void visitAtomicFence(AtomicFence* curr) { usesMemory = true; }
void visitMemoryInit(MemoryInit* curr) { usesMemory = true; }
void visitDataDrop(DataDrop* curr) { usesMemory = true; }
void visitMemoryCopy(MemoryCopy* curr) { usesMemory = true; }
void visitMemoryFill(MemoryFill* curr) { usesMemory = true; }
void visitMemorySize(MemorySize* curr) { usesMemory = true; }
void visitMemoryGrow(MemoryGrow* curr) { usesMemory = true; }
void visitRefFunc(RefFunc* curr) {
if (reachable.count(
ModuleElement(ModuleElementKind::Function, curr->func)) == 0) {
queue.emplace_back(ModuleElementKind::Function, curr->func);
}
}
void visitThrow(Throw* curr) {
if (reachable.count(ModuleElement(ModuleElementKind::Event, curr->event)) ==
0) {
queue.emplace_back(ModuleElementKind::Event, curr->event);
}
}
void visitBrOnExn(BrOnExn* curr) {
if (reachable.count(ModuleElement(ModuleElementKind::Event, curr->event)) ==
0) {
queue.emplace_back(ModuleElementKind::Event, curr->event);
}
}
};
struct RemoveUnusedModuleElements : public Pass {
bool rootAllFunctions;
RemoveUnusedModuleElements(bool rootAllFunctions)
: rootAllFunctions(rootAllFunctions) {}
void run(PassRunner* runner, Module* module) override {
std::vector<ModuleElement> roots;
if (module->start.is()) {
auto startFunction = module->getFunction(module->start);
if (startFunction->body->is<Nop>()) {
module->start.clear();
} else {
roots.emplace_back(ModuleElementKind::Function, module->start);
}
}
if (rootAllFunctions) {
ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) {
roots.emplace_back(ModuleElementKind::Function, func->name);
});
}
bool exportsMemory = false;
bool exportsTable = false;
for (auto& curr : module->exports) {
if (curr->kind == ExternalKind::Function) {
roots.emplace_back(ModuleElementKind::Function, curr->value);
} else if (curr->kind == ExternalKind::Global) {
roots.emplace_back(ModuleElementKind::Global, curr->value);
} else if (curr->kind == ExternalKind::Event) {
roots.emplace_back(ModuleElementKind::Event, curr->value);
} else if (curr->kind == ExternalKind::Memory) {
exportsMemory = true;
} else if (curr->kind == ExternalKind::Table) {
exportsTable = true;
}
}
bool importsMemory = false;
bool importsTable = false;
if (module->memory.imported()) {
importsMemory = true;
}
if (module->table.imported()) {
importsTable = true;
}
for (auto& segment : module->table.segments) {
for (auto& curr : segment.data) {
roots.emplace_back(ModuleElementKind::Function, curr);
}
}
ReachabilityAnalyzer analyzer(module, roots);
module->removeFunctions([&](Function* curr) {
return analyzer.reachable.count(
ModuleElement(ModuleElementKind::Function, curr->name)) == 0;
});
module->removeGlobals([&](Global* curr) {
return analyzer.reachable.count(
ModuleElement(ModuleElementKind::Global, curr->name)) == 0;
});
module->removeEvents([&](Event* curr) {
return analyzer.reachable.count(
ModuleElement(ModuleElementKind::Event, curr->name)) == 0;
});
if (!exportsMemory && !analyzer.usesMemory) {
if (!importsMemory) {
module->memory.segments.clear();
}
if (module->memory.segments.empty()) {
module->memory.exists = false;
module->memory.module = module->memory.base = Name();
module->memory.initial = 0;
module->memory.max = 0;
}
}
if (!exportsTable && !analyzer.usesTable) {
if (!importsTable) {
module->table.segments.clear();
}
if (module->table.segments.empty()) {
module->table.exists = false;
module->table.module = module->table.base = Name();
module->table.initial = 0;
module->table.max = 0;
}
}
}
};
Pass* createRemoveUnusedModuleElementsPass() {
return new RemoveUnusedModuleElements(false);
}
Pass* createRemoveUnusedNonFunctionModuleElementsPass() {
return new RemoveUnusedModuleElements(true);
}
}