#include "ir/manipulation.h"
#include "ir/module-utils.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm-binary.h"
#include "wasm-builder.h"
#include "wasm.h"
namespace wasm {
struct Range {
bool isZero;
size_t start;
size_t end;
};
using Replacement = std::function<Expression*(Function*)>;
using Replacements = std::unordered_map<Expression*, Replacement>;
using Referrers = std::vector<Expression*>;
const size_t MEMORY_INIT_SIZE = 10;
const size_t MEMORY_FILL_SIZE = 9;
const size_t DATA_DROP_SIZE = 3;
namespace {
Expression*
makeGtShiftedMemorySize(Builder& builder, Module& module, MemoryInit* curr) {
return builder.makeBinary(
module.memory.is64() ? GtUInt64 : GtUInt32,
curr->dest,
builder.makeBinary(module.memory.is64() ? ShlInt64 : ShlInt32,
builder.makeMemorySize(),
builder.makeConstPtr(16)));
}
}
struct MemoryPacking : public Pass {
size_t dropStateGlobalCount = 0;
uint32_t maxSegments;
void run(PassRunner* runner, Module* module) override;
void optimizeBulkMemoryOps(PassRunner* runner, Module* module);
void getSegmentReferrers(Module* module, std::vector<Referrers>& referrers);
void dropUnusedSegments(std::vector<Memory::Segment>& segments,
std::vector<Referrers>& referrers);
bool canSplit(const Memory::Segment& segment, const Referrers& referrers);
void calculateRanges(const Memory::Segment& segment,
const Referrers& referrers,
std::vector<Range>& ranges);
void createSplitSegments(Builder& builder,
const Memory::Segment& segment,
std::vector<Range>& ranges,
std::vector<Memory::Segment>& packed,
size_t segmentsRemaining);
void createReplacements(Module* module,
const std::vector<Range>& ranges,
const Referrers& referrers,
Replacements& replacements,
const Index segmentIndex);
void replaceBulkMemoryOps(PassRunner* runner,
Module* module,
Replacements& replacements);
};
void MemoryPacking::run(PassRunner* runner, Module* module) {
if (!module->memory.exists) {
return;
}
maxSegments = module->features.hasBulkMemory()
? 63
: uint32_t(WebLimitations::MaxDataSegments);
auto& segments = module->memory.segments;
std::vector<Referrers> referrers(segments.size());
if (module->features.hasBulkMemory()) {
optimizeBulkMemoryOps(runner, module);
getSegmentReferrers(module, referrers);
dropUnusedSegments(segments, referrers);
}
std::vector<Memory::Segment> packed;
Replacements replacements;
Builder builder(*module);
for (size_t origIndex = 0; origIndex < segments.size(); ++origIndex) {
auto& segment = segments[origIndex];
auto& currReferrers = referrers[origIndex];
std::vector<Range> ranges;
if (canSplit(segment, currReferrers)) {
calculateRanges(segment, currReferrers, ranges);
} else {
ranges.push_back({false, 0, segment.data.size()});
}
Index firstNewIndex = packed.size();
size_t segmentsRemaining = segments.size() - origIndex;
createSplitSegments(builder, segment, ranges, packed, segmentsRemaining);
createReplacements(
module, ranges, currReferrers, replacements, firstNewIndex);
}
segments.swap(packed);
if (module->features.hasBulkMemory()) {
replaceBulkMemoryOps(runner, module, replacements);
}
}
bool MemoryPacking::canSplit(const Memory::Segment& segment,
const Referrers& referrers) {
if (segment.isPassive) {
for (auto* referrer : referrers) {
if (auto* init = referrer->dynCast<MemoryInit>()) {
if (!init->offset->is<Const>() || !init->size->is<Const>()) {
return false;
}
}
}
return true;
} else {
return segment.offset->is<Const>();
}
}
void MemoryPacking::calculateRanges(const Memory::Segment& segment,
const Referrers& referrers,
std::vector<Range>& ranges) {
auto& data = segment.data;
if (data.size() == 0) {
return;
}
size_t start = 0;
while (start < data.size()) {
size_t end = start;
while (end < data.size() && data[end] == 0) {
end++;
}
if (end > start) {
ranges.push_back({true, start, end});
start = end;
}
while (end < data.size() && data[end] != 0) {
end++;
}
if (end > start) {
ranges.push_back({false, start, end});
start = end;
}
}
size_t threshold = 0;
if (segment.isPassive) {
threshold += 2;
size_t edgeThreshold = 0;
for (auto* referrer : referrers) {
if (referrer->is<MemoryInit>()) {
threshold += MEMORY_FILL_SIZE + MEMORY_INIT_SIZE;
edgeThreshold += MEMORY_FILL_SIZE;
} else {
threshold += DATA_DROP_SIZE;
}
}
if (ranges.size() >= 2) {
auto last = ranges.end() - 1;
auto penultimate = ranges.end() - 2;
if (last->isZero && last->end - last->start <= edgeThreshold) {
penultimate->end = last->end;
ranges.erase(last);
}
}
if (ranges.size() >= 2) {
auto first = ranges.begin();
auto second = ranges.begin() + 1;
if (first->isZero && first->end - first->start <= edgeThreshold) {
second->start = first->start;
ranges.erase(first);
}
}
} else {
threshold = 8;
}
std::vector<Range> mergedRanges = {ranges.front()};
size_t i;
for (i = 1; i < ranges.size() - 1; ++i) {
auto left = mergedRanges.end() - 1;
auto curr = ranges.begin() + i;
auto right = ranges.begin() + i + 1;
if (curr->isZero && curr->end - curr->start <= threshold) {
left->end = right->end;
++i;
} else {
mergedRanges.push_back(*curr);
}
}
if (i < ranges.size()) {
mergedRanges.push_back(ranges.back());
}
std::swap(ranges, mergedRanges);
}
void MemoryPacking::optimizeBulkMemoryOps(PassRunner* runner, Module* module) {
struct Optimizer : WalkerPass<PostWalker<Optimizer>> {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new Optimizer; }
bool needsRefinalizing;
void visitMemoryInit(MemoryInit* curr) {
Builder builder(*getModule());
Memory::Segment& segment = getModule()->memory.segments[curr->segment];
size_t maxRuntimeSize = segment.isPassive ? segment.data.size() : 0;
bool mustNop = false;
bool mustTrap = false;
auto* offset = curr->offset->dynCast<Const>();
auto* size = curr->size->dynCast<Const>();
if (offset && uint32_t(offset->value.geti32()) > maxRuntimeSize) {
mustTrap = true;
}
if (size && uint32_t(size->value.geti32()) > maxRuntimeSize) {
mustTrap = true;
}
if (offset && size) {
uint64_t offsetVal(offset->value.geti32());
uint64_t sizeVal(size->value.geti32());
if (offsetVal + sizeVal > maxRuntimeSize) {
mustTrap = true;
} else if (offsetVal == 0 && sizeVal == 0) {
mustNop = true;
}
}
assert(!mustNop || !mustTrap);
if (mustNop) {
replaceCurrent(
builder.makeIf(makeGtShiftedMemorySize(builder, *getModule(), curr),
builder.makeUnreachable()));
} else if (mustTrap) {
replaceCurrent(builder.blockify(builder.makeDrop(curr->dest),
builder.makeDrop(curr->offset),
builder.makeDrop(curr->size),
builder.makeUnreachable()));
needsRefinalizing = true;
} else if (!segment.isPassive) {
replaceCurrent(builder.makeIf(
builder.makeBinary(
OrInt32,
makeGtShiftedMemorySize(builder, *getModule(), curr),
builder.makeBinary(OrInt32, curr->offset, curr->size)),
builder.makeUnreachable()));
}
}
void visitDataDrop(DataDrop* curr) {
if (!getModule()->memory.segments[curr->segment].isPassive) {
ExpressionManipulator::nop(curr);
}
}
void doWalkFunction(Function* func) {
needsRefinalizing = false;
super::doWalkFunction(func);
if (needsRefinalizing) {
ReFinalize().walkFunctionInModule(func, getModule());
}
}
} optimizer;
optimizer.run(runner, module);
}
void MemoryPacking::getSegmentReferrers(Module* module,
std::vector<Referrers>& referrers) {
auto collectReferrers = [&](Function* func,
std::vector<Referrers>& referrers) {
if (func->imported()) {
return;
}
struct Collector : WalkerPass<PostWalker<Collector>> {
std::vector<Referrers>& referrers;
Collector(std::vector<Referrers>& referrers) : referrers(referrers) {}
void visitMemoryInit(MemoryInit* curr) {
referrers[curr->segment].push_back(curr);
}
void visitDataDrop(DataDrop* curr) {
referrers[curr->segment].push_back(curr);
}
void doWalkFunction(Function* func) {
referrers.resize(getModule()->memory.segments.size());
super::doWalkFunction(func);
}
} collector(referrers);
collector.walkFunctionInModule(func, module);
};
ModuleUtils::ParallelFunctionAnalysis<std::vector<Referrers>> analysis(
*module, collectReferrers);
referrers.resize(module->memory.segments.size());
for (auto& pair : analysis.map) {
std::vector<Referrers>& funcReferrers = pair.second;
for (size_t i = 0; i < funcReferrers.size(); ++i) {
referrers[i].insert(
referrers[i].end(), funcReferrers[i].begin(), funcReferrers[i].end());
}
}
}
void MemoryPacking::dropUnusedSegments(std::vector<Memory::Segment>& segments,
std::vector<Referrers>& referrers) {
std::vector<Memory::Segment> usedSegments;
std::vector<Referrers> usedReferrers;
for (size_t i = 0; i < segments.size(); ++i) {
bool used = false;
if (segments[i].isPassive) {
for (auto* referrer : referrers[i]) {
if (referrer->is<MemoryInit>()) {
used = true;
break;
}
}
} else {
used = true;
}
if (used) {
usedSegments.push_back(segments[i]);
usedReferrers.push_back(referrers[i]);
} else {
for (auto* referrer : referrers[i]) {
ExpressionManipulator::nop(referrer);
}
}
}
std::swap(segments, usedSegments);
std::swap(referrers, usedReferrers);
}
void MemoryPacking::createSplitSegments(Builder& builder,
const Memory::Segment& segment,
std::vector<Range>& ranges,
std::vector<Memory::Segment>& packed,
size_t segmentsRemaining) {
for (size_t i = 0; i < ranges.size(); ++i) {
Range& range = ranges[i];
if (range.isZero) {
continue;
}
Expression* offset = nullptr;
if (!segment.isPassive) {
if (auto* c = segment.offset->dynCast<Const>()) {
offset = builder.makeConst(int32_t(c->value.geti32() + range.start));
} else {
assert(ranges.size() == 1);
offset = segment.offset;
}
}
if (maxSegments <= packed.size() + segmentsRemaining) {
auto lastNonzero = ranges.end() - 1;
if (lastNonzero->isZero) {
--lastNonzero;
}
range.end = lastNonzero->end;
ranges.erase(ranges.begin() + i + 1, lastNonzero + 1);
}
packed.emplace_back(segment.isPassive,
offset,
&segment.data[range.start],
range.end - range.start);
}
}
void MemoryPacking::createReplacements(Module* module,
const std::vector<Range>& ranges,
const Referrers& referrers,
Replacements& replacements,
const Index segmentIndex) {
if (ranges.size() == 1 && !ranges.front().isZero) {
for (auto referrer : referrers) {
replacements[referrer] = [referrer, segmentIndex](Function*) {
if (auto* init = referrer->dynCast<MemoryInit>()) {
init->segment = segmentIndex;
} else if (auto* drop = referrer->dynCast<DataDrop>()) {
drop->segment = segmentIndex;
} else {
WASM_UNREACHABLE("Unexpected bulk memory operation");
}
return referrer;
};
}
return;
}
Builder builder(*module);
Name dropStateGlobal;
auto getDropStateGlobal = [&]() {
if (dropStateGlobal != Name()) {
return dropStateGlobal;
}
dropStateGlobal = Name(std::string("__mem_segment_drop_state_") +
std::to_string(dropStateGlobalCount++));
module->addGlobal(builder.makeGlobal(dropStateGlobal,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Mutable));
return dropStateGlobal;
};
for (auto referrer : referrers) {
auto* init = referrer->dynCast<MemoryInit>();
if (init == nullptr) {
continue;
}
size_t start = init->offset->cast<Const>()->value.geti32();
size_t end = start + init->size->cast<Const>()->value.geti32();
size_t firstRangeIdx = 0;
while (firstRangeIdx < ranges.size() &&
ranges[firstRangeIdx].end <= start) {
++firstRangeIdx;
}
if (start == end) {
Expression* result = builder.makeIf(
builder.makeBinary(
OrInt32,
makeGtShiftedMemorySize(builder, *module, init),
builder.makeGlobalGet(getDropStateGlobal(), Type::i32)),
builder.makeUnreachable());
replacements[init] = [result](Function*) { return result; };
continue;
}
assert(firstRangeIdx < ranges.size());
Expression* result = nullptr;
auto appendResult = [&](Expression* expr) {
result = result ? builder.blockify(result, expr) : expr;
};
Index* setVar = nullptr;
std::vector<Index*> getVars;
if (!init->dest->is<Const>()) {
auto set = builder.makeLocalSet(-1, init->dest);
setVar = &set->index;
appendResult(set);
}
if (ranges[firstRangeIdx].isZero) {
appendResult(
builder.makeIf(builder.makeGlobalGet(getDropStateGlobal(), Type::i32),
builder.makeUnreachable()));
}
size_t bytesWritten = 0;
size_t initIndex = segmentIndex;
for (size_t i = firstRangeIdx; i < ranges.size() && ranges[i].start < end;
++i) {
auto& range = ranges[i];
Expression* dest;
if (auto* c = init->dest->dynCast<Const>()) {
dest = builder.makeConst(int32_t(c->value.geti32() + bytesWritten));
} else {
auto* get = builder.makeLocalGet(-1, Type::i32);
getVars.push_back(&get->index);
dest = get;
if (bytesWritten > 0) {
Const* addend = builder.makeConst(int32_t(bytesWritten));
dest = builder.makeBinary(AddInt32, dest, addend);
}
}
size_t bytes = std::min(range.end, end) - std::max(range.start, start);
Expression* size = builder.makeConst(int32_t(bytes));
bytesWritten += bytes;
if (range.isZero) {
Expression* value = builder.makeConst(Literal::makeZero(Type::i32));
appendResult(builder.makeMemoryFill(dest, value, size));
} else {
size_t offsetBytes = std::max(start, range.start) - range.start;
Expression* offset = builder.makeConst(int32_t(offsetBytes));
appendResult(builder.makeMemoryInit(initIndex, dest, offset, size));
initIndex++;
}
}
assert(result);
replacements[init] = [module, setVar, getVars, result](Function* function) {
if (setVar != nullptr) {
Index destVar = Builder(*module).addVar(function, Type::i32);
*setVar = destVar;
for (auto* getVar : getVars) {
*getVar = destVar;
}
}
return result;
};
}
for (auto drop : referrers) {
if (!drop->is<DataDrop>()) {
continue;
}
Expression* result = nullptr;
auto appendResult = [&](Expression* expr) {
result = result ? builder.blockify(result, expr) : expr;
};
if (dropStateGlobal != Name()) {
appendResult(
builder.makeGlobalSet(dropStateGlobal, builder.makeConst(int32_t(1))));
}
size_t dropIndex = segmentIndex;
for (auto range : ranges) {
if (!range.isZero) {
appendResult(builder.makeDataDrop(dropIndex++));
}
}
replacements[drop] = [result, module](Function*) {
return result ? result : Builder(*module).makeNop();
};
}
}
void MemoryPacking::replaceBulkMemoryOps(PassRunner* runner,
Module* module,
Replacements& replacements) {
struct Replacer : WalkerPass<PostWalker<Replacer>> {
bool isFunctionParallel() override { return true; }
Replacements& replacements;
Replacer(Replacements& replacements) : replacements(replacements){};
Pass* create() override { return new Replacer(replacements); }
void visitMemoryInit(MemoryInit* curr) {
auto replacement = replacements.find(curr);
assert(replacement != replacements.end());
replaceCurrent(replacement->second(getFunction()));
}
void visitDataDrop(DataDrop* curr) {
auto replacement = replacements.find(curr);
assert(replacement != replacements.end());
replaceCurrent(replacement->second(getFunction()));
}
} replacer(replacements);
replacer.run(runner, module);
}
Pass* createMemoryPackingPass() { return new MemoryPacking(); }
}