#include <ir/block-utils.h>
#include <ir/branch-utils.h>
#include <ir/type-updating.h>
#include <pass.h>
#include <vector>
#include <wasm-builder.h>
#include <wasm.h>
namespace wasm {
struct DeadCodeElimination
: public WalkerPass<PostWalker<DeadCodeElimination>> {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new DeadCodeElimination; }
TypeUpdater typeUpdater;
Expression* replaceCurrent(Expression* expression) {
auto* old = getCurrent();
if (old == expression) {
return expression;
}
super::replaceCurrent(expression);
typeUpdater.noteReplacement(old, expression);
return expression;
}
bool reachable;
void doWalkFunction(Function* func) {
reachable = true;
typeUpdater.walk(func->body);
walk(func->body);
}
std::set<Name> reachableBreaks;
void addBreak(Name name) {
if (reachable) {
reachableBreaks.insert(name);
}
}
bool isDead(Expression* child) {
return child && child->type == Type::unreachable;
}
bool isUnreachable(Expression* child) {
return child->type == Type::unreachable;
}
void visitBreak(Break* curr) {
if (isDead(curr->value)) {
replaceCurrent(curr->value);
return;
}
if (isDead(curr->condition)) {
if (curr->value) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
block->list[0] = drop(curr->value);
block->list[1] = curr->condition;
block->finalize(curr->type);
replaceCurrent(block);
} else {
replaceCurrent(curr->condition);
}
return;
}
addBreak(curr->name);
if (!curr->condition) {
reachable = false;
}
}
void visitSwitch(Switch* curr) {
if (isDead(curr->value)) {
replaceCurrent(curr->value);
return;
}
if (isUnreachable(curr->condition)) {
if (curr->value) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
block->list[0] = drop(curr->value);
block->list[1] = curr->condition;
block->finalize(curr->type);
replaceCurrent(block);
} else {
replaceCurrent(curr->condition);
}
return;
}
for (auto target : curr->targets) {
addBreak(target);
}
addBreak(curr->default_);
reachable = false;
}
void visitReturn(Return* curr) {
if (isDead(curr->value)) {
replaceCurrent(curr->value);
return;
}
reachable = false;
}
void visitUnreachable(Unreachable* curr) { reachable = false; }
void visitBlock(Block* curr) {
auto& list = curr->list;
if (!reachable && list.size() > 1) {
for (Index i = 0; i < list.size() - 1; i++) {
if (list[i]->type == Type::unreachable) {
list.resize(i + 1);
break;
}
}
}
if (curr->name.is()) {
reachable = reachable || reachableBreaks.count(curr->name);
reachableBreaks.erase(curr->name);
}
if (list.size() == 1 && isUnreachable(list[0])) {
replaceCurrent(
BlockUtils::simplifyToContentsWithPossibleTypeChange(curr, this));
} else {
typeUpdater.maybeUpdateTypeToUnreachable(curr);
}
}
void visitLoop(Loop* curr) {
if (curr->name.is()) {
reachableBreaks.erase(curr->name);
}
if (isUnreachable(curr->body) &&
!BranchUtils::BranchSeeker::has(curr->body, curr->name)) {
replaceCurrent(curr->body);
return;
}
}
std::vector<bool> ifStack;
std::vector<bool> tryStack;
static void doAfterIfCondition(DeadCodeElimination* self,
Expression** currp) {
self->ifStack.push_back(self->reachable);
}
static void doAfterIfElseTrue(DeadCodeElimination* self, Expression** currp) {
assert((*currp)->cast<If>()->ifFalse);
bool reachableBefore = self->ifStack.back();
self->ifStack.pop_back();
self->ifStack.push_back(self->reachable);
self->reachable = reachableBefore;
}
void visitIf(If* curr) {
reachable = reachable || ifStack.back();
ifStack.pop_back();
if (isUnreachable(curr->condition)) {
replaceCurrent(curr->condition);
}
typeUpdater.maybeUpdateTypeToUnreachable(curr);
}
static void doBeforeTryBody(DeadCodeElimination* self, Expression** currp) {
self->tryStack.push_back(self->reachable);
}
static void doAfterTryBody(DeadCodeElimination* self, Expression** currp) {
bool reachableBefore = self->tryStack.back();
self->tryStack.pop_back();
self->tryStack.push_back(self->reachable);
self->reachable = reachableBefore;
}
void visitTry(Try* curr) {
reachable = reachable || tryStack.back();
tryStack.pop_back();
typeUpdater.maybeUpdateTypeToUnreachable(curr);
}
void visitThrow(Throw* curr) { reachable = false; }
void visitRethrow(Rethrow* curr) { reachable = false; }
void visitBrOnExn(BrOnExn* curr) {
if (isDead(curr->exnref)) {
replaceCurrent(curr->exnref);
return;
}
addBreak(curr->name);
}
static void scan(DeadCodeElimination* self, Expression** currp) {
auto* curr = *currp;
if (!self->reachable) {
#define DELEGATE(CLASS_TO_VISIT) \
{ \
auto* parent = self->typeUpdater.parents[curr]; \
self->typeUpdater.noteRecursiveRemoval(curr); \
ExpressionManipulator::convert<CLASS_TO_VISIT, Unreachable>( \
static_cast<CLASS_TO_VISIT*>(curr)); \
self->typeUpdater.noteAddition(curr, parent); \
break; \
}
switch (curr->_id) {
case Expression::Id::BlockId:
DELEGATE(Block);
case Expression::Id::IfId:
DELEGATE(If);
case Expression::Id::LoopId:
DELEGATE(Loop);
case Expression::Id::BreakId:
DELEGATE(Break);
case Expression::Id::SwitchId:
DELEGATE(Switch);
case Expression::Id::CallId:
DELEGATE(Call);
case Expression::Id::CallIndirectId:
DELEGATE(CallIndirect);
case Expression::Id::LocalGetId:
DELEGATE(LocalGet);
case Expression::Id::LocalSetId:
DELEGATE(LocalSet);
case Expression::Id::GlobalGetId:
DELEGATE(GlobalGet);
case Expression::Id::GlobalSetId:
DELEGATE(GlobalSet);
case Expression::Id::LoadId:
DELEGATE(Load);
case Expression::Id::StoreId:
DELEGATE(Store);
case Expression::Id::ConstId:
DELEGATE(Const);
case Expression::Id::UnaryId:
DELEGATE(Unary);
case Expression::Id::BinaryId:
DELEGATE(Binary);
case Expression::Id::SelectId:
DELEGATE(Select);
case Expression::Id::DropId:
DELEGATE(Drop);
case Expression::Id::ReturnId:
DELEGATE(Return);
case Expression::Id::MemorySizeId:
DELEGATE(MemorySize);
case Expression::Id::MemoryGrowId:
DELEGATE(MemoryGrow);
case Expression::Id::NopId:
DELEGATE(Nop);
case Expression::Id::UnreachableId:
break;
case Expression::Id::AtomicCmpxchgId:
DELEGATE(AtomicCmpxchg);
case Expression::Id::AtomicRMWId:
DELEGATE(AtomicRMW);
case Expression::Id::AtomicWaitId:
DELEGATE(AtomicWait);
case Expression::Id::AtomicNotifyId:
DELEGATE(AtomicNotify);
case Expression::Id::AtomicFenceId:
DELEGATE(AtomicFence);
case Expression::Id::SIMDExtractId:
DELEGATE(SIMDExtract);
case Expression::Id::SIMDReplaceId:
DELEGATE(SIMDReplace);
case Expression::Id::SIMDShuffleId:
DELEGATE(SIMDShuffle);
case Expression::Id::SIMDTernaryId:
DELEGATE(SIMDTernary);
case Expression::Id::SIMDShiftId:
DELEGATE(SIMDShift);
case Expression::Id::SIMDLoadId:
DELEGATE(SIMDLoad);
case Expression::Id::MemoryInitId:
DELEGATE(MemoryInit);
case Expression::Id::DataDropId:
DELEGATE(DataDrop);
case Expression::Id::MemoryCopyId:
DELEGATE(MemoryCopy);
case Expression::Id::MemoryFillId:
DELEGATE(MemoryFill);
case Expression::Id::PopId:
DELEGATE(Pop);
case Expression::Id::RefNullId:
DELEGATE(RefNull);
case Expression::Id::RefIsNullId:
DELEGATE(RefIsNull);
case Expression::Id::RefFuncId:
DELEGATE(RefFunc);
case Expression::Id::RefEqId:
DELEGATE(RefEq);
case Expression::Id::TryId:
DELEGATE(Try);
case Expression::Id::ThrowId:
DELEGATE(Throw);
case Expression::Id::RethrowId:
DELEGATE(Rethrow);
case Expression::Id::BrOnExnId:
DELEGATE(BrOnExn);
case Expression::Id::TupleMakeId:
DELEGATE(TupleMake);
case Expression::Id::TupleExtractId:
DELEGATE(TupleExtract);
case Expression::Id::I31NewId:
DELEGATE(I31New);
case Expression::Id::I31GetId:
DELEGATE(I31Get);
case Expression::Id::RefTestId:
DELEGATE(RefTest);
case Expression::Id::RefCastId:
DELEGATE(RefCast);
case Expression::Id::BrOnCastId:
DELEGATE(BrOnCast);
case Expression::Id::RttCanonId:
DELEGATE(RttCanon);
case Expression::Id::RttSubId:
DELEGATE(RttSub);
case Expression::Id::StructNewId:
DELEGATE(StructNew);
case Expression::Id::StructGetId:
DELEGATE(StructGet);
case Expression::Id::StructSetId:
DELEGATE(StructSet);
case Expression::Id::ArrayNewId:
DELEGATE(ArrayNew);
case Expression::Id::ArrayGetId:
DELEGATE(ArrayGet);
case Expression::Id::ArraySetId:
DELEGATE(ArraySet);
case Expression::Id::ArrayLenId:
DELEGATE(ArrayLen);
case Expression::Id::InvalidId:
WASM_UNREACHABLE("unimp");
case Expression::Id::NumExpressionIds:
WASM_UNREACHABLE("unimp");
}
#undef DELEGATE
return;
}
if (curr->is<If>()) {
self->pushTask(DeadCodeElimination::doVisitIf, currp);
if (curr->cast<If>()->ifFalse) {
self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->ifFalse);
self->pushTask(DeadCodeElimination::doAfterIfElseTrue, currp);
}
self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->ifTrue);
self->pushTask(DeadCodeElimination::doAfterIfCondition, currp);
self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->condition);
} else if (curr->is<Try>()) {
self->pushTask(DeadCodeElimination::doVisitTry, currp);
self->pushTask(DeadCodeElimination::scan, &curr->cast<Try>()->catchBody);
self->pushTask(DeadCodeElimination::doAfterTryBody, currp);
self->pushTask(DeadCodeElimination::scan, &curr->cast<Try>()->body);
self->pushTask(DeadCodeElimination::doBeforeTryBody, currp);
} else {
super::scan(self, currp);
}
}
Expression* drop(Expression* toDrop) {
if (toDrop->type == Type::unreachable) {
return toDrop;
}
return Builder(*getModule()).makeDrop(toDrop);
}
template<typename T> Expression* handleCall(T* curr) {
for (Index i = 0; i < curr->operands.size(); i++) {
if (isUnreachable(curr->operands[i])) {
if (i > 0) {
auto* block = getModule()->allocator.alloc<Block>();
Index newSize = i + 1;
block->list.resize(newSize);
Index j = 0;
for (; j < newSize; j++) {
block->list[j] = drop(curr->operands[j]);
}
block->finalize(curr->type);
return replaceCurrent(block);
} else {
return replaceCurrent(curr->operands[i]);
}
}
}
return curr;
}
void visitCall(Call* curr) {
handleCall(curr);
if (curr->isReturn) {
reachable = false;
}
}
void visitCallIndirect(CallIndirect* curr) {
if (handleCall(curr) != curr) {
return;
}
if (isUnreachable(curr->target)) {
auto* block = getModule()->allocator.alloc<Block>();
for (auto* operand : curr->operands) {
block->list.push_back(drop(operand));
}
block->list.push_back(curr->target);
block->finalize(curr->type);
replaceCurrent(block);
}
if (curr->isReturn) {
reachable = false;
}
}
void blockifyReachableOperands(std::vector<Expression*>&& list, Type type) {
for (size_t i = 0; i < list.size(); ++i) {
auto* elem = list[i];
if (isUnreachable(elem)) {
auto* replacement = elem;
if (i > 0) {
auto* block = getModule()->allocator.alloc<Block>();
for (size_t j = 0; j < i; ++j) {
block->list.push_back(drop(list[j]));
}
block->list.push_back(list[i]);
block->finalize(type);
replacement = block;
}
replaceCurrent(replacement);
return;
}
}
}
void visitLocalSet(LocalSet* curr) {
blockifyReachableOperands({curr->value}, curr->type);
}
void visitGlobalSet(GlobalSet* curr) {
blockifyReachableOperands({curr->value}, curr->type);
}
void visitLoad(Load* curr) {
blockifyReachableOperands({curr->ptr}, curr->type);
}
void visitStore(Store* curr) {
blockifyReachableOperands({curr->ptr, curr->value}, curr->type);
}
void visitAtomicRMW(AtomicRMW* curr) {
blockifyReachableOperands({curr->ptr, curr->value}, curr->type);
}
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
blockifyReachableOperands({curr->ptr, curr->expected, curr->replacement},
curr->type);
}
void visitUnary(Unary* curr) {
blockifyReachableOperands({curr->value}, curr->type);
}
void visitBinary(Binary* curr) {
blockifyReachableOperands({curr->left, curr->right}, curr->type);
}
void visitSelect(Select* curr) {
blockifyReachableOperands({curr->ifTrue, curr->ifFalse, curr->condition},
curr->type);
}
void visitDrop(Drop* curr) {
blockifyReachableOperands({curr->value}, curr->type);
}
void visitMemorySize(MemorySize* curr) {}
void visitMemoryGrow(MemoryGrow* curr) {
blockifyReachableOperands({curr->delta}, curr->type);
}
void visitRefIsNull(RefIsNull* curr) {
blockifyReachableOperands({curr->value}, curr->type);
}
void visitRefEq(RefEq* curr) {
blockifyReachableOperands({curr->left, curr->right}, curr->type);
}
void visitFunction(Function* curr) { assert(reachableBreaks.size() == 0); }
};
Pass* createDeadCodeEliminationPass() { return new DeadCodeElimination(); }
}