import io
from itertools import product
import numpy as np
import pytest
import megengine.utils.comp_graph_tools as cgtools
from megengine import jit, tensor
from megengine.device import get_device_count
from megengine.functional import expand_dims
from megengine.module import (
BatchMatMulActivation,
Conv2d,
ConvBn2d,
ConvRelu2d,
ConvTranspose2d,
DequantStub,
Module,
QuantStub,
)
from megengine.quantization.quantize import (
disable_fake_quant,
enable_fake_quant,
quantize,
quantize_qat,
)
def test_qat_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
for groups, bias in product([1, 4], [True, False]):
module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
)
np.testing.assert_allclose(
module.bn.running_mean.numpy(),
qat_module.bn.running_mean.numpy(),
atol=5e-8,
)
np.testing.assert_allclose(
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
)
module.eval()
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
)
@pytest.mark.parametrize(
"padding, padding_mode",
[
(0, "zeros"),
((1, 2), "zeros"),
(3, "reflect"),
((1, 2), "reflect"),
(4, "replicate"),
((1, 2), "replicate"),
],
)
def test_qat_conv(padding, padding_mode):
in_channels = 32
out_channels = 64
kernel_size = 3
class TestNet(Module):
def __init__(self, groups, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv = Conv2d(
in_channels,
out_channels,
kernel_size,
groups=groups,
bias=bias,
padding=padding,
padding_mode=padding_mode,
)
self.conv_relu = ConvRelu2d(
out_channels, in_channels, kernel_size, groups=groups, bias=bias
)
def forward(self, inp):
out = self.quant(inp)
out = self.conv(out)
out = self.conv_relu(out)
out = self.dequant(out)
return out
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
for groups, bias in product([1, 4], [True, False]):
net = TestNet(groups, bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
def test_qat_batchmatmul_activation():
batch = 4
in_features = 8
out_features = 4
class TestNet(Module):
def __init__(self, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.batch_mm = BatchMatMulActivation(
batch, in_features, out_features, bias=bias
)
def forward(self, inp):
out = self.quant(inp)
out = self.batch_mm(out)
out = self.dequant(out)
return out
inputs = tensor(
np.random.randn(batch, in_features, out_features).astype(np.float32)
)
for bias in (True, False):
net = TestNet(bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
@pytest.mark.skip(reason="FIXME: abnormal exit")
def test_quantize_batchmatmul_activation():
batch = 4
in_features = 8
out_features = 4
class TestNet(Module):
def __init__(self, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.batch_mm = BatchMatMulActivation(
batch, in_features, out_features, bias=bias
)
def forward(self, inp):
out = self.quant(inp)
out = self.batch_mm(out)
out = expand_dims(out, -1)
out = self.dequant(out)
return out
inputs = tensor(
np.random.randn(batch, in_features, out_features).astype(np.float32)
)
for bias in (True, False):
net = TestNet(bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
enable_fake_quant(qat_net)
qat_outputs = qat_net(inputs)
qnet = quantize(qat_net, inplace=False)
qnet.eval()
quantize_outputs = qnet(inputs)
np.testing.assert_allclose(
qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
)
@jit.trace(capture_as_const=True)
def f(x):
qnet.eval()
return qnet(x)
f(inputs)
file = io.BytesIO()
f.dump(file, enable_nchw4=True)
file.seek(0)
infer_cg = cgtools.GraphInference(file)[0]
dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
def test_qat_conv_transpose2d():
in_channels = 32
out_channels = 64
kernel_size = 3
class TestNet(Module):
def __init__(self, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv = ConvTranspose2d(
in_channels, out_channels, kernel_size, bias=bias
)
def forward(self, inp):
out = self.quant(inp)
out = self.conv(out)
out = self.dequant(out)
return out
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
for bias in [True, False]:
net = TestNet(bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())