import functools
import platform
import numpy as np
import pytest
import megengine as mge
import megengine.amp as amp
import megengine.distributed as dist
from megengine import Tensor, jit
from megengine.autodiff.grad_manager import GradManager
from megengine.core._trace_option import use_symbolic_shape
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
@pytest.mark.parametrize("enable_amp", [False, True])
def test_syncbn(enable_amp):
nr_chan = 8
data_shape = (3, nr_chan, 4, 16)
momentum = 0.9
eps = 1e-5
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
steps = 4
nr_ranks = 2
server = dist.Server()
port = server.py_server_port
@dist.launcher(n_gpus=2)
def worker(data, yv_expect, running_mean, running_var):
with amp.autocast(enabled=enable_amp):
rank = dist.get_rank()
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
for i in range(steps):
yv = bn(Tensor(data[rank][i]))
if enable_amp:
np.testing.assert_allclose(
yv.numpy(), yv_expect[rank], atol=5e-4, rtol=5e-4
)
else:
_assert_allclose(yv.numpy(), yv_expect[rank])
_assert_allclose(bn.running_mean.numpy(), running_mean)
_assert_allclose(bn.running_var.numpy(), running_var)
xv = []
for i in range(steps):
xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + eps)
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)
yv_expect = (xv[i] - mean) / sd
data = []
for i in range(nr_ranks):
data.append([])
for j in range(steps):
data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8])
yv_expect = [yv_expect[:, :, :, i * 8 : i * 8 + 8] for i in range(nr_ranks)]
worker(data, yv_expect, running_mean, running_var)
def test_batchnorm():
nr_chan = 8
data_shape = (3, nr_chan, 4)
momentum = 0.9
bn = BatchNorm1d(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
)
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
sd = np.sqrt(var_biased + bn.eps)
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)
yv = bn(Tensor(xv))
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
_assert_allclose(bn.running_mean.numpy().reshape(-1), running_mean.reshape(-1))
_assert_allclose(bn.running_var.numpy().reshape(-1), running_var.reshape(-1))
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data = Tensor(xv)
yv1 = bn(data)
yv2 = bn(data)
np.testing.assert_equal(yv1.numpy(), yv2.numpy())
np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
np.testing.assert_equal(var_backup, bn.running_var.numpy())
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
_assert_allclose(yv1.numpy(), yv_expect)
def test_syncbn1d():
nr_chan = 8
data_shape = (3, nr_chan, 4)
momentum = 0.9
bn = SyncBatchNorm(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
)
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
sd = np.sqrt(var_biased + bn.eps)
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)
yv = bn(Tensor(xv))
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
_assert_allclose(bn.running_mean.numpy().reshape(-1), running_mean.reshape(-1))
_assert_allclose(bn.running_var.numpy().reshape(-1), running_var.reshape(-1))
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data = Tensor(xv)
yv1 = bn(data)
yv2 = bn(data)
np.testing.assert_equal(yv1.numpy(), yv2.numpy())
np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
np.testing.assert_equal(var_backup, bn.running_var.numpy())
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
_assert_allclose(yv1.numpy(), yv_expect)
def test_batchnorm2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
momentum = 0.9
bn = BatchNorm2d(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + bn.eps)
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)
yv = bn(Tensor(xv))
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
_assert_allclose(bn.running_mean.numpy(), running_mean)
_assert_allclose(bn.running_var.numpy(), running_var)
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data = Tensor(xv)
yv1 = bn(data)
yv2 = bn(data)
np.testing.assert_equal(yv1.numpy(), yv2.numpy())
np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
np.testing.assert_equal(var_backup, bn.running_var.numpy())
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
_assert_allclose(yv1.numpy(), yv_expect)
def test_syncbn2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
momentum = 0.9
bn = SyncBatchNorm(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + bn.eps)
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)
yv = bn(Tensor(xv))
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
_assert_allclose(bn.running_mean.numpy(), running_mean)
_assert_allclose(bn.running_var.numpy(), running_var)
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data = Tensor(xv)
yv1 = bn(data)
yv2 = bn(data)
np.testing.assert_equal(yv1.numpy(), yv2.numpy())
np.testing.assert_equal(mean_backup, bn.running_mean.numpy())
np.testing.assert_equal(var_backup, bn.running_var.numpy())
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
_assert_allclose(yv1.numpy(), yv_expect)
def test_batchnorm_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
bn = BatchNorm1d(8, track_running_stats=False)
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
var = np.var(
np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
),
axis=0,
).reshape((1, nr_chan, 1))
sd = np.sqrt(var + bn.eps)
yv = bn(Tensor(xv))
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
def test_syncbn_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
bn = SyncBatchNorm(8, track_running_stats=False)
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
var = np.var(
np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
),
axis=0,
).reshape((1, nr_chan, 1))
sd = np.sqrt(var + bn.eps)
yv = bn(Tensor(xv))
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
def test_batchnorm2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
bn = BatchNorm2d(8, track_running_stats=False)
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var + bn.eps)
yv = bn(Tensor(xv))
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
def test_syncbn2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
bn = SyncBatchNorm(8, track_running_stats=False)
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var + bn.eps)
yv = bn(Tensor(xv))
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
def test_syncbn2d_grad():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
syncbn = SyncBatchNorm(8, track_running_stats=False)
bn = BatchNorm2d(8, track_running_stats=False)
for i in range(4):
if i == 2:
syncbn.training = False
bn.training = False
inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
diff = Tensor(np.random.normal(size=data_shape).astype(np.float32))
with GradManager().attach(inp) as gm:
oup = syncbn(inp)
gm.backward(oup, diff)
grad = inp.grad
inp.grad = None
with GradManager().attach(inp) as gm:
oup_expect = bn(inp)
gm.backward(oup_expect, diff)
grad_expect = inp.grad
inp.grad = None
_assert_allclose(oup.numpy(), oup_expect.numpy())
_assert_allclose(grad.numpy(), grad_expect.numpy())
@pytest.mark.parametrize("dim", [1, 2])
@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_batchnorm_empty_tensor(dim, is_symbolic):
if dim == 1:
m = BatchNorm1d(4, affine=True)
inp = mge.tensor(np.random.randn(0, 4, 0).astype("float32"))
elif dim == 2:
m = BatchNorm2d(4, affine=True)
inp = mge.tensor(np.random.randn(0, 4, 0, 0).astype("float32"))
else:
raise NotImplementedError
m.train()
def fn(inp):
return m(inp)
if is_symbolic is not None:
fn = jit.trace(symbolic=is_symbolic)(fn)
for _ in range(3):
out = fn(inp)
np.testing.assert_equal(out.numpy(), inp)
if is_symbolic is None:
break