#include "megbrain/opr/internal/indexing_helper_sereg.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/sereg.h"
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h"
#endif
MGB_SEREG_GET_SUBTENSOR_OPR(Subtensor);
MGB_SEREG_MODIFY_SUBTENSOR_OPR(SetSubtensor);
MGB_SEREG_MODIFY_SUBTENSOR_OPR(IncrSubtensor);
namespace mgb {
namespace serialization {
template <>
struct OprMaker<opr::Padding, 1> {
using Opr = opr::Padding;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 1) {
return Opr::make(inputs[0], param, config).node()->owner_opr();
} else {
return nullptr;
}
}
};
template <>
struct OprMaker<opr::PaddingBackward, 2> {
using Opr = opr::PaddingBackward;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, config).node()->owner_opr();
} else {
return nullptr;
}
}
};
template <>
struct OprMaker<opr::Concat, 0> : public OprMakerVariadic<opr::Concat> {};
template <>
struct OprMaker<opr::GetVarShape, 0> : public OprMakerVariadic<opr::GetVarShape> {};
template <>
struct OprLoadDumpImpl<opr::Split, 0> {
using Split = opr::Split;
using Options = Split::Options;
using Method = Options::Method;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<opr::Split>();
auto&& opt = opr.options();
mgb_assert(
opt.method == Method::SPECIFY,
"only Spllit with SPECIFY output shapes can be serialized");
ctx.write_param<megdnn::param::Axis>(opt.axis);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<megdnn::param::Axis>();
opr::Split::Options opt;
opt.method = Method::SPECIFY;
opt.axis = param.axis;
mgb_assert(inputs.size() > 1);
opt.nr_part = inputs.size() - 1;
opt.partition.resize(opt.nr_part);
for (size_t i = 1; i < inputs.size(); ++i)
opt.partition[i - 1] = inputs[i];
return Split::make(inputs[0], opt, config)[0].node()->owner_opr();
}
};
#if MGB_ENABLE_FBS_SERIALIZATION
namespace fbs {
template <>
struct ParamConverter<opr::Dimshuffle::Param> {
using FlatBufferType = param::Dimshuffle;
static opr::Dimshuffle::Param to_param(const FlatBufferType* fb) {
opr::Dimshuffle::Param param;
param.ndim = fb->ndim();
if (fb->pattern()) {
param.pattern_len = fb->pattern()->size();
mgb_assert(
param.pattern_len <=
sizeof(param.pattern) / sizeof(param.pattern[0]));
memcpy(param.pattern, fb->pattern()->data(),
sizeof(param.pattern[0]) * param.pattern_len);
} else {
param.pattern_len = 0;
}
return param;
}
static flatbuffers::Offset<FlatBufferType> to_flatbuffer(
flatbuffers::FlatBufferBuilder& builder, const opr::Dimshuffle::Param& p) {
return param::CreateDimshuffle(
builder, builder.CreateVector(p.pattern, p.pattern_len), p.ndim);
}
};
template <>
struct ParamConverter<opr::AxisAddRemove::Param> {
using FlatBufferType = param::AxisAddRemove;
static opr::AxisAddRemove::Param to_param(const FlatBufferType* fb) {
opr::AxisAddRemove::Param param;
if (fb->desc()) {
param.nr_desc = fb->desc()->size();
for (uint32_t i = 0; i < param.nr_desc; i++) {
param.desc[i].axis = fb->desc()->Get(i)->axis();
param.desc[i].method =
static_cast<opr::AxisAddRemove::AxisDesc::Method>(
fb->desc()->Get(i)->method());
}
} else {
param.nr_desc = 0;
}
return param;
}
static flatbuffers::Offset<FlatBufferType> to_flatbuffer(
flatbuffers::FlatBufferBuilder& builder,
const opr::AxisAddRemove::Param& p) {
std::vector<param::AxisDesc> desc(p.nr_desc);
for (uint32_t i = 0; i < p.nr_desc; i++) {
desc[i] = {
static_cast<param::AxisDescMethod>(p.desc[i].method),
p.desc[i].axis.get_raw()};
}
return param::CreateAxisAddRemoveDirect(builder, &desc);
}
};
} #endif
}
namespace opr {
MGB_SEREG_OPR(Broadcast, 2);
MGB_SEREG_OPR(Dimshuffle, 1);
MGB_SEREG_OPR(AxisAddRemove, 1);
MGB_SEREG_OPR(Concat, 0);
using GetVarShapeV1 = opr::GetVarShape;
MGB_SEREG_OPR(GetVarShapeV1, 0);
using ReshapeV1 = opr::Reshape;
MGB_SEREG_OPR(ReshapeV1, 2);
cg::OperatorNodeBase* opr_shallow_copy_split(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<Split>();
auto option = opr.options();
using Meth = Split::Options::Method;
switch (option.method) {
case Meth::CALL_BACK:
mgb_assert(inputs.size() == 1);
break;
case Meth::SPECIFY:
mgb_assert(inputs.size() == 1 + option.partition.size());
for (size_t i = 0; i < option.partition.size(); ++i)
option.partition[i] = inputs[i + 1];
break;
}
return Split::make(inputs[0], option, config).at(0).node()->owner_opr();
}
MGB_SEREG_OPR(Split, 0);
MGB_REG_OPR_SHALLOW_COPY(Split, opr_shallow_copy_split);
cg::OperatorNodeBase* opr_shallow_copy_param_pack_split(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<ParamPackSplit>();
auto&& offsets = opr.get_offsets();
auto&& shape = opr.get_output_shapes();
return ParamPackSplit::make(inputs[0], offsets, shape, config)
.at(0)
.node()
->owner_opr();
}
MGB_REG_OPR_SHALLOW_COPY(ParamPackSplit, opr_shallow_copy_param_pack_split);
cg::OperatorNodeBase* opr_shallow_copy_param_pack_concat(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<ParamPackConcat>();
auto&& offsets = opr.get_offsets();
SymbolVarArray ivar{inputs.size() - 1};
for (size_t i = 0; i < inputs.size() - 1; ++i)
ivar[i] = inputs[i];
return ParamPackConcat::make(ivar, inputs.back(), offsets, config)
.node()
->owner_opr();
}
MGB_REG_OPR_SHALLOW_COPY(ParamPackConcat, opr_shallow_copy_param_pack_concat);
using RelayoutFormatV1 = opr::RelayoutFormat;
MGB_SEREG_OPR(RelayoutFormatV1, 1);
MGB_SEREG_OPR(Padding, 1);
MGB_SEREG_OPR(PaddingBackward, 2);
}
}