#ifndef wasm_ir_branch_h
#define wasm_ir_branch_h
#include "ir/iteration.h"
#include "wasm-traversal.h"
#include "wasm.h"
namespace wasm::BranchUtils {
inline bool isBranchReachable(Expression* expr) {
for (auto child : ChildIterator(expr)) {
if (child->type == Type::unreachable) {
return false;
}
}
return true;
}
template<typename T> void operateOnScopeNameUses(Expression* expr, T func) {
#define DELEGATE_ID expr->_id
#define DELEGATE_START(id) [[maybe_unused]] auto* cast = expr->cast<id>();
#define DELEGATE_GET_FIELD(id, field) cast->field
#define DELEGATE_FIELD_SCOPE_NAME_USE(id, field) func(cast->field);
#define DELEGATE_FIELD_CHILD(id, field)
#define DELEGATE_FIELD_INT(id, field)
#define DELEGATE_FIELD_LITERAL(id, field)
#define DELEGATE_FIELD_NAME(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, field)
#define DELEGATE_FIELD_TYPE(id, field)
#define DELEGATE_FIELD_HEAPTYPE(id, field)
#define DELEGATE_FIELD_ADDRESS(id, field)
#include "wasm-delegations-fields.def"
}
template<typename T>
void operateOnScopeNameUsesAndSentTypes(Expression* expr, T func) {
operateOnScopeNameUses(expr, [&](Name& name) {
if (auto* br = expr->dynCast<Break>()) {
func(name, br->value ? br->value->type : Type::none);
} else if (auto* sw = expr->dynCast<Switch>()) {
func(name, sw->value ? sw->value->type : Type::none);
} else if (auto* br = expr->dynCast<BrOn>()) {
func(name, br->getSentType());
} else if (auto* tt = expr->dynCast<TryTable>()) {
for (Index i = 0; i < tt->catchTags.size(); i++) {
auto dest = tt->catchDests[i];
if (dest == name) {
func(name, tt->sentTypes[i]);
}
}
} else if (auto* r = expr->dynCast<Resume>()) {
for (Index i = 0; i < r->handlerTags.size(); i++) {
auto dest = r->handlerTags[i];
if (dest == name) {
func(name, r->sentTypes[i]);
}
}
} else {
assert(expr->is<Try>() || expr->is<Rethrow>()); }
});
}
template<typename T>
void operateOnScopeNameUsesAndSentValues(Expression* expr, T func) {
operateOnScopeNameUses(expr, [&](Name& name) {
if (auto* br = expr->dynCast<Break>()) {
func(name, br->value);
} else if (auto* sw = expr->dynCast<Switch>()) {
func(name, sw->value);
} else if (auto* br = expr->dynCast<BrOn>()) {
func(name, br->ref);
} else if (auto* tt = expr->dynCast<TryTable>()) {
func(name, nullptr);
} else if (auto* res = expr->dynCast<Resume>()) {
func(name, nullptr);
} else {
assert(expr->is<Try>() || expr->is<Rethrow>()); }
});
}
template<typename T> void operateOnScopeNameDefs(Expression* expr, T func) {
#define DELEGATE_ID expr->_id
#define DELEGATE_START(id) [[maybe_unused]] auto* cast = expr->cast<id>();
#define DELEGATE_GET_FIELD(id, field) cast->field
#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, field) func(cast->field)
#define DELEGATE_FIELD_CHILD(id, field)
#define DELEGATE_FIELD_INT(id, field)
#define DELEGATE_FIELD_LITERAL(id, field)
#define DELEGATE_FIELD_NAME(id, field)
#define DELEGATE_FIELD_TYPE(id, field)
#define DELEGATE_FIELD_HEAPTYPE(id, field)
#define DELEGATE_FIELD_ADDRESS(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_USE(id, field)
#include "wasm-delegations-fields.def"
}
using NameSet = std::set<Name>;
inline NameSet getUniqueTargets(Expression* expr) {
NameSet ret;
operateOnScopeNameUses(expr, [&](Name& name) { ret.insert(name); });
return ret;
}
inline bool replacePossibleTarget(Expression* branch, Name from, Name to) {
bool worked = false;
operateOnScopeNameUses(branch, [&](Name& name) {
if (name == from) {
name = to;
worked = true;
}
});
return worked;
}
inline void replaceExceptionTargets(Expression* ast, Name from, Name to) {
struct Replacer
: public PostWalker<Replacer, UnifiedExpressionVisitor<Replacer>> {
Name from, to;
Replacer(Name from, Name to) : from(from), to(to) {}
void visitExpression(Expression* curr) {
if (curr->is<Try>() || curr->is<Rethrow>()) {
operateOnScopeNameUses(curr, [&](Name& name) {
if (name == from) {
name = to;
}
});
}
}
};
Replacer replacer(from, to);
replacer.walk(ast);
}
inline void replaceBranchTargets(Expression* ast, Name from, Name to) {
struct Replacer
: public PostWalker<Replacer, UnifiedExpressionVisitor<Replacer>> {
Name from, to;
Replacer(Name from, Name to) : from(from), to(to) {}
void visitExpression(Expression* curr) {
if (Properties::isBranch(curr)) {
operateOnScopeNameUses(curr, [&](Name& name) {
if (name == from) {
name = to;
}
});
}
}
};
Replacer replacer(from, to);
replacer.walk(ast);
}
inline NameSet getExitingBranches(Expression* ast) {
struct Scanner
: public PostWalker<Scanner, UnifiedExpressionVisitor<Scanner>> {
NameSet targets;
void visitExpression(Expression* curr) {
operateOnScopeNameDefs(curr, [&](Name& name) {
if (name.is()) {
targets.erase(name);
}
});
operateOnScopeNameUses(curr, [&](Name& name) { targets.insert(name); });
}
};
Scanner scanner;
scanner.walk(ast);
return scanner.targets;
}
inline NameSet getBranchTargets(Expression* ast) {
struct Scanner
: public PostWalker<Scanner, UnifiedExpressionVisitor<Scanner>> {
NameSet targets;
void visitExpression(Expression* curr) {
operateOnScopeNameDefs(curr, [&](Name& name) {
if (name.is()) {
targets.insert(name);
}
});
}
};
Scanner scanner;
scanner.walk(ast);
return scanner.targets;
}
inline bool hasBranchTarget(Expression* ast, Name target) {
if (!target.is()) {
return false;
}
struct Scanner
: public PostWalker<Scanner, UnifiedExpressionVisitor<Scanner>> {
Name target;
bool has = false;
void visitExpression(Expression* curr) {
operateOnScopeNameDefs(curr, [&](Name& name) {
if (name == target) {
has = true;
}
});
}
};
Scanner scanner;
scanner.target = target;
scanner.walk(ast);
return scanner.has;
}
inline Name getDefinedName(Expression* curr) {
Name ret;
operateOnScopeNameDefs(curr, [&](Name& name) { ret = name; });
return ret;
}
inline Expression* getSentValue(Expression* curr) {
Expression* ret = nullptr;
operateOnScopeNameUsesAndSentValues(
curr, [&](Name name, Expression* value) { ret = value; });
return ret;
}
struct BranchSeeker
: public PostWalker<BranchSeeker, UnifiedExpressionVisitor<BranchSeeker>> {
Name target;
Index found = 0;
std::unordered_set<Type> types;
BranchSeeker(Name target) : target(target) {}
void noteFound(Type newType) {
found++;
types.insert(newType);
}
void visitExpression(Expression* curr) {
operateOnScopeNameUsesAndSentTypes(curr, [&](Name& name, Type type) {
if (name == target) {
noteFound(type);
}
});
}
static bool has(Expression* tree, Name target) {
if (!target.is()) {
return false;
}
BranchSeeker seeker(target);
seeker.walk(tree);
return seeker.found > 0;
}
static Index count(Expression* tree, Name target) {
if (!target.is()) {
return 0;
}
BranchSeeker seeker(target);
seeker.walk(tree);
return seeker.found;
}
};
struct BranchAccumulator
: public PostWalker<BranchAccumulator,
UnifiedExpressionVisitor<BranchAccumulator>> {
NameSet branches;
void visitExpression(Expression* curr) {
auto selfBranches = getUniqueTargets(curr);
branches.insert(selfBranches.begin(), selfBranches.end());
}
static NameSet get(Expression* tree) {
BranchAccumulator accumulator;
accumulator.walk(tree);
return accumulator.branches;
}
};
class BranchSeekerCache {
std::unordered_map<Expression*, NameSet> branches;
public:
const NameSet& getBranches(Expression* curr) {
auto iter = branches.find(curr);
if (iter != branches.end()) {
return iter->second;
}
NameSet currBranches;
auto add = [&](NameSet& moreBranches) {
if (currBranches.empty()) {
currBranches.swap(moreBranches);
} else {
currBranches.insert(moreBranches.begin(), moreBranches.end());
}
};
for (auto child : ChildIterator(curr)) {
auto iter = branches.find(child);
if (iter != branches.end()) {
add(iter->second);
branches.erase(iter);
} else {
BranchAccumulator childBranches;
childBranches.walk(child);
add(childBranches.branches);
}
}
auto selfBranches = getUniqueTargets(curr);
add(selfBranches);
return branches[curr] = std::move(currBranches);
}
bool hasBranch(Expression* curr, Name target) {
bool result = getBranches(curr).count(target);
#ifdef BRANCH_UTILS_DEBUG
assert(bresult == BranchSeeker::has(curr, target));
#endif
return result;
}
};
struct BranchTargets {
BranchTargets(Expression* expr) { inner.walk(expr); }
Expression* getTarget(Name name) { return inner.targets[name]; }
std::unordered_set<Expression*> getBranches(Name name) {
auto iter = inner.branches.find(name);
if (iter != inner.branches.end()) {
return iter->second;
}
return {};
}
private:
struct Inner : public PostWalker<Inner, UnifiedExpressionVisitor<Inner>> {
void visitExpression(Expression* curr) {
operateOnScopeNameDefs(curr, [&](Name name) {
if (name.is()) {
targets[name] = curr;
}
});
operateOnScopeNameUses(curr, [&](Name& name) {
if (name.is()) {
branches[name].insert(curr);
}
});
}
std::map<Name, Expression*> targets;
std::map<Name, std::unordered_set<Expression*>> branches;
} inner;
};
}
#endif