#ifndef cfg_traversal_h
#define cfg_traversal_h
#include "wasm-traversal.h"
#include "wasm.h"
namespace wasm {
template<typename SubType, typename VisitorType, typename Contents>
struct CFGWalker : public ControlFlowWalker<SubType, VisitorType> {
struct BasicBlock {
Contents contents; std::vector<BasicBlock*> out, in;
};
BasicBlock* entry;
BasicBlock* makeBasicBlock() { return new BasicBlock(); }
std::vector<std::unique_ptr<BasicBlock>> basicBlocks; std::vector<BasicBlock*> loopTops;
BasicBlock* currBasicBlock;
std::map<Expression*, std::vector<BasicBlock*>> branches;
std::vector<BasicBlock*> ifStack;
std::vector<BasicBlock*> loopStack;
std::vector<BasicBlock*> tryStack;
std::vector<BasicBlock*> catchStack;
void startBasicBlock() {
currBasicBlock = ((SubType*)this)->makeBasicBlock();
basicBlocks.push_back(std::unique_ptr<BasicBlock>(currBasicBlock));
}
void startUnreachableBlock() { currBasicBlock = nullptr; }
static void doStartUnreachableBlock(SubType* self, Expression** currp) {
self->startUnreachableBlock();
}
void link(BasicBlock* from, BasicBlock* to) {
if (!from || !to) {
return; }
from->out.push_back(to);
to->in.push_back(from);
}
static void doEndBlock(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<Block>();
if (!curr->name.is()) {
return;
}
auto iter = self->branches.find(curr);
if (iter == self->branches.end()) {
return;
}
auto& origins = iter->second;
if (origins.size() == 0) {
return;
}
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); for (auto* origin : origins) {
self->link(origin, self->currBasicBlock);
}
self->branches.erase(curr);
}
static void doStartIfTrue(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); self->ifStack.push_back(last); }
static void doStartIfFalse(SubType* self, Expression** currp) {
self->ifStack.push_back(self->currBasicBlock); self->startBasicBlock();
self->link(self->ifStack[self->ifStack.size() - 2],
self->currBasicBlock); }
static void doEndIf(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock);
if ((*currp)->cast<If>()->ifFalse) {
self->link(self->ifStack.back(), self->currBasicBlock);
self->ifStack.pop_back();
} else {
self->link(self->ifStack.back(), self->currBasicBlock);
}
self->ifStack.pop_back();
}
static void doStartLoop(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->loopTops.push_back(self->currBasicBlock);
self->link(last, self->currBasicBlock);
self->loopStack.push_back(self->currBasicBlock);
}
static void doEndLoop(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); auto* curr = (*currp)->cast<Loop>();
if (curr->name.is()) {
auto* loopStart = self->loopStack.back();
auto& origins = self->branches[curr];
for (auto* origin : origins) {
self->link(origin, loopStart);
}
self->branches.erase(curr);
}
self->loopStack.pop_back();
}
static void doEndBreak(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<Break>();
self->branches[self->findBreakTarget(curr->name)].push_back(
self->currBasicBlock); if (curr->condition) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); } else {
self->startUnreachableBlock();
}
}
static void doEndSwitch(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<Switch>();
std::set<Name> seen;
for (Name target : curr->targets) {
if (!seen.count(target)) {
self->branches[self->findBreakTarget(target)].push_back(
self->currBasicBlock); seen.insert(target);
}
}
if (!seen.count(curr->default_)) {
self->branches[self->findBreakTarget(curr->default_)].push_back(
self->currBasicBlock); }
self->startUnreachableBlock();
}
static void doEndCall(SubType* self, Expression** currp) {
if (!self->catchStack.empty()) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); self->link(last, self->catchStack.back()); }
}
static void doStartTry(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock(); self->catchStack.push_back(self->currBasicBlock);
self->currBasicBlock = last; }
static void doStartCatch(SubType* self, Expression** currp) {
self->tryStack.push_back(self->currBasicBlock); self->currBasicBlock = self->catchStack.back();
self->catchStack.pop_back();
}
static void doEndTry(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock(); self->link(last, self->currBasicBlock);
self->link(self->tryStack.back(), self->currBasicBlock);
self->tryStack.pop_back();
}
static void doEndThrow(SubType* self, Expression** currp) {
if (!self->catchStack.empty()) {
self->link(self->currBasicBlock, self->catchStack.back());
}
self->startUnreachableBlock();
}
static void doEndBrOnExn(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<BrOnExn>();
self->branches[self->findBreakTarget(curr->name)].push_back(
self->currBasicBlock); auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); }
static void scan(SubType* self, Expression** currp) {
Expression* curr = *currp;
switch (curr->_id) {
case Expression::Id::BlockId: {
self->pushTask(SubType::doEndBlock, currp);
break;
}
case Expression::Id::IfId: {
self->pushTask(SubType::doEndIf, currp);
auto* ifFalse = curr->cast<If>()->ifFalse;
if (ifFalse) {
self->pushTask(SubType::scan, &curr->cast<If>()->ifFalse);
self->pushTask(SubType::doStartIfFalse, currp);
}
self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue);
self->pushTask(SubType::doStartIfTrue, currp);
self->pushTask(SubType::scan, &curr->cast<If>()->condition);
return; }
case Expression::Id::LoopId: {
self->pushTask(SubType::doEndLoop, currp);
break;
}
case Expression::Id::BreakId: {
self->pushTask(SubType::doEndBreak, currp);
break;
}
case Expression::Id::SwitchId: {
self->pushTask(SubType::doEndSwitch, currp);
break;
}
case Expression::Id::ReturnId: {
self->pushTask(SubType::doStartUnreachableBlock, currp);
break;
}
case Expression::Id::UnreachableId: {
self->pushTask(SubType::doStartUnreachableBlock, currp);
break;
}
case Expression::Id::CallId:
case Expression::Id::CallIndirectId: {
self->pushTask(SubType::doEndCall, currp);
break;
}
case Expression::Id::TryId: {
self->pushTask(SubType::doEndTry, currp);
self->pushTask(SubType::scan, &curr->cast<Try>()->catchBody);
self->pushTask(SubType::doStartCatch, currp);
self->pushTask(SubType::scan, &curr->cast<Try>()->body);
self->pushTask(SubType::doStartTry, currp);
return; }
case Expression::Id::ThrowId:
case Expression::Id::RethrowId: {
self->pushTask(SubType::doEndThrow, currp);
break;
}
case Expression::Id::BrOnExnId: {
self->pushTask(SubType::doEndBrOnExn, currp);
break;
}
default: {}
}
ControlFlowWalker<SubType, VisitorType>::scan(self, currp);
switch (curr->_id) {
case Expression::Id::LoopId: {
self->pushTask(SubType::doStartLoop, currp);
break;
}
default: {}
}
}
void doWalkFunction(Function* func) {
basicBlocks.clear();
debugIds.clear();
startBasicBlock();
entry = currBasicBlock;
ControlFlowWalker<SubType, VisitorType>::doWalkFunction(func);
assert(branches.size() == 0);
assert(ifStack.size() == 0);
assert(loopStack.size() == 0);
assert(tryStack.size() == 0);
assert(catchStack.size() == 0);
}
std::unordered_set<BasicBlock*> findLiveBlocks() {
std::unordered_set<BasicBlock*> alive;
std::unordered_set<BasicBlock*> queue;
queue.insert(entry);
while (queue.size() > 0) {
auto iter = queue.begin();
auto* curr = *iter;
queue.erase(iter);
alive.insert(curr);
for (auto* out : curr->out) {
if (!alive.count(out)) {
queue.insert(out);
}
}
}
return alive;
}
void unlinkDeadBlocks(std::unordered_set<BasicBlock*> alive) {
for (auto& block : basicBlocks) {
if (!alive.count(block.get())) {
block->in.clear();
block->out.clear();
continue;
}
block->in.erase(std::remove_if(block->in.begin(),
block->in.end(),
[&alive](BasicBlock* other) {
return !alive.count(other);
}),
block->in.end());
block->out.erase(std::remove_if(block->out.begin(),
block->out.end(),
[&alive](BasicBlock* other) {
return !alive.count(other);
}),
block->out.end());
}
}
std::map<BasicBlock*, size_t> debugIds;
void generateDebugIds() {
if (debugIds.size() > 0) {
return;
}
for (auto& block : basicBlocks) {
debugIds[block.get()] = debugIds.size();
}
}
void dumpCFG(std::string message) {
std::cout << "<==\nCFG [" << message << "]:\n";
generateDebugIds();
for (auto& block : basicBlocks) {
assert(debugIds.count(block.get()) > 0);
std::cout << " block " << debugIds[block.get()] << ":\n";
block->contents.dump(static_cast<SubType*>(this)->getFunction());
for (auto& in : block->in) {
assert(debugIds.count(in) > 0);
assert(std::find(in->out.begin(), in->out.end(), block.get()) !=
in->out.end()); }
for (auto& out : block->out) {
assert(debugIds.count(out) > 0);
std::cout << " out: " << debugIds[out] << "\n";
assert(std::find(out->in.begin(), out->in.end(), block.get()) !=
out->in.end()); }
checkDuplicates(block->in);
checkDuplicates(block->out);
}
std::cout << "==>\n";
}
private:
void checkDuplicates(std::vector<BasicBlock*>& list) {
std::unordered_set<BasicBlock*> seen;
for (auto* curr : list) {
assert(seen.count(curr) == 0);
seen.insert(curr);
}
}
void removeLink(std::vector<BasicBlock*>& list, BasicBlock* toRemove) {
if (list.size() == 1) {
list.clear();
return;
}
for (size_t i = 0; i < list.size(); i++) {
if (list[i] == toRemove) {
list[i] = list.back();
list.pop_back();
return;
}
}
WASM_UNREACHABLE("not found");
}
};
}
#endif