megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from importlib import import_module
from typing import Dict, Tuple

from ..core._imperative_rt import OpDef
from ..core.ops import builtin
from ..tensor import Tensor
from ..version import __version__
from .utils import _convert_kwargs_to_args

OPDEF_LOADER = {}
FUNCTIONAL_LOADER = {}
TENSORMETHOD_LOADER = {}
MODULE_LOADER = {}


class _ModuleState:
    obj = None

    def __init__(self, module: Tuple, state: Dict, version: str):
        self.module = module
        self.state = state
        self.version = version

    @classmethod
    def get_module_state(cls, module):
        typem = (type(module).__module__, type(module).__qualname__)
        state = module.__dict__.copy()
        state.pop("_m_dump_modulestate", None)
        if hasattr(module, "_m_dump_modulestate"):
            assert isinstance(module._m_dump_modulestate, cls)
            module._m_dump_modulestate.__init__(typem, state, __version__)
        else:
            module.__dict__["_m_dump_modulestate"] = _ModuleState(
                typem, state, __version__
            )

        return module._m_dump_modulestate

    def __getstate__(self):
        return {"module": self.module, "state": self.state, "version": self.version}

    def to_module(self):
        if self.obj is None:
            typem = getattr(import_module(self.module[0]), self.module[1])
            m_obj = typem.__new__(typem)
            m_obj.__setstate__(self.state)
            self.obj = m_obj
        return self.obj


def register_opdef_loader(*opdefs):
    def callback(loader):
        for opdef in opdefs:
            assert opdef not in OPDEF_LOADER
            OPDEF_LOADER[opdef] = loader
        return loader

    return callback


def register_functional_loader(*funcs):
    def callback(loader):
        for func in funcs:
            assert func not in FUNCTIONAL_LOADER
            FUNCTIONAL_LOADER[func] = loader
        return loader

    return callback


def register_module_loader(*module_types):
    def callback(loader):
        for module_type in module_types:
            assert module_type not in MODULE_LOADER
            MODULE_LOADER[module_type] = loader
        return loader

    return callback


def register_tensor_method_loader(*methods):
    def callback(loader):
        for method in methods:
            assert method not in TENSORMETHOD_LOADER
            TENSORMETHOD_LOADER[method] = loader
        return loader

    return callback


def _replace_args_kwargs(expr, new_args, new_kwargs):
    if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set(
        expr.kwargs.keys()
    ):
        expr.set_args_kwargs(*new_args, **new_kwargs)


def load_functional(expr):
    func = (
        (expr.func.__module__, expr.func.__qualname__)
        if callable(expr.func)
        else expr.func
    )
    assert isinstance(func, tuple)
    if func in FUNCTIONAL_LOADER:
        loader = FUNCTIONAL_LOADER[func]
        loader(expr)
        mname, fname = func
        f = import_module(mname)
        for i in fname.split("."):
            f = getattr(f, i)
        expr.func = f
    assert callable(expr.func)
    if not hasattr(expr, "version") or expr.version != __version__:
        args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs)
        _replace_args_kwargs(expr, args, kwargs)


def load_call_module_expr(expr):
    m_type = expr.inputs[0].module_type
    if isinstance(m_type, type):
        m_type = (m_type.__module__, m_type.__qualname__)
    if m_type in MODULE_LOADER:
        MODULE_LOADER[m_type](expr)
    if isinstance(expr.inputs[0].module_type, tuple):
        mname, classname = expr.inputs[0].module_type
        expr.inputs[0].module_type = getattr(import_module(mname), classname)
    if not hasattr(expr, "version") or expr.version != __version__:
        fwd_func = getattr(expr.inputs[0].module_type, "forward")
        args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs)
        _replace_args_kwargs(expr, args, kwargs)


def load_call_tensor_method_expr(expr):
    if expr.method in TENSORMETHOD_LOADER:
        loader = TENSORMETHOD_LOADER[expr.method]
        loader(expr)
    if not hasattr(expr, "version") or expr.version != __version__:
        tmethod = (
            getattr(expr.args[0], expr.method)
            if isinstance(expr.args[0], type)
            else getattr(Tensor, expr.method)
        )
        args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs)
        _replace_args_kwargs(expr, args, kwargs)


def load_apply_expr(expr):
    opdef_type = type(expr.opdef)
    if opdef_type in OPDEF_LOADER:
        OPDEF_LOADER[opdef_type](expr)
        opdef_state = expr.opdef_state
        opdef_obj = opdef_state.pop("opdef_type")()
        opdef_obj.__setstate__(opdef_state)
        expr.opdef = opdef_obj