import numpy as np
from megengine.functional.tensor import zeros
from ..core.ops.builtin import BatchNorm
from .expr import CallMethod, Constant
from .node import TensorNode
from .serialization import (
register_functional_loader,
register_module_loader,
register_opdef_loader,
register_tensor_method_loader,
)
@register_module_loader(
("megengine.module.batchnorm", "BatchNorm1d"),
("megengine.module.batchnorm", "BatchNorm2d"),
("megengine.module.batchnorm", "SyncBatchNorm"),
)
def bn2d_module_loader(expr):
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module, "param_dim"):
module.param_dim = "dim_1c11"
@register_module_loader(
("megengine.module.conv_bn", "ConvBn2d"),
("megengine.module.conv_bn", "ConvBnRelu2d"),
("megengine.module.qat.conv_bn", "ConvBn2d"),
("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
)
def convbn2d_module_loader(expr):
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module.bn, "param_dim"):
module.bn.param_dim = "dim_1c11"
module = expr.inputs[0].owner
if not hasattr(module.conv, "padding_mode"):
module.conv.padding_mode = "zeros"
@register_opdef_loader(BatchNorm)
def bn_opdef_loader(expr):
if not hasattr(expr, "version") and len(expr.outputs) != 6:
assert len(expr.outputs) == 5
output = expr.outputs[-1]
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
expr.outputs.insert(4, oup)
@register_functional_loader(
("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros")
)
def tensor_gen_func_loader(expr):
if hasattr(expr, "version") and expr.version == "1.7.0":
expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2])
if not hasattr(expr, "version"):
shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"]
if len(expr.args) > 1:
dtype = expr.args[1]
elif "dtype" in expr.kwargs:
dtype = expr.kwargs["dtype"]
else:
dtype = "float32"
if len(expr.args) > 2:
device = expr.args[2]
elif "device" in expr.kwargs:
device = expr.kwargs["device"]
else:
device = None
expr.set_args_kwargs(shape, dtype=dtype, device=device)
@register_functional_loader(("megengine.functional.nn", "pad"))
def pad_func_loader(expr):
if "pad_witdth" in expr.kwargs:
kwargs = expr.kwargs
kwargs["pad_width"] = kwargs.pop("pad_witdth")
expr.set_args_kwargs(*expr.args, **kwargs)
@register_module_loader(
("megengine.module.conv", "Conv1d"),
("megengine.module.conv", "Conv2d"),
("megengine.module.conv", "ConvRelu2d"),
("megengine.module.qat.conv", "Conv2d"),
("megengine.module.qat.conv", "ConvRelu2d"),
("megengine.module.quantized.conv", "Conv2d"),
("megengine.module.quantized.conv", "ConvRelu2d"),
)
def conv2d_module_loader(expr):
module = expr.inputs[0].owner
if not hasattr(module, "padding_mode"):
module.padding_mode = "zeros"
@register_module_loader(
("megengine.module.quantized.conv_bn", "ConvBn2d"),
("megengine.module.quantized.conv_bn", "ConvBnRelu2d"),
)
def quantized_convbn2d_module_loader(expr):
module = expr.inputs[0].owner
if not hasattr(module, "padding_mode"):
module.padding_mode = "zeros"