#include "type-updating.h"
#include "find_all.h"
#include "ir/local-structural-dominance.h"
#include "ir/module-utils.h"
#include "ir/utils.h"
#include "support/topological_sort.h"
#include "wasm-type-ordering.h"
#include "wasm-type.h"
#include "wasm.h"
namespace wasm {
GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {}
void GlobalTypeRewriter::update() { mapTypes(rebuildTypes()); }
GlobalTypeRewriter::TypeMap GlobalTypeRewriter::rebuildTypes(
const std::vector<HeapType>& additionalPrivateTypes) {
Index i = 0;
auto privateTypes = ModuleUtils::getPrivateHeapTypes(wasm);
for (auto t : additionalPrivateTypes) {
privateTypes.push_back(t);
}
struct SupertypesFirst
: HeapTypeOrdering::SupertypesFirstBase<SupertypesFirst> {
GlobalTypeRewriter& parent;
SupertypesFirst(GlobalTypeRewriter& parent) : parent(parent) {}
std::optional<HeapType> getDeclaredSuperType(HeapType type) {
return parent.getDeclaredSuperType(type);
}
};
SupertypesFirst sortedTypes(*this);
for (auto type : sortedTypes.sort(privateTypes)) {
typeIndices[type] = i++;
}
if (typeIndices.size() == 0) {
return {};
}
typeBuilder.grow(typeIndices.size());
typeBuilder.createRecGroup(0, typeBuilder.size());
i = 0;
for (auto [type, _] : typeIndices) {
typeBuilder[i].setOpen(type.isOpen());
if (type.isSignature()) {
auto sig = type.getSignature();
TypeList newParams, newResults;
for (auto t : sig.params) {
newParams.push_back(getTempType(t));
}
for (auto t : sig.results) {
newResults.push_back(getTempType(t));
}
Signature newSig(typeBuilder.getTempTupleType(newParams),
typeBuilder.getTempTupleType(newResults));
modifySignature(type, newSig);
typeBuilder[i] = newSig;
} else if (type.isStruct()) {
auto struct_ = type.getStruct();
auto newStruct = struct_;
for (auto& field : newStruct.fields) {
field.type = getTempType(field.type);
}
modifyStruct(type, newStruct);
typeBuilder[i] = newStruct;
} else if (type.isArray()) {
auto array = type.getArray();
auto newArray = array;
newArray.element.type = getTempType(newArray.element.type);
modifyArray(type, newArray);
typeBuilder[i] = newArray;
} else {
WASM_UNREACHABLE("bad type");
}
if (auto super = getDeclaredSuperType(type)) {
if (auto it = typeIndices.find(*super); it != typeIndices.end()) {
assert(it->second < i);
typeBuilder[i].subTypeOf(typeBuilder[it->second]);
} else {
typeBuilder[i].subTypeOf(*super);
}
}
modifyTypeBuilderEntry(typeBuilder, i, type);
i++;
}
auto buildResults = typeBuilder.build();
#ifndef NDEBUG
if (auto* err = buildResults.getError()) {
Fatal() << "Internal GlobalTypeRewriter build error: " << err->reason
<< " at index " << err->index;
}
#endif
auto& newTypes = *buildResults;
TypeMap oldToNewTypes;
for (auto [type, index] : typeIndices) {
oldToNewTypes[type] = newTypes[index];
}
for (auto& [old, new_] : oldToNewTypes) {
if (auto it = wasm.typeNames.find(old); it != wasm.typeNames.end()) {
wasm.typeNames[new_] = it->second;
}
}
return oldToNewTypes;
}
void GlobalTypeRewriter::mapTypes(const TypeMap& oldToNewTypes) {
struct CodeUpdater
: public WalkerPass<
PostWalker<CodeUpdater, UnifiedExpressionVisitor<CodeUpdater>>> {
bool isFunctionParallel() override { return true; }
const TypeMap& oldToNewTypes;
CodeUpdater(const TypeMap& oldToNewTypes) : oldToNewTypes(oldToNewTypes) {}
std::unique_ptr<Pass> create() override {
return std::make_unique<CodeUpdater>(oldToNewTypes);
}
Type getNew(Type type) {
if (type.isRef()) {
return Type(getNew(type.getHeapType()), type.getNullability());
}
if (type.isTuple()) {
auto tuple = type.getTuple();
for (auto& t : tuple) {
t = getNew(t);
}
return Type(tuple);
}
return type;
}
HeapType getNew(HeapType type) {
auto iter = oldToNewTypes.find(type);
if (iter != oldToNewTypes.end()) {
return iter->second;
}
return type;
}
Signature getNew(Signature sig) {
return Signature(getNew(sig.params), getNew(sig.results));
}
void visitExpression(Expression* curr) {
if (auto* get = curr->dynCast<LocalGet>()) {
curr->type = getFunction()->getLocalType(get->index);
return;
} else if (auto* tee = curr->dynCast<LocalSet>()) {
if (tee->type != Type::none && tee->type != Type::unreachable) {
curr->type = getFunction()->getLocalType(tee->index);
}
return;
}
curr->type = getNew(curr->type);
#define DELEGATE_ID curr->_id
#define DELEGATE_START(id) [[maybe_unused]] auto* cast = curr->cast<id>();
#define DELEGATE_GET_FIELD(id, field) cast->field
#define DELEGATE_FIELD_TYPE(id, field) cast->field = getNew(cast->field);
#define DELEGATE_FIELD_HEAPTYPE(id, field) cast->field = getNew(cast->field);
#define DELEGATE_FIELD_CHILD(id, field)
#define DELEGATE_FIELD_OPTIONAL_CHILD(id, field)
#define DELEGATE_FIELD_INT(id, field)
#define DELEGATE_FIELD_LITERAL(id, field)
#define DELEGATE_FIELD_NAME(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_USE(id, field)
#define DELEGATE_FIELD_ADDRESS(id, field)
#include "wasm-delegations-fields.def"
}
};
CodeUpdater updater(oldToNewTypes);
PassRunner runner(&wasm);
for (auto& func : wasm.functions) {
func->type = updater.getNew(func->type);
for (auto& var : func->vars) {
var = updater.getNew(var);
}
}
updater.run(&runner, &wasm);
updater.walkModuleCode(&wasm);
for (auto& table : wasm.tables) {
table->type = updater.getNew(table->type);
}
for (auto& elementSegment : wasm.elementSegments) {
elementSegment->type = updater.getNew(elementSegment->type);
}
for (auto& global : wasm.globals) {
global->type = updater.getNew(global->type);
}
for (auto& tag : wasm.tags) {
tag->sig = updater.getNew(tag->sig);
}
}
Type GlobalTypeRewriter::getTempType(Type type) {
if (type.isBasic()) {
return type;
}
if (type.isRef()) {
auto heapType = type.getHeapType();
if (auto it = typeIndices.find(heapType); it != typeIndices.end()) {
return typeBuilder.getTempRefType(typeBuilder[it->second],
type.getNullability());
}
return type;
}
if (type.isTuple()) {
auto& tuple = type.getTuple();
auto newTuple = tuple;
for (auto& t : newTuple) {
t = getTempType(t);
}
return typeBuilder.getTempTupleType(newTuple);
}
WASM_UNREACHABLE("bad type");
}
Type GlobalTypeRewriter::getTempTupleType(Tuple tuple) {
return typeBuilder.getTempTupleType(tuple);
}
namespace TypeUpdating {
bool canHandleAsLocal(Type type) {
return type.isConcrete();
}
void handleNonDefaultableLocals(Function* func, Module& wasm) {
if (!wasm.features.hasReferenceTypes()) {
return;
}
bool hasNonNullable = false;
for (auto varType : func->vars) {
for (auto type : varType) {
if (type.isNonNullable()) {
hasNonNullable = true;
break;
}
}
}
if (!hasNonNullable) {
return;
}
LocalStructuralDominance info(
func, wasm, LocalStructuralDominance::NonNullableOnly);
std::unordered_set<Index> badIndexes;
for (auto index : info.nonDominatingIndices) {
badIndexes.insert(index);
assert(func->getLocalType(index).isNonNullable() ||
func->getLocalType(index).isTuple());
assert(!func->isParam(index));
}
if (badIndexes.empty()) {
return;
}
Builder builder(wasm);
for (auto** getp : FindAllPointers<LocalGet>(func->body).list) {
auto* get = (*getp)->cast<LocalGet>();
if (badIndexes.count(get->index)) {
*getp = fixLocalGet(get, wasm);
}
}
for (auto** setp : FindAllPointers<LocalSet>(func->body).list) {
auto* set = (*setp)->cast<LocalSet>();
if (!func->isVar(set->index)) {
continue;
}
if (!set->isTee() || set->type == Type::unreachable) {
continue;
}
if (badIndexes.count(set->index)) {
auto type = func->getLocalType(set->index);
auto validType = getValidLocalType(type, wasm.features);
if (type.isRef()) {
set->type = validType;
*setp = builder.makeRefAs(RefAsNonNull, set);
} else {
assert(type.isTuple());
set->makeSet();
std::vector<Expression*> elems(type.size());
for (size_t i = 0, size = type.size(); i < size; ++i) {
elems[i] = builder.makeTupleExtract(
builder.makeLocalGet(set->index, validType), i);
if (type[i].isNonNullable()) {
elems[i] = builder.makeRefAs(RefAsNonNull, elems[i]);
}
}
*setp =
builder.makeSequence(set, builder.makeTupleMake(std::move(elems)));
}
}
}
for (auto index : badIndexes) {
func->vars[index - func->getNumParams()] =
getValidLocalType(func->getLocalType(index), wasm.features);
}
}
Type getValidLocalType(Type type, FeatureSet features) {
assert(type.isConcrete());
if (type.isNonNullable()) {
return Type(type.getHeapType(), Nullable);
}
if (type.isTuple()) {
std::vector<Type> elems(type.size());
for (size_t i = 0, size = type.size(); i < size; ++i) {
elems[i] = getValidLocalType(type[i], features);
}
return Type(std::move(elems));
}
return type;
}
Expression* fixLocalGet(LocalGet* get, Module& wasm) {
if (get->type.isNonNullable()) {
get->type = getValidLocalType(get->type, wasm.features);
return Builder(wasm).makeRefAs(RefAsNonNull, get);
}
if (get->type.isTuple()) {
auto type = get->type;
get->type = getValidLocalType(type, wasm.features);
std::vector<Expression*> elems(type.size());
Builder builder(wasm);
for (Index i = 0, size = type.size(); i < size; ++i) {
auto* elemGet =
i == 0 ? get : builder.makeLocalGet(get->index, get->type);
elems[i] = builder.makeTupleExtract(elemGet, i);
if (type[i].isNonNullable()) {
elems[i] = builder.makeRefAs(RefAsNonNull, elems[i]);
}
}
return builder.makeTupleMake(std::move(elems));
}
return get;
}
void updateParamTypes(Function* func,
const std::vector<Type>& newParamTypes,
Module& wasm,
LocalUpdatingMode localUpdating) {
std::unordered_map<Index, Index> paramFixups;
FindAll<LocalSet> sets(func->body);
for (auto* set : sets.list) {
auto index = set->index;
if (func->isParam(index) && !paramFixups.count(index) &&
!Type::isSubType(set->value->type, newParamTypes[index])) {
paramFixups[index] = Builder::addVar(func, func->getLocalType(index));
}
}
FindAll<LocalGet> gets(func->body);
if (!paramFixups.empty()) {
Builder builder(wasm);
std::vector<Expression*> contents;
for (Index index = 0; index < func->getNumParams(); index++) {
auto iter = paramFixups.find(index);
if (iter != paramFixups.end()) {
auto fixup = iter->second;
contents.push_back(builder.makeLocalSet(
fixup,
builder.makeLocalGet(index,
localUpdating == Update
? newParamTypes[index]
: func->getLocalType(index))));
}
}
contents.push_back(func->body);
func->body = builder.makeBlock(contents);
for (auto* get : gets.list) {
auto iter = paramFixups.find(get->index);
if (iter != paramFixups.end()) {
get->index = iter->second;
}
}
for (auto* set : sets.list) {
auto iter = paramFixups.find(set->index);
if (iter != paramFixups.end()) {
set->index = iter->second;
}
}
}
if (localUpdating == Update) {
for (auto* get : gets.list) {
auto index = get->index;
if (func->isParam(index)) {
get->type = newParamTypes[index];
}
}
for (auto* set : sets.list) {
auto index = set->index;
if (func->isParam(index) && set->isTee()) {
set->type = newParamTypes[index];
set->finalize();
}
}
}
ReFinalize().walkFunctionInModule(func, &wasm);
if (!paramFixups.empty()) {
TypeUpdating::handleNonDefaultableLocals(func, wasm);
}
}
}
}