megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ... import functional as F
from ... import module as M
from ...core.ops.builtin import GetVarShape
from ...logger import get_logger
from ...tensor import Tensor
from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr
from ..node import Node, NodeMixin, TensorNode
from .matcher import PatternMatcher
from .pass_base import BackwardPass, ForwardPass, register_pass
from .pattern import is_op
from .utils import get_const_value

logger = get_logger(__name__)


def _as_const_node(x):
    node = Constant.make(x)
    NodeMixin.wrap(x, node)
    return node


@register_pass("AttrToConstant")
class AttrToConstant(BackwardPass):
    r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr."""
    name = "AttrToConstant"
    run_once = True

    def run_transform(self, expr: Expr):
        if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)):
            return expr
        graph = expr.top_graph
        value = get_const_value(expr)
        orig_node = expr.outputs[0]
        name = orig_node.name
        with graph.insert_exprs(expr):
            const_node = _as_const_node(value)
        graph.replace_node({orig_node: const_node})
        graph.compile()
        const_node.name = name
        return const_node.expr


@register_pass("FixInputShape")
class FixInputShape(BackwardPass):
    name = "FixInputShape"
    run_once = True

    def run_transform(self, expr: Expr):
        if not is_apply_def(expr, GetVarShape):
            return expr
        shape = Tensor(expr.inputs[0].shape, dtype="int32")
        graph = expr.top_graph
        with graph.insert_exprs(expr):
            const_shape = _as_const_node(shape)
        graph.replace_node({expr.outputs[0]: const_shape})
        graph.compile()
        const_shape.name = expr.outputs[0].name
        return const_shape.expr


@register_pass("FlodConstant")
class FlodConstant(ForwardPass):
    r"""Constant folding."""
    name = "FlodConstant"
    required_pass = ["AttrToConstant"]
    run_once = False

    def run_transform(self, expr: Expr):
        if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs):
            return expr
        const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0]
        graph = expr.top_graph
        with graph.insert_exprs(expr):
            const_node = _as_const_node(const_var)
        graph.replace_node({expr.outputs[0]: const_node})
        graph.compile()
        const_node.name = expr.outputs[0].name
        return const_node.expr


@register_pass("NormElemWise")
class NormElemWise(BackwardPass):
    r"""Transform add/sub or mul/div expr to add-only or mul-only chains.
    
    For example, the following code

    .. code-block::

        b = 1 - a
        c = 2 * b
        d = 1 / c

    will be changed to

    .. code-block::
    
        a1 = F.neg(a)
        b = a1 + 1
        c = b * 2
        d = F.pow(d, -1)
    """
    name = "NormElemWise"
    required_pass = ["FlodConstant"]
    run_once = False

    def __init__(self,):
        super().__init__()
        self.pattern = is_op(F.add)
        for op in [F.sub, F.mul, F.div]:
            self.pattern |= is_op(op)
        for op in ["__add__", "__iadd__", "__radd__"]:
            self.pattern |= is_op(op)
        for op in ["__sub__", "__isub__", "__rsub__"]:
            self.pattern |= is_op(op)
        for op in ["__mul__", "__imul__", "__rmul__"]:
            self.pattern |= is_op(op)
        for op in ["__truediv__", "__itruediv__", "__rtruediv__"]:
            self.pattern |= is_op(op)

    def run_transform(self, expr: Expr):

        matcher = PatternMatcher()
        if not matcher.match(self.pattern, expr):
            return expr

        pattern = matcher.matched_patterns[0]
        target = pattern.target
        cofee, left_node, right_node = 1, None, None
        if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]:
            left_node = expr.inputs[0]
            right_node = expr.const_val[0][-1]
            if target in ["__rsub__", "__rtruediv__"]:
                cofee = -1
            if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
                cofee = -1
        elif len(expr.inputs) == 2 and (
            target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr)
        ):
            left_node, right_node = expr.inputs
            if target in ["__rsub__", "__rtruediv__"]:
                left_node, right_node = right_node, left_node
            if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
                left_node, right_node = right_node, left_node
            if is_constant(left_node.expr):
                left_node, right_node = right_node, left_node
                cofee = -1

        if left_node is None:
            return expr

        if isinstance(right_node, TensorNode):
            right_node = get_const_value(right_node.expr, right_node)

        graph = expr.top_graph
        with graph.insert_exprs():
            if target in ["__mul__", "__imul__", "__rmul__", F.mul]:
                out_node = left_node * right_node
            elif target in ["__add__", "__iadd__", "__radd__", F.add]:
                out_node = left_node + right_node
            elif target in ["__sub__", "__isub__", "__rsub__", F.sub]:
                if cofee == -1:
                    left_node = F.neg(left_node)
                else:
                    if isinstance(right_node, TensorNode):
                        right_node = F.neg(right_node)
                    else:
                        right_node = -1 * right_node
                out_node = left_node + right_node
            elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]:
                if cofee == -1:
                    left_node = F.pow(left_node, -1)
                else:
                    if isinstance(right_node, TensorNode):
                        right_node = F.pow(right_node, -1)
                    else:
                        right_node = 1 / right_node
                out_node = left_node * right_node

        graph.replace_node({expr.outputs[0]: out_node})
        graph.compile()
        return out_node.expr