#include "megbrain/opr/indexing.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/utility.h"
#include "./internal/megdnn_opr_wrapper.inl"
using namespace mgb;
using namespace opr;
namespace {
void check_index_dtype(std::initializer_list<SymbolVar*>& inputs) {
mgb_assert(inputs.size() >= 2);
auto iter = inputs.begin();
++iter;
SymbolVar& index = **iter;
if (index.dtype() != dtype::Int32()) {
mgb_log_warn(
"dtype of index in IndexingOneHot must be Int32, "
"got %s for variable %s; convert to Int32 implicitly",
index.dtype().name(), index.node()->cname());
index = opr::TypeCvt::make(index, dtype::Int32());
}
}
enum IndexingModifyType { SET, INCR };
template <typename Opr>
struct IndexingModifyTypeGetter {};
#define REG(op, type) \
template <> \
struct IndexingModifyTypeGetter<megdnn::op> { \
static constexpr IndexingModifyType value = IndexingModifyType::type; \
};
REG(IndexingIncrMultiAxisVec, INCR)
REG(IncrMeshIndexing, INCR)
REG(BatchedIncrMeshIndexing, INCR)
REG(IndexingSetMultiAxisVec, SET)
REG(SetMeshIndexing, SET)
REG(BatchedSetMeshIndexing, SET)
#undef REG
}
namespace mgb {
namespace opr {
namespace intl {
template <>
struct MegDNNOprInitInputsModifier<IndexingOneHot> {
static void apply(
const IndexingOneHot::Param& param,
std::initializer_list<SymbolVar*> inputs) {
MGB_MARK_USED_VAR(param);
check_index_dtype(inputs);
}
};
template <>
struct MegDNNOprInitInputsModifier<IndexingSetOneHot>
: public MegDNNOprInitInputsModifier<IndexingOneHot> {};
} } }
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Diag);
MEGDNN_OPR_INIT1(Diag, "diag")
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Diag) {
if (wrt_idx == 0) {
SymbolVar data_sym{opr.input(0)};
return DiagBackward::make(data_sym.symshape(), out_grad[0], opr.param()).node();
}
return InvalidGrad::make(opr, wrt_idx);
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(DiagBackward);
DiagBackward::DiagBackward(
VarNode* shape, VarNode* value, const Param& param,
const OperatorNodeConfig& config)
: Super{shape->owner_graph(), config, "diag_backward", {shape, value}},
m_param{param} {
add_input({shape, value});
add_output(None)->dtype(value->dtype());
add_equivalence_component<PODHash<Param>>(&m_param);
}
SymbolVar DiagBackward::make(
SymbolVar shape, SymbolVar value, const Param& param,
const OperatorNodeConfig& config) {
return shape.insert_single_output_opr<DiagBackward>(
shape.node(), value.node(), param, config);
}
cg::OperatorNodeBase::NodeProp* DiagBackward::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
using D = NodeProp::DepType;
prop->add_dep_type(input(0), D::HOST_VALUE);
return prop;
}
void DiagBackward::scn_do_execute() {
auto&& dest = output(0)->dev_tensor();
auto&& val = input(1)->dev_tensor();
auto&& layout = dest.layout();
mgb_assert(layout.ndim == 1 || layout.ndim == 2);
if (layout.ndim == 2) {
dev_tensor_memset(dest, 0);
size_t offset = (m_param.k >= 0) ? (m_param.k * layout.stride[1])
: (-m_param.k * layout.stride[0]);
auto dest_sub = dest.sub(SubTensorSpec::make_from_offset_elem(
{val.shape(), {layout.stride[0] + layout.stride[1]}, val.dtype()},
offset));
dest_sub.copy_from_fixlayout(val);
} else {
auto&& opr = m_dnn_opr;
if (!opr) {
opr = intl::create_megdnn_opr<megdnn::Diag>(comp_node());
opr->param() = m_param;
}
opr->exec(val.as_megdnn(), dest.as_megdnn(), {});
}
}
void DiagBackward::record_execute_deps(ExecDependencyArray& deps) {
deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr)));
}
void DiagBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto infer_shape = [](TensorShape& dest, const InpVal& inp) {
cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value());
return true;
};
mgr.register_shape_infer(
output(0), {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_shape});
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(DiagBackward) {
return InvalidGrad::make(opr, wrt_idx);
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot);
MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot")
void IndexingOneHot::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingOneHot) {
if (wrt_idx == 0) {
return IndexingSetOneHot::make(
SymbolVar{opr.input(0)}.fill_retain_dtype(0), opr.input(1),
out_grad[0], opr.param())
.node();
}
return InvalidGrad::make(opr, wrt_idx);
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingSetOneHot);
MEGDNN_OPR_INIT3(IndexingSetOneHot, "indexing_set_one_hot")
void IndexingSetOneHot::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}
void IndexingSetOneHot::add_input_layout_constraint() {
mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}
void IndexingSetOneHot::mem_plan_fwd_in2out_writable() {
cg::request_fwd_in2out_writable_if_no_mem_ovelap(this, 0, 0);
}
void IndexingSetOneHot::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
init_output_static_infer_desc_workspace(false);
}
void IndexingSetOneHot::scn_do_execute() {
auto &&idata = input(0)->dev_tensor(), &&index = input(1)->dev_tensor(),
&&odata = output(0)->dev_tensor();
if (idata.raw_ptr() != odata.raw_ptr()) {
odata.copy_from_fixlayout(idata);
} else {
mgb_assert(odata.layout().eq_layout(idata.layout()));
}
mgb_assert(odata.layout().is_contiguous());
megdnn_opr()->exec(
odata.as_megdnn(), index.as_megdnn(), input(2)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingSetOneHot) {
SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)};
if (wrt_idx == 0) {
return IndexingSetOneHot::make(og, index, sub.fill_retain_dtype(0), opr.param())
.node();
}
if (wrt_idx == 2) {
return IndexingOneHot::make(og, index, opr.param()).node();
}
return InvalidGrad::make(opr, wrt_idx);
}
#endif
size_t IndexingSetOneHot::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return megdnn_opr()->get_workspace_in_bytes(
{input_shapes[0], input(0)->dtype()}, {input_shapes[1], input(1)->dtype()},
{input_shapes[2], input(2)->dtype()});
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemap);
MEGDNN_OPR_INIT2(IndexingRemap, "indexing_remap")
void IndexingRemap::init_output_dtype() {
mgb_throw_if(
input(1)->dtype() != dtype::Int32(), GraphError,
"IndexingRemap requires map input to be int32");
output(0)->dtype(input(0)->dtype());
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingRemap) {
if (wrt_idx == 1)
return InvalidGrad::make(opr, wrt_idx);
mgb_assert(wrt_idx == 0 && out_grad[0]);
return IndexingRemapBackward::make(
out_grad[0], opr.input(1), opr.input(0), opr.param())
.node();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemapBackward);
MEGDNN_OPR_INIT3(IndexingRemapBackward, "indexing_remap_bwd", 2, false);
template <class Opr>
Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr(
cg::SingleCNOperatorNodeBase& self) {
auto comp_node = self.comp_node();
if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node) {
m_dnn_opr = intl::create_megdnn_opr<Opr>(comp_node);
m_dnn_opr->set_error_tracker(static_cast<cg::OperatorNodeBase*>(&self));
}
return *m_dnn_opr;
}
template <class Opr>
void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer(
const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr,
VarNode* data, VarNode* value, VarNodeArray idx_arr) {
using namespace cg::static_infer;
DepVal deps = {{data, DepType::SHAPE}, {value, DepType::SHAPE}};
for (auto&& idx : idx_arr) {
deps.push_back({idx, DepType::SHAPE});
}
auto infer_shape = [this, &index_desc, &opr, nr_idx = idx_arr.size()](
TensorShape& dest, const InpVal& inp) {
size_t axes[TensorShape::MAX_NDIM], nr_axes = 0;
auto ndim = inp.val[0].shape().ndim;
for (auto&& i : reverse_adaptor(index_desc)) {
if (i.idx.node()) {
axes[nr_axes++] = i.axis.get(ndim);
}
}
mgb_assert(nr_axes == nr_idx);
if (!nr_axes) {
dest = {0};
} else {
size_t idx_ndim = 0;
for (size_t i = 0; i < nr_idx; ++i) {
idx_ndim = std::max(idx_ndim, inp.val[2 + i].shape().ndim);
}
mgb_assert(idx_ndim > 0);
dest = {megdnn_opr(opr).get_workspace_in_bytes(
inp.val[1].shape(), axes, nr_axes, idx_ndim)};
}
return true;
};
opr.owner_graph()->static_infer_manager().register_shape_infer(
opr.output(1), {SourceType::DEP, deps, infer_shape});
}
template <class Opr>
void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr(
mgb::cg::GraphExecutable::ExecDependencyArray& deps) {
deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr)));
}
std::pair<const megdnn::IndexingMultiAxisVec::IndexDesc&, bool> intl::
MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
size_t inp_ndim, bool warn_all_scalar) {
auto&& index = m_megdnn_index_cache;
index.clear();
bool is_empty_shape = false;
for (auto i : reverse_adaptor(m_input2idxonly_axis_indexer)) {
if (i) {
index.push_back(
{i->axis.get(inp_ndim), i->idx.node()->dev_tensor().as_megdnn()});
is_empty_shape |= index.back().vec.layout.is_empty();
}
}
if (!m_scalar_idx_warn_printed && warn_all_scalar &&
!this->owner_graph()->options().imperative_proxy_graph) {
bool all_scalar = true;
for (auto&& i : index) {
if (!i.vec.layout.is_scalar()) {
all_scalar = false;
break;
}
}
if (all_scalar) {
#if MGB_ENABLE_GETENV
mgb_log_warn(
"%s{%s}: no vector indexer; consider using Subtensor "
"family for better performance; you can set "
"MGB_THROW_ON_SCALAR_IDX to throw an exception to help "
"tracking the related operator",
cname(), dyn_typeinfo()->name);
#else
mgb_log_warn(
"%s{%s}: no vector indexer; consider using Subtensor "
"family for better performance",
cname(), dyn_typeinfo()->name);
#endif
#if MGB_ENABLE_GETENV
mgb_throw_if(
MGB_GETENV("MGB_THROW_ON_SCALAR_IDX"), MegBrainError,
"vector-indexing operator used with all "
"scalar indices");
#endif
}
m_scalar_idx_warn_printed = true;
}
return {index, is_empty_shape};
}
template <class Opr>
cg::OperatorNodeBase::NodeProp* IndexingMultiAxisVecBase<Opr>::do_make_node_prop()
const {
auto prop = Super::do_make_node_prop();
using DT = NodeProp::DepType;
prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY);
for (auto i : m_input2idxonly_axis_indexer) {
if (i) {
prop->add_dep_type_existing_var(
i->idx.node(), NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
}
return prop;
}
template <class Opr>
void IndexingMultiAxisVecBase<Opr>::init_output_static_infer_desc() {
using namespace cg::static_infer;
DepVal deps;
deps.push_back({input(0), DepType::SHAPE});
for (size_t i = m_input2idxonly_axis_indexer.size() - 1; i; --i) {
if (m_input2idxonly_axis_indexer[i]) {
deps.push_back({input(i), DepType::SHAPE});
}
}
size_t inp_interval_start = deps.size();
for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
if (!m_input2idxonly_axis_indexer[i]) {
deps.push_back({input(i), DepType::VALUE});
}
}
auto infer_shape = [this, inp_interval_start](
TensorShape& dest, const InpVal& inp) {
auto&& ishp = inp.val[0].shape();
auto subspec = fancy_indexing_make_sub_spec(
{ishp, input(0)->dtype()}, inp, inp_interval_start);
dest = subspec.layout();
typename Opr::IndexDescLayoutOnly index_layout;
size_t indexer_pos = 1;
for (auto i : reverse_adaptor(m_input2idxonly_axis_indexer)) {
if (i) {
index_layout.push_back(
{i->axis.get(dest.ndim),
{inp.val.at(indexer_pos++).shape(), dtype::Int32()}});
}
}
mgb_assert(indexer_pos == inp_interval_start);
if (!index_layout.empty()) {
TensorLayout tmp;
Opr::deduce_layout({dest, input(0)->dtype()}, index_layout, tmp);
dest = tmp;
}
return true;
};
owner_graph()->static_infer_manager().register_shape_infer(
output(0), {SourceType::DEP, deps, infer_shape});
VarNodeArray idx_arr;
for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
if (m_input2idxonly_axis_indexer[i]) {
idx_arr.push_back(input(i));
}
}
this->register_workspace_infer(index_desc(), *this, input(0), output(0), idx_arr);
}
template <class Opr>
void IndexingMultiAxisVecBase<Opr>::record_execute_deps(
mgb::cg::GraphExecutable::ExecDependencyArray& deps) {
this->record_megdnn_opr(deps);
}
namespace {
template <class Opr>
struct ShouldWarnOnScalarIndexer {
static constexpr bool val = false;
};
#define WARN(opr) \
template <> \
struct ShouldWarnOnScalarIndexer<megdnn::opr> { \
static constexpr bool val = true; \
}
WARN(IndexingMultiAxisVec);
WARN(IndexingSetMultiAxisVec);
WARN(IndexingIncrMultiAxisVec);
#undef WARN
}
template <class Opr>
void IndexingMultiAxisVecBase<Opr>::scn_do_execute() {
if (output(0)->layout().is_empty()) {
return;
}
auto inp = input(0)->dev_tensor();
inp = inp.sub(fancy_indexing_make_sub_spec(inp.layout()));
auto&& index_desc = make_megdnn_index_desc(
inp.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
auto&& odev = output(0)->dev_tensor();
if (index_desc.first.empty()) {
odev.copy_from_fixlayout(inp);
} else {
if (!index_desc.second) {
this->megdnn_opr(*this).exec(
inp.as_megdnn(), index_desc.first, odev.as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
} else {
mgb_assert(odev.empty());
}
}
}
template <class Opr>
void intl::IndexingModifyMultiAxisVecHelper<Opr>::init_output_static_infer_desc() {
using namespace cg::static_infer;
this->owner_graph()->static_infer_manager().register_shape_infer(
this->output(0), ShapeInferDesc::make_identity(this->input(0)));
VarNodeArray idx_arr;
for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
if (m_input2idxonly_axis_indexer[i]) {
idx_arr.push_back(input(i));
}
}
this->register_workspace_infer(index_desc(), *this, input(0), input(1), idx_arr);
}
template <class Opr>
void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute();
auto index_desc = this->make_megdnn_index_desc(
inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
if (inp.first.shape().is_empty() || index_desc.second) {
mgb_assert(inp.second.shape().is_empty());
return;
}
if (index_desc.first.empty()) {
using IMT = IndexingModifyType;
static constexpr auto modify_type = IndexingModifyTypeGetter<Opr>::value;
switch (modify_type) {
case IMT::SET: {
inp.first.copy_from_fixlayout(inp.second);
break;
}
case IMT::INCR: {
megdnn::AddUpdate* add_update =
intl::get_megdnn_global_opr<megdnn::AddUpdate>(comp_node());
add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn());
break;
}
default:
mgb_throw(MegBrainError, "bad modify type");
}
} else {
this->megdnn_opr(*this).exec(
inp.first.as_megdnn(), inp.second.as_megdnn(), index_desc.first,
intl::get_megdnn_workspace_from_var(output(1)));
}
}
template <class Opr>
cg::OperatorNodeBase::NodeProp* intl::IndexingModifyMultiAxisVecHelper<
Opr>::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
using DT = NodeProp::DepType;
prop->add_dep_type_existing_var(input(1), DT::VALUE_ALLOW_EMPTY);
for (auto i : m_input2idxonly_axis_indexer) {
if (i) {
prop->add_dep_type_existing_var(i->idx.node(), DT::VALUE_ALLOW_EMPTY);
}
}
return prop;
}
template <class Opr>
void intl::IndexingModifyMultiAxisVecHelper<Opr>::add_input_layout_constraint() {
auto check_cont1 = [](const TensorLayout& ly) {
return ly.collapse_contiguous().ndim == 1;
};
this->input(1)->add_layout_constraint(check_cont1);
}
MGB_IMPL_FANCY_INDEXING_OPR_GET(IndexingMultiAxisVec, "indexing_multi_axis_vec", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false);
IndexingSetMultiAxisVec::NodeProp* IndexingSetMultiAxisVec::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
prop->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return prop;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx);
return IndexingIncrMultiAxisVec::make(
SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
opr.index_desc())
.node();
}
#endif
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx);
if (wrt_idx == 0) {
return IndexingSetMultiAxisVec::make(
out_grad.at(0), SymbolVar{opr.input(1)}.fill_retain_dtype(0),
opr.index_desc())
.node();
}
return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node();
}
#endif
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx);
if (wrt_idx == 0) {
return out_grad.at(0);
}
return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node();
}
#endif
MGB_IMPL_FANCY_INDEXING_OPR_GET(MeshIndexing, "mesh_indexing", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
MGB_IMPL_FANCY_INDEXING_OPR_GET(BatchedMeshIndexing, "batched_mesh_indexing", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MeshIndexing) {
if (wrt_idx != 0) {
return InvalidGrad::make(opr, wrt_idx);
}
return IncrMeshIndexing::make(
SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
opr.index_desc())
.node();
}
#endif
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
if (wrt_idx != 0) {
return InvalidGrad::make(opr, wrt_idx);
}
return BatchedIncrMeshIndexing::make(
SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
opr.index_desc())
.node();
}
#endif
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", false);
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx);
}
if (wrt_idx == 0) {
return out_grad.at(0);
}
return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
}
#endif
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
BatchedIncrMeshIndexing, "batched_incr_mesh_indexing", false);
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx);
}
if (wrt_idx == 0) {
return out_grad.at(0);
}
return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
}
#endif
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false);
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
if (wrt_idx >= 2) {
return opr::InvalidGrad::make(opr, wrt_idx);
}
if (wrt_idx == 0) {
return SetMeshIndexing::make(
out_grad.at(0), SymbolVar{opr.input(1)}.fill_retain_dtype(0),
opr.index_desc())
.node();
} else {
return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
}
}
#endif
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
BatchedSetMeshIndexing, "batched_set_mesh_indexing", false);
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) {
if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx);
}
if (wrt_idx == 0) {
return BatchedSetMeshIndexing::make(
out_grad.at(0), SymbolVar{opr.input(1)}.fill_retain_dtype(0),
opr.index_desc())
.node();
} else {
return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
}
}
#endif