#pragma once
#include "megbrain_build_config.h"
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/collective_comm.h"
#include "megbrain/opr/group_manager.h"
namespace mgb {
namespace opr {
class GroupClientProxy : public std::enable_shared_from_this<GroupClientProxy>,
public opr::GroupClient {
public:
virtual ~GroupClientProxy() = default;
GroupClientProxy(const std::string& server_addr);
GroupManager::RegisterInfo opr_register(
const std::string& key, size_t nr_devices, bool is_root, int rank,
uint64_t comp_node_hash) override;
void bcast_addr(
std::string& master_ip, int& port, const std::string& key, uint32_t size,
uint32_t rank, uint32_t root) override;
void set_output_shape(const std::string& key, const TensorShape& shape) override;
TensorShape get_output_shape(const std::string& key) override;
uint32_t group_barrier(uint32_t size, uint32_t rank) override;
const std::string& get_addr() const override { return m_addr; }
private:
const std::string m_addr;
void* m_stub;
};
int create_zmqrpc_server(const std::string& server_addr, int port);
} }
#endif