#include "megbrain/serialization/opr_registry.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/utils/hash_ct.h"
using namespace mgb;
using namespace serialization;
namespace mgb {
void call_sereg();
}
namespace {
struct StaticData {
ThinHashMap<size_t, OprRegistry> id2reg;
ThinHashMap<Typeinfo*, OprRegistry*> type2reg;
std::unordered_map<std::string, OprRegistry*> name2reg;
ThinHashMap<size_t, OprRegistry*> unversioned_id2reg;
};
StaticData& static_data() {
static StaticData inst;
return inst;
}
OprWithOutputAccessor dynamic_loader(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto name = ctx.load_buf_with_len();
return ctx.make_opr_loader(name)(ctx, inputs, config);
}
const OprRegistry* dynamic_registry() {
static const OprRegistry* ret = nullptr;
if (ret)
return ret;
auto id = MGB_HASH_STR("dynamic");
OprRegistry::add({nullptr, id, {}, {}, dynamic_loader, {}, id});
ret = OprRegistry::find_by_id(id);
mgb_assert(ret);
return ret;
}
class _Init {
public:
_Init() {
call_sereg();
dynamic_registry();
}
};
_Init _init;
}
void OprRegistry::add(const OprRegistry& record) {
auto&& sd = static_data();
auto persist_id = record.persist_type_id;
auto registry_ins = sd.id2reg.emplace(persist_id, record);
mgb_assert(
registry_ins.second || persist_id == dynamic_registry()->persist_type_id,
"duplicated operator persist_type_id: %s",
std::to_string(persist_id).c_str());
OprRegistry* persis_record_ptr;
if (registry_ins.second) {
persis_record_ptr = ®istry_ins.first->second;
} else {
static std::vector<std::unique_ptr<OprRegistry>> dynamic_opr_reg;
mgb_assert(!record.loader);
dynamic_opr_reg.emplace_back(new OprRegistry{record});
persis_record_ptr = dynamic_opr_reg.back().get();
}
if (!record.type) {
mgb_assert(!record.dumper);
mgb_assert(!record.shallow_copy);
} else {
auto&& res = sd.type2reg.insert({record.type, persis_record_ptr});
mgb_assert(res.second, "duplicated OprRegistry type: %s", record.type->name);
if (!record.shallow_copy) {
res.first->second->shallow_copy = intl::copy_opr_shallow_default_impl;
}
}
if (!record.name.empty()) {
auto&& n2r = sd.name2reg[record.name];
mgb_assert(!n2r, "duplicated OprRegistry name: %s", record.name.c_str());
n2r = persis_record_ptr;
}
if (record.unversioned_type_id) {
auto&& res = sd.unversioned_id2reg.emplace(
record.unversioned_type_id, persis_record_ptr);
mgb_assert(
res.second || record.unversioned_type_id ==
dynamic_registry()->unversioned_type_id,
"duplicated OprRegistry unversioned id: %s",
std::to_string(record.unversioned_type_id).c_str());
}
}
const OprRegistry* OprRegistry::find_by_name(const std::string& name) {
auto&& name2reg = static_data().name2reg;
auto iter = name2reg.find(name);
return iter == name2reg.end() ? nullptr : iter->second;
}
const OprRegistry* OprRegistry::find_by_id(size_t id) {
auto&& id2reg = static_data().id2reg;
auto iter = id2reg.find(id);
return iter == id2reg.end() ? nullptr : &iter->second;
}
const OprRegistry* OprRegistry::find_by_type(Typeinfo* type) {
auto&& type2reg = static_data().type2reg;
auto iter = type2reg.find(type);
return iter == type2reg.end() ? nullptr : iter->second;
}
const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) {
auto&& uid2reg = static_data().unversioned_id2reg;
auto iter = uid2reg.find(unversioned_id);
return iter == uid2reg.end() ? nullptr : iter->second;
}
void OprRegistry::add_using_dynamic_loader(
Typeinfo* type, const std::string& name, const OprDumper& dumper) {
add({type,
dynamic_registry()->persist_type_id,
name,
dumper,
{},
{},
dynamic_registry()->unversioned_type_id});
}
#if MGB_ENABLE_DEBUG_UTIL
std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() {
auto&& id2reg = static_data().id2reg;
std::vector<std::pair<size_t, std::string>> result;
for (auto iter = id2reg.begin(); iter != id2reg.end(); ++iter) {
if (iter->second.name.size() == 0)
result.push_back({iter->first, "<special>"});
else
result.push_back({iter->first, iter->second.name});
}
return result;
}
#endif
namespace {
const VarNodeArray& default_accessor(const VarNodeArray& outputs) {
return outputs;
}
}
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr) {
m_accessor = &default_accessor;
};
OprWithOutputAccessor::OprWithOutputAccessor(
cg::OperatorNodeBase* opr, Accessor accessor)
: OprWithOutputAccessor(opr) {
if (accessor) {
m_accessor = accessor;
}
};