#include "megbrain/cambricon/cambricon_runtime_opr.h"
#include "megbrain/common.h"
#if MGB_CAMBRICON
using namespace mgb;
using namespace opr;
namespace {
SmallVector<int> mgb_shape_to_cnrt_shape(TensorShape mgb_shp) {
int ndim = mgb_shp.ndim;
SmallVector<int> cnrt_shp(ndim);
for (int i = 0; i < ndim; ++i) {
cnrt_shp[i] = mgb_shp[i];
}
return cnrt_shp;
}
TensorShape cnrt_shape_to_mgb_shape(int* dim_values, int dim_num) {
TensorShape ret;
ret.ndim = dim_num;
for (int i = 0; i < dim_num; ++i) {
ret[i] = dim_values[i];
}
return ret;
}
DType cnrt_dtype_to_mgb_dtype(cnrtDataType_t data_type) {
switch (data_type) {
case CNRT_FLOAT16:
#if !MEGDNN_DISABLE_FLOAT16
return dtype::Float16();
#else
mgb_throw(MegBrainError, "Float16 support is disabled at compile time.");
#endif
case CNRT_FLOAT32:
return dtype::Float32();
case CNRT_INT8:
return dtype::QuantizedS8(1.f);
case CNRT_INT16:
return dtype::Int16();
case CNRT_INT32:
return dtype::Int32();
case CNRT_UINT8:
return dtype::Uint8();
case CNRT_QUANT8:
return dtype::QuantizedS8(1.f);
default:
mgb_throw(
MegBrainError, "cnrtDataType %x is not supported by MegBrain.",
data_type);
}
}
cnrtDataType_t mgb_dtype_to_cnrt_dtype(DType data_type) {
switch (data_type.enumv()) {
#if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return CNRT_FLOAT16;
#endif
case DTypeEnum::Float32:
return CNRT_FLOAT32;
case DTypeEnum::QuantizedS8:
return CNRT_QUANT8;
case DTypeEnum::Quantized8Asymm:
return CNRT_QUANT8;
case DTypeEnum::Int8:
return CNRT_INT8;
case DTypeEnum::Int32:
return CNRT_INT32;
case DTypeEnum::Uint8:
return CNRT_UINT8;
default:
mgb_throw(
MegBrainError, "megbrain data type %s is not supported by cnrt.",
data_type.name());
}
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CambriconRuntimeOpr);
CambriconRuntimeOpr::CambriconRuntimeOpr(
SharedBuffer buf, std::string symbol, const VarNodeArray& inputs,
bool tensor_dim_mutable, const OperatorNodeConfig& config)
: Super(inputs[0]->owner_graph(), config, "cambricon_runtime", inputs),
m_buffer{std::move(buf)},
m_symbol{std::move(symbol)},
m_model{nullptr},
m_function{nullptr},
m_context{nullptr},
m_tensor_dim_mutable{tensor_dim_mutable} {
mgb_assert(
inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON,
"CambriconRuntimeOpr can only be used on cambricon comp node; "
"got %s",
inputs[0]->comp_node().to_string().c_str());
for (auto i : inputs) {
add_input({i});
}
if (m_model == nullptr) {
m_model = {new cnrtModel_t(), cnrt_intl::ModelUnloader()};
MGB_CNRT_CHECK(cnrtLoadModelFromMem(
m_model.get(),
reinterpret_cast<char*>(const_cast<void*>(m_buffer.data()))));
}
if (m_function == nullptr) {
m_function = {new cnrtFunction_t(), cnrt_intl::FunctionDeleter()};
MGB_CNRT_CHECK(cnrtCreateFunction(m_function.get()));
MGB_CNRT_CHECK(
cnrtExtractFunction(m_function.get(), *m_model, m_symbol.c_str()));
}
int nr_inputs = 0;
int nr_outputs = 0;
int64_t* inputs_size = nullptr;
int64_t* outputs_size = nullptr;
MGB_CNRT_CHECK(cnrtGetInputDataSize(&inputs_size, &nr_inputs, *m_function));
mgb_assert(
static_cast<size_t>(nr_inputs) == inputs.size(),
"inputs size mismatch: expect=%d, got=%zu", nr_inputs, inputs.size());
MGB_CNRT_CHECK(cnrtGetOutputDataSize(&outputs_size, &nr_outputs, *m_function));
if (nr_outputs == 1) {
add_output(None);
} else {
for (int i = 0; i < nr_outputs; ++i) {
add_output(ssprintf("o%d", i));
}
}
add_equivalence_component<mgb::ScalarHash<const void*>>(m_buffer.data());
};
void CambriconRuntimeOpr::scn_do_execute() {
mgb_assert(m_function != nullptr);
auto&& cnrt_env = CompNodeEnv::from_comp_node(input(0)->comp_node()).cnrt_env();
cnrt_env.activate();
if (m_context == nullptr) {
m_context = {new cnrtRuntimeContext_t(), cnrt_intl::RuntimeContextDeleter()};
MGB_CNRT_CHECK(cnrtCreateRuntimeContext(m_context.get(), *m_function, nullptr));
int dev_id = cnrt_env.device;
MGB_CNRT_CHECK(cnrtSetRuntimeContextDeviceId(*m_context, dev_id));
MGB_CNRT_CHECK(cnrtInitRuntimeContext(*m_context, nullptr));
}
size_t nr_inputs = input().size(), nr_outputs = output().size();
SmallVector<void*> params(nr_inputs + nr_outputs);
SmallVector<cnrtParamDesc_t> param_descs(nr_inputs + nr_outputs);
for (size_t i = 0; i < nr_inputs; ++i) {
params[i] = input(i)->dev_tensor().raw_ptr();
MGB_CNRT_CHECK(cnrtCreateParamDesc(¶m_descs[i]));
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc(
param_descs[i], mgb_dtype_to_cnrt_dtype(input(i)->dtype())));
auto dims = mgb_shape_to_cnrt_shape(input(i)->shape());
MGB_CNRT_CHECK(cnrtSetShapeToParamDesc(
param_descs[i], dims.data(), static_cast<int>(dims.size())));
}
for (size_t i = 0; i < nr_outputs; ++i) {
params[nr_inputs + i] = output(i)->dev_tensor().raw_ptr();
MGB_CNRT_CHECK(cnrtCreateParamDesc(¶m_descs[nr_inputs + i]));
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc(
param_descs[nr_inputs + i],
mgb_dtype_to_cnrt_dtype(output(i)->dtype())));
auto dims = mgb_shape_to_cnrt_shape(output(i)->shape());
MGB_CNRT_CHECK(cnrtSetShapeToParamDesc(
param_descs[nr_inputs + i], dims.data(),
static_cast<int>(dims.size())));
}
MGB_CNRT_CHECK(cnrtInvokeRuntimeContext_V2(
*m_context, param_descs.data(), params.data(), cnrt_env.queue, nullptr));
for (auto& param : param_descs) {
MGB_CNRT_CHECK(cnrtDestroyParamDesc(param));
}
}
void CambriconRuntimeOpr::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
mgb_assert(m_function != nullptr);
mgb_assert(input().size() == inp_shape.size());
if (m_tensor_dim_mutable) {
cnrtParamDescArray_t input_descs, output_descs;
int inp_param_num = input().size();
int out_param_num = output().size();
MGB_CNRT_CHECK(cnrtCreateParamDescArray(&input_descs, inp_param_num));
MGB_CNRT_CHECK(cnrtCreateParamDescArray(&output_descs, out_param_num));
for (int i = 0; i < inp_param_num; ++i) {
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc(
input_descs[i], mgb_dtype_to_cnrt_dtype(input(i)->dtype())));
auto dims = mgb_shape_to_cnrt_shape(inp_shape[i]);
MGB_CNRT_CHECK(cnrtSetShapeToParamDesc(
input_descs[i], dims.data(), static_cast<int>(dims.size())));
}
MGB_CNRT_CHECK(cnrtInferFunctionOutputShape(
*m_function, inp_param_num, input_descs, out_param_num, output_descs));
for (int i = 0; i < out_param_num; ++i) {
int* dims = nullptr;
int dim_num = 0;
MGB_CNRT_CHECK(cnrtGetShapeFromParamDesc(output_descs[i], &dims, &dim_num));
out_shape[i] = cnrt_shape_to_mgb_shape(dims, dim_num);
}
MGB_CNRT_CHECK(cnrtDestroyParamDescArray(input_descs, inp_param_num));
MGB_CNRT_CHECK(cnrtDestroyParamDescArray(output_descs, out_param_num));
} else {
for (size_t i = 0; i < inp_shape.size(); ++i) {
int* dim_values = nullptr;
int dim_num = 0;
MGB_CNRT_CHECK(cnrtGetInputDataShape(
&dim_values, &dim_num, static_cast<int>(i), *m_function));
auto shp_in_func = cnrt_shape_to_mgb_shape(dim_values, dim_num);
auto inpshp = inp_shape[i];
MGB_MARK_USED_VAR(shp_in_func);
mgb_assert(
inpshp.eq_shape(shp_in_func),
"input shape(%s) mismatch with that(%s) in cnrtFunction_t.",
inpshp.to_string().c_str(), shp_in_func.to_string().c_str());
}
MGB_MARK_USED_VAR(mgb_dtype_to_cnrt_dtype);
for (size_t i = 0; i < out_shape.size(); ++i) {
int* dim_values = nullptr;
int dim_num = 0;
MGB_CNRT_CHECK(cnrtGetOutputDataShape(
&dim_values, &dim_num, static_cast<int>(i), *m_function));
out_shape[i] = cnrt_shape_to_mgb_shape(dim_values, dim_num);
}
}
}
void CambriconRuntimeOpr::add_input_layout_constraint() {
for (auto i : input()) {
i->add_layout_constraint_contiguous();
}
}
void CambriconRuntimeOpr::init_output_dtype() {
cnrtDataType_t* inp_dtype_array = nullptr;
int inp_num;
MGB_CNRT_CHECK(cnrtGetInputDataType(&inp_dtype_array, &inp_num, *m_function));
for (size_t i = 0; i < input().size(); ++i) {
auto dt_cnrt = cnrt_dtype_to_mgb_dtype(inp_dtype_array[i]);
auto dt_inp = input(i)->dtype();
MGB_MARK_USED_VAR(dt_cnrt);
MGB_MARK_USED_VAR(dt_inp);
mgb_assert(
dt_cnrt.valid() && dt_inp.valid() && dt_cnrt.enumv() == dt_inp.enumv(),
"Input %zu's data type mismatch with that in "
"cnrtFunction_t: expected %s, got %s",
i, dt_cnrt.name(), dt_inp.name());
}
cnrtDataType_t* out_dtype_array = nullptr;
int out_num;
MGB_CNRT_CHECK(cnrtGetOutputDataType(&out_dtype_array, &out_num, *m_function));
for (size_t i = 0; i < output().size(); ++i) {
auto dt_cnrt = cnrt_dtype_to_mgb_dtype(out_dtype_array[i]);
mgb_assert(
dt_cnrt.valid(),
"output dtype checking failed: invalid dtype returned.");
if (dt_cnrt.enumv() == DTypeEnum::QuantizedS8) {
mgb_assert(
output(i)->dtype().valid(),
"user should specify scale of output tensor of "
"CambriconRuntimeOpr.");
}
if (!output(i)->dtype().valid())
output(i)->dtype(dt_cnrt);
}
}
SymbolVarArray CambriconRuntimeOpr::make(
SharedBuffer buf, std::string symbol, const SymbolVarArray& src,
bool tensor_dim_mutable, const OperatorNodeConfig& config) {
VarNodeArray var_node_array = cg::to_var_node_array(src);
auto cambricon_runtime_opr = std::make_unique<CambriconRuntimeOpr>(
std::move(buf), std::move(symbol), var_node_array, tensor_dim_mutable,
config);
auto ret = cg::to_symbol_var_array(
src[0].node()
->owner_graph()
->insert_opr(std::move(cambricon_runtime_opr))
->output());
return ret;
}
SymbolVarArray CambriconRuntimeOpr::make(
const void* buf, size_t size, std::string symbol, const SymbolVarArray& src,
bool tensor_dim_mutable, const OperatorNodeConfig& config) {
mgb_throw_if(
!CompNode::get_device_count(CompNode::DeviceType::CAMBRICON), SystemError,
"can not create CambriconRuntimeOpr when Cambricon is not "
"available");
std::shared_ptr<uint8_t> shptr{new uint8_t[size], [](uint8_t* p) { delete[] p; }};
memcpy(shptr.get(), buf, size);
SharedBuffer buffer{std::move(shptr), size};
return make(std::move(buffer), std::move(symbol), src, tensor_dim_mutable, config);
}
#endif