#include "analysis/cfg.h"
#include "analysis/lattice.h"
#include "analysis/lattices/inverted.h"
#include "analysis/lattices/shared.h"
#include "analysis/lattices/stack.h"
#include "analysis/lattices/tuple.h"
#include "analysis/lattices/valtype.h"
#include "analysis/lattices/vector.h"
#include "analysis/monotone-analyzer.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm-traversal.h"
#include "wasm.h"
#define TYPE_GENERALIZING_DEBUG 0
#if TYPE_GENERALIZING_DEBUG
#define DBG(statement) statement
#else
#define DBG(statement)
#endif
namespace wasm {
namespace {
using namespace analysis;
using TypeRequirement = Inverted<ValType>;
using LocalTypeRequirements = Shared<Vector<TypeRequirement>>;
using ValueStackTypeRequirements = Stack<TypeRequirement>;
using StateLattice =
analysis::Tuple<LocalTypeRequirements, ValueStackTypeRequirements>;
struct State : StateLattice {
using Element = StateLattice::Element;
static constexpr int LocalsIndex = 0;
static constexpr int StackIndex = 1;
State(Function* func) : StateLattice{Shared{initLocals(func)}, initStack()} {}
void push(Element& elem, Type type) const noexcept {
stackLattice().push(stack(elem), std::move(type));
}
Type pop(Element& elem) const noexcept {
return stackLattice().pop(stack(elem));
}
void clearStack(Element& elem) const noexcept {
stack(elem) = stackLattice().getBottom();
}
const std::vector<Type>& getLocals(Element& elem) const noexcept {
return *locals(elem);
}
const std::vector<Type>& getLocals() const noexcept {
return *locals(getBottom());
}
Type getLocal(Element& elem, Index i) const noexcept {
return getLocals(elem)[i];
}
bool updateLocal(Element& elem, Index i, Type type) const noexcept {
return localsLattice().join(
locals(elem),
Vector<TypeRequirement>::SingletonElement(i, std::move(type)));
}
private:
static LocalTypeRequirements initLocals(Function* func) noexcept {
return Shared{Vector{Inverted{ValType{}}, func->getNumLocals()}};
}
static ValueStackTypeRequirements initStack() noexcept {
return Stack{Inverted{ValType{}}};
}
const LocalTypeRequirements& localsLattice() const noexcept {
return std::get<LocalsIndex>(lattices);
}
const ValueStackTypeRequirements& stackLattice() const noexcept {
return std::get<StackIndex>(lattices);
}
typename LocalTypeRequirements::Element&
locals(Element& elem) const noexcept {
return std::get<LocalsIndex>(elem);
}
const typename LocalTypeRequirements::Element&
locals(const Element& elem) const noexcept {
return std::get<LocalsIndex>(elem);
}
typename ValueStackTypeRequirements::Element&
stack(Element& elem) const noexcept {
return std::get<StackIndex>(elem);
}
const typename ValueStackTypeRequirements::Element&
stack(const Element& elem) const noexcept {
return std::get<StackIndex>(elem);
}
};
struct TransferFn : OverriddenVisitor<TransferFn> {
Module& wasm;
Function* func;
State lattice;
typename State::Element* state = nullptr;
std::vector<std::vector<const BasicBlock*>> localDependents;
std::unordered_set<const BasicBlock*> currDependents;
TransferFn(Module& wasm, Function* func, CFG& cfg)
: wasm(wasm), func(func), lattice(func),
localDependents(func->getNumLocals()) {
auto numLocals = func->getNumLocals();
std::vector<std::unordered_set<const BasicBlock*>> dependentSets(numLocals);
for (const auto& bb : cfg) {
for (const auto* inst : bb) {
if (auto set = inst->dynCast<LocalSet>()) {
dependentSets[set->index].insert(&bb);
}
}
}
for (size_t i = 0, n = dependentSets.size(); i < n; ++i) {
localDependents[i] = std::vector<const BasicBlock*>(
dependentSets[i].begin(), dependentSets[i].end());
}
}
Type pop() noexcept { return lattice.pop(*state); }
void push(Type type) noexcept { lattice.push(*state, type); }
void clearStack() noexcept { lattice.clearStack(*state); }
Type getLocal(Index i) noexcept { return lattice.getLocal(*state, i); }
void updateLocal(Index i, Type type) noexcept {
if (lattice.updateLocal(*state, i, type)) {
currDependents.insert(localDependents[i].begin(),
localDependents[i].end());
}
}
void dumpState() {
#if TYPE_GENERALIZING_DEBUG
std::cerr << "locals: ";
for (size_t i = 0, n = lattice.getLocals(*state).size(); i < n; ++i) {
if (i != 0) {
std::cerr << ", ";
}
std::cerr << getLocal(i);
}
std::cerr << "\nstack: ";
auto& stack = std::get<1>(*state);
for (size_t i = 0, n = stack.size(); i < n; ++i) {
if (i != 0) {
std::cerr << ", ";
}
std::cerr << stack[i];
}
std::cerr << "\n";
#endif }
std::unordered_set<const BasicBlock*>
transfer(const BasicBlock& bb, typename State::Element& elem) noexcept {
DBG(std::cerr << "transferring bb " << bb.getIndex() << "\n");
state = &elem;
assert(currDependents.empty());
const auto& preds = bb.preds();
currDependents.insert(preds.begin(), preds.end());
dumpState();
if (bb.isExit()) {
DBG(std::cerr << "visiting exit\n");
visitFunctionExit();
dumpState();
}
for (auto it = bb.rbegin(); it != bb.rend(); ++it) {
DBG(std::cerr << "visiting " << ShallowExpression{*it} << "\n");
visit(*it);
dumpState();
}
if (bb.isEntry()) {
DBG(std::cerr << "visiting entry\n");
visitFunctionEntry();
dumpState();
}
DBG(std::cerr << "\n");
state = nullptr;
return std::move(currDependents);
}
void visitFunctionExit() {
auto result = func->getResults();
if (result.isRef()) {
push(result);
}
}
void visitFunctionEntry() {
Index i = 0;
Index numParams = func->getNumParams();
Index numLocals = func->getNumLocals();
for (; i < numParams; ++i) {
updateLocal(i, func->getLocalType(i));
}
for (Index i = numParams; i < numLocals; ++i) {
auto type = func->getLocalType(i);
if (type.isRef()) {
updateLocal(i, Type(type.getHeapType().getTop(), Nullable));
} else {
updateLocal(i, type);
}
}
}
void visitNop(Nop* curr) {}
void visitBlock(Block* curr) {}
void visitIf(If* curr) {}
void visitLoop(Loop* curr) {}
void visitBreak(Break* curr) {
if (curr->condition) {
WASM_UNREACHABLE("TODO");
} else {
if (curr->value && curr->value->type.isRef()) {
auto type = pop();
clearStack();
push(type);
} else {
clearStack();
}
}
}
void visitSwitch(Switch* curr) {
if (curr->value && curr->value->type.isRef()) {
auto type = pop();
clearStack();
push(type);
} else {
clearStack();
}
}
template<typename T> void handleCall(T* curr, Type params) {
if (curr->type.isRef()) {
pop();
}
for (auto param : params) {
if (param.isRef()) {
push(param);
}
}
}
void visitCall(Call* curr) {
handleCall(curr, wasm.getFunction(curr->target)->getParams());
}
void visitCallIndirect(CallIndirect* curr) {
handleCall(curr, curr->heapType.getSignature().params);
}
void visitLocalGet(LocalGet* curr) {
if (!curr->type.isRef()) {
return;
}
updateLocal(curr->index, pop());
}
void visitLocalSet(LocalSet* curr) {
if (!curr->value->type.isRef()) {
return;
}
if (curr->isTee()) {
updateLocal(curr->index, pop());
}
push(getLocal(curr->index));
}
void visitGlobalGet(GlobalGet* curr) {
if (curr->type.isRef()) {
pop();
}
}
void visitGlobalSet(GlobalSet* curr) {
auto type = wasm.getGlobal(curr->name)->type;
if (type.isRef()) {
push(type);
}
}
void visitLoad(Load* curr) {}
void visitStore(Store* curr) {}
void visitAtomicRMW(AtomicRMW* curr) {}
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {}
void visitAtomicWait(AtomicWait* curr) {}
void visitAtomicNotify(AtomicNotify* curr) {}
void visitAtomicFence(AtomicFence* curr) {}
void visitSIMDExtract(SIMDExtract* curr) {}
void visitSIMDReplace(SIMDReplace* curr) {}
void visitSIMDShuffle(SIMDShuffle* curr) {}
void visitSIMDTernary(SIMDTernary* curr) {}
void visitSIMDShift(SIMDShift* curr) {}
void visitSIMDLoad(SIMDLoad* curr) {}
void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) {}
void visitMemoryInit(MemoryInit* curr) {}
void visitDataDrop(DataDrop* curr) {}
void visitMemoryCopy(MemoryCopy* curr) {}
void visitMemoryFill(MemoryFill* curr) {}
void visitConst(Const* curr) {}
void visitUnary(Unary* curr) {}
void visitBinary(Binary* curr) {}
void visitSelect(Select* curr) {
if (curr->type.isRef()) {
auto type = pop();
push(type);
push(type);
}
}
void visitDrop(Drop* curr) {
if (curr->type.isRef()) {
pop();
}
}
void visitReturn(Return* curr) {}
void visitMemorySize(MemorySize* curr) {}
void visitMemoryGrow(MemoryGrow* curr) {}
void visitUnreachable(Unreachable* curr) {
clearStack();
}
void visitPop(Pop* curr) { WASM_UNREACHABLE("TODO"); }
void visitRefNull(RefNull* curr) { pop(); }
void visitRefIsNull(RefIsNull* curr) {
push(Type::none);
}
void visitRefFunc(RefFunc* curr) { pop(); }
void visitRefEq(RefEq* curr) {
auto eqref = Type(HeapType::eq, Nullable);
push(eqref);
push(eqref);
}
void visitTableGet(TableGet* curr) {
pop();
}
void visitTableSet(TableSet* curr) {
push(wasm.getTable(curr->table)->type);
}
void visitTableSize(TableSize* curr) {}
void visitTableGrow(TableGrow* curr) {}
void visitTableFill(TableFill* curr) {
push(wasm.getTable(curr->table)->type);
}
void visitTableCopy(TableCopy* curr) {
}
void visitTry(Try* curr) { WASM_UNREACHABLE("TODO"); }
void visitTryTable(TryTable* curr) { WASM_UNREACHABLE("TODO"); }
void visitThrow(Throw* curr) { WASM_UNREACHABLE("TODO"); }
void visitRethrow(Rethrow* curr) { WASM_UNREACHABLE("TODO"); }
void visitThrowRef(ThrowRef* curr) { WASM_UNREACHABLE("TODO"); }
void visitTupleMake(TupleMake* curr) { WASM_UNREACHABLE("TODO"); }
void visitTupleExtract(TupleExtract* curr) { WASM_UNREACHABLE("TODO"); }
void visitRefI31(RefI31* curr) { pop(); }
void visitI31Get(I31Get* curr) { push(Type(HeapType::i31, Nullable)); }
void visitCallRef(CallRef* curr) {
auto sigType = curr->target->type.getHeapType();
if (sigType.isBottom()) {
clearStack();
push(Type(HeapType::nofunc, Nullable));
return;
}
auto sig = sigType.getSignature();
auto numParams = sig.params.size();
std::optional<Type> resultReq;
if (sig.results.isRef()) {
resultReq = pop();
}
auto targetReq = sigType;
while (true) {
auto candidateReq = targetReq.getDeclaredSuperType();
if (!candidateReq) {
break;
}
auto candidateSig = candidateReq->getSignature();
if (resultReq && *resultReq != candidateSig.results &&
Type::isSubType(*resultReq, candidateSig.results)) {
break;
}
for (size_t i = 0; i < numParams; ++i) {
if (candidateSig.params[i] != sig.params[i]) {
goto done;
}
}
targetReq = *candidateReq;
}
done:
auto targetSig = targetReq.getSignature();
for (auto param : targetSig.params) {
if (param.isRef()) {
push(param);
}
}
push(Type(targetReq, Nullable));
}
void visitRefTest(RefTest* curr) {
push(Type::none);
}
void visitRefCast(RefCast* curr) {
pop();
push(Type::none);
}
void visitBrOn(BrOn* curr) {
WASM_UNREACHABLE("TODO");
}
void visitStructNew(StructNew* curr) {
pop();
if (!curr->isWithDefault()) {
auto type = curr->type.getHeapType();
for (const auto& field : type.getStruct().fields) {
if (field.type.isRef()) {
push(field.type);
}
}
}
}
HeapType
generalizeStructType(HeapType type,
Index index,
std::optional<Type> reqFieldType = std::nullopt) {
while (true) {
auto candidateType = type.getDeclaredSuperType();
if (!candidateType) {
break;
}
const auto& candidateFields = candidateType->getStruct().fields;
if (candidateFields.size() <= index) {
break;
}
if (reqFieldType) {
auto candidateFieldType = candidateFields[index].type;
if (candidateFieldType != *reqFieldType &&
Type::isSubType(*reqFieldType, candidateFieldType)) {
break;
}
}
type = *candidateType;
}
return type;
}
void visitStructGet(StructGet* curr) {
auto type = curr->ref->type.getHeapType();
if (type.isBottom()) {
clearStack();
push(Type(HeapType::none, Nullable));
return;
}
std::optional<Type> reqFieldType;
if (curr->type.isRef()) {
reqFieldType = pop();
}
auto generalized = generalizeStructType(type, curr->index, reqFieldType);
push(Type(generalized, Nullable));
}
void visitStructSet(StructSet* curr) {
auto type = curr->ref->type.getHeapType();
if (type.isBottom()) {
clearStack();
push(Type(HeapType::none, Nullable));
if (curr->value->type.isRef()) {
push(Type::none);
}
return;
}
auto generalized = generalizeStructType(type, curr->index);
push(Type(generalized, Nullable));
push(generalized.getStruct().fields[curr->index].type);
}
void visitArrayNew(ArrayNew* curr) {
pop();
if (!curr->isWithDefault()) {
auto type = curr->type.getHeapType();
auto fieldType = type.getArray().element.type;
if (fieldType.isRef()) {
push(fieldType);
}
}
}
void visitArrayNewData(ArrayNewData* curr) {
pop();
}
void visitArrayNewElem(ArrayNewElem* curr) {
pop();
}
void visitArrayNewFixed(ArrayNewFixed* curr) {
pop();
auto type = curr->type.getHeapType();
auto fieldType = type.getArray().element.type;
if (fieldType.isRef()) {
for (size_t i = 0, n = curr->values.size(); i < n; ++i) {
push(fieldType);
}
}
}
HeapType
generalizeArrayType(HeapType type,
std::optional<Type> reqFieldType = std::nullopt) {
while (true) {
auto candidateType = type.getDeclaredSuperType();
if (!candidateType) {
break;
}
if (reqFieldType) {
auto candidateFieldType = candidateType->getArray().element.type;
if (candidateFieldType != *reqFieldType &&
Type::isSubType(*reqFieldType, candidateFieldType)) {
break;
}
}
type = *candidateType;
}
return type;
}
void visitArrayGet(ArrayGet* curr) {
auto type = curr->ref->type.getHeapType();
if (type.isBottom()) {
clearStack();
push(Type(HeapType::none, Nullable));
return;
}
std::optional<Type> reqFieldType;
if (curr->type.isRef()) {
reqFieldType = pop();
}
auto generalized = generalizeArrayType(type, reqFieldType);
push(Type(generalized, Nullable));
}
void visitArraySet(ArraySet* curr) {
auto type = curr->ref->type.getHeapType();
if (type.isBottom()) {
clearStack();
push(Type(HeapType::none, Nullable));
if (curr->value->type.isRef()) {
push(Type::none);
}
return;
}
auto generalized = generalizeArrayType(type);
push(Type(generalized, Nullable));
auto elemType = generalized.getArray().element.type;
if (elemType.isRef()) {
push(elemType);
}
}
void visitArrayLen(ArrayLen* curr) {
push(Type(HeapType::array, Nullable));
}
void visitArrayCopy(ArrayCopy* curr) {
auto destType = curr->destRef->type.getHeapType();
auto srcType = curr->srcRef->type.getHeapType();
if (destType.isBottom() || srcType.isBottom()) {
clearStack();
auto nullref = Type(HeapType::none, Nullable);
push(destType.isBottom() ? nullref : Type::none);
push(srcType.isBottom() ? nullref : Type::none);
return;
}
ArraySet set;
set.ref = curr->destRef;
set.index = nullptr;
set.value = nullptr;
visitArraySet(&set);
ArrayGet get;
get.ref = curr->srcRef;
get.index = nullptr;
get.type = srcType.getArray().element.type;
visitArrayGet(&get);
}
void visitArrayFill(ArrayFill* curr) {
ArraySet set;
set.ref = curr->ref;
set.value = curr->value;
visitArraySet(&set);
}
void visitArrayInitData(ArrayInitData* curr) {
auto type = curr->ref->type.getHeapType();
if (type.isBottom()) {
clearStack();
push(Type(HeapType::none, Nullable));
return;
}
auto generalized = generalizeArrayType(type);
push(Type(generalized, Nullable));
}
void visitArrayInitElem(ArrayInitElem* curr) {
auto type = curr->ref->type.getHeapType();
if (type.isBottom()) {
clearStack();
push(Type(HeapType::none, Nullable));
return;
}
auto generalized = generalizeArrayType(type);
push(Type(generalized, Nullable));
}
void visitRefAs(RefAs* curr) {
auto type = pop();
switch (curr->op) {
case RefAsNonNull:
push(Type(type.getHeapType(), Nullable));
return;
case ExternInternalize:
push(Type(HeapType::ext, type.getNullability()));
return;
case ExternExternalize:
push(Type(HeapType::any, type.getNullability()));
return;
}
WASM_UNREACHABLE("unexpected op");
}
void visitStringNew(StringNew* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringConst(StringConst* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringMeasure(StringMeasure* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringEncode(StringEncode* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringConcat(StringConcat* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringEq(StringEq* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringAs(StringAs* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringWTF8Advance(StringWTF8Advance* curr) {
WASM_UNREACHABLE("TODO");
}
void visitStringWTF16Get(StringWTF16Get* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringIterNext(StringIterNext* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringIterMove(StringIterMove* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringSliceWTF(StringSliceWTF* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringSliceIter(StringSliceIter* curr) { WASM_UNREACHABLE("TODO"); }
void visitContBind(ContBind* curr) { WASM_UNREACHABLE("TODO"); }
void visitContNew(ContNew* curr) { WASM_UNREACHABLE("TODO"); }
void visitResume(Resume* curr) { WASM_UNREACHABLE("TODO"); }
};
struct TypeGeneralizing : WalkerPass<PostWalker<TypeGeneralizing>> {
std::vector<Type> localTypes;
bool refinalize = false;
bool isFunctionParallel() override { return true; }
std::unique_ptr<Pass> create() override {
return std::make_unique<TypeGeneralizing>();
}
void runOnFunction(Module* wasm, Function* func) override {
PassRunner runner(getPassRunner());
runner.add("dce");
runner.runOnFunction(func);
auto cfg = CFG::fromFunction(func);
DBG(cfg.print(std::cerr));
TransferFn txfn(*wasm, func, cfg);
MonotoneCFGAnalyzer analyzer(txfn.lattice, txfn, cfg);
analyzer.evaluate();
localTypes = txfn.lattice.getLocals();
auto numParams = func->getNumParams();
for (Index i = numParams; i < localTypes.size(); ++i) {
func->vars[i - numParams] = localTypes[i];
}
super::runOnFunction(wasm, func);
if (refinalize) {
ReFinalize().walkFunctionInModule(func, wasm);
}
}
void visitLocalGet(LocalGet* curr) {
if (localTypes[curr->index] != curr->type) {
curr->type = localTypes[curr->index];
refinalize = true;
}
}
void visitLocalSet(LocalSet* curr) {
if (curr->isTee() && localTypes[curr->index] != curr->type) {
curr->type = localTypes[curr->index];
refinalize = true;
}
}
};
}
Pass* createTypeGeneralizingPass() { return new TypeGeneralizing; }
}