#include "module-utils.h"
#include "ir/debug.h"
#include "ir/intrinsics.h"
#include "ir/manipulation.h"
#include "ir/properties.h"
#include "support/insert_ordered.h"
#include "support/topological_sort.h"
namespace wasm::ModuleUtils {
static void updateLocationSet(std::set<Function::DebugLocation>& locations,
std::vector<Index>& fileIndexMap) {
std::set<Function::DebugLocation> updatedLocations;
for (auto iter : locations) {
iter.fileIndex = fileIndexMap[iter.fileIndex];
updatedLocations.insert(iter);
}
locations.clear();
std::swap(locations, updatedLocations);
}
Function* copyFunction(Function* func,
Module& out,
Name newName,
std::optional<std::vector<Index>> fileIndexMap) {
auto ret = std::make_unique<Function>();
ret->name = newName.is() ? newName : func->name;
ret->type = func->type;
ret->vars = func->vars;
ret->localNames = func->localNames;
ret->localIndices = func->localIndices;
ret->body = ExpressionManipulator::copy(func->body, out);
debug::copyDebugInfo(func->body, ret->body, func, ret.get());
ret->prologLocation = func->prologLocation;
ret->epilogLocation = func->epilogLocation;
if (fileIndexMap) {
for (auto& iter : ret->debugLocations) {
iter.second.fileIndex = (*fileIndexMap)[iter.second.fileIndex];
}
updateLocationSet(ret->prologLocation, *fileIndexMap);
updateLocationSet(ret->epilogLocation, *fileIndexMap);
}
ret->module = func->module;
ret->base = func->base;
ret->noFullInline = func->noFullInline;
ret->noPartialInline = func->noPartialInline;
assert(!func->stackIR);
return out.addFunction(std::move(ret));
}
Global* copyGlobal(Global* global, Module& out) {
auto* ret = new Global();
ret->name = global->name;
ret->type = global->type;
ret->mutable_ = global->mutable_;
ret->module = global->module;
ret->base = global->base;
if (global->imported()) {
ret->init = nullptr;
} else {
ret->init = ExpressionManipulator::copy(global->init, out);
}
out.addGlobal(ret);
return ret;
}
Tag* copyTag(Tag* tag, Module& out) {
auto* ret = new Tag();
ret->name = tag->name;
ret->sig = tag->sig;
ret->module = tag->module;
ret->base = tag->base;
out.addTag(ret);
return ret;
}
ElementSegment* copyElementSegment(const ElementSegment* segment, Module& out) {
auto copy = [&](std::unique_ptr<ElementSegment>&& ret) {
ret->name = segment->name;
ret->hasExplicitName = segment->hasExplicitName;
ret->type = segment->type;
ret->data.reserve(segment->data.size());
for (auto* item : segment->data) {
ret->data.push_back(ExpressionManipulator::copy(item, out));
}
return out.addElementSegment(std::move(ret));
};
if (segment->table.isNull()) {
return copy(std::make_unique<ElementSegment>());
} else {
auto offset = ExpressionManipulator::copy(segment->offset, out);
return copy(std::make_unique<ElementSegment>(segment->table, offset));
}
}
Table* copyTable(const Table* table, Module& out) {
auto ret = std::make_unique<Table>();
ret->name = table->name;
ret->hasExplicitName = table->hasExplicitName;
ret->type = table->type;
ret->module = table->module;
ret->base = table->base;
ret->initial = table->initial;
ret->max = table->max;
return out.addTable(std::move(ret));
}
Memory* copyMemory(const Memory* memory, Module& out) {
auto ret = Builder::makeMemory(memory->name);
ret->hasExplicitName = memory->hasExplicitName;
ret->initial = memory->initial;
ret->max = memory->max;
ret->shared = memory->shared;
ret->indexType = memory->indexType;
ret->module = memory->module;
ret->base = memory->base;
return out.addMemory(std::move(ret));
}
DataSegment* copyDataSegment(const DataSegment* segment, Module& out) {
auto ret = Builder::makeDataSegment();
ret->name = segment->name;
ret->hasExplicitName = segment->hasExplicitName;
ret->memory = segment->memory;
ret->isPassive = segment->isPassive;
if (!segment->isPassive) {
auto offset = ExpressionManipulator::copy(segment->offset, out);
ret->offset = offset;
}
ret->data = segment->data;
return out.addDataSegment(std::move(ret));
}
void copyModuleItems(const Module& in, Module& out) {
std::optional<std::vector<Index>> fileIndexMap;
if (!in.debugInfoFileNames.empty()) {
std::unordered_map<std::string, Index> debugInfoFileIndices;
for (Index i = 0; i < out.debugInfoFileNames.size(); i++) {
debugInfoFileIndices[out.debugInfoFileNames[i]] = i;
}
fileIndexMap.emplace();
for (Index i = 0; i < in.debugInfoFileNames.size(); i++) {
std::string file = in.debugInfoFileNames[i];
auto iter = debugInfoFileIndices.find(file);
if (iter == debugInfoFileIndices.end()) {
Index index = out.debugInfoFileNames.size();
out.debugInfoFileNames.push_back(file);
debugInfoFileIndices[file] = index;
}
fileIndexMap->push_back(debugInfoFileIndices[file]);
}
}
for (auto& curr : in.functions) {
copyFunction(curr.get(), out, Name(), fileIndexMap);
}
for (auto& curr : in.globals) {
copyGlobal(curr.get(), out);
}
for (auto& curr : in.tags) {
copyTag(curr.get(), out);
}
for (auto& curr : in.elementSegments) {
copyElementSegment(curr.get(), out);
}
for (auto& curr : in.tables) {
copyTable(curr.get(), out);
}
for (auto& curr : in.memories) {
copyMemory(curr.get(), out);
}
for (auto& curr : in.dataSegments) {
copyDataSegment(curr.get(), out);
}
}
void copyModule(const Module& in, Module& out) {
for (auto& curr : in.exports) {
out.addExport(std::make_unique<Export>(*curr));
}
copyModuleItems(in, out);
out.start = in.start;
out.customSections = in.customSections;
out.debugInfoFileNames = in.debugInfoFileNames;
out.features = in.features;
out.typeNames = in.typeNames;
}
void clearModule(Module& wasm) {
wasm.~Module();
new (&wasm) Module;
}
template<typename T> void renameFunctions(Module& wasm, T& map) {
for (auto& [oldName, newName] : map) {
if (Function* func = wasm.getFunctionOrNull(oldName)) {
assert(!wasm.getFunctionOrNull(newName) || func->name == newName);
func->name = newName;
}
}
wasm.updateMaps();
struct Updater : public WalkerPass<PostWalker<Updater>> {
bool isFunctionParallel() override { return true; }
T& map;
void maybeUpdate(Name& name) {
if (auto iter = map.find(name); iter != map.end()) {
name = iter->second;
}
}
Updater(T& map) : map(map) {}
std::unique_ptr<Pass> create() override {
return std::make_unique<Updater>(map);
}
void visitCall(Call* curr) { maybeUpdate(curr->target); }
void visitRefFunc(RefFunc* curr) { maybeUpdate(curr->func); }
};
Updater updater(map);
updater.maybeUpdate(wasm.start);
PassRunner runner(&wasm);
updater.run(&runner, &wasm);
updater.runOnModuleCode(&runner, &wasm);
}
void renameFunction(Module& wasm, Name oldName, Name newName) {
std::map<Name, Name> map;
map[oldName] = newName;
renameFunctions(wasm, map);
}
namespace {
struct Counts {
InsertOrderedMap<HeapType, size_t> counts;
InsertOrderedMap<Signature, size_t> controlFlowSignatures;
void note(HeapType type) {
if (!type.isBasic()) {
counts[type]++;
}
}
void note(Type type) {
for (HeapType ht : type.getHeapTypeChildren()) {
note(ht);
}
}
void include(HeapType type) {
if (!type.isBasic()) {
counts[type];
}
}
void include(Type type) {
for (HeapType ht : type.getHeapTypeChildren()) {
include(ht);
}
}
void noteControlFlow(Signature sig) {
assert(sig.params.size() == 0);
if (sig.results.isTuple()) {
controlFlowSignatures[sig]++;
} else if (sig.results != Type::none) {
note(sig.results[0]);
}
}
};
struct CodeScanner
: PostWalker<CodeScanner, UnifiedExpressionVisitor<CodeScanner>> {
Counts& counts;
CodeScanner(Module& wasm, Counts& counts) : counts(counts) {
setModule(&wasm);
}
void visitExpression(Expression* curr) {
if (auto* call = curr->dynCast<CallIndirect>()) {
counts.note(call->heapType);
} else if (auto* call = curr->dynCast<CallRef>()) {
counts.note(call->target->type);
} else if (curr->is<RefNull>()) {
counts.note(curr->type);
} else if (curr->is<Select>() && curr->type.isRef()) {
counts.note(curr->type);
} else if (curr->is<StructNew>()) {
counts.note(curr->type);
} else if (curr->is<ArrayNew>()) {
counts.note(curr->type);
} else if (curr->is<ArrayNewData>()) {
counts.note(curr->type);
} else if (curr->is<ArrayNewElem>()) {
counts.note(curr->type);
} else if (curr->is<ArrayNewFixed>()) {
counts.note(curr->type);
} else if (auto* copy = curr->dynCast<ArrayCopy>()) {
counts.note(copy->destRef->type);
counts.note(copy->srcRef->type);
} else if (auto* fill = curr->dynCast<ArrayFill>()) {
counts.note(fill->ref->type);
} else if (auto* init = curr->dynCast<ArrayInitData>()) {
counts.note(init->ref->type);
} else if (auto* init = curr->dynCast<ArrayInitElem>()) {
counts.note(init->ref->type);
} else if (auto* cast = curr->dynCast<RefCast>()) {
counts.note(cast->type);
} else if (auto* cast = curr->dynCast<RefTest>()) {
counts.note(cast->castType);
} else if (auto* cast = curr->dynCast<BrOn>()) {
if (cast->op == BrOnCast || cast->op == BrOnCastFail) {
counts.note(cast->ref->type);
counts.note(cast->castType);
}
} else if (auto* get = curr->dynCast<StructGet>()) {
counts.note(get->ref->type);
counts.include(get->type);
} else if (auto* set = curr->dynCast<StructSet>()) {
counts.note(set->ref->type);
} else if (auto* get = curr->dynCast<ArrayGet>()) {
counts.note(get->ref->type);
counts.include(get->type);
} else if (auto* set = curr->dynCast<ArraySet>()) {
counts.note(set->ref->type);
} else if (auto* contBind = curr->dynCast<ContBind>()) {
counts.note(contBind->contTypeBefore);
counts.note(contBind->contTypeAfter);
} else if (auto* contNew = curr->dynCast<ContNew>()) {
counts.note(contNew->contType);
} else if (auto* resume = curr->dynCast<Resume>()) {
counts.note(resume->contType);
} else if (Properties::isControlFlowStructure(curr)) {
counts.noteControlFlow(Signature(Type::none, curr->type));
}
}
};
InsertOrderedMap<HeapType, size_t> getHeapTypeCounts(Module& wasm,
bool prune = false) {
Counts counts;
CodeScanner(wasm, counts).walkModuleCode(&wasm);
for (auto& curr : wasm.globals) {
counts.note(curr->type);
}
for (auto& curr : wasm.tags) {
counts.note(curr->sig);
}
for (auto& curr : wasm.tables) {
counts.note(curr->type);
}
for (auto& curr : wasm.elementSegments) {
counts.note(curr->type);
}
ModuleUtils::ParallelFunctionAnalysis<Counts, Immutable, InsertOrderedMap>
analysis(wasm, [&](Function* func, Counts& counts) {
counts.note(func->type);
for (auto type : func->vars) {
counts.note(type);
}
if (!func->imported()) {
CodeScanner(wasm, counts).walk(func->body);
}
});
for (auto& [_, functionCounts] : analysis.map) {
for (auto& [type, count] : functionCounts.counts) {
counts.counts[type] += count;
}
for (auto& [sig, count] : functionCounts.controlFlowSignatures) {
counts.controlFlowSignatures[sig] += count;
}
}
if (prune) {
auto it = counts.counts.begin();
while (it != counts.counts.end()) {
if (it->second == 0) {
auto deleted = it++;
counts.counts.erase(deleted);
} else {
++it;
}
}
}
UniqueNonrepeatingDeferredQueue<HeapType> newTypes;
std::unordered_map<Signature, HeapType> seenSigs;
auto noteNewType = [&](HeapType type) {
newTypes.push(type);
if (type.isSignature()) {
seenSigs.insert({type.getSignature(), type});
}
};
for (auto& [type, _] : counts.counts) {
noteNewType(type);
}
auto controlFlowIt = counts.controlFlowSignatures.begin();
std::unordered_set<RecGroup> includedGroups;
while (!newTypes.empty()) {
while (!newTypes.empty()) {
auto ht = newTypes.pop();
for (HeapType child : ht.getHeapTypeChildren()) {
if (!child.isBasic()) {
if (!counts.counts.count(child)) {
noteNewType(child);
}
counts.note(child);
}
}
if (auto super = ht.getDeclaredSuperType()) {
if (!counts.counts.count(*super)) {
noteNewType(*super);
counts.include(*super);
}
}
if (!prune) {
auto recGroup = ht.getRecGroup();
if (includedGroups.insert(recGroup).second) {
for (auto type : recGroup) {
if (!counts.counts.count(type)) {
noteNewType(type);
counts.include(type);
}
}
}
}
}
for (; controlFlowIt != counts.controlFlowSignatures.end();
++controlFlowIt) {
auto& [sig, count] = *controlFlowIt;
if (auto it = seenSigs.find(sig); it != seenSigs.end()) {
counts.counts[it->second] += count;
} else {
HeapType type(sig);
noteNewType(type);
counts.counts[type] += count;
break;
}
}
}
return counts.counts;
}
void setIndices(IndexedHeapTypes& indexedTypes) {
for (Index i = 0; i < indexedTypes.types.size(); i++) {
indexedTypes.indices[indexedTypes.types[i]] = i;
}
}
InsertOrderedSet<HeapType> getPublicTypeSet(Module& wasm) {
InsertOrderedSet<HeapType> publicTypes;
auto notePublic = [&](HeapType type) {
if (type.isBasic()) {
return;
}
for (auto member : type.getRecGroup()) {
if (!publicTypes.insert(member)) {
break;
}
}
};
ModuleUtils::iterImportedTables(wasm, [&](Table* table) {
assert(table->type.isRef());
notePublic(table->type.getHeapType());
});
ModuleUtils::iterImportedGlobals(wasm, [&](Global* global) {
if (global->type.isRef()) {
notePublic(global->type.getHeapType());
}
});
ModuleUtils::iterImportedFunctions(wasm, [&](Function* func) {
if (!Intrinsics(wasm).isCallWithoutEffects(func)) {
notePublic(func->type);
}
});
for (auto& ex : wasm.exports) {
switch (ex->kind) {
case ExternalKind::Function: {
auto* func = wasm.getFunction(ex->value);
notePublic(func->type);
continue;
}
case ExternalKind::Table: {
auto* table = wasm.getTable(ex->value);
assert(table->type.isRef());
notePublic(table->type.getHeapType());
continue;
}
case ExternalKind::Memory:
continue;
case ExternalKind::Global: {
auto* global = wasm.getGlobal(ex->value);
if (global->type.isRef()) {
notePublic(global->type.getHeapType());
}
continue;
}
case ExternalKind::Tag:
continue;
case ExternalKind::Invalid:
break;
}
WASM_UNREACHABLE("unexpected export kind");
}
for (auto type : getIgnorablePublicTypes()) {
notePublic(type);
}
std::vector<HeapType> workList(publicTypes.begin(), publicTypes.end());
while (workList.size()) {
auto curr = workList.back();
workList.pop_back();
for (auto t : curr.getReferencedHeapTypes()) {
if (!t.isBasic() && publicTypes.insert(t)) {
workList.push_back(t);
}
}
}
return publicTypes;
}
}
std::vector<HeapType> collectHeapTypes(Module& wasm) {
auto counts = getHeapTypeCounts(wasm);
std::vector<HeapType> types;
types.reserve(counts.size());
for (auto& [type, _] : counts) {
types.push_back(type);
}
return types;
}
std::vector<HeapType> getPublicHeapTypes(Module& wasm) {
auto publicTypes = getPublicTypeSet(wasm);
std::vector<HeapType> types;
types.reserve(publicTypes.size());
for (auto type : publicTypes) {
types.push_back(type);
}
return types;
}
std::vector<HeapType> getPrivateHeapTypes(Module& wasm) {
auto usedTypes = getHeapTypeCounts(wasm, true);
auto publicTypes = getPublicTypeSet(wasm);
std::vector<HeapType> types;
for (auto& [type, _] : usedTypes) {
if (!publicTypes.count(type)) {
types.push_back(type);
}
}
return types;
}
IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
auto counts = getHeapTypeCounts(wasm);
struct GroupInfo {
size_t index;
double useCount = 0;
std::unordered_set<RecGroup> preds;
std::vector<RecGroup> sortedPreds;
GroupInfo(size_t index) : index(index) {}
bool operator<(const GroupInfo& other) const {
if (useCount != other.useCount) {
return useCount < other.useCount;
}
return index > other.index;
}
};
struct GroupInfoMap : std::unordered_map<RecGroup, GroupInfo> {
void sort(std::vector<RecGroup>& groups) {
std::sort(groups.begin(), groups.end(), [&](auto& a, auto& b) {
return this->at(a) < this->at(b);
});
}
};
GroupInfoMap groupInfos;
for (auto& [type, _] : counts) {
RecGroup group = type.getRecGroup();
auto& info = groupInfos.insert({group, {groupInfos.size()}}).first->second;
info.useCount += counts.at(type);
for (auto child : type.getReferencedHeapTypes()) {
if (!child.isBasic()) {
RecGroup otherGroup = child.getRecGroup();
if (otherGroup != group) {
info.preds.insert(otherGroup);
}
}
}
}
for (auto& [group, info] : groupInfos) {
info.useCount /= group.size();
}
for (auto& [group, info] : groupInfos) {
info.sortedPreds.insert(
info.sortedPreds.end(), info.preds.begin(), info.preds.end());
groupInfos.sort(info.sortedPreds);
info.preds.clear();
}
struct RecGroupSort : TopologicalSort<RecGroup, RecGroupSort> {
GroupInfoMap& groupInfos;
RecGroupSort(GroupInfoMap& groupInfos) : groupInfos(groupInfos) {
std::vector<RecGroup> sortedGroups;
sortedGroups.reserve(groupInfos.size());
for (auto& [group, _] : groupInfos) {
sortedGroups.push_back(group);
}
groupInfos.sort(sortedGroups);
for (auto group : sortedGroups) {
push(group);
}
}
void pushPredecessors(RecGroup group) {
for (auto pred : groupInfos.at(group).sortedPreds) {
push(pred);
}
}
};
IndexedHeapTypes indexedTypes;
indexedTypes.types.reserve(counts.size());
for (auto group : RecGroupSort(groupInfos)) {
for (auto member : group) {
indexedTypes.types.push_back(member);
}
}
setIndices(indexedTypes);
return indexedTypes;
}
}