#ifndef wasm_wasm_ir_builder_h
#define wasm_wasm_ir_builder_h
#include <vector>
#include "ir/names.h"
#include "support/result.h"
#include "wasm-builder.h"
#include "wasm-traversal.h"
#include "wasm-type.h"
#include "wasm.h"
namespace wasm {
class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
public:
IRBuilder(Module& wasm, Function* func = nullptr)
: wasm(wasm), func(func), builder(wasm) {}
[[nodiscard]] Result<Expression*> build();
[[nodiscard]] Result<> visit(Expression*);
void push(Expression*);
void setDebugLocation(const Function::DebugLocation&);
[[nodiscard]] Result<> visitFunctionStart(Function* func);
[[nodiscard]] Result<> visitBlockStart(Block* block);
[[nodiscard]] Result<> visitIfStart(If* iff, Name label = {});
[[nodiscard]] Result<> visitElse();
[[nodiscard]] Result<> visitLoopStart(Loop* iff);
[[nodiscard]] Result<> visitTryStart(Try* tryy, Name label = {});
[[nodiscard]] Result<> visitCatch(Name tag);
[[nodiscard]] Result<> visitCatchAll();
[[nodiscard]] Result<> visitDelegate(Index label);
[[nodiscard]] Result<> visitTryTableStart(TryTable* trytable,
Name label = {});
[[nodiscard]] Result<> visitEnd();
[[nodiscard]] Result<Index> getLabelIndex(Name label,
bool inDelegate = false);
[[nodiscard]] Result<> makeNop();
[[nodiscard]] Result<> makeBlock(Name label, Type type);
[[nodiscard]] Result<> makeIf(Name label, Type type);
[[nodiscard]] Result<> makeLoop(Name label, Type type);
[[nodiscard]] Result<> makeBreak(Index label, bool isConditional);
[[nodiscard]] Result<> makeSwitch(const std::vector<Index>& labels,
Index defaultLabel);
[[nodiscard]] Result<> makeCall(Name func, bool isReturn);
[[nodiscard]] Result<>
makeCallIndirect(Name table, HeapType type, bool isReturn);
[[nodiscard]] Result<> makeLocalGet(Index local);
[[nodiscard]] Result<> makeLocalSet(Index local);
[[nodiscard]] Result<> makeLocalTee(Index local);
[[nodiscard]] Result<> makeGlobalGet(Name global);
[[nodiscard]] Result<> makeGlobalSet(Name global);
[[nodiscard]] Result<> makeLoad(unsigned bytes,
bool signed_,
Address offset,
unsigned align,
Type type,
Name mem);
[[nodiscard]] Result<> makeStore(
unsigned bytes, Address offset, unsigned align, Type type, Name mem);
[[nodiscard]] Result<>
makeAtomicLoad(unsigned bytes, Address offset, Type type, Name mem);
[[nodiscard]] Result<>
makeAtomicStore(unsigned bytes, Address offset, Type type, Name mem);
[[nodiscard]] Result<> makeAtomicRMW(
AtomicRMWOp op, unsigned bytes, Address offset, Type type, Name mem);
[[nodiscard]] Result<>
makeAtomicCmpxchg(unsigned bytes, Address offset, Type type, Name mem);
[[nodiscard]] Result<> makeAtomicWait(Type type, Address offset, Name mem);
[[nodiscard]] Result<> makeAtomicNotify(Address offset, Name mem);
[[nodiscard]] Result<> makeAtomicFence();
[[nodiscard]] Result<> makeSIMDExtract(SIMDExtractOp op, uint8_t lane);
[[nodiscard]] Result<> makeSIMDReplace(SIMDReplaceOp op, uint8_t lane);
[[nodiscard]] Result<> makeSIMDShuffle(const std::array<uint8_t, 16>& lanes);
[[nodiscard]] Result<> makeSIMDTernary(SIMDTernaryOp op);
[[nodiscard]] Result<> makeSIMDShift(SIMDShiftOp op);
[[nodiscard]] Result<>
makeSIMDLoad(SIMDLoadOp op, Address offset, unsigned align, Name mem);
[[nodiscard]] Result<> makeSIMDLoadStoreLane(SIMDLoadStoreLaneOp op,
Address offset,
unsigned align,
uint8_t lane,
Name mem);
[[nodiscard]] Result<> makeMemoryInit(Name data, Name mem);
[[nodiscard]] Result<> makeDataDrop(Name data);
[[nodiscard]] Result<> makeMemoryCopy(Name destMem, Name srcMem);
[[nodiscard]] Result<> makeMemoryFill(Name mem);
[[nodiscard]] Result<> makeConst(Literal val);
[[nodiscard]] Result<> makeUnary(UnaryOp op);
[[nodiscard]] Result<> makeBinary(BinaryOp op);
[[nodiscard]] Result<> makeSelect(std::optional<Type> type = std::nullopt);
[[nodiscard]] Result<> makeDrop();
[[nodiscard]] Result<> makeReturn();
[[nodiscard]] Result<> makeMemorySize(Name mem);
[[nodiscard]] Result<> makeMemoryGrow(Name mem);
[[nodiscard]] Result<> makeUnreachable();
[[nodiscard]] Result<> makePop(Type type);
[[nodiscard]] Result<> makeRefNull(HeapType type);
[[nodiscard]] Result<> makeRefIsNull();
[[nodiscard]] Result<> makeRefFunc(Name func);
[[nodiscard]] Result<> makeRefEq();
[[nodiscard]] Result<> makeTableGet(Name table);
[[nodiscard]] Result<> makeTableSet(Name table);
[[nodiscard]] Result<> makeTableSize(Name table);
[[nodiscard]] Result<> makeTableGrow(Name table);
[[nodiscard]] Result<> makeTableFill(Name table);
[[nodiscard]] Result<> makeTableCopy(Name destTable, Name srcTable);
[[nodiscard]] Result<> makeTry(Name label, Type type);
[[nodiscard]] Result<> makeTryTable(Name label,
Type type,
const std::vector<Name>& tags,
const std::vector<Index>& labels,
const std::vector<bool>& isRefs);
[[nodiscard]] Result<> makeThrow(Name tag);
[[nodiscard]] Result<> makeRethrow(Index label);
[[nodiscard]] Result<> makeThrowRef();
[[nodiscard]] Result<> makeTupleMake(uint32_t arity);
[[nodiscard]] Result<> makeTupleExtract(uint32_t arity, uint32_t index);
[[nodiscard]] Result<> makeTupleDrop(uint32_t arity);
[[nodiscard]] Result<> makeRefI31();
[[nodiscard]] Result<> makeI31Get(bool signed_);
[[nodiscard]] Result<> makeCallRef(HeapType type, bool isReturn);
[[nodiscard]] Result<> makeRefTest(Type type);
[[nodiscard]] Result<> makeRefCast(Type type);
[[nodiscard]] Result<>
makeBrOn(Index label, BrOnOp op, Type in = Type::none, Type out = Type::none);
[[nodiscard]] Result<> makeStructNew(HeapType type);
[[nodiscard]] Result<> makeStructNewDefault(HeapType type);
[[nodiscard]] Result<>
makeStructGet(HeapType type, Index field, bool signed_);
[[nodiscard]] Result<> makeStructSet(HeapType type, Index field);
[[nodiscard]] Result<> makeArrayNew(HeapType type);
[[nodiscard]] Result<> makeArrayNewDefault(HeapType type);
[[nodiscard]] Result<> makeArrayNewData(HeapType type, Name data);
[[nodiscard]] Result<> makeArrayNewElem(HeapType type, Name elem);
[[nodiscard]] Result<> makeArrayNewFixed(HeapType type, uint32_t arity);
[[nodiscard]] Result<> makeArrayGet(HeapType type, bool signed_);
[[nodiscard]] Result<> makeArraySet(HeapType type);
[[nodiscard]] Result<> makeArrayLen();
[[nodiscard]] Result<> makeArrayCopy(HeapType destType, HeapType srcType);
[[nodiscard]] Result<> makeArrayFill(HeapType type);
[[nodiscard]] Result<> makeArrayInitData(HeapType type, Name data);
[[nodiscard]] Result<> makeArrayInitElem(HeapType type, Name elem);
[[nodiscard]] Result<> makeRefAs(RefAsOp op);
[[nodiscard]] Result<> makeStringNew(StringNewOp op, bool try_, Name mem);
[[nodiscard]] Result<> makeStringConst(Name string);
[[nodiscard]] Result<> makeStringMeasure(StringMeasureOp op);
[[nodiscard]] Result<> makeStringEncode(StringEncodeOp op, Name mem);
[[nodiscard]] Result<> makeStringConcat();
[[nodiscard]] Result<> makeStringEq(StringEqOp op);
[[nodiscard]] Result<> makeStringAs(StringAsOp op);
[[nodiscard]] Result<> makeStringWTF8Advance();
[[nodiscard]] Result<> makeStringWTF16Get();
[[nodiscard]] Result<> makeStringIterNext();
[[nodiscard]] Result<> makeStringIterMove(StringIterMoveOp op);
[[nodiscard]] Result<> makeStringSliceWTF(StringSliceWTFOp op);
[[nodiscard]] Result<> makeStringSliceIter();
[[nodiscard]] Result<> makeContBind(HeapType contTypeBefore,
HeapType contTypeAfter);
[[nodiscard]] Result<> makeContNew(HeapType ct);
[[nodiscard]] Result<> makeResume(HeapType ct,
const std::vector<Name>& tags,
const std::vector<Index>& labels);
[[nodiscard]] Result<> visitExpression(Expression*);
[[nodiscard]] Result<>
visitDrop(Drop*, std::optional<uint32_t> arity = std::nullopt);
[[nodiscard]] Result<> visitIf(If*);
[[nodiscard]] Result<> visitReturn(Return*);
[[nodiscard]] Result<> visitStructNew(StructNew*);
[[nodiscard]] Result<> visitArrayNew(ArrayNew*);
[[nodiscard]] Result<> visitArrayNewFixed(ArrayNewFixed*);
[[nodiscard]] Result<> visitBreak(Break*,
std::optional<Index> label = std::nullopt);
[[nodiscard]] Result<> visitBreakWithType(Break*, Type);
[[nodiscard]] Result<>
visitSwitch(Switch*, std::optional<Index> defaultLabel = std::nullopt);
[[nodiscard]] Result<> visitSwitchWithType(Switch*, Type);
[[nodiscard]] Result<> visitCall(Call*);
[[nodiscard]] Result<> visitCallIndirect(CallIndirect*);
[[nodiscard]] Result<> visitCallRef(CallRef*);
[[nodiscard]] Result<> visitLocalSet(LocalSet*);
[[nodiscard]] Result<> visitGlobalSet(GlobalSet*);
[[nodiscard]] Result<> visitThrow(Throw*);
[[nodiscard]] Result<> visitStringNew(StringNew*);
[[nodiscard]] Result<> visitStringEncode(StringEncode*);
[[nodiscard]] Result<> visitContBind(ContBind*);
[[nodiscard]] Result<> visitResume(Resume*);
[[nodiscard]] Result<> visitTupleMake(TupleMake*);
[[nodiscard]] Result<>
visitTupleExtract(TupleExtract*,
std::optional<uint32_t> arity = std::nullopt);
[[nodiscard]] Result<> visitPop(Pop*);
private:
Module& wasm;
Function* func;
Builder builder;
std::optional<Function::DebugLocation> debugLoc;
void applyDebugLoc(Expression* expr);
struct ScopeCtx {
struct NoScope {};
struct FuncScope {
Function* func;
};
struct BlockScope {
Block* block;
};
struct IfScope {
If* iff;
Name originalLabel;
};
struct ElseScope {
If* iff;
Name originalLabel;
};
struct LoopScope {
Loop* loop;
};
struct TryScope {
Try* tryy;
Name originalLabel;
};
struct CatchScope {
Try* tryy;
Name originalLabel;
};
struct CatchAllScope {
Try* tryy;
Name originalLabel;
};
struct TryTableScope {
TryTable* trytable;
Name originalLabel;
};
using Scope = std::variant<NoScope,
FuncScope,
BlockScope,
IfScope,
ElseScope,
LoopScope,
TryScope,
CatchScope,
CatchAllScope,
TryTableScope>;
Scope scope;
Name label;
bool labelUsed = false;
std::vector<Expression*> exprStack;
bool unreachable = false;
ScopeCtx() : scope(NoScope{}) {}
ScopeCtx(Scope scope) : scope(scope) {}
ScopeCtx(Scope scope, Name label) : scope(scope), label(label) {}
static ScopeCtx makeFunc(Function* func) {
return ScopeCtx(FuncScope{func});
}
static ScopeCtx makeBlock(Block* block) {
return ScopeCtx(BlockScope{block});
}
static ScopeCtx makeIf(If* iff, Name originalLabel = {}) {
return ScopeCtx(IfScope{iff, originalLabel});
}
static ScopeCtx makeElse(If* iff, Name originalLabel, Name label) {
return ScopeCtx(ElseScope{iff, originalLabel}, label);
}
static ScopeCtx makeLoop(Loop* loop) { return ScopeCtx(LoopScope{loop}); }
static ScopeCtx makeTry(Try* tryy, Name originalLabel = {}) {
return ScopeCtx(TryScope{tryy, originalLabel});
}
static ScopeCtx makeCatch(Try* tryy, Name originalLabel, Name label) {
return ScopeCtx(CatchScope{tryy, originalLabel}, label);
}
static ScopeCtx makeCatchAll(Try* tryy, Name originalLabel, Name label) {
return ScopeCtx(CatchAllScope{tryy, originalLabel}, label);
}
static ScopeCtx makeTryTable(TryTable* trytable, Name originalLabel = {}) {
return ScopeCtx(TryTableScope{trytable, originalLabel});
}
bool isNone() { return std::get_if<NoScope>(&scope); }
Function* getFunction() {
if (auto* funcScope = std::get_if<FuncScope>(&scope)) {
return funcScope->func;
}
return nullptr;
}
Block* getBlock() {
if (auto* blockScope = std::get_if<BlockScope>(&scope)) {
return blockScope->block;
}
return nullptr;
}
If* getIf() {
if (auto* ifScope = std::get_if<IfScope>(&scope)) {
return ifScope->iff;
}
return nullptr;
}
If* getElse() {
if (auto* elseScope = std::get_if<ElseScope>(&scope)) {
return elseScope->iff;
}
return nullptr;
}
Loop* getLoop() {
if (auto* loopScope = std::get_if<LoopScope>(&scope)) {
return loopScope->loop;
}
return nullptr;
}
Try* getTry() {
if (auto* tryScope = std::get_if<TryScope>(&scope)) {
return tryScope->tryy;
}
return nullptr;
}
Try* getCatch() {
if (auto* catchScope = std::get_if<CatchScope>(&scope)) {
return catchScope->tryy;
}
return nullptr;
}
Try* getCatchAll() {
if (auto* catchAllScope = std::get_if<CatchAllScope>(&scope)) {
return catchAllScope->tryy;
}
return nullptr;
}
TryTable* getTryTable() {
if (auto* tryTableScope = std::get_if<TryTableScope>(&scope)) {
return tryTableScope->trytable;
}
return nullptr;
}
Type getResultType() {
if (auto* func = getFunction()) {
return func->type.getSignature().results;
}
if (auto* block = getBlock()) {
return block->type;
}
if (auto* iff = getIf()) {
return iff->type;
}
if (auto* iff = getElse()) {
return iff->type;
}
if (auto* loop = getLoop()) {
return loop->type;
}
if (auto* tryy = getTry()) {
return tryy->type;
}
if (auto* tryy = getCatch()) {
return tryy->type;
}
if (auto* tryy = getCatchAll()) {
return tryy->type;
}
if (auto* trytable = getTryTable()) {
return trytable->type;
}
WASM_UNREACHABLE("unexpected scope kind");
}
Name getOriginalLabel() {
if (std::get_if<NoScope>(&scope) || getFunction()) {
return Name{};
}
if (auto* block = getBlock()) {
return block->name;
}
if (auto* ifScope = std::get_if<IfScope>(&scope)) {
return ifScope->originalLabel;
}
if (auto* elseScope = std::get_if<ElseScope>(&scope)) {
return elseScope->originalLabel;
}
if (auto* loop = getLoop()) {
return loop->name;
}
if (auto* tryScope = std::get_if<TryScope>(&scope)) {
return tryScope->originalLabel;
}
if (auto* catchScope = std::get_if<CatchScope>(&scope)) {
return catchScope->originalLabel;
}
if (auto* catchAllScope = std::get_if<CatchAllScope>(&scope)) {
return catchAllScope->originalLabel;
}
if (auto* tryTableScope = std::get_if<TryTableScope>(&scope)) {
return tryTableScope->originalLabel;
}
WASM_UNREACHABLE("unexpected scope kind");
}
};
std::vector<ScopeCtx> scopeStack;
std::unordered_map<Name, std::vector<Index>> labelDepths;
Name makeFresh(Name label) {
return Names::getValidName(label, [&](Name candidate) {
return labelDepths.insert({candidate, {}}).second;
});
}
void pushScope(ScopeCtx scope) {
if (auto label = scope.getOriginalLabel()) {
if (!scope.label) {
scope.label = makeFresh(label);
}
labelDepths[label].push_back(scopeStack.size() + 1);
}
scopeStack.push_back(scope);
}
ScopeCtx& getScope() {
if (scopeStack.empty()) {
scopeStack.push_back({});
}
return scopeStack.back();
}
Result<ScopeCtx*> getScope(Index label) {
Index numLabels = scopeStack.size();
if (!scopeStack.empty() && scopeStack[0].isNone()) {
--numLabels;
}
if (label >= numLabels) {
return Err{"label index out of bounds"};
}
return &scopeStack[scopeStack.size() - 1 - label];
}
Result<Expression*> finishScope(Block* block = nullptr);
[[nodiscard]] Result<Name> getLabelName(Index label);
[[nodiscard]] Result<Name> getDelegateLabelName(Index label);
[[nodiscard]] Result<Index> addScratchLocal(Type);
[[nodiscard]] Result<Expression*> pop(size_t size = 1);
struct HoistedVal {
Index valIndex;
LocalGet* get;
};
[[nodiscard]] MaybeResult<HoistedVal> hoistLastValue();
[[nodiscard]] Result<> packageHoistedValue(const HoistedVal&,
size_t sizeHint = 1);
[[nodiscard]] Result<Expression*>
getBranchValue(Expression* curr, Name labelName, std::optional<Index> label);
void dump();
};
}
#endif