#include "cfg/liveness-traversal.h"
#include "ir/effects.h"
#include "ir/find_all.h"
#include "ir/literal-utils.h"
#include "ir/memory-utils.h"
#include "ir/module-utils.h"
#include "ir/utils.h"
#include "pass.h"
#include "support/file.h"
#include "support/string.h"
#include "wasm-builder.h"
#include "wasm.h"
namespace wasm {
namespace {
static const Name ASYNCIFY_STATE = "__asyncify_state";
static const Name ASYNCIFY_GET_STATE = "asyncify_get_state";
static const Name ASYNCIFY_DATA = "__asyncify_data";
static const Name ASYNCIFY_START_UNWIND = "asyncify_start_unwind";
static const Name ASYNCIFY_STOP_UNWIND = "asyncify_stop_unwind";
static const Name ASYNCIFY_START_REWIND = "asyncify_start_rewind";
static const Name ASYNCIFY_STOP_REWIND = "asyncify_stop_rewind";
static const Name ASYNCIFY_UNWIND = "__asyncify_unwind";
static const Name ASYNCIFY = "asyncify";
static const Name START_UNWIND = "start_unwind";
static const Name STOP_UNWIND = "stop_unwind";
static const Name START_REWIND = "start_rewind";
static const Name STOP_REWIND = "stop_rewind";
static const Name ASYNCIFY_GET_CALL_INDEX = "__asyncify_get_call_index";
static const Name ASYNCIFY_CHECK_CALL_INDEX = "__asyncify_check_call_index";
enum class State { Normal = 0, Unwinding = 1, Rewinding = 2 };
enum class DataOffset { BStackPos = 0, BStackEnd = 4 };
const auto STACK_ALIGN = 4;
class FakeGlobalHelper {
Module& module;
public:
FakeGlobalHelper(Module& module) : module(module) {
Builder builder(module);
std::string prefix = "asyncify_fake_call_global_";
for (auto type : collectTypes()) {
auto global = prefix + Type(type).toString();
map[type] = global;
rev[global] = type;
module.addGlobal(builder.makeGlobal(
global, type, LiteralUtils::makeZero(type, module), Builder::Mutable));
}
}
~FakeGlobalHelper() {
for (auto& pair : map) {
auto name = pair.second;
module.removeGlobal(name);
}
}
Name getName(Type type) { return map.at(type); }
Type getTypeOrNone(Name name) {
auto iter = rev.find(name);
if (iter != rev.end()) {
return iter->second;
}
return Type::none;
}
private:
std::map<Type, Name> map;
std::map<Name, Type> rev;
using Types = std::unordered_set<Type>;
Types collectTypes() {
ModuleUtils::ParallelFunctionAnalysis<Types> analysis(
module, [&](Function* func, Types& types) {
if (!func->body) {
return;
}
struct TypeCollector : PostWalker<TypeCollector> {
Types& types;
TypeCollector(Types& types) : types(types) {}
void visitCall(Call* curr) {
if (curr->type.isConcrete()) {
types.insert(curr->type);
}
}
void visitCallIndirect(CallIndirect* curr) {
if (curr->type.isConcrete()) {
types.insert(curr->type);
}
}
};
TypeCollector(types).walk(func->body);
});
Types types;
for (auto& pair : analysis.map) {
Types& functionTypes = pair.second;
for (auto t : functionTypes) {
types.insert(t);
}
}
return types;
}
};
class PatternMatcher {
public:
std::string designation;
std::set<Name> names;
std::set<std::string> patterns;
std::set<std::string> patternsMatched;
std::map<std::string, std::string> unescaped;
PatternMatcher(std::string designation,
Module& module,
const String::Split& list)
: designation(designation) {
for (auto& name : list) {
auto escaped = WasmBinaryBuilder::escape(name);
unescaped[escaped.str] = name;
if (name.find('*') != std::string::npos) {
patterns.insert(escaped.str);
} else {
auto* func = module.getFunctionOrNull(escaped);
if (!func) {
std::cerr << "warning: Asyncify " << designation
<< "list contained a non-existing function name: " << name
<< " (" << escaped << ")\n";
} else if (func->imported()) {
Fatal() << "Asyncify " << designation
<< "list contained an imported function name (use the import "
"list for imports): "
<< name << '\n';
}
names.insert(escaped.str);
}
}
}
bool match(Name funcName) {
if (names.count(funcName) > 0) {
return true;
} else {
for (auto& pattern : patterns) {
if (String::wildcardMatch(pattern, funcName.str)) {
patternsMatched.insert(pattern);
return true;
}
}
}
return false;
}
void checkPatternsMatches() {
for (auto& pattern : patterns) {
if (patternsMatched.count(pattern) == 0) {
std::cerr << "warning: Asyncify " << designation
<< "list contained a non-matching pattern: "
<< unescaped[pattern] << " (" << pattern << ")\n";
}
}
}
};
class ModuleAnalyzer {
Module& module;
bool canIndirectChangeState;
struct Info
: public ModuleUtils::CallGraphPropertyAnalysis<Info>::FunctionInfo {
Name name;
bool canChangeState = false;
bool isBottomMostRuntime = false;
bool isTopMostRuntime = false;
bool inRemoveList = false;
bool addedFromList = false;
};
typedef std::map<Function*, Info> Map;
Map map;
public:
ModuleAnalyzer(Module& module,
std::function<bool(Name, Name)> canImportChangeState,
bool canIndirectChangeState,
const String::Split& removeListInput,
const String::Split& addListInput,
const String::Split& onlyListInput,
bool asserts,
bool verbose)
: module(module), canIndirectChangeState(canIndirectChangeState),
fakeGlobals(module), asserts(asserts), verbose(verbose) {
PatternMatcher removeList("remove", module, removeListInput);
PatternMatcher addList("add", module, addListInput);
PatternMatcher onlyList("only", module, onlyListInput);
std::map<Name, Name> renamings;
for (auto& func : module.functions) {
if (func->module == ASYNCIFY) {
if (func->base == START_UNWIND) {
renamings[func->name] = ASYNCIFY_START_UNWIND;
} else if (func->base == STOP_UNWIND) {
renamings[func->name] = ASYNCIFY_STOP_UNWIND;
} else if (func->base == START_REWIND) {
renamings[func->name] = ASYNCIFY_START_REWIND;
} else if (func->base == STOP_REWIND) {
renamings[func->name] = ASYNCIFY_STOP_REWIND;
} else {
Fatal() << "call to unidenfied asyncify import: " << func->base;
}
}
}
ModuleUtils::renameFunctions(module, renamings);
ModuleUtils::CallGraphPropertyAnalysis<Info> scanner(
module, [&](Function* func, Info& info) {
info.name = func->name;
if (func->imported()) {
if (func->module == ASYNCIFY &&
(func->base == START_UNWIND || func->base == STOP_REWIND)) {
info.canChangeState = true;
} else {
info.canChangeState =
canImportChangeState(func->module, func->base);
if (verbose && info.canChangeState) {
std::cout << "[asyncify] " << func->name
<< " is an import that can change the state\n";
}
}
return;
}
struct Walker : PostWalker<Walker> {
Info& info;
Module& module;
bool canIndirectChangeState;
Walker(Info& info, Module& module, bool canIndirectChangeState)
: info(info), module(module),
canIndirectChangeState(canIndirectChangeState) {}
void visitCall(Call* curr) {
if (curr->isReturn) {
Fatal() << "tail calls not yet supported in asyncify";
}
auto* target = module.getFunction(curr->target);
if (target->imported() && target->module == ASYNCIFY) {
if (target->base == START_UNWIND) {
info.canChangeState = true;
info.isTopMostRuntime = true;
} else if (target->base == STOP_UNWIND) {
info.isBottomMostRuntime = true;
} else if (target->base == START_REWIND) {
info.isBottomMostRuntime = true;
} else if (target->base == STOP_REWIND) {
info.canChangeState = true;
info.isTopMostRuntime = true;
} else {
WASM_UNREACHABLE("call to unidenfied asyncify import");
}
}
}
void visitCallIndirect(CallIndirect* curr) {
if (curr->isReturn) {
Fatal() << "tail calls not yet supported in asyncify";
}
if (canIndirectChangeState) {
info.canChangeState = true;
}
}
};
Walker walker(info, module, canIndirectChangeState);
walker.walk(func->body);
if (info.isBottomMostRuntime) {
info.canChangeState = false;
}
if (verbose && info.canChangeState) {
std::cout << "[asyncify] " << func->name
<< " can change the state due to initial scan\n";
}
});
for (auto& pair : scanner.map) {
auto* func = pair.first;
auto& info = pair.second;
if (removeList.match(func->name)) {
info.inRemoveList = true;
if (verbose && info.canChangeState) {
std::cout << "[asyncify] " << func->name
<< " is in the remove-list, ignore\n";
}
info.canChangeState = false;
}
}
std::vector<Name> funcsToDelete;
for (auto& pair : scanner.map) {
auto* func = pair.first;
auto& callsTo = pair.second.callsTo;
if (func->imported() && func->module == ASYNCIFY) {
funcsToDelete.push_back(func->name);
}
std::vector<Function*> callersToDelete;
for (auto* target : callsTo) {
if (target->imported() && target->module == ASYNCIFY) {
callersToDelete.push_back(target);
}
}
for (auto* target : callersToDelete) {
callsTo.erase(target);
}
}
for (auto name : funcsToDelete) {
module.removeFunction(name);
}
scanner.propagateBack([](const Info& info) { return info.canChangeState; },
[](const Info& info) {
return !info.isBottomMostRuntime &&
!info.inRemoveList;
},
[verbose](Info& info, Function* reason) {
if (verbose && !info.canChangeState) {
std::cout << "[asyncify] " << info.name
<< " can change the state due to "
<< reason->name << "\n";
}
info.canChangeState = true;
},
scanner.IgnoreIndirectCalls);
map.swap(scanner.map);
if (!onlyListInput.empty()) {
for (auto& func : module.functions) {
if (!func->imported()) {
auto& info = map[func.get()];
bool matched = onlyList.match(func->name);
info.canChangeState = matched;
if (matched) {
info.addedFromList = true;
}
if (verbose) {
std::cout << "[asyncify] " << func->name
<< "'s state is set based on the only-list to " << matched
<< '\n';
}
}
}
}
if (!addListInput.empty()) {
for (auto& func : module.functions) {
if (!func->imported() && addList.match(func->name)) {
auto& info = map[func.get()];
if (verbose && !info.canChangeState) {
std::cout << "[asyncify] " << func->name
<< " is in the add-list, add\n";
}
info.canChangeState = true;
info.addedFromList = true;
}
}
}
removeList.checkPatternsMatches();
addList.checkPatternsMatches();
onlyList.checkPatternsMatches();
}
bool needsInstrumentation(Function* func) {
auto& info = map[func];
return info.canChangeState && !info.isTopMostRuntime;
}
bool canChangeState(Expression* curr, Function* func) {
struct Walker : PostWalker<Walker> {
void visitCall(Call* curr) {
if (curr->target == ASYNCIFY_START_UNWIND ||
curr->target == ASYNCIFY_STOP_REWIND ||
curr->target == ASYNCIFY_GET_CALL_INDEX ||
curr->target == ASYNCIFY_CHECK_CALL_INDEX) {
canChangeState = true;
return;
}
if (curr->target == ASYNCIFY_STOP_UNWIND ||
curr->target == ASYNCIFY_START_REWIND) {
isBottomMostRuntime = true;
return;
}
auto* target = module->getFunctionOrNull(curr->target);
if (target && (*map)[target].canChangeState) {
canChangeState = true;
}
}
void visitCallIndirect(CallIndirect* curr) { hasIndirectCall = true; }
Module* module;
ModuleAnalyzer* analyzer;
Map* map;
bool hasIndirectCall = false;
bool canChangeState = false;
bool isBottomMostRuntime = false;
};
Walker walker;
walker.module = &module;
walker.analyzer = this;
walker.map = ↦
walker.walk(curr);
if (walker.hasIndirectCall &&
(canIndirectChangeState || map[func].addedFromList)) {
walker.canChangeState = true;
}
return walker.canChangeState && !walker.isBottomMostRuntime;
}
FakeGlobalHelper fakeGlobals;
bool asserts;
bool verbose;
};
static bool doesCall(Expression* curr) {
if (auto* set = curr->dynCast<LocalSet>()) {
curr = set->value;
} else if (auto* drop = curr->dynCast<Drop>()) {
curr = drop->value;
}
return curr->is<Call>() || curr->is<CallIndirect>();
}
class AsyncifyBuilder : public Builder {
public:
AsyncifyBuilder(Module& wasm) : Builder(wasm) {}
Expression* makeGetStackPos() {
return makeLoad(4,
false,
int32_t(DataOffset::BStackPos),
4,
makeGlobalGet(ASYNCIFY_DATA, Type::i32),
Type::i32);
}
Expression* makeIncStackPos(int32_t by) {
if (by == 0) {
return makeNop();
}
return makeStore(
4,
int32_t(DataOffset::BStackPos),
4,
makeGlobalGet(ASYNCIFY_DATA, Type::i32),
makeBinary(AddInt32, makeGetStackPos(), makeConst(Literal(by))),
Type::i32);
}
Expression* makeStateCheck(State value) {
return makeBinary(EqInt32,
makeGlobalGet(ASYNCIFY_STATE, Type::i32),
makeConst(Literal(int32_t(value))));
}
Expression* makeNegatedStateCheck(State value) {
return makeUnary(EqZInt32, makeStateCheck(value));
}
};
struct AsyncifyFlow : public Pass {
bool isFunctionParallel() override { return true; }
ModuleAnalyzer* analyzer;
AsyncifyFlow* create() override { return new AsyncifyFlow(analyzer); }
AsyncifyFlow(ModuleAnalyzer* analyzer) : analyzer(analyzer) {}
void
runOnFunction(PassRunner* runner, Module* module_, Function* func_) override {
module = module_;
func = func_;
builder = make_unique<AsyncifyBuilder>(*module);
if (!analyzer->needsInstrumentation(func)) {
if (analyzer->asserts) {
addAssertsInNonInstrumented(func);
}
return;
}
auto* block = builder->makeBlock(
{builder->makeIf(builder->makeStateCheck(
State::Rewinding), makeCallIndexPop()),
process(func->body)});
if (func->sig.results != Type::none) {
block->list.push_back(builder->makeUnreachable());
}
block->finalize();
func->body = block;
ReFinalize().walkFunctionInModule(func, module);
}
private:
std::unique_ptr<AsyncifyBuilder> builder;
Module* module;
Function* func;
Index callIndex = 0;
Expression* process(Expression* curr) {
if (!analyzer->canChangeState(curr, func)) {
return makeMaybeSkip(curr);
}
if (auto* block = curr->dynCast<Block>()) {
Index i = 0;
auto& list = block->list;
while (i < list.size()) {
if (analyzer->canChangeState(list[i], func)) {
list[i] = process(list[i]);
i++;
} else {
Index end = i + 1;
while (end < list.size() &&
!analyzer->canChangeState(list[end], func)) {
end++;
}
if (end == i + 1) {
list[i] = makeMaybeSkip(list[i]);
} else {
auto* block = builder->makeBlock();
for (auto j = i; j < end; j++) {
block->list.push_back(list[j]);
}
block->finalize();
list[i] = makeMaybeSkip(block);
for (auto j = i + 1; j < end; j++) {
list[j] = builder->makeNop();
}
}
i = end;
}
}
return block;
} else if (auto* iff = curr->dynCast<If>()) {
assert(!analyzer->canChangeState(iff->condition, func));
if (!iff->ifFalse) {
iff->condition = builder->makeBinary(
OrInt32, iff->condition, builder->makeStateCheck(State::Rewinding));
iff->ifTrue = process(iff->ifTrue);
iff->finalize();
return iff;
}
auto conditionTemp = builder->addVar(func, Type::i32);
auto* pre =
makeMaybeSkip(builder->makeLocalSet(conditionTemp, iff->condition));
iff->condition = builder->makeLocalGet(conditionTemp, Type::i32);
iff->condition = builder->makeBinary(
OrInt32, iff->condition, builder->makeStateCheck(State::Rewinding));
iff->ifTrue = process(iff->ifTrue);
auto* otherArm = iff->ifFalse;
iff->ifFalse = nullptr;
iff->finalize();
auto* otherIf = builder->makeIf(
builder->makeBinary(
OrInt32,
builder->makeUnary(EqZInt32,
builder->makeLocalGet(conditionTemp, Type::i32)),
builder->makeStateCheck(State::Rewinding)),
process(otherArm));
otherIf->finalize();
return builder->makeBlock({pre, iff, otherIf});
} else if (auto* loop = curr->dynCast<Loop>()) {
loop->body = process(loop->body);
return loop;
} else if (doesCall(curr)) {
return makeCallSupport(curr);
}
WASM_UNREACHABLE("unexpected expression type");
}
Expression* makeMaybeSkip(Expression* curr) {
return builder->makeIf(builder->makeStateCheck(State::Normal), curr);
}
Expression* makeCallSupport(Expression* curr) {
assert(doesCall(curr));
assert(curr->type == Type::none);
auto* set = curr->dynCast<LocalSet>();
if (set) {
auto name = analyzer->fakeGlobals.getName(set->value->type);
curr = builder->makeGlobalSet(name, set->value);
set->value = builder->makeGlobalGet(name, set->value->type);
}
auto index = callIndex++;
curr = builder->makeIf(
builder->makeIf(builder->makeStateCheck(State::Normal),
builder->makeConst(int32_t(1)),
makeCallIndexPeek(index)),
builder->makeSequence(curr, makePossibleUnwind(index, set)));
return curr;
}
Expression* makePossibleUnwind(Index index, Expression* ifNotUnwinding) {
return builder->makeIf(
builder->makeStateCheck(State::Unwinding),
builder->makeCall(
ASYNCIFY_UNWIND, {builder->makeConst(int32_t(index))}, Type::none),
ifNotUnwinding);
}
Expression* makeCallIndexPeek(Index index) {
return builder->makeCall(ASYNCIFY_CHECK_CALL_INDEX,
{builder->makeConst(int32_t(index))},
Type::i32);
}
Expression* makeCallIndexPop() {
return builder->makeCall(ASYNCIFY_GET_CALL_INDEX, {}, Type::none);
}
void addAssertsInNonInstrumented(Function* func) {
auto oldState = builder->addVar(func, Type::i32);
func->body = builder->makeSequence(
builder->makeLocalSet(oldState,
builder->makeGlobalGet(ASYNCIFY_STATE, Type::i32)),
func->body);
struct Walker : PostWalker<Walker> {
void visitCall(Call* curr) {
assert(!curr->isReturn);
handleCall(curr);
}
void visitCallIndirect(CallIndirect* curr) {
assert(!curr->isReturn);
handleCall(curr);
}
void handleCall(Expression* call) {
auto* check = builder->makeIf(
builder->makeBinary(NeInt32,
builder->makeGlobalGet(ASYNCIFY_STATE, Type::i32),
builder->makeLocalGet(oldState, Type::i32)),
builder->makeUnreachable());
Expression* rep;
if (call->type.isConcrete()) {
auto temp = builder->addVar(func, call->type);
rep = builder->makeBlock({
builder->makeLocalSet(temp, call),
check,
builder->makeLocalGet(temp, call->type),
});
} else {
rep = builder->makeSequence(call, check);
}
replaceCurrent(rep);
}
Function* func;
AsyncifyBuilder* builder;
Index oldState;
};
Walker walker;
walker.func = func;
walker.builder = builder.get();
walker.oldState = oldState;
walker.walk(func->body);
}
};
struct AsyncifyLocals : public WalkerPass<PostWalker<AsyncifyLocals>> {
bool isFunctionParallel() override { return true; }
ModuleAnalyzer* analyzer;
AsyncifyLocals* create() override { return new AsyncifyLocals(analyzer); }
AsyncifyLocals(ModuleAnalyzer* analyzer) : analyzer(analyzer) {}
void visitCall(Call* curr) {
if (curr->target == ASYNCIFY_UNWIND) {
replaceCurrent(builder->makeBreak(ASYNCIFY_UNWIND, curr->operands[0]));
} else if (curr->target == ASYNCIFY_GET_CALL_INDEX) {
replaceCurrent(builder->makeSequence(
builder->makeIncStackPos(-4),
builder->makeLocalSet(
rewindIndex,
builder->makeLoad(
4, false, 0, 4, builder->makeGetStackPos(), Type::i32))));
} else if (curr->target == ASYNCIFY_CHECK_CALL_INDEX) {
replaceCurrent(builder->makeBinary(
EqInt32,
builder->makeLocalGet(rewindIndex, Type::i32),
builder->makeConst(
Literal(int32_t(curr->operands[0]->cast<Const>()->value.geti32())))));
}
}
void visitGlobalSet(GlobalSet* curr) {
auto type = analyzer->fakeGlobals.getTypeOrNone(curr->name);
if (type != Type::none) {
replaceCurrent(
builder->makeLocalSet(getFakeCallLocal(type), curr->value));
}
}
void visitGlobalGet(GlobalGet* curr) {
auto type = analyzer->fakeGlobals.getTypeOrNone(curr->name);
if (type != Type::none) {
replaceCurrent(builder->makeLocalGet(getFakeCallLocal(type), type));
}
}
Index getFakeCallLocal(Type type) {
auto iter = fakeCallLocals.find(type);
if (iter != fakeCallLocals.end()) {
return iter->second;
}
return fakeCallLocals[type] = builder->addVar(getFunction(), type);
}
void doWalkFunction(Function* func) {
if (!analyzer->needsInstrumentation(func)) {
return;
}
findRelevantLiveLocals(func);
auto unwindIndex = builder->addVar(func, Type::i32);
rewindIndex = builder->addVar(func, Type::i32);
builder = make_unique<AsyncifyBuilder>(*getModule());
walk(func->body);
Expression* barrier;
if (func->sig.results == Type::none) {
barrier = builder->makeReturn();
} else {
barrier = builder->makeUnreachable();
}
auto* newBody = builder->makeBlock(
{builder->makeIf(builder->makeStateCheck(State::Rewinding),
makeLocalLoading()),
builder->makeLocalSet(
unwindIndex,
builder->makeBlock(ASYNCIFY_UNWIND,
builder->makeSequence(func->body, barrier))),
makeCallIndexPush(unwindIndex),
makeLocalSaving()});
if (func->sig.results != Type::none) {
newBody->list.push_back(
LiteralUtils::makeZero(func->sig.results, *getModule()));
newBody->finalize(func->sig.results);
}
func->body = newBody;
ReFinalize().walkFunctionInModule(func, getModule());
}
private:
std::unique_ptr<AsyncifyBuilder> builder;
Index rewindIndex;
std::map<Type, Index> fakeCallLocals;
std::set<Index> relevantLiveLocals;
void findRelevantLiveLocals(Function* func) {
struct RelevantLiveLocalsWalker
: public LivenessWalker<RelevantLiveLocalsWalker,
Visitor<RelevantLiveLocalsWalker>> {
std::set<BasicBlock*> relevantBasicBlocks;
void visitCall(Call* curr) {
if (!currBasicBlock) {
return;
}
if (curr->target == ASYNCIFY_CHECK_CALL_INDEX) {
relevantBasicBlocks.insert(currBasicBlock);
}
}
};
RelevantLiveLocalsWalker walker;
walker.walkFunctionInModule(func, getModule());
for (auto* block : walker.liveBlocks) {
if (walker.relevantBasicBlocks.count(block)) {
for (auto local : block->contents.start) {
relevantLiveLocals.insert(local);
}
}
}
}
Expression* makeLocalLoading() {
if (relevantLiveLocals.empty()) {
return builder->makeNop();
}
auto* func = getFunction();
auto numLocals = func->getNumLocals();
Index total = 0;
for (Index i = 0; i < numLocals; i++) {
if (!relevantLiveLocals.count(i)) {
continue;
}
total += func->getLocalType(i).getByteSize();
}
auto* block = builder->makeBlock();
block->list.push_back(builder->makeIncStackPos(-total));
auto tempIndex = builder->addVar(func, Type::i32);
block->list.push_back(
builder->makeLocalSet(tempIndex, builder->makeGetStackPos()));
Index offset = 0;
for (Index i = 0; i < numLocals; i++) {
if (!relevantLiveLocals.count(i)) {
continue;
}
auto localType = func->getLocalType(i);
SmallVector<Expression*, 1> loads;
for (const auto& type : localType) {
auto size = type.getByteSize();
assert(size % STACK_ALIGN == 0);
loads.push_back(
builder->makeLoad(size,
true,
offset,
STACK_ALIGN,
builder->makeLocalGet(tempIndex, Type::i32),
type));
offset += size;
}
Expression* load;
if (loads.size() == 1) {
load = loads[0];
} else if (localType.size() > 1) {
load = builder->makeTupleMake(std::move(loads));
} else {
WASM_UNREACHABLE("Unexpected empty type");
}
block->list.push_back(builder->makeLocalSet(i, load));
}
block->finalize();
return block;
}
Expression* makeLocalSaving() {
if (relevantLiveLocals.empty()) {
return builder->makeNop();
}
auto* func = getFunction();
auto numLocals = func->getNumLocals();
auto* block = builder->makeBlock();
auto tempIndex = builder->addVar(func, Type::i32);
block->list.push_back(
builder->makeLocalSet(tempIndex, builder->makeGetStackPos()));
Index offset = 0;
for (Index i = 0; i < numLocals; i++) {
if (!relevantLiveLocals.count(i)) {
continue;
}
auto localType = func->getLocalType(i);
size_t j = 0;
for (const auto& type : localType) {
auto size = type.getByteSize();
Expression* localGet = builder->makeLocalGet(i, localType);
if (localType.size() > 1) {
localGet = builder->makeTupleExtract(localGet, j);
}
assert(size % STACK_ALIGN == 0);
block->list.push_back(
builder->makeStore(size,
offset,
STACK_ALIGN,
builder->makeLocalGet(tempIndex, Type::i32),
localGet,
type));
offset += size;
++j;
}
}
block->list.push_back(builder->makeIncStackPos(offset));
block->finalize();
return block;
}
Expression* makeCallIndexPush(Index tempIndex) {
return builder->makeSequence(
builder->makeStore(4,
0,
4,
builder->makeGetStackPos(),
builder->makeLocalGet(tempIndex, Type::i32),
Type::i32),
builder->makeIncStackPos(4));
}
};
}
static std::string getFullImportName(Name module, Name base) {
return std::string(module.str) + '.' + base.str;
}
struct Asyncify : public Pass {
void run(PassRunner* runner, Module* module) override {
bool optimize = runner->options.optimizeLevel > 0;
MemoryUtils::ensureExists(module->memory);
auto stateChangingImports = String::trim(read_possible_response_file(
runner->options.getArgumentOrDefault("asyncify-imports", "")));
auto ignoreImports =
runner->options.getArgumentOrDefault("asyncify-ignore-imports", "");
bool allImportsCanChangeState =
stateChangingImports == "" && ignoreImports == "";
String::Split listedImports(stateChangingImports, ",");
auto ignoreIndirect = runner->options.getArgumentOrDefault(
"asyncify-ignore-indirect", "") == "";
std::string removeListInput =
runner->options.getArgumentOrDefault("asyncify-removelist", "");
if (removeListInput.empty()) {
removeListInput =
runner->options.getArgumentOrDefault("asyncify-blacklist", "");
}
String::Split removeList(
String::trim(read_possible_response_file(removeListInput)), ",");
String::Split addList(
String::trim(read_possible_response_file(
runner->options.getArgumentOrDefault("asyncify-addlist", ""))),
",");
std::string onlyListInput =
runner->options.getArgumentOrDefault("asyncify-onlylist", "");
if (onlyListInput.empty()) {
onlyListInput =
runner->options.getArgumentOrDefault("asyncify-whitelist", "");
}
String::Split onlyList(
String::trim(read_possible_response_file(onlyListInput)), ",");
auto asserts =
runner->options.getArgumentOrDefault("asyncify-asserts", "") != "";
auto verbose =
runner->options.getArgumentOrDefault("asyncify-verbose", "") != "";
removeList = handleBracketingOperators(removeList);
addList = handleBracketingOperators(addList);
onlyList = handleBracketingOperators(onlyList);
if (!onlyList.empty() && (!removeList.empty() || !addList.empty())) {
Fatal() << "It makes no sense to use both an asyncify only-list together "
"with another list.";
}
auto canImportChangeState = [&](Name module, Name base) {
if (allImportsCanChangeState) {
return true;
}
auto full = getFullImportName(module, base);
for (auto& listedImport : listedImports) {
if (String::wildcardMatch(listedImport, full)) {
return true;
}
}
return false;
};
ModuleAnalyzer analyzer(*module,
canImportChangeState,
ignoreIndirect,
removeList,
addList,
onlyList,
asserts,
verbose);
addGlobals(module);
{
PassRunner runner(module);
runner.add("flatten");
runner.add("dce");
if (optimize) {
runner.add("simplify-locals-nonesting");
runner.add("reorder-locals");
runner.add("coalesce-locals");
runner.add("simplify-locals-nonesting");
runner.add("reorder-locals");
runner.add("merge-blocks");
}
runner.add(make_unique<AsyncifyFlow>(&analyzer));
runner.setIsNested(true);
runner.setValidateGlobally(false);
runner.run();
}
{
PassRunner runner(module);
if (optimize) {
runner.addDefaultFunctionOptimizationPasses();
}
runner.add(make_unique<AsyncifyLocals>(&analyzer));
if (optimize) {
runner.addDefaultFunctionOptimizationPasses();
}
runner.setIsNested(true);
runner.setValidateGlobally(false);
runner.run();
}
addFunctions(module);
}
private:
void addGlobals(Module* module) {
Builder builder(*module);
module->addGlobal(builder.makeGlobal(ASYNCIFY_STATE,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Mutable));
module->addGlobal(builder.makeGlobal(ASYNCIFY_DATA,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Mutable));
}
void addFunctions(Module* module) {
Builder builder(*module);
auto makeFunction = [&](Name name, bool setData, State state) {
std::vector<Type> params;
if (setData) {
params.push_back(Type::i32);
}
auto* body = builder.makeBlock();
body->list.push_back(builder.makeGlobalSet(
ASYNCIFY_STATE, builder.makeConst(int32_t(state))));
if (setData) {
body->list.push_back(builder.makeGlobalSet(
ASYNCIFY_DATA, builder.makeLocalGet(0, Type::i32)));
}
auto* stackPos =
builder.makeLoad(4,
false,
int32_t(DataOffset::BStackPos),
4,
builder.makeGlobalGet(ASYNCIFY_DATA, Type::i32),
Type::i32);
auto* stackEnd =
builder.makeLoad(4,
false,
int32_t(DataOffset::BStackEnd),
4,
builder.makeGlobalGet(ASYNCIFY_DATA, Type::i32),
Type::i32);
body->list.push_back(
builder.makeIf(builder.makeBinary(GtUInt32, stackPos, stackEnd),
builder.makeUnreachable()));
body->finalize();
auto* func = builder.makeFunction(
name, Signature(Type(params), Type::none), {}, body);
module->addFunction(func);
module->addExport(builder.makeExport(name, name, ExternalKind::Function));
};
makeFunction(ASYNCIFY_START_UNWIND, true, State::Unwinding);
makeFunction(ASYNCIFY_STOP_UNWIND, false, State::Normal);
makeFunction(ASYNCIFY_START_REWIND, true, State::Rewinding);
makeFunction(ASYNCIFY_STOP_REWIND, false, State::Normal);
module->addFunction(
builder.makeFunction(ASYNCIFY_GET_STATE,
Signature(Type::none, Type::i32),
{},
builder.makeGlobalGet(ASYNCIFY_STATE, Type::i32)));
module->addExport(builder.makeExport(
ASYNCIFY_GET_STATE, ASYNCIFY_GET_STATE, ExternalKind::Function));
}
};
Pass* createAsyncifyPass() { return new Asyncify(); }
template<bool neverRewind, bool neverUnwind, bool importsAlwaysUnwind>
struct ModAsyncify
: public WalkerPass<LinearExecutionWalker<
ModAsyncify<neverRewind, neverUnwind, importsAlwaysUnwind>>> {
bool isFunctionParallel() override { return true; }
ModAsyncify* create() override {
return new ModAsyncify<neverRewind, neverUnwind, importsAlwaysUnwind>();
}
void doWalkFunction(Function* func) {
auto* unwind = this->getModule()->getExport(ASYNCIFY_STOP_UNWIND);
auto* unwindFunc = this->getModule()->getFunction(unwind->value);
FindAll<GlobalSet> sets(unwindFunc->body);
assert(sets.list.size() == 1);
asyncifyStateName = sets.list[0]->name;
this->walk(func->body);
}
void visitBinary(Binary* curr) {
bool flip = false;
if (curr->op == NeInt32) {
flip = true;
} else if (curr->op != EqInt32) {
return;
}
auto* c = curr->right->dynCast<Const>();
if (!c) {
return;
}
auto* get = curr->left->dynCast<GlobalGet>();
if (!get || get->name != asyncifyStateName) {
return;
}
int32_t value;
auto checkedValue = c->value.geti32();
if ((checkedValue == int(State::Unwinding) && neverUnwind) ||
(checkedValue == int(State::Rewinding) && neverRewind)) {
value = 0;
} else if (checkedValue == int(State::Unwinding) && this->unwinding) {
value = 1;
unsetUnwinding();
} else {
return;
}
if (flip) {
value = 1 - value;
}
Builder builder(*this->getModule());
this->replaceCurrent(builder.makeConst(int32_t(value)));
}
void visitSelect(Select* curr) {
auto* get = curr->condition->dynCast<GlobalGet>();
if (!get || get->name != asyncifyStateName) {
return;
}
if (neverRewind) {
Builder builder(*this->getModule());
curr->condition = builder.makeConst(int32_t(0));
}
}
void visitCall(Call* curr) {
unsetUnwinding();
if (!importsAlwaysUnwind) {
return;
}
auto* target = this->getModule()->getFunction(curr->target);
if (!target->imported()) {
return;
}
this->unwinding = true;
}
void visitCallIndirect(CallIndirect* curr) { unsetUnwinding(); }
static void doNoteNonLinear(
ModAsyncify<neverRewind, neverUnwind, importsAlwaysUnwind>* self,
Expression**) {
self->unsetUnwinding();
}
void visitGlobalSet(GlobalSet* set) {
unsetUnwinding();
}
private:
Name asyncifyStateName;
bool unwinding = false;
void unsetUnwinding() { this->unwinding = false; }
};
Pass* createModAsyncifyAlwaysOnlyUnwindPass() {
return new ModAsyncify<true, false, true>();
}
struct ModAsyncifyNeverUnwind : public Pass {
void run(PassRunner* runner, Module* module) override {}
};
Pass* createModAsyncifyNeverUnwindPass() {
return new ModAsyncify<false, true, false>();
}
}