#ifndef wasm_ir_branch_h
#define wasm_ir_branch_h
#include "ir/iteration.h"
#include "wasm-traversal.h"
#include "wasm.h"
namespace wasm {
namespace BranchUtils {
inline bool isBranchReachable(Break* br) {
return !(br->value && br->value->type == Type::unreachable) &&
!(br->condition && br->condition->type == Type::unreachable);
}
inline bool isBranchReachable(Switch* sw) {
return !(sw->value && sw->value->type == Type::unreachable) &&
sw->condition->type != Type::unreachable;
}
inline bool isBranchReachable(BrOnExn* br) {
return br->exnref->type != Type::unreachable;
}
inline bool isBranchReachable(Expression* expr) {
if (auto* br = expr->dynCast<Break>()) {
return isBranchReachable(br);
} else if (auto* sw = expr->dynCast<Switch>()) {
return isBranchReachable(sw);
} else if (auto* br = expr->dynCast<BrOnExn>()) {
return isBranchReachable(br);
}
WASM_UNREACHABLE("unexpected expression type");
}
using NameSet = std::set<Name>;
inline NameSet getUniqueTargets(Break* br) { return {br->name}; }
inline NameSet getUniqueTargets(Switch* sw) {
NameSet ret;
for (auto target : sw->targets) {
ret.insert(target);
}
ret.insert(sw->default_);
return ret;
}
inline NameSet getUniqueTargets(BrOnExn* br) { return {br->name}; }
inline NameSet getUniqueTargets(Expression* expr) {
if (auto* br = expr->dynCast<Break>()) {
return getUniqueTargets(br);
}
if (auto* br = expr->dynCast<Switch>()) {
return getUniqueTargets(br);
}
if (auto* br = expr->dynCast<BrOnExn>()) {
return getUniqueTargets(br);
}
return {};
}
inline bool replacePossibleTarget(Expression* branch, Name from, Name to) {
bool worked = false;
if (auto* br = branch->dynCast<Break>()) {
if (br->name == from) {
br->name = to;
worked = true;
}
} else if (auto* sw = branch->dynCast<Switch>()) {
for (auto& target : sw->targets) {
if (target == from) {
target = to;
worked = true;
}
}
if (sw->default_ == from) {
sw->default_ = to;
worked = true;
}
} else if (auto* br = branch->dynCast<BrOnExn>()) {
if (br->name == from) {
br->name = to;
worked = true;
}
} else {
WASM_UNREACHABLE("unexpected expression type");
}
return worked;
}
inline NameSet getExitingBranches(Expression* ast) {
struct Scanner : public PostWalker<Scanner> {
NameSet targets;
void visitBreak(Break* curr) { targets.insert(curr->name); }
void visitSwitch(Switch* curr) {
for (auto target : curr->targets) {
targets.insert(target);
}
targets.insert(curr->default_);
}
void visitBrOnExn(BrOnExn* curr) { targets.insert(curr->name); }
void visitBlock(Block* curr) {
if (curr->name.is()) {
targets.erase(curr->name);
}
}
void visitLoop(Loop* curr) {
if (curr->name.is()) {
targets.erase(curr->name);
}
}
};
Scanner scanner;
scanner.walk(ast);
return scanner.targets;
}
inline NameSet getBranchTargets(Expression* ast) {
struct Scanner : public PostWalker<Scanner> {
NameSet targets;
void visitBlock(Block* curr) {
if (curr->name.is()) {
targets.insert(curr->name);
}
}
void visitLoop(Loop* curr) {
if (curr->name.is()) {
targets.insert(curr->name);
}
}
};
Scanner scanner;
scanner.walk(ast);
return scanner.targets;
}
struct BranchSeeker : public PostWalker<BranchSeeker> {
Name target;
Index found = 0;
Type valueType;
BranchSeeker(Name target) : target(target) {}
void noteFound(Expression* value) {
noteFound(value ? value->type : Type::none);
}
void noteFound(Type type) {
found++;
if (found == 1) {
valueType = Type::unreachable;
}
if (type != Type::unreachable) {
valueType = type;
}
}
void visitBreak(Break* curr) {
if (curr->name == target) {
noteFound(curr->value);
}
}
void visitSwitch(Switch* curr) {
for (auto name : curr->targets) {
if (name == target) {
noteFound(curr->value);
}
}
if (curr->default_ == target) {
noteFound(curr->value);
}
}
void visitBrOnExn(BrOnExn* curr) {
if (curr->name == target) {
noteFound(curr->sent);
}
}
static bool has(Expression* tree, Name target) {
if (!target.is()) {
return false;
}
BranchSeeker seeker(target);
seeker.walk(tree);
return seeker.found > 0;
}
static Index count(Expression* tree, Name target) {
if (!target.is()) {
return 0;
}
BranchSeeker seeker(target);
seeker.walk(tree);
return seeker.found;
}
};
struct BranchAccumulator
: public PostWalker<BranchAccumulator,
UnifiedExpressionVisitor<BranchAccumulator>> {
NameSet branches;
void visitExpression(Expression* curr) {
auto selfBranches = getUniqueTargets(curr);
branches.insert(selfBranches.begin(), selfBranches.end());
}
};
class BranchSeekerCache {
std::unordered_map<Expression*, NameSet> branches;
public:
const NameSet& getBranches(Expression* curr) {
auto iter = branches.find(curr);
if (iter != branches.end()) {
return iter->second;
}
NameSet currBranches;
auto add = [&](NameSet& moreBranches) {
if (currBranches.empty()) {
currBranches.swap(moreBranches);
} else {
currBranches.insert(moreBranches.begin(), moreBranches.end());
}
};
for (auto child : ChildIterator(curr)) {
auto iter = branches.find(child);
if (iter != branches.end()) {
add(iter->second);
branches.erase(iter);
} else {
BranchAccumulator childBranches;
childBranches.walk(child);
add(childBranches.branches);
}
}
auto selfBranches = getUniqueTargets(curr);
add(selfBranches);
return branches[curr] = std::move(currBranches);
}
bool hasBranch(Expression* curr, Name target) {
bool result = getBranches(curr).count(target);
#ifdef BRANCH_UTILS_DEBUG
assert(bresult == BranchSeeker::has(curr, target));
#endif
return result;
}
};
}
}
#endif