megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
#include "megbrain/rdnn/management.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/metahelper.h"

#include "megdnn/handle.h"
#include "megdnn/oprs.h"

/* ================== global functions ================== */

using namespace mgb;
using namespace mgb::opr;

namespace {
template <class Opr>
class MegDNNGlobalOprContainer final : public UserDataContainer::UserData {
    MGB_TYPEINFO_OBJ_DECL;

    std::shared_ptr<megdnn::Handle> m_megdnn_handle;
    std::unique_ptr<Opr> m_opr;

public:
    MegDNNGlobalOprContainer(CompNode cn)
            : m_megdnn_handle{intl::get_megdnn_handle_shared(cn)},
              m_opr{m_megdnn_handle->create_operator<Opr>()} {
        mgb_assert(m_opr->is_thread_safe());
    }

    Opr* get() const { return m_opr.get(); }
};

template <class Opr>
MGB_TYPEINFO_OBJ_IMPL(MegDNNGlobalOprContainer<Opr>);
}  // anonymous namespace

std::shared_ptr<megdnn::Handle> intl::get_megdnn_handle_shared(CompNode comp_node) {
    auto& handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node));
    return {handle.shared_from_this(), handle.handle()};
}

megdnn::Handle* intl::get_megdnn_handle(CompNode comp_node) {
    return MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle();
}

template <typename Opr>
Opr* intl::get_megdnn_global_opr(CompNode comp_node) {
    using T = MegDNNGlobalOprContainer<Opr>;
    auto maker = [comp_node]() { return std::make_shared<T>(comp_node); };
    return CompNodeEnv::from_comp_node(comp_node).get_user_data<T>(maker).get();
}

namespace mgb {
namespace opr {
namespace intl {

#define INST(o) template o* get_megdnn_global_opr<o>(CompNode)
INST(megdnn::AddUpdate);
INST(megdnn::Relayout);
INST(megdnn::Checksum);
#undef INST

}  // namespace intl
}  // namespace opr
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}