#include "megbrain/opr/cond.h"
#include "megbrain/graph/event.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/utility.h"
using namespace mgb;
using namespace opr;
#if MGB_ENABLE_COND_EXEC
namespace {
bool can_prove_imply(cg::ExecutionMask* lhs, cg::ExecutionMask* rhs) {
if (rhs == lhs->parent()) {
return true;
}
using Mode = CondExecPredLogical::Mode;
auto is_pred_logical = [](cg::OperatorNodeBase* opr, Mode mode) {
auto as_p = opr->try_cast_final<CondExecPredLogical>();
return as_p && as_p->param().mode == mode;
};
auto opr = rhs->owner()->owner_opr();
if (is_pred_logical(opr, Mode::AND) && opr->input().size() == 1) {
opr = opr->input(0)->owner_opr();
}
if (is_pred_logical(opr, Mode::OR)) {
auto lvar = lhs->owner();
for (auto i : opr->input()) {
if (lvar == i) {
return true;
}
}
return false;
}
return false;
}
VarNode* proxy_var_from_mask(cg::ExecutionMask* mask) {
auto var = mask->owner();
mgb_assert(var);
auto opr = var->owner_opr();
auto type = opr->dyn_typeinfo();
mgb_assert(
type->is<CondExecPred>() || type->is<CondExecPredLogical>(),
"mask not from CondExec opr: %s", cg::dump_var_info({var}).c_str());
return var;
}
#if MGB_ENABLE_LOGGING
std::string mask2str(cg::ExecutionMask* mask) {
if (!mask) {
return "null";
}
auto var = mask->owner();
mgb_assert(var);
if (var->owner_opr()->same_type<CondExecPred>()) {
return ssprintf("CondExecPred(%s)", var->cname());
}
mgb_assert(var->owner_opr()->same_type<CondExecPredLogical>());
return ssprintf("CondExecPredLogical(%s)", var->cname());
}
#else
std::string mask2str(cg::ExecutionMask*) {
return "";
}
#endif
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecPred);
class CondExecPred::PredEvaluator {
public:
enum Result { LT, EQ, GT };
PredEvaluator(const CondExecPred& opr, const DeviceTensorND& pred);
Result operator()(const DeviceTensorND& key) {
pre_check(key);
return m_compare(key);
}
private:
CompNode default_cpu = CompNode::default_cpu();
thin_function<Result(const DeviceTensorND&)> m_compare;
void pre_check(const DeviceTensorND& val) {
mgb_assert(val.comp_node() == default_cpu);
mgb_throw_if(
!val.shape().is_scalar(), GraphError,
"CondExec predicate or branch key is not scalar: %s",
val.shape().to_string().c_str());
}
};
CondExecPred::PredEvaluator::PredEvaluator(
const CondExecPred& opr, const DeviceTensorND& pred) {
pre_check(pred);
switch (pred.dtype().enumv()) {
#define cbf(dt) \
case DTypeTrait<dt>::enumv: { \
using ct = DTypeTrait<dt>::ctype; \
m_compare = [eps = opr.m_param.eps, \
p = pred.ptr<ct>()[0]](const DeviceTensorND& key) { \
ct k = key.ptr<ct>()[0]; \
return std::abs(p - k) < eps ? EQ : (p < k ? LT : GT); \
}; \
break; \
}
#define cbi(dt) \
case DTypeTrait<dt>::enumv: { \
using ct = DTypeTrait<dt>::ctype; \
m_compare = [p = pred.ptr<ct>()[0]](const DeviceTensorND& key) { \
ct k = key.ptr<ct>()[0]; \
return p == k ? EQ : (p < k ? LT : GT); \
}; \
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cbf);
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cbi)
#undef cbf
#undef cbi
default:
mgb_throw(GraphError, "unsupported pred dtype: %s", pred.dtype().name());
}
}
class CondExecPred::GlobalRegistry final : public UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
SyncEventConnecter::ReceiverHandler m_opr_insert_handler;
ThinHashMap<VarNode*, ExecutionMask*> m_var2mask;
void on_new_opr(OperatorNodeBase* opr);
public:
static GlobalRegistry* get(ComputingGraph& graph) {
using namespace cg::event;
auto ptr = graph.options().user_data.get_user_data_or_create<GlobalRegistry>();
if (!ptr->m_opr_insert_handler) {
ptr->m_opr_insert_handler = graph.event().register_receiver<OprInserted>(
[ptr](const OprInserted& ev) {
if (!ev.is_dedup && !ev.exc) {
ptr->on_new_opr(ev.opr);
}
});
}
return ptr;
}
ExecutionMask* get_mask_from_var(VarNode* var) const {
auto iter = m_var2mask.find(var);
return iter == m_var2mask.end() ? nullptr : iter->second;
}
ExecutionMask* require_mask_from_var(VarNode* var) const {
auto mask = get_mask_from_var(var);
mgb_throw_if(
!mask, GraphError, "var is not controlled by ExecutionMask: %s",
cg::dump_var_info({var}).c_str());
return mask;
}
ExecutionMask* check_ppv(VarNode* var) const {
auto mask = require_mask_from_var(var);
mgb_throw_if(
mask->owner() != var, GraphError,
"a conditional var is not PPV: mask=%s var=%s", mask2str(mask).c_str(),
cg::dump_var_info({var}).c_str());
return mask;
}
};
MGB_TYPEINFO_OBJ_IMPL(CondExecPred::GlobalRegistry);
void CondExecPred::GlobalRegistry::on_new_opr(OperatorNodeBase* const opr) {
ExecutionMask* mask = nullptr;
auto opr_type = opr->dyn_typeinfo();
bool opr_is_mark = opr_type->is<CondExecMark>(),
opr_is_merge = opr_type->is<CondExecMerge>(),
opr_is_pred_logical = opr_type->is<CondExecPredLogical>();
using MergeMode = CondExecMerge::Mode;
MergeMode merge_mode = opr_is_merge ? opr->cast_final<CondExecMerge>().param().mode
: static_cast<MergeMode>(-1);
bool opr_follow_pred =
opr_is_mark || (opr_is_merge && merge_mode == MergeMode::SUM_COND_OUT);
auto&& inputs = opr->input();
for (size_t idx = 0; idx < inputs.size(); ++idx) {
auto i_var = inputs[idx];
ExecutionMask* i_mask = nullptr;
auto i_owner = i_var->owner_opr();
bool i_is_pred = false;
if (i_owner->same_type<CondExecPred>() ||
i_owner->same_type<CondExecPredLogical>()) {
i_is_pred = true;
mgb_throw_if(
!((opr_follow_pred && i_var == opr->input().back()) ||
opr_is_pred_logical),
GraphError,
"predicate proxy var not received by CondExec "
"mark/merge opr: var=%s recv_opr=%s{%s}",
cg::dump_var_info({i_var}).c_str(), opr->cname(),
opr->dyn_typeinfo()->name);
}
if (opr_follow_pred && i_var == opr->input().back()) {
mgb_assert(i_is_pred);
i_mask = m_var2mask.at(i_var);
if (mask) {
mgb_throw_if(
!can_prove_imply(i_mask, mask), GraphError,
"can not prove opr mask implies inputs mask: "
"opr=%s{%s}: opr_mask=%s "
"inputs_mask=%s",
opr->cname(), opr->dyn_typeinfo()->name,
mask2str(i_mask).c_str(), mask2str(mask).c_str());
}
mask = i_mask;
break;
}
if (!i_mask) {
auto iter = m_var2mask.find(i_var);
i_mask = iter == m_var2mask.end() ? nullptr : iter->second;
}
if (opr_is_pred_logical && i_mask) {
i_mask = i_mask->parent();
}
if (opr_is_merge) {
if (merge_mode == MergeMode::SUM &&
idx >= inputs.size() - opr->output().size()) {
if (cg::is_static_var_value(i_var)) {
i_mask = nullptr;
}
} else if (i_mask) {
i_mask = i_mask->parent();
}
}
if (i_mask) {
auto lower = ExecutionMask::find_direct_lowest(mask, i_mask);
mgb_throw_if(
!lower, GraphError,
"different ExecutionMask trees on inputs of a single "
"opr: opr=%s{%s} mask0=%s mask1=%s",
opr->cname(), opr->dyn_typeinfo()->name, mask2str(mask).c_str(),
mask2str(i_mask).c_str());
mask = lower;
}
}
if (mask) {
mask->register_to_opr(opr);
for (auto i : opr->output()) {
m_var2mask[i] = mask;
}
}
if (opr_type->is<CondExecPred>()) {
size_t idx = 0;
for (auto&& i : opr->cast_final<CondExecPred>().masks()) {
if (mask) {
mask->add_nested(i.get());
}
m_var2mask[opr->output(idx++)] = i.get();
}
} else if (opr_is_pred_logical) {
auto m = opr->cast_final<CondExecPredLogical>().mask();
if (mask) {
mask->add_nested(m);
}
m_var2mask[opr->output(0)] = m;
}
}
CondExecPred::CondExecPred(
VarNode* pred, const VarNodeArrayView& keys, const Param& param,
const OperatorNodeConfig& config)
: Super(pred->owner_graph(), config, "cond_pred", {pred}), m_param{param} {
m_masks.reserve(keys.size() + 1);
auto add_out = [this](const std::string& name) {
auto var = add_output(name);
var->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC).dtype(dtype::Int32{});
m_masks.emplace_back(std::make_shared<ExecutionMask>(var));
};
for (size_t i = 0; i < keys.size(); ++i) {
mgb_throw_if(
keys[i]->dtype() != pred->dtype(), GraphError,
"dtype mismatch: pred=%s input[%zu]=%s", pred->dtype().name(), i,
keys[i]->dtype().name());
add_input({keys[i]});
if (param.mode == Param::Mode::PIECEWISE) {
if (!i) {
add_out("[-inf,k0]");
}
if (i != keys.size() - 1) {
add_out(ssprintf("[k%zu,k%zu]", i, i + 1));
} else {
add_out(ssprintf("[k%zu,inf]", i));
}
} else {
add_out(ssprintf("branch%zu", i));
}
}
if (param.mode == Param::Mode::CASE_FALLBACK) {
add_out("fallback");
}
add_input({pred});
add_equivalence_component<PODHash<Param>>(&m_param);
GlobalRegistry::get(*owner_graph());
}
cg::OperatorNodeBase* CondExecPred::make_opr(
SymbolVar pred, const VarNodeArrayView& keys, const Param& param,
const OperatorNodeConfig& config) {
return pred.node()->owner_graph()->insert_opr(
std::make_unique<CondExecPred>(pred.node(), keys, param, config));
}
void CondExecPred::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
for (auto i : output()) {
mgr.register_shape_infer(i, ShapeInferDesc::make_const({1}));
}
auto reg_value_infer_no_const = [&mgr](VarNode* var, ValueInferDesc& desc) {
auto orig_size = desc.deps.size();
mixin::ForwardInputToOutput::ensure_not_replaced_by_const_folding(desc);
mgr.register_value_infer(var, desc);
if (desc.deps.size() != orig_size) {
mgb_assert(desc.deps.size() == orig_size + 1);
desc.deps.pop_back();
}
};
size_t nr_key = input().size() - 1;
auto mode = m_param.mode;
if (mode == Mode::CASE || mode == Mode::CASE_FALLBACK) {
auto infer_val_eq = [this](DeviceTensorND& dest, const InpVal& inp) {
auto&& pv = inp.val[0].value();
auto&& key = inp.val[1].value();
dest.resize({1}).ptr<int>()[0] =
(PredEvaluator{*this, pv}(key) == PredEvaluator::EQ);
return true;
};
ValueInferDesc desc{
SourceType::DEP,
{{input().back(), DepType::VALUE}, {nullptr, DepType::VALUE}},
infer_val_eq};
for (size_t i = 0; i < nr_key; ++i) {
desc.deps[1].dest = input(i);
reg_value_infer_no_const(output(i), desc);
}
if (mode == Mode::CASE_FALLBACK) {
desc.deps.clear();
for (size_t i = 0; i < nr_key; ++i) {
desc.deps.push_back({output(i), DepType::VALUE});
}
desc.infer_func = [](DeviceTensorND& dest, const InpVal& inp) {
int r = 1;
for (auto&& i : inp.val) {
if (i.value().ptr<int>()[0]) {
r = 0;
break;
}
}
dest.resize({1}).ptr<int>()[0] = r;
return true;
};
reg_value_infer_no_const(output().back(), desc);
}
} else {
mgb_assert(mode == Mode::PIECEWISE);
auto infer_first = [this](DeviceTensorND& dest, const InpVal& inp) {
auto&& pv = inp.val[0].value();
auto&& key = inp.val[1].value();
dest.resize({1}).ptr<int>()[0] =
(PredEvaluator{*this, pv}(key) == PredEvaluator::LT);
return true;
};
auto infer_mid = [this](DeviceTensorND& dest, const InpVal& inp) {
auto&& pv = inp.val[0].value();
auto&& left = inp.val[1].value();
auto&& right = inp.val[2].value();
PredEvaluator eval{*this, pv};
auto el = eval(left), er = eval(right);
dest.resize({1}).ptr<int>()[0] =
(el != PredEvaluator::LT && er == PredEvaluator::LT);
return true;
};
auto infer_last = [this](DeviceTensorND& dest, const InpVal& inp) {
auto&& pv = inp.val[0].value();
auto&& key = inp.val[1].value();
dest.resize({1}).ptr<int>()[0] =
(PredEvaluator{*this, pv}(key) != PredEvaluator::LT);
return true;
};
ValueInferDesc desc{
SourceType::DEP,
{{input().back(), DepType::VALUE}, {input(0), DepType::VALUE}},
infer_first};
reg_value_infer_no_const(output(0), desc);
desc.deps.push_back({nullptr, DepType::VALUE});
desc.infer_func = infer_mid;
for (size_t i = 1; i < nr_key; ++i) {
desc.deps[1].dest = input(i - 1);
desc.deps[2].dest = input(i);
reg_value_infer_no_const(output(i), desc);
}
desc.deps.resize(2);
desc.deps[1].dest = input(nr_key - 1);
desc.infer_func = infer_last;
reg_value_infer_no_const(output(nr_key), desc);
}
}
CondExecPred::NodeProp* CondExecPred::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
for (auto&& i : ret->dep_map()) {
i.second = NodeProp::DepType::HOST_VALUE;
}
return ret;
}
void CondExecPred::scn_do_execute() {
auto&& mgr = owner_graph()->static_infer_manager();
PredEvaluator eval{*this, mgr.infer_value(input().back())};
auto mode = m_param.mode;
if (mode == Mode::CASE || mode == Mode::CASE_FALLBACK) {
bool enabled = false;
for (size_t i = 0; i < input().size() - 1; ++i) {
auto cur = eval(mgr.infer_value(input(i))) == PredEvaluator::EQ;
m_masks[i]->enable(cur);
enabled |= cur;
}
if (mode == Mode::CASE_FALLBACK) {
m_masks.back()->enable(!enabled);
}
} else {
mgb_assert(mode == Mode::PIECEWISE);
const DeviceTensorND *val_prev = nullptr, *val_cur = nullptr;
for (size_t i = 0; i < input().size(); ++i) {
val_prev = val_cur;
if (i == input().size() - 1) {
val_cur = nullptr;
} else {
val_cur = &mgr.infer_value(input(i));
}
PredEvaluator::Result el, er;
if (!val_prev) {
el = PredEvaluator::GT;
} else {
el = eval(*val_prev);
}
if (!val_cur) {
er = PredEvaluator::LT;
} else {
er = eval(*val_cur);
}
m_masks[i]->enable(el != PredEvaluator::LT && er == PredEvaluator::LT);
}
}
}
VarNode* CondExecPred::out_var_from_mask(ExecutionMask* mask) const {
for (size_t i = 0; i < output().size(); ++i) {
if (mask == m_masks[i].get()) {
return output(i);
}
}
mgb_throw(AssertionError, "bad mask");
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecPredLogical);
class CondExecPredLogical::PredEvaluator {
bool (*m_updater)(int*, int);
int m_cur_val, m_negate = 0;
public:
explicit PredEvaluator(Mode mode, int init) : m_cur_val{init} {
auto fn_or = [](int* dst, int v) -> bool {
*dst |= v;
return !*dst;
};
auto fn_and = [](int* dst, int v) -> bool {
*dst &= v;
return *dst;
};
auto fn_xor = [](int* dst, int v) -> bool {
*dst ^= v;
return true;
};
switch (mode) {
case Mode::NOR:
m_negate = 1;
case Mode::OR:
m_updater = fn_or;
break;
case Mode::NAND:
m_negate = 1;
case Mode::AND:
m_updater = fn_and;
break;
case Mode::XNOR:
m_negate = 1;
case Mode::XOR:
m_updater = fn_xor;
break;
default:
mgb_throw(MegBrainError, "invalid CondExecPredLogical mode");
}
}
bool update(int val) { return m_updater(&m_cur_val, val); }
bool get() const { return m_cur_val ^ m_negate; }
};
CondExecPredLogical::CondExecPredLogical(
const VarNodeArrayView& preds, const Param& param,
const OperatorNodeConfig& config)
: Super(preds.at(0)->owner_graph(), config, mgb_cstr_log(mode2str(param.mode)),
preds),
m_param{param} {
m_input_masks.resize(preds.size());
auto gr = CondExecPred::GlobalRegistry::get(*owner_graph());
for (size_t i = 0; i < preds.size(); ++i) {
m_input_masks[i] = gr->require_mask_from_var(preds[i]);
add_input(
{preds[i]},
i == preds.size() - 1 ? AddInputSortType::ALL : AddInputSortType::NONE);
}
add_output(None)->dtype(dtype::Int32{}).add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
m_mask = std::make_shared<ExecutionMask>(output(0));
add_equivalence_component<PODHash<Param>>(&m_param);
}
SymbolVar CondExecPredLogical::make(
const VarNodeArrayView& preds, const Param& param,
const OperatorNodeConfig& config) {
mgb_assert(!preds.empty());
if (preds.size() == 1) {
if (!config.has_comp_node_set() ||
config.get_single_comp_node() == preds[0]->comp_node()) {
auto m = param.mode;
if (m == Mode::OR || m == Mode::XOR || m == Mode::AND) {
return preds[0];
}
}
}
return SymbolVar{preds[0]}.insert_single_output_opr<CondExecPredLogical>(
preds, param, config);
}
void CondExecPredLogical::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0), ShapeInferDesc::make_const({1}));
auto infer_val = [mode = m_param.mode](DeviceTensorND& dst, const InpVal& inp) {
PredEvaluator eval{mode, inp.val[0].value().ptr<int>()[0]};
for (size_t i = 1; i < inp.val.size(); ++i) {
if (!eval.update(inp.val[i].value().ptr<int>()[0])) {
break;
}
}
dst.resize({1}).ptr<int>()[0] = eval.get();
return true;
};
ValueInferDesc desc;
desc.src_type = SourceType::DEP;
desc.deps.reserve(input().size());
for (auto i : input()) {
desc.deps.push_back({i, DepType::VALUE});
}
desc.infer_func = infer_val;
mgr.register_value_infer(output(0), desc);
}
void CondExecPredLogical::scn_do_execute() {
PredEvaluator eval{m_param.mode, m_input_masks[0]->enabled()};
for (size_t i = 1; i < m_input_masks.size(); ++i) {
if (!eval.update(m_input_masks[i]->enabled())) {
break;
}
}
m_mask->enable(eval.get());
}
CondExecPredLogical::NodeProp* CondExecPredLogical::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
for (auto&& i : ret->dep_map()) {
i.second = NodeProp::DepType::DEV_COMP_ORDER;
}
ret->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
return ret;
}
const char* CondExecPredLogical::mode2str(Mode mode) {
switch (mode) {
#define CASE(n) \
case Mode::n: \
return #n
CASE(OR);
CASE(AND);
CASE(XOR);
CASE(NOR);
CASE(NAND);
CASE(XNOR);
default:
mgb_throw(
MegBrainError, "bad CondExecPredLogical mode: %d",
static_cast<int>(mode));
}
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMark);
CondExecMark::CondExecMark(
VarNode* ppv, const VarNodeArrayView& inputs, const Param& param,
const OperatorNodeConfig& config)
: Super(ppv->owner_graph(), config, "cond_mark", {ppv}), m_param{param} {
CondExecPred::GlobalRegistry::get(*owner_graph())->check_ppv(ppv);
for (size_t i = 0; i < inputs.size(); ++i) {
add_input({inputs[i]});
add_output(ssprintf("fwd%zu", i))
->dtype(inputs[i]->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
add_input({ppv});
add_equivalence_component<PODHash<Param>>(&m_param);
if (has_no_shape_infer()) {
for (auto i : input()) {
i->add_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC);
}
for (auto i : output()) {
i->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
}
} else {
m_mem_fwd_success.resize(inputs.size(), false);
}
}
void CondExecMark::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
using InferMode = Param::StaticInfer;
auto infer_mode = param().static_infer;
if (infer_mode == InferMode::NONE) {
return;
}
for (size_t i = 0; i < output().size(); ++i) {
auto s = input(i), t = output(i);
mgr.register_shape_infer(t, ShapeInferDesc::make_identity(s));
if (infer_mode != InferMode::SHAPE_ONLY) {
auto desc = ValueInferDesc::make_identity(s);
mixin::ForwardInputToOutput::ensure_not_replaced_by_const_folding(desc);
mgr.register_value_infer(t, desc);
}
}
}
void CondExecMark::scn_do_execute() {
bool no_sys_alloc = has_no_shape_infer();
for (size_t i = 0; i < output().size(); ++i) {
if (no_sys_alloc) {
bool succ = output(i)->reset_dev_tensor_from_other_var(input(i));
MGB_MARK_USED_VAR(succ);
} else {
auto &&out = output(i)->dev_tensor(), &&inp = input(i)->dev_tensor();
if (m_mem_fwd_success[i]) {
mgb_assert(
inp.raw_ptr() == out.raw_ptr() &&
out.layout().eq_layout(inp.layout()));
} else {
out.copy_from_fixlayout(inp);
}
}
}
}
void CondExecMark::init_rt_force_dynamic_mem_alloc_imply_chain() {
if (has_no_shape_infer()) {
return;
}
for (size_t i = 0; i < output().size(); ++i) {
auto s = input(i), t = output(i);
s->add_rt_force_dynamic_mem_alloc_imply_chain(t);
t->add_rt_force_dynamic_mem_alloc_imply_chain(s);
}
}
void CondExecMark::mem_plan_fwd_in2out_readonly() {
if (has_no_shape_infer()) {
return;
}
for (size_t i = 0; i < output().size(); ++i) {
auto s = input(i), t = output(i);
m_mem_fwd_success[i] = t->set_fwd_in2out_readonly(
s, SubTensorSpec::make_from_layout(s->layout()));
}
}
void CondExecMark::add_input_layout_constraint() {
if (has_no_shape_infer()) {
for (auto i : input()) {
i->add_layout_constraint_contiguous();
}
}
}
CondExecMark::NodeProp* CondExecMark::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER;
for (size_t i = 0; i < input().size() - 1; ++i) {
ret->add_dep_type_existing_var(input(i), NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
cg::OperatorNodeBase* CondExecMark::make_opr(
SymbolVar ppv, const VarNodeArrayView& inputs, const Param& param,
const OperatorNodeConfig& config) {
return ppv.node()->owner_graph()->insert_opr(
std::make_unique<CondExecMark>(ppv.node(), inputs, param, config));
}
SymbolVar CondExecMark::mark_if_need(
SymbolVar maybe_ppv, SymbolVar input, const Param& param,
const OperatorNodeConfig& config) {
auto mask = CondExecPred::GlobalRegistry::get(*maybe_ppv.node()->owner_graph())
->get_mask_from_var(maybe_ppv.node());
if (mask) {
return make_opr(mask->owner(), {input}, param, config)->output(0);
}
return input;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondExecMark) {
if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) {
return nullptr;
}
using GradMode = CondExecMark::Param::GradMode;
using MergeMode = CondExecMerge::Param::Mode;
MergeMode grad_mode;
SymbolVarArray grad_shapes;
switch (opr.param().grad_mode) {
case GradMode::SUM:
grad_mode = MergeMode::SUM;
grad_shapes.emplace_back(SymbolVar{opr.input(wrt_idx)}.symshape());
break;
case GradMode::SUM_COND_OUT:
grad_mode = MergeMode::SUM_COND_OUT;
break;
default:
mgb_throw(MegBrainError, "invalid grad_mode");
}
return CondExecMerge::make_opr(
{out_grad[wrt_idx]}, grad_shapes, {1, grad_mode},
OperatorNodeConfig{})
->output(0);
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMerge);
CondExecMerge::CondExecMerge(
const VarNodeArrayView& inputs, const VarNodeArrayView& out_shapes,
const Param& param, const OperatorNodeConfig& config)
: Super(inputs[0]->owner_graph(), config, "cond_merge", {}), m_param{param} {
mgb_throw_if(
inputs.size() % param.nr_output, GraphError,
"input size can not divide nr_output: %zu %u", inputs.size(),
param.nr_output);
auto global_registry = CondExecPred::GlobalRegistry::get(*owner_graph());
auto nr_branch = inputs.size() / param.nr_output;
mgb_assert(param.nr_output);
for (size_t i = 0; i < param.nr_output; ++i) {
auto ovar = add_output(ssprintf("out%zu", i));
ovar->dtype(inputs[i]->dtype());
ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
MGB_MARK_USED_VAR(mask2str);
m_branch_masks.resize(nr_branch, nullptr);
for (size_t i = 0; i < nr_branch; ++i) {
ExecutionMask* br_mask = nullptr;
for (size_t j = 0; j < param.nr_output; ++j) {
auto ivar = inputs[i * param.nr_output + j];
auto mask = global_registry->require_mask_from_var(ivar);
mgb_throw_if(
output(j)->dtype() != ivar->dtype(), GraphError,
"CondExecMerge input dtypes mismatch: branch=%zu %s vs %s", i,
output(j)->dtype().name(), ivar->dtype().name());
if (!j) {
br_mask = mask;
} else {
mgb_throw_if(
br_mask != mask, GraphError,
"CondExecMerge branch %zu have different masks: "
"%s vs %s",
i, mask2str(br_mask).c_str(), mask2str(mask).c_str());
}
mgb_assert(ivar->contain_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC));
add_input({ivar});
}
m_branch_masks[i] = br_mask;
}
add_equivalence_component<PODHash<Param>>(&m_param);
if (param.mode == Mode::SUM) {
mgb_assert(out_shapes.size() == param.nr_output);
for (auto i : out_shapes) {
add_input({i});
}
} else {
mgb_assert(
out_shapes.empty(),
"out_shapes should not be given if mode is not SUM");
}
if (param.mode == Mode::SUM_COND_OUT) {
VarNodeArray preds;
preds.reserve(nr_branch);
for (auto i : m_branch_masks) {
preds.emplace_back(proxy_var_from_mask(i));
}
auto cn = mixin_infer_output_comp_node(*this, true);
auto preds_or =
CondExecPredLogical::make(preds, CondExecPredLogical::Mode::OR, cn);
add_input({preds_or.node()});
}
}
cg::OperatorNodeBase* CondExecMerge::make_opr(
const VarNodeArrayView& inputs, const VarNodeArrayView& out_shapes,
const Param& param, const OperatorNodeConfig& config) {
mgb_assert(!inputs.empty());
const VarNodeArrayView* out_shapes_ptr = &out_shapes;
Maybe<VarNodeArrayView> out_shapes_from_inp;
VarNodeArray out_shapes_from_inp_storage;
if (out_shapes.empty() && param.mode == Mode::SUM) {
mgb_assert(inputs.size() % param.nr_output == 0);
size_t nr_branch = inputs.size() / param.nr_output;
auto inp = [&](size_t br, size_t oidx) {
return inputs[br * param.nr_output + oidx];
};
for (size_t oidx = 0; oidx < param.nr_output; ++oidx) {
bool found = false;
for (size_t br = 0; br < nr_branch; ++br) {
auto ivar = inp(br, oidx);
if (cg::is_static_var_shape(ivar)) {
found = true;
out_shapes_from_inp_storage.push_back(
SymbolVar{ivar}.symshape().node());
break;
}
}
mgb_throw_if(
!found, GraphError,
"out_shapes is omitted but no input shape is "
"inferrable for output %zu",
oidx);
}
out_shapes_ptr = &out_shapes_from_inp.emplace(out_shapes_from_inp_storage);
}
return inputs[0]->owner_graph()->insert_opr(
std::make_unique<CondExecMerge>(inputs, *out_shapes_ptr, param, config));
}
void CondExecMerge::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto nr_out = m_param.nr_output;
auto inp = [this, nr_out](size_t branch, size_t oidx) {
return input(branch * nr_out + oidx);
};
static auto select_one_branch = [](size_t nr_branch, const InpVal& bval) -> size_t {
bool found = false;
size_t ret;
for (size_t i = 0; i < nr_branch; ++i) {
if (bval.val[i].value().ptr<int>()[0]) {
if (!found) {
found = true;
ret = i;
} else {
mgb_throw(
GraphError,
"multiple branches are active in EXACT_ONE mode: "
"%zu and %zu",
ret, i);
}
}
}
mgb_throw_if(!found, GraphError, "no branch is active in EXACT_ONE mode");
return ret;
};
DepVal branch_deps;
auto nr_branch = m_branch_masks.size();
branch_deps.reserve(nr_branch);
for (size_t i = 0; i < nr_branch; ++i) {
branch_deps.push_back({proxy_var_from_mask(m_branch_masks[i]), DepType::VALUE});
}
for (size_t oidx = 0; oidx < nr_out; oidx++) {
if (m_param.mode == Mode::EXACT_ONE_SAME_SHAPE ||
m_param.mode == Mode::SUM_COND_OUT) {
bool found = false;
for (size_t i = 0; i < nr_branch; ++i) {
if (cg::is_static_var_shape(inp(i, oidx))) {
mgr.register_shape_infer(
output(oidx), ShapeInferDesc::make_identity(inp(i, oidx)));
found = true;
break;
}
}
if (!found) {
mgr.register_shape_infer(
output(oidx), ShapeInferDesc::make_identity(inp(0, oidx)));
}
} else if (m_param.mode == Mode::SUM) {
auto infer_fn = [](TensorShape& dst, const InpVal& inp) {
cg::copy_tensor_value_to_shape(dst, inp.val[0].value());
return true;
};
mgr.register_shape_infer(
output(oidx), {SourceType::DEP,
{{inp(nr_branch, oidx), DepType::VALUE}},
infer_fn});
} else {
auto infer_fn = [this](TensorShape& dest, const InpVal& inp) {
auto nr_branch = m_branch_masks.size();
size_t branch = select_one_branch(nr_branch, inp);
dest = inp.val.at(nr_branch + branch).shape();
return true;
};
ShapeInferDesc desc{SourceType::DEP, branch_deps, infer_fn};
for (size_t i = 0; i < nr_branch; ++i) {
desc.deps.push_back({inp(i, oidx), DepType::SHAPE});
}
mgr.register_shape_infer(output(oidx), desc);
}
ValueInferDesc desc{SourceType::DEP, branch_deps, {}};
for (size_t i = 0; i < nr_branch; ++i) {
desc.deps.push_back({inp(i, oidx), DepType::VALUE});
}
if (is_exact_one()) {
desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) {
auto nr_branch = m_branch_masks.size();
size_t branch = select_one_branch(nr_branch, inp);
dest = inp.val.at(nr_branch + branch).value();
return true;
};
} else {
mgb_assert(m_param.mode == Mode::SUM || m_param.mode == Mode::SUM_COND_OUT);
desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) {
auto nr_branch = m_branch_masks.size();
bool found = false, first = true;
auto&& shape = inp.val.at(nr_branch).shape();
for (size_t i = 0; i < nr_branch && !shape.is_empty(); ++i) {
if (!inp.val[i].value().ptr<int>()[0])
continue;
auto&& cur = inp.val.at(nr_branch + i).value();
if (!found) {
found = true;
dest = cur;
} else {
if (first) {
first = false;
DeviceTensorND tmp;
tmp.copy_from(dest);
dest = std::move(tmp);
}
auto dnn_opr = intl::create_megdnn_opr<megdnn::Elemwise>(
dest.comp_node());
dnn_opr->param().mode = Elemwise::Mode::ADD;
dnn_opr->exec(
{dest.as_megdnn(), cur.as_megdnn()}, dest.as_megdnn());
}
}
if (!found) {
if (dest.storage().raw_storage().use_count() > 1) {
DeviceTensorND tmp{dest.comp_node(), shape, dest.dtype()};
dest = std::move(tmp);
} else {
dest.resize(shape);
}
fill_zero_dev_tensor(dest);
}
return true;
};
}
mgr.register_value_infer(output(oidx), desc);
}
}
void CondExecMerge::scn_do_execute() {
auto nr_out = m_param.nr_output;
auto inp = [this, nr_out](size_t branch, size_t oidx) {
return input(branch * nr_out + oidx);
};
auto cn = this->comp_node();
mgb_assert(cn == output(0)->comp_node());
bool first = true;
auto&& forwarded = m_mem_forwarded;
std::vector<bool> is_shape_empty(nr_out, false);
for (size_t br = 0; br < m_branch_masks.size(); ++br) {
if (!m_branch_masks[br]->enabled()) {
continue;
}
if (first) {
first = false;
for (size_t oidx = 0; oidx < nr_out; ++oidx) {
bool succ =
output(oidx)->reset_dev_tensor_from_other_var(inp(br, oidx));
if (inp(br, oidx)->shape().is_empty()) {
is_shape_empty[oidx] = true;
continue;
}
if (!is_exact_one()) {
if (forwarded.empty()) {
forwarded.resize(nr_out);
}
forwarded[oidx] = succ;
}
}
} else {
mgb_throw_if(
is_exact_one(), GraphError,
"multiple branches are active in EXACT_ONE mode");
auto&& dnn_opr = m_exec_dnn_opr;
if (!dnn_opr || dnn_opr.comp_node() != cn) {
dnn_opr = intl::create_megdnn_opr<megdnn::Elemwise>(cn);
dnn_opr->param().mode = Elemwise::Mode::ADD;
}
for (size_t oidx = 0; oidx < nr_out; ++oidx) {
auto ovar = output(oidx);
auto&& src = inp(br, oidx)->dev_tensor().as_megdnn();
auto&& dest = ovar->dev_tensor().as_megdnn();
mgb_assert(
src.layout.eq_shape(dest.layout),
"shape mismatch: %s vs %s in CondExecMerge",
src.layout.to_string().c_str(),
dest.layout.to_string().c_str());
if (is_shape_empty[oidx])
continue;
if (forwarded[oidx]) {
ovar->shape_alloc(ovar->shape());
auto&& own_dest = ovar->dev_tensor().as_megdnn();
mgb_assert(own_dest.raw_ptr() != dest.raw_ptr());
dnn_opr->exec({dest, src}, own_dest);
forwarded[oidx] = false;
} else {
dnn_opr->exec({dest, src}, dest);
}
}
}
}
if (first) {
mgb_throw_if(
is_exact_one(), GraphError, "no branch is selected in EXACT_ONE mode");
mgb_assert(m_param.mode == Param::Mode::SUM);
auto&& mgr = owner_graph()->static_infer_manager();
for (auto var : output()) {
auto&& dv = var->shape_alloc(mgr.infer_shape(var)).dev_tensor();
fill_zero_dev_tensor(dv);
}
} else if (m_param.mode == Param::Mode::SUM) {
auto&& mgr = owner_graph()->static_infer_manager();
for (auto var : output()) {
auto&& shp_infer = mgr.infer_shape(var);
auto&& shp_got = var->shape();
mgb_throw_if(
!shp_infer.eq_shape(shp_got), GraphError,
"inferred shape is %s, actual shape is %s",
shp_infer.to_string().c_str(), shp_got.to_string().c_str());
}
}
}
void CondExecMerge::add_input_layout_constraint() {
for (auto i : input()) {
i->add_layout_constraint_contiguous();
}
}
CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
using DT = NodeProp::DepType;
if (m_param.mode == Mode::SUM) {
SmallVector<DT> inp_dt(input().size(), DT::DEV_VALUE);
for (size_t i = 0; i < m_param.nr_output; ++i) {
inp_dt[inp_dt.size() - i - 1] = DT::HOST_VALUE;
}
ret->reset_dep_type(input(), inp_dt);
} else if (m_param.mode == Mode::SUM_COND_OUT) {
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER;
}
for (size_t i = 0; i < m_param.nr_output * m_branch_masks.size(); ++i) {
ret->add_dep_type_existing_var(input(i), NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondExecMerge) {
using Mode = CondExecMerge::Param::Mode;
if (opr.param().mode == Mode::SUM_COND_OUT && wrt_idx == opr.input().size() - 1) {
return nullptr;
}
if (opr.param().mode == Mode::SUM &&
wrt_idx >= opr.input().size() - opr.output().size()) {
return InvalidGrad::make(opr, wrt_idx);
}
size_t wrt_branch = wrt_idx / opr.param().nr_output,
wrt_oidx = wrt_idx % opr.param().nr_output;
auto og = out_grad.at(wrt_oidx);
if (!og) {
return nullptr;
}
auto ppv = proxy_var_from_mask(opr.branch_mask(wrt_branch));
if (ppv->comp_node().mem_node() != og->comp_node().mem_node()) {
ppv = CondExecPredLogical::make(
{ppv}, CondExecPredLogical::Mode::AND, og->comp_node())
.node();
}
CondExecMark::Param gparam;
if (opr.param().mode == Mode::EXACT_ONE) {
gparam.static_infer = CondExecMark::Param::StaticInfer::NONE;
}
return CondExecMark::make_opr(
ppv, {og}, gparam, OperatorNodeConfig{og->comp_node()})
->output(0);
}
#endif
void CondExecMerge::modify_grad_sum_list(VarNode* wrt, VarNodeArray& grads) {
if (!ExecutionMask::have_alive_instance()) {
return;
}
auto global_registry_vec =
grads.at(0)
->owner_graph()
->options()
.user_data.get_user_data<CondExecPred::GlobalRegistry>();
if (!global_registry_vec.second) {
return;
}
auto global_registry = global_registry_vec.first[0];
size_t nr_var_remove = 0, nr_merge_opr = 0;
VarNodeArray merged_branches;
static constexpr Param::Mode BAD_MODE = static_cast<Param::Mode>(-1);
Param::Mode merged_mode = BAD_MODE;
ExecutionMask* part_exec_mask = nullptr;
bool have_multiple_exec_mask = false;
auto check_multiple_mask = [&part_exec_mask,
&have_multiple_exec_mask](ExecutionMask* mask) {
if (!part_exec_mask) {
part_exec_mask = mask;
} else if (part_exec_mask != mask) {
have_multiple_exec_mask = true;
}
};
for (size_t i = grads.size(); i;) {
--i;
auto opr = grads[i]->owner_opr();
if (opr->same_type<CondExecMerge>()) {
mgb_assert(
opr->output().size() == 1,
"CondExecMerge in grad list has multiple outputs: "
"name=%s out=%zu",
opr->cname(), opr->output().size());
auto cur_mode = opr->cast_final<CondExecMerge>().param().mode;
mgb_assert(
cur_mode == Param::Mode::SUM ||
cur_mode == Param::Mode::SUM_COND_OUT);
if (merged_mode != Param::Mode::SUM_COND_OUT) {
merged_mode = cur_mode;
}
merged_branches.insert(
merged_branches.end(), opr->input().begin(), opr->input().end());
if (cur_mode == Param::Mode::SUM_COND_OUT) {
mgb_assert(opr->input().size() == opr->output().size() + 1);
merged_branches.pop_back();
check_multiple_mask(
global_registry->require_mask_from_var(opr->output(0)));
} else if (cur_mode == Param::Mode::SUM) {
mgb_assert(opr->input().size() >= opr->output().size() * 2);
merged_branches.resize(merged_branches.size() - opr->output().size());
}
++nr_merge_opr;
++nr_var_remove;
std::swap(grads[grads.size() - nr_var_remove], grads[i]);
} else if (auto mask = global_registry->get_mask_from_var(grads[i])) {
check_multiple_mask(mask);
merged_branches.push_back(grads[i]);
++nr_var_remove;
std::swap(grads[grads.size() - nr_var_remove], grads[i]);
merged_mode = Param::Mode::SUM_COND_OUT;
}
}
if (have_multiple_exec_mask || nr_merge_opr > 1) {
mgb_assert(merged_mode != BAD_MODE);
grads.resize(grads.size() - nr_var_remove);
SymbolVarArray grad_shapes;
if (merged_mode == Param::Mode::SUM) {
grad_shapes.emplace_back(SymbolVar{wrt}.symshape());
}
grads.push_back(CondExecMerge::make_opr(
merged_branches, grad_shapes, {1, merged_mode},
OperatorNodeConfig{})
->output(0));
}
}
#endif