megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/custom/impl/manager.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/common.h"

#if MGB_CUSTOM_OP

#include <unordered_set>
#include "megbrain/custom/manager.h"

#ifndef _WIN32
#include <dlfcn.h>
#else
#include <windows.h>
#endif

using namespace mgb;

namespace custom {

#ifdef _WIN32
#define RTLD_LAZY 0

void* dlopen(const char* file, int) {
    return static_cast<void*>(LoadLibrary(file));
}

int dlclose(void* handle) {
    return static_cast<int>(FreeLibrary(static_cast<HMODULE>(handle)));
}

const char* dlerror(void) {
    static char win_err_info[] = "no dlerror info in windows";
    return win_err_info;
}
#endif

CustomOpManager* CustomOpManager::inst(void) {
    static CustomOpManager op_manager;
    return &op_manager;
}

CustomOpManager::~CustomOpManager() {
    mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!");
    LibManager::inst()->m_custom_libs.clear();
}

std::shared_ptr<CustomOp> CustomOpManager::insert(
        const std::string& name, uint32_t version) {
    MGB_LOCK_GUARD(m_mtx);
    auto iter = m_name2op.find(name);
    if (iter != m_name2op.end()) {
        mgb_log_warn(
                "Register Custom Op Failed! Op %s has been registered", name.c_str());
        return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second);
    }
    std::shared_ptr<const CustomOp> op =
            std::make_shared<const CustomOp>(name, version);
    m_name2op[op->op_type()] = op;
    m_id2op[op->runtime_id()] = op;
    return std::const_pointer_cast<CustomOp, const CustomOp>(op);
}

bool CustomOpManager::erase(const std::string& name) {
    MGB_LOCK_GUARD(m_mtx);
    auto iter = m_name2op.find(name);
    if (iter == m_name2op.end()) {
        mgb_log_warn(
                "Erase Custom Op Failed! %s has not been registered", name.c_str());
        return false;
    }
    std::shared_ptr<const CustomOp> op = iter->second;
    m_id2op.erase(op->runtime_id());
    m_name2op.erase(op->op_type());
    return true;
}

bool CustomOpManager::erase(const RunTimeId& id) {
    MGB_LOCK_GUARD(m_mtx);
    auto iter = m_id2op.find(id);
    if (iter == m_id2op.end()) {
        mgb_log_warn("Erase Custom Op Failed! The Op has not been registered");
        return false;
    }
    std::shared_ptr<const CustomOp> op = iter->second;
    m_id2op.erase(op->runtime_id());
    m_name2op.erase(op->op_type());
    return true;
}

std::shared_ptr<CustomOp> CustomOpManager::find_or_reg(
        const std::string& name, uint32_t version) {
    auto iter = m_name2op.find(name);
    if (iter == m_name2op.end()) {
        return insert(name, version);
    }
    return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second);
}

RunTimeId CustomOpManager::to_id(const std::string& name) const {
    std::shared_ptr<const CustomOp> op = find(name);
    return op->runtime_id();
}

std::string CustomOpManager::to_name(const RunTimeId& id) const {
    std::shared_ptr<const CustomOp> op = find(id);
    return op->op_type();
}

std::shared_ptr<const CustomOp> CustomOpManager::find(const std::string& name) const {
    auto ret = m_name2op.find(name);
    mgb_assert(
            ret != m_name2op.end(),
            "Find Custom Op Failed! Op %s has not been registered", name.c_str());
    return ret->second;
}

std::shared_ptr<const CustomOp> CustomOpManager::find(const RunTimeId& id) const {
    auto ret = m_id2op.find(id);
    mgb_assert(
            ret != m_id2op.end(), "Find Custom Op Failed! Op has not been registered");
    return ret->second;
}

std::vector<std::string> CustomOpManager::op_name_list(void) {
    std::vector<std::string> ret;
    for (auto kv : m_name2op) {
        ret.emplace_back(kv.first);
    }
    return ret;
}

std::vector<RunTimeId> CustomOpManager::op_id_list(void) {
    std::vector<RunTimeId> ret;
    for (auto kv : m_id2op) {
        ret.emplace_back(kv.first);
    }
    return ret;
}

CustomLib::CustomLib(const std::string& path, int mode = RTLD_LAZY)
        : m_handle(nullptr, [](void* handle) { dlclose(handle); }) {
    auto op_list_before_load = CustomOpManager::inst()->op_name_list();
    std::unordered_set<std::string> op_set_before_load(
            op_list_before_load.begin(), op_list_before_load.end());

    m_handle.reset(dlopen(path.c_str(), mode));
    mgb_assert(
            m_handle != nullptr, "open custom op lib failed, error type: %s",
            dlerror());

    auto op_list_after_load = CustomOpManager::inst()->op_name_list();
    for (auto& op : op_list_after_load) {
        if (op_set_before_load.find(op) == op_set_before_load.end()) {
            m_ops.emplace_back(op);
        }
    }
}

const std::vector<std::string>& CustomLib::ops_in_lib(void) const {
    return m_ops;
}

CustomLib::~CustomLib() {
    for (auto& op : m_ops) {
        CustomOpManager::inst()->erase(op);
    }
}

bool CustomLib::valid() const {
    return m_handle != nullptr;
}

LibManager* LibManager::inst(void) {
    static LibManager custom_libs;
    return &custom_libs;
}

const std::vector<std::string>& LibManager::install(
        const std::string& name, const std::string& path) {
    MGB_LOCK_GUARD(m_mtx);
    ;
    LibHandle handle = std::make_shared<CustomLib>(path);
    m_custom_libs.insert({name, handle});
    return m_custom_libs[name]->ops_in_lib();
}

bool LibManager::uninstall(const std::string& name) {
    MGB_LOCK_GUARD(m_mtx);
    ;
    mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error");
    return true;
}

std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version) {
    return CustomOpManager::inst()->insert(opname, version);
}

}  // namespace custom

#endif