megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import copy
from abc import abstractmethod
from collections import OrderedDict, namedtuple
from functools import partial
from re import T
from typing import Any, Callable, Dict, Iterable, List, Union

from ...logger import get_logger
from ..expr import Expr
from ..traced_module import InternalGraph, TracedModule
from .utils import register_obj

logger = get_logger(__name__)


class PassContext:
    def __init__(
        self, disabled_pass: Iterable[str] = None, pass_config: Dict[str, Any] = None
    ):
        self._disabled_pass = set()
        self._config = pass_config
        self._handle = None
        if disabled_pass:
            self.add_diabled_pass(disabled_pass)

    def add_diabled_pass(self, passes: Iterable[str]):
        if isinstance(passes, str):
            passes = [passes]
        for pas in passes:
            self._disabled_pass.add(pas)

    def pass_enabled(self, pas: Union["BasePass", str]):
        pass_name = pas.name if isinstance(pas, BasePass) else pas
        return pass_name not in self._disabled_pass


_default_context = PassContext()


def get_default_pass_context():
    return _default_context


_pass_dict = OrderedDict()
register_pass = partial(register_obj, _dict=_pass_dict)


def get_registered_pass(pass_name: str):
    pas = _pass_dict.get(pass_name, None)
    assert (
        pas is not None
    ), "{} is not found, please call `register_pass` to register it".format(pass_name)
    return pas


class BasePass:
    run_once = True  # bool
    required_pass = []  # Iterable[str]
    name = ""  # str

    def __init__(self):
        super().__init__()

    def __call__(
        self, mod: TracedModule, pass_ctx: PassContext = get_default_pass_context()
    ) -> TracedModule:
        assert isinstance(pass_ctx, PassContext)
        return self.apply_optimization(mod, pass_ctx)

    def apply_optimization(
        self, mod: TracedModule, pass_ctx: PassContext
    ) -> TracedModule:
        new_mod = mod
        for pass_name in self.required_pass + [self.name]:
            if not pass_ctx.pass_enabled(pass_name):
                logger.warning(
                    "Since {} is disabled, {} will skipped".format(pass_name, self.name)
                )
                return mod

        for pass_name in self.required_pass:
            pass_func = get_registered_pass(pass_name)()
            new_mod = pass_func(new_mod, pass_ctx)

        iter_num = 1
        graph_changed = self.visit_graph(new_mod.graph)
        while not self.run_once and graph_changed:
            graph_changed = self.visit_graph(new_mod.graph)
            iter_num += 1
            if iter_num == 100:
                break
        assert iter_num < 100, "{} was run 100 times, plase check for pass conflict."

        return new_mod

    @abstractmethod
    def visit_graph(self, graph: InternalGraph):
        raise NotImplementedError

    def before_visit_graph(self, graph: InternalGraph):
        pass

    def run_transform(self, expr: Expr) -> Expr:
        return expr

    def __repr__(self) -> str:
        return self.name


class ForwardPass(BasePass):
    def visit_graph(self, graph: InternalGraph):
        class Item:
            def __init__(self, expr: Expr, child_expanded: bool = False):
                self.expr = expr
                self.child_expanded = child_expanded

        self.before_visit_graph(graph)
        graph_changed = False
        queue = [Item(n.expr) for n in graph.outputs]
        visited_expr, visited_graph = set(), set()
        while queue:
            item = queue[-1]
            if item.expr in visited_expr:
                queue.pop()
            elif item.child_expanded:
                if item.expr not in graph._exprs:
                    queue.pop()
                    continue
                new_expr = self.run_transform(item.expr)
                if new_expr is not item.expr:
                    graph_changed = True
                    assert new_expr not in visited_expr
                    queue.append(Item(new_expr))
                    continue
                if (
                    hasattr(item.expr, "graph")
                    and item.expr.graph is not None
                    and item.expr.graph not in visited_graph
                ):
                    graph_changed |= self.visit_graph(item.expr.graph)
                    visited_graph.add(item.expr.graph)
                visited_expr.add(item.expr)
            else:
                item.child_expanded = True
                for i in item.expr.inputs:
                    expr = i.expr
                    if expr not in queue and expr not in visited_expr:
                        queue.append(Item(expr))
        return graph_changed


class BackwardPass(BasePass):
    def visit_graph(self, graph: InternalGraph):
        self.before_visit_graph(graph)
        graph_changed = False
        queue = [n.expr for n in graph.outputs]
        visited_expr, visited_graph = set(), set()
        while queue:
            expr = queue.pop()
            if expr not in graph._exprs:
                continue
            new_expr = self.run_transform(expr)
            if new_expr is not expr:
                graph_changed = True
                queue.append(new_expr)
                continue
            else:
                visited_expr.add(expr)

            if (
                hasattr(expr, "graph")
                and expr.graph is not None
                and expr.graph not in visited_graph
            ):
                graph_changed |= self.visit_graph(expr.graph)
                visited_graph.add(expr.graph)

            for i in expr.inputs:
                expr = i.expr
                if expr not in queue and expr not in visited_expr:
                    queue.append(expr)
        return graph_changed