#include "ir/module-splitting.h"
#include "ir/element-utils.h"
#include "ir/export-utils.h"
#include "ir/manipulation.h"
#include "ir/module-utils.h"
#include "ir/names.h"
#include "pass.h"
#include "wasm-builder.h"
#include "wasm.h"
namespace wasm::ModuleSplitting {
namespace {
static const Name LOAD_SECONDARY_STATUS = "load_secondary_module_status";
template<class F> void forEachElement(Module& module, F f) {
ModuleUtils::iterActiveElementSegments(module, [&](ElementSegment* segment) {
Name base = "";
Index offset = 0;
if (auto* c = segment->offset->dynCast<Const>()) {
offset = c->value.geti32();
} else if (auto* g = segment->offset->dynCast<GlobalGet>()) {
base = g->name;
}
ElementUtils::iterElementSegmentFunctionNames(
segment, [&](Name& entry, Index i) {
f(segment->table, base, offset + i, entry);
});
});
}
struct TableSlotManager {
struct Slot {
Name tableName;
Name global;
Index index = 0;
Expression* makeExpr(Module& module);
};
Module& module;
Table* activeTable = nullptr;
ElementSegment* activeSegment = nullptr;
Slot activeBase;
std::map<Name, Slot> funcIndices;
std::vector<ElementSegment*> activeTableSegments;
TableSlotManager(Module& module);
Table* makeTable();
ElementSegment* makeElementSegment();
Slot getSlot(Name func, HeapType type);
void addSlot(Name func, Slot slot);
};
Expression* TableSlotManager::Slot::makeExpr(Module& module) {
Builder builder(module);
auto makeIndex = [&]() { return builder.makeConst(int32_t(index)); };
if (global.size()) {
Expression* getBase = builder.makeGlobalGet(global, Type::i32);
return index == 0 ? getBase
: builder.makeBinary(AddInt32, getBase, makeIndex());
} else {
return makeIndex();
}
}
void TableSlotManager::addSlot(Name func, Slot slot) {
funcIndices.insert({func, slot});
}
TableSlotManager::TableSlotManager(Module& module) : module(module) {
auto funcref = Type(HeapType::func, Nullable);
auto it = std::find_if(
module.tables.begin(),
module.tables.end(),
[&](std::unique_ptr<Table>& table) { return table->type == funcref; });
if (it == module.tables.end()) {
return;
}
activeTable = it->get();
ModuleUtils::iterTableSegments(
module, activeTable->name, [&](ElementSegment* segment) {
activeTableSegments.push_back(segment);
});
if (activeTableSegments.size() == 1 &&
activeTableSegments[0]->type == funcref &&
!activeTableSegments[0]->offset->is<Const>()) {
assert(activeTableSegments[0]->offset->is<GlobalGet>() &&
"Unexpected initializer instruction");
activeSegment = activeTableSegments[0];
activeBase = {activeTable->name,
activeTableSegments[0]->offset->cast<GlobalGet>()->name,
0};
} else {
Index maxIndex = 0;
for (auto& segment : activeTableSegments) {
assert(segment->offset->is<Const>() &&
"Unexpected non-const segment offset with multiple segments");
Index segmentBase = segment->offset->cast<Const>()->value.geti32();
if (segmentBase + segment->data.size() >= maxIndex) {
maxIndex = segmentBase + segment->data.size();
activeSegment = segment;
activeBase = {activeTable->name, "", segmentBase};
}
}
}
forEachElement(module, [&](Name table, Name base, Index offset, Name func) {
addSlot(func, {table, base, offset});
});
}
Table* TableSlotManager::makeTable() {
return module.addTable(
Builder::makeTable(Names::getValidTableName(module, Name::fromInt(0))));
}
ElementSegment* TableSlotManager::makeElementSegment() {
return module.addElementSegment(Builder::makeElementSegment(
Names::getValidElementSegmentName(module, Name::fromInt(0)),
activeTable->name,
Builder(module).makeConst(int32_t(0))));
}
TableSlotManager::Slot TableSlotManager::getSlot(Name func, HeapType type) {
auto slotIt = funcIndices.find(func);
if (slotIt != funcIndices.end()) {
return slotIt->second;
}
if (activeSegment == nullptr) {
if (activeTable == nullptr) {
activeTable = makeTable();
activeBase = {activeTable->name, "", 0};
}
assert(std::all_of(module.elementSegments.begin(),
module.elementSegments.end(),
[&](std::unique_ptr<ElementSegment>& segment) {
return segment->table != activeTable->name;
}));
activeSegment = makeElementSegment();
}
Slot newSlot = {activeBase.tableName,
activeBase.global,
activeBase.index + Index(activeSegment->data.size())};
Builder builder(module);
activeSegment->data.push_back(builder.makeRefFunc(func, type));
addSlot(func, newSlot);
if (activeTable->initial <= newSlot.index) {
activeTable->initial = newSlot.index + 1;
if (module.dylinkSection) {
module.dylinkSection->tableSize = activeTable->initial;
}
}
if (activeTable->max <= newSlot.index) {
activeTable->max = newSlot.index + 1;
}
return newSlot;
}
struct ModuleSplitter {
const Config& config;
std::unique_ptr<Module> secondaryPtr;
Module& primary;
Module& secondary;
const std::pair<std::set<Name>, std::set<Name>> classifiedFuncs;
const std::set<Name>& primaryFuncs;
const std::set<Name>& secondaryFuncs;
TableSlotManager tableManager;
Names::MinifiedNameGenerator minified;
std::map<Name, Name> exportedPrimaryFuncs;
std::map<size_t, Name> placeholderMap;
Name internalLoadSecondaryModule;
static std::unique_ptr<Module> initSecondary(const Module& primary);
static std::pair<std::set<Name>, std::set<Name>>
classifyFunctions(const Module& primary, const Config& config);
static std::map<Name, Name> initExportedPrimaryFuncs(const Module& primary);
void exportImportFunction(Name func);
Expression* maybeLoadSecondary(Builder& builder, Expression* callIndirect);
void setupJSPI();
void moveSecondaryFunctions();
void thunkExportedSecondaryFunctions();
void indirectCallsToSecondaryFunctions();
void exportImportCalledPrimaryFunctions();
void setupTablePatching();
void shareImportableItems();
ModuleSplitter(Module& primary, const Config& config)
: config(config), secondaryPtr(initSecondary(primary)), primary(primary),
secondary(*secondaryPtr),
classifiedFuncs(classifyFunctions(primary, config)),
primaryFuncs(classifiedFuncs.first),
secondaryFuncs(classifiedFuncs.second), tableManager(primary),
exportedPrimaryFuncs(initExportedPrimaryFuncs(primary)) {
if (config.jspi) {
setupJSPI();
}
moveSecondaryFunctions();
thunkExportedSecondaryFunctions();
indirectCallsToSecondaryFunctions();
exportImportCalledPrimaryFunctions();
setupTablePatching();
shareImportableItems();
}
};
void ModuleSplitter::setupJSPI() {
assert(primary.getExportOrNull(LOAD_SECONDARY_MODULE) &&
"The load secondary module function must exist");
internalLoadSecondaryModule = primary.getExport(LOAD_SECONDARY_MODULE)->value;
primary.removeExport(LOAD_SECONDARY_MODULE);
Builder builder(primary);
primary.addGlobal(builder.makeGlobal(LOAD_SECONDARY_STATUS,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Mutable));
primary.addExport(builder.makeExport(
LOAD_SECONDARY_STATUS, LOAD_SECONDARY_STATUS, ExternalKind::Global));
}
std::unique_ptr<Module> ModuleSplitter::initSecondary(const Module& primary) {
auto secondary = std::make_unique<Module>();
secondary->features = primary.features;
secondary->hasFeaturesSection = primary.hasFeaturesSection;
return secondary;
}
std::pair<std::set<Name>, std::set<Name>>
ModuleSplitter::classifyFunctions(const Module& primary, const Config& config) {
std::set<Name> primaryFuncs, secondaryFuncs;
for (auto& func : primary.functions) {
if (func->imported() || config.primaryFuncs.count(func->name) ||
(config.jspi && ExportUtils::isExported(primary, *func))) {
primaryFuncs.insert(func->name);
} else {
assert(func->name != primary.start && "The start function must be kept");
secondaryFuncs.insert(func->name);
}
}
return std::make_pair(primaryFuncs, secondaryFuncs);
}
std::map<Name, Name>
ModuleSplitter::initExportedPrimaryFuncs(const Module& primary) {
std::map<Name, Name> functionExportNames;
for (auto& ex : primary.exports) {
if (ex->kind == ExternalKind::Function) {
functionExportNames[ex->value] = ex->name;
}
}
return functionExportNames;
}
void ModuleSplitter::exportImportFunction(Name funcName) {
Name exportName;
auto exportIt = exportedPrimaryFuncs.find(funcName);
if (exportIt != exportedPrimaryFuncs.end()) {
exportName = exportIt->second;
} else {
if (config.minimizeNewExportNames) {
do {
exportName = config.newExportPrefix + minified.getName();
} while (primary.getExportOrNull(exportName) != nullptr);
} else {
exportName = Names::getValidExportName(
primary, config.newExportPrefix + funcName.toString());
}
primary.addExport(
Builder::makeExport(exportName, funcName, ExternalKind::Function));
exportedPrimaryFuncs[funcName] = exportName;
}
if (secondary.getFunctionOrNull(funcName) == nullptr) {
auto func =
Builder::makeFunction(funcName, primary.getFunction(funcName)->type, {});
func->module = config.importNamespace;
func->base = exportName;
secondary.addFunction(std::move(func));
}
}
void ModuleSplitter::moveSecondaryFunctions() {
for (auto funcName : secondaryFuncs) {
auto* func = primary.getFunction(funcName);
ModuleUtils::copyFunction(func, secondary);
primary.removeFunction(funcName);
}
}
void ModuleSplitter::thunkExportedSecondaryFunctions() {
Builder builder(primary);
for (auto& ex : primary.exports) {
if (ex->kind != ExternalKind::Function ||
!secondaryFuncs.count(ex->value)) {
continue;
}
Name secondaryFunc = ex->value;
if (primary.getFunctionOrNull(secondaryFunc)) {
continue;
}
auto* func = primary.addFunction(Builder::makeFunction(
secondaryFunc, secondary.getFunction(secondaryFunc)->type, {}));
std::vector<Expression*> args;
Type params = func->getParams();
for (size_t i = 0, size = params.size(); i < size; ++i) {
args.push_back(builder.makeLocalGet(i, params[i]));
}
auto tableSlot = tableManager.getSlot(secondaryFunc, func->type);
func->body = builder.makeCallIndirect(
tableSlot.tableName, tableSlot.makeExpr(primary), args, func->type);
}
}
Expression* ModuleSplitter::maybeLoadSecondary(Builder& builder,
Expression* callIndirect) {
if (!config.jspi) {
return callIndirect;
}
auto* loadSecondary = builder.makeIf(
builder.makeUnary(EqZInt32,
builder.makeGlobalGet(LOAD_SECONDARY_STATUS, Type::i32)),
builder.makeCall(internalLoadSecondaryModule, {}, Type::none));
return builder.makeSequence(loadSecondary, callIndirect);
}
void ModuleSplitter::indirectCallsToSecondaryFunctions() {
struct CallIndirector : public WalkerPass<PostWalker<CallIndirector>> {
ModuleSplitter& parent;
Builder builder;
CallIndirector(ModuleSplitter& parent)
: parent(parent), builder(parent.primary) {}
void walkElementSegment(ElementSegment* segment) {}
void visitCall(Call* curr) {
if (!parent.secondaryFuncs.count(curr->target)) {
return;
}
auto* func = parent.secondary.getFunction(curr->target);
auto tableSlot = parent.tableManager.getSlot(curr->target, func->type);
replaceCurrent(parent.maybeLoadSecondary(
builder,
builder.makeCallIndirect(tableSlot.tableName,
tableSlot.makeExpr(parent.primary),
curr->operands,
func->type,
curr->isReturn)));
}
void visitRefFunc(RefFunc* curr) {
assert(false && "TODO: handle ref.func as well");
}
};
PassRunner runner(&primary);
CallIndirector(*this).run(&runner, &primary);
}
void ModuleSplitter::exportImportCalledPrimaryFunctions() {
ModuleUtils::ParallelFunctionAnalysis<std::vector<Name>> callCollector(
secondary, [&](Function* func, std::vector<Name>& calledPrimaryFuncs) {
struct CallCollector : PostWalker<CallCollector> {
const std::set<Name>& primaryFuncs;
std::vector<Name>& calledPrimaryFuncs;
CallCollector(const std::set<Name>& primaryFuncs,
std::vector<Name>& calledPrimaryFuncs)
: primaryFuncs(primaryFuncs), calledPrimaryFuncs(calledPrimaryFuncs) {
}
void visitCall(Call* curr) {
if (primaryFuncs.count(curr->target)) {
calledPrimaryFuncs.push_back(curr->target);
}
}
void visitRefFunc(RefFunc* curr) {
assert(false && "TODO: handle ref.func as well");
}
};
CallCollector(primaryFuncs, calledPrimaryFuncs).walkFunction(func);
});
std::set<Name> calledPrimaryFuncs;
for (auto& entry : callCollector.map) {
auto& calledFuncs = entry.second;
calledPrimaryFuncs.insert(calledFuncs.begin(), calledFuncs.end());
}
for (auto func : calledPrimaryFuncs) {
exportImportFunction(func);
}
}
void ModuleSplitter::setupTablePatching() {
if (!tableManager.activeTable) {
return;
}
std::map<Index, Function*> replacedElems;
forEachElement(primary, [&](Name, Name, Index index, Name& elem) {
if (secondaryFuncs.count(elem)) {
placeholderMap[index] = elem;
auto* secondaryFunc = secondary.getFunction(elem);
replacedElems[index] = secondaryFunc;
auto placeholder = std::make_unique<Function>();
placeholder->module = config.placeholderNamespace;
placeholder->base = std::to_string(index);
placeholder->name = Names::getValidFunctionName(
primary, std::string("placeholder_") + placeholder->base.toString());
placeholder->hasExplicitName = false;
placeholder->type = secondaryFunc->type;
elem = placeholder->name;
primary.addFunction(std::move(placeholder));
}
});
if (replacedElems.size() == 0) {
return;
}
auto secondaryTable =
ModuleUtils::copyTable(tableManager.activeTable, secondary);
if (tableManager.activeBase.global.size()) {
assert(tableManager.activeTableSegments.size() == 1 &&
"Unexpected number of segments with non-const base");
assert(secondary.tables.size() == 1 && secondary.elementSegments.empty());
ElementSegment* primarySeg = tableManager.activeTableSegments.front();
std::vector<Expression*> secondaryElems;
secondaryElems.reserve(primarySeg->data.size());
auto replacement = replacedElems.begin();
for (Index i = 0;
i < primarySeg->data.size() && replacement != replacedElems.end();
++i) {
if (replacement->first == i) {
auto* func = replacement->second;
auto* ref = Builder(secondary).makeRefFunc(func->name, func->type);
secondaryElems.push_back(ref);
++replacement;
} else if (auto* get = primarySeg->data[i]->dynCast<RefFunc>()) {
exportImportFunction(get->func);
auto* copied =
ExpressionManipulator::copy(primarySeg->data[i], secondary);
secondaryElems.push_back(copied);
}
}
auto offset = ExpressionManipulator::copy(primarySeg->offset, secondary);
auto secondarySeg = std::make_unique<ElementSegment>(
secondaryTable->name, offset, secondaryTable->type, secondaryElems);
secondarySeg->setName(primarySeg->name, primarySeg->hasExplicitName);
secondary.addElementSegment(std::move(secondarySeg));
return;
}
Index currBase = replacedElems.begin()->first;
std::vector<Expression*> currData;
auto finishSegment = [&]() {
auto* offset = Builder(secondary).makeConst(int32_t(currBase));
auto secondarySeg = std::make_unique<ElementSegment>(
secondaryTable->name, offset, secondaryTable->type, currData);
Name name = Names::getValidElementSegmentName(
secondary, Name::fromInt(secondary.elementSegments.size()));
secondarySeg->setName(name, false);
secondary.addElementSegment(std::move(secondarySeg));
};
for (auto curr = replacedElems.begin(); curr != replacedElems.end(); ++curr) {
if (curr->first != currBase + currData.size()) {
finishSegment();
currBase = curr->first;
currData.clear();
}
auto* func = curr->second;
currData.push_back(Builder(secondary).makeRefFunc(func->name, func->type));
}
if (currData.size()) {
finishSegment();
}
}
void ModuleSplitter::shareImportableItems() {
std::unordered_map<std::pair<ExternalKind, Name>, Name> exports;
for (auto& ex : primary.exports) {
if (ex->kind != ExternalKind::Function) {
exports[std::make_pair(ex->kind, ex->value)] = ex->name;
}
}
auto makeImportExport = [&](Importable& primaryItem,
Importable& secondaryItem,
const std::string& genericExportName,
ExternalKind kind) {
secondaryItem.name = primaryItem.name;
secondaryItem.hasExplicitName = primaryItem.hasExplicitName;
secondaryItem.module = config.importNamespace;
auto exportIt = exports.find(std::make_pair(kind, primaryItem.name));
if (exportIt != exports.end()) {
secondaryItem.base = exportIt->second;
} else {
Name exportName = Names::getValidExportName(
primary, config.newExportPrefix + genericExportName);
primary.addExport(new Export{exportName, primaryItem.name, kind});
secondaryItem.base = exportName;
}
};
for (auto& memory : primary.memories) {
auto secondaryMemory = ModuleUtils::copyMemory(memory.get(), secondary);
makeImportExport(*memory, *secondaryMemory, "memory", ExternalKind::Memory);
}
for (auto& table : primary.tables) {
auto secondaryTable = secondary.getTableOrNull(table->name);
if (!secondaryTable) {
secondaryTable = ModuleUtils::copyTable(table.get(), secondary);
}
makeImportExport(*table, *secondaryTable, "table", ExternalKind::Table);
}
for (auto& global : primary.globals) {
if (global->mutable_) {
assert(primary.features.hasMutableGlobals() &&
"TODO: add wrapper functions for disallowed mutable globals");
}
auto secondaryGlobal = std::make_unique<Global>();
secondaryGlobal->type = global->type;
secondaryGlobal->mutable_ = global->mutable_;
secondaryGlobal->init =
global->init == nullptr
? nullptr
: ExpressionManipulator::copy(global->init, secondary);
makeImportExport(*global, *secondaryGlobal, "global", ExternalKind::Global);
secondary.addGlobal(std::move(secondaryGlobal));
}
for (auto& tag : primary.tags) {
auto secondaryTag = std::make_unique<Tag>();
secondaryTag->sig = tag->sig;
makeImportExport(*tag, *secondaryTag, "tag", ExternalKind::Tag);
secondary.addTag(std::move(secondaryTag));
}
}
}
Results splitFunctions(Module& primary, const Config& config) {
ModuleSplitter split(primary, config);
return {std::move(split.secondaryPtr), std::move(split.placeholderMap)};
}
}