#include "megdnn/basic_types.h"
#include "src/common/handle_impl.h"
#include "src/common/utils.h"
#include "src/fallback/handle.h"
#include "src/naive/handle.h"
#include "midout.h"
#if MEGDNN_X86
#include "src/x86/handle.h"
#endif
#if MEGDNN_ARMV7
#include "src/armv7/handle.h"
#endif
#if MEGDNN_AARCH64
#include "src/aarch64/handle.h"
#endif
#if MEGDNN_WITH_CUDA
#include "src/cuda/handle.h"
#endif
#if MEGDNN_WITH_CAMBRICON
#include "src/cambricon/handle.h"
#endif
#ifdef MEGDNN_WITH_ATLAS
#include "src/atlas/handle.h"
#endif
using namespace megdnn;
MIDOUT_DECL(HandlePlatform);
MIDOUT_DECL(HandleOpr);
Handle::Handle(megcoreComputingHandle_t computing_handle, HandleType type)
: m_computing_handle(computing_handle), m_handle_type(type) {}
std::unique_ptr<Handle> Handle::make(
megcoreComputingHandle_t computing_handle, int debug_level) {
(void)debug_level;
megcoreDeviceHandle_t device_handle;
megcorePlatform_t platform;
megcoreGetDeviceHandle(computing_handle, &device_handle);
megcoreGetPlatform(device_handle, &platform);
if (platform == megcorePlatformCPU) {
MIDOUT_BEGIN(HandlePlatform, midout_iv(megcorePlatformCPU)) {
#if MEGDNN_NAIVE
return make_unique<naive::HandleImpl>(computing_handle);
#else
if (debug_level == 0) {
#if MEGDNN_X86
return std::unique_ptr<x86::HandleImpl>(
new x86::HandleImpl(computing_handle));
#elif MEGDNN_ARMV7
return make_unique<armv7::HandleImpl>(computing_handle);
#elif MEGDNN_AARCH64
return make_unique<aarch64::HandleImpl>(computing_handle);
#else
return make_unique<fallback::HandleImpl>(computing_handle);
#endif
} else if (debug_level == 1) {
return make_unique<fallback::HandleImpl>(computing_handle);
} else if (debug_level == 2) {
return make_unique<naive::HandleImpl>(computing_handle);
} else {
megdnn_throw("Debug level must be 0/1/2.");
}
#endif
}
MIDOUT_END();
}
else if (platform == megcorePlatformROCM) {
#if MEGDNN_WITH_ROCM
return make_rocm_handle(computing_handle);
#else
return nullptr;
#endif
} else if (platform == megcorePlatformCambricon) {
#if MEGDNN_WITH_CAMBRICON
return make_unique<cambricon::HandleImpl>(computing_handle);
#else
return nullptr;
#endif
} else if (platform == megcorePlatformAtlas) {
#if MEGDNN_WITH_ATLAS
return make_unique<atlas::HandleImpl>(computing_handle);
#else
return nullptr;
#endif
}
else {
megdnn_throw_if(
platform != megcorePlatformCUDA, megdnn_error,
"platform should be CUDA Platform");
#if MEGDNN_WITH_CUDA
return make_unique<cuda::HandleImpl>(computing_handle);
#else
return nullptr;
#endif
}
return nullptr;
}
void Handle::set_destructor(const thin_function<void()>& d) {
megdnn_assert(!m_destructor, "destructor can be set only once");
m_destructor = d;
}
Handle::~Handle() {
if (m_destructor)
m_destructor();
m_alive_magic = 0;
}
size_t Handle::alignment_requirement() const {
return 32;
}
size_t Handle::image2d_pitch_alignment() const {
megdnn_throw("image2d tensor format not supported on this handle");
}
megdnn::HandleImplHelper::HandleVendorType Handle::vendor_type() const {
return HandleVendorType::NOT_SPEC;
}
bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) {
return src.is_contiguous();
}
void Handle::on_opr_destructed(OperatorBase* opr) {
if (m_alive_magic != ALIVE_MAGIC) {
megdnn_log_error(
"Handle is destructed before opr gets destructed. "
"Please fix the destruction order as this would cause "
"undefined memory access. "
"Abort now to avoid further problems.");
abort();
}
if (m_on_opr_destructed) {
m_on_opr_destructed(opr);
}
}
OperatorBase::~OperatorBase() {
m_handle->on_opr_destructed(this);
}
template <typename Opr>
std::unique_ptr<Opr> Handle::create_operator() {
#define CASE(etype, nm) \
case HandleType::etype: { \
MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::etype)) { \
return static_cast<nm::HandleImpl*>(this)->create_operator<Opr>(); \
} \
MIDOUT_END(); \
}
switch (m_handle_type) {
CASE(NAIVE, naive);
#if !MEGDNN_NAIVE
CASE(FALLBACK, fallback);
#if MEGDNN_X86
CASE(X86, x86);
#endif
#if MEGDNN_ARMV7
CASE(ARMV7, armv7);
#endif
#if MEGDNN_AARCH64
CASE(AARCH64, aarch64);
#endif
#if MEGDNN_ARMV7 || MEGDNN_AARCH64
CASE(ARM_COMMON, arm_common);
#endif
#endif #if MEGDNN_WITH_CUDA
CASE(CUDA, cuda);
#endif
#if MEGDNN_WITH_ATLAS
CASE(ATLAS, atlas);
#endif
#if MEGDNN_WITH_ROCM
case HandleType::ROCM: {
MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::ROCM)) {
return create_rocm_operator<Opr>();
}
MIDOUT_END();
}
#endif
#if MEGDNN_WITH_CAMBRICON
CASE(CAMBRICON, cambricon);
#endif
default:
megdnn_throw("bad handle type");
}
#undef CASE
}
#define INST(opr) template std::unique_ptr<opr> Handle::create_operator();
MEGDNN_FOREACH_OPR_CLASS(INST)
#undef INST