megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
from copy import deepcopy

from ..functional import ones, sqrt, zeros
from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU
from ..tensor import Parameter

_MAP_TO_FUSED_MODULE = {
    (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d,
    (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d,
    (Conv2d, BatchNorm2d, False): Conv2d,
    (Conv2d, BatchNorm2d, True): ConvBn2d,
    (Conv2d, ReLU): ConvRelu2d,
}


def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5):
    # get fold bn conv param
    kernel_shape = weight.shape
    if len(kernel_shape) == 5:
        groups, num_features = kernel_shape[0], kernel_shape[1]
    else:
        groups, num_features = 1, kernel_shape[0]

    if gamma is None:
        gamma = ones((num_features), dtype="float32")
    gamma = gamma.reshape(1, -1, 1, 1)
    if beta is None:
        beta = zeros((num_features), dtype="float32")
    beta = beta.reshape(1, -1, 1, 1)

    if bn_mean is None:
        bn_mean = zeros((1, num_features, 1, 1), dtype="float32")
    if bn_var is None:
        bn_var = ones((1, num_features, 1, 1), dtype="float32")

    if bias is None:
        bias = zeros((1, num_features, 1, 1), dtype="float32")

    bn_istd = 1.0 / sqrt(bn_var + eps)
    scale_factor = gamma * bn_istd

    if groups == 1:
        w_fold = weight * scale_factor.reshape(-1, 1, 1, 1)
    else:
        w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1)

    b_fold = beta + gamma * (bias - bn_mean) * bn_istd
    return w_fold, b_fold


def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
    module_key = tuple([type(m) for m in [conv, bn, relu] if m])
    if bn:
        assert (
            conv.training == bn.training
        ), "Conv and BN both must be in the same mode (train or eval)."
        assert (
            bn.num_features == conv.out_channels
        ), "Output channel of Conv2d must match num_features of BatchNorm2d"
        module_key = module_key + (conv.training,)
    module = _MAP_TO_FUSED_MODULE[module_key](
        in_channels=conv.in_channels,
        out_channels=conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        groups=conv.groups,
        bias=conv.bias is not None,
        conv_mode=conv.conv_mode,
        compute_mode=conv.compute_mode,
        name=conv.name,
    )
    new_conv = module if bn is None or not conv.training else module.conv
    weight, bias = conv.weight, conv.bias
    if not conv.training and bn is not None:
        weight, bias = fold_weight_bias(
            weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps,
        )
    new_conv.weight = Parameter(weight)
    if bias is not None:
        new_conv.bias = Parameter(bias)
    if bn is not None and conv.training:
        module.bn = deepcopy(bn)
    new_conv.training = conv.training
    return module