#include "abi/js.h"
#include "ir/import-utils.h"
#include "pass.h"
#include "shared-constants.h"
#include "support/debug.h"
#include "wasm-emscripten.h"
#define DEBUG_TYPE "stack-check"
namespace wasm {
static Name STACK_BASE("__stack_base");
static Name STACK_LIMIT("__stack_limit");
static Name SET_STACK_LIMITS("__set_stack_limits");
static void importStackOverflowHandler(Module& module, Name name) {
ImportInfo info(module);
if (!info.getImportedFunction(ENV, name)) {
auto* import = new Function;
import->name = name;
import->module = ENV;
import->base = name;
import->sig = Signature(Type::none, Type::none);
module.addFunction(import);
}
}
static void addExportedFunction(Module& module, Function* function) {
module.addFunction(function);
auto export_ = new Export;
export_->name = export_->value = function->name;
export_->kind = ExternalKind::Function;
module.addExport(export_);
}
static void generateSetStackLimitFunctions(Module& module) {
Builder builder(module);
Function* limitsFunc = builder.makeFunction(
SET_STACK_LIMITS, Signature({Type::i32, Type::i32}, Type::none), {});
LocalGet* getBase = builder.makeLocalGet(0, Type::i32);
Expression* storeBase = builder.makeGlobalSet(STACK_BASE, getBase);
LocalGet* getLimit = builder.makeLocalGet(1, Type::i32);
Expression* storeLimit = builder.makeGlobalSet(STACK_LIMIT, getLimit);
limitsFunc->body = builder.makeBlock({storeBase, storeLimit});
addExportedFunction(module, limitsFunc);
}
struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> {
EnforceStackLimits(Global* stackPointer,
Global* stackBase,
Global* stackLimit,
Builder& builder,
Name handler)
: stackPointer(stackPointer), stackBase(stackBase), stackLimit(stackLimit),
builder(builder), handler(handler) {}
bool isFunctionParallel() override { return true; }
Pass* create() override {
return new EnforceStackLimits(
stackPointer, stackBase, stackLimit, builder, handler);
}
Expression* stackBoundsCheck(Function* func, Expression* value) {
auto newSP = Builder::addVar(func, stackPointer->type);
Expression* handlerExpr;
if (handler.is()) {
handlerExpr = builder.makeCall(handler, {}, Type::none);
} else {
handlerExpr = builder.makeUnreachable();
}
auto check = builder.makeIf(
builder.makeBinary(
BinaryOp::OrInt32,
builder.makeBinary(
BinaryOp::GtUInt32,
builder.makeLocalTee(newSP, value, stackPointer->type),
builder.makeGlobalGet(stackBase->name, stackBase->type)),
builder.makeBinary(
BinaryOp::LtUInt32,
builder.makeLocalGet(newSP, stackPointer->type),
builder.makeGlobalGet(stackLimit->name, stackLimit->type))),
handlerExpr);
auto newSet = builder.makeGlobalSet(
stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
return builder.blockify(check, newSet);
}
void visitGlobalSet(GlobalSet* curr) {
if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
replaceCurrent(stackBoundsCheck(getFunction(), curr->value));
}
}
private:
Global* stackPointer;
Global* stackBase;
Global* stackLimit;
Builder& builder;
Name handler;
};
struct StackCheck : public Pass {
void run(PassRunner* runner, Module* module) override {
Global* stackPointer = getStackPointerGlobal(*module);
if (!stackPointer) {
BYN_DEBUG(std::cerr << "no stack pointer found\n");
return;
}
Name handler;
auto handlerName =
runner->options.getArgumentOrDefault("stack-check-handler", "");
if (handlerName != "") {
handler = handlerName;
importStackOverflowHandler(*module, handler);
}
Builder builder(*module);
Global* stackBase = builder.makeGlobal(STACK_BASE,
stackPointer->type,
builder.makeConst(int32_t(0)),
Builder::Mutable);
module->addGlobal(stackBase);
Global* stackLimit = builder.makeGlobal(STACK_LIMIT,
stackPointer->type,
builder.makeConst(int32_t(0)),
Builder::Mutable);
module->addGlobal(stackLimit);
PassRunner innerRunner(module);
EnforceStackLimits(stackPointer, stackBase, stackLimit, builder, handler)
.run(&innerRunner, module);
generateSetStackLimitFunctions(*module);
}
};
Pass* createStackCheckPass() { return new StackCheck; }
}