#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/helper.h"
using namespace mgb;
using namespace serialization;
MGB_TYPEINFO_OBJ_IMPL(OprLoadContext);
OprLoader OprLoadContext::make_opr_loader(const std::string& id) {
auto&& maker = config().opr_loader_maker;
mgb_throw_if(
!maker, SerializationError,
"opr_loader_maker not set in LoadConfig; but opr loader with "
"id %s is needed",
id.c_str());
return maker(id);
}
template <>
void OprDumpContextRawPOD::write_param(const DType& param) {
if (m_check_param_tag) {
uint32_t tag = megdnn::param::FakeSerializedDType::TAG;
write_raw(&tag, sizeof(tag));
}
serialization::serialize_dtype(
param, [this](const void* data, size_t len) { write_raw(data, len); });
}
template <>
DType OprLoadContextRawPOD::read_param() {
if (m_check_param_tag) {
uint32_t tag;
read_raw(&tag, sizeof(tag));
mgb_throw_if(
tag != megdnn::param::FakeSerializedDType::TAG, MegBrainError,
"ERROR tag");
}
return serialization::deserialize_dtype(
[this](void* data, size_t len) { read_raw(data, len); });
}
std::string OprLoadContextRawPOD::load_buf_with_len() {
std::string ret;
uint32_t size;
read_raw(&size, sizeof(size));
ret.resize(size);
read_raw(&ret[0], size);
return ret;
}
SharedBuffer OprLoadContextRawPOD::load_shared_buf_with_len() {
uint32_t size;
read_raw(&size, sizeof(size));
return load_shared_buf(size);
}
void GraphDumpConfig::default_tensor_value_dumper(
OutputFile& fout, const cg::OperatorNodeBase& ,
const HostTensorND& tensor) {
auto size = tensor.layout().span().high_byte;
fout.write(tensor.raw_ptr(), size);
}
void GraphLoadConfig::default_tensor_value_loader(
void* ptr, const TensorLayout& layout, InputFile& fin) {
auto sz = layout.span().high_byte;
if (ptr) {
fin.read(ptr, sz);
} else {
fin.skip(sz);
}
}
SharedBuffer OprLoadContextRawPOD::load_shared_buf(size_t size) {
std::shared_ptr<uint8_t> shptr{new uint8_t[size], [](uint8_t* p) { delete[] p; }};
read_raw(shptr.get(), size);
return {std::move(shptr), size};
}