megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time
from contextlib import contextmanager
from typing import List, Optional, Tuple

from mprop import mproperty

from ..device import _sh, set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server


class StaticData:
    server = None
    client = None
    master_ip = None
    py_server_port = None
    mm_server_port = None
    world_size = None
    proc_rank = None
    device = None
    backend = None
    device_type = None
    machine_ranks = None


_sd = None


class Group:
    r"""Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).

    By default collectives operate on the default group (also called ``WORLD``)
    and require all processes to enter the distributed function call.

    Args:
        proc_ranks: rank list of the group, the first one is root rank.


    """

    def __init__(self, proc_ranks):
        if len(proc_ranks) == 0:  # empty group
            self.proc_ranks = None
            self.stream = None
        else:
            self.reset(proc_ranks)

    def reset(self, proc_ranks):
        self.check(proc_ranks)
        self.proc_ranks = proc_ranks
        self.is_single_machine_cache = None
        self.stream = _sh.get_next()

    def check(self, proc_ranks):
        assert _sd is not None, "please call init_process_group first"
        for rank in proc_ranks:
            assert isinstance(rank, int)
            assert rank >= 0 and rank < _sd.world_size
        assert _sd.proc_rank in proc_ranks

    @property
    def size(self):
        assert len(self.proc_ranks) > 0, "invalid group"
        return len(self.proc_ranks)

    @property
    def key(self):
        assert len(self.proc_ranks) > 0, "invalid group"
        return ",".join(map(str, self.proc_ranks))

    @property
    def rank(self):
        assert len(self.proc_ranks) > 0, "invalid group"
        return self.proc_ranks.index(_sd.proc_rank)

    @property
    def comp_node(self):
        assert len(self.proc_ranks) > 0, "invalid group"
        return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)

    @property
    def is_single_machine(self):
        if self.is_single_machine_cache is not None:
            return self.is_single_machine_cache
        assert _sd is not None, "please call init_process_group first"
        for rank in self.proc_ranks:
            if _sd.machine_ranks is None or rank not in _sd.machine_ranks:
                self.is_single_machine_cache = False
                return False
        self.is_single_machine_cache = True
        return True


WORLD = Group([])

_devices = {"gpu", "cuda", "rocm"}
_backends = {"nccl", "rccl", "shm", "auto"}


def init_process_group(
    master_ip: str,
    port: int,
    world_size: int,
    rank: int,
    device: int,
    backend: Optional[str] = "auto",
    device_type: str = "xpu",
) -> None:
    r"""Initialize the distributed process group and specify the device used in the current process

    Args:
        master_ip: ip address of the master node.
        port: port available for all processes to communicate.
        world_size: total number of processes participating in the job.
        rank: rank of the current process.
        device: the GPU device id to bind this process to.
        backend: communicator backend, currently support 'nccl' and 'shm'.
    """
    physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
    if not isinstance(master_ip, str):
        raise TypeError("Expect type str but got {}".format(type(master_ip)))
    if not isinstance(port, int):
        raise TypeError("Expect type int but got {}".format(type(port)))
    if not isinstance(world_size, int):
        raise TypeError("Expect type int but got {}".format(type(world_size)))
    if not isinstance(rank, int):
        raise TypeError("Expect type int but got {}".format(type(rank)))
    if not isinstance(device, int):
        raise TypeError("Expect type int but got {}".format(type(backend)))
    if backend not in _backends:
        raise ValueError(
            "backend should be one of {} but got {}".format(_backends, backend)
        )
    if physical_device_type not in _devices:
        raise ValueError(
            "{} is not a valid distributed device type".format(device_type)
        )

    global _sd
    assert _sd is None, "init_process_group should be called only once"
    _sd = StaticData()

    assert world_size > 1
    assert rank >= 0 and rank < world_size
    assert port > 0

    _sd.client = Client(master_ip, port)
    _sd.master_ip = master_ip
    _sd.py_server_port = port
    _sd.mm_server_port = _sd.client.get_mm_server_port()
    _sd.world_size = world_size
    _sd.proc_rank = rank
    _sd.device = device
    _sd.backend = backend
    _sd.device_type = device_type

    WORLD.reset(list(range(world_size)))

    set_default_device("{}{}".format(device_type, device))
    seed(int(time.time()) + rank)


def _set_machine_ranks(ranks) -> None:
    global _sd
    assert _sd is not None

    _sd.machine_ranks = ranks


@contextmanager
def override_backend(new_backend: str):
    r"""Override distributed backend

    Args:
        new_backend: communicator backend set in this context.
    """
    global _sd
    assert _sd, "please call init_process_group first"
    old_backend = _sd.backend
    _sd.backend = new_backend
    try:
        yield
    finally:
        _sd.backend = old_backend


def is_distributed() -> bool:
    r"""Return True if the distributed process group has been initialized."""
    return _sd is not None


def get_rank() -> int:
    r"""Get the rank of the current process."""
    return _sd.proc_rank if _sd is not None else 0


def get_world_size() -> int:
    r"""Get the total number of processes participating in the job."""
    return _sd.world_size if _sd is not None else 1


def get_backend() -> str:
    r"""Get the backend str."""
    assert _sd is not None, "please call init_process_group first"
    return _sd.backend if _sd is not None else None


def get_py_server_addr() -> Tuple[str, int]:
    r"""Get master_ip and port of python XML RPC server."""
    assert _sd is not None, "please call init_process_group first"
    return _sd.master_ip, _sd.py_server_port


def get_mm_server_addr() -> Tuple[str, int]:
    r"""Get master_ip and port of C++ mm_server."""
    assert _sd is not None, "please call init_process_group first"
    return _sd.master_ip, _sd.mm_server_port


def get_client() -> Client:
    r"""Get client of python XML RPC server."""
    assert _sd is not None, "please call init_process_group first"
    return _sd.client


def new_group(proc_ranks: List[int]) -> Group:
    r"""Build a subgroup containing certain ranks."""
    return Group(proc_ranks)


def group_barrier(group: Group = WORLD) -> None:
    r"""Block until all ranks in the group reach this barrier."""
    # if running with single node, skip it
    if _sd is None:
        return
    assert isinstance(group, Group)
    _sd.client.group_barrier(group.key, group.size)