from typing import Optional
import numpy as np
from ..distributed.group import WORLD, Group
from ..functional.nn import batch_norm, sync_batch_norm
from ..tensor import Parameter, Tensor
from . import init
from .module import Module
class _BatchNorm(Module):
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
freeze=False,
compute_mode="default",
param_dim="dim_1c11",
**kwargs
):
super(_BatchNorm, self).__init__(**kwargs)
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats
self.freeze = freeze
self.compute_mode = compute_mode
self.param_dim = param_dim
if self.freeze:
assert (
self._track_running_stats_saved
), "track_running_stats must be initilized to True if freeze is True"
tshape = (1, self.num_features, 1, 1)
if self.affine:
self.weight = Parameter(np.ones(tshape, dtype=np.float32))
self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
else:
self.weight = None
self.bias = None
if self.track_running_stats:
self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
else:
self.running_mean = None
self.running_var = None
def reset_running_stats(self) -> None:
if self.track_running_stats:
init.zeros_(self.running_mean)
init.ones_(self.running_var)
def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def _check_input_ndim(self, inp):
raise NotImplementedError
def forward(self, inp):
self._check_input_ndim(inp)
if self._track_running_stats_saved == False:
assert (
self.track_running_stats == False
), "track_running_stats can not be initilized to False and changed to True later"
_weight = self.weight
_bias = self.bias
if self.freeze:
if _weight is not None:
_weight = _weight.detach()
if _bias is not None:
_bias = _bias.detach()
scale = (self.running_var + self.eps) ** (-0.5)
if _weight is not None:
scale *= _weight
bias = -self.running_mean * scale
if _bias is not None:
bias += _bias
return inp * scale + bias
if self.training and self.track_running_stats:
exponential_average_factor = self.momentum
else:
exponential_average_factor = 0.0
output = batch_norm(
inp,
self.running_mean if self.track_running_stats else None,
self.running_var if self.track_running_stats else None,
_weight,
_bias,
training=self.training
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
compute_mode=self.compute_mode,
param_dim=self.param_dim,
)
return output
def _module_info_string(self) -> str:
s = (
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
"track_running_stats={track_running_stats}"
)
return s.format(**self.__dict__)
class SyncBatchNorm(_BatchNorm):
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
freeze=False,
group: Optional[Group] = WORLD,
**kwargs
) -> None:
super().__init__(
num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
)
self.group = group
def _check_input_ndim(self, inp):
if len(inp.shape) not in {2, 3, 4}:
raise ValueError(
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
)
def forward(self, inp):
self._check_input_ndim(inp)
inp_shape = inp.shape
_ndims = len(inp_shape)
if _ndims != 4:
new_shape = Tensor([1, 1, 1, 1], device=inp.device)
origin_shape = inp_shape
if _ndims == 2:
new_shape[:2] = origin_shape[:2]
elif _ndims == 3:
new_shape[:3] = origin_shape[:3]
else:
raise ValueError(
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
)
inp = inp.reshape(new_shape)
if self.training and self.track_running_stats:
exponential_average_factor = self.momentum
else:
exponential_average_factor = 0.0
_weight = self.weight
_bias = self.bias
if self.freeze:
if _weight is not None:
_weight = _weight.detach()
if _bias is not None:
_bias = _bias.detach()
output = sync_batch_norm(
inp,
self.running_mean,
self.running_var,
_weight,
_bias,
training=(self.training and not self.freeze)
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
group=self.group,
)
if _ndims != 4:
output = output.reshape(origin_shape)
return output
class BatchNorm1d(_BatchNorm):
def _check_input_ndim(self, inp):
if len(inp.shape) not in {2, 3}:
raise ValueError(
"expected 2D or 3D input (got {}D input)".format(len(inp.shape))
)
class BatchNorm2d(_BatchNorm):
def _check_input_ndim(self, inp):
if len(inp.shape) != 4:
raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))