#include "megbrain/comp_node_env.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/sereg.h"
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h"
#endif
namespace mgb {
namespace serialization {
#if MGB_ENABLE_FBS_SERIALIZATION
namespace fbs {
template <>
struct ParamConverter<opr::Host2DeviceCopy::Param> {
using FlatBufferType = param::Host2DeviceCopy;
static opr::Host2DeviceCopy::Param to_param(const FlatBufferType* fb) {
return {fb->enable_value_infer(), fb->dump_default_value(),
fb->allow_cpu_mem_fwd()};
}
static flatbuffers::Offset<FlatBufferType> to_flatbuffer(
flatbuffers::FlatBufferBuilder& builder,
const opr::Host2DeviceCopy::Param& p) {
return param::CreateHost2DeviceCopy(
builder, p.enable_value_infer, p.dump_default_value,
p.allow_cpu_mem_fwd);
}
};
} #endif
template <>
struct OprLoadDumpImpl<opr::Host2DeviceCopy, 0> {
using Opr = opr::Host2DeviceCopy;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param(opr.param());
using Meth = OprDumpContext::TensorWriteMethod;
ctx.dump_tensor(
opr.name(), *opr.host_data(),
opr.param().dump_default_value ? Meth::VALUE_INPUT : Meth::META_INPUT);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto param = ctx.read_param<Opr::Param>();
auto tensor = ctx.load_tensor();
return Opr::make(ctx.graph(), tensor, param, config).node()->owner_opr();
}
};
template <class Opr>
struct SharedDeviceTensorLoadDump {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
HostTensorND val;
val.copy_from(opr.get_dev_tensor()).sync();
ctx.dump_tensor(opr.name(), val, Meth::VALUE_SHARED);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto val = ctx.load_tensor_shared();
return Opr::make(ctx.graph(), val, config).node()->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::SharedDeviceTensor, 0>
: public SharedDeviceTensorLoadDump<opr::SharedDeviceTensor> {};
template <>
struct OprLoadDumpImpl<opr::VolatileSharedDeviceTensor, 0>
: public SharedDeviceTensorLoadDump<opr::VolatileSharedDeviceTensor> {};
template <>
struct OprLoadDumpImpl<opr::SharedDeviceTensorWithFormat, 0> {
using Opr = opr::SharedDeviceTensorWithFormat;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
HostTensorND val;
val.copy_from(opr.get_dev_tensor()).sync();
ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS);
auto param_bin = opr.get_dev_tensor().format().serialize();
ctx.dump_buf_with_len(param_bin.data(), param_bin.size());
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto val = ctx.load_tensor();
auto handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(val->comp_node()))
.handle();
auto format = TensorFormat::deserialize(ctx.load_buf_with_len(), handle);
TensorLayout layout_with_format = {val->shape(), val->dtype(), format};
auto dev_val =
std::make_shared<DeviceTensorND>(val->comp_node(), layout_with_format);
dev_val->copy_from_fixlayout(*val);
auto out_var = Opr::make(ctx.graph(), dev_val, config);
dev_val->sync();
return out_var.node()->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::ImmutableTensor, 0> {
using Opr = opr::ImmutableTensor;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.dump_tensor(
{}, HostTensorND{}.copy_from(opr.value()).sync(),
Meth::VALUE_ANONYMOUS);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto val = ctx.load_tensor();
return Opr::make(ctx.graph(), *val, config).node()->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::MultipleDeviceTensorHolder, 0> {
using Opr = opr::MultipleDeviceTensorHolder;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
uint32_t nr_val = opr.values().size();
ctx.dump_buf_with_len(&nr_val, sizeof(nr_val));
for (uint32_t i = 0; i < nr_val; ++i) {
HostTensorND val;
val.copy_from(*opr.values()[i]).sync();
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED);
}
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
uint32_t nr;
{
auto t = ctx.load_buf_with_len();
mgb_assert(t.size() == sizeof(nr));
memcpy(&nr, t.data(), sizeof(nr));
}
Opr::ValueArray values(nr);
for (auto&& i : values) {
i = ctx.load_tensor_shared();
}
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::MultipleDeviceTensorWithFormatHolder, 0> {
using Opr = opr::MultipleDeviceTensorWithFormatHolder;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
uint32_t nr_val = opr.values().size();
ctx.dump_buf_with_len(&nr_val, sizeof(nr_val));
for (uint32_t i = 0; i < nr_val; ++i) {
HostTensorND val;
auto value = *opr.values()[i];
val.copy_from(value).sync();
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED);
auto param_bin = value.format().serialize();
ctx.dump_buf_with_len(param_bin.data(), param_bin.size());
}
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
uint32_t nr;
{
auto t = ctx.load_buf_with_len();
mgb_assert(t.size() == sizeof(nr));
memcpy(&nr, t.data(), sizeof(nr));
}
Opr::ValueArray values(nr);
for (auto&& i : values) {
i = ctx.load_tensor_shared();
auto handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(i->comp_node()))
.handle();
auto format = TensorFormat::deserialize(ctx.load_buf_with_len(), handle);
TensorLayout layout_with_format{i->layout(), i->layout().dtype, format};
if (i->storage().comp_node().mem_node() ==
CompNode::default_cpu().mem_node()) {
mgb_assert(
i->storage().ptr(),
"storage should not be nullptr if mem_node is "
"default_cpu");
HostTensorND src{i->storage().comp_node(), layout_with_format};
src.copy_from_fixlayout(*i).sync();
*i = DeviceTensorND::make_proxy(src);
} else {
auto size = layout_with_format.span().dist_byte();
DeviceTensorStorage storage;
storage.reset(i->comp_node(), size, nullptr);
i->reset(storage, layout_with_format);
}
}
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr();
}
};
}
namespace opr {
cg::OperatorNodeBase* opr_shallow_copy_h2d(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto&& opr = opr_.cast_final_safe<Host2DeviceCopy>();
return Host2DeviceCopy::make(
*ctx.owner_graph(opr, inputs), opr.host_data(), opr.param(), config)
.node()
->owner_opr();
}
template <class Opr>
cg::OperatorNodeBase* opr_shallow_copy_shared_device_tensor(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto&& opr = opr_.cast_final_safe<Opr>();
return Opr::make(
*ctx.owner_graph(opr, inputs), opr.dev_data(), opr.const_value(),
config)
.node()
->owner_opr();
}
cg::OperatorNodeBase* opr_shallow_copy_immutable_tensor(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto&& opr = opr_.cast_final_safe<ImmutableTensor>();
auto graph = ctx.owner_graph(opr, inputs);
return opr.shallow_copy(*graph, config).node()->owner_opr();
}
MGB_SEREG_OPR(Host2DeviceCopy, 0);
MGB_REG_OPR_SHALLOW_COPY(Host2DeviceCopy, opr_shallow_copy_h2d);
MGB_SEREG_OPR(SharedDeviceTensor, 0);
MGB_REG_OPR_SHALLOW_COPY(
SharedDeviceTensor, opr_shallow_copy_shared_device_tensor<SharedDeviceTensor>);
MGB_SEREG_OPR(SharedDeviceTensorWithFormat, 0);
MGB_SEREG_OPR(VolatileSharedDeviceTensor, 0);
MGB_REG_OPR_SHALLOW_COPY(
VolatileSharedDeviceTensor,
opr_shallow_copy_shared_device_tensor<VolatileSharedDeviceTensor>);
MGB_SEREG_OPR(ImmutableTensor, 0);
MGB_REG_OPR_SHALLOW_COPY(ImmutableTensor, opr_shallow_copy_immutable_tensor);
MGB_SEREG_OPR(Copy, 1);
MGB_SEREG_OPR(MultipleDeviceTensorHolder, 0);
MGB_SEREG_OPR(MultipleDeviceTensorWithFormatHolder, 0);
} }