#pragma once
#include <memory>
#include <mutex>
#include "megbrain/comp_node.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"
namespace mgb {
namespace opr {
MegRay::DType get_megray_dtype(megdnn::DType);
MegRay::Backend get_megray_backend(const std::string& backend);
std::shared_ptr<MegRay::Context> get_megray_context(CompNode comp_node);
class MegRayCommBuilder {
private:
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm);
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms;
std::mutex m_map_mtx;
static MegRayCommBuilder* sm_instance;
static std::mutex sm_instance_mtx;
public:
static std::shared_ptr<MegRay::Communicator> get_megray_comm(
uint64_t hash, std::string key, uint32_t size, uint32_t rank,
MegRay::Backend backend,
std::shared_ptr<mgb::opr::GroupClient> group_client);
};
} }