import io
import numpy as np
import megengine.core.tensor.megbrain_graph as G
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine.core.tensor.megbrain_graph import OutputNode
from megengine.jit import trace
from megengine.utils.network_node import VarNode
def _default_compare_fn(x, y):
if isinstance(x, tensor):
x = x.numpy()
elif not isinstance(x, np.ndarray):
x = get_var_value(x)
assert isinstance(x, np.ndarray)
np.testing.assert_allclose(x, y, rtol=1e-6)
def make_tensor(x, network=None, device=None):
if network is not None:
if isinstance(x, VarNode):
return VarNode(x.var)
return network.make_const(x, device=device)
else:
return tensor(x, device=device)
def get_var_value(x):
try:
o = OutputNode(x.var)
o.graph.compile(o.outputs).execute()
return o.get_value().numpy()
except RuntimeError:
raise ValueError("value invalid!")
def opr_test(
cases,
func,
compare_fn=_default_compare_fn,
ref_fn=None,
test_trace=True,
network=None,
**kwargs
):
def check_results(results, expected, check_shape=True):
if not isinstance(results, (tuple, list)):
results = (results,)
for r, e in zip(results, expected):
if not isinstance(r, (tensor, VarNode)):
r = tensor(r)
if check_shape:
r_shape = r.numpy().shape
e_shape = e.shape if isinstance(e, np.ndarray) else ()
assert r_shape == e_shape
compare_fn(r, e)
def get_param(cases, idx):
case = cases[idx]
inp = case.get("input", None)
outp = case.get("output", None)
if inp is None:
raise ValueError("the test case should have input")
if not isinstance(inp, (tuple, list)):
inp = (inp,)
if ref_fn is not None and callable(ref_fn):
outp = ref_fn(*inp)
if outp is None:
raise ValueError("the test case should have output or reference function")
if not isinstance(outp, (tuple, list)):
outp = (outp,)
return inp, outp
def run_index(index):
inp, outp = get_param(cases, index)
inp_tensor = [make_tensor(inpi, network) for inpi in inp]
if test_trace and not network:
copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)
for _ in range(3):
traced_results = traced_func(*copied_inp, **kwargs)
check_results(traced_results, outp)
dumped_func = trace(symbolic=True, capture_as_const=True)(func)
dumped_results = dumped_func(*copied_inp, **kwargs)
check_results(dumped_results, outp)
file = io.BytesIO()
dump_info = dumped_func.dump(file)
file.seek(0)
def take_number(arg_name):
return int(arg_name.split("_")[-1])
input_names = dump_info[4]
inps_np = [i.numpy() for i in copied_inp]
input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp, check_shape=False)
results = func(*inp_tensor, **kwargs)
check_results(results, outp, check_shape=(network is None))
if len(cases) == 0:
raise ValueError("should give one case at least")
if not callable(func):
raise ValueError("the input func should be callable")
for index in range(len(cases)):
run_index(index)