megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import types

import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm


class myconv(M.Conv2d):
    pass


class mybn(M.BatchNorm2d):
    pass


class MyBlock(M.Module):
    def __init__(self, conv_cls, bn_cls):
        super().__init__()
        self.conv = conv_cls(3, 3, 1, 1, 0)
        self.bn = bn_cls(3)
        self.conv2 = conv_cls(3, 3, 1, 1, 0)
        self.bn2 = bn_cls(3)
        self.scale = mge.Tensor([3, 4])

    def forward(self, x):
        x1 = self.conv(x)
        x1 = self.bn(x1)
        x1 = F.relu(x1)
        x1 = x1 * self.scale[0]
        x2 = self.conv2(x)
        x2 = self.bn2(x2)
        x2 = F.relu(x2)
        x2 = x2 * self.scale[1]
        y = x1 + x2
        y = y + 4
        y = self.scale[0] + y
        y = F.relu(y) * 3
        return y


class MyModule(M.Module):
    def __init__(self, conv_cls, bn_cls):
        super().__init__()
        self.block_0 = MyBlock(conv_cls, bn_cls)
        self.block_1 = MyBlock(conv_cls, bn_cls)

    def forward(self, x):
        x1 = self.block_0(x)
        x2 = self.block_1(x)
        y = x1 + x2
        y = F.reshape(y, (-1))
        y = y * 3
        return y


@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv])
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn])
def test_backward_fold_scale(conv_cls, bn_cls):
    module = MyModule(conv_cls, bn_cls)
    module.eval()
    inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
    desired = module(inp)
    traced_net = tm.trace_module(module, inp)

    traced_net = traced_net.flatten()
    optimized_net = tm.optimize(traced_net, "BackwardFoldScale")

    actual = optimized_net(inp)
    np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
    # fuse all mul to conv
    mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list()
    assert len(mul_list) == 0


@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv])
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn])
def test_fuse_bn(conv_cls, bn_cls):
    module = MyModule(conv_cls, bn_cls)
    module.eval()
    inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
    desired = module(inp)
    traced_net = tm.trace_module(module, inp)

    traced_net = traced_net.flatten()
    optimized_net = tm.optimize(traced_net, "FuseConvBn")

    actual = optimized_net(inp)
    np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
    # fuse all mul to conv
    bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list()
    assert len(bn_list) == 0

    bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list()
    assert len(bn_list) == 0