import io
import numpy as np
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
import megengine.module as M
import megengine.utils.network_node as N
from megengine.jit.tracing import trace
from megengine.tensor import Tensor
from megengine.utils.comp_graph_tools import GraphInference
from megengine.utils.network import Network as Net
from megengine.utils.network import as_oprnode, set_symbolic_shape
from megengine.utils.network_node import Host2DeviceCopy, VarNode
def test_metadata():
x = Tensor(0)
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return x * 2
fwd(x)
orig_model = io.BytesIO()
fwd.dump(orig_model, user_info="test", optimize_for_inference=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": "test",
"graph_modified": False, "optimized_for_inference": False,
}
orig_model.seek(0)
graph.dump(
orig_model,
user_info={"str": "x", "tensor": x, "module": M.Module, "none": None},
optimize_for_inference=True,
enable_nchw4=True,
enable_ioc16=True,
)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
"graph_modified": True, "optimized_for_inference": True,
"enable_nchw4": True,
"enable_ioc16": True,
}
orig_model.seek(0)
fwd.dump(orig_model, enable_metadata=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata is None
def test_replace_var():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique()
out = F.mul(vara, varb)
out = F.relu(out)
opnode = list(graph.opr_filter.has_input(vara))
repl_dict = {opnode[0].outputs[0]: out}
graph.replace_vars(repl_dict)
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [6, 16])
def test_replace_opr():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique()
out1 = F.sub(vara, varb)
out1 = F.relu(out1)
out1 = graph.add_dep_oprs(out1)
orig_opr = graph.opr_filter.has_input(vara).as_unique()
new_opr = out1[0].owner
repl_dict = {orig_opr: new_opr}
graph.replace_oprs(repl_dict)
var_out = orig_opr.outputs
for idx, node in enumerate(var_out):
assert node.owner is new_opr
assert node.owner.outputs[idx] is node
modified_model1 = io.BytesIO()
graph.dump(modified_model1)
modified_model1.seek(0)
load_graph = GraphInference(modified_model1)
out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [0, 0])
def test_splice_network():
x = F.ones((2,))
y = F.ones((2,))
@trace(symbolic=True, capture_as_const=True)
def fun1(a, b):
return (a + b) * 2
@trace(symbolic=True, capture_as_const=True)
def fun2(a):
return a * 2 - 1
model = io.BytesIO()
fun1(x, y)
fun2(x)
fun1.dump(
model,
arg_names=["net1_i0", "net1_i1"],
output_names=["net1_o0"],
optimize_for_inference=False,
)
model.seek(0)
net1 = Net.load(model)
model.seek(0)
fun2.dump(
model,
arg_names=["net2_i0"],
output_names=["net2_o0"],
optimize_for_inference=False,
)
model.seek(0)
net2 = Net.load(model)
net1.add_output(*net2.output_vars)
var = net1.var_filter.name("net1_i0").as_unique()
repl_var = net2.var_filter.name("net2_o0").as_unique()
net1.replace_vars({var: repl_var})
assert "net1_i0" not in [var.name for var in net1.all_vars]
assert "net2_i0" in [var.name for var in net1.all_vars]
model.seek(0)
net1.dump(model, keep_var_name=2, optimize_for_inference=False)
model.seek(0)
net = Net.load(model)
assert "net1_i0" not in [var.name for var in net.all_vars]
assert "net2_i0" in [var.name for var in net.all_vars]
def test_modify_params():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
param_const = graph.params_filter.as_unique()
param_const.set_value(3)
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [12, 18])
def test_make_const():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
const_b = graph.make_const(np.array([0.0, 0.0]), name="b")
varb = graph.var_filter.name("b").as_unique()
repl_dict = {varb: const_b}
graph.replace_vars(repl_dict)
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a)
np.testing.assert_equal(out["o"], [2, 4])
def test_add_input():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
inp_c = graph.make_input_node((2,), np.int32, name="c")
varo = graph.var_filter.name("o").as_unique()
out = F.add(varo, inp_c)
out.name = "o1"
graph.remove_output(varo)
graph.add_output(out)
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a, b, a)
np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy())
def test_add_remove_output():
a = Tensor([1.0, 2.0])
b = Tensor([3.0, 4.0])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2, (a - b)
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model,
arg_names=["a", "b"],
output_names=["o1", "o2"],
optimize_for_inference=False,
)
orig_model.seek(0)
net = Net.load(orig_model)
var_a = net.var_filter.name("a").as_unique()
var_b = net.var_filter.name("b").as_unique()
y1 = (var_a + var_b) * 3
y2 = F.sigmoid(var_a + var_b)
net.remove_output(*net.output_vars)
y1.name = "new_o1"
y2.name = "new_o2"
net.add_output(y1, y2)
modified_model = io.BytesIO()
net.dump(modified_model)
modified_model.seek(0)
g = GraphInference(modified_model)
out = g.run(a.numpy(), b.numpy())
np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy())
np.testing.assert_almost_equal(out["new_o2"], (F.sigmoid((a + b))).numpy())
def test_query():
class Model(M.Module):
def __init__(self):
super().__init__()
self.conv1 = M.Conv2d(3, 32, 3)
self.conv2 = M.Conv2d(32, 32, 3)
self.conv3 = M.Conv2d(32, 32, 3)
def forward(self, data):
x = self.conv1(data)
x = self.conv2(x)
x = self.conv3(x)
return x
n = Model()
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return n(data)
fwd(Tensor(np.random.random((1, 3, 224, 224))))
orig_model = io.BytesIO()
fwd.dump(
orig_model,
arg_names=["data"],
output_names="o",
keep_opr_name=True,
keep_var_name=True,
optimize_for_inference=False,
)
orig_model.seek(0)
graph = Net.load(orig_model)
r = graph.data_providers_filter.as_count()
assert r == 1
opr = graph.get_opr_by_type(Host2DeviceCopy)
assert isinstance(opr, Host2DeviceCopy)
r1 = graph.params_filter.as_count()
assert r1 == 6
r2 = graph.opr_filter.type(N.ConvolutionForward).as_count()
assert r2 == 3
r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count()
assert r3 == len(graph.all_oprs) - r2
var = graph.var_filter.name("data").as_unique()
r4 = graph.opr_filter.has_input(var).as_count()
assert r4 == 1
r5 = graph.opr_filter.name("data").as_count()
assert r5 == 1
opr = graph.get_opr_by_name("data")
assert isinstance(opr, Host2DeviceCopy)
var = graph.get_var_by_name("data")
assert isinstance(var, VarNode)
r6 = graph.var_filter.name("*bias").as_count()
assert r6 == 3
def test_optimize_for_inference():
@trace(symbolic=True, capture_as_const=True)
def f(x):
return F.exp(x)
orig_model = io.BytesIO()
f(Tensor(5.0))
f.dump(orig_model, optimize_for_inference=False)
orig_model.seek(0)
optimize_model = io.BytesIO()
net = Net.load(orig_model)
net.dump(optimize_model, enable_io16xc32=True)
optimize_model.seek(0)
res = G.load_graph(optimize_model)
computing_input = res.output_vars_list[0].owner.inputs[0]
assert computing_input.dtype == np.float16
def test_reset_batchsize():
@trace(symbolic=True, capture_as_const=True)
def f(x):
return F.exp(x)
orig_model = io.BytesIO()
f(Tensor(np.random.random((3, 3, 224, 224))))
f.dump(orig_model, optimize_for_inference=False)
orig_model.seek(0)
modified_model = io.BytesIO()
net = Net.load(orig_model)
net.reset_batch_size(1)
net.dump(modified_model, optimize_for_inference=False)
modified_model.seek(0)
net1 = Net.load(modified_model)
assert net1.data_providers_filter.as_unique().shape[0] == 1
def test_modify_opr_name():
@trace(symbolic=True, capture_as_const=True)
def f(x):
return F.exp(x)
orig_model = io.BytesIO()
f(Tensor(np.random.random((3, 3, 224, 224))))
f.dump(orig_model, arg_names=["a"], optimize_for_inference=False)
orig_model.seek(0)
modified_model = io.BytesIO()
net = Net.load(orig_model)
net.modify_opr_names("net")
net.modify_opr_names(lambda x: "net1." + x)
net.dump(modified_model, optimize_for_inference=False)
modified_model.seek(0)
net1 = Net.load(modified_model)
assert net1.data_providers_filter.as_unique().name == "net1.net.a"
def test_dump_cond_take():
a = Tensor([1.0, 2.0])
@trace(symbolic=True, capture_as_const=True)
def fwd(a):
return F.cond_take(a > 1, a)
fwd(a)
orig_model = io.BytesIO()
fwd.dump(
orig_model,
arg_names=["a"],
output_names=["o1", "o2"],
optimize_for_inference=False,
)
orig_model.seek(0)
net = Net.load(orig_model)
var_a = net.input_vars[0]
val, idx = F.cond_take(var_a > 1, var_a)
net.remove_output(*net.output_vars)
val.name = "value"
idx.name = "index"
net.add_output(val, idx)
modified_model = io.BytesIO()
net.dump(modified_model)
modified_model.seek(0)
g = GraphInference(modified_model)
out = g.run(a.numpy())
data = a.numpy()
mask = a.numpy() > 1
np.testing.assert_equal(out["index"], np.where(mask.reshape(-1))[0])
np.testing.assert_equal(out["value"], data[mask])
def test_set_symbolic_shape():
a = Tensor([1.0, 2.0])
@trace(symbolic=True, capture_as_const=True)
def fwd(a):
return F.relu(a * 2)
fwd(a)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a"], output_names=["o"], optimize_for_inference=False,
)
orig_model.seek(0)
net = Net.load(orig_model)
var_a = net.input_vars[0]
saved_symbolic_shape = set_symbolic_shape(True)
assert isinstance(var_a.shape, VarNode)
set_symbolic_shape(False)
assert var_a.shape == var_a.partial_shape
set_symbolic_shape(saved_symbolic_shape)
def test_replace_var_in_different_network():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
@trace(symbolic=True, capture_as_const=True)
def fwd1(c, d):
return c + d
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
fwd1(a, b)
orig_model1 = io.BytesIO()
fwd1.dump(
orig_model1,
arg_names=["c", "d"],
output_names="o",
optimize_for_inference=False,
)
orig_model1.seek(0)
graph = Net.load(orig_model)
graph1 = Net.load(orig_model1)
vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique()
varo = graph1.var_filter.name("o").as_unique()
graph.replace_vars({vara: varo, varb: varo})
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [16, 24])