megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from abc import abstractmethod
from typing import Any, Callable, Dict, List

from ...core._imperative_rt import OpDef
from ...logger import get_logger
from ...module import Module
from ..expr import Expr
from ..node import Node

logger = get_logger(__name__)


class ExprPattern:
    def __init__(self):
        self._check_users = True
        self._users = []

    def __call__(self, *args):
        args = list(args)
        if len(args) == 1 and args[0] is None:
            args = None
        return CallPattern(self, *args)

    def __add__(self, other):
        return is_op("__add__")(self, other)

    def __iadd__(self, other):
        return is_op("__iadd__")(self, other)

    def __radd__(self, other):
        return is_op("__radd__")(self, other)

    def __sub__(self, other):
        return is_op("__sub__")(self, other)

    def __isub__(self, other):
        return is_op("__isub__")(self, other)

    def __rsub__(self, other):
        return is_op("__rsub__")(self, other)

    def __mul__(self, other):
        return is_op("__mul__")(self, other)

    def __imul__(self, other):
        return is_op("__imul__")(self, other)

    def __rmul__(self, other):
        return is_op("__rmul__")(self, other)

    def __truediv__(self, other):
        return is_op("__truediv__")(self, other)

    def __itruediv__(self, other):
        return is_op("__itruediv__")(self, other)

    def __rtruediv__(self, other):
        return is_op("__rtruediv__")(self, other)

    def __or__(self, other):
        assert isinstance(other, ExprPattern)
        return OrPattern(self, other)

    def get_output(self, index):
        raise NotImplementedError

    def check_users(self, check: bool = True):
        self._check_users = check
        return self

    def _add_users(self, pattern: "ExprPattern"):
        self._users.append(pattern)

    def _clear_users(self,):
        self._users.clear()

    def __getitem__(self, index):
        return is_op("__getitem__")(self, index)

    def has_attr(self, **attrs):
        logger.warning("has_param only support ModulePattern")
        return self

    def has_param(self, **params):
        logger.warning("has_param only support FunctionPattern")
        return self

    @abstractmethod
    def __repr__(self) -> str:
        raise NotImplementedError


class CallPattern(ExprPattern):
    def __init__(self, op: ExprPattern, *args: List[ExprPattern]):
        super().__init__()
        self.op = op
        self.args = list(filter(lambda x: isinstance(x, ExprPattern), args))
        self._match_all_args = True

    def __repr__(self) -> str:
        return "{}({})".format(self.op, ",".join(str(x) for x in self.args))

    def not_all_args(self):
        self._match_all_args = False

    def check_users(self, check: bool = True):
        self._check_users = check
        self.op.check_users(check)
        return self

    def _add_users(self, pattern: "ExprPattern"):
        self._users.append(pattern)
        self.op._add_users(pattern)

    def _clear_users(self):
        self._users.clear()
        self.op._clear_users()


class OrPattern(ExprPattern):
    def __init__(self, left: ExprPattern, right: ExprPattern):
        super().__init__()
        self.left = left
        self.right = right

    def __repr__(self) -> str:
        return "({}|{})".format(self.left, self.right)

    def check_users(self, check: bool = True):
        self._check_users = check
        self.left.check_users(check)
        self.right.check_users(check)
        return self

    def _clear_users(self):
        self._users.clear()
        self.left._clear_users()
        self.right._clear_users()


class GetOutputPaterrn(ExprPattern):
    def __init__(self, op, index):
        super().__init__()
        self.op = op
        self.index = index

    def __repr__(self) -> str:
        return "{}[{}]".format(self.op, self.index)


class ModulePattern(ExprPattern):
    def __init__(self, module_cls: Module) -> None:
        super().__init__()
        self.attrs = {}
        self.target = module_cls

    def has_attr(self, **attrs):
        self.attrs.update(attrs)
        return self

    def __repr__(self) -> str:
        return "{}".format(self.target.__name__)


class FunctionPattern(ExprPattern):
    def __init__(self, func: Callable):
        super().__init__()
        self.params = {}
        self.target = func

    def has_params(self, **params):
        self.params.update(params)
        return self

    def __repr__(self) -> str:
        return "{}".format(self.target.__name__)


class TensorMethodPattern(ExprPattern):
    def __init__(self, method: str):
        super().__init__()
        self.target = method

    def __repr__(self) -> str:
        return self.target


class ApplyDefPattern(ExprPattern):
    def __init__(self, opdef: OpDef):
        super().__init__()
        self.target = opdef

    def __repr__(self) -> str:
        return "{}".format(self.target.__name__)


class VarPattern(ExprPattern):
    def __init__(self):
        super().__init__()

    def __repr__(self) -> str:
        return "var"


class ConstantPattern(ExprPattern):
    def __init__(self):
        super().__init__()

    def __repr__(self) -> str:
        return "const"


class AnyPattern(ExprPattern):
    def __init__(self):
        super().__init__()

    def __repr__(self) -> str:
        return "any"


def is_op(target):
    if isinstance(target, type):
        if issubclass(target, Module):
            return ModulePattern(target)
        if issubclass(target, OpDef):
            return ApplyDefPattern(target)
    elif callable(target):
        return FunctionPattern(target)
    elif isinstance(target, str):
        return TensorMethodPattern(target)
    else:
        raise ValueError("not support")


def is_const():
    return ConstantPattern().check_users(False)


def any_node():
    return AnyPattern()


def is_var():
    return VarPattern()