megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/opr-mm/test/mock_client.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/group_manager.h"

namespace mgb {
namespace test {

class MockGroupClient final : public opr::GroupClient {
public:
    using RegisterInfo = opr::GroupManager::RegisterInfo;

    MockGroupClient(const std::string& server_addr = "mock_addr")
            : m_addr(server_addr) {}

    ~MockGroupClient() override = default;

    const std::string& get_addr() const { return m_addr; }

    RegisterInfo opr_register(
            const std::string& key, size_t nr_devices, bool is_root, int rank,
            uint64_t comp_node_hash) override {
        return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash);
    }

    void bcast_addr(
            std::string& master_ip, int& port, const std::string& key, uint32_t size,
            uint32_t rank, uint32_t root) override {
        return m_mgr.bcast_addr(master_ip, port, key, size, rank, root);
    }

    void set_output_shape(const std::string& key, const TensorShape& shape) override {
        m_mgr.set_output_shape(key, shape);
    }

    TensorShape get_output_shape(const std::string& key) override {
        return m_mgr.get_output_shape(key);
    }

    uint32_t group_barrier(uint32_t size, uint32_t rank) override {
        return m_mgr.group_barrier(size, rank);
    }

private:
    const std::string m_addr;
    opr::GroupManager m_mgr;
};

}  // namespace test
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}