#include "abi/js.h"
#include "ir/abstract.h"
#include "ir/import-utils.h"
#include "ir/names.h"
#include "pass.h"
#include "shared-constants.h"
#include "support/debug.h"
#include "wasm-emscripten.h"
#define DEBUG_TYPE "stack-check"
namespace wasm {
static Name SET_STACK_LIMITS("__set_stack_limits");
static void
importStackOverflowHandler(Module& module, Name name, Signature sig) {
ImportInfo info(module);
if (!info.getImportedFunction(ENV, name)) {
auto import = Builder::makeFunction(name, sig, {});
import->module = ENV;
import->base = name;
module.addFunction(std::move(import));
}
}
static void addExportedFunction(Module& module,
std::unique_ptr<Function> function) {
auto export_ =
Builder::makeExport(function->name, function->name, ExternalKind::Function);
module.addFunction(std::move(function));
module.addExport(std::move(export_));
}
struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> {
EnforceStackLimits(const Global* stackPointer,
const Global* stackBase,
const Global* stackLimit,
Builder& builder,
Name handler)
: stackPointer(stackPointer), stackBase(stackBase), stackLimit(stackLimit),
builder(builder), handler(handler) {}
bool isFunctionParallel() override { return true; }
bool requiresNonNullableLocalFixups() override { return false; }
std::unique_ptr<Pass> create() override {
return std::make_unique<EnforceStackLimits>(
stackPointer, stackBase, stackLimit, builder, handler);
}
Expression* stackBoundsCheck(Function* func, Expression* value) {
auto newSP = Builder::addVar(func, stackPointer->type);
Expression* handlerExpr;
if (handler.is()) {
handlerExpr =
builder.makeCall(handler,
{builder.makeLocalGet(newSP, stackPointer->type)},
stackPointer->type);
} else {
handlerExpr = builder.makeUnreachable();
}
auto check = builder.makeIf(
builder.makeBinary(
BinaryOp::OrInt32,
builder.makeBinary(
Abstract::getBinary(stackPointer->type, Abstract::GtU),
builder.makeLocalTee(newSP, value, stackPointer->type),
builder.makeGlobalGet(stackBase->name, stackBase->type)),
builder.makeBinary(
Abstract::getBinary(stackPointer->type, Abstract::LtU),
builder.makeLocalGet(newSP, stackPointer->type),
builder.makeGlobalGet(stackLimit->name, stackLimit->type))),
handlerExpr);
auto newSet = builder.makeGlobalSet(
stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
return builder.blockify(check, newSet);
}
void visitGlobalSet(GlobalSet* curr) {
if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
replaceCurrent(stackBoundsCheck(getFunction(), curr->value));
}
}
private:
const Global* stackPointer;
const Global* stackBase;
const Global* stackLimit;
Builder& builder;
Name handler;
};
struct StackCheck : public Pass {
bool addsEffects() override { return true; }
void run(Module* module) override {
Global* stackPointer = getStackPointerGlobal(*module);
if (!stackPointer) {
BYN_DEBUG(std::cerr << "no stack pointer found\n");
return;
}
auto stackBaseName = Names::getValidGlobalName(*module, "__stack_base");
auto stackLimitName = Names::getValidGlobalName(*module, "__stack_limit");
Name handler;
auto handlerName =
getPassOptions().getArgumentOrDefault("stack-check-handler", "");
if (handlerName != "") {
handler = handlerName;
importStackOverflowHandler(
*module, handler, Signature({stackPointer->type}, Type::none));
}
Builder builder(*module);
Type indexType =
module->memories.empty() ? Type::i32 : module->memories[0]->indexType;
auto stackBase =
module->addGlobal(builder.makeGlobal(stackBaseName,
stackPointer->type,
builder.makeConstPtr(0, indexType),
Builder::Mutable));
auto stackLimit =
module->addGlobal(builder.makeGlobal(stackLimitName,
stackPointer->type,
builder.makeConstPtr(0, indexType),
Builder::Mutable));
EnforceStackLimits(stackPointer, stackBase, stackLimit, builder, handler)
.run(getPassRunner(), module);
auto limitsFunc = builder.makeFunction(
SET_STACK_LIMITS,
Signature({stackPointer->type, stackPointer->type}, Type::none),
{});
auto* getBase = builder.makeLocalGet(0, stackPointer->type);
auto* storeBase = builder.makeGlobalSet(stackBaseName, getBase);
auto* getLimit = builder.makeLocalGet(1, stackPointer->type);
auto* storeLimit = builder.makeGlobalSet(stackLimitName, getLimit);
limitsFunc->body = builder.makeBlock({storeBase, storeLimit});
addExportedFunction(*module, std::move(limitsFunc));
}
};
Pass* createStackCheckPass() { return new StackCheck; }
}