#include "Luau/OptimizeDeadStore.h"
#include "Luau/IrBuilder.h"
#include "Luau/IrVisitUseDef.h"
#include "Luau/IrUtils.h"
#include <array>
#include "lobject.h"
namespace Luau
{
namespace CodeGen
{
constexpr uint8_t kUnknownTag = 0xff;
struct StoreRegInfo
{
uint32_t tagInstIdx = ~0u;
uint32_t valueInstIdx = ~0u;
uint32_t tvalueInstIdx = ~0u;
bool maybeGco = false;
uint8_t knownTag = kUnknownTag;
};
struct RemoveDeadStoreState
{
RemoveDeadStoreState(IrFunction& function)
: function(function)
{
maxReg = function.proto ? function.proto->maxstacksize : 255;
}
void killTagStore(StoreRegInfo& regInfo)
{
if (regInfo.tagInstIdx != ~0u)
{
kill(function, function.instructions[regInfo.tagInstIdx]);
regInfo.tagInstIdx = ~0u;
regInfo.maybeGco = false;
}
}
void killValueStore(StoreRegInfo& regInfo)
{
if (regInfo.valueInstIdx != ~0u)
{
kill(function, function.instructions[regInfo.valueInstIdx]);
regInfo.valueInstIdx = ~0u;
regInfo.maybeGco = false;
}
}
void killTagAndValueStorePair(StoreRegInfo& regInfo)
{
bool tagEstablished = regInfo.tagInstIdx != ~0u || regInfo.knownTag != kUnknownTag;
bool valueEstablished = regInfo.valueInstIdx != ~0u || regInfo.knownTag == LUA_TNIL;
if (tagEstablished && valueEstablished)
{
if (regInfo.tagInstIdx != ~0u)
{
kill(function, function.instructions[regInfo.tagInstIdx]);
regInfo.tagInstIdx = ~0u;
}
if (regInfo.valueInstIdx != ~0u)
{
kill(function, function.instructions[regInfo.valueInstIdx]);
regInfo.valueInstIdx = ~0u;
}
regInfo.maybeGco = false;
}
}
void killTValueStore(StoreRegInfo& regInfo)
{
if (regInfo.tvalueInstIdx != ~0u)
{
kill(function, function.instructions[regInfo.tvalueInstIdx]);
regInfo.tvalueInstIdx = ~0u;
regInfo.maybeGco = false;
}
}
void defReg(uint8_t reg)
{
StoreRegInfo& regInfo = info[reg];
if (function.cfg.captured.regs.test(reg))
return;
killTagAndValueStorePair(regInfo);
killTValueStore(regInfo);
regInfo.knownTag = kUnknownTag;
}
void useReg(uint8_t reg)
{
StoreRegInfo& regInfo = info[reg];
regInfo.tagInstIdx = ~0u;
regInfo.valueInstIdx = ~0u;
regInfo.tvalueInstIdx = ~0u;
regInfo.maybeGco = false;
}
void checkLiveIns(IrOp op)
{
if (op.kind == IrOpKind::VmExit)
{
readAllRegs();
}
else if (op.kind == IrOpKind::Block)
{
if (op.index < function.cfg.in.size())
{
const RegisterSet& in = function.cfg.in[op.index];
for (int i = 0; i <= maxReg; i++)
{
if (in.regs.test(i) || (in.varargSeq && i >= in.varargStart))
useReg(i);
}
}
else
{
readAllRegs();
}
}
else if (op.kind == IrOpKind::Undef)
{
}
else
{
CODEGEN_ASSERT(!"unexpected jump target type");
}
}
void checkLiveOuts(const IrBlock& block)
{
uint32_t index = function.getBlockIndex(block);
if (index < function.cfg.out.size())
{
const RegisterSet& out = function.cfg.out[index];
for (int i = 0; i <= maxReg; i++)
{
bool isOut = out.regs.test(i) || (out.varargSeq && i >= out.varargStart);
if (!isOut)
{
StoreRegInfo& regInfo = info[i];
if (!function.cfg.captured.regs.test(i))
{
killTagAndValueStorePair(regInfo);
killTValueStore(regInfo);
}
}
}
}
}
void defVarargs(uint8_t varargStart)
{
for (int i = varargStart; i <= maxReg; i++)
defReg(uint8_t(i));
}
void useVarargs(uint8_t varargStart)
{
for (int i = varargStart; i <= maxReg; i++)
useReg(uint8_t(i));
}
void def(IrOp op, int offset = 0)
{
defReg(vmRegOp(op) + offset);
}
void use(IrOp op, int offset = 0)
{
useReg(vmRegOp(op) + offset);
}
void maybeDef(IrOp op)
{
if (op.kind == IrOpKind::VmReg)
defReg(vmRegOp(op));
}
void maybeUse(IrOp op)
{
if (op.kind == IrOpKind::VmReg)
useReg(vmRegOp(op));
}
void defRange(int start, int count)
{
if (count == -1)
{
defVarargs(start);
}
else
{
for (int i = start; i < start + count; i++)
defReg(i);
}
}
void useRange(int start, int count)
{
if (count == -1)
{
useVarargs(start);
}
else
{
for (int i = start; i < start + count; i++)
useReg(i);
}
}
void capture(int reg) {}
void readAllRegs()
{
for (int i = 0; i <= maxReg; i++)
useReg(i);
hasGcoToClear = false;
}
void flushGcoRegs()
{
for (int i = 0; i <= maxReg; i++)
{
StoreRegInfo& regInfo = info[i];
if (regInfo.maybeGco)
{
CODEGEN_ASSERT(regInfo.knownTag == kUnknownTag || isGCO(regInfo.knownTag));
regInfo.tagInstIdx = ~0u;
regInfo.valueInstIdx = ~0u;
regInfo.tvalueInstIdx = ~0u;
regInfo.maybeGco = false;
}
}
hasGcoToClear = false;
}
IrFunction& function;
std::array<StoreRegInfo, 256> info;
int maxReg = 255;
bool hasGcoToClear = false;
};
static bool tryReplaceTagWithFullStore(
RemoveDeadStoreState& state,
IrBuilder& build,
IrFunction& function,
IrBlock& block,
uint32_t instIndex,
IrOp targetOp,
IrOp tagOp,
StoreRegInfo& regInfo
)
{
uint8_t tag = function.tagOp(tagOp);
if (regInfo.tagInstIdx != ~0u && (regInfo.valueInstIdx != ~0u || regInfo.knownTag == LUA_TNIL))
{
if (tag != LUA_TNIL && regInfo.valueInstIdx != ~0u)
{
IrOp prevValueOp = function.instructions[regInfo.valueInstIdx].b;
replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp});
}
state.killTagStore(regInfo);
state.killValueStore(regInfo);
regInfo.tvalueInstIdx = instIndex;
regInfo.maybeGco = isGCO(tag);
regInfo.knownTag = tag;
state.hasGcoToClear |= regInfo.maybeGco;
return true;
}
if (regInfo.tvalueInstIdx != ~0u)
{
IrInst& prev = function.instructions[regInfo.tvalueInstIdx];
if (prev.cmd == IrCmd::STORE_SPLIT_TVALUE)
{
CODEGEN_ASSERT(prev.d.kind == IrOpKind::None);
if (tag != LUA_TNIL)
{
IrOp prevValueOp = prev.c;
replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp});
}
state.killTValueStore(regInfo);
regInfo.tvalueInstIdx = instIndex;
regInfo.maybeGco = isGCO(tag);
regInfo.knownTag = tag;
state.hasGcoToClear |= regInfo.maybeGco;
return true;
}
}
return false;
}
static bool tryReplaceValueWithFullStore(
RemoveDeadStoreState& state,
IrBuilder& build,
IrFunction& function,
IrBlock& block,
uint32_t instIndex,
IrOp targetOp,
IrOp valueOp,
StoreRegInfo& regInfo
)
{
if (regInfo.tagInstIdx != ~0u && regInfo.valueInstIdx != ~0u)
{
IrOp prevTagOp = function.instructions[regInfo.tagInstIdx].b;
uint8_t prevTag = function.tagOp(prevTagOp);
CODEGEN_ASSERT(regInfo.knownTag == prevTag);
replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, prevTagOp, valueOp});
state.killTagStore(regInfo);
state.killValueStore(regInfo);
regInfo.tvalueInstIdx = instIndex;
return true;
}
if (regInfo.tvalueInstIdx != ~0u)
{
IrInst& prev = function.instructions[regInfo.tvalueInstIdx];
if (prev.cmd == IrCmd::STORE_SPLIT_TVALUE)
{
IrOp prevTagOp = prev.b;
uint8_t prevTag = function.tagOp(prevTagOp);
CODEGEN_ASSERT(regInfo.knownTag == prevTag);
CODEGEN_ASSERT(prev.d.kind == IrOpKind::None);
replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, prevTagOp, valueOp});
state.killTValueStore(regInfo);
regInfo.tvalueInstIdx = instIndex;
return true;
}
}
return false;
}
static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index)
{
switch (inst.cmd)
{
case IrCmd::STORE_TAG:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
if (tryReplaceTagWithFullStore(state, build, function, block, index, inst.a, inst.b, regInfo))
break;
uint8_t tag = function.tagOp(inst.b);
regInfo.tagInstIdx = index;
regInfo.maybeGco = isGCO(tag);
regInfo.knownTag = tag;
state.hasGcoToClear |= regInfo.maybeGco;
}
break;
case IrCmd::STORE_EXTRA:
if (inst.a.kind == IrOpKind::VmReg)
{
state.useReg(vmRegOp(inst.a));
}
break;
case IrCmd::STORE_POINTER:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
if (tryReplaceValueWithFullStore(state, build, function, block, index, inst.a, inst.b, regInfo))
{
regInfo.maybeGco = true;
state.hasGcoToClear |= true;
break;
}
if (regInfo.knownTag != kUnknownTag)
state.killValueStore(regInfo);
regInfo.valueInstIdx = index;
regInfo.maybeGco = true;
state.hasGcoToClear = true;
}
break;
case IrCmd::STORE_DOUBLE:
case IrCmd::STORE_INT:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
if (tryReplaceValueWithFullStore(state, build, function, block, index, inst.a, inst.b, regInfo))
break;
if (regInfo.knownTag != kUnknownTag)
state.killValueStore(regInfo);
regInfo.valueInstIdx = index;
regInfo.maybeGco = false;
}
break;
case IrCmd::STORE_VECTOR:
if (inst.a.kind == IrOpKind::VmReg)
{
state.useReg(vmRegOp(inst.a));
}
break;
case IrCmd::STORE_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
state.killTagAndValueStorePair(regInfo);
state.killTValueStore(regInfo);
regInfo.tvalueInstIdx = index;
regInfo.maybeGco = true;
regInfo.knownTag = kUnknownTag;
if (IrInst* arg = function.asInstOp(inst.b))
{
if (arg->cmd == IrCmd::TAG_VECTOR)
regInfo.maybeGco = false;
if (arg->cmd == IrCmd::LOAD_TVALUE && arg->c.kind != IrOpKind::None)
regInfo.maybeGco = isGCO(function.tagOp(arg->c));
}
state.hasGcoToClear |= regInfo.maybeGco;
}
break;
case IrCmd::STORE_SPLIT_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
state.killTagAndValueStorePair(regInfo);
state.killTValueStore(regInfo);
regInfo.tvalueInstIdx = index;
regInfo.maybeGco = isGCO(function.tagOp(inst.b));
regInfo.knownTag = function.tagOp(inst.b);
state.hasGcoToClear |= regInfo.maybeGco;
}
break;
case IrCmd::CHECK_TAG:
state.checkLiveIns(inst.c);
if (IrInst* load = function.asInstOp(inst.a); load && load->cmd == IrCmd::LOAD_TAG && load->a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(load->a);
StoreRegInfo& regInfo = state.info[reg];
regInfo.knownTag = function.tagOp(inst.b);
}
break;
case IrCmd::TRY_NUM_TO_INDEX:
state.checkLiveIns(inst.b);
break;
case IrCmd::TRY_CALL_FASTGETTM:
state.checkLiveIns(inst.c);
break;
case IrCmd::CHECK_FASTCALL_RES:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_TRUTHY:
state.checkLiveIns(inst.c);
break;
case IrCmd::CHECK_READONLY:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_NO_METATABLE:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_SAFE_ENV:
state.checkLiveIns(inst.a);
break;
case IrCmd::CHECK_ARRAY_SIZE:
state.checkLiveIns(inst.c);
break;
case IrCmd::CHECK_SLOT_MATCH:
state.checkLiveIns(inst.c);
break;
case IrCmd::CHECK_NODE_NO_NEXT:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_NODE_VALUE:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_BUFFER_LEN:
state.checkLiveIns(inst.d);
break;
case IrCmd::CHECK_USERDATA_TAG:
state.checkLiveIns(inst.c);
break;
case IrCmd::JUMP:
break;
case IrCmd::RETURN:
visitVmRegDefsUses(state, function, inst);
state.checkLiveOuts(block);
break;
case IrCmd::ADJUST_STACK_TO_REG:
break;
case IrCmd::CMP_ANY:
case IrCmd::DO_ARITH:
case IrCmd::DO_LEN:
case IrCmd::GET_TABLE:
case IrCmd::SET_TABLE:
case IrCmd::GET_IMPORT:
case IrCmd::CONCAT:
case IrCmd::INTERRUPT:
case IrCmd::CHECK_GC:
case IrCmd::CALL:
case IrCmd::FORGLOOP_FALLBACK:
case IrCmd::FALLBACK_GETGLOBAL:
case IrCmd::FALLBACK_SETGLOBAL:
case IrCmd::FALLBACK_GETTABLEKS:
case IrCmd::FALLBACK_SETTABLEKS:
case IrCmd::FALLBACK_NAMECALL:
case IrCmd::FALLBACK_DUPCLOSURE:
case IrCmd::FALLBACK_FORGPREP:
if (state.hasGcoToClear)
state.flushGcoRegs();
visitVmRegDefsUses(state, function, inst);
break;
default:
CODEGEN_ASSERT(!isNonTerminatingJump(inst.cmd));
visitVmRegDefsUses(state, function, inst);
break;
}
}
static void markDeadStoresInBlock(IrBuilder& build, IrBlock& block, RemoveDeadStoreState& state)
{
IrFunction& function = build.function;
for (uint32_t index = block.start; index <= block.finish; index++)
{
CODEGEN_ASSERT(index < function.instructions.size());
IrInst& inst = function.instructions[index];
markDeadStoresInInst(state, build, function, block, inst, index);
}
}
static void markDeadStoresInBlockChain(IrBuilder& build, std::vector<uint8_t>& visited, IrBlock* block)
{
IrFunction& function = build.function;
RemoveDeadStoreState state{function};
while (block)
{
uint32_t blockIdx = function.getBlockIndex(*block);
CODEGEN_ASSERT(!visited[blockIdx]);
visited[blockIdx] = true;
markDeadStoresInBlock(build, *block, state);
IrInst& termInst = function.instructions[block->finish];
IrBlock* nextBlock = nullptr;
if (termInst.cmd == IrCmd::JUMP && termInst.a.kind == IrOpKind::Block)
{
IrBlock& target = function.blockOp(termInst.a);
uint32_t targetIdx = function.getBlockIndex(target);
if (target.useCount == 1 && !visited[targetIdx] && target.kind != IrBlockKind::Fallback)
nextBlock = ⌖
}
block = nextBlock;
}
}
void markDeadStoresInBlockChains(IrBuilder& build)
{
IrFunction& function = build.function;
std::vector<uint8_t> visited(function.blocks.size(), false);
for (IrBlock& block : function.blocks)
{
if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead)
continue;
if (visited[function.getBlockIndex(block)])
continue;
markDeadStoresInBlockChain(build, visited, &block);
}
}
} }