#include "asmjs/shared-constants.h"
#include "cfg/liveness-traversal.h"
#include "ir/effects.h"
#include "ir/find_all.h"
#include "ir/linear-execution.h"
#include "ir/literal-utils.h"
#include "ir/memory-utils.h"
#include "ir/module-utils.h"
#include "ir/names.h"
#include "ir/utils.h"
#include "pass.h"
#include "passes/pass-utils.h"
#include "support/file.h"
#include "support/string.h"
#include "wasm-builder.h"
#include "wasm.h"
namespace wasm {
namespace {
static const Name ASYNCIFY_STATE = "__asyncify_state";
static const Name ASYNCIFY_GET_STATE = "asyncify_get_state";
static const Name ASYNCIFY_DATA = "__asyncify_data";
static const Name ASYNCIFY_START_UNWIND = "asyncify_start_unwind";
static const Name ASYNCIFY_STOP_UNWIND = "asyncify_stop_unwind";
static const Name ASYNCIFY_START_REWIND = "asyncify_start_rewind";
static const Name ASYNCIFY_STOP_REWIND = "asyncify_stop_rewind";
static const Name ASYNCIFY_UNWIND = "__asyncify_unwind";
static const Name ASYNCIFY = "asyncify";
static const Name START_UNWIND = "start_unwind";
static const Name STOP_UNWIND = "stop_unwind";
static const Name START_REWIND = "start_rewind";
static const Name STOP_REWIND = "stop_rewind";
static const Name ASYNCIFY_GET_CALL_INDEX = "__asyncify_get_call_index";
static const Name ASYNCIFY_CHECK_CALL_INDEX = "__asyncify_check_call_index";
enum class State { Normal = 0, Unwinding = 1, Rewinding = 2 };
enum class DataOffset { BStackPos = 0, BStackEnd = 4, BStackEnd64 = 8 };
const auto STACK_ALIGN = 4;
class FakeGlobalHelper {
Module& module;
public:
FakeGlobalHelper(Module& module) : module(module) {
Builder builder(module);
std::string prefix = "asyncify_fake_call_global_";
for (auto type : collectTypes()) {
auto global = prefix + Type(type).toString();
map[type] = global;
rev[global] = type;
module.addGlobal(builder.makeGlobal(
global, type, LiteralUtils::makeZero(type, module), Builder::Mutable));
}
}
~FakeGlobalHelper() {
for (auto& pair : map) {
auto name = pair.second;
module.removeGlobal(name);
}
}
Name getName(Type type) { return map.at(type); }
Type getTypeOrNone(Name name) {
auto iter = rev.find(name);
if (iter != rev.end()) {
return iter->second;
}
return Type::none;
}
private:
std::unordered_map<Type, Name> map;
std::unordered_map<Name, Type> rev;
using Types = std::unordered_set<Type>;
Types collectTypes() {
ModuleUtils::ParallelFunctionAnalysis<Types> analysis(
module, [&](Function* func, Types& types) {
if (!func->body) {
return;
}
struct TypeCollector : PostWalker<TypeCollector> {
Types& types;
TypeCollector(Types& types) : types(types) {}
void visitCall(Call* curr) {
if (curr->type.isConcrete()) {
types.insert(curr->type);
}
}
void visitCallIndirect(CallIndirect* curr) {
if (curr->type.isConcrete()) {
types.insert(curr->type);
}
}
};
TypeCollector(types).walk(func->body);
});
Types types;
for (auto& pair : analysis.map) {
Types& functionTypes = pair.second;
for (auto t : functionTypes) {
types.insert(t);
}
}
return types;
}
};
class PatternMatcher {
public:
std::string designation;
std::set<Name> names;
std::set<std::string> patterns;
std::set<std::string> patternsMatched;
std::map<std::string, std::string> unescaped;
PatternMatcher(std::string designation,
Module& module,
const String::Split& list)
: designation(designation) {
for (auto& name : list) {
auto escaped = WasmBinaryReader::escape(name);
unescaped[escaped.toString()] = name;
if (name.find('*') != std::string::npos) {
patterns.insert(escaped.toString());
} else {
auto* func = module.getFunctionOrNull(escaped);
if (!func) {
std::cerr << "warning: Asyncify " << designation
<< "list contained a non-existing function name: " << name
<< " (" << escaped << ")\n";
} else if (func->imported()) {
Fatal() << "Asyncify " << designation
<< "list contained an imported function name (use the import "
"list for imports): "
<< name << '\n';
}
names.insert(escaped.str);
}
}
}
bool match(Name funcName) {
if (names.count(funcName) > 0) {
return true;
} else {
for (auto& pattern : patterns) {
if (String::wildcardMatch(pattern, funcName.toString())) {
patternsMatched.insert(pattern);
return true;
}
}
}
return false;
}
void checkPatternsMatches() {
for (auto& pattern : patterns) {
if (patternsMatched.count(pattern) == 0) {
std::cerr << "warning: Asyncify " << designation
<< "list contained a non-matching pattern: "
<< unescaped[pattern] << " (" << pattern << ")\n";
}
}
}
};
class ModuleAnalyzer {
Module& module;
bool canIndirectChangeState;
struct Info
: public ModuleUtils::CallGraphPropertyAnalysis<Info>::FunctionInfo {
Name name;
bool canChangeState = false;
bool isBottomMostRuntime = false;
bool isTopMostRuntime = false;
bool inRemoveList = false;
bool addedFromList = false;
};
using Map = std::map<Function*, Info>;
Map map;
public:
ModuleAnalyzer(Module& module,
std::function<bool(Name, Name)> canImportChangeState,
bool canIndirectChangeState,
const String::Split& removeListInput,
const String::Split& addListInput,
const String::Split& onlyListInput,
bool verbose)
: module(module), canIndirectChangeState(canIndirectChangeState),
fakeGlobals(module), verbose(verbose) {
PatternMatcher removeList("remove", module, removeListInput);
PatternMatcher addList("add", module, addListInput);
PatternMatcher onlyList("only", module, onlyListInput);
std::map<Name, Name> renamings;
for (auto& func : module.functions) {
if (func->module == ASYNCIFY) {
if (func->base == START_UNWIND) {
renamings[func->name] = ASYNCIFY_START_UNWIND;
} else if (func->base == STOP_UNWIND) {
renamings[func->name] = ASYNCIFY_STOP_UNWIND;
} else if (func->base == START_REWIND) {
renamings[func->name] = ASYNCIFY_START_REWIND;
} else if (func->base == STOP_REWIND) {
renamings[func->name] = ASYNCIFY_STOP_REWIND;
} else {
Fatal() << "call to unidenfied asyncify import: " << func->base;
}
}
}
ModuleUtils::renameFunctions(module, renamings);
ModuleUtils::CallGraphPropertyAnalysis<Info> scanner(
module, [&](Function* func, Info& info) {
info.name = func->name;
if (func->imported()) {
if (func->module == ASYNCIFY &&
(func->base == START_UNWIND || func->base == STOP_REWIND)) {
info.canChangeState = true;
} else {
info.canChangeState =
canImportChangeState(func->module, func->base);
if (verbose && info.canChangeState) {
std::cout << "[asyncify] " << func->name
<< " is an import that can change the state\n";
}
}
return;
}
struct Walker : PostWalker<Walker> {
Info& info;
Module& module;
bool canIndirectChangeState;
Walker(Info& info, Module& module, bool canIndirectChangeState)
: info(info), module(module),
canIndirectChangeState(canIndirectChangeState) {}
void visitCall(Call* curr) {
if (curr->isReturn) {
Fatal() << "tail calls not yet supported in asyncify";
}
auto* target = module.getFunction(curr->target);
if (target->imported() && target->module == ASYNCIFY) {
if (target->base == START_UNWIND) {
info.canChangeState = true;
info.isTopMostRuntime = true;
} else if (target->base == STOP_UNWIND) {
info.isBottomMostRuntime = true;
} else if (target->base == START_REWIND) {
info.isBottomMostRuntime = true;
} else if (target->base == STOP_REWIND) {
info.canChangeState = true;
info.isTopMostRuntime = true;
} else {
WASM_UNREACHABLE("call to unidenfied asyncify import");
}
}
}
void visitCallIndirect(CallIndirect* curr) {
if (curr->isReturn) {
Fatal() << "tail calls not yet supported in asyncify";
}
if (canIndirectChangeState) {
info.canChangeState = true;
}
}
};
Walker walker(info, module, canIndirectChangeState);
walker.walk(func->body);
if (info.isBottomMostRuntime) {
info.canChangeState = false;
}
if (verbose && info.canChangeState) {
std::cout << "[asyncify] " << func->name
<< " can change the state due to initial scan\n";
}
});
for (auto& [func, info] : scanner.map) {
if (removeList.match(func->name)) {
info.inRemoveList = true;
if (verbose && info.canChangeState) {
std::cout << "[asyncify] " << func->name
<< " is in the remove-list, ignore\n";
}
info.canChangeState = false;
}
}
std::vector<Name> funcsToDelete;
for (auto& [func, info] : scanner.map) {
auto& callsTo = info.callsTo;
if (func->imported() && func->module == ASYNCIFY) {
funcsToDelete.push_back(func->name);
}
std::vector<Function*> callersToDelete;
for (auto* target : callsTo) {
if (target->imported() && target->module == ASYNCIFY) {
callersToDelete.push_back(target);
}
}
for (auto* target : callersToDelete) {
callsTo.erase(target);
}
}
for (auto name : funcsToDelete) {
module.removeFunction(name);
}
scanner.propagateBack([](const Info& info) { return info.canChangeState; },
[](const Info& info) {
return !info.isBottomMostRuntime &&
!info.inRemoveList;
},
[verbose](Info& info, Function* reason) {
if (verbose && !info.canChangeState) {
std::cout << "[asyncify] " << info.name
<< " can change the state due to "
<< reason->name << "\n";
}
info.canChangeState = true;
},
scanner.IgnoreNonDirectCalls);
map.swap(scanner.map);
if (!onlyListInput.empty()) {
for (auto& func : module.functions) {
if (!func->imported()) {
auto& info = map[func.get()];
bool matched = onlyList.match(func->name);
info.canChangeState = matched;
if (matched) {
info.addedFromList = true;
}
if (verbose) {
std::cout << "[asyncify] " << func->name
<< "'s state is set based on the only-list to " << matched
<< '\n';
}
}
}
}
if (!addListInput.empty()) {
for (auto& func : module.functions) {
if (!func->imported() && addList.match(func->name)) {
auto& info = map[func.get()];
if (verbose && !info.canChangeState) {
std::cout << "[asyncify] " << func->name
<< " is in the add-list, add\n";
}
info.canChangeState = true;
info.addedFromList = true;
}
}
}
removeList.checkPatternsMatches();
addList.checkPatternsMatches();
onlyList.checkPatternsMatches();
}
bool needsInstrumentation(Function* func) {
auto& info = map[func];
return info.canChangeState && !info.isTopMostRuntime;
}
bool canChangeState(Expression* curr, Function* func) {
struct Walker : PostWalker<Walker> {
void visitCall(Call* curr) {
if (curr->target == ASYNCIFY_START_UNWIND ||
curr->target == ASYNCIFY_STOP_REWIND ||
curr->target == ASYNCIFY_GET_CALL_INDEX ||
curr->target == ASYNCIFY_CHECK_CALL_INDEX) {
canChangeState = true;
return;
}
if (curr->target == ASYNCIFY_STOP_UNWIND ||
curr->target == ASYNCIFY_START_REWIND) {
isBottomMostRuntime = true;
return;
}
auto* target = module->getFunctionOrNull(curr->target);
if (target && (*map)[target].canChangeState) {
canChangeState = true;
}
}
void visitCallIndirect(CallIndirect* curr) { hasIndirectCall = true; }
Module* module;
ModuleAnalyzer* analyzer;
Map* map;
bool hasIndirectCall = false;
bool canChangeState = false;
bool isBottomMostRuntime = false;
};
Walker walker;
walker.module = &module;
walker.analyzer = this;
walker.map = ↦
walker.walk(curr);
if (walker.hasIndirectCall &&
(canIndirectChangeState || map[func].addedFromList)) {
walker.canChangeState = true;
}
return walker.canChangeState && !walker.isBottomMostRuntime;
}
FakeGlobalHelper fakeGlobals;
bool verbose;
};
static bool doesCall(Expression* curr) {
if (auto* set = curr->dynCast<LocalSet>()) {
curr = set->value;
} else if (auto* drop = curr->dynCast<Drop>()) {
curr = drop->value;
}
return curr->is<Call>() || curr->is<CallIndirect>();
}
class AsyncifyBuilder : public Builder {
public:
Module& wasm;
Type pointerType;
Name asyncifyMemory;
AsyncifyBuilder(Module& wasm, Type pointerType, Name asyncifyMemory)
: Builder(wasm), wasm(wasm), pointerType(pointerType),
asyncifyMemory(asyncifyMemory) {}
Expression* makeGetStackPos() {
return makeLoad(pointerType.getByteSize(),
false,
int(DataOffset::BStackPos),
pointerType.getByteSize(),
makeGlobalGet(ASYNCIFY_DATA, pointerType),
pointerType,
asyncifyMemory);
}
Expression* makeIncStackPos(int32_t by) {
if (by == 0) {
return makeNop();
}
auto literal = Literal::makeFromInt64(by, pointerType);
return makeStore(pointerType.getByteSize(),
int(DataOffset::BStackPos),
pointerType.getByteSize(),
makeGlobalGet(ASYNCIFY_DATA, pointerType),
makeBinary(Abstract::getBinary(pointerType, Abstract::Add),
makeGetStackPos(),
makeConst(literal)),
pointerType,
asyncifyMemory);
}
Expression* makeStateCheck(State value) {
return makeBinary(EqInt32,
makeGlobalGet(ASYNCIFY_STATE, Type::i32),
makeConst(Literal(int32_t(value))));
}
};
struct AsyncifyFlow : public Pass {
bool isFunctionParallel() override { return true; }
ModuleAnalyzer* analyzer;
Type pointerType;
Name asyncifyMemory;
std::unique_ptr<Pass> create() override {
return std::make_unique<AsyncifyFlow>(
analyzer, pointerType, asyncifyMemory);
}
AsyncifyFlow(ModuleAnalyzer* analyzer, Type pointerType, Name asyncifyMemory)
: analyzer(analyzer), pointerType(pointerType),
asyncifyMemory(asyncifyMemory) {}
void runOnFunction(Module* module_, Function* func_) override {
module = module_;
func = func_;
builder =
std::make_unique<AsyncifyBuilder>(*module, pointerType, asyncifyMemory);
if (!analyzer->needsInstrumentation(func)) {
return;
}
auto* block = builder->makeBlock(
{builder->makeIf(builder->makeStateCheck(
State::Rewinding), makeCallIndexPop()),
process(func->body)});
if (func->getResults() != Type::none) {
block->list.push_back(builder->makeUnreachable());
}
block->finalize();
func->body = block;
ReFinalize().walkFunctionInModule(func, module);
}
private:
std::unique_ptr<AsyncifyBuilder> builder;
Module* module;
Function* func;
Index callIndex = 0;
Expression* process(Expression* curr) {
struct Work {
Expression* curr;
enum { Scan, Finish } phase;
};
std::vector<Work> work;
std::vector<Expression*> results;
std::unordered_set<Expression*> processed;
work.push_back(Work{curr, Work::Scan});
while (!work.empty()) {
auto item = work.back();
work.pop_back();
processed.insert(item.curr);
auto* curr = item.curr;
auto phase = item.phase;
if (phase == Work::Scan && !analyzer->canChangeState(curr, func)) {
results.push_back(makeMaybeSkip(curr));
continue;
}
if (auto* block = curr->dynCast<Block>()) {
auto& list = block->list;
if (phase == Work::Scan) {
work.push_back(Work{curr, Work::Finish});
for (size_t i = list.size(); i > 0; i--) {
auto* child = list[i - 1];
if (analyzer->canChangeState(child, func)) {
work.push_back(Work{child, Work::Scan});
}
}
continue;
}
Index i = list.size() - 1;
while (1) {
if (processed.count(list[i])) {
list[i] = results.back();
results.pop_back();
} else {
Index begin = i;
while (begin > 0 && !processed.count(list[begin - 1])) {
begin--;
}
if (begin == i) {
list[i] = makeMaybeSkip(list[i]);
} else {
auto* block = builder->makeBlock();
for (auto j = begin; j <= i; j++) {
block->list.push_back(list[j]);
}
block->finalize();
list[begin] = makeMaybeSkip(block);
for (auto j = begin + 1; j <= i; j++) {
list[j] = builder->makeNop();
}
}
i = begin;
}
if (i == 0) {
break;
} else {
i--;
}
}
results.push_back(block);
continue;
} else if (auto* iff = curr->dynCast<If>()) {
assert(!analyzer->canChangeState(iff->condition, func));
if (item.phase == Work::Scan) {
work.push_back(Work{curr, Work::Finish});
if (iff->ifFalse) {
work.push_back(Work{iff->ifFalse, Work::Scan});
}
work.push_back(Work{iff->ifTrue, Work::Scan});
continue;
}
if (!iff->ifFalse) {
iff->condition = builder->makeBinary(
OrInt32, iff->condition, builder->makeStateCheck(State::Rewinding));
iff->ifTrue = results.back();
results.pop_back();
iff->finalize();
results.push_back(iff);
continue;
}
auto* newIfFalse = results.back();
results.pop_back();
auto* newIfTrue = results.back();
results.pop_back();
auto conditionTemp = builder->addVar(func, Type::i32);
auto* pre =
makeMaybeSkip(builder->makeLocalSet(conditionTemp, iff->condition));
iff->condition = builder->makeLocalGet(conditionTemp, Type::i32);
iff->condition = builder->makeBinary(
OrInt32, iff->condition, builder->makeStateCheck(State::Rewinding));
iff->ifTrue = newIfTrue;
iff->ifFalse = nullptr;
iff->finalize();
auto* otherIf = builder->makeIf(
builder->makeBinary(
OrInt32,
builder->makeUnary(EqZInt32,
builder->makeLocalGet(conditionTemp, Type::i32)),
builder->makeStateCheck(State::Rewinding)),
newIfFalse);
otherIf->finalize();
results.push_back(builder->makeBlock({pre, iff, otherIf}));
continue;
} else if (auto* loop = curr->dynCast<Loop>()) {
if (item.phase == Work::Scan) {
work.push_back(Work{curr, Work::Finish});
work.push_back(Work{loop->body, Work::Scan});
continue;
}
loop->body = results.back();
results.pop_back();
results.push_back(loop);
continue;
} else if (doesCall(curr)) {
results.push_back(makeCallSupport(curr));
continue;
}
WASM_UNREACHABLE("unexpected expression type");
}
assert(results.size() == 1);
return results.back();
}
Expression* makeMaybeSkip(Expression* curr) {
return builder->makeIf(builder->makeStateCheck(State::Normal), curr);
}
Expression* makeCallSupport(Expression* curr) {
assert(doesCall(curr));
assert(curr->type == Type::none);
auto* set = curr->dynCast<LocalSet>();
if (set) {
auto name = analyzer->fakeGlobals.getName(set->value->type);
curr = builder->makeGlobalSet(name, set->value);
set->value = builder->makeGlobalGet(name, set->value->type);
}
auto index = callIndex++;
curr = builder->makeIf(
builder->makeBinary(OrInt32,
builder->makeStateCheck(State::Normal),
makeCallIndexPeek(index)),
builder->makeSequence(curr, makePossibleUnwind(index, set)));
return curr;
}
Expression* makePossibleUnwind(Index index, Expression* ifNotUnwinding) {
return builder->makeIf(
builder->makeStateCheck(State::Unwinding),
builder->makeCall(
ASYNCIFY_UNWIND, {builder->makeConst(int32_t(index))}, Type::none),
ifNotUnwinding);
}
Expression* makeCallIndexPeek(Index index) {
return builder->makeCall(ASYNCIFY_CHECK_CALL_INDEX,
{builder->makeConst(int32_t(index))},
Type::i32);
}
Expression* makeCallIndexPop() {
return builder->makeCall(ASYNCIFY_GET_CALL_INDEX, {}, Type::none);
}
};
struct AsyncifyAssertInNonInstrumented : public Pass {
bool isFunctionParallel() override { return true; }
ModuleAnalyzer* analyzer;
Type pointerType;
Name asyncifyMemory;
std::unique_ptr<Pass> create() override {
return std::make_unique<AsyncifyAssertInNonInstrumented>(
analyzer, pointerType, asyncifyMemory);
}
AsyncifyAssertInNonInstrumented(ModuleAnalyzer* analyzer,
Type pointerType,
Name asyncifyMemory)
: analyzer(analyzer), pointerType(pointerType),
asyncifyMemory(asyncifyMemory) {}
void runOnFunction(Module* module_, Function* func) override {
if (!analyzer->needsInstrumentation(func)) {
module = module_;
builder =
std::make_unique<AsyncifyBuilder>(*module, pointerType, asyncifyMemory);
addAssertsInNonInstrumented(func);
}
}
void addAssertsInNonInstrumented(Function* func) {
auto oldState = builder->addVar(func, Type::i32);
func->body = builder->makeSequence(
builder->makeLocalSet(oldState,
builder->makeGlobalGet(ASYNCIFY_STATE, Type::i32)),
func->body);
struct Walker : PostWalker<Walker> {
void visitCall(Call* curr) {
assert(!curr->isReturn);
handleCall(curr);
}
void visitCallIndirect(CallIndirect* curr) {
assert(!curr->isReturn);
handleCall(curr);
}
void handleCall(Expression* call) {
auto* check = builder->makeIf(
builder->makeBinary(NeInt32,
builder->makeGlobalGet(ASYNCIFY_STATE, Type::i32),
builder->makeLocalGet(oldState, Type::i32)),
builder->makeUnreachable());
Expression* rep;
if (call->type.isConcrete()) {
auto temp = builder->addVar(func, call->type);
rep = builder->makeBlock({
builder->makeLocalSet(temp, call),
check,
builder->makeLocalGet(temp, call->type),
});
} else {
rep = builder->makeSequence(call, check);
}
replaceCurrent(rep);
}
Function* func;
AsyncifyBuilder* builder;
Index oldState;
};
Walker walker;
walker.func = func;
walker.builder = builder.get();
walker.oldState = oldState;
walker.walk(func->body);
}
private:
std::unique_ptr<AsyncifyBuilder> builder;
Module* module;
};
struct AsyncifyLocals : public WalkerPass<PostWalker<AsyncifyLocals>> {
bool isFunctionParallel() override { return true; }
ModuleAnalyzer* analyzer;
Type pointerType;
Name asyncifyMemory;
std::unique_ptr<Pass> create() override {
return std::make_unique<AsyncifyLocals>(
analyzer, pointerType, asyncifyMemory);
}
AsyncifyLocals(ModuleAnalyzer* analyzer,
Type pointerType,
Name asyncifyMemory)
: analyzer(analyzer), pointerType(pointerType),
asyncifyMemory(asyncifyMemory) {}
void visitCall(Call* curr) {
if (curr->target == ASYNCIFY_UNWIND) {
replaceCurrent(builder->makeBreak(ASYNCIFY_UNWIND, curr->operands[0]));
} else if (curr->target == ASYNCIFY_GET_CALL_INDEX) {
replaceCurrent(builder->makeSequence(
builder->makeIncStackPos(-4),
builder->makeLocalSet(rewindIndex,
builder->makeLoad(4,
false,
0,
4,
builder->makeGetStackPos(),
Type::i32,
asyncifyMemory))));
} else if (curr->target == ASYNCIFY_CHECK_CALL_INDEX) {
replaceCurrent(
builder->makeBinary(EqInt32,
builder->makeLocalGet(rewindIndex, Type::i32),
curr->operands[0]));
}
}
void visitGlobalSet(GlobalSet* curr) {
auto type = analyzer->fakeGlobals.getTypeOrNone(curr->name);
if (type != Type::none) {
replaceCurrent(
builder->makeLocalSet(getFakeCallLocal(type), curr->value));
}
}
void visitGlobalGet(GlobalGet* curr) {
auto type = analyzer->fakeGlobals.getTypeOrNone(curr->name);
if (type != Type::none) {
replaceCurrent(builder->makeLocalGet(getFakeCallLocal(type), type));
}
}
Index getFakeCallLocal(Type type) {
auto iter = fakeCallLocals.find(type);
if (iter != fakeCallLocals.end()) {
return iter->second;
}
return fakeCallLocals[type] = builder->addVar(getFunction(), type);
}
void doWalkFunction(Function* func) {
if (!analyzer->needsInstrumentation(func)) {
return;
}
findRelevantLiveLocals(func);
auto unwindIndex = builder->addVar(func, Type::i32);
rewindIndex = builder->addVar(func, Type::i32);
builder = std::make_unique<AsyncifyBuilder>(
*getModule(), pointerType, asyncifyMemory);
walk(func->body);
Expression* barrier;
if (func->getResults() == Type::none) {
barrier = builder->makeReturn();
} else {
barrier = builder->makeUnreachable();
}
auto* newBody = builder->makeBlock(
{builder->makeIf(builder->makeStateCheck(State::Rewinding),
makeLocalLoading()),
builder->makeLocalSet(
unwindIndex,
builder->makeBlock(ASYNCIFY_UNWIND,
builder->makeSequence(func->body, barrier))),
makeCallIndexPush(unwindIndex),
makeLocalSaving()});
if (func->getResults() != Type::none) {
newBody->list.push_back(
LiteralUtils::makeZero(func->getResults(), *getModule()));
newBody->finalize(func->getResults());
}
func->body = newBody;
ReFinalize().walkFunctionInModule(func, getModule());
}
private:
std::unique_ptr<AsyncifyBuilder> builder;
Index rewindIndex;
std::unordered_map<Type, Index> fakeCallLocals;
std::set<Index> relevantLiveLocals;
void findRelevantLiveLocals(Function* func) {
struct RelevantLiveLocalsWalker
: public LivenessWalker<RelevantLiveLocalsWalker,
Visitor<RelevantLiveLocalsWalker>> {
std::set<BasicBlock*> relevantBasicBlocks;
void visitCall(Call* curr) {
if (!currBasicBlock) {
return;
}
if (curr->target == ASYNCIFY_CHECK_CALL_INDEX) {
relevantBasicBlocks.insert(currBasicBlock);
}
}
};
RelevantLiveLocalsWalker walker;
walker.setFunction(func);
walker.walkFunctionInModule(func, getModule());
for (auto* block : walker.liveBlocks) {
if (walker.relevantBasicBlocks.count(block)) {
for (auto local : block->contents.start) {
relevantLiveLocals.insert(local);
}
}
}
}
Expression* makeLocalLoading() {
if (relevantLiveLocals.empty()) {
return builder->makeNop();
}
auto* func = getFunction();
auto numLocals = func->getNumLocals();
Index total = 0;
for (Index i = 0; i < numLocals; i++) {
if (!relevantLiveLocals.count(i)) {
continue;
}
total += getByteSize(func->getLocalType(i));
}
auto* block = builder->makeBlock();
block->list.push_back(builder->makeIncStackPos(-total));
auto tempIndex = builder->addVar(func, builder->pointerType);
block->list.push_back(
builder->makeLocalSet(tempIndex, builder->makeGetStackPos()));
Index offset = 0;
for (Index i = 0; i < numLocals; i++) {
if (!relevantLiveLocals.count(i)) {
continue;
}
auto localType = func->getLocalType(i);
SmallVector<Expression*, 1> loads;
for (const auto& type : localType) {
auto size = getByteSize(type);
assert(size % STACK_ALIGN == 0);
loads.push_back(builder->makeLoad(
size,
true,
offset,
STACK_ALIGN,
builder->makeLocalGet(tempIndex, builder->pointerType),
type,
asyncifyMemory));
offset += size;
}
Expression* load;
if (loads.size() == 1) {
load = loads[0];
} else if (localType.size() > 1) {
load = builder->makeTupleMake(std::move(loads));
} else {
WASM_UNREACHABLE("Unexpected empty type");
}
block->list.push_back(builder->makeLocalSet(i, load));
}
block->finalize();
return block;
}
Expression* makeLocalSaving() {
if (relevantLiveLocals.empty()) {
return builder->makeNop();
}
auto* func = getFunction();
auto numLocals = func->getNumLocals();
auto* block = builder->makeBlock();
auto tempIndex = builder->addVar(func, builder->pointerType);
block->list.push_back(
builder->makeLocalSet(tempIndex, builder->makeGetStackPos()));
Index offset = 0;
for (Index i = 0; i < numLocals; i++) {
if (!relevantLiveLocals.count(i)) {
continue;
}
auto localType = func->getLocalType(i);
size_t j = 0;
for (const auto& type : localType) {
auto size = getByteSize(type);
Expression* localGet = builder->makeLocalGet(i, localType);
if (localType.size() > 1) {
localGet = builder->makeTupleExtract(localGet, j);
}
assert(size % STACK_ALIGN == 0);
block->list.push_back(builder->makeStore(
size,
offset,
STACK_ALIGN,
builder->makeLocalGet(tempIndex, builder->pointerType),
localGet,
type,
asyncifyMemory));
offset += size;
++j;
}
}
block->list.push_back(builder->makeIncStackPos(offset));
block->finalize();
return block;
}
Expression* makeCallIndexPush(Index tempIndex) {
return builder->makeSequence(
builder->makeStore(4,
0,
4,
builder->makeGetStackPos(),
builder->makeLocalGet(tempIndex, Type::i32),
Type::i32,
asyncifyMemory),
builder->makeIncStackPos(4));
}
unsigned getByteSize(Type type) {
if (!type.hasByteSize()) {
Fatal() << "Asyncify does not yet support non-number types, like "
"references (see "
"https://github.com/WebAssembly/binaryen/issues/3739)";
}
return type.getByteSize();
}
};
}
static std::string getFullImportName(Name module, Name base) {
return std::string(module.str) + '.' + base.toString();
}
struct Asyncify : public Pass {
bool addsEffects() override { return true; }
void run(Module* module) override {
auto& options = getPassOptions();
bool optimize = options.optimizeLevel > 0;
auto stateChangingImports = String::trim(read_possible_response_file(
options.getArgumentOrDefault("asyncify-imports", "")));
auto ignoreImports =
options.getArgumentOrDefault("asyncify-ignore-imports", "");
bool allImportsCanChangeState =
stateChangingImports == "" && ignoreImports == "";
String::Split listedImports(stateChangingImports,
String::Split::NewLineOr(","));
auto canIndirectChangeState =
!options.hasArgument("asyncify-ignore-indirect");
std::string removeListInput =
options.getArgumentOrDefault("asyncify-removelist", "");
if (removeListInput.empty()) {
removeListInput = options.getArgumentOrDefault("asyncify-blacklist", "");
}
String::Split removeList(
String::trim(read_possible_response_file(removeListInput)),
String::Split::NewLineOr(","));
String::Split addList(
String::trim(read_possible_response_file(
options.getArgumentOrDefault("asyncify-addlist", ""))),
String::Split::NewLineOr(","));
std::string onlyListInput =
options.getArgumentOrDefault("asyncify-onlylist", "");
if (onlyListInput.empty()) {
onlyListInput = options.getArgumentOrDefault("asyncify-whitelist", "");
}
String::Split onlyList(
String::trim(read_possible_response_file(onlyListInput)),
String::Split::NewLineOr(","));
auto asserts = options.hasArgument("asyncify-asserts");
auto verbose = options.hasArgument("asyncify-verbose");
auto relocatable = options.hasArgument("asyncify-relocatable");
auto secondaryMemory = options.hasArgument("asyncify-in-secondary-memory");
if (secondaryMemory) {
auto secondaryMemorySizeString =
options.getArgumentOrDefault("asyncify-secondary-memory-size", "1");
Address secondaryMemorySize = std::stoi(secondaryMemorySizeString);
asyncifyMemory = createSecondaryMemory(module, secondaryMemorySize);
} else {
MemoryUtils::ensureExists(module);
asyncifyMemory = module->memories[0]->name;
}
pointerType =
module->getMemory(asyncifyMemory)->is64() ? Type::i64 : Type::i32;
removeList = handleBracketingOperators(removeList);
addList = handleBracketingOperators(addList);
onlyList = handleBracketingOperators(onlyList);
if (!onlyList.empty() && (!removeList.empty() || !addList.empty())) {
Fatal() << "It makes no sense to use both an asyncify only-list together "
"with another list.";
}
auto canImportChangeState = [&](Name module, Name base) {
if (allImportsCanChangeState) {
return true;
}
auto full = getFullImportName(module, base);
for (auto& listedImport : listedImports) {
if (String::wildcardMatch(listedImport, full)) {
return true;
}
}
return false;
};
ModuleAnalyzer analyzer(*module,
canImportChangeState,
canIndirectChangeState,
removeList,
addList,
onlyList,
verbose);
addGlobals(module, relocatable);
PassUtils::FuncSet instrumentedFuncs;
for (auto& func : module->functions) {
if (analyzer.needsInstrumentation(func.get())) {
instrumentedFuncs.insert(func.get());
}
}
{
PassUtils::FilteredPassRunner runner(module, instrumentedFuncs);
runner.add("flatten");
runner.add("dce");
if (optimize) {
runner.add("remove-unused-names");
runner.add("simplify-locals-nonesting");
runner.add("reorder-locals");
runner.add("coalesce-locals");
runner.add("simplify-locals-nonesting");
runner.add("reorder-locals");
runner.add("merge-blocks");
}
runner.add(
std::make_unique<AsyncifyFlow>(&analyzer, pointerType, asyncifyMemory));
runner.setIsNested(true);
runner.setValidateGlobally(false);
runner.run();
}
if (asserts) {
PassRunner runner(module);
runner.add(std::make_unique<AsyncifyAssertInNonInstrumented>(
&analyzer, pointerType, asyncifyMemory));
runner.setIsNested(true);
runner.setValidateGlobally(false);
runner.run();
}
{
PassUtils::FilteredPassRunner runner(module, instrumentedFuncs);
if (optimize) {
runner.addDefaultFunctionOptimizationPasses();
}
runner.add(std::make_unique<AsyncifyLocals>(
&analyzer, pointerType, asyncifyMemory));
if (optimize) {
runner.addDefaultFunctionOptimizationPasses();
}
runner.setIsNested(true);
runner.setValidateGlobally(false);
runner.run();
}
addFunctions(module);
}
private:
void addGlobals(Module* module, bool imported) {
Builder builder(*module);
auto asyncifyState = builder.makeGlobal(ASYNCIFY_STATE,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Mutable);
if (imported) {
asyncifyState->module = ENV;
asyncifyState->base = ASYNCIFY_STATE;
}
module->addGlobal(std::move(asyncifyState));
auto asyncifyData = builder.makeGlobal(ASYNCIFY_DATA,
pointerType,
builder.makeConst(pointerType),
Builder::Mutable);
if (imported) {
asyncifyData->module = ENV;
asyncifyData->base = ASYNCIFY_DATA;
}
module->addGlobal(std::move(asyncifyData));
}
void addFunctions(Module* module) {
Builder builder(*module);
auto makeFunction = [&](Name name, bool setData, State state) {
std::vector<Type> params;
if (setData) {
params.push_back(pointerType);
}
auto* body = builder.makeBlock();
body->list.push_back(builder.makeGlobalSet(
ASYNCIFY_STATE, builder.makeConst(int32_t(state))));
if (setData) {
body->list.push_back(builder.makeGlobalSet(
ASYNCIFY_DATA, builder.makeLocalGet(0, pointerType)));
}
auto* stackPos =
builder.makeLoad(pointerType.getByteSize(),
false,
int(DataOffset::BStackPos),
pointerType.getByteSize(),
builder.makeGlobalGet(ASYNCIFY_DATA, pointerType),
pointerType,
asyncifyMemory);
auto* stackEnd =
builder.makeLoad(pointerType.getByteSize(),
false,
int(pointerType == Type::i64 ? DataOffset::BStackEnd64
: DataOffset::BStackEnd),
pointerType.getByteSize(),
builder.makeGlobalGet(ASYNCIFY_DATA, pointerType),
pointerType,
asyncifyMemory);
body->list.push_back(builder.makeIf(
builder.makeBinary(
Abstract::getBinary(pointerType, Abstract::GtU), stackPos, stackEnd),
builder.makeUnreachable()));
body->finalize();
auto func = builder.makeFunction(
name, Signature(Type(params), Type::none), {}, body);
module->addFunction(std::move(func));
module->addExport(builder.makeExport(name, name, ExternalKind::Function));
};
makeFunction(ASYNCIFY_START_UNWIND, true, State::Unwinding);
makeFunction(ASYNCIFY_STOP_UNWIND, false, State::Normal);
makeFunction(ASYNCIFY_START_REWIND, true, State::Rewinding);
makeFunction(ASYNCIFY_STOP_REWIND, false, State::Normal);
module->addFunction(
builder.makeFunction(ASYNCIFY_GET_STATE,
Signature(Type::none, Type::i32),
{},
builder.makeGlobalGet(ASYNCIFY_STATE, Type::i32)));
module->addExport(builder.makeExport(
ASYNCIFY_GET_STATE, ASYNCIFY_GET_STATE, ExternalKind::Function));
}
Name createSecondaryMemory(Module* module, Address secondaryMemorySize) {
Name name = Names::getValidMemoryName(*module, "asyncify_memory");
auto secondaryMemory =
Builder::makeMemory(name, secondaryMemorySize, secondaryMemorySize);
module->addMemory(std::move(secondaryMemory));
return name;
}
Type pointerType;
Name asyncifyMemory;
};
Pass* createAsyncifyPass() { return new Asyncify(); }
template<bool neverRewind, bool neverUnwind, bool importsAlwaysUnwind>
struct ModAsyncify
: public WalkerPass<LinearExecutionWalker<
ModAsyncify<neverRewind, neverUnwind, importsAlwaysUnwind>>> {
bool isFunctionParallel() override { return true; }
std::unique_ptr<Pass> create() override {
return std::make_unique<
ModAsyncify<neverRewind, neverUnwind, importsAlwaysUnwind>>();
}
void doWalkFunction(Function* func) {
auto* unwind = this->getModule()->getExport(ASYNCIFY_STOP_UNWIND);
auto* unwindFunc = this->getModule()->getFunction(unwind->value);
FindAll<GlobalSet> sets(unwindFunc->body);
assert(sets.list.size() == 1);
asyncifyStateName = sets.list[0]->name;
this->walk(func->body);
}
void visitBinary(Binary* curr) {
bool flip = false;
if (curr->op == NeInt32) {
flip = true;
} else if (curr->op != EqInt32) {
return;
}
auto* c = curr->right->dynCast<Const>();
if (!c) {
return;
}
auto* get = curr->left->dynCast<GlobalGet>();
if (!get || get->name != asyncifyStateName) {
return;
}
int32_t value;
auto checkedValue = c->value.geti32();
if ((checkedValue == int(State::Unwinding) && neverUnwind) ||
(checkedValue == int(State::Rewinding) && neverRewind)) {
value = 0;
} else if (checkedValue == int(State::Unwinding) && this->unwinding) {
value = 1;
unsetUnwinding();
} else {
return;
}
if (flip) {
value = 1 - value;
}
Builder builder(*this->getModule());
this->replaceCurrent(builder.makeConst(int32_t(value)));
}
void visitSelect(Select* curr) {
auto* get = curr->condition->dynCast<GlobalGet>();
if (!get || get->name != asyncifyStateName) {
return;
}
if (neverRewind) {
Builder builder(*this->getModule());
curr->condition = builder.makeConst(int32_t(0));
}
}
void visitUnary(Unary* curr) {
if (curr->op != EqZInt32) {
return;
}
auto* get = curr->value->dynCast<GlobalGet>();
if (!get || get->name != asyncifyStateName) {
return;
}
if (neverRewind) {
Builder builder(*this->getModule());
this->replaceCurrent(builder.makeConst(int32_t(1)));
}
}
void visitCall(Call* curr) {
unsetUnwinding();
if (!importsAlwaysUnwind) {
return;
}
auto* target = this->getModule()->getFunction(curr->target);
if (!target->imported()) {
return;
}
this->unwinding = true;
}
void visitCallIndirect(CallIndirect* curr) { unsetUnwinding(); }
static void doNoteNonLinear(
ModAsyncify<neverRewind, neverUnwind, importsAlwaysUnwind>* self,
Expression**) {
self->unsetUnwinding();
}
void visitGlobalSet(GlobalSet* set) {
unsetUnwinding();
}
private:
Name asyncifyStateName;
bool unwinding = false;
void unsetUnwinding() { this->unwinding = false; }
};
Pass* createModAsyncifyAlwaysOnlyUnwindPass() {
return new ModAsyncify<true, false, true>();
}
struct ModAsyncifyNeverUnwind : public Pass {
void run(Module* module) override {}
};
Pass* createModAsyncifyNeverUnwindPass() {
return new ModAsyncify<false, true, false>();
}
}