#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_registry.h"
#include "megbrain/utils/big_key_hashmap.h"
using namespace mgb;
using namespace serialization;
namespace {
class OprDumpContextMemory final : public OprDumpContextRawPOD {
std::vector<uint8_t> m_buf;
void write_raw(const void* data, size_t size) override {
auto pos = m_buf.size();
auto end = pos + size;
if (end > m_buf.capacity())
m_buf.reserve(end * 2);
m_buf.resize(end);
memcpy(m_buf.data() + pos, data, size);
}
void dump_tensor(
const std::string&, const HostTensorND&, TensorWriteMethod) override {
mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor");
}
const GraphDumpConfig& config() const override {
mgb_throw(GraphError, "OprDumpContextMemory has no associated config");
}
public:
OprDumpContextMemory() : OprDumpContextRawPOD(false) {}
auto&& buf() const { return m_buf; }
};
class OprLoadContextMemory final : public OprLoadContextRawPOD {
const uint8_t* m_ptr;
size_t m_size, m_pos = 0;
ComputingGraph* m_graph;
void read_raw(void* dest, size_t size) override {
auto end = m_pos + size;
mgb_assert(end <= m_size);
memcpy(dest, m_ptr + m_pos, size);
m_pos = end;
}
ComputingGraph& graph() override { return *m_graph; }
std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }
std::shared_ptr<DeviceTensorND> load_tensor_shared() override { mgb_assert(0); }
const GraphLoadConfig& config() const override {
mgb_throw(GraphError, "OprLoadContextMemory has no associated config");
}
public:
OprLoadContextMemory(ComputingGraph* graph, const OprDumpContextMemory& dumper)
: OprLoadContextRawPOD(false),
m_ptr{dumper.buf().data()},
m_size{dumper.buf().size()},
m_graph{graph} {}
~OprLoadContextMemory() { mgb_assert(m_pos == m_size); }
};
class ShallowCopyCacheContainer final : public UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
struct HashEq {
template <typename T>
static bool eq(const T& x, const T& y) {
return x == y;
}
static bool eq(const OperatorNodeConfig& x, const OperatorNodeConfig& y) {
return x.is_same(y);
}
static size_t hash(const void* ptr) { return std::hash<const void*>{}(ptr); }
static size_t hash(const VarNodeArray& inputs) {
return PODHash<VarNode*>::perform(inputs.data(), inputs.size());
}
static size_t hash(const OperatorNodeConfig& config) { return config.hash(); }
};
public:
big_key_hash_map::BigKeyHashMap<
cg::OperatorNodeBase*, HashEq,
big_key_hash_map::Copy<const cg::OperatorNodeBase*>,
big_key_hash_map::Ref<VarNodeArray>,
big_key_hash_map::Ref<OperatorNodeConfig>>
cache;
};
MGB_TYPEINFO_OBJ_IMPL(ShallowCopyCacheContainer);
}
ComputingGraph* serialization::OprShallowCopyContext::owner_graph(
const cg::OperatorNodeBase& opr, const VarNodeArray& inputs) const {
if (!m_owner_graph) {
if (inputs.empty())
return opr.owner_graph();
return inputs[0]->owner_graph();
}
if (!inputs.empty())
mgb_assert(m_owner_graph == inputs[0]->owner_graph());
return m_owner_graph;
}
cg::OperatorNodeBase* serialization::copy_opr_shallow(
const cg::OperatorNodeBase& opr, const VarNodeArray& inputs,
const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) {
auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo());
mgb_assert(
registry, "could not find OprReceiver to copy opr %s{%s}", opr.cname(),
opr.dyn_typeinfo()->name);
mgb_assert(inputs.size() == opr.input().size());
auto dst_og = ctx.owner_graph(opr, inputs);
auto do_copy = [&]() {
auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph();
auto ret = registry->shallow_copy(ctx, opr, inputs, config);
if (dst_og != opr.owner_graph() ||
opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) {
auto&& attr = ret->node_prop().attribute();
if (!attr.src_opr) {
auto src = cg::get_opr_root_source_opr(
const_cast<cg::OperatorNodeBase*>(&opr));
if (ret != src)
attr.src_opr = src;
}
if (!attr.priority) {
attr.priority = opr.node_prop().attribute().priority;
}
}
return ret;
};
cg::OperatorNodeBase* ret;
if (dst_og == opr.owner_graph()) {
auto&& cache =
dst_og->options()
.user_data.get_user_data_or_create<ShallowCopyCacheContainer>()
->cache;
auto ins = cache.get(&opr, inputs, config);
if (ins.first) {
*ins.second = do_copy();
} else {
cg::update_output_var_shapes(*ins.second);
}
ret = *ins.second;
} else {
ret = do_copy();
}
mgb_assert(
gopt::has_inplace_basic_arith_opt(opr) ||
(( opr.usable_output().size() ==
ret->usable_output().size()) &&
( (&opr != ret) || opr.input() == inputs)),
"bad opr copy: src=%s{%s} dst=%s{%s}", opr.cname(),
opr.dyn_typeinfo()->name, ret->cname(), ret->dyn_typeinfo()->name);
return ret;
}
cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr,
const VarNodeArray& inputs, const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(ctx);
auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo());
mgb_assert(
registry && registry->dumper && registry->loader,
"can not shallow_copy operator %s{%s}: "
"no dumper/loader registered",
opr.cname(), opr.dyn_typeinfo()->name);
OprDumpContextMemory dumper;
registry->dumper(dumper, opr);
OprLoadContextMemory loader{opr.owner_graph(), dumper};
return registry->loader(loader, inputs, config).opr();
}