import unittest
import retworkx
class TestCompoes(unittest.TestCase):
def test_simple_dag_composition(self):
dag = retworkx.PyDAG()
dag.check_cycle = True
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', {'a': 1})
node_c = dag.add_child(node_b, 'c', {'a': 2})
dag_other = retworkx.PyDAG()
node_d = dag_other.add_node('d')
dag_other.add_child(node_d, 'e', {'a': 3})
res = dag.compose(dag_other, {node_c: (node_d, {'b': 1})})
self.assertEqual({0: 3, 1: 4}, res)
self.assertEqual([0, 1, 2, 3, 4], retworkx.topological_sort(dag))
def test_compose_graph_onto_digraph_error(self):
digraph = retworkx.PyDiGraph()
graph = retworkx.PyGraph()
with self.assertRaises(TypeError):
digraph.compose(graph, {})
def test_simple_graph_composition(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, {'a': 1})
node_c = graph.add_node('c')
graph.add_edge(node_b, node_c, {'a': 2})
graph_other = retworkx.PyGraph()
node_d = graph_other.add_node('d')
node_e = graph_other.add_node('e')
graph_other.add_edge(node_d, node_e, {'a': 3})
res = graph.compose(graph_other, {node_c: (node_d, {'b': 1})})
self.assertEqual({0: 3, 1: 4}, res)
self.assertEqual([0, 1, 2, 3, 4], graph.node_indexes())
def test_compose_digraph_onto_graph_error(self):
digraph = retworkx.PyDiGraph()
graph = retworkx.PyGraph()
with self.assertRaises(TypeError):
graph.compose(digraph, {})
def test_edge_map_and_node_map_funcs_digraph_compose(self):
digraph = retworkx.PyDiGraph()
original_input_nodes = digraph.add_nodes_from(['qr[0]', 'qr[1]'])
original_op_nodes = digraph.add_nodes_from(['h'])
output_nodes = digraph.add_nodes_from(['qr[0]', 'qr[1]'])
digraph.add_edge(original_input_nodes[0], original_op_nodes[0],
'qr[0]')
digraph.add_edge(original_op_nodes[0], output_nodes[0], 'qr[0]')
other_digraph = retworkx.PyDiGraph()
input_nodes = other_digraph.add_nodes_from(['qr[2]', 'qr[3]'])
op_nodes = other_digraph.add_nodes_from(['cx'])
other_output_nodes = other_digraph.add_nodes_from(['qr[2]', 'qr[3]'])
other_digraph.add_edges_from([(input_nodes[0], op_nodes[0], 'qr[2]'),
(input_nodes[1], op_nodes[0], 'qr[3]')])
other_digraph.add_edges_from([
(op_nodes[0], other_output_nodes[0], 'qr[2]'),
(op_nodes[0], other_output_nodes[1], 'qr[3]')])
def map_fn(weight):
if weight == 'qr[2]':
return 'qr[0]'
elif weight == 'qr[3]':
return 'qr[1]'
else:
return weight
digraph.remove_nodes_from(output_nodes)
other_digraph.remove_nodes_from(input_nodes)
node_map = {original_op_nodes[0]: (op_nodes[0], 'qr[0]'),
original_input_nodes[1]: (op_nodes[0], 'qr[1]')}
res = digraph.compose(other_digraph, node_map,
node_map_func=map_fn,
edge_map_func=map_fn)
self.assertEqual({2: 4, 3: 3, 4: 5}, res)
self.assertEqual(digraph[res[other_output_nodes[0]]], 'qr[0]')
self.assertEqual(digraph[res[other_output_nodes[1]]], 'qr[1]')
self.assertTrue(digraph.has_edge(0, 2))
self.assertTrue(digraph.get_all_edge_data(0, 2), ['qr[0]'])
self.assertTrue(digraph.has_edge(1, 4))
self.assertTrue(digraph.get_all_edge_data(1, 4), ['qr[1]'])
self.assertTrue(digraph.has_edge(2, 4))
self.assertTrue(digraph.get_all_edge_data(0, 2), ['qr[0]'])
self.assertTrue(digraph.has_edge(4, 3))
self.assertTrue(digraph.get_all_edge_data(0, 2), ['qr[0]'])
self.assertTrue(digraph.has_edge(4, 5))
self.assertTrue(digraph.get_all_edge_data(0, 2), ['qr[1]'])
def test_edge_map_and_node_map_funcs_graph_compose(self):
graph = retworkx.PyGraph()
original_input_nodes = graph.add_nodes_from(['qr[0]', 'qr[1]'])
original_op_nodes = graph.add_nodes_from(['h'])
output_nodes = graph.add_nodes_from(['qr[0]', 'qr[1]'])
graph.add_edge(original_input_nodes[0], original_op_nodes[0],
'qr[0]')
graph.add_edge(original_op_nodes[0], output_nodes[0], 'qr[0]')
other_graph = retworkx.PyGraph()
input_nodes = other_graph.add_nodes_from(['qr[2]', 'qr[3]'])
op_nodes = other_graph.add_nodes_from(['cx'])
other_output_nodes = other_graph.add_nodes_from(['qr[2]', 'qr[3]'])
other_graph.add_edges_from([(input_nodes[0], op_nodes[0], 'qr[2]'),
(input_nodes[1], op_nodes[0], 'qr[3]')])
other_graph.add_edges_from([
(op_nodes[0], other_output_nodes[0], 'qr[2]'),
(op_nodes[0], other_output_nodes[1], 'qr[3]')])
def map_fn(weight):
if weight == 'qr[2]':
return 'qr[0]'
elif weight == 'qr[3]':
return 'qr[1]'
else:
return weight
graph.remove_nodes_from(output_nodes)
other_graph.remove_nodes_from(input_nodes)
node_map = {original_op_nodes[0]: (op_nodes[0], 'qr[0]'),
original_input_nodes[1]: (op_nodes[0], 'qr[1]')}
res = graph.compose(other_graph, node_map, node_map_func=map_fn,
edge_map_func=map_fn)
self.assertEqual({2: 4, 3: 3, 4: 5}, res)
self.assertEqual(graph[res[other_output_nodes[0]]], 'qr[0]')
self.assertEqual(graph[res[other_output_nodes[1]]], 'qr[1]')
self.assertTrue(graph.has_edge(0, 2))
self.assertTrue(graph.get_all_edge_data(0, 2), ['qr[0]'])
self.assertTrue(graph.has_edge(1, 4))
self.assertTrue(graph.get_all_edge_data(0, 2), ['qr[1]'])
self.assertTrue(graph.has_edge(2, 4))
self.assertTrue(graph.get_all_edge_data(0, 2), ['qr[0]'])
self.assertTrue(graph.has_edge(4, 3))
self.assertTrue(graph.get_all_edge_data(0, 2), ['qr[0]'])
self.assertTrue(graph.has_edge(4, 5))
self.assertTrue(graph.get_all_edge_data(0, 2), ['qr[1]'])