#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "../op_trait.h"
#include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb {
namespace imperative {
namespace {
class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD {
const OprAttr::Param& m_param;
size_t m_pos = 0;
ComputingGraph* m_graph;
void read_raw(void* dest, size_t size) override final {
mgb_assert(m_pos + size <= m_param.size(), "too many bytes requested");
memcpy(dest, m_param.data() + m_pos, size);
m_pos += size;
}
std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }
std::shared_ptr<DeviceTensorND> load_tensor_shared() override { mgb_assert(0); }
const serialization::GraphLoadConfig& config() const override { mgb_assert(0); }
public:
OprParamsLoadContext(const OprAttr::Param& param, ComputingGraph* graph)
: serialization::OprLoadContextRawPOD(false),
m_param(param),
m_graph(graph) {}
~OprParamsLoadContext() {
mgb_assert(m_pos == m_param.size(), "param not fully consumed");
}
ComputingGraph& graph() override { return *m_graph; }
};
class OprParamsDumpContext final : public serialization::OprDumpContextRawPOD {
public:
OprAttr::Param m_param;
OprParamsDumpContext() : serialization::OprDumpContextRawPOD(false) {}
void write_raw(const void* data, size_t size) {
const char* src = static_cast<const char*>(data);
m_param.insert(m_param.end(), src, src + size);
}
void dump_tensor(
const std::string& name, const HostTensorND& tensor,
TensorWriteMethod method) {
mgb_assert(0);
}
const serialization::GraphDumpConfig& config() const { mgb_assert(0); }
};
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& attr = def.cast_final_safe<OprAttr>();
auto config = attr.config;
config.name(attr.make_name());
mgb_assert(!inputs.empty());
auto registry = serialization::OprRegistry::find_by_name(attr.type);
mgb_assert(registry, "operator %s not found", attr.type.c_str());
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
return registry->loader(ctx, inputs, config).usable_output();
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
OprParamsDumpContext ctx;
auto registry = serialization::OprRegistry::find_by_type(opr->dyn_typeinfo());
mgb_assert(registry, "operator %s not found", opr->dyn_typeinfo()->name);
mgb_assert(
registry->dumper, "operator %s cannot be serialized",
opr->dyn_typeinfo()->name);
registry->dumper(ctx, *opr);
return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config());
}
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
return {};
}
std::string make_name(const OpDef& def) {
auto&& attr = def.cast_final_safe<OprAttr>();
return attr.type;
}
OP_TRAIT_REG(OprAttr, OprAttr)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.props(props)
.make_name(make_name)
.fallback();
}
bool OprAttr::is_same_st(const Hashable& rhs_) const {
auto&& rhs = static_cast<const OprAttr&>(rhs_);
return type == rhs.type && param == rhs.param &&
config.comp_node() == rhs.config.comp_node() &&
config.output_dtype() == rhs.config.output_dtype();
}
size_t OprAttr::hash() const {
return hash_pair_combine(
hash_pair_combine(
mgb::hash(type), mgb::hash(static_cast<std::vector<char>>(param))),
config.hash());
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr);
} }