#include "stack-utils.h"
#include "ir/iteration.h"
#include "ir/properties.h"
namespace wasm {
namespace StackUtils {
void removeNops(Block* block) {
size_t newIndex = 0;
for (size_t i = 0, size = block->list.size(); i < size; ++i) {
if (!block->list[i]->is<Nop>()) {
block->list[newIndex++] = block->list[i];
}
}
block->list.resize(newIndex);
}
bool mayBeUnreachable(Expression* expr) {
if (Properties::isControlFlowStructure(expr)) {
return true;
}
switch (expr->_id) {
case Expression::Id::BreakId:
return expr->cast<Break>()->condition == nullptr;
case Expression::Id::CallId:
return expr->cast<Call>()->isReturn;
case Expression::Id::CallIndirectId:
return expr->cast<CallIndirect>()->isReturn;
case Expression::Id::ReturnId:
case Expression::Id::SwitchId:
case Expression::Id::UnreachableId:
case Expression::Id::ThrowId:
case Expression::Id::RethrowId:
return true;
default:
return false;
}
}
}
StackSignature::StackSignature(Expression* expr) {
std::vector<Type> inputs;
for (auto* child : ValueChildIterator(expr)) {
assert(child->type.isConcrete());
inputs.insert(inputs.end(), child->type.begin(), child->type.end());
}
params = Type(inputs);
if (expr->type == Type::unreachable) {
unreachable = true;
results = Type::none;
} else {
unreachable = false;
results = expr->type;
}
}
bool StackSignature::composes(const StackSignature& next) const {
auto checked = std::min(results.size(), next.params.size());
return std::equal(results.end() - checked,
results.end(),
next.params.end() - checked,
[](const Type& produced, const Type& consumed) {
return Type::isSubType(produced, consumed);
});
}
bool StackSignature::satisfies(Signature sig) const {
if (sig.params.size() < params.size() ||
sig.results.size() < results.size()) {
return false;
}
bool paramSuffixMatches =
std::equal(params.begin(),
params.end(),
sig.params.end() - params.size(),
[](const Type& consumed, const Type& provided) {
return Type::isSubType(provided, consumed);
});
if (!paramSuffixMatches) {
return false;
}
bool resultSuffixMatches =
std::equal(results.begin(),
results.end(),
sig.results.end() - results.size(),
[](const Type& produced, const Type& expected) {
return Type::isSubType(produced, expected);
});
if (!resultSuffixMatches) {
return false;
}
if (unreachable) {
return true;
}
return std::equal(sig.params.begin(),
sig.params.end() - params.size(),
sig.results.begin(),
sig.results.end() - results.size(),
[](const Type& produced, const Type& expected) {
return Type::isSubType(produced, expected);
});
}
StackSignature& StackSignature::operator+=(const StackSignature& next) {
assert(composes(next));
std::vector<Type> stack(results.begin(), results.end());
size_t required = next.params.size();
if (stack.size() >= required) {
stack.resize(stack.size() - required);
} else {
if (!unreachable) {
size_t unsatisfied = required - stack.size();
std::vector<Type> newParams(next.params.begin(),
next.params.begin() + unsatisfied);
newParams.insert(newParams.end(), params.begin(), params.end());
params = Type(newParams);
}
stack.clear();
}
if (next.unreachable) {
results = next.results;
unreachable = true;
} else {
stack.insert(stack.end(), next.results.begin(), next.results.end());
results = Type(stack);
}
return *this;
}
StackSignature StackSignature::operator+(const StackSignature& next) const {
StackSignature sig = *this;
sig += next;
return sig;
}
StackFlow::StackFlow(Block* block) {
auto processBlock = [&block](auto process) {
for (auto* expr : block->list) {
process(expr, StackSignature(expr));
}
bool unreachable = block->type == Type::unreachable;
Type params = unreachable ? Type::none : block->type;
process(block, StackSignature(params, Type::none, unreachable));
};
std::unordered_map<Expression*, size_t> producedByUnreachable;
{
size_t stackSize = 0;
size_t produced = 0;
Expression* lastUnreachable = nullptr;
processBlock([&](Expression* expr, const StackSignature sig) {
if (sig.params.size() > stackSize) {
assert(lastUnreachable);
produced += sig.params.size() - stackSize;
stackSize = 0;
} else {
stackSize -= sig.params.size();
}
if (sig.unreachable) {
if (lastUnreachable) {
producedByUnreachable[lastUnreachable] = produced;
produced = 0;
}
assert(produced == 0);
lastUnreachable = expr;
stackSize = 0;
} else {
stackSize += sig.results.size();
}
});
if (lastUnreachable) {
producedByUnreachable[lastUnreachable] = produced;
}
}
std::vector<Location> values;
Expression* lastUnreachable = nullptr;
processBlock([&](Expression* expr, const StackSignature sig) {
assert((sig.params.size() <= values.size() || lastUnreachable) &&
"Block inputs not yet supported");
size_t consumed = sig.unreachable
? std::max(values.size(), sig.params.size())
: sig.params.size();
size_t produced =
sig.unreachable ? producedByUnreachable[expr] : sig.results.size();
srcs[expr] = std::vector<Location>(consumed);
dests[expr] = std::vector<Location>(produced);
assert(sig.params.size() <= consumed);
size_t unreachableBeyondStack = 0;
size_t unreachableFromStack = 0;
if (consumed > values.size()) {
assert(consumed == sig.params.size());
unreachableBeyondStack = consumed - values.size();
} else if (consumed > sig.params.size()) {
assert(consumed == values.size());
unreachableFromStack = consumed - sig.params.size();
}
for (Index i = 0; i < consumed; ++i) {
if (i < unreachableBeyondStack) {
assert(lastUnreachable);
assert(producedByUnreachable[lastUnreachable] >=
unreachableBeyondStack);
Index destIndex =
producedByUnreachable[lastUnreachable] - unreachableBeyondStack + i;
Type type = sig.params[i];
srcs[expr][i] = {lastUnreachable, destIndex, type, true};
dests[lastUnreachable][destIndex] = {expr, i, type, false};
} else {
bool unreachable = i < unreachableFromStack;
auto& src = values[values.size() + i - consumed];
srcs[expr][i] = src;
dests[src.expr][src.index] = {expr, i, src.type, unreachable};
}
}
if (unreachableBeyondStack) {
producedByUnreachable[lastUnreachable] -= unreachableBeyondStack;
values.resize(0);
} else {
values.resize(values.size() - consumed);
}
for (Index i = 0; i < sig.results.size(); ++i) {
values.push_back({expr, i, sig.results[i], false});
}
if (sig.unreachable) {
assert(producedByUnreachable[lastUnreachable] == 0);
lastUnreachable = expr;
}
});
}
StackSignature StackFlow::getSignature(Expression* expr) {
auto exprSrcs = srcs.find(expr);
auto exprDests = dests.find(expr);
assert(exprSrcs != srcs.end() && exprDests != dests.end());
std::vector<Type> params, results;
for (auto& src : exprSrcs->second) {
params.push_back(src.type);
}
for (auto& dest : exprDests->second) {
results.push_back(dest.type);
}
bool unreachable = expr->type == Type::unreachable;
return StackSignature(Type(params), Type(results), unreachable);
}
}