megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
from collections import OrderedDict

import numpy as np

import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.module import Module
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
from megengine.traced_module.expr import Apply, CallFunction, Constant


class MyModule1(M.Module):
    def forward(self, x):
        y = Tensor(x)
        y += 1
        x = x + 2
        return x, y


class MyModule2(M.Module):
    def forward(self, x):
        y = Tensor([1, x, 1])
        y += 1
        x = x + 2
        return x, y


class MyModule3(M.Module):
    def __init__(self):
        super().__init__()
        self.modules = [
            M.Elemwise("ADD"),
            M.Elemwise("ADD"),
            OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
            M.Elemwise("RELU"),
            M.Elemwise("RELU"),
        ]

    def forward(self, a, b):
        x = self.modules[0](a, b)
        y = self.modules[1](a, b)
        assert list(self.modules[2].keys()) == ["a", "b"]
        for _, m in self.modules[2].items():
            y = m(x, y)
        for m in self.modules[3:]:
            y = m(y)
        return y


class MyModule4(M.Module):
    def __init__(self):
        super().__init__()
        self.add = F.add

    def forward(self, x, y):
        return self.add(x, y)


def test_trace_module():
    enable_expr_checker()
    x = Tensor(1)
    m1 = MyModule1()
    tm1 = trace_module(m1, x)

    m2 = MyModule2()
    tm2 = trace_module(m2, x)
    inp = Tensor(2)
    gt = m1(inp)
    output = tm1(inp)
    for a, b in zip(output, gt):
        np.testing.assert_equal(a.numpy(), b.numpy())

    gt1 = m2(inp)
    output1 = tm2(inp)

    for a, b in zip(output1, gt1):
        np.testing.assert_equal(a.numpy(), b.numpy())

    a, b = Tensor(1), Tensor(2)
    m3 = MyModule3()
    gt = m3(a, b)
    tm3 = trace_module(m3, a, b)
    out = tm3(a, b)
    np.testing.assert_equal(out.numpy(), gt.numpy())
    assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
    assert isinstance(tm3.modules.__dict__["2"], TracedModule)
    assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
    assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)

    m4 = MyModule4()
    tm4 = trace_module(m4, a, b)
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm4 = trace_module(m4, a, y=b)
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm4 = trace_module(m4, x=a, y=b)
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, a, b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, a, y=b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, x=a, y=b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

    assert len(tm4.graph._exprs) == 1
    assert isinstance(tm4.graph._exprs[0], CallFunction)

    class MyModule5(Module):
        def __init__(self):
            super().__init__()
            self.m1 = tm4

        def forward(self, x, y):
            return self.m1(x, y)

    tm6 = trace_module(MyModule5(), a, b)
    assert tm6.m1.argspec is None
    assert tm6.m1._is_top is False


def test_trace_module_2():
    class Model(M.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            out = x.shape
            out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1))
            return out

    traced_model = trace_module(Model(), Tensor(([1,])))

    assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance(
        traced_model.graph._exprs[0].opdef, builtin.GetVarShape
    )
    assert isinstance(traced_model.graph._exprs[1], Constant)
    assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance(
        traced_model.graph._exprs[2].opdef, builtin.Elemwise
    )
    assert int(traced_model(Tensor([1, 2]))[0]) == 3