1use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7pub use crate::hugr::internal::HugrMutInternals;
8use crate::hugr::views::SiblingSubgraph;
9use crate::hugr::{HugrMut, HugrView, Rewrite};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::{Hugr, IncomingPort, Node, OutgoingPort};
12
13use itertools::Itertools;
14
15use thiserror::Error;
16
17use super::inline_dfg::InlineDFGError;
18use super::{BoundaryPort, HostPort, ReplacementPort};
19
20#[derive(Debug, Clone)]
26pub struct SimpleReplacement<HostNode = Node> {
27 subgraph: SiblingSubgraph<HostNode>,
29 replacement: Hugr,
31 nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
34 nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>,
37}
38
39impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
40 #[inline]
42 pub fn new(
43 subgraph: SiblingSubgraph<HostNode>,
44 replacement: Hugr,
45 nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
46 nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>,
47 ) -> Self {
48 Self {
49 subgraph,
50 replacement,
51 nu_inp,
52 nu_out,
53 }
54 }
55
56 #[inline]
58 pub fn replacement(&self) -> &Hugr {
59 &self.replacement
60 }
61
62 #[inline]
64 pub fn into_replacement(self) -> Hugr {
65 self.replacement
66 }
67
68 #[inline]
70 pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
71 &self.subgraph
72 }
73
74 pub fn is_valid_rewrite(
76 &self,
77 h: &impl HugrView<Node = HostNode>,
78 ) -> Result<(), SimpleReplacementError> {
79 let parent = self.subgraph.get_parent(h);
80
81 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
83 return Err(SimpleReplacementError::InvalidParentNode());
84 }
85
86 for node in self.subgraph.nodes() {
88 if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
89 return Err(SimpleReplacementError::InvalidRemovedNode());
90 }
91 }
92
93 Ok(())
94 }
95
96 pub fn get_replacement_io(&self) -> Result<[Node; 2], SimpleReplacementError> {
98 self.replacement
99 .get_io(self.replacement.root())
100 .ok_or(SimpleReplacementError::InvalidParentNode())
101 }
102
103 pub fn incoming_boundary<'a>(
112 &'a self,
113 host: &'a impl HugrView<Node = HostNode>,
114 ) -> impl Iterator<
115 Item = (
116 HostPort<HostNode, OutgoingPort>,
117 ReplacementPort<IncomingPort>,
118 ),
119 > + 'a {
120 self.nu_inp
123 .iter()
124 .filter(|&((rep_inp_node, _), _)| {
125 self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
126 })
127 .map(
128 |(&(rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
129 let (rem_inp_pred_node, rem_inp_pred_port) = host
131 .single_linked_output(*rem_inp_node, *rem_inp_port)
132 .unwrap();
133 (
134 HostPort(rem_inp_pred_node, rem_inp_pred_port),
135 ReplacementPort(rep_inp_node, rep_inp_port),
136 )
137 },
138 )
139 }
140
141 pub fn outgoing_boundary<'a>(
152 &'a self,
153 _host: &'a impl HugrView<Node = HostNode>,
154 ) -> impl Iterator<
155 Item = (
156 ReplacementPort<OutgoingPort>,
157 HostPort<HostNode, IncomingPort>,
158 ),
159 > + 'a {
160 let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG");
161
162 self.nu_out
165 .iter()
166 .filter_map(move |(&(rem_out_node, rem_out_port), rep_out_port)| {
167 let (rep_out_pred_node, rep_out_pred_port) = self
168 .replacement
169 .single_linked_output(replacement_output_node, *rep_out_port)
170 .unwrap();
171 (self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
172 (
173 ReplacementPort(rep_out_pred_node, rep_out_pred_port),
175 HostPort(rem_out_node, rem_out_port),
176 )
177 })
178 })
179 }
180
181 pub fn host_to_host_boundary<'a>(
190 &'a self,
191 host: &'a impl HugrView<Node = HostNode>,
192 ) -> impl Iterator<
193 Item = (
194 HostPort<HostNode, OutgoingPort>,
195 HostPort<HostNode, IncomingPort>,
196 ),
197 > + 'a {
198 let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG");
199
200 self.nu_out
203 .iter()
204 .filter_map(move |(&(rem_out_node, rem_out_port), &rep_out_port)| {
205 self.nu_inp
206 .get(&(replacement_output_node, rep_out_port))
207 .map(|&(rem_inp_node, rem_inp_port)| {
208 let (rem_inp_pred_node, rem_inp_pred_port) = host
209 .single_linked_output(rem_inp_node, rem_inp_port)
210 .unwrap();
211 (
212 HostPort(rem_inp_pred_node, rem_inp_pred_port),
213 HostPort(rem_out_node, rem_out_port),
214 )
215 })
216 })
217 }
218
219 pub fn map_host_output(
224 &self,
225 port: impl Into<HostPort<HostNode, IncomingPort>>,
226 ) -> Option<ReplacementPort<IncomingPort>> {
227 let HostPort(node, port) = port.into();
228 let [_, rep_output] = self.get_replacement_io().expect("replacement is a DFG");
229 self.nu_out
230 .get(&(node, port))
231 .map(|&rep_out_port| ReplacementPort(rep_output, rep_out_port))
232 }
233
234 pub fn map_replacement_input(
239 &self,
240 port: impl Into<ReplacementPort<IncomingPort>>,
241 ) -> Option<HostPort<HostNode, IncomingPort>> {
242 let ReplacementPort(node, port) = port.into();
243 self.nu_inp.get(&(node, port)).copied().map(Into::into)
244 }
245
246 pub fn all_boundary_edges<'a>(
254 &'a self,
255 host: &'a impl HugrView<Node = HostNode>,
256 ) -> impl Iterator<
257 Item = (
258 BoundaryPort<HostNode, OutgoingPort>,
259 BoundaryPort<HostNode, IncomingPort>,
260 ),
261 > + 'a {
262 let incoming_boundary = self
263 .incoming_boundary(host)
264 .map(|(src, tgt)| (src.into(), tgt.into()));
265 let outgoing_boundary = self
266 .outgoing_boundary(host)
267 .map(|(src, tgt)| (src.into(), tgt.into()));
268 let host_to_host_boundary = self
269 .host_to_host_boundary(host)
270 .map(|(src, tgt)| (src.into(), tgt.into()));
271
272 incoming_boundary
273 .chain(outgoing_boundary)
274 .chain(host_to_host_boundary)
275 }
276}
277
278impl Rewrite for SimpleReplacement {
279 type Error = SimpleReplacementError;
280 type ApplyResult = Vec<(Node, OpType)>;
281 const UNCHANGED_ON_FAILURE: bool = true;
282
283 fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), SimpleReplacementError> {
284 self.is_valid_rewrite(h)
285 }
286
287 fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
288 self.is_valid_rewrite(h)?;
289
290 let parent = self.subgraph.get_parent(h);
291
292 let boundary_edges = self.all_boundary_edges(h).collect_vec();
300
301 let Self {
302 replacement,
303 subgraph,
304 ..
305 } = self;
306
307 let InsertionResult {
309 new_root,
310 node_map: index_map,
311 } = h.insert_hugr(parent, replacement);
312
313 let replace_children = h.children(new_root).collect::<Vec<Node>>();
315 for &io in &replace_children[..2] {
316 h.remove_node(io);
317 }
318 for &child in &replace_children[2..] {
320 h.set_parent(child, parent);
321 }
322 h.remove_node(new_root);
324
325 for (src, tgt) in boundary_edges {
327 let (src_node, src_port) = src.map_replacement(&index_map);
328 let (tgt_node, tgt_port) = tgt.map_replacement(&index_map);
329 h.connect(src_node, src_port, tgt_node, tgt_port);
330 }
331
332 Ok(subgraph
334 .nodes()
335 .iter()
336 .map(|&node| (node, h.remove_node(node)))
337 .collect())
338 }
339
340 #[inline]
341 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
342 let subcirc = self.subgraph.nodes().iter().copied();
343 let out_neighs = self.nu_out.keys().map(|key| key.0);
344 subcirc.chain(out_neighs)
345 }
346}
347
348#[derive(Debug, Clone, Error, PartialEq, Eq)]
350#[non_exhaustive]
351pub enum SimpleReplacementError {
352 #[error("Parent node is invalid.")]
354 InvalidParentNode(),
355 #[error("A node requested for removal is invalid.")]
357 InvalidRemovedNode(),
358 #[error("A node in the replacement graph is invalid.")]
360 InvalidReplacementNode(),
361 #[error("Inlining replacement failed: {0}")]
363 InliningFailed(#[from] InlineDFGError),
364}
365
366#[cfg(test)]
367pub(in crate::hugr::rewrite) mod test {
368 use itertools::Itertools;
369 use rstest::{fixture, rstest};
370 use std::collections::{HashMap, HashSet};
371
372 use crate::builder::test::n_identity;
373 use crate::builder::{
374 endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
375 DataflowSubContainer, HugrBuilder, ModuleBuilder,
376 };
377 use crate::extension::prelude::{bool_t, qb_t};
378 use crate::extension::ExtensionSet;
379 use crate::hugr::views::{HugrView, SiblingSubgraph};
380 use crate::hugr::{Hugr, HugrMut, Rewrite};
381 use crate::ops::dataflow::DataflowOpTrait;
382 use crate::ops::handle::NodeHandle;
383 use crate::ops::OpTag;
384 use crate::ops::OpTrait;
385 use crate::std_extensions::logic::test::and_op;
386 use crate::std_extensions::logic::LogicOp;
387 use crate::types::{Signature, Type};
388 use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID};
389 use crate::{IncomingPort, Node};
390
391 use super::SimpleReplacement;
392
393 fn make_hugr() -> Result<Hugr, BuildError> {
403 let mut module_builder = ModuleBuilder::new();
404 let _f_id = {
405 let just_q: ExtensionSet = EXTENSION_ID.into();
406 let mut func_builder = module_builder.define_function(
407 "main",
408 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
409 .with_extension_delta(just_q.clone()),
410 )?;
411
412 let [qb0, qb1, qb2] = func_builder.input_wires_arr();
413
414 let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
415
416 let mut inner_builder =
417 func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
418 let inner_graph = {
419 let [wire0, wire1] = inner_builder.input_wires_arr();
420 let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
421 let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
422 let wire45 = inner_builder
423 .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
424 let [wire4, wire5] = wire45.outputs_arr();
425 let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
426 let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
427 inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
428 }?;
429
430 func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
431 };
432 Ok(module_builder.finish_hugr()?)
433 }
434
435 #[fixture]
436 pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr {
437 make_hugr().unwrap()
438 }
439 fn make_dfg_hugr() -> Result<Hugr, BuildError> {
446 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?;
447 let [wire0, wire1] = dfg_builder.input_wires_arr();
448 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
449 let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
450 let wire45 =
451 dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
452 dfg_builder.finish_hugr_with_outputs(wire45.outputs())
453 }
454
455 #[fixture]
456 pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr {
457 make_dfg_hugr().unwrap()
458 }
459
460 fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
466 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
467
468 let [wire0, wire1] = dfg_builder.input_wires_arr();
469 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
470 let wire2out = wire2.outputs().exactly_one().unwrap();
471 let wireoutvec = vec![wire0, wire2out];
472 dfg_builder.finish_hugr_with_outputs(wireoutvec)
473 }
474
475 #[fixture]
476 pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr {
477 make_dfg_hugr2().unwrap()
478 }
479
480 #[fixture]
492 pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
493 let mut dfg_builder =
494 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
495 let [b] = dfg_builder.input_wires_arr();
496
497 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
498 let [b] = not_inp.outputs_arr();
499
500 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
501 let [b0] = not_0.outputs_arr();
502 let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
503 let [b1] = not_1.outputs_arr();
504
505 (
506 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
507 vec![not_inp.node(), not_0.node(), not_1.node()],
508 )
509 }
510
511 #[fixture]
523 pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
524 let mut dfg_builder =
525 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
526 let [b] = dfg_builder.input_wires_arr();
527
528 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
529 let [b] = not_inp.outputs_arr();
530
531 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
532 let [b0] = not_0.outputs_arr();
533 let b1 = b;
534
535 (
536 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
537 vec![not_inp.node(), not_0.node()],
538 )
539 }
540
541 #[rstest]
542 fn test_simple_replacement(
561 simple_hugr: Hugr,
562 dfg_hugr: Hugr,
563 #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
564 ) {
565 let mut h: Hugr = simple_hugr;
566 let h_node_cx: Node = h
568 .nodes()
569 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
570 .unwrap();
571 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
572 let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
573 let n: Hugr = dfg_hugr;
575 let n_node_cx = n
578 .nodes()
579 .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
580 .unwrap();
581 let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap();
582 let n_port_0 = n.node_inputs(n_node_h0).next().unwrap();
584 let n_port_1 = n.node_inputs(n_node_h1).next().unwrap();
585 let (n_cx_out_0, n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
586 let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
587 let n_port_3 = n.linked_inputs(n_node_cx, n_cx_out_1).next().unwrap().1;
588 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
590 let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
591 let h_h1_out = h.node_outputs(h_node_h1).next().unwrap();
592 let (h_outp_node, h_port_2) = h.linked_inputs(h_node_h0, h_h0_out).next().unwrap();
593 let h_port_3 = h.linked_inputs(h_node_h1, h_h1_out).next().unwrap().1;
594 let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
596 let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
597 nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0));
598 nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1));
599 nu_out.insert((h_outp_node, h_port_2), n_port_2);
600 nu_out.insert((h_outp_node, h_port_3), n_port_3);
601 let r = SimpleReplacement {
603 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
604 replacement: n,
605 nu_inp,
606 nu_out,
607 };
608 assert_eq!(
609 HashSet::<_>::from_iter(r.invalidation_set()),
610 HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]),
611 );
612
613 applicator(&mut h, r);
614 assert_eq!(h.validate(), Ok(()));
621 }
622
623 #[rstest]
624 fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
642 let mut h: Hugr = simple_hugr;
643
644 let h_node_cx: Node = h
646 .nodes()
647 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
648 .unwrap();
649 let s: Vec<Node> = vec![h_node_cx].into_iter().collect();
650 let n: Hugr = dfg_hugr2;
652 let n_node_output = n
655 .nodes()
656 .find(|node: &Node| n.get_optype(*node).tag() == OpTag::Output)
657 .unwrap();
658 let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
659 let (n_port_0, n_port_1) = n
661 .node_inputs(n_node_output)
662 .take(2)
663 .collect_tuple()
664 .unwrap();
665 let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
666 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
668 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
669 let h_port_2 = h.node_inputs(h_node_h0).next().unwrap();
670 let h_port_3 = h.node_inputs(h_node_h1).next().unwrap();
671 let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
673 let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
674 nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
675 nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
676 nu_out.insert((h_node_h0, h_port_2), n_port_0);
677 nu_out.insert((h_node_h1, h_port_3), n_port_1);
678 let r = SimpleReplacement {
680 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
681 replacement: n,
682 nu_inp,
683 nu_out,
684 };
685 h.apply_rewrite(r).unwrap();
686 assert_eq!(h.validate(), Ok(()));
693 }
694
695 #[test]
696 fn test_replace_cx_cross() {
697 let q_row: Vec<Type> = vec![qb_t(), qb_t()];
698 let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
699 let mut circ = builder.as_circuit(builder.input_wires());
700 circ.append(cx_gate(), [0, 1]).unwrap();
701 circ.append(cx_gate(), [1, 0]).unwrap();
702 let wires = circ.finish();
703 let [input, output] = builder.io();
704 let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
705 let replacement = h.clone();
706 let orig = h.clone();
707
708 let removal = h
709 .nodes()
710 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
711 .collect_vec();
712 let inputs = h
713 .node_outputs(input)
714 .filter(|&p| {
715 h.get_optype(input)
716 .as_input()
717 .unwrap()
718 .signature()
719 .port_type(p)
720 .is_some()
721 })
722 .map(|p| {
723 let link = h.linked_inputs(input, p).next().unwrap();
724 (link, link)
725 })
726 .collect();
727 let outputs = h
728 .node_inputs(output)
729 .filter(|&p| {
730 h.get_optype(output)
731 .as_output()
732 .unwrap()
733 .signature()
734 .port_type(p)
735 .is_some()
736 })
737 .map(|p| ((output, p), p))
738 .collect();
739 h.apply_rewrite(SimpleReplacement::new(
740 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
741 replacement,
742 inputs,
743 outputs,
744 ))
745 .unwrap();
746
747 assert_eq!(h.edge_count(), orig.edge_count());
749 }
750
751 #[test]
752 fn test_replace_after_copy() {
753 let one_bit = vec![bool_t()];
754 let two_bit = vec![bool_t(), bool_t()];
755
756 let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
757 let inw = builder.input_wires().exactly_one().unwrap();
758 let outw = builder
759 .add_dataflow_op(and_op(), [inw, inw])
760 .unwrap()
761 .outputs();
762 let [input, _] = builder.io();
763 let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
764
765 let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
766 let inw = builder.input_wires();
767 let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
768 let [repl_input, repl_output] = builder.io();
769 let repl = builder.finish_hugr_with_outputs(outw).unwrap();
770
771 let orig = h.clone();
772
773 let removal = h
774 .nodes()
775 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
776 .collect_vec();
777
778 let first_out_p = h.node_outputs(input).next().unwrap();
779 let embedded_inputs = h.linked_inputs(input, first_out_p);
780 let repl_inputs = repl
781 .node_outputs(repl_input)
782 .map(|p| repl.linked_inputs(repl_input, p).next().unwrap());
783 let inputs = embedded_inputs.zip(repl_inputs).collect();
784
785 let outputs = repl
786 .node_inputs(repl_output)
787 .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some())
788 .map(|p| ((repl_output, p), p))
789 .collect();
790
791 h.apply_rewrite(SimpleReplacement::new(
792 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
793 repl,
794 inputs,
795 outputs,
796 ))
797 .unwrap();
798
799 assert_eq!(h.node_count(), orig.node_count());
801 }
802
803 #[rstest]
808 fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
809 let (mut hugr, nodes) = dfg_hugr_copy_bools;
810 let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
811
812 let [_input, output] = hugr.get_io(hugr.root()).unwrap();
813
814 let replacement = {
815 let b =
816 DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
817 let [w] = b.input_wires_arr();
818 b.finish_hugr_with_outputs([w, w]).unwrap()
819 };
820 let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
821
822 let subgraph =
823 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
824 .unwrap();
825 let nu_inp = [
828 (
829 (repl_output, IncomingPort::from(0)),
830 (input_not, IncomingPort::from(0)),
831 ),
832 (
833 (repl_output, IncomingPort::from(1)),
834 (input_not, IncomingPort::from(0)),
835 ),
836 ]
837 .into_iter()
838 .collect();
839 let nu_out = [
842 ((output, IncomingPort::from(0)), IncomingPort::from(0)),
843 ((output, IncomingPort::from(1)), IncomingPort::from(1)),
844 ]
845 .into_iter()
846 .collect();
847
848 let rewrite = SimpleReplacement {
849 subgraph,
850 replacement,
851 nu_inp,
852 nu_out,
853 };
854 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
855
856 assert_eq!(hugr.validate(), Ok(()));
857 assert_eq!(hugr.node_count(), 3);
858 }
859
860 #[rstest]
865 fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
866 let (mut hugr, nodes) = dfg_hugr_half_not_bools;
867 let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
868
869 let [_input, output] = hugr.get_io(hugr.root()).unwrap();
870
871 let (replacement, repl_not) = {
872 let mut b =
873 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
874 let [w] = b.input_wires_arr();
875 let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
876 let [w_not] = not.outputs_arr();
877 (b.finish_hugr_with_outputs([w, w_not]).unwrap(), not.node())
878 };
879 let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
880
881 let subgraph =
882 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
883 let nu_inp = [
886 (
887 (repl_output, IncomingPort::from(0)),
888 (input_not, IncomingPort::from(0)),
889 ),
890 (
891 (repl_not, IncomingPort::from(0)),
892 (input_not, IncomingPort::from(0)),
893 ),
894 ]
895 .into_iter()
896 .collect();
897 let nu_out = [
900 ((output, IncomingPort::from(0)), IncomingPort::from(0)),
901 ((output, IncomingPort::from(1)), IncomingPort::from(1)),
902 ]
903 .into_iter()
904 .collect();
905
906 let rewrite = SimpleReplacement {
907 subgraph,
908 replacement,
909 nu_inp,
910 nu_out,
911 };
912 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
913
914 assert_eq!(hugr.validate(), Ok(()));
915 assert_eq!(hugr.node_count(), 4);
916 }
917
918 #[rstest]
919 fn test_nested_replace(dfg_hugr2: Hugr) {
920 let mut h = dfg_hugr2;
923 let h_node = h
924 .nodes()
925 .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
926 .unwrap();
927
928 let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
930 let [input] = nest_build.input_wires_arr();
931 let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
932 let inner_dfg = n_identity(inner_build).unwrap();
933 let inner_dfg_node = inner_dfg.node();
934 let replacement = nest_build
935 .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
936 .unwrap();
937 let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
938 let nu_inp = vec![(
939 (inner_dfg_node, IncomingPort::from(0)),
940 (h_node, IncomingPort::from(0)),
941 )]
942 .into_iter()
943 .collect();
944
945 let nu_out = vec![(
946 (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)),
947 IncomingPort::from(0),
948 )]
949 .into_iter()
950 .collect();
951
952 let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);
953
954 assert_eq!(h.node_count(), 4);
955
956 rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
957 h.validate().unwrap_or_else(|e| panic!("{e}"));
958
959 assert_eq!(h.node_count(), 6);
960 }
961
962 use crate::hugr::rewrite::replace::Replacement;
963 fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
964 use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};
965
966 let mut replacement = s.replacement;
967 let (in_, out) = replacement
968 .children(replacement.root())
969 .take(2)
970 .collect_tuple()
971 .unwrap();
972 let mu_inp = s
973 .nu_inp
974 .iter()
975 .map(|((tgt, tgt_port), (r_n, r_p))| {
976 if *tgt == out {
977 unimplemented!()
978 };
979 let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap();
980 NewEdgeSpec {
981 src,
982 tgt: *tgt,
983 kind: NewEdgeKind::Value {
984 src_pos: src_port,
985 tgt_pos: *tgt_port,
986 },
987 }
988 })
989 .collect();
990 let mu_out = s
991 .nu_out
992 .iter()
993 .map(|((tgt, tgt_port), out_port)| {
994 let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap();
995 if src == in_ {
996 unimplemented!()
997 };
998 NewEdgeSpec {
999 src,
1000 tgt: *tgt,
1001 kind: NewEdgeKind::Value {
1002 src_pos: src_port,
1003 tgt_pos: *tgt_port,
1004 },
1005 }
1006 })
1007 .collect();
1008 replacement.remove_node(in_);
1009 replacement.remove_node(out);
1010 Replacement {
1011 removal: s.subgraph.nodes().to_vec(),
1012 replacement,
1013 adoptions: HashMap::new(),
1014 mu_inp,
1015 mu_out,
1016 mu_new: vec![],
1017 }
1018 }
1019
1020 fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
1021 h.apply_rewrite(rw).unwrap();
1022 }
1023
1024 fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
1025 h.apply_rewrite(to_replace(h, rw)).unwrap();
1026 }
1027}