from collections import OrderedDict
from io import BytesIO
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
from megengine import Parameter, Tensor, tensor
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv1d,
Conv2d,
Dropout,
Linear,
MaxPool2d,
Module,
Sequential,
Softmax,
)
from megengine.module.module import _access_structure
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.traced_module import TracedModule, trace_module
from megengine.utils.module_utils import get_expand_structure, set_expand_structure
class MLP(Module):
def __init__(self):
super().__init__()
self.dense0 = Linear(28, 50)
self.dense1 = Linear(50, 20)
def forward(self, x):
x = self.dense0(x)
x = F.relu(x)
x = self.dense1(x)
return x
class MyModule(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
def forward(self, x):
return self.bn(x)
def __init__(self):
super().__init__()
self.i = self.InnerModule()
self.bn = BatchNorm2d(4)
self.param = Parameter(np.ones(1, dtype=np.float32))
self.buff = Tensor(np.ones(1, dtype=np.float32))
def forward(self, x):
x = self.i(x)
x = self.bn(x)
return x
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api(test_traced_module):
m = MyModule()
if test_traced_module:
buff = m.buff
param = m.param
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
assert "buff" not in m.__dict__
assert "param" not in m.__dict__
m.buff = buff
m.param = param
assert list(m.children()) == [m.bn, m.i]
assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)]
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("i", m.i),
("i.bn", m.i.bn),
]
assert list(m.named_modules(prefix="x")) == [
("x", m),
("x.bn", m.bn),
("x.i", m.i),
("x.i.bn", m.i.bn),
]
assert list(m.buffers()) == [
m.bn.running_mean,
m.bn.running_var,
m.buff,
m.i.bn.running_mean,
m.i.bn.running_var,
]
assert list(m.buffers(recursive=False)) == [m.buff]
assert list(m.named_buffers()) == [
("bn.running_mean", m.bn.running_mean),
("bn.running_var", m.bn.running_var),
("buff", m.buff),
("i.bn.running_mean", m.i.bn.running_mean),
("i.bn.running_var", m.i.bn.running_var),
]
assert list(m.parameters()) == [
m.bn.bias,
m.bn.weight,
m.i.bn.bias,
m.i.bn.weight,
m.param,
]
assert list(m.named_parameters()) == [
("bn.bias", m.bn.bias),
("bn.weight", m.bn.weight),
("i.bn.bias", m.i.bn.bias),
("i.bn.weight", m.i.bn.weight),
("param", m.param),
]
assert list(m.tensors()) == [
m.bn.bias,
m.bn.running_mean,
m.bn.running_var,
m.bn.weight,
m.buff,
m.i.bn.bias,
m.i.bn.running_mean,
m.i.bn.running_var,
m.i.bn.weight,
m.param,
]
assert list(m.named_tensors()) == [
("bn.bias", m.bn.bias),
("bn.running_mean", m.bn.running_mean),
("bn.running_var", m.bn.running_var),
("bn.weight", m.bn.weight),
("buff", m.buff),
("i.bn.bias", m.i.bn.bias),
("i.bn.running_mean", m.i.bn.running_mean),
("i.bn.running_var", m.i.bn.running_var),
("i.bn.weight", m.i.bn.weight),
("param", m.param),
]
m.eval()
assert (
m.training == False
and m.bn.training == False
and m.i.training == False
and m.i.bn.training == False
)
m.bn.train()
assert m.training == False and m.bn.training == True and m.i.bn.training == False
m.eval()
m.i.train()
assert (
m.training == False
and m.bn.training == False
and m.i.training == True
and m.i.bn.training == True
)
m.eval()
m.train()
assert m.training == True and m.bn.training == True and m.i.bn.training == True
def fn(m):
m.training = False
m.apply(fn)
assert m.bn.training == False and m.i.bn.training == False
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_reuse_submodule(test_traced_module):
m = MyModule()
if test_traced_module:
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
m.h = m.i assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("h", m.i),
("h.bn", m.i.bn),
]
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_iterable_stability(test_traced_module):
m = MyModule()
if test_traced_module:
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
l = list(m.modules())
for _ in range(100):
assert list(m.modules()) == l
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_hooks(test_traced_module):
net = MyModule()
if test_traced_module:
net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1))))
pre_hook_num = 0
post_hook_num = 0
hooks = []
def pre_hook(_, inputs):
nonlocal pre_hook_num
pre_hook_num += 1
modified_inputs = tuple(inp + 1 for inp in inputs)
return modified_inputs
def post_hook(_, __, outputs):
nonlocal post_hook_num
post_hook_num += 1
outputs += 1
return outputs
net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook)))
net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook)))
shape = (1, 4, 1, 1)
x = tensor(np.zeros(shape, dtype=np.float32))
y = net(x)
assert pre_hook_num == 4
assert post_hook_num == 4
mean1 = Parameter(np.zeros(shape), dtype=np.float32)
bn1 = F.batch_norm(
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True
)
np.testing.assert_allclose(
net.i.bn.running_mean.numpy(), mean1.numpy(),
)
mean2 = Parameter(np.zeros(shape), dtype=np.float32)
bn2 = F.batch_norm(
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True
)
np.testing.assert_allclose(
net.bn.running_mean.numpy(), mean2.numpy(),
)
np.testing.assert_allclose((bn2 + 2).numpy(), y.numpy())
assert len(hooks) == 8
for handler in hooks:
handler.remove()
y = net(x)
assert pre_hook_num == 4
assert post_hook_num == 4
class MyModule2(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.test_bool_key = {True: 1, False: 0}
def forward(self, x):
x = self.bn(x)
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.a = [
BatchNorm2d(4),
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0},
(self.InnerModule(),),
]
def forward(self, x):
return x
def test_expand_structure():
m = MyModule2()
rst = [
("", m),
("a.0", m.a[0]),
("a.1.x", m.a[1]["x"]),
("a.1.y.0", m.a[1]["y"][0]),
("a.1.y.1", m.a[1]["y"][1]),
("a.1.y.1.bn", m.a[1]["y"][1].bn),
("a.2.0", m.a[2][0]),
("a.2.0.bn", m.a[2][0].bn),
("bn", m.bn),
]
assert list(m.named_modules()) == rst
for item in rst[1:]:
assert get_expand_structure(m, item[0]) == item[1]
for item in reversed(rst[1:]):
if _access_structure(m, item[0], lambda p, k, o: isinstance(p, tuple)):
continue
set_expand_structure(m, item[0], "TEST_VALUE")
assert get_expand_structure(m, item[0]) == "TEST_VALUE"
def test_flatten_others():
def be_others(obj):
return not isinstance(obj, (Tensor, Module))
m = MyModule2()
assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0
def test_flatten_with_parent():
m = MyModule2()
assert list(m.named_modules(with_parent=True)) == [
("", m, None),
("a.0", m.a[0], m),
("a.1.x", m.a[1]["x"], m),
("a.1.y.0", m.a[1]["y"][0], m),
("a.1.y.1", m.a[1]["y"][1], m),
("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]),
("a.2.0", m.a[2][0], m),
("a.2.0.bn", m.a[2][0].bn, m.a[2][0]),
("bn", m.bn, m),
]
assert list(m.modules(with_parent=True)) == [
(m, None),
(m.a[0], m),
(m.a[1]["x"], m),
(m.a[1]["y"][0], m),
(m.a[1]["y"][1], m),
(m.a[1]["y"][1].bn, m.a[1]["y"][1]),
(m.a[2][0], m),
(m.a[2][0].bn, m.a[2][0]),
(m.bn, m),
]
class MyModule3(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
def forward(self, x):
x = self.bn(x)
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),)
def forward(self, x):
return x
def test_module_api_with_sequential():
m = MyModule3()
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("seq", m.seq),
("seq.0", m.seq[0]),
("seq.1", m.seq[1]),
("seq.1.bn", m.seq[1].bn),
]
def test_sequential_named_children():
modules = OrderedDict()
modules["name0"] = Linear(20, 10)
modules["name1"] = Linear(10, 5)
modules["name2"] = Linear(5, 1)
m = Sequential(modules)
l = list(m.named_children())
assert l[0][0] == "name0"
assert l[1][0] == "name1"
assert l[2][0] == "name2"
def test_state_dict():
data_shape = (2, 28)
data = tensor(np.random.random(data_shape))
mlp = MLP()
pred0 = mlp(data)
with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
state_dict = mge.load(fout)
state_dict["extra"] = None
mlp1 = MLP()
mlp1.load_state_dict(state_dict, strict=False)
pred1 = mlp1(data)
np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6)
with pytest.raises(KeyError):
mlp1.load_state_dict(state_dict)
del state_dict["extra"]
del state_dict["dense0.bias"]
with pytest.raises(KeyError):
mlp1.load_state_dict(state_dict)
class AssertModule(Module):
def __init__(self):
super().__init__()
self.error_tensor_key = {True: tensor([]), False: 0}
def forward(self, x):
return x
def test_assert_message():
with pytest.raises(
AssertionError, match="keys for Tensor and Module must be str, error key: True"
):
m = AssertModule()
list(m._flatten())
class Simple(Module):
def __init__(self):
super().__init__()
self.conv0 = Conv2d(1, 1, kernel_size=3, bias=False)
self.conv1 = Conv2d(1, 1, kernel_size=3, bias=False)
self.conv1.weight = self.conv0.weight
def forward(self, inputs):
x = self.conv0(inputs)
y = self.conv1(inputs)
return x + y
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_shared_param(test_traced_module):
net = Simple()
if test_traced_module:
net = trace_module(net, tensor(np.random.random((1, 1, 8, 8))))
assert net.conv0.weight is net.conv1.weight
data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32))
np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
with BytesIO() as f:
mge.save(net, f)
f.seek(0)
net1 = mge.load(f)
assert net1.conv0.weight is net1.conv1.weight
np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy())
with BytesIO() as f:
mge.save(net.conv0, f)
f.seek(0)
conv0 = mge.load(f)
with BytesIO() as f:
mge.save(net.conv1, f)
f.seek(0)
conv1 = mge.load(f)
assert conv0.weight is not conv1.weight
np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
class Simple2(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv1d(1, 1, kernel_size=3, bias=False)
self.conv0 = Conv1d(1, 1, kernel_size=3, bias=False)
self.conv1.weight = self.conv0.weight
def forward(self, inputs):
pass
def test_shared_param_1d():
net = Simple2()
assert net.conv0.weight is net.conv1.weight
data = tensor(np.random.random((1, 1, 8)).astype(np.float32))
np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
with BytesIO() as f:
mge.save(net, f)
f.seek(0)
net1 = mge.load(f)
assert net1.conv0.weight is net1.conv1.weight
np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy())
with BytesIO() as f:
mge.save(net.conv0, f)
f.seek(0)
conv0 = mge.load(f)
with BytesIO() as f:
mge.save(net.conv1, f)
f.seek(0)
conv1 = mge.load(f)
assert conv0.weight is not conv1.weight
np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_pickle_module(test_traced_module):
data_shape = (2, 28)
data = tensor(np.random.random(data_shape))
mlp = MLP()
pred_gt = mlp(data)
if test_traced_module:
mlp = trace_module(mlp, data)
with BytesIO() as fout:
mge.save(mlp, fout)
fout.seek(0)
mlp1 = mge.load(fout)
if test_traced_module:
assert type(mlp1) == TracedModule
pred0 = mlp1(data)
pred1 = mlp(data)
with BytesIO() as fout:
mge.save(mlp, fout)
fout.seek(0)
mlp1 = mge.load(fout)
if test_traced_module:
assert type(mlp1) == TracedModule
pred2 = mlp1(data)
np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6)
np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6)
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)
def test_repr_basic():
class ConvModel(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 128, 3, padding=1, bias=False)
self.conv2 = Conv2d(3, 128, 3, dilation=2, bias=False)
self.bn1 = BatchNorm1d(128)
self.bn2 = BatchNorm2d(128)
self.pooling = MaxPool2d(kernel_size=2, padding=0)
modules = OrderedDict()
modules["depthwise"] = Conv2d(256, 256, 3, 1, 0, groups=256, bias=False,)
modules["pointwise"] = Conv2d(
256, 256, kernel_size=1, stride=1, padding=0, bias=True,
)
self.submodule1 = Sequential(modules)
self.list1 = [Dropout(drop_prob=0.1), [Softmax(axis=100)]]
self.tuple1 = (
Dropout(drop_prob=0.1),
(Softmax(axis=100), Dropout(drop_prob=0.2)),
)
self.dict1 = {"Dropout": Dropout(drop_prob=0.1)}
self.fc1 = Linear(512, 1024)
def forward(self, inputs):
pass
ground_truth = (
"ConvModel(\n"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
" (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n"
" (submodule1): Sequential(\n"
" (depthwise): Conv2d(256, 256, kernel_size=(3, 3), groups=256, bias=False)\n"
" (pointwise): Conv2d(256, 256, kernel_size=(1, 1))\n"
" )\n"
" (list1.0): Dropout(drop_prob=0.1)\n"
" (list1.1.0): Softmax(axis=100)\n"
" (tuple1.0): Dropout(drop_prob=0.1)\n"
" (tuple1.1.0): Softmax(axis=100)\n"
" (tuple1.1.1): Dropout(drop_prob=0.2)\n"
" (dict1.Dropout): Dropout(drop_prob=0.1)\n"
" (fc1): Linear(in_features=512, out_features=1024, bias=True)\n"
")"
)
net = ConvModel()
output = net.__repr__()
assert output == ground_truth
def test_repr_module_reassign():
class ConvModel1(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 128, 3, bias=False)
self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False)
self.conv1 = Conv2d(3, 256, 3, dilation=2, bias=False)
def forward(self, inputs):
pass
ground_truth = (
"ConvModel1(\n"
" (conv1): Conv2d(3, 256, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
")"
)
net = ConvModel1()
output = net.__repr__()
assert output == ground_truth
def test_repr_module_rereference():
class ConvModel2(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 128, 3, bias=False)
self.conv2 = self.conv1
self.conv3 = self.conv1
def forward(self, inputs):
pass
ground_truth = (
"ConvModel2(\n"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
")"
)
net = ConvModel2()
output = net.__repr__()
assert output == ground_truth
def test_repr_module_delete():
class ConvModel3(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 128, 3, bias=False)
self.softmax = Softmax(100)
def forward(self, inputs):
pass
ground_truth = (
"ConvModel3(\n"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
")"
)
net = ConvModel3()
del net.softmax
output = net.__repr__()
assert output == ground_truth
def test_repr_module_reset_attr():
class ResetAttrModule(Module):
def __init__(self, flag):
super().__init__()
if flag:
self.a = None
self.a = Linear(3, 5)
else:
self.a = Linear(3, 5)
self.a = None
def forward(self, x):
if self.a:
x = self.a(x)
return x
ground_truth = [
(
"ResetAttrModule(\n"
" (a): Linear(in_features=3, out_features=5, bias=True)\n"
")"
),
("ResetAttrModule()"),
]
m0 = ResetAttrModule(True)
m1 = ResetAttrModule(False)
output = [m0.__repr__(), m1.__repr__()]
assert output == ground_truth
def test_module_compatible():
class Empty(Module):
def forward(self):
pass
empty_module = Empty()
old_attributes = set(
[
"_modules",
"name",
"training",
"quantize_disabled",
"_forward_pre_hooks",
"_forward_hooks",
"_name",
"_short_name",
]
)
current_attributes = set(empty_module.__dict__.keys())
assert (
old_attributes == current_attributes
), "Add or delete attributes in Module class may break compatibility of pickle serialization"