megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
import types
from functools import partial

from .. import functional as F
from .. import module as M
from ..utils.module_utils import set_module_mode_safe


def get_norm_mod_value(weight, norm_value):
    weight = weight.reshape(-1)
    norm = F.norm(weight)
    scale = norm_value / norm
    round_log = F.floor(F.log(scale) / F.log(2))
    rounded_scale = 2 ** round_log
    return rounded_scale.detach()


def get_scaled_model(model, scale_submodel, input_shape=None):
    submodule_list = None
    scale_value = None
    accumulated_scale = 1.0

    def scale_calc(mod_calc_func):
        def calcfun(self, inp, weight, bias):
            scaled_weight = weight
            scaled_bias = bias
            if self.training:
                scaled_weight = (
                    weight * self.weight_scale if weight is not None else None
                )
                scaled_bias = bias * self.bias_scale if bias is not None else None
            return mod_calc_func(inp, scaled_weight, scaled_bias)

        return calcfun

    def scale_module_structure(
        scale_list: list = None, scale_value: tuple = None,
    ):
        nonlocal accumulated_scale
        for i in range(len(scale_list)):
            key, mod = scale_list[i]
            w_scale_value = scale_value[1]
            if scale_value[0] is not "CONST":
                w_scale_value = get_norm_mod_value(mod.weight, scale_value[1])

            accumulated_scale *= w_scale_value

            mod.weight_scale = w_scale_value
            mod.bias_scale = accumulated_scale

            if isinstance(mod, M.conv.Conv2d):
                mod.calc_conv = types.MethodType(scale_calc(mod.calc_conv), mod)
            else:
                mod._calc_linear = types.MethodType(scale_calc(mod._calc_linear), mod)

    def forward_hook(submodel, inputs, outpus, modelname=""):
        nonlocal submodule_list
        nonlocal scale_value
        nonlocal accumulated_scale
        if modelname in scale_submodel:
            scale_value = scale_submodel[modelname]
            if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)):
                scale_module_structure([(modelname, submodel)], scale_value)
            else:
                submodule_list = []

        if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)) and (
            submodule_list is not None
        ):
            submodule_list.append((modelname, submodel))

        if isinstance(submodel, M.batchnorm.BatchNorm2d) and (
            submodule_list is not None
        ):
            scale_module_structure(submodule_list, scale_value)
            submodule_list = None
            scale_value = None
            accumulated_scale = 1.0

    if input_shape is None:
        raise ValueError("input_shape is required for calculating scale value")

    input = F.zeros(input_shape)

    hooks = []
    for modelname, submodel in model.named_modules():
        hooks.append(
            submodel.register_forward_pre_hook(
                partial(forward_hook, modelname=modelname, outpus=None)
            )
        )

    with set_module_mode_safe(model, training=False) as model:
        model(input)

    for hook in hooks:
        hook.remove()

    return model