megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import traceback
from typing import Sequence

import numpy as np

from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.ops import ROIAlign, ROIPooling
from ..core.ops.builtin import Copy
from ..tensor import Tensor
from .tm_config import _exclude_from_trace


class TracedModuleChecker:
    def __init__(self, tracer):
        self._active_node2values = []
        self.tracer = tracer

        self.node_without_tensor_info = {}

    def push_scope(self):
        self._active_node2values.append({})

    def pop_scope(self):
        self._active_node2values.pop()

    def current_node2values(self):
        return self._active_node2values[-1]

    def reset_checker(self):
        self._active_node2values = []

    def check_node_not_in_scope(self):
        if self.node_without_tensor_info:
            for node, info in self.node_without_tensor_info.items():
                for expr in info[0]._exprs:
                    if node in expr.inputs or node in expr.outputs:
                        traceback.print_list(info[1])
                        raise ValueError(
                            "node({}) not in the graph:\n{}".format(node, info[0])
                        )
            return True
        else:
            return False

    def check_net_outputs(self, tm_res, gt_res):
        if isinstance(tm_res, Tensor):
            np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy())
        elif isinstance(tm_res, Sequence):
            for i, j in zip(tm_res, gt_res):
                np.testing.assert_allclose(i.numpy(), j.numpy())
        else:
            for k in tm_res.__dict__.keys():
                np.testing.assert_allclose(
                    getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy()
                )

    def record_nodemixin(self, node, value):
        self.current_node2values()[node] = value

    def record_node2value(self, node, value):
        with _exclude_from_trace():
            self.current_node2values()[node] = apply(
                Copy(comp_node=value.device), value
            )[0]

    def check_apply_special_cases(self, opdef, num_outputs):
        indexs = list(range(num_outputs))
        if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE:
            indexs.pop(-1)
        if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE:
            indexs.pop(-1)
        return indexs

    def check_expr_results(self, expr_outputs, gt_outputs, indexs=None):
        expr_outputs = (
            (expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs
        )
        gt_outputs = (
            (gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs
        )
        if indexs is not None:
            for i in indexs:
                np.testing.assert_allclose(
                    expr_outputs[i].numpy(), gt_outputs[i].numpy()
                )
        else:
            np.testing.assert_allclose(expr_outputs, gt_outputs)

    def get_node2value(self, inputs, start_idx=0):
        inp_values = []
        has_node_not_in_scope = False
        for i in range(start_idx, len(inputs)):
            try:
                inp_values.append(self.current_node2values()[inputs[i]])
            except:
                has_node_not_in_scope = True
                self.node_without_tensor_info[inputs[i]] = [
                    self.tracer.current_scope(),
                    traceback.extract_stack(),
                ]
        return inp_values, has_node_not_in_scope

    def check_expr_interpret(self, expr, gt_outputs):
        ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
        if not has_node_not_in_scope:
            expr_res = expr.interpret(*ori_in)
            try:
                self.check_expr_results(expr_res, gt_outputs)
            except:
                raise ValueError("Error occurred when checking expr: {}".format(expr))

    def check_apply(self, expr, gt_outputs, opdef):
        ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
        if not has_node_not_in_scope:
            expr_res = expr.interpret(*ori_in)
            indexs = self.check_apply_special_cases(opdef, len(gt_outputs))
            try:
                self.check_expr_results(expr_res, gt_outputs, indexs=indexs)
            except:
                raise ValueError("Error occurred when checking expr: {}".format(expr))

    def check_builtin_module(self, module, expr, gt_outputs):
        ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1)
        if not has_node_not_in_scope:
            ori_in.insert(0, module)
            expr_res = expr.interpret(*ori_in)
            try:
                self.check_expr_results(expr_res, gt_outputs)
            except:
                raise ValueError(
                    "{}, Error occurred when checking expr: {}".format(expr)
                )