#include "asmjs/shared-constants.h"
#include "ir/element-utils.h"
#include "ir/import-utils.h"
#include "ir/literal-utils.h"
#include "ir/module-splitting.h"
#include "ir/names.h"
#include "ir/utils.h"
#include "pass.h"
#include "shared-constants.h"
#include "support/file.h"
#include "support/string.h"
#include "wasm-builder.h"
#include "wasm.h"
#include <utility>
namespace wasm {
static std::string getFullFunctionName(Name module, Name base) {
return std::string(module.str) + '.' + base.toString();
}
static bool canChangeState(std::string name, String::Split stateChangers) {
if (stateChangers.empty()) {
return true;
}
for (auto& stateChanger : stateChangers) {
if (String::wildcardMatch(stateChanger, name)) {
return true;
}
}
return false;
}
struct JSPI : public Pass {
Type externref = Type(HeapType::ext, Nullable);
void run(Module* module) override {
Builder builder(*module);
auto& options = getPassOptions();
auto stateChangingImports = String::trim(read_possible_response_file(
options.getArgumentOrDefault("jspi-imports", "")));
String::Split listedImports(stateChangingImports, ",");
auto stateChangingExports = String::trim(read_possible_response_file(
options.getArgumentOrDefault("jspi-exports", "")));
String::Split listedExports(stateChangingExports, ",");
bool wasmSplit = options.hasArgument("jspi-split-module");
if (wasmSplit) {
auto import =
Builder::makeFunction(ModuleSplitting::LOAD_SECONDARY_MODULE,
Signature(Type::none, Type::none),
{});
import->module = ENV;
import->base = ModuleSplitting::LOAD_SECONDARY_MODULE;
module->addFunction(std::move(import));
listedImports.push_back(
ENV.toString() + "." +
ModuleSplitting::LOAD_SECONDARY_MODULE.toString());
}
Name suspender = Names::getValidGlobalName(*module, "suspender");
module->addGlobal(builder.makeGlobal(suspender,
externref,
builder.makeRefNull(HeapType::noext),
Builder::Mutable));
std::unordered_map<Name, Name> wrappedExports;
for (auto& ex : module->exports) {
if (ex->kind == ExternalKind::Function &&
canChangeState(ex->name.toString(), listedExports)) {
auto* func = module->getFunction(ex->value);
Name wrapperName;
auto iter = wrappedExports.find(func->name);
if (iter == wrappedExports.end()) {
wrapperName = makeWrapperForExport(func, module, suspender);
wrappedExports[func->name] = wrapperName;
} else {
wrapperName = iter->second;
}
ex->value = wrapperName;
}
}
for (auto& segment : module->elementSegments) {
if (!segment->type.isFunction()) {
continue;
}
for (Index i = 0; i < segment->data.size(); i++) {
if (auto* get = segment->data[i]->dynCast<RefFunc>()) {
auto iter = wrappedExports.find(get->func);
if (iter == wrappedExports.end()) {
continue;
}
auto* replacementRef = builder.makeRefFunc(
iter->second, module->getFunction(iter->second)->type);
segment->data[i] = replacementRef;
}
}
}
std::vector<Function*> originalFunctions;
for (auto& func : module->functions) {
originalFunctions.push_back(func.get());
}
for (auto* im : originalFunctions) {
if (im->imported() &&
canChangeState(getFullFunctionName(im->module, im->base),
listedImports)) {
makeWrapperForImport(im, module, suspender, wasmSplit);
}
}
}
private:
Name makeWrapperForExport(Function* func, Module* module, Name suspender) {
Name wrapperName = Names::getValidFunctionName(
*module, std::string("export$") + func->name.toString());
Builder builder(*module);
auto* call = module->allocator.alloc<Call>();
call->target = func->name;
call->type = func->getResults();
std::vector<Type> wrapperParams;
std::vector<NameType> namedWrapperParams;
wrapperParams.push_back(externref);
namedWrapperParams.emplace_back(Names::getValidLocalName(*func, "susp"),
externref);
Index i = 0;
for (const auto& param : func->getParams()) {
call->operands.push_back(
builder.makeLocalGet(wrapperParams.size(), param));
wrapperParams.push_back(param);
namedWrapperParams.emplace_back(func->getLocalNameOrGeneric(i), param);
i++;
}
auto* block = builder.makeBlock();
block->list.push_back(
builder.makeGlobalSet(suspender, builder.makeLocalGet(0, externref)));
block->list.push_back(call);
Type resultsType = func->getResults();
if (resultsType == Type::none) {
resultsType = Type::i32;
block->list.push_back(builder.makeConst(0));
}
block->finalize();
auto wrapperFunc =
Builder::makeFunction(wrapperName,
std::move(namedWrapperParams),
Signature(Type(wrapperParams), resultsType),
{},
block);
return module->addFunction(std::move(wrapperFunc))->name;
}
void makeWrapperForImport(Function* im,
Module* module,
Name suspender,
bool wasmSplit) {
Builder builder(*module);
auto wrapperIm = std::make_unique<Function>();
wrapperIm->name = Names::getValidFunctionName(
*module, std::string("import$") + im->name.toString());
wrapperIm->module = im->module;
wrapperIm->base = im->base;
auto stub = std::make_unique<Function>();
stub->name = Name(im->name.str);
stub->type = im->type;
auto* call = module->allocator.alloc<Call>();
call->target = wrapperIm->name;
std::vector<Type> params;
params.push_back(externref);
call->operands.push_back(builder.makeGlobalGet(suspender, externref));
Index i = 0;
for (const auto& param : im->getParams()) {
call->operands.push_back(builder.makeLocalGet(i, param));
params.push_back(param);
++i;
}
auto* block = builder.makeBlock();
auto supsenderCopyIndex = Builder::addVar(stub.get(), externref);
std::optional<Index> returnIndex;
if (stub->getResults().isConcrete()) {
returnIndex = Builder::addVar(stub.get(), stub->getResults());
}
block->list.push_back(builder.makeLocalSet(
supsenderCopyIndex, builder.makeGlobalGet(suspender, externref)));
if (returnIndex) {
block->list.push_back(builder.makeLocalSet(*returnIndex, call));
} else {
block->list.push_back(call);
}
block->list.push_back(builder.makeGlobalSet(
suspender, builder.makeLocalGet(supsenderCopyIndex, externref)));
if (returnIndex) {
block->list.push_back(
builder.makeLocalGet(*returnIndex, stub->getResults()));
}
block->finalize();
call->type = im->getResults();
stub->body = block;
wrapperIm->type = Signature(Type(params), call->type);
if (wasmSplit && im->name == ModuleSplitting::LOAD_SECONDARY_MODULE) {
module->addExport(
builder.makeExport(ModuleSplitting::LOAD_SECONDARY_MODULE,
ModuleSplitting::LOAD_SECONDARY_MODULE,
ExternalKind::Function));
}
module->removeFunction(im->name);
module->addFunction(std::move(stub));
module->addFunction(std::move(wrapperIm));
}
};
Pass* createJSPIPass() { return new JSPI(); }
}