#include <memory>
#include "ir/global-utils.h"
#include "ir/import-utils.h"
#include "ir/literal-utils.h"
#include "ir/memory-utils.h"
#include "ir/module-utils.h"
#include "pass.h"
#include "support/colors.h"
#include "support/file.h"
#include "tool-options.h"
#include "wasm-builder.h"
#include "wasm-interpreter.h"
#include "wasm-io.h"
#include "wasm-validator.h"
using namespace wasm;
struct FailToEvalException {
std::string why;
FailToEvalException(std::string why) : why(why) {}
};
class EvallingGlobalManager {
std::map<Name, Literals> globals;
std::set<Name> dangerousGlobals;
bool sealed = false;
public:
void addDangerous(Name name) { dangerousGlobals.insert(name); }
void seal() { sealed = true; }
bool operator==(const EvallingGlobalManager& other) {
return globals == other.globals;
}
bool operator!=(const EvallingGlobalManager& other) {
return !(*this == other);
}
Literals& operator[](Name name) {
if (dangerousGlobals.count(name) > 0) {
std::string extra;
if (name == "___dso_handle") {
extra = "\nrecommendation: build with -s NO_EXIT_RUNTIME=1 so that "
"calls to atexit that use ___dso_handle are not emitted";
}
throw FailToEvalException(
std::string(
"tried to access a dangerous (import-initialized) global: ") +
name.str + extra);
}
return globals[name];
}
struct Iterator {
Name first;
Literals second;
bool found;
Iterator() : found(false) {}
Iterator(Name name, Literals value)
: first(name), second(value), found(true) {}
bool operator==(const Iterator& other) {
return first == other.first && second == other.second &&
found == other.found;
}
bool operator!=(const Iterator& other) { return !(*this == other); }
};
Iterator find(Name name) {
if (globals.find(name) == globals.end()) {
return end();
}
return Iterator(name, globals[name]);
}
Iterator end() { return Iterator(); }
};
static Index STACK_SIZE = 32 * 1024 * 1024;
static Index STACK_START = 1024 * 1024 * 1024 + STACK_SIZE;
static Index STACK_LOWER_LIMIT = STACK_START - STACK_SIZE;
static Index STACK_UPPER_LIMIT = STACK_START + STACK_SIZE;
class EvallingModuleInstance
: public ModuleInstanceBase<EvallingGlobalManager, EvallingModuleInstance> {
public:
EvallingModuleInstance(Module& wasm, ExternalInterface* externalInterface)
: ModuleInstanceBase(wasm, externalInterface) {
ModuleUtils::iterDefinedGlobals(wasm, [&](Global* global) {
if (!global->init->is<Const>()) {
if (auto* get = global->init->dynCast<GlobalGet>()) {
auto name = get->name;
auto* import = wasm.getGlobal(name);
if (import->module == Name(ENV) &&
(import->base ==
STACKTOP || import->base == STACK_MAX)) {
return; }
}
globals.addDangerous(global->name);
}
});
}
std::vector<char> stack;
void setupEnvironment() {
stack.resize(2 * STACK_SIZE);
auto total = STACK_START + STACK_SIZE;
memorySize = total / Memory::kPageSize;
}
};
struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface {
Module* wasm;
EvallingModuleInstance* instance;
void init(Module& wasm_, EvallingModuleInstance& instance_) override {
wasm = &wasm_;
instance = &instance_;
}
void importGlobals(EvallingGlobalManager& globals, Module& wasm_) override {
ImportInfo imports(wasm_);
if (auto* stackTop = imports.getImportedGlobal(ENV, STACKTOP)) {
globals[stackTop->name] = {Literal(int32_t(STACK_START))};
if (auto* stackTop =
GlobalUtils::getGlobalInitializedToImport(wasm_, ENV, STACKTOP)) {
globals[stackTop->name] = {Literal(int32_t(STACK_START))};
}
}
if (auto* stackMax = imports.getImportedGlobal(ENV, STACK_MAX)) {
globals[stackMax->name] = {Literal(int32_t(STACK_START))};
if (auto* stackMax =
GlobalUtils::getGlobalInitializedToImport(wasm_, ENV, STACK_MAX)) {
globals[stackMax->name] = {Literal(int32_t(STACK_START))};
}
}
ModuleUtils::iterDefinedGlobals(wasm_, [&](Global* defined) {
if (globals.find(defined->name) == globals.end()) {
globals[defined->name] = Literal::makeZeros(defined->type);
}
});
ModuleUtils::iterImportedGlobals(wasm_, [&](Global* import) {
if (globals.find(import->name) == globals.end()) {
globals[import->name] = Literal::makeZeros(import->type);
}
});
}
Literals callImport(Function* import, LiteralList& arguments) override {
std::string extra;
if (import->module == ENV && import->base == "___cxa_atexit") {
extra = "\nrecommendation: build with -s NO_EXIT_RUNTIME=1 so that calls "
"to atexit are not emitted";
}
throw FailToEvalException(std::string("call import: ") +
import->module.str + "." + import->base.str +
extra);
}
Literals callTable(Index index,
Signature sig,
LiteralList& arguments,
Type result,
EvallingModuleInstance& instance) override {
for (auto& segment : wasm->table.segments) {
Index start;
if (auto* c = segment.offset->dynCast<Const>()) {
start = c->value.getInteger();
} else if (segment.offset->is<GlobalGet>()) {
start = 0;
} else {
WASM_UNREACHABLE("invalid expr type");
}
auto end = start + segment.data.size();
if (start <= index && index < end) {
auto name = segment.data[index - start];
auto* func = wasm->getFunction(name);
if (func->sig != sig) {
throw FailToEvalException(
std::string("callTable signature mismatch: ") + name.str);
}
if (!func->imported()) {
return instance.callFunctionInternal(name, arguments);
} else {
throw FailToEvalException(
std::string("callTable on imported function: ") + name.str);
}
}
}
throw FailToEvalException(
std::string("callTable on index not found in static segments: ") +
std::to_string(index));
}
int8_t load8s(Address addr) override { return doLoad<int8_t>(addr); }
uint8_t load8u(Address addr) override { return doLoad<uint8_t>(addr); }
int16_t load16s(Address addr) override { return doLoad<int16_t>(addr); }
uint16_t load16u(Address addr) override { return doLoad<uint16_t>(addr); }
int32_t load32s(Address addr) override { return doLoad<int32_t>(addr); }
uint32_t load32u(Address addr) override { return doLoad<uint32_t>(addr); }
int64_t load64s(Address addr) override { return doLoad<int64_t>(addr); }
uint64_t load64u(Address addr) override { return doLoad<uint64_t>(addr); }
void store8(Address addr, int8_t value) override {
doStore<int8_t>(addr, value);
}
void store16(Address addr, int16_t value) override {
doStore<int16_t>(addr, value);
}
void store32(Address addr, int32_t value) override {
doStore<int32_t>(addr, value);
}
void store64(Address addr, int64_t value) override {
doStore<int64_t>(addr, value);
}
void tableStore(Address addr, Name value) override {}
bool growMemory(Address , Address newSize) override {
throw FailToEvalException("grow memory");
}
void trap(const char* why) override {
throw FailToEvalException(std::string("trap: ") + why);
}
void throwException(Literal exn) override {
std::stringstream ss;
ss << "exception thrown: " << exn;
throw FailToEvalException(ss.str());
}
private:
template<typename T> T* getMemory(Address address) {
if (address >= STACK_LOWER_LIMIT) {
if (address >= STACK_UPPER_LIMIT) {
throw FailToEvalException("stack usage too high");
}
Address relative = address - STACK_LOWER_LIMIT;
return (T*)(&instance->stack[relative]);
}
if (wasm->memory.segments.size() == 0) {
std::vector<char> temp;
Builder builder(*wasm);
wasm->memory.segments.push_back(
Memory::Segment(builder.makeConst(int32_t(0)), temp));
}
assert(wasm->memory.segments[0].offset->cast<Const>()->value.getInteger() ==
0);
auto max = address + sizeof(T);
auto& data = wasm->memory.segments[0].data;
if (max > data.size()) {
data.resize(max);
}
return (T*)(&data[address]);
}
template<typename T> void doStore(Address address, T value) {
memcpy(getMemory<T>(address), &value, sizeof(T));
}
template<typename T> T doLoad(Address address) {
T ret;
memcpy(&ret, getMemory<T>(address), sizeof(T));
return ret;
}
};
void evalCtors(Module& wasm, std::vector<std::string> ctors) {
CtorEvalExternalInterface interface;
try {
if (!MemoryUtils::flatten(wasm.memory)) {
Fatal() << " ...stopping since could not flatten memory\n";
}
EvallingModuleInstance instance(wasm, &interface);
instance.setupEnvironment();
instance.globals.seal();
for (auto& ctor : ctors) {
std::cerr << "trying to eval " << ctor << '\n';
auto memoryBefore = wasm.memory;
auto globalsBefore = instance.globals;
Export* ex = wasm.getExportOrNull(ctor);
if (!ex) {
Fatal() << "export not found: " << ctor;
}
try {
instance.callFunction(ex->value, LiteralList());
} catch (FailToEvalException& fail) {
std::cerr << " ...stopping since could not eval: " << fail.why << "\n";
wasm.memory = memoryBefore;
return;
}
if (instance.globals != globalsBefore) {
std::cerr << " ...stopping since globals modified\n";
wasm.memory = memoryBefore;
return;
}
std::cerr << " ...success on " << ctor << ".\n";
auto* exp = wasm.getExport(ctor);
auto* func = wasm.getFunction(exp->value);
func->body = wasm.allocator.alloc<Nop>();
wasm.removeExport(exp->name);
}
} catch (FailToEvalException& fail) {
std::cerr << " ...stopping since could not create module instance: "
<< fail.why << "\n";
return;
}
}
int main(int argc, const char* argv[]) {
Name entry;
std::vector<std::string> passes;
bool emitBinary = true;
bool debugInfo = false;
std::string ctorsString;
ToolOptions options("wasm-ctor-eval",
"Execute C++ global constructors ahead of time");
options
.add("--output",
"-o",
"Output file (stdout if not specified)",
Options::Arguments::One,
[](Options* o, const std::string& argument) {
o->extra["output"] = argument;
Colors::setEnabled(false);
})
.add("--emit-text",
"-S",
"Emit text instead of binary for the output file",
Options::Arguments::Zero,
[&](Options* o, const std::string& argument) { emitBinary = false; })
.add("--debuginfo",
"-g",
"Emit names section and debug info",
Options::Arguments::Zero,
[&](Options* o, const std::string& arguments) { debugInfo = true; })
.add(
"--ctors",
"-c",
"Comma-separated list of global constructor functions to evaluate",
Options::Arguments::One,
[&](Options* o, const std::string& argument) { ctorsString = argument; })
.add_positional("INFILE",
Options::Arguments::One,
[](Options* o, const std::string& argument) {
o->extra["infile"] = argument;
});
options.parse(argc, argv);
auto input(read_file<std::string>(options.extra["infile"], Flags::Text));
Module wasm;
{
if (options.debug) {
std::cerr << "reading...\n";
}
ModuleReader reader;
try {
reader.read(options.extra["infile"], wasm);
} catch (ParseException& p) {
p.dump(std::cerr);
Fatal() << "error in parsing input";
}
}
options.applyFeatures(wasm);
if (!WasmValidator().validate(wasm)) {
WasmPrinter::printModule(&wasm);
Fatal() << "error in validating input";
}
std::vector<std::string> ctors;
std::istringstream stream(ctorsString);
std::string temp;
while (std::getline(stream, temp, ',')) {
ctors.push_back(temp);
}
evalCtors(wasm, ctors);
{
PassRunner passRunner(&wasm);
passRunner.add("memory-packing"); passRunner.add("remove-unused-names");
passRunner.add("dce");
passRunner.add("merge-blocks");
passRunner.add("vacuum");
passRunner.add("remove-unused-module-elements");
passRunner.run();
}
if (options.extra.count("output") > 0) {
if (options.debug) {
std::cerr << "writing..." << std::endl;
}
ModuleWriter writer;
writer.setBinary(emitBinary);
writer.setDebugInfo(debugInfo);
writer.write(wasm, options.extra["output"]);
}
}