#include <pass.h>
#include <wasm.h>
#include "abi/js.h"
#include "asmjs/shared-constants.h"
#include "ir/find_all.h"
#include "ir/literal-utils.h"
#include "ir/memory-utils.h"
#include "ir/module-utils.h"
#include "passes/intrinsics-module.h"
#include "wasm-builder.h"
#include "wasm-s-parser.h"
namespace wasm {
struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> {
std::unique_ptr<Builder> builder;
std::unordered_set<Name> neededIntrinsics;
std::set<std::pair<Name, Type>> neededImportedGlobals;
bool isFunctionParallel() override { return false; }
Pass* create() override { return new RemoveNonJSOpsPass; }
void doWalkModule(Module* module) {
ABI::wasm2js::ensureHelpers(module);
if (!builder) {
builder = make_unique<Builder>(*module);
}
PostWalker<RemoveNonJSOpsPass>::doWalkModule(module);
if (neededIntrinsics.size() == 0) {
return;
}
Module intrinsicsModule;
std::string input(IntrinsicsModuleWast);
SExpressionParser parser(const_cast<char*>(input.c_str()));
Element& root = *parser.root;
SExpressionWasmBuilder builder(
intrinsicsModule, *root[0], IRProfile::Normal);
std::set<Name> neededFunctions;
while (neededIntrinsics.size() > 0) {
for (auto& name : neededIntrinsics) {
addNeededFunctions(intrinsicsModule, name, neededFunctions);
}
neededIntrinsics.clear();
for (auto& name : neededFunctions) {
auto* func = module->getFunctionOrNull(name);
if (!func) {
func = ModuleUtils::copyFunction(intrinsicsModule.getFunction(name),
*module);
}
doWalkFunction(func);
}
neededFunctions.clear();
}
for (auto& global : intrinsicsModule.globals) {
ModuleUtils::copyGlobal(global.get(), *module);
}
MemoryUtils::ensureExists(module->memory);
for (auto& pair : neededImportedGlobals) {
auto name = pair.first;
auto type = pair.second;
if (!getModule()->getGlobalOrNull(name)) {
auto global = make_unique<Global>();
global->name = name;
global->type = type;
global->mutable_ = false;
global->module = ENV;
global->base = name;
module->addGlobal(global.release());
}
}
}
void addNeededFunctions(Module& m, Name name, std::set<Name>& needed) {
if (needed.count(name)) {
return;
}
needed.insert(name);
auto function = m.getFunction(name);
FindAll<Call> calls(function->body);
for (auto* call : calls.list) {
auto* called = m.getFunction(call->target);
if (!called->imported()) {
this->addNeededFunctions(m, call->target, needed);
}
}
}
void doWalkFunction(Function* func) {
if (!builder) {
builder = make_unique<Builder>(*getModule());
}
PostWalker<RemoveNonJSOpsPass>::doWalkFunction(func);
}
void visitLoad(Load* curr) {
if (curr->align == 0 || curr->align >= curr->bytes) {
return;
}
switch (curr->type.getBasic()) {
case Type::f32:
curr->type = Type::i32;
replaceCurrent(builder->makeUnary(ReinterpretInt32, curr));
break;
case Type::f64:
curr->type = Type::i64;
replaceCurrent(builder->makeUnary(ReinterpretInt64, curr));
break;
default:
break;
}
}
void visitStore(Store* curr) {
if (curr->align == 0 || curr->align >= curr->bytes) {
return;
}
switch (curr->valueType.getBasic()) {
case Type::f32:
curr->valueType = Type::i32;
curr->value = builder->makeUnary(ReinterpretFloat32, curr->value);
break;
case Type::f64:
curr->valueType = Type::i64;
curr->value = builder->makeUnary(ReinterpretFloat64, curr->value);
break;
default:
break;
}
}
void visitBinary(Binary* curr) {
Name name;
switch (curr->op) {
case CopySignFloat32:
case CopySignFloat64:
rewriteCopysign(curr);
return;
case RotLInt32:
name = WASM_ROTL32;
break;
case RotRInt32:
name = WASM_ROTR32;
break;
case RotLInt64:
name = WASM_ROTL64;
break;
case RotRInt64:
name = WASM_ROTR64;
break;
case MulInt64:
name = WASM_I64_MUL;
break;
case DivSInt64:
name = WASM_I64_SDIV;
break;
case DivUInt64:
name = WASM_I64_UDIV;
break;
case RemSInt64:
name = WASM_I64_SREM;
break;
case RemUInt64:
name = WASM_I64_UREM;
break;
default:
return;
}
neededIntrinsics.insert(name);
replaceCurrent(
builder->makeCall(name, {curr->left, curr->right}, curr->type));
}
void rewriteCopysign(Binary* curr) {
Literal signBit, otherBits;
UnaryOp int2float, float2int;
BinaryOp bitAnd, bitOr;
switch (curr->op) {
case CopySignFloat32:
float2int = ReinterpretFloat32;
int2float = ReinterpretInt32;
bitAnd = AndInt32;
bitOr = OrInt32;
signBit = Literal(uint32_t(1 << 31));
otherBits = Literal(uint32_t(1 << 31) - 1);
break;
case CopySignFloat64:
float2int = ReinterpretFloat64;
int2float = ReinterpretInt64;
bitAnd = AndInt64;
bitOr = OrInt64;
signBit = Literal(uint64_t(1) << 63);
otherBits = Literal((uint64_t(1) << 63) - 1);
break;
default:
return;
}
replaceCurrent(builder->makeUnary(
int2float,
builder->makeBinary(
bitOr,
builder->makeBinary(bitAnd,
builder->makeUnary(float2int, curr->left),
builder->makeConst(otherBits)),
builder->makeBinary(bitAnd,
builder->makeUnary(float2int, curr->right),
builder->makeConst(signBit)))));
}
void visitUnary(Unary* curr) {
Name functionCall;
switch (curr->op) {
case NearestFloat32:
functionCall = WASM_NEAREST_F32;
break;
case NearestFloat64:
functionCall = WASM_NEAREST_F64;
break;
case TruncFloat32:
functionCall = WASM_TRUNC_F32;
break;
case TruncFloat64:
functionCall = WASM_TRUNC_F64;
break;
case PopcntInt64:
functionCall = WASM_POPCNT64;
break;
case PopcntInt32:
functionCall = WASM_POPCNT32;
break;
case CtzInt64:
functionCall = WASM_CTZ64;
break;
case CtzInt32:
functionCall = WASM_CTZ32;
break;
default:
return;
}
neededIntrinsics.insert(functionCall);
replaceCurrent(builder->makeCall(functionCall, {curr->value}, curr->type));
}
void visitGlobalGet(GlobalGet* curr) {
neededImportedGlobals.insert(std::make_pair(curr->name, curr->type));
}
};
struct StubUnsupportedJSOpsPass
: public WalkerPass<PostWalker<StubUnsupportedJSOpsPass>> {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new StubUnsupportedJSOpsPass; }
void visitUnary(Unary* curr) {
switch (curr->op) {
case ConvertUInt64ToFloat32:
stubOut(curr->value, curr->type);
break;
default: {
}
}
}
void visitCallIndirect(CallIndirect* curr) {
Builder builder(*getModule());
std::vector<Expression*> items;
for (auto* operand : curr->operands) {
items.push_back(builder.makeDrop(operand));
}
items.push_back(builder.makeDrop(curr->target));
stubOut(builder.makeBlock(items), curr->type);
}
void stubOut(Expression* value, Type outputType) {
Builder builder(*getModule());
auto* replacement = value;
if (outputType == Type::unreachable) {
assert(value->type == Type::unreachable);
} else if (outputType != Type::none) {
if (value->type != Type::none) {
value = builder.makeDrop(value);
}
replacement = builder.makeSequence(
value, LiteralUtils::makeZero(outputType, *getModule()));
}
replaceCurrent(replacement);
}
};
Pass* createRemoveNonJSOpsPass() { return new RemoveNonJSOpsPass(); }
Pass* createStubUnsupportedJSOpsPass() {
return new StubUnsupportedJSOpsPass();
}
}