1use std::collections::HashMap;
4
5use crate::hugr::hugrmut::InsertionResult;
6pub use crate::hugr::internal::HugrMutInternals;
7use crate::hugr::views::SiblingSubgraph;
8use crate::hugr::{HugrMut, HugrView, Rewrite};
9use crate::ops::{OpTag, OpTrait, OpType};
10use crate::{Hugr, IncomingPort, Node};
11use thiserror::Error;
12
13use super::inline_dfg::InlineDFGError;
14
15#[derive(Debug, Clone)]
17pub struct SimpleReplacement {
18 subgraph: SiblingSubgraph,
20 replacement: Hugr,
22 nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)>,
25 nu_out: HashMap<(Node, IncomingPort), IncomingPort>,
28}
29
30impl SimpleReplacement {
31 #[inline]
33 pub fn new(
34 subgraph: SiblingSubgraph,
35 replacement: Hugr,
36 nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)>,
37 nu_out: HashMap<(Node, IncomingPort), IncomingPort>,
38 ) -> Self {
39 Self {
40 subgraph,
41 replacement,
42 nu_inp,
43 nu_out,
44 }
45 }
46
47 #[inline]
49 pub fn replacement(&self) -> &Hugr {
50 &self.replacement
51 }
52
53 #[inline]
55 pub fn subgraph(&self) -> &SiblingSubgraph {
56 &self.subgraph
57 }
58}
59
60impl Rewrite for SimpleReplacement {
61 type Error = SimpleReplacementError;
62 type ApplyResult = Vec<(Node, OpType)>;
63 const UNCHANGED_ON_FAILURE: bool = true;
64
65 fn verify(&self, _h: &impl HugrView) -> Result<(), SimpleReplacementError> {
66 unimplemented!()
67 }
68
69 fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
70 let Self {
71 subgraph,
72 replacement,
73 nu_inp,
74 nu_out,
75 } = self;
76 let parent = subgraph.get_parent(h);
77 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
79 return Err(SimpleReplacementError::InvalidParentNode());
80 }
81 for node in subgraph.nodes() {
83 if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
84 return Err(SimpleReplacementError::InvalidRemovedNode());
85 }
86 }
87
88 let replacement_output_node = replacement
89 .get_io(replacement.root())
90 .expect("parent already checked.")[1];
91
92 let nu_inp_connects: Vec<_> = nu_inp
102 .iter()
103 .filter(|&((rep_inp_node, _), _)| {
104 replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
105 })
106 .map(
107 |((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
108 let (rem_inp_pred_node, rem_inp_pred_port) = h
110 .single_linked_output(*rem_inp_node, *rem_inp_port)
111 .unwrap();
112 (
113 rem_inp_pred_node,
114 rem_inp_pred_port,
115 rep_inp_node,
117 rep_inp_port,
118 )
119 },
120 )
121 .collect();
122
123 let nu_out_connects: Vec<_> = nu_out
126 .iter()
127 .filter_map(|((rem_out_node, rem_out_port), rep_out_port)| {
128 let (rep_out_pred_node, rep_out_pred_port) = replacement
129 .single_linked_output(replacement_output_node, *rep_out_port)
130 .unwrap();
131 (replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
132 (
133 rep_out_pred_node,
135 rep_out_pred_port,
136 rem_out_node,
137 rem_out_port,
138 )
139 })
140 })
141 .collect();
142
143 let InsertionResult {
145 new_root,
146 node_map: index_map,
147 } = h.insert_hugr(parent, replacement);
148
149 let replace_children = h.children(new_root).collect::<Vec<Node>>();
151 for &io in &replace_children[..2] {
152 h.remove_node(io);
153 }
154 for &child in &replace_children[2..] {
156 h.set_parent(child, parent);
157 }
158 h.remove_node(new_root);
160
161 for (src_node, src_port, tgt_node, tgt_port) in nu_inp_connects {
163 h.connect(
164 src_node,
165 src_port,
166 *index_map.get(tgt_node).unwrap(),
167 *tgt_port,
168 )
169 }
170
171 for (src_node, src_port, tgt_node, tgt_port) in nu_out_connects {
172 h.connect(
173 *index_map.get(&src_node).unwrap(),
174 src_port,
175 *tgt_node,
176 *tgt_port,
177 )
178 }
179 for ((rem_out_node, rem_out_port), &rep_out_port) in &nu_out {
184 let rem_inp_nodeport = nu_inp.get(&(replacement_output_node, rep_out_port));
185 if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
186 let (rem_inp_pred_node, rem_inp_pred_port) = h
188 .single_linked_output(*rem_inp_node, *rem_inp_port)
189 .unwrap();
190
191 h.connect(
192 rem_inp_pred_node,
193 rem_inp_pred_port,
194 *rem_out_node,
195 *rem_out_port,
196 );
197 }
198 }
199
200 Ok(subgraph
202 .nodes()
203 .iter()
204 .map(|&node| (node, h.remove_node(node)))
205 .collect())
206 }
207
208 #[inline]
209 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
210 let subcirc = self.subgraph.nodes().iter().copied();
211 let out_neighs = self.nu_out.keys().map(|key| key.0);
212 subcirc.chain(out_neighs)
213 }
214}
215
216#[derive(Debug, Clone, Error, PartialEq, Eq)]
218#[non_exhaustive]
219pub enum SimpleReplacementError {
220 #[error("Parent node is invalid.")]
222 InvalidParentNode(),
223 #[error("A node requested for removal is invalid.")]
225 InvalidRemovedNode(),
226 #[error("A node in the replacement graph is invalid.")]
228 InvalidReplacementNode(),
229 #[error("Inlining replacement failed: {0}")]
231 InliningFailed(#[from] InlineDFGError),
232}
233
234#[cfg(test)]
235pub(in crate::hugr::rewrite) mod test {
236 use itertools::Itertools;
237 use rstest::{fixture, rstest};
238 use std::collections::{HashMap, HashSet};
239
240 use crate::builder::test::n_identity;
241 use crate::builder::{
242 endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
243 DataflowSubContainer, HugrBuilder, ModuleBuilder,
244 };
245 use crate::extension::prelude::{bool_t, qb_t};
246 use crate::extension::ExtensionSet;
247 use crate::hugr::views::{HugrView, SiblingSubgraph};
248 use crate::hugr::{Hugr, HugrMut, Rewrite};
249 use crate::ops::dataflow::DataflowOpTrait;
250 use crate::ops::handle::NodeHandle;
251 use crate::ops::OpTag;
252 use crate::ops::OpTrait;
253 use crate::std_extensions::logic::test::and_op;
254 use crate::std_extensions::logic::LogicOp;
255 use crate::types::{Signature, Type};
256 use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID};
257 use crate::{IncomingPort, Node};
258
259 use super::SimpleReplacement;
260
261 fn make_hugr() -> Result<Hugr, BuildError> {
271 let mut module_builder = ModuleBuilder::new();
272 let _f_id = {
273 let just_q: ExtensionSet = EXTENSION_ID.into();
274 let mut func_builder = module_builder.define_function(
275 "main",
276 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
277 .with_extension_delta(just_q.clone()),
278 )?;
279
280 let [qb0, qb1, qb2] = func_builder.input_wires_arr();
281
282 let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
283
284 let mut inner_builder =
285 func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
286 let inner_graph = {
287 let [wire0, wire1] = inner_builder.input_wires_arr();
288 let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
289 let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
290 let wire45 = inner_builder
291 .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
292 let [wire4, wire5] = wire45.outputs_arr();
293 let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
294 let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
295 inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
296 }?;
297
298 func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
299 };
300 Ok(module_builder.finish_hugr()?)
301 }
302
303 #[fixture]
304 pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr {
305 make_hugr().unwrap()
306 }
307 fn make_dfg_hugr() -> Result<Hugr, BuildError> {
314 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?;
315 let [wire0, wire1] = dfg_builder.input_wires_arr();
316 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
317 let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
318 let wire45 =
319 dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
320 dfg_builder.finish_hugr_with_outputs(wire45.outputs())
321 }
322
323 #[fixture]
324 pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr {
325 make_dfg_hugr().unwrap()
326 }
327
328 fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
334 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
335
336 let [wire0, wire1] = dfg_builder.input_wires_arr();
337 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
338 let wire2out = wire2.outputs().exactly_one().unwrap();
339 let wireoutvec = vec![wire0, wire2out];
340 dfg_builder.finish_hugr_with_outputs(wireoutvec)
341 }
342
343 #[fixture]
344 pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr {
345 make_dfg_hugr2().unwrap()
346 }
347
348 #[fixture]
360 pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
361 let mut dfg_builder =
362 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
363 let [b] = dfg_builder.input_wires_arr();
364
365 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
366 let [b] = not_inp.outputs_arr();
367
368 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
369 let [b0] = not_0.outputs_arr();
370 let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
371 let [b1] = not_1.outputs_arr();
372
373 (
374 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
375 vec![not_inp.node(), not_0.node(), not_1.node()],
376 )
377 }
378
379 #[fixture]
391 pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
392 let mut dfg_builder =
393 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
394 let [b] = dfg_builder.input_wires_arr();
395
396 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
397 let [b] = not_inp.outputs_arr();
398
399 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
400 let [b0] = not_0.outputs_arr();
401 let b1 = b;
402
403 (
404 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
405 vec![not_inp.node(), not_0.node()],
406 )
407 }
408
409 #[rstest]
410 fn test_simple_replacement(
429 simple_hugr: Hugr,
430 dfg_hugr: Hugr,
431 #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
432 ) {
433 let mut h: Hugr = simple_hugr;
434 let h_node_cx: Node = h
436 .nodes()
437 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
438 .unwrap();
439 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
440 let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
441 let n: Hugr = dfg_hugr;
443 let n_node_cx = n
446 .nodes()
447 .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
448 .unwrap();
449 let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap();
450 let n_port_0 = n.node_inputs(n_node_h0).next().unwrap();
452 let n_port_1 = n.node_inputs(n_node_h1).next().unwrap();
453 let (n_cx_out_0, n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
454 let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
455 let n_port_3 = n.linked_inputs(n_node_cx, n_cx_out_1).next().unwrap().1;
456 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
458 let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
459 let h_h1_out = h.node_outputs(h_node_h1).next().unwrap();
460 let (h_outp_node, h_port_2) = h.linked_inputs(h_node_h0, h_h0_out).next().unwrap();
461 let h_port_3 = h.linked_inputs(h_node_h1, h_h1_out).next().unwrap().1;
462 let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
464 let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
465 nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0));
466 nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1));
467 nu_out.insert((h_outp_node, h_port_2), n_port_2);
468 nu_out.insert((h_outp_node, h_port_3), n_port_3);
469 let r = SimpleReplacement {
471 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
472 replacement: n,
473 nu_inp,
474 nu_out,
475 };
476 assert_eq!(
477 HashSet::<_>::from_iter(r.invalidation_set()),
478 HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]),
479 );
480
481 applicator(&mut h, r);
482 assert_eq!(h.validate(), Ok(()));
489 }
490
491 #[rstest]
492 fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
510 let mut h: Hugr = simple_hugr;
511
512 let h_node_cx: Node = h
514 .nodes()
515 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
516 .unwrap();
517 let s: Vec<Node> = vec![h_node_cx].into_iter().collect();
518 let n: Hugr = dfg_hugr2;
520 let n_node_output = n
523 .nodes()
524 .find(|node: &Node| n.get_optype(*node).tag() == OpTag::Output)
525 .unwrap();
526 let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
527 let (n_port_0, n_port_1) = n
529 .node_inputs(n_node_output)
530 .take(2)
531 .collect_tuple()
532 .unwrap();
533 let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
534 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
536 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
537 let h_port_2 = h.node_inputs(h_node_h0).next().unwrap();
538 let h_port_3 = h.node_inputs(h_node_h1).next().unwrap();
539 let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
541 let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
542 nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
543 nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
544 nu_out.insert((h_node_h0, h_port_2), n_port_0);
545 nu_out.insert((h_node_h1, h_port_3), n_port_1);
546 let r = SimpleReplacement {
548 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
549 replacement: n,
550 nu_inp,
551 nu_out,
552 };
553 h.apply_rewrite(r).unwrap();
554 assert_eq!(h.validate(), Ok(()));
561 }
562
563 #[test]
564 fn test_replace_cx_cross() {
565 let q_row: Vec<Type> = vec![qb_t(), qb_t()];
566 let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
567 let mut circ = builder.as_circuit(builder.input_wires());
568 circ.append(cx_gate(), [0, 1]).unwrap();
569 circ.append(cx_gate(), [1, 0]).unwrap();
570 let wires = circ.finish();
571 let [input, output] = builder.io();
572 let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
573 let replacement = h.clone();
574 let orig = h.clone();
575
576 let removal = h
577 .nodes()
578 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
579 .collect_vec();
580 let inputs = h
581 .node_outputs(input)
582 .filter(|&p| {
583 h.get_optype(input)
584 .as_input()
585 .unwrap()
586 .signature()
587 .port_type(p)
588 .is_some()
589 })
590 .map(|p| {
591 let link = h.linked_inputs(input, p).next().unwrap();
592 (link, link)
593 })
594 .collect();
595 let outputs = h
596 .node_inputs(output)
597 .filter(|&p| {
598 h.get_optype(output)
599 .as_output()
600 .unwrap()
601 .signature()
602 .port_type(p)
603 .is_some()
604 })
605 .map(|p| ((output, p), p))
606 .collect();
607 h.apply_rewrite(SimpleReplacement::new(
608 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
609 replacement,
610 inputs,
611 outputs,
612 ))
613 .unwrap();
614
615 assert_eq!(h.edge_count(), orig.edge_count());
617 }
618
619 #[test]
620 fn test_replace_after_copy() {
621 let one_bit = vec![bool_t()];
622 let two_bit = vec![bool_t(), bool_t()];
623
624 let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
625 let inw = builder.input_wires().exactly_one().unwrap();
626 let outw = builder
627 .add_dataflow_op(and_op(), [inw, inw])
628 .unwrap()
629 .outputs();
630 let [input, _] = builder.io();
631 let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
632
633 let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
634 let inw = builder.input_wires();
635 let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
636 let [repl_input, repl_output] = builder.io();
637 let repl = builder.finish_hugr_with_outputs(outw).unwrap();
638
639 let orig = h.clone();
640
641 let removal = h
642 .nodes()
643 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
644 .collect_vec();
645
646 let first_out_p = h.node_outputs(input).next().unwrap();
647 let embedded_inputs = h.linked_inputs(input, first_out_p);
648 let repl_inputs = repl
649 .node_outputs(repl_input)
650 .map(|p| repl.linked_inputs(repl_input, p).next().unwrap());
651 let inputs = embedded_inputs.zip(repl_inputs).collect();
652
653 let outputs = repl
654 .node_inputs(repl_output)
655 .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some())
656 .map(|p| ((repl_output, p), p))
657 .collect();
658
659 h.apply_rewrite(SimpleReplacement::new(
660 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
661 repl,
662 inputs,
663 outputs,
664 ))
665 .unwrap();
666
667 assert_eq!(h.node_count(), orig.node_count());
669 }
670
671 #[rstest]
676 fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
677 let (mut hugr, nodes) = dfg_hugr_copy_bools;
678 let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
679
680 let [_input, output] = hugr.get_io(hugr.root()).unwrap();
681
682 let replacement = {
683 let b =
684 DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
685 let [w] = b.input_wires_arr();
686 b.finish_hugr_with_outputs([w, w]).unwrap()
687 };
688 let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
689
690 let subgraph =
691 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
692 .unwrap();
693 let nu_inp = [
696 (
697 (repl_output, IncomingPort::from(0)),
698 (input_not, IncomingPort::from(0)),
699 ),
700 (
701 (repl_output, IncomingPort::from(1)),
702 (input_not, IncomingPort::from(0)),
703 ),
704 ]
705 .into_iter()
706 .collect();
707 let nu_out = [
710 ((output, IncomingPort::from(0)), IncomingPort::from(0)),
711 ((output, IncomingPort::from(1)), IncomingPort::from(1)),
712 ]
713 .into_iter()
714 .collect();
715
716 let rewrite = SimpleReplacement {
717 subgraph,
718 replacement,
719 nu_inp,
720 nu_out,
721 };
722 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
723
724 assert_eq!(hugr.validate(), Ok(()));
725 assert_eq!(hugr.node_count(), 3);
726 }
727
728 #[rstest]
733 fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
734 let (mut hugr, nodes) = dfg_hugr_half_not_bools;
735 let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
736
737 let [_input, output] = hugr.get_io(hugr.root()).unwrap();
738
739 let (replacement, repl_not) = {
740 let mut b =
741 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
742 let [w] = b.input_wires_arr();
743 let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
744 let [w_not] = not.outputs_arr();
745 (b.finish_hugr_with_outputs([w, w_not]).unwrap(), not.node())
746 };
747 let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
748
749 let subgraph =
750 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
751 let nu_inp = [
754 (
755 (repl_output, IncomingPort::from(0)),
756 (input_not, IncomingPort::from(0)),
757 ),
758 (
759 (repl_not, IncomingPort::from(0)),
760 (input_not, IncomingPort::from(0)),
761 ),
762 ]
763 .into_iter()
764 .collect();
765 let nu_out = [
768 ((output, IncomingPort::from(0)), IncomingPort::from(0)),
769 ((output, IncomingPort::from(1)), IncomingPort::from(1)),
770 ]
771 .into_iter()
772 .collect();
773
774 let rewrite = SimpleReplacement {
775 subgraph,
776 replacement,
777 nu_inp,
778 nu_out,
779 };
780 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
781
782 assert_eq!(hugr.validate(), Ok(()));
783 assert_eq!(hugr.node_count(), 4);
784 }
785
786 #[rstest]
787 fn test_nested_replace(dfg_hugr2: Hugr) {
788 let mut h = dfg_hugr2;
791 let h_node = h
792 .nodes()
793 .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
794 .unwrap();
795
796 let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
798 let [input] = nest_build.input_wires_arr();
799 let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
800 let inner_dfg = n_identity(inner_build).unwrap();
801 let inner_dfg_node = inner_dfg.node();
802 let replacement = nest_build
803 .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
804 .unwrap();
805 let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
806 let nu_inp = vec![(
807 (inner_dfg_node, IncomingPort::from(0)),
808 (h_node, IncomingPort::from(0)),
809 )]
810 .into_iter()
811 .collect();
812
813 let nu_out = vec![(
814 (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)),
815 IncomingPort::from(0),
816 )]
817 .into_iter()
818 .collect();
819
820 let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);
821
822 assert_eq!(h.node_count(), 4);
823
824 rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
825 h.validate().unwrap_or_else(|e| panic!("{e}"));
826
827 assert_eq!(h.node_count(), 6);
828 }
829
830 use crate::hugr::rewrite::replace::Replacement;
831 fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement {
832 use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};
833
834 let mut replacement = s.replacement;
835 let (in_, out) = replacement
836 .children(replacement.root())
837 .take(2)
838 .collect_tuple()
839 .unwrap();
840 let mu_inp = s
841 .nu_inp
842 .iter()
843 .map(|((tgt, tgt_port), (r_n, r_p))| {
844 if *tgt == out {
845 unimplemented!()
846 };
847 let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap();
848 NewEdgeSpec {
849 src,
850 tgt: *tgt,
851 kind: NewEdgeKind::Value {
852 src_pos: src_port,
853 tgt_pos: *tgt_port,
854 },
855 }
856 })
857 .collect();
858 let mu_out = s
859 .nu_out
860 .iter()
861 .map(|((tgt, tgt_port), out_port)| {
862 let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap();
863 if src == in_ {
864 unimplemented!()
865 };
866 NewEdgeSpec {
867 src,
868 tgt: *tgt,
869 kind: NewEdgeKind::Value {
870 src_pos: src_port,
871 tgt_pos: *tgt_port,
872 },
873 }
874 })
875 .collect();
876 replacement.remove_node(in_);
877 replacement.remove_node(out);
878 Replacement {
879 removal: s.subgraph.nodes().to_vec(),
880 replacement,
881 adoptions: HashMap::new(),
882 mu_inp,
883 mu_out,
884 mu_new: vec![],
885 }
886 }
887
888 fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
889 h.apply_rewrite(rw).unwrap();
890 }
891
892 fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
893 h.apply_rewrite(to_replace(h, rw)).unwrap();
894 }
895}