#include "asmjs/shared-constants.h"
#include "ir/trapping.h"
#include "mixed_arena.h"
#include "pass.h"
#include "support/name.h"
#include "wasm-builder.h"
#include "wasm-type.h"
#include "wasm.h"
namespace wasm {
Name I64S_REM("i64s-rem");
Name I64U_REM("i64u-rem");
Name I64S_DIV("i64s-div");
Name I64U_DIV("i64u-div");
static Expression* ensureDouble(Expression* expr, MixedArena& allocator) {
if (expr->type == Type::f32) {
auto conv = allocator.alloc<Unary>();
conv->op = PromoteFloat32;
conv->value = expr;
conv->type = Type::f64;
return conv;
}
assert(expr->type == Type::f64);
return expr;
}
Name getBinaryFuncName(Binary* curr) {
switch (curr->op) {
case RemSInt32:
return I32S_REM;
case RemUInt32:
return I32U_REM;
case DivSInt32:
return I32S_DIV;
case DivUInt32:
return I32U_DIV;
case RemSInt64:
return I64S_REM;
case RemUInt64:
return I64U_REM;
case DivSInt64:
return I64S_DIV;
case DivUInt64:
return I64U_DIV;
default:
return Name();
}
}
Name getUnaryFuncName(Unary* curr) {
switch (curr->op) {
case TruncSFloat32ToInt32:
return F32_TO_INT;
case TruncUFloat32ToInt32:
return F32_TO_UINT;
case TruncSFloat32ToInt64:
return F32_TO_INT64;
case TruncUFloat32ToInt64:
return F32_TO_UINT64;
case TruncSFloat64ToInt32:
return F64_TO_INT;
case TruncUFloat64ToInt32:
return F64_TO_UINT;
case TruncSFloat64ToInt64:
return F64_TO_INT64;
case TruncUFloat64ToInt64:
return F64_TO_UINT64;
default:
return Name();
}
}
bool isTruncOpSigned(UnaryOp op) {
switch (op) {
case TruncUFloat32ToInt32:
case TruncUFloat32ToInt64:
case TruncUFloat64ToInt32:
case TruncUFloat64ToInt64:
return false;
default:
return true;
}
}
Function* generateBinaryFunc(Module& wasm, Binary* curr) {
BinaryOp op = curr->op;
Type type = curr->type;
bool isI64 = type == Type::i64;
Builder builder(wasm);
Expression* result = builder.makeBinary(
op, builder.makeLocalGet(0, type), builder.makeLocalGet(1, type));
BinaryOp divSIntOp = isI64 ? DivSInt64 : DivSInt32;
UnaryOp eqZOp = isI64 ? EqZInt64 : EqZInt32;
Literal minLit = isI64 ? Literal(std::numeric_limits<int64_t>::min())
: Literal(std::numeric_limits<int32_t>::min());
Literal zeroLit = isI64 ? Literal(int64_t(0)) : Literal(int32_t(0));
if (op == divSIntOp) {
BinaryOp eqOp = isI64 ? EqInt64 : EqInt32;
Literal negLit = isI64 ? Literal(int64_t(-1)) : Literal(int32_t(-1));
result = builder.makeIf(
builder.makeBinary(
AndInt32,
builder.makeBinary(
eqOp, builder.makeLocalGet(0, type), builder.makeConst(minLit)),
builder.makeBinary(
eqOp, builder.makeLocalGet(1, type), builder.makeConst(negLit))),
builder.makeConst(zeroLit),
result);
}
auto funcSig = Signature({type, type}, type);
auto func = Builder::makeFunction(getBinaryFuncName(curr), funcSig, {});
func->body =
builder.makeIf(builder.makeUnary(eqZOp, builder.makeLocalGet(1, type)),
builder.makeConst(zeroLit),
result);
return func.release();
}
template<typename IntType, typename FloatType>
void makeClampLimitLiterals(Literal& iMin, Literal& fMin, Literal& fMax) {
IntType minVal = std::numeric_limits<IntType>::min();
IntType maxVal = std::numeric_limits<IntType>::max();
iMin = Literal(minVal);
fMin = Literal(FloatType(minVal) - 1);
fMax = Literal(FloatType(maxVal) + 1);
}
Function* generateUnaryFunc(Module& wasm, Unary* curr) {
Type type = curr->value->type;
Type retType = curr->type;
UnaryOp truncOp = curr->op;
bool isF64 = type == Type::f64;
Builder builder(wasm);
BinaryOp leOp = isF64 ? LeFloat64 : LeFloat32;
BinaryOp geOp = isF64 ? GeFloat64 : GeFloat32;
BinaryOp neOp = isF64 ? NeFloat64 : NeFloat32;
Literal iMin, fMin, fMax;
switch (truncOp) {
case TruncSFloat32ToInt32:
makeClampLimitLiterals<int32_t, float>(iMin, fMin, fMax);
break;
case TruncUFloat32ToInt32:
makeClampLimitLiterals<uint32_t, float>(iMin, fMin, fMax);
break;
case TruncSFloat32ToInt64:
makeClampLimitLiterals<int64_t, float>(iMin, fMin, fMax);
break;
case TruncUFloat32ToInt64:
makeClampLimitLiterals<uint64_t, float>(iMin, fMin, fMax);
break;
case TruncSFloat64ToInt32:
makeClampLimitLiterals<int32_t, double>(iMin, fMin, fMax);
break;
case TruncUFloat64ToInt32:
makeClampLimitLiterals<uint32_t, double>(iMin, fMin, fMax);
break;
case TruncSFloat64ToInt64:
makeClampLimitLiterals<int64_t, double>(iMin, fMin, fMax);
break;
case TruncUFloat64ToInt64:
makeClampLimitLiterals<uint64_t, double>(iMin, fMin, fMax);
break;
default:
WASM_UNREACHABLE("unexpected op");
}
auto func =
Builder::makeFunction(getUnaryFuncName(curr), Signature(type, retType), {});
func->body = builder.makeUnary(truncOp, builder.makeLocalGet(0, type));
func->body = builder.makeIf(builder.makeBinary(leOp,
builder.makeLocalGet(0, type),
builder.makeConst(fMin)),
builder.makeConst(iMin),
func->body);
func->body = builder.makeIf(
builder.makeBinary(
geOp, builder.makeLocalGet(0, type), builder.makeConst(fMax)),
builder.makeConst(iMin),
func->body);
func->body = builder.makeIf(
builder.makeBinary(
neOp, builder.makeLocalGet(0, type), builder.makeLocalGet(0, type)),
builder.makeConst(iMin),
func->body);
return func.release();
}
void ensureBinaryFunc(Binary* curr,
Module& wasm,
TrappingFunctionContainer& trappingFunctions) {
Name name = getBinaryFuncName(curr);
if (trappingFunctions.hasFunction(name)) {
return;
}
trappingFunctions.addFunction(generateBinaryFunc(wasm, curr));
}
void ensureUnaryFunc(Unary* curr,
Module& wasm,
TrappingFunctionContainer& trappingFunctions) {
Name name = getUnaryFuncName(curr);
if (trappingFunctions.hasFunction(name)) {
return;
}
trappingFunctions.addFunction(generateUnaryFunc(wasm, curr));
}
void ensureF64ToI64JSImport(TrappingFunctionContainer& trappingFunctions) {
if (trappingFunctions.hasImport(F64_TO_INT)) {
return;
}
auto import = new Function;
import->name = F64_TO_INT;
import->module = ASM2WASM;
import->base = F64_TO_INT;
import->type = Signature(Type::f64, Type::i32);
trappingFunctions.addImport(import);
}
Expression* makeTrappingBinary(Binary* curr,
TrappingFunctionContainer& trappingFunctions) {
Name name = getBinaryFuncName(curr);
if (!name.is() || trappingFunctions.getMode() == TrapMode::Allow) {
return curr;
}
Type type = curr->type;
Module& wasm = trappingFunctions.getModule();
Builder builder(wasm);
ensureBinaryFunc(curr, wasm, trappingFunctions);
return builder.makeCall(name, {curr->left, curr->right}, type);
}
Expression* makeTrappingUnary(Unary* curr,
TrappingFunctionContainer& trappingFunctions) {
Name name = getUnaryFuncName(curr);
TrapMode mode = trappingFunctions.getMode();
if (!name.is() || mode == TrapMode::Allow) {
return curr;
}
Module& wasm = trappingFunctions.getModule();
Builder builder(wasm);
if (curr->type != Type::i64 && mode == TrapMode::JS) {
ensureF64ToI64JSImport(trappingFunctions);
Expression* f64Value = ensureDouble(curr->value, wasm.allocator);
return builder.makeCall(F64_TO_INT, {f64Value}, Type::i32);
}
ensureUnaryFunc(curr, wasm, trappingFunctions);
return builder.makeCall(name, {curr->value}, curr->type);
}
struct TrapModePass : public WalkerPass<PostWalker<TrapModePass>> {
public:
bool isFunctionParallel() override { return false; }
TrapModePass(TrapMode mode) : mode(mode) { assert(mode != TrapMode::Allow); }
std::unique_ptr<Pass> create() override {
return std::make_unique<TrapModePass>(mode);
}
void visitUnary(Unary* curr) {
replaceCurrent(makeTrappingUnary(curr, *trappingFunctions));
}
void visitBinary(Binary* curr) {
replaceCurrent(makeTrappingBinary(curr, *trappingFunctions));
}
void visitModule(Module* curr) { trappingFunctions->addToModule(); }
void doWalkModule(Module* module) {
trappingFunctions =
std::make_unique<TrappingFunctionContainer>(mode, *module);
super::doWalkModule(module);
}
private:
TrapMode mode;
std::unique_ptr<TrappingFunctionContainer> trappingFunctions;
};
Pass* createTrapModeClamp() { return new TrapModePass(TrapMode::Clamp); }
Pass* createTrapModeJS() { return new TrapModePass(TrapMode::JS); }
}