#ifndef wasm_dataflow_graph_h
#define wasm_dataflow_graph_h
#include "dataflow/node.h"
#include "ir/abstract.h"
#include "ir/iteration.h"
#include "ir/literal-utils.h"
#include "wasm.h"
namespace wasm {
namespace DataFlow {
struct Graph : public UnifiedExpressionVisitor<Graph, Node*> {
Node bad = Node(Node::Type::Bad);
std::unordered_map<LocalSet*, Node*> setNodeMap;
std::unordered_map<Expression*, std::vector<Node*>> expressionConditionMap;
std::unordered_map<Expression*, Expression*> expressionParentMap;
std::unordered_map<Node*, Expression*> nodeParentMap;
std::vector<LocalSet*> sets;
Function* func;
Module* module;
std::vector<std::unique_ptr<Node>> nodes;
Expression* parent = nullptr;
typedef std::vector<Node*> Locals;
Locals locals;
std::unordered_map<Name, std::vector<Locals>> breakStates;
struct FlowState {
Locals locals;
Node* condition;
FlowState(Locals locals, Node* condition)
: locals(locals), condition(condition) {}
};
void build(Function* funcInit, Module* moduleInit) {
func = funcInit;
module = moduleInit;
auto numLocals = func->getNumLocals();
if (numLocals == 0) {
return; }
setInReachable();
for (Index i = 0; i < numLocals; i++) {
if (!isRelevantType(func->getLocalType(i))) {
continue;
}
Node* node;
auto type = func->getLocalType(i);
if (func->isParam(i)) {
node = makeVar(type);
} else {
node = makeZero(type);
}
locals[i] = node;
}
visit(func->body);
}
Node* makeVar(wasm::Type type) {
if (isRelevantType(type)) {
return addNode(Node::makeVar(type));
} else {
return &bad;
}
}
std::unordered_map<Literal, Node*> constantNodes;
Node* makeConst(Literal value) {
auto iter = constantNodes.find(value);
if (iter != constantNodes.end()) {
return iter->second;
}
Builder builder(*module);
auto* c = builder.makeConst(value);
auto* ret = addNode(Node::makeExpr(c, c));
constantNodes[value] = ret;
return ret;
}
Node* makeZero(wasm::Type type) { return makeConst(Literal::makeZero(type)); }
Node* addNode(Node* node) {
nodes.push_back(std::unique_ptr<Node>(node));
return node;
}
Node* makeZeroComp(Node* node, bool equal, Expression* origin) {
assert(!node->isBad());
Builder builder(*module);
auto type = node->getWasmType();
if (!type.isConcrete()) {
return &bad;
}
auto* zero = makeZero(type);
auto* expr = builder.makeBinary(
Abstract::getBinary(type, equal ? Abstract::Eq : Abstract::Ne),
makeUse(node),
makeUse(zero));
auto* check = addNode(Node::makeExpr(expr, origin));
check->addValue(expandFromI1(node, origin));
check->addValue(zero);
return check;
}
void setInUnreachable() { locals.clear(); }
void setInReachable() { locals.resize(func->getNumLocals()); }
bool isInUnreachable() { return isInUnreachable(locals); }
bool isInUnreachable(const Locals& state) { return state.empty(); }
bool isInUnreachable(const FlowState& state) {
return isInUnreachable(state.locals);
}
Node* visitExpression(Expression* curr) {
if (auto* block = curr->dynCast<Block>()) {
return doVisitBlock(block);
} else if (auto* iff = curr->dynCast<If>()) {
return doVisitIf(iff);
} else if (auto* loop = curr->dynCast<Loop>()) {
return doVisitLoop(loop);
} else if (auto* get = curr->dynCast<LocalGet>()) {
return doVisitLocalGet(get);
} else if (auto* set = curr->dynCast<LocalSet>()) {
return doVisitLocalSet(set);
} else if (auto* br = curr->dynCast<Break>()) {
return doVisitBreak(br);
} else if (auto* sw = curr->dynCast<Switch>()) {
return doVisitSwitch(sw);
} else if (auto* c = curr->dynCast<Const>()) {
return doVisitConst(c);
} else if (auto* unary = curr->dynCast<Unary>()) {
return doVisitUnary(unary);
} else if (auto* binary = curr->dynCast<Binary>()) {
return doVisitBinary(binary);
} else if (auto* select = curr->dynCast<Select>()) {
return doVisitSelect(select);
} else if (auto* unreachable = curr->dynCast<Unreachable>()) {
return doVisitUnreachable(unreachable);
} else if (auto* drop = curr->dynCast<Drop>()) {
return doVisitDrop(drop);
} else if (curr->is<Try>() || curr->is<Throw>() || curr->is<Rethrow>() ||
curr->is<BrOnExn>()) {
Fatal() << "DataFlow does not support EH instructions yet";
} else {
return doVisitGeneric(curr);
}
}
Node* doVisitBlock(Block* curr) {
auto* oldParent = parent;
expressionParentMap[curr] = oldParent;
parent = curr;
for (auto* child : curr->list) {
visit(child);
}
if (curr->name.is()) {
auto iter = breakStates.find(curr->name);
if (iter != breakStates.end()) {
auto& states = iter->second;
if (!isInUnreachable()) {
states.push_back(locals);
}
mergeBlock(states, locals);
}
}
parent = oldParent;
return &bad;
}
Node* doVisitIf(If* curr) {
auto* oldParent = parent;
expressionParentMap[curr] = oldParent;
parent = curr;
Node* condition = visit(curr->condition);
assert(condition);
auto initialState = locals;
visit(curr->ifTrue);
auto afterIfTrueState = locals;
if (curr->ifFalse) {
locals = initialState;
visit(curr->ifFalse);
auto afterIfFalseState = locals; mergeIf(afterIfTrueState, afterIfFalseState, condition, curr, locals);
} else {
mergeIf(initialState, afterIfTrueState, condition, curr, locals);
}
parent = oldParent;
return &bad;
}
Node* doVisitLoop(Loop* curr) {
auto* oldParent = parent;
expressionParentMap[curr] = oldParent;
parent = curr;
if (isInUnreachable()) {
return &bad; }
if (!curr->name.is()) {
visit(curr->body);
return &bad; }
auto previous = locals;
auto numLocals = func->getNumLocals();
for (Index i = 0; i < numLocals; i++) {
locals[i] = makeVar(func->getLocalType(i));
}
auto vars = locals; auto firstNodeFromLoop = nodes.size();
visit(curr->body);
auto& breaks = breakStates[curr->name];
for (Index i = 0; i < numLocals; i++) {
if (!isRelevantType(func->getLocalType(i))) {
continue;
}
bool needPhi = false;
auto* var = vars[i];
auto* proper = previous[i];
for (auto& other : breaks) {
assert(!isInUnreachable(other));
auto& curr = *(other[i]);
if (curr != *var && curr != *proper) {
needPhi = true;
break;
}
}
if (needPhi) {
} else {
for (auto j = firstNodeFromLoop; j < nodes.size(); j++) {
for (auto*& value : nodes[j].get()->values) {
if (value == var) {
value = proper;
}
}
}
for (auto*& node : locals) {
if (node == var) {
node = proper;
}
}
}
}
return &bad;
}
Node* doVisitBreak(Break* curr) {
if (!isInUnreachable()) {
breakStates[curr->name].push_back(locals);
}
if (!curr->condition) {
setInUnreachable();
} else {
visit(curr->condition);
}
return &bad;
}
Node* doVisitSwitch(Switch* curr) {
visit(curr->condition);
if (!isInUnreachable()) {
std::unordered_set<Name> targets;
for (auto target : curr->targets) {
targets.insert(target);
}
targets.insert(curr->default_);
for (auto target : targets) {
breakStates[target].push_back(locals);
}
}
setInUnreachable();
return &bad;
}
Node* doVisitLocalGet(LocalGet* curr) {
if (!isRelevantLocal(curr->index) || isInUnreachable()) {
return &bad;
}
auto* node = locals[curr->index];
return node;
}
Node* doVisitLocalSet(LocalSet* curr) {
if (!isRelevantLocal(curr->index) || isInUnreachable()) {
return &bad;
}
assert(curr->value->type.isConcrete());
sets.push_back(curr);
expressionParentMap[curr] = parent;
expressionParentMap[curr->value] = curr;
auto* node = visit(curr->value);
locals[curr->index] = setNodeMap[curr] = node;
if (nodeParentMap.find(node) == nodeParentMap.end()) {
nodeParentMap[node] = curr;
}
return &bad;
}
Node* doVisitConst(Const* curr) { return makeConst(curr->value); }
Node* doVisitUnary(Unary* curr) {
switch (curr->op) {
case ClzInt32:
case ClzInt64:
case CtzInt32:
case CtzInt64:
case PopcntInt32:
case PopcntInt64: {
auto* value = expandFromI1(visit(curr->value), curr);
if (value->isBad()) {
return value;
}
auto* ret = addNode(Node::makeExpr(curr, curr));
ret->addValue(value);
return ret;
}
case EqZInt32:
case EqZInt64: {
auto* value = expandFromI1(visit(curr->value), curr);
if (value->isBad()) {
return value;
}
return makeZeroComp(value, true, curr);
}
default: {
return makeVar(curr->type);
}
}
}
Node* doVisitBinary(Binary* curr) {
switch (curr->op) {
case AddInt32:
case AddInt64:
case SubInt32:
case SubInt64:
case MulInt32:
case MulInt64:
case DivSInt32:
case DivSInt64:
case DivUInt32:
case DivUInt64:
case RemSInt32:
case RemSInt64:
case RemUInt32:
case RemUInt64:
case AndInt32:
case AndInt64:
case OrInt32:
case OrInt64:
case XorInt32:
case XorInt64:
case ShlInt32:
case ShlInt64:
case ShrUInt32:
case ShrUInt64:
case ShrSInt32:
case ShrSInt64:
case RotLInt32:
case RotLInt64:
case RotRInt32:
case RotRInt64:
case EqInt32:
case EqInt64:
case NeInt32:
case NeInt64:
case LtSInt32:
case LtSInt64:
case LtUInt32:
case LtUInt64:
case LeSInt32:
case LeSInt64:
case LeUInt32:
case LeUInt64: {
auto* left = expandFromI1(visit(curr->left), curr);
if (left->isBad()) {
return left;
}
auto* right = expandFromI1(visit(curr->right), curr);
if (right->isBad()) {
return right;
}
auto* ret = addNode(Node::makeExpr(curr, curr));
ret->addValue(left);
ret->addValue(right);
return ret;
}
case GtSInt32:
case GtSInt64:
case GeSInt32:
case GeSInt64:
case GtUInt32:
case GtUInt64:
case GeUInt32:
case GeUInt64: {
Builder builder(*module);
BinaryOp opposite;
switch (curr->op) {
case GtSInt32:
opposite = LtSInt32;
break;
case GtSInt64:
opposite = LtSInt64;
break;
case GeSInt32:
opposite = LeSInt32;
break;
case GeSInt64:
opposite = LeSInt64;
break;
case GtUInt32:
opposite = LtUInt32;
break;
case GtUInt64:
opposite = LtUInt64;
break;
case GeUInt32:
opposite = LeUInt32;
break;
case GeUInt64:
opposite = LeUInt64;
break;
default:
WASM_UNREACHABLE("unexpected op");
}
auto* ret =
visitBinary(builder.makeBinary(opposite, curr->right, curr->left));
ret->origin = curr;
return ret;
}
default: {
return makeVar(curr->type);
}
}
}
Node* doVisitSelect(Select* curr) {
auto* ifTrue = expandFromI1(visit(curr->ifTrue), curr);
if (ifTrue->isBad()) {
return ifTrue;
}
auto* ifFalse = expandFromI1(visit(curr->ifFalse), curr);
if (ifFalse->isBad()) {
return ifFalse;
}
auto* condition = ensureI1(visit(curr->condition), curr);
if (condition->isBad()) {
return condition;
}
auto* ret = addNode(Node::makeExpr(curr, curr));
ret->addValue(condition);
ret->addValue(ifTrue);
ret->addValue(ifFalse);
return ret;
}
Node* doVisitUnreachable(Unreachable* curr) {
setInUnreachable();
return &bad;
}
Node* doVisitDrop(Drop* curr) {
visit(curr->value);
expressionParentMap[curr->value] = curr;
return &bad;
}
Node* doVisitGeneric(Expression* curr) {
for (auto* child : ChildIterator(curr)) {
visit(child);
}
return makeVar(curr->type);
}
bool isRelevantType(wasm::Type type) { return type.isInteger(); }
bool isRelevantLocal(Index index) {
return isRelevantType(func->getLocalType(index));
}
void mergeIf(Locals& aState,
Locals& bState,
Node* condition,
Expression* expr,
Locals& out) {
Node* ifTrue;
Node* ifFalse;
if (!condition->isBad()) {
auto& conditions = expressionConditionMap[expr];
ifTrue = ensureI1(condition, nullptr);
conditions.push_back(ifTrue);
ifFalse = makeZeroComp(condition, true, nullptr);
conditions.push_back(ifFalse);
} else {
ifTrue = ifFalse = &bad;
}
std::vector<FlowState> states;
if (!isInUnreachable(aState)) {
states.emplace_back(aState, ifTrue);
}
if (!isInUnreachable(bState)) {
states.emplace_back(bState, ifFalse);
}
merge(states, out);
}
void mergeBlock(std::vector<Locals>& localses, Locals& out) {
std::vector<FlowState> states;
for (auto& locals : localses) {
states.emplace_back(locals, &bad);
}
merge(states, out);
}
void merge(std::vector<FlowState>& states, Locals& out) {
#ifndef NDEBUG
for (auto& state : states) {
assert(!isInUnreachable(state.locals));
}
#endif
Index numStates = states.size();
if (numStates == 0) {
assert(isInUnreachable());
return;
}
setInReachable();
if (numStates == 1) {
out = states[0].locals;
return;
}
Index numLocals = func->getNumLocals();
Node* block = nullptr;
for (Index i = 0; i < numLocals; i++) {
if (!isRelevantType(func->getLocalType(i))) {
continue;
}
bool bad = false;
for (auto& state : states) {
auto* node = state.locals[i];
if (node->isBad()) {
bad = true;
out[i] = node;
break;
}
}
if (bad) {
continue;
}
Node* first = nullptr;
for (auto& state : states) {
if (!first) {
first = out[i] = state.locals[i];
} else if (state.locals[i] != first) {
if (!block) {
block = addNode(Node::makeBlock());
for (Index index = 0; index < numStates; index++) {
auto* condition = states[index].condition;
if (!condition->isBad()) {
condition = addNode(Node::makeCond(block, index, condition));
}
block->addValue(condition);
}
}
auto* phi = addNode(Node::makePhi(block, i));
for (auto& state : states) {
auto* value = expandFromI1(state.locals[i], nullptr);
phi->addValue(value);
}
out[i] = phi;
break;
}
}
}
}
Node* expandFromI1(Node* node, Expression* origin) {
if (!node->isBad() && node->returnsI1()) {
node = addNode(Node::makeZext(node, origin));
}
return node;
}
Node* ensureI1(Node* node, Expression* origin) {
if (!node->isBad() && !node->returnsI1()) {
node = makeZeroComp(node, false, origin);
}
return node;
}
LocalSet* getSet(Node* node) {
auto iter = nodeParentMap.find(node);
if (iter == nodeParentMap.end()) {
return nullptr;
}
return iter->second->dynCast<LocalSet>();
}
Expression* getParent(Expression* curr) {
auto iter = expressionParentMap.find(curr);
if (iter == expressionParentMap.end()) {
return nullptr;
}
return iter->second;
}
LocalSet* getSet(Expression* curr) {
auto* parent = getParent(curr);
return parent ? parent->dynCast<LocalSet>() : nullptr;
}
Expression* makeUse(Node* node) {
Builder builder(*module);
if (node->isPhi()) {
auto index = node->index;
return builder.makeLocalGet(index, func->getLocalType(index));
} else if (node->isConst()) {
return builder.makeConst(node->expr->cast<Const>()->value);
} else if (node->isExpr()) {
auto index = getSet(node)->index;
return builder.makeLocalGet(index, func->getLocalType(index));
} else if (node->isZext()) {
return makeUse(node->values[0]);
} else if (node->isVar()) {
return Builder(*module).makeCall(FAKE_CALL, {}, node->wasmType);
} else {
WASM_UNREACHABLE("unexpected node type"); }
}
const Name FAKE_CALL = "fake$dfo$call";
};
}
}
#endif