1use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7pub use crate::hugr::internal::HugrMutInternals;
8use crate::hugr::views::SiblingSubgraph;
9use crate::hugr::{HugrMut, HugrView, Rewrite};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::{Hugr, IncomingPort, Node};
12
13use thiserror::Error;
14
15use super::inline_dfg::InlineDFGError;
16
17#[derive(Debug, Clone)]
23pub struct SimpleReplacement<HostNode = Node> {
24 subgraph: SiblingSubgraph<HostNode>,
26 replacement: Hugr,
28 nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
31 nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>,
34}
35
36impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
37 #[inline]
39 pub fn new(
40 subgraph: SiblingSubgraph<HostNode>,
41 replacement: Hugr,
42 nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
43 nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>,
44 ) -> Self {
45 Self {
46 subgraph,
47 replacement,
48 nu_inp,
49 nu_out,
50 }
51 }
52
53 #[inline]
55 pub fn replacement(&self) -> &Hugr {
56 &self.replacement
57 }
58
59 #[inline]
61 pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
62 &self.subgraph
63 }
64}
65
66impl Rewrite for SimpleReplacement {
67 type Error = SimpleReplacementError;
68 type ApplyResult = Vec<(Node, OpType)>;
69 const UNCHANGED_ON_FAILURE: bool = true;
70
71 fn verify(&self, _h: &impl HugrView<Node = Node>) -> Result<(), SimpleReplacementError> {
72 unimplemented!()
73 }
74
75 fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
76 let Self {
77 subgraph,
78 replacement,
79 nu_inp,
80 nu_out,
81 } = self;
82 let parent = subgraph.get_parent(h);
83 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
85 return Err(SimpleReplacementError::InvalidParentNode());
86 }
87 for node in subgraph.nodes() {
89 if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
90 return Err(SimpleReplacementError::InvalidRemovedNode());
91 }
92 }
93
94 let replacement_output_node = replacement
95 .get_io(replacement.root())
96 .expect("parent already checked.")[1];
97
98 let nu_inp_connects: Vec<_> = nu_inp
108 .iter()
109 .filter(|&((rep_inp_node, _), _)| {
110 replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
111 })
112 .map(
113 |((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
114 let (rem_inp_pred_node, rem_inp_pred_port) = h
116 .single_linked_output(*rem_inp_node, *rem_inp_port)
117 .unwrap();
118 (
119 rem_inp_pred_node,
120 rem_inp_pred_port,
121 rep_inp_node,
123 rep_inp_port,
124 )
125 },
126 )
127 .collect();
128
129 let nu_out_connects: Vec<_> = nu_out
132 .iter()
133 .filter_map(|((rem_out_node, rem_out_port), rep_out_port)| {
134 let (rep_out_pred_node, rep_out_pred_port) = replacement
135 .single_linked_output(replacement_output_node, *rep_out_port)
136 .unwrap();
137 (replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
138 (
139 rep_out_pred_node,
141 rep_out_pred_port,
142 rem_out_node,
143 rem_out_port,
144 )
145 })
146 })
147 .collect();
148
149 let InsertionResult {
151 new_root,
152 node_map: index_map,
153 } = h.insert_hugr(parent, replacement);
154
155 let replace_children = h.children(new_root).collect::<Vec<Node>>();
157 for &io in &replace_children[..2] {
158 h.remove_node(io);
159 }
160 for &child in &replace_children[2..] {
162 h.set_parent(child, parent);
163 }
164 h.remove_node(new_root);
166
167 for (src_node, src_port, tgt_node, tgt_port) in nu_inp_connects {
169 h.connect(
170 src_node,
171 src_port,
172 *index_map.get(tgt_node).unwrap(),
173 *tgt_port,
174 )
175 }
176
177 for (src_node, src_port, tgt_node, tgt_port) in nu_out_connects {
178 h.connect(
179 *index_map.get(&src_node).unwrap(),
180 src_port,
181 *tgt_node,
182 *tgt_port,
183 )
184 }
185 for ((rem_out_node, rem_out_port), &rep_out_port) in &nu_out {
190 let rem_inp_nodeport = nu_inp.get(&(replacement_output_node, rep_out_port));
191 if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
192 let (rem_inp_pred_node, rem_inp_pred_port) = h
194 .single_linked_output(*rem_inp_node, *rem_inp_port)
195 .unwrap();
196
197 h.connect(
198 rem_inp_pred_node,
199 rem_inp_pred_port,
200 *rem_out_node,
201 *rem_out_port,
202 );
203 }
204 }
205
206 Ok(subgraph
208 .nodes()
209 .iter()
210 .map(|&node| (node, h.remove_node(node)))
211 .collect())
212 }
213
214 #[inline]
215 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
216 let subcirc = self.subgraph.nodes().iter().copied();
217 let out_neighs = self.nu_out.keys().map(|key| key.0);
218 subcirc.chain(out_neighs)
219 }
220}
221
222#[derive(Debug, Clone, Error, PartialEq, Eq)]
224#[non_exhaustive]
225pub enum SimpleReplacementError {
226 #[error("Parent node is invalid.")]
228 InvalidParentNode(),
229 #[error("A node requested for removal is invalid.")]
231 InvalidRemovedNode(),
232 #[error("A node in the replacement graph is invalid.")]
234 InvalidReplacementNode(),
235 #[error("Inlining replacement failed: {0}")]
237 InliningFailed(#[from] InlineDFGError),
238}
239
240#[cfg(test)]
241pub(in crate::hugr::rewrite) mod test {
242 use itertools::Itertools;
243 use rstest::{fixture, rstest};
244 use std::collections::{HashMap, HashSet};
245
246 use crate::builder::test::n_identity;
247 use crate::builder::{
248 endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
249 DataflowSubContainer, HugrBuilder, ModuleBuilder,
250 };
251 use crate::extension::prelude::{bool_t, qb_t};
252 use crate::extension::ExtensionSet;
253 use crate::hugr::views::{HugrView, SiblingSubgraph};
254 use crate::hugr::{Hugr, HugrMut, Rewrite};
255 use crate::ops::dataflow::DataflowOpTrait;
256 use crate::ops::handle::NodeHandle;
257 use crate::ops::OpTag;
258 use crate::ops::OpTrait;
259 use crate::std_extensions::logic::test::and_op;
260 use crate::std_extensions::logic::LogicOp;
261 use crate::types::{Signature, Type};
262 use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID};
263 use crate::{IncomingPort, Node};
264
265 use super::SimpleReplacement;
266
267 fn make_hugr() -> Result<Hugr, BuildError> {
277 let mut module_builder = ModuleBuilder::new();
278 let _f_id = {
279 let just_q: ExtensionSet = EXTENSION_ID.into();
280 let mut func_builder = module_builder.define_function(
281 "main",
282 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
283 .with_extension_delta(just_q.clone()),
284 )?;
285
286 let [qb0, qb1, qb2] = func_builder.input_wires_arr();
287
288 let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
289
290 let mut inner_builder =
291 func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
292 let inner_graph = {
293 let [wire0, wire1] = inner_builder.input_wires_arr();
294 let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
295 let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
296 let wire45 = inner_builder
297 .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
298 let [wire4, wire5] = wire45.outputs_arr();
299 let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
300 let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
301 inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
302 }?;
303
304 func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
305 };
306 Ok(module_builder.finish_hugr()?)
307 }
308
309 #[fixture]
310 pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr {
311 make_hugr().unwrap()
312 }
313 fn make_dfg_hugr() -> Result<Hugr, BuildError> {
320 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?;
321 let [wire0, wire1] = dfg_builder.input_wires_arr();
322 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
323 let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
324 let wire45 =
325 dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
326 dfg_builder.finish_hugr_with_outputs(wire45.outputs())
327 }
328
329 #[fixture]
330 pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr {
331 make_dfg_hugr().unwrap()
332 }
333
334 fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
340 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
341
342 let [wire0, wire1] = dfg_builder.input_wires_arr();
343 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
344 let wire2out = wire2.outputs().exactly_one().unwrap();
345 let wireoutvec = vec![wire0, wire2out];
346 dfg_builder.finish_hugr_with_outputs(wireoutvec)
347 }
348
349 #[fixture]
350 pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr {
351 make_dfg_hugr2().unwrap()
352 }
353
354 #[fixture]
366 pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
367 let mut dfg_builder =
368 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
369 let [b] = dfg_builder.input_wires_arr();
370
371 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
372 let [b] = not_inp.outputs_arr();
373
374 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
375 let [b0] = not_0.outputs_arr();
376 let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
377 let [b1] = not_1.outputs_arr();
378
379 (
380 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
381 vec![not_inp.node(), not_0.node(), not_1.node()],
382 )
383 }
384
385 #[fixture]
397 pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
398 let mut dfg_builder =
399 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
400 let [b] = dfg_builder.input_wires_arr();
401
402 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
403 let [b] = not_inp.outputs_arr();
404
405 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
406 let [b0] = not_0.outputs_arr();
407 let b1 = b;
408
409 (
410 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
411 vec![not_inp.node(), not_0.node()],
412 )
413 }
414
415 #[rstest]
416 fn test_simple_replacement(
435 simple_hugr: Hugr,
436 dfg_hugr: Hugr,
437 #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
438 ) {
439 let mut h: Hugr = simple_hugr;
440 let h_node_cx: Node = h
442 .nodes()
443 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
444 .unwrap();
445 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
446 let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
447 let n: Hugr = dfg_hugr;
449 let n_node_cx = n
452 .nodes()
453 .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
454 .unwrap();
455 let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap();
456 let n_port_0 = n.node_inputs(n_node_h0).next().unwrap();
458 let n_port_1 = n.node_inputs(n_node_h1).next().unwrap();
459 let (n_cx_out_0, n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
460 let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
461 let n_port_3 = n.linked_inputs(n_node_cx, n_cx_out_1).next().unwrap().1;
462 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
464 let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
465 let h_h1_out = h.node_outputs(h_node_h1).next().unwrap();
466 let (h_outp_node, h_port_2) = h.linked_inputs(h_node_h0, h_h0_out).next().unwrap();
467 let h_port_3 = h.linked_inputs(h_node_h1, h_h1_out).next().unwrap().1;
468 let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
470 let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
471 nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0));
472 nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1));
473 nu_out.insert((h_outp_node, h_port_2), n_port_2);
474 nu_out.insert((h_outp_node, h_port_3), n_port_3);
475 let r = SimpleReplacement {
477 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
478 replacement: n,
479 nu_inp,
480 nu_out,
481 };
482 assert_eq!(
483 HashSet::<_>::from_iter(r.invalidation_set()),
484 HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]),
485 );
486
487 applicator(&mut h, r);
488 assert_eq!(h.validate(), Ok(()));
495 }
496
497 #[rstest]
498 fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
516 let mut h: Hugr = simple_hugr;
517
518 let h_node_cx: Node = h
520 .nodes()
521 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
522 .unwrap();
523 let s: Vec<Node> = vec![h_node_cx].into_iter().collect();
524 let n: Hugr = dfg_hugr2;
526 let n_node_output = n
529 .nodes()
530 .find(|node: &Node| n.get_optype(*node).tag() == OpTag::Output)
531 .unwrap();
532 let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
533 let (n_port_0, n_port_1) = n
535 .node_inputs(n_node_output)
536 .take(2)
537 .collect_tuple()
538 .unwrap();
539 let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
540 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
542 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
543 let h_port_2 = h.node_inputs(h_node_h0).next().unwrap();
544 let h_port_3 = h.node_inputs(h_node_h1).next().unwrap();
545 let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
547 let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
548 nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
549 nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
550 nu_out.insert((h_node_h0, h_port_2), n_port_0);
551 nu_out.insert((h_node_h1, h_port_3), n_port_1);
552 let r = SimpleReplacement {
554 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
555 replacement: n,
556 nu_inp,
557 nu_out,
558 };
559 h.apply_rewrite(r).unwrap();
560 assert_eq!(h.validate(), Ok(()));
567 }
568
569 #[test]
570 fn test_replace_cx_cross() {
571 let q_row: Vec<Type> = vec![qb_t(), qb_t()];
572 let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
573 let mut circ = builder.as_circuit(builder.input_wires());
574 circ.append(cx_gate(), [0, 1]).unwrap();
575 circ.append(cx_gate(), [1, 0]).unwrap();
576 let wires = circ.finish();
577 let [input, output] = builder.io();
578 let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
579 let replacement = h.clone();
580 let orig = h.clone();
581
582 let removal = h
583 .nodes()
584 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
585 .collect_vec();
586 let inputs = h
587 .node_outputs(input)
588 .filter(|&p| {
589 h.get_optype(input)
590 .as_input()
591 .unwrap()
592 .signature()
593 .port_type(p)
594 .is_some()
595 })
596 .map(|p| {
597 let link = h.linked_inputs(input, p).next().unwrap();
598 (link, link)
599 })
600 .collect();
601 let outputs = h
602 .node_inputs(output)
603 .filter(|&p| {
604 h.get_optype(output)
605 .as_output()
606 .unwrap()
607 .signature()
608 .port_type(p)
609 .is_some()
610 })
611 .map(|p| ((output, p), p))
612 .collect();
613 h.apply_rewrite(SimpleReplacement::new(
614 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
615 replacement,
616 inputs,
617 outputs,
618 ))
619 .unwrap();
620
621 assert_eq!(h.edge_count(), orig.edge_count());
623 }
624
625 #[test]
626 fn test_replace_after_copy() {
627 let one_bit = vec![bool_t()];
628 let two_bit = vec![bool_t(), bool_t()];
629
630 let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
631 let inw = builder.input_wires().exactly_one().unwrap();
632 let outw = builder
633 .add_dataflow_op(and_op(), [inw, inw])
634 .unwrap()
635 .outputs();
636 let [input, _] = builder.io();
637 let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
638
639 let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
640 let inw = builder.input_wires();
641 let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
642 let [repl_input, repl_output] = builder.io();
643 let repl = builder.finish_hugr_with_outputs(outw).unwrap();
644
645 let orig = h.clone();
646
647 let removal = h
648 .nodes()
649 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
650 .collect_vec();
651
652 let first_out_p = h.node_outputs(input).next().unwrap();
653 let embedded_inputs = h.linked_inputs(input, first_out_p);
654 let repl_inputs = repl
655 .node_outputs(repl_input)
656 .map(|p| repl.linked_inputs(repl_input, p).next().unwrap());
657 let inputs = embedded_inputs.zip(repl_inputs).collect();
658
659 let outputs = repl
660 .node_inputs(repl_output)
661 .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some())
662 .map(|p| ((repl_output, p), p))
663 .collect();
664
665 h.apply_rewrite(SimpleReplacement::new(
666 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
667 repl,
668 inputs,
669 outputs,
670 ))
671 .unwrap();
672
673 assert_eq!(h.node_count(), orig.node_count());
675 }
676
677 #[rstest]
682 fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
683 let (mut hugr, nodes) = dfg_hugr_copy_bools;
684 let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
685
686 let [_input, output] = hugr.get_io(hugr.root()).unwrap();
687
688 let replacement = {
689 let b =
690 DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
691 let [w] = b.input_wires_arr();
692 b.finish_hugr_with_outputs([w, w]).unwrap()
693 };
694 let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
695
696 let subgraph =
697 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
698 .unwrap();
699 let nu_inp = [
702 (
703 (repl_output, IncomingPort::from(0)),
704 (input_not, IncomingPort::from(0)),
705 ),
706 (
707 (repl_output, IncomingPort::from(1)),
708 (input_not, IncomingPort::from(0)),
709 ),
710 ]
711 .into_iter()
712 .collect();
713 let nu_out = [
716 ((output, IncomingPort::from(0)), IncomingPort::from(0)),
717 ((output, IncomingPort::from(1)), IncomingPort::from(1)),
718 ]
719 .into_iter()
720 .collect();
721
722 let rewrite = SimpleReplacement {
723 subgraph,
724 replacement,
725 nu_inp,
726 nu_out,
727 };
728 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
729
730 assert_eq!(hugr.validate(), Ok(()));
731 assert_eq!(hugr.node_count(), 3);
732 }
733
734 #[rstest]
739 fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
740 let (mut hugr, nodes) = dfg_hugr_half_not_bools;
741 let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
742
743 let [_input, output] = hugr.get_io(hugr.root()).unwrap();
744
745 let (replacement, repl_not) = {
746 let mut b =
747 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
748 let [w] = b.input_wires_arr();
749 let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
750 let [w_not] = not.outputs_arr();
751 (b.finish_hugr_with_outputs([w, w_not]).unwrap(), not.node())
752 };
753 let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
754
755 let subgraph =
756 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
757 let nu_inp = [
760 (
761 (repl_output, IncomingPort::from(0)),
762 (input_not, IncomingPort::from(0)),
763 ),
764 (
765 (repl_not, IncomingPort::from(0)),
766 (input_not, IncomingPort::from(0)),
767 ),
768 ]
769 .into_iter()
770 .collect();
771 let nu_out = [
774 ((output, IncomingPort::from(0)), IncomingPort::from(0)),
775 ((output, IncomingPort::from(1)), IncomingPort::from(1)),
776 ]
777 .into_iter()
778 .collect();
779
780 let rewrite = SimpleReplacement {
781 subgraph,
782 replacement,
783 nu_inp,
784 nu_out,
785 };
786 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
787
788 assert_eq!(hugr.validate(), Ok(()));
789 assert_eq!(hugr.node_count(), 4);
790 }
791
792 #[rstest]
793 fn test_nested_replace(dfg_hugr2: Hugr) {
794 let mut h = dfg_hugr2;
797 let h_node = h
798 .nodes()
799 .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
800 .unwrap();
801
802 let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
804 let [input] = nest_build.input_wires_arr();
805 let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
806 let inner_dfg = n_identity(inner_build).unwrap();
807 let inner_dfg_node = inner_dfg.node();
808 let replacement = nest_build
809 .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
810 .unwrap();
811 let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
812 let nu_inp = vec![(
813 (inner_dfg_node, IncomingPort::from(0)),
814 (h_node, IncomingPort::from(0)),
815 )]
816 .into_iter()
817 .collect();
818
819 let nu_out = vec![(
820 (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)),
821 IncomingPort::from(0),
822 )]
823 .into_iter()
824 .collect();
825
826 let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);
827
828 assert_eq!(h.node_count(), 4);
829
830 rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
831 h.validate().unwrap_or_else(|e| panic!("{e}"));
832
833 assert_eq!(h.node_count(), 6);
834 }
835
836 use crate::hugr::rewrite::replace::Replacement;
837 fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
838 use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};
839
840 let mut replacement = s.replacement;
841 let (in_, out) = replacement
842 .children(replacement.root())
843 .take(2)
844 .collect_tuple()
845 .unwrap();
846 let mu_inp = s
847 .nu_inp
848 .iter()
849 .map(|((tgt, tgt_port), (r_n, r_p))| {
850 if *tgt == out {
851 unimplemented!()
852 };
853 let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap();
854 NewEdgeSpec {
855 src,
856 tgt: *tgt,
857 kind: NewEdgeKind::Value {
858 src_pos: src_port,
859 tgt_pos: *tgt_port,
860 },
861 }
862 })
863 .collect();
864 let mu_out = s
865 .nu_out
866 .iter()
867 .map(|((tgt, tgt_port), out_port)| {
868 let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap();
869 if src == in_ {
870 unimplemented!()
871 };
872 NewEdgeSpec {
873 src,
874 tgt: *tgt,
875 kind: NewEdgeKind::Value {
876 src_pos: src_port,
877 tgt_pos: *tgt_port,
878 },
879 }
880 })
881 .collect();
882 replacement.remove_node(in_);
883 replacement.remove_node(out);
884 Replacement {
885 removal: s.subgraph.nodes().to_vec(),
886 replacement,
887 adoptions: HashMap::new(),
888 mu_inp,
889 mu_out,
890 mu_new: vec![],
891 }
892 }
893
894 fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
895 h.apply_rewrite(rw).unwrap();
896 }
897
898 fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
899 h.apply_rewrite(to_replace(h, rw)).unwrap();
900 }
901}