#include "megbrain/rdnn/management.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/metahelper.h"
#include "megdnn/handle.h"
#include "megdnn/oprs.h"
using namespace mgb;
using namespace mgb::opr;
namespace {
template <class Opr>
class MegDNNGlobalOprContainer final : public UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
std::shared_ptr<megdnn::Handle> m_megdnn_handle;
std::unique_ptr<Opr> m_opr;
public:
MegDNNGlobalOprContainer(CompNode cn)
: m_megdnn_handle{intl::get_megdnn_handle_shared(cn)},
m_opr{m_megdnn_handle->create_operator<Opr>()} {
mgb_assert(m_opr->is_thread_safe());
}
Opr* get() const { return m_opr.get(); }
};
template <class Opr>
MGB_TYPEINFO_OBJ_IMPL(MegDNNGlobalOprContainer<Opr>);
}
std::shared_ptr<megdnn::Handle> intl::get_megdnn_handle_shared(CompNode comp_node) {
auto& handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node));
return {handle.shared_from_this(), handle.handle()};
}
megdnn::Handle* intl::get_megdnn_handle(CompNode comp_node) {
return MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle();
}
template <typename Opr>
Opr* intl::get_megdnn_global_opr(CompNode comp_node) {
using T = MegDNNGlobalOprContainer<Opr>;
auto maker = [comp_node]() { return std::make_shared<T>(comp_node); };
return CompNodeEnv::from_comp_node(comp_node).get_user_data<T>(maker).get();
}
namespace mgb {
namespace opr {
namespace intl {
#define INST(o) template o* get_megdnn_global_opr<o>(CompNode)
INST(megdnn::AddUpdate);
INST(megdnn::Relayout);
INST(megdnn::Checksum);
#undef INST
} } }