#include "megbrain/jit/executor_opr.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/gopt/framework.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/graph/helper.h"
#include "megbrain/jit/compiler.h"
#include "megbrain/jit/param_elem_visitor.h"
#include "megbrain/jit/placeholder_opr.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/utils/hash.h"
#if MGB_JIT
using namespace mgb;
using namespace jit;
using CPFlag = Compiler::Property::Flag;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(JITExecutor);
JITExecutor::JITExecutor(
const InternalGraphPtr& internal_graph, const VarNodeArray& inputs,
const OperatorNodeConfig& config)
: Super(internal_graph->output()->owner_graph(), config,
ssprintf("JIT-Fusion{%zu}", internal_graph->placeholders().size()),
inputs),
m_internal_graph{internal_graph},
m_compiler{Compiler::get(*inputs[0]->owner_graph(), inputs[0]->comp_node())} {
for (auto inp : inputs) {
add_input({inp});
}
m_input_broadcastable.resize(inputs.size());
auto&& placeholders = m_internal_graph->placeholders();
mgb_assert(placeholders.size() == inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
mgb_assert(placeholders[i]->output(0) != internal_graph->output());
if (placeholders[i]->is_host_value_shape_input() ||
input()[i]->owner_opr()->same_type<opr::MarkNoBroadcastElemwise>()) {
m_input_broadcastable[i] = false;
} else {
m_input_broadcastable[i] = true;
}
}
if (inputs.size() == 1) {
m_input_broadcastable[0] = false;
} else {
Maybe<size_t> non_scalar;
for (size_t i = 0; i < input().size(); ++i) {
if (placeholders[i]->is_host_value_shape_input())
continue;
if (!(cg::is_const_var_shape(input(i)) && input(i)->shape().is_scalar())) {
if (non_scalar.valid()) {
non_scalar.invalidate();
break;
}
non_scalar = i;
}
}
if (non_scalar.valid()) {
m_input_broadcastable[non_scalar.val()] = false;
}
}
add_output(None)->dtype(m_internal_graph->output()->dtype());
add_equivalence_component<ScalarHash<void*>>(internal_graph->output());
for (size_t i = 0, it = m_compiler->get_nr_workspace_outputs(this); i < it; ++i) {
cg::add_workspace_output(this);
}
size_t nr_placeholders = internal_graph_ptr()->placeholders().size();
std::vector<bool> used(nr_placeholders, false);
cg::DepOprIter{[this, nr_placeholders, &used](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::Reduce>()) {
m_feature_bits |= JITFeatureBits::REDUCE;
}
if (opr->same_type<opr::Dimshuffle>()) {
m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
}
if (auto ph = opr->try_cast_final<JITPlaceholder>()) {
mgb_assert(
ph->input_id() < nr_placeholders,
"bad placeholders %s in JITExecutor %s", ph->cname(), cname());
used[ph->input_id()] = true;
}
}}.add(internal_graph->output());
for (size_t i = 0; i < nr_placeholders; ++i) {
mgb_assert(
used[i], "placeholder %s is not depended on the output of %s",
internal_graph_ptr()->placeholders()[i]->cname(), cname());
}
if (has_dimshuffle()) {
prepare_dimshuffle();
}
}
void JITExecutor::add_input_layout_constraint() {
if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_CONTIG)) {
for (auto i : input()) {
i->add_layout_constraint_contiguous();
}
} else {
for (auto i : input()) {
i->add_layout_constraint_monotone();
}
}
}
void JITExecutor::init_output_mem_plan(bool dynamic) {
Super::init_output_mem_plan(dynamic);
m_args.need_update = true;
}
void JITExecutor::mem_plan_fwd_in2out_writable() {
if (m_feature_bits != JITFeatureBits::NONE)
return;
mixin_mem_plan_fwd_in2out_writable(*this);
}
SymbolVar JITExecutor::make(
const InternalGraphPtr& internal_graph, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
return internal_graph->output()
->owner_graph()
->insert_opr(std::make_unique<JITExecutor>(internal_graph, inputs, config))
->output(0);
}
void JITExecutor::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(
output(0), ShapeInferDesc::make_identity(m_internal_graph->shape_infer()));
m_compiler->init_workspace_size_infer(this);
if (m_internal_graph->value_infer()) {
mgr.register_value_infer(
output(0),
ValueInferDesc::make_identity(m_internal_graph->value_infer()));
}
}
void JITExecutor::scn_do_execute() {
if (m_executable == nullptr || m_args.need_update) {
m_executable = m_compiler->compile(this);
}
m_executable->execute(this);
}
void JITExecutor::do_dimshuffle() {
static auto get_dimshuffled_layout = [](const TensorLayout& ily,
std::vector<int> pattern) {
TensorLayout oly{ily.dtype};
oly.ndim = pattern.size();
bool input_used[TensorLayout::MAX_NDIM] = {0};
for (uint32_t idx = 0; idx < pattern.size(); ++idx) {
auto i = pattern[idx];
if (i < 0) {
oly.shape[idx] = 1;
oly.stride[idx] = 1;
} else {
input_used[i] = true;
oly.shape[idx] = ily.shape[i];
oly.stride[idx] = ily.stride[i];
}
}
for (size_t i = 0; i < ily.ndim; ++i) {
mgb_assert(
input_used[i] || ily.shape[i] == 1,
"non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
static_cast<const TensorShape&>(ily).to_string().c_str(), i);
}
return oly;
};
for (auto&& i : m_internal_graph->placeholders()) {
auto&& input = m_args.inputs[i->input_id()];
auto&& iter = m_jitph2dimshuffle.find(i);
if (iter == m_jitph2dimshuffle.end())
continue;
auto&& param = iter->second;
mgb_assert(
input.layout.ndim == param.second,
"input ndim mismatch for Dimshuffle: "
"expect=%u "
"actual=%zu",
param.second, input.layout.ndim);
auto dimshuffled_layout = get_dimshuffled_layout(input.layout, param.first);
input.layout = dimshuffled_layout;
}
}
void JITExecutor::update_args() {
m_args.outputs.clear();
for (auto out : output()) {
m_args.outputs.push_back({out, out->layout(), -1});
}
m_args.inputs.resize(input().size());
auto is_host_value_shape_input = [this](size_t idx) {
return m_internal_graph->placeholders().at(idx)->is_host_value_shape_input();
};
for (size_t i = 0; i < input().size(); i++) {
auto&& dst_data = m_args.inputs[i];
dst_data.from = input(i);
dst_data.idx = i;
if (is_host_value_shape_input(i)) {
auto&& mgr = owner_graph()->static_infer_manager();
auto&& shpval_inp_val = &mgr.infer_value(input(i));
cg::copy_tensor_value_to_shape(dst_data.layout, *shpval_inp_val);
dst_data.layout.dtype = {};
for (size_t i = 0; i < dst_data.layout.ndim; ++i) {
dst_data.layout.stride[i] = 0;
}
} else {
dst_data.layout = input(i)->layout();
}
}
if (has_dimshuffle()) {
do_dimshuffle();
}
if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) {
opr::Elemwise::TensorLayoutPtrArray inp_layouts;
inp_layouts.reserve(m_args.inputs.size());
for (size_t i = 0; i < m_args.inputs.size(); i++) {
if (!is_host_value_shape_input(i)) {
inp_layouts.push_back(&m_args.inputs[i].layout);
}
}
opr::Elemwise::broadcast_collective_collapse(
inp_layouts, &m_args.outputs[0].layout);
}
XXHash hstate;
auto prop = m_compiler->property();
if (prop.contain_flag(CPFlag::BIND_NDIM | CPFlag::BIND_SHAPE)) {
mgb_assert(
prop.contain_flag(CPFlag::BIND_NDIM),
"BIND_NDIM must be set if bind_shape is set");
std::vector<size_t> buf;
buf.reserve(1024);
buf.push_back(m_args.inputs.size());
for (auto&& i : m_args.inputs) {
buf.push_back(i.layout.ndim);
if (prop.contain_flag(CPFlag::BIND_SHAPE)) {
for (size_t j = 0; j < i.layout.ndim; ++j) {
buf.push_back(i.layout[j]);
}
}
}
hstate.update(buf.data(), sizeof(buf[0]) * buf.size());
}
m_args.hash = hstate.digest();
static std::atomic_uint_fast64_t global_version;
m_args.version = global_version.fetch_add(1);
m_args.need_update = false;
}
void JITExecutor::prepare_dimshuffle() {
std::unordered_set<OperatorNodeBase*> visited;
std::vector<OperatorNodeBase*> stack(0);
std::vector<uint8_t> idx(0); using Param = DimshuffleParam;
std::vector<Param> dimshuffle_stack;
auto merge_dimshuffle = [&](const opr::Dimshuffle::Param& p) {
if (dimshuffle_stack.empty()) {
dimshuffle_stack.emplace_back();
auto&& param = dimshuffle_stack.back();
param.first.insert(param.first.end(), p.pattern, p.pattern + p.pattern_len);
param.second = p.ndim;
} else {
dimshuffle_stack.emplace_back();
auto&& param = dimshuffle_stack.back();
auto&& src = dimshuffle_stack[dimshuffle_stack.size() - 2];
mgb_assert(p.pattern_len == src.second);
param.first.resize(src.first.size());
for (size_t i = 0; i < src.first.size(); ++i) {
if (src.first[i] == -1) {
param.first[i] = -1;
} else {
param.first[i] = p.pattern[src.first[i]];
}
}
param.second = p.ndim;
}
};
auto push_back = [&](cg::OperatorNodeBase* op) {
mgb_assert(!op->same_type<jit::JITPlaceholder>());
if (auto o = op->try_cast_final<opr::Dimshuffle>()) {
merge_dimshuffle(o->param());
}
stack.push_back(op);
idx.push_back(0);
};
auto pop_back = [&]() {
auto&& op = stack.back();
if (op->same_type<opr::Dimshuffle>()) {
dimshuffle_stack.pop_back();
}
stack.pop_back();
idx.pop_back();
};
push_back(m_internal_graph->output()->owner_opr());
while (!stack.empty()) {
if (idx.back() < stack.back()->input().size()) {
auto cur_opr = stack.back()->input(idx.back())->owner_opr();
if (visited.insert(cur_opr).second) {
if (auto jitph = cur_opr->try_cast_final<jit::JITPlaceholder>()) {
if (!dimshuffle_stack.empty()) {
mgb_assert(
m_jitph2dimshuffle
.emplace(jitph, dimshuffle_stack.back())
.second,
"already visited JITPlaceholder %s", jitph->cname());
}
++idx.back();
} else {
push_back(cur_opr);
}
} else {
++idx.back();
}
} else {
pop_back();
if (!stack.empty())
++idx.back();
}
}
}
const JITExecutor::Args& JITExecutor::args() const {
if (m_args.need_update) {
const_cast<JITExecutor*>(this)->update_args();
}
return m_args;
}
bool JITExecutor::Args::operator==(const Args& rhs) const {
auto&& lhs = *this;
mgb_assert(!lhs.need_update && !rhs.need_update);
if (lhs.hash != rhs.hash) {
return false;
}
if (lhs.version == rhs.version) {
return true;
}
if (lhs.outputs.size() != rhs.outputs.size())
return false;
if (lhs.inputs.size() != rhs.inputs.size())
return false;
auto prop = owner->m_compiler->property();
if (prop.contain_flag(CPFlag::BIND_NDIM | CPFlag::BIND_SHAPE)) {
bool (*chk_layout)(const TensorLayout&, const TensorLayout&);
if (prop.contain_flag(CPFlag::BIND_SHAPE)) {
chk_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) {
return lhs.eq_shape(rhs);
};
} else {
chk_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) {
return lhs.ndim == rhs.ndim;
};
}
for (size_t i = 0; i < lhs.inputs.size(); i++) {
if (!chk_layout(lhs.inputs[i].layout, rhs.inputs[i].layout))
return false;
}
for (size_t i = 0; i < lhs.outputs.size(); i++) {
if (!chk_layout(lhs.outputs[i].layout, rhs.outputs[i].layout))
return false;
}
}
lhs.version = rhs.version = std::min(lhs.version, rhs.version);
return true;
}
JITExecutor::NodeProp* JITExecutor::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
using DepType = NodeProp::DepType;
SmallVector<DepType> dt(input().size());
auto&& placeholders = internal_graph().placeholders();
for (size_t i = 0; i < dt.size(); ++i) {
dt[i] = placeholders[i]->is_host_value_shape_input() ? DepType::HOST_VALUE
: DepType::DEV_VALUE;
}
ret->reset_dep_type(input(), dt);
return ret;
}
megdnn::TensorShape JITExecutor::broadcasted_input_shape() const {
megdnn::TensorShapeArray inp_shps;
megdnn::TensorShape brdcast_shp;
auto placeholders = m_internal_graph->placeholders();
for (auto ph : placeholders) {
if (!ph->is_host_value_shape_input()) {
inp_shps.push_back(input(ph->input_id())->shape());
}
}
megdnn::Elemwise::deduce_shape(inp_shps, brdcast_shp);
return brdcast_shp;
}
#if MGB_ENABLE_GRAD
namespace {
class InternalGraphRewriter {
ThinHashMap<VarNode*, VarNode*> m_var_map;
VarNode* m_dest_var;
VarNodeArray m_new_inp;
VarNode* get_var(VarNode* var) {
auto&& iter = m_var_map.find(var);
if (iter != m_var_map.end()) {
return iter->second;
}
return var;
}
public:
InternalGraphRewriter(VarNode* dest_var) : m_dest_var{dest_var} {}
void iter(thin_function<void(cg::OperatorNodeBase*)>&& cb) {
m_var_map.clear();
cg::DepOprIter{std::move(cb)}.add(m_dest_var->owner_opr());
m_dest_var = get_var(m_dest_var);
}
VarNode* dest_var() { return m_dest_var; }
void replace_var(VarNode* src, VarNode* dst) {
m_var_map[src] = dst;
}
void auto_replace_outputs(cg::OperatorNodeBase* opr) {
mgb_assert(opr->usable_output().size() == 1);
m_new_inp.clear();
bool need_replace = false;
for (auto&& i : opr->input()) {
auto inp = get_var(i);
m_new_inp.push_back(inp);
need_replace |= (inp != i);
}
if (need_replace) {
auto new_op = serialization::copy_opr_shallow(*opr, m_new_inp);
replace_var(opr->output(0), new_op->output(0));
}
}
};
} MGB_IMPL_OPR_GRAD(JITExecutor) {
VarNodeArray grad_inputs;
for (auto input : opr.input())
grad_inputs.push_back(input);
mgb_assert(out_grad[0]);
grad_inputs.push_back(opr.output(0));
grad_inputs.push_back(out_grad[0]);
auto fwd_igraph_ptr = opr.internal_graph_ptr();
auto output_ph = JITPlaceholder::make(
fwd_igraph_ptr->output(), fwd_igraph_ptr->placeholders().size());
auto og_ph = JITPlaceholder::make(
out_grad[0], fwd_igraph_ptr->placeholders().size() + 1);
auto loss = opr::VirtualLoss::make({fwd_igraph_ptr->output()}, {og_ph});
auto gx = cg::grad(
loss, fwd_igraph_ptr->placeholders()[wrt_idx]->output(0), false, false);
if (!gx.node()) {
return nullptr;
}
if (gx.node()->owner_opr()->same_type<opr::InvalidGrad>()) {
return opr::InvalidGrad::make(opr, wrt_idx);
}
for (size_t i = 0; i < fwd_igraph_ptr->placeholders().size(); ++i) {
if (gx.node() == fwd_igraph_ptr->placeholders()[i]->output(0)) {
return grad_inputs[i];
}
}
if (gx.node() == og_ph.node()) {
return out_grad[0];
}
if (gx.node() == fwd_igraph_ptr->output()) {
return opr.output(0);
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(gx.node()->owner_opr())) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
return opr::ImmutableTensor::make(*imm->owner_graph(), hval).node();
}
InternalGraphRewriter rewriter{gx.node()};
rewriter.iter([&rewriter, &fwd_igraph_ptr, &output_ph](cg::OperatorNodeBase* opr) {
if (opr == fwd_igraph_ptr->output()->owner_opr()) {
rewriter.replace_var(opr->output(0), output_ph.node());
return;
}
rewriter.auto_replace_outputs(opr);
});
auto expand_into_origin_graph = [&rewriter](
cg::OperatorNodeBase* opr,
const VarNodeArray& grad_inputs) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
rewriter.replace_var(opr->output(0), grad_inputs.at(ph->input_id()));
return;
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
rewriter.replace_var(
opr->output(0),
opr::ImmutableTensor::make(*opr->owner_graph(), hval).node());
return;
}
rewriter.auto_replace_outputs(opr);
};
if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) {
using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1, std::cref(grad_inputs)));
return rewriter.dest_var();
} else {
VarNodeArray new_grad_inputs;
PlaceholderArray placeholders;
bool all_inp_const = true;
rewriter.iter([&rewriter, &grad_inputs, &new_grad_inputs, &placeholders,
&all_inp_const](cg::OperatorNodeBase* opr) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
new_grad_inputs.push_back(grad_inputs[ph->input_id()]);
auto new_ph = JITPlaceholder::make(
new_grad_inputs.back(), placeholders.size())
.node()
->owner_opr();
placeholders.push_back(new_ph->try_cast_final<JITPlaceholder>());
mgb_assert(placeholders.back());
rewriter.replace_var(opr->output(0), new_ph->output(0));
if (!cg::is_const_var_value(new_grad_inputs.back())) {
all_inp_const = false;
}
return;
}
rewriter.auto_replace_outputs(opr);
});
if (all_inp_const) {
using namespace std::placeholders;
rewriter.iter(std::bind(
expand_into_origin_graph, _1, std::cref(new_grad_inputs)));
return rewriter.dest_var();
}
gx = rewriter.dest_var();
auto shape_infer = fwd_igraph_ptr->shape_infer();
if (opr.has_dimshuffle()) {
auto&& iter = opr.dimshuffle_params().find(
fwd_igraph_ptr->placeholders()[wrt_idx]);
if (iter != opr.dimshuffle_params().end()) {
auto&& pattern = iter->second.first;
auto&& ndim = iter->second.second;
std::vector<int> back(ndim, -1);
for (size_t i = 0; i < pattern.size(); i++) {
auto j = pattern[i];
if (j >= 0) {
mgb_assert(
back[j] == -1,
"taking grad for Dimshuffle with duplicated "
"input axis unsupported");
back[j] = i;
}
}
shape_infer =
opr::Dimshuffle::make(shape_infer, back, pattern.size()).node();
}
}
auto grad_ig = std::make_shared<InternalGraph>(
gx.node(), shape_infer, nullptr, std::move(placeholders));
auto grad_jit = JITExecutor::make(grad_ig, new_grad_inputs);
if (opr.input_broadcastable()[wrt_idx]) {
grad_jit = opr::reduce_sum(
grad_jit, opr::GetVarShape::make(opr.input(wrt_idx)));
}
return grad_jit.node();
}
}
#endif
#endif