import io
from functools import partial
from itertools import chain
from typing import Callable
import numpy as np
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.module.qat as QM
import megengine.quantization as Q
from megengine import Tensor
from megengine.module.qat.module import QATModule
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.utils import get_subattr
class MyConvBnRelu2d(M.ConvBnRelu2d):
pass
class MyQATConvBnRelu2d(QM.ConvBnRelu2d):
pass
class Myblcok(M.Module):
def __init__(self,):
super().__init__()
self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1)
self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0)
self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0)
self.add = M.Elemwise("FUSE_ADD_RELU")
def forward(self, x):
x = self.conv0(x)
x0 = self.conv1(x)
x1 = self.conv2(x)
o = self.add(x0, x1)
return o
class MyModule(M.Module):
def __init__(self):
super().__init__()
self.block0 = Myblcok()
self.block1 = Myblcok()
def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x
class MyMinMaxObserver(Q.MinMaxObserver):
pass
class MyTQT(Q.TQT):
pass
def get_lsq_config(lsq_cls):
return Q.QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(lsq_cls, dtype="qint8_narrow"),
act_fake_quant=partial(lsq_cls, dtype="qint8"),
)
def get_observer_config(observer_cls):
return Q.QConfig(
weight_observer=partial(observer_cls, dtype="qint8_narrow"),
act_observer=partial(observer_cls, dtype="qint8"),
weight_fake_quant=None,
act_fake_quant=None,
)
def get_qparams(mod: QATModule):
weight_qparams, act_qparams = None, None
if mod.act_observer is not None:
act_qparams = mod.act_observer.get_qparams()
if mod.act_fake_quant:
act_qparams = mod.act_fake_quant.get_qparams()
if mod.weight_observer is not None:
weight_qparams = mod.weight_observer.get_qparams()
if mod.weight_fake_quant:
weight_qparams = mod.weight_fake_quant.get_qparams()
return weight_qparams, act_qparams
def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
assert qparmsa.dtype_meta == qparmsb.dtype_meta
assert qparmsa.mode == qparmsb.mode
np.testing.assert_equal(qparmsa.scale.numpy(), qparmsb.scale.numpy())
if qparmsa.zero_point is not None:
np.testing.assert_equal(qparmsa.zero_point.numpy(), qparmsb.zero_point.numpy())
def build_observered_net(net: M.Module, observer_cls):
qat_net = Q.quantize_qat(
net,
qconfig=get_observer_config(observer_cls),
mapping={MyConvBnRelu2d: MyQATConvBnRelu2d},
)
Q.enable_observer(qat_net)
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net.eval()
qat_net(inp)
Q.disable_observer(qat_net)
return qat_net
def build_fakequanted_net(net: QATModule, fakequant_cls):
qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls))
qat_net.eval()
return qat_net
def test_trace_qat():
def _check_qat_module(qat_net: QATModule):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
traced_net = trace_module(qat_net, inp)
for name, qat_module in qat_net.named_modules():
if not isinstance(qat_module, QATModule):
continue
traced_qat_module = get_subattr(traced_net, name)
weight_qparams, act_qparams = get_qparams(qat_module)
traced_weight_qparams, traced_act_qparams = get_qparams(traced_qat_module)
if weight_qparams:
check_qparams(weight_qparams, traced_weight_qparams)
if act_qparams:
check_qparams(act_qparams, traced_act_qparams)
flatten_traced_net = traced_net.flatten()
conv0_node = flatten_traced_net.graph.get_node_by_name(
"MyModule_block0_conv0"
).as_unique()
conv0_out_node = flatten_traced_net.graph.get_node_by_name(
"MyModule_block0_conv0_out"
).as_unique()
assert isinstance(conv0_node.owner, TracedModule)
assert conv0_out_node.expr.inputs[0] is conv0_node
_check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver))
_check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver))
_check_qat_module(
build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), Q.TQT)
)
_check_qat_module(
build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), MyTQT)
)
def test_load_param():
def _check_param(moda: M.Module, modb: M.Module):
for name, attr in chain(moda.named_parameters(), moda.named_buffers()):
traced_attr = get_subattr(modb, name)
np.testing.assert_equal(attr.numpy(), traced_attr.numpy())
def _check_module(build_func: Callable):
net = build_func()
net.eval()
buffer = io.BytesIO()
mge.save(net.state_dict(), buffer)
buffer.seek(0)
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
traced_net = trace_module(build_func(), inp)
traced_net.load_state_dict(mge.load(buffer))
_check_param(net, traced_net)
buffer.seek(0)
traced_net = trace_module(build_func(), inp).flatten()
traced_net.load_state_dict(mge.load(buffer))
_check_param(net, traced_net)
_check_module(lambda: MyModule())
_check_module(lambda: build_observered_net(MyModule(), Q.MinMaxObserver))
def test_qualname():
def _check_qualname(net):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
net.eval()
traced_net = trace_module(net, inp)
base_qualname = traced_net.graph.qualname
for node in traced_net.graph.nodes():
qualname = node.qualname
qualname = qualname[len(base_qualname) + 1 :]
if qualname.endswith("]"):
qualname = qualname.rsplit(".", 1)[0]
if qualname.startswith("["):
qualname = ""
traced_attr = get_subattr(traced_net, qualname)
orig_attr = get_subattr(net, qualname)
assert traced_attr is not None
assert orig_attr is not None
_check_qualname(MyModule())
_check_qualname(build_observered_net(MyModule(), Q.MinMaxObserver))