from collections import OrderedDict
import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.module import Module
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
from megengine.traced_module.expr import Apply, CallFunction, Constant
class MyModule1(M.Module):
def forward(self, x):
y = Tensor(x)
y += 1
x = x + 2
return x, y
class MyModule2(M.Module):
def forward(self, x):
y = Tensor([1, x, 1])
y += 1
x = x + 2
return x, y
class MyModule3(M.Module):
def __init__(self):
super().__init__()
self.modules = [
M.Elemwise("ADD"),
M.Elemwise("ADD"),
OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
M.Elemwise("RELU"),
M.Elemwise("RELU"),
]
def forward(self, a, b):
x = self.modules[0](a, b)
y = self.modules[1](a, b)
assert list(self.modules[2].keys()) == ["a", "b"]
for _, m in self.modules[2].items():
y = m(x, y)
for m in self.modules[3:]:
y = m(y)
return y
class MyModule4(M.Module):
def __init__(self):
super().__init__()
self.add = F.add
def forward(self, x, y):
return self.add(x, y)
def test_trace_module():
enable_expr_checker()
x = Tensor(1)
m1 = MyModule1()
tm1 = trace_module(m1, x)
m2 = MyModule2()
tm2 = trace_module(m2, x)
inp = Tensor(2)
gt = m1(inp)
output = tm1(inp)
for a, b in zip(output, gt):
np.testing.assert_equal(a.numpy(), b.numpy())
gt1 = m2(inp)
output1 = tm2(inp)
for a, b in zip(output1, gt1):
np.testing.assert_equal(a.numpy(), b.numpy())
a, b = Tensor(1), Tensor(2)
m3 = MyModule3()
gt = m3(a, b)
tm3 = trace_module(m3, a, b)
out = tm3(a, b)
np.testing.assert_equal(out.numpy(), gt.numpy())
assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
assert isinstance(tm3.modules.__dict__["2"], TracedModule)
assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)
m4 = MyModule4()
tm4 = trace_module(m4, a, b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm4 = trace_module(m4, a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm4 = trace_module(m4, x=a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, a, b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, x=a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
assert len(tm4.graph._exprs) == 1
assert isinstance(tm4.graph._exprs[0], CallFunction)
class MyModule5(Module):
def __init__(self):
super().__init__()
self.m1 = tm4
def forward(self, x, y):
return self.m1(x, y)
tm6 = trace_module(MyModule5(), a, b)
assert tm6.m1.argspec is None
assert tm6.m1._is_top is False
def test_trace_module_2():
class Model(M.Module):
def __init__(self):
super().__init__()
def forward(self, x):
out = x.shape
out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1))
return out
traced_model = trace_module(Model(), Tensor(([1,])))
assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance(
traced_model.graph._exprs[0].opdef, builtin.GetVarShape
)
assert isinstance(traced_model.graph._exprs[1], Constant)
assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance(
traced_model.graph._exprs[2].opdef, builtin.Elemwise
)
assert int(traced_model(Tensor([1, 2]))[0]) == 3