#ifndef wasm_wasm_traversal_h
#define wasm_wasm_traversal_h
#include "support/small_vector.h"
#include "support/threads.h"
#include "wasm.h"
namespace wasm {
template<typename SubType, typename ReturnType = void> struct Visitor {
#define DELEGATE(CLASS_TO_VISIT) \
ReturnType visit##CLASS_TO_VISIT(CLASS_TO_VISIT* curr) { \
return ReturnType(); \
}
#include "wasm-delegations.def"
ReturnType visitExport(Export* curr) { return ReturnType(); }
ReturnType visitGlobal(Global* curr) { return ReturnType(); }
ReturnType visitFunction(Function* curr) { return ReturnType(); }
ReturnType visitTable(Table* curr) { return ReturnType(); }
ReturnType visitElementSegment(ElementSegment* curr) { return ReturnType(); }
ReturnType visitMemory(Memory* curr) { return ReturnType(); }
ReturnType visitDataSegment(DataSegment* curr) { return ReturnType(); }
ReturnType visitTag(Tag* curr) { return ReturnType(); }
ReturnType visitModule(Module* curr) { return ReturnType(); }
ReturnType visit(Expression* curr) {
assert(curr);
switch (curr->_id) {
#define DELEGATE(CLASS_TO_VISIT) \
case Expression::Id::CLASS_TO_VISIT##Id: \
return static_cast<SubType*>(this)->visit##CLASS_TO_VISIT( \
static_cast<CLASS_TO_VISIT*>(curr))
#include "wasm-delegations.def"
default:
WASM_UNREACHABLE("unexpected expression type");
}
}
};
template<typename SubType, typename ReturnType = void>
struct OverriddenVisitor : public Visitor<SubType, ReturnType> {
#define DELEGATE(CLASS_TO_VISIT) \
ReturnType visit##CLASS_TO_VISIT(CLASS_TO_VISIT* curr) { \
static_assert( \
&SubType::visit##CLASS_TO_VISIT != \
&OverriddenVisitor<SubType, ReturnType>::visit##CLASS_TO_VISIT, \
"Derived class must implement visit" #CLASS_TO_VISIT); \
WASM_UNREACHABLE("Derived class must implement visit" #CLASS_TO_VISIT); \
}
#include "wasm-delegations.def"
};
template<typename SubType, typename ReturnType = void>
struct UnifiedExpressionVisitor : public Visitor<SubType, ReturnType> {
ReturnType visitExpression(Expression* curr) { return ReturnType(); }
#define DELEGATE(CLASS_TO_VISIT) \
ReturnType visit##CLASS_TO_VISIT(CLASS_TO_VISIT* curr) { \
return static_cast<SubType*>(this)->visitExpression(curr); \
}
#include "wasm-delegations.def"
};
template<typename SubType, typename VisitorType>
struct Walker : public VisitorType {
Expression* replaceCurrent(Expression* expression) {
if (currFunction) {
auto& debugLocations = currFunction->debugLocations;
if (!debugLocations.empty() && !debugLocations.count(expression)) {
auto* curr = getCurrent();
auto iter = debugLocations.find(curr);
if (iter != debugLocations.end()) {
debugLocations[expression] = iter->second;
}
}
}
return *replacep = expression;
}
Expression* getCurrent() { return *replacep; }
Expression** getCurrentPointer() { return replacep; }
Module* getModule() { return currModule; }
Function* getFunction() { return currFunction; }
void walkGlobal(Global* global) {
walk(global->init);
static_cast<SubType*>(this)->visitGlobal(global);
}
void walkFunction(Function* func) {
setFunction(func);
static_cast<SubType*>(this)->doWalkFunction(func);
static_cast<SubType*>(this)->visitFunction(func);
setFunction(nullptr);
}
void walkTag(Tag* tag) { static_cast<SubType*>(this)->visitTag(tag); }
void walkFunctionInModule(Function* func, Module* module) {
setModule(module);
setFunction(func);
static_cast<SubType*>(this)->doWalkFunction(func);
static_cast<SubType*>(this)->visitFunction(func);
setFunction(nullptr);
setModule(nullptr);
}
void doWalkFunction(Function* func) { walk(func->body); }
void walkElementSegment(ElementSegment* segment) {
if (segment->table.is()) {
walk(segment->offset);
}
for (auto* expr : segment->data) {
walk(expr);
}
static_cast<SubType*>(this)->visitElementSegment(segment);
}
void walkTable(Table* table) {
static_cast<SubType*>(this)->visitTable(table);
}
void walkDataSegment(DataSegment* segment) {
if (!segment->isPassive) {
walk(segment->offset);
}
static_cast<SubType*>(this)->visitDataSegment(segment);
}
void walkMemory(Memory* memory) {
static_cast<SubType*>(this)->visitMemory(memory);
}
void walkModule(Module* module) {
setModule(module);
static_cast<SubType*>(this)->doWalkModule(module);
static_cast<SubType*>(this)->visitModule(module);
setModule(nullptr);
}
void doWalkModule(Module* module) {
SubType* self = static_cast<SubType*>(this);
for (auto& curr : module->exports) {
self->visitExport(curr.get());
}
for (auto& curr : module->globals) {
if (curr->imported()) {
self->visitGlobal(curr.get());
} else {
self->walkGlobal(curr.get());
}
}
for (auto& curr : module->functions) {
if (curr->imported()) {
self->visitFunction(curr.get());
} else {
self->walkFunction(curr.get());
}
}
for (auto& curr : module->tags) {
if (curr->imported()) {
self->visitTag(curr.get());
} else {
self->walkTag(curr.get());
}
}
for (auto& curr : module->tables) {
self->walkTable(curr.get());
}
for (auto& curr : module->elementSegments) {
self->walkElementSegment(curr.get());
}
for (auto& curr : module->memories) {
self->walkMemory(curr.get());
}
for (auto& curr : module->dataSegments) {
self->walkDataSegment(curr.get());
}
}
void walkModuleCode(Module* module) {
setModule(module);
SubType* self = static_cast<SubType*>(this);
for (auto& curr : module->globals) {
if (!curr->imported()) {
self->walk(curr->init);
}
}
for (auto& curr : module->elementSegments) {
if (curr->offset) {
self->walk(curr->offset);
}
for (auto* item : curr->data) {
self->walk(item);
}
}
for (auto& curr : module->dataSegments) {
if (curr->offset) {
self->walk(curr->offset);
}
}
setModule(nullptr);
}
using TaskFunc = void (*)(SubType*, Expression**);
struct Task {
TaskFunc func;
Expression** currp;
Task() {}
Task(TaskFunc func, Expression** currp) : func(func), currp(currp) {}
};
void pushTask(TaskFunc func, Expression** currp) {
assert(*currp);
stack.emplace_back(func, currp);
}
void maybePushTask(TaskFunc func, Expression** currp) {
if (*currp) {
stack.emplace_back(func, currp);
}
}
Task popTask() {
auto ret = stack.back();
stack.pop_back();
return ret;
}
void walk(Expression*& root) {
assert(stack.size() == 0);
pushTask(SubType::scan, &root);
while (stack.size() > 0) {
auto task = popTask();
replacep = task.currp;
assert(*task.currp);
task.func(static_cast<SubType*>(this), task.currp);
}
}
static void scan(SubType* self, Expression** currp) { abort(); }
#define DELEGATE(CLASS_TO_VISIT) \
static void doVisit##CLASS_TO_VISIT(SubType* self, Expression** currp) { \
self->visit##CLASS_TO_VISIT((*currp)->cast<CLASS_TO_VISIT>()); \
}
#include "wasm-delegations.def"
void setModule(Module* module) { currModule = module; }
void setFunction(Function* func) { currFunction = func; }
private:
Expression** replacep = nullptr;
SmallVector<Task, 10> stack; Function* currFunction = nullptr; Module* currModule = nullptr; };
template<typename SubType, typename VisitorType = Visitor<SubType>>
struct PostWalker : public Walker<SubType, VisitorType> {
static void scan(SubType* self, Expression** currp) {
Expression* curr = *currp;
#define DELEGATE_ID curr->_id
#define DELEGATE_START(id) \
self->pushTask(SubType::doVisit##id, currp); \
[[maybe_unused]] auto* cast = curr->cast<id>();
#define DELEGATE_GET_FIELD(id, field) cast->field
#define DELEGATE_FIELD_CHILD(id, field) \
self->pushTask(SubType::scan, &cast->field);
#define DELEGATE_FIELD_OPTIONAL_CHILD(id, field) \
self->maybePushTask(SubType::scan, &cast->field);
#define DELEGATE_FIELD_INT(id, field)
#define DELEGATE_FIELD_LITERAL(id, field)
#define DELEGATE_FIELD_NAME(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_USE(id, field)
#define DELEGATE_FIELD_TYPE(id, field)
#define DELEGATE_FIELD_HEAPTYPE(id, field)
#define DELEGATE_FIELD_ADDRESS(id, field)
#include "wasm-delegations-fields.def"
}
};
using ExpressionStack = SmallVector<Expression*, 10>;
template<typename SubType, typename VisitorType = Visitor<SubType>>
struct ControlFlowWalker : public PostWalker<SubType, VisitorType> {
ExpressionStack controlFlowStack;
Expression* findBreakTarget(Name name) {
assert(!controlFlowStack.empty());
Index i = controlFlowStack.size() - 1;
while (true) {
auto* curr = controlFlowStack[i];
if (Block* block = curr->template dynCast<Block>()) {
if (name == block->name) {
return curr;
}
} else if (Loop* loop = curr->template dynCast<Loop>()) {
if (name == loop->name) {
return curr;
}
} else {
assert(curr->template is<If>() || curr->template is<Try>() ||
curr->template is<TryTable>());
}
if (i == 0) {
return nullptr;
}
i--;
}
}
static void doPreVisitControlFlow(SubType* self, Expression** currp) {
self->controlFlowStack.push_back(*currp);
}
static void doPostVisitControlFlow(SubType* self, Expression** currp) {
self->controlFlowStack.pop_back();
}
static void scan(SubType* self, Expression** currp) {
auto* curr = *currp;
switch (curr->_id) {
case Expression::Id::BlockId:
case Expression::Id::IfId:
case Expression::Id::LoopId:
case Expression::Id::TryId:
case Expression::Id::TryTableId: {
self->pushTask(SubType::doPostVisitControlFlow, currp);
break;
}
default: {
}
}
PostWalker<SubType, VisitorType>::scan(self, currp);
switch (curr->_id) {
case Expression::Id::BlockId:
case Expression::Id::IfId:
case Expression::Id::LoopId:
case Expression::Id::TryId:
case Expression::Id::TryTableId: {
self->pushTask(SubType::doPreVisitControlFlow, currp);
break;
}
default: {
}
}
}
};
template<typename SubType, typename VisitorType = Visitor<SubType>>
struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> {
ExpressionStackWalker() = default;
ExpressionStack expressionStack;
Expression* findBreakTarget(Name name) {
assert(!expressionStack.empty());
Index i = expressionStack.size() - 1;
while (true) {
auto* curr = expressionStack[i];
if (Block* block = curr->template dynCast<Block>()) {
if (name == block->name) {
return curr;
}
} else if (Loop* loop = curr->template dynCast<Loop>()) {
if (name == loop->name) {
return curr;
}
}
if (i == 0) {
return nullptr;
}
i--;
}
}
Expression* getParent() {
if (expressionStack.size() == 1) {
return nullptr;
}
assert(expressionStack.size() >= 2);
return expressionStack[expressionStack.size() - 2];
}
static void doPreVisit(SubType* self, Expression** currp) {
self->expressionStack.push_back(*currp);
}
static void doPostVisit(SubType* self, Expression** currp) {
self->expressionStack.pop_back();
}
static void scan(SubType* self, Expression** currp) {
self->pushTask(SubType::doPostVisit, currp);
PostWalker<SubType, VisitorType>::scan(self, currp);
self->pushTask(SubType::doPreVisit, currp);
}
Expression* replaceCurrent(Expression* expression) {
PostWalker<SubType, VisitorType>::replaceCurrent(expression);
expressionStack.back() = expression;
return expression;
}
};
}
#endif