megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from copy import deepcopy
from typing import List, Set

from ...logger import get_logger
from ..traced_module import TracedModule
from .pass_base import get_default_pass_context, get_registered_pass

logger = get_logger(__name__)


def optimize(
    module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"],
) -> TracedModule:
    r"""Performs a set of optimization passes to optimize a `TracedModule` for inference.

    The following passes are currently supported:

        * FuseConvBn: fuse BN layers into to conv2d
        * FuseAddMul: fold adjacent const add or mul binary operations
        * BackwardFoldScale: backward fold const scaling into weights of conv2d

    Args:
        module: the :class:`TracedModule` to be optimized.
        enabled_pass: optimization passes to be enabled during optimization.
            Default: ["FuseConvBn"]

    Returns:
        the optimized :class:`TracedModule`.
    """

    defalut_passes_list = [
        "FuseConvBn",
        "FuseAddMul",
    ]

    if isinstance(enabled_pass, str):
        enabled_pass = [enabled_pass]

    if "BackwardFoldScale" in enabled_pass:
        if "FuseConvBn" not in enabled_pass:
            logger.warning(
                "Since BackwardFoldScale requires FuseConvBn"
                ", FuseConvBn will be enabled."
            )
            enabled_pass.append("FuseConvBn")
        defalut_passes_list.extend(
            ["BackwardFoldScale", "FuseAddMul",]
        )

    pass_ctx = get_default_pass_context()

    def run_pass(mod: TracedModule):
        for pass_name in defalut_passes_list:
            if pass_name in enabled_pass:
                pass_func = get_registered_pass(pass_name)()
                mod = pass_func(mod, pass_ctx)
        return mod

    module = deepcopy(module)
    module = run_pass(module)

    return module