import megengine.module as M
from megengine.quantization import quantize, quantize_qat
def test_repr():
class Net(M.Module):
def __init__(self):
super().__init__()
self.conv_bn = M.ConvBnRelu2d(3, 3, 3)
self.linear = M.Linear(3, 3)
def forward(self, x):
return x
net = Net()
ground_truth = (
"Net(\n"
" (conv_bn): ConvBnRelu2d(\n"
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n"
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" )\n"
" (linear): Linear(in_features=3, out_features=3, bias=True)\n"
")"
)
assert net.__repr__() == ground_truth
quantize_qat(net)
ground_truth = (
"Net(\n"
" (conv_bn): QAT.ConvBnRelu2d(\n"
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n"
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (act_observer): ExponentialMovingAverageObserver()\n"
" (act_fake_quant): FakeQuantize()\n"
" (weight_observer): MinMaxObserver()\n"
" (weight_fake_quant): FakeQuantize()\n"
" )\n"
" (linear): QAT.Linear(\n"
" in_features=3, out_features=3, bias=True\n"
" (act_observer): ExponentialMovingAverageObserver()\n"
" (act_fake_quant): FakeQuantize()\n"
" (weight_observer): MinMaxObserver()\n"
" (weight_fake_quant): FakeQuantize()\n"
" )\n"
")"
)
assert net.__repr__() == ground_truth
quantize(net)
ground_truth = (
"Net(\n"
" (conv_bn): Quantized.ConvBnRelu2d(3, 3, kernel_size=(3, 3))\n"
" (linear): Quantized.Linear()\n"
")"
)
assert net.__repr__() == ground_truth