megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

import megengine as mge
import megengine.distributed as dist
from megengine import tensor
from megengine.distributed.functional import (
    all_gather,
    all_to_all,
    gather,
    reduce_scatter_sum,
    scatter,
)
from megengine.jit import trace


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(2, 3), (8, 10), (99, 77), (2, 2, 2, 2)], ids=str)
@pytest.mark.parametrize("symbolic", [False, True], ids=str)
@pytest.mark.parametrize("axis", [0, 1], ids=str)
@pytest.mark.isolated_distributed
def test_all_gather(shape, symbolic, axis):
    @dist.launcher(n_gpus=2)
    def worker(data, expect):
        rank = dist.get_rank()
        inp = tensor(data[rank])

        def func():
            output = all_gather(inp, axis=axis)
            return output

        func = trace(symbolic=symbolic)(func)
        output = func()
        assert np.allclose(output.numpy(), expect[rank])

    x = np.random.random_sample(shape).astype("float32")
    y = np.random.random_sample(shape).astype("float32")
    z = np.concatenate((x, y), axis=axis)
    data = (x, y)
    expect = (z, z)
    worker(data, expect)


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize(
    "shape,symbolic", [((2, 4, 6, 8), False), ((2, 4, 6, 8), True)], ids=str
)
@pytest.mark.parametrize("axis", [1, 0, 2, 3], ids=str)
@pytest.mark.isolated_distributed
def test_reduce_scatter_sum(shape, symbolic, axis):
    @dist.launcher(n_gpus=2)
    def worker(data, expect):
        rank = dist.get_rank()
        inp = tensor(data[rank])

        def func():
            output = reduce_scatter_sum(inp, axis=axis)
            return output

        func = trace(symbolic=symbolic)(func)
        output = func()
        assert np.allclose(output.numpy(), expect[rank])

    x = np.random.random_sample(shape).astype("float32")
    y = np.random.random_sample(shape).astype("float32")
    z = x + y
    data = (x, y)
    z = np.split(z, 2, axis=axis)
    z = np.concatenate(z, axis=0)
    expect = (z[: z.shape[0] // 2], z[z.shape[0] // 2 :])
    worker(data, expect)


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize(
    "shape,symbolic", [((2, 4, 6, 8), True), ((2, 4, 6, 8), False)], ids=str
)
@pytest.mark.parametrize("axis", [1, 0, 2, 3], ids=str)
@pytest.mark.isolated_distributed
def test_scatter(shape, symbolic, axis):
    @dist.launcher(n_gpus=2)
    def worker(data, expect):
        rank = dist.get_rank()
        inp = tensor(data[rank])

        def func():
            output = scatter(inp, axis=axis)
            return output

        func = trace(symbolic=symbolic)(func)
        output = func()
        assert np.allclose(output.numpy(), expect[rank])

    x = np.random.random_sample(shape).astype("float32")
    y = x + 1
    data = (x, y)
    _x = np.split(x, 2, axis=axis)
    _x = np.concatenate(_x, axis=0)
    expect = (_x[: _x.shape[0] // 2], _x[_x.shape[0] // 2 :])
    worker(data, expect)


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(2, 4, 6, 8)], ids=str)
@pytest.mark.parametrize("symbolic", [False, True], ids=str)
@pytest.mark.parametrize(
    "split_axis,concat_axis", [(0, 1), (1, 0), (2, 0), (0, 2), (2, 3)], ids=str
)
@pytest.mark.isolated_distributed
def test_all_to_all(shape, symbolic, split_axis, concat_axis):
    @dist.launcher(n_gpus=2)
    def worker(data):
        rank = dist.get_rank()
        inp = tensor(data[rank])

        def func():
            all_to_all_output = all_to_all(
                inp, split_axis=split_axis, concat_axis=concat_axis
            )
            gather_C = gather(inp, axis=concat_axis)
            gather_B = gather(all_to_all_output, axis=split_axis)
            if rank == 0:
                return gather_B, gather_C
            return all_to_all_output

        func = trace(symbolic=symbolic)(func)
        ret = func()
        if rank == 0:
            assert np.allclose(ret[0], ret[1])

    x = np.random.random_sample(shape).astype("float32")
    y = np.random.random_sample(shape).astype("float32")
    data = (x, y)
    worker(data)