megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
import io

import numpy as np

import megengine.core.tensor.megbrain_graph as G
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine.core.tensor.megbrain_graph import OutputNode
from megengine.jit import trace
from megengine.utils.network_node import VarNode


def _default_compare_fn(x, y):
    if isinstance(x, tensor):
        x = x.numpy()
    elif not isinstance(x, np.ndarray):
        x = get_var_value(x)
    assert isinstance(x, np.ndarray)
    np.testing.assert_allclose(x, y, rtol=1e-6)


def make_tensor(x, network=None, device=None):
    if network is not None:
        if isinstance(x, VarNode):
            return VarNode(x.var)
        return network.make_const(x, device=device)
    else:
        return tensor(x, device=device)


def get_var_value(x):
    try:
        o = OutputNode(x.var)
        o.graph.compile(o.outputs).execute()
        return o.get_value().numpy()
    except RuntimeError:
        raise ValueError("value invalid!")


def opr_test(
    cases,
    func,
    compare_fn=_default_compare_fn,
    ref_fn=None,
    test_trace=True,
    network=None,
    **kwargs
):
    """
    :param cases: the list which have dict element, the list length should be 2 for dynamic shape test.
           and the dict should have input,
           and should have output if ref_fn is None.
           should use list for multiple inputs and outputs for each case.
    :param func: the function to run opr.
    :param compare_fn: the function to compare the result and expected, use
        ``np.testing.assert_allclose`` if None.
    :param ref_fn: the function to generate expected data, should assign output if None.

    Examples:

    .. code-block::

        dtype = np.float32
        cases = [{"input": [10, 20]}, {"input": [20, 30]}]
        opr_test(cases,
                 F.eye,
                 ref_fn=lambda n, m: np.eye(n, m).astype(dtype),
                 dtype=dtype)

    """

    def check_results(results, expected, check_shape=True):
        if not isinstance(results, (tuple, list)):
            results = (results,)
        for r, e in zip(results, expected):
            if not isinstance(r, (tensor, VarNode)):
                r = tensor(r)
            if check_shape:
                r_shape = r.numpy().shape
                e_shape = e.shape if isinstance(e, np.ndarray) else ()
                assert r_shape == e_shape
            compare_fn(r, e)

    def get_param(cases, idx):
        case = cases[idx]
        inp = case.get("input", None)
        outp = case.get("output", None)
        if inp is None:
            raise ValueError("the test case should have input")
        if not isinstance(inp, (tuple, list)):
            inp = (inp,)
        if ref_fn is not None and callable(ref_fn):
            outp = ref_fn(*inp)
        if outp is None:
            raise ValueError("the test case should have output or reference function")
        if not isinstance(outp, (tuple, list)):
            outp = (outp,)

        return inp, outp

    def run_index(index):
        inp, outp = get_param(cases, index)
        inp_tensor = [make_tensor(inpi, network) for inpi in inp]

        if test_trace and not network:
            copied_inp = inp_tensor.copy()
            for symbolic in [False, True]:
                traced_func = trace(symbolic=symbolic)(func)

                for _ in range(3):
                    traced_results = traced_func(*copied_inp, **kwargs)
                check_results(traced_results, outp)

            dumped_func = trace(symbolic=True, capture_as_const=True)(func)
            dumped_results = dumped_func(*copied_inp, **kwargs)
            check_results(dumped_results, outp)

            file = io.BytesIO()
            dump_info = dumped_func.dump(file)
            file.seek(0)

            # arg_name has pattern arg_xxx, xxx is int value
            def take_number(arg_name):
                return int(arg_name.split("_")[-1])

            input_names = dump_info[4]
            inps_np = [i.numpy() for i in copied_inp]
            input_names.sort(key=take_number)
            inp_dict = dict(zip(input_names, inps_np))
            infer_cg = cgtools.GraphInference(file)

            # assume #outputs == 1
            loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
            check_results(loaded_results, outp, check_shape=False)  # scalar info lost

        results = func(*inp_tensor, **kwargs)
        check_results(results, outp, check_shape=(network is None))

    if len(cases) == 0:
        raise ValueError("should give one case at least")

    if not callable(func):
        raise ValueError("the input func should be callable")

    for index in range(len(cases)):
        run_index(index)