1use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7use crate::hugr::views::SiblingSubgraph;
8pub use crate::hugr::views::sibling_subgraph::InvalidReplacement;
9use crate::hugr::{HugrMut, HugrView};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::{Hugr, IncomingPort, Node, OutgoingPort, PortIndex};
12
13use itertools::{Either, Itertools};
14
15use thiserror::Error;
16
17use super::inline_dfg::InlineDFGError;
18use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort};
19
20#[derive(Debug, Clone)]
26pub struct SimpleReplacement<HostNode = Node> {
27 subgraph: SiblingSubgraph<HostNode>,
29 replacement: Hugr,
31}
32
33impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
34 #[inline]
37 pub fn new_unchecked(subgraph: SiblingSubgraph<HostNode>, replacement: Hugr) -> Self {
38 Self {
39 subgraph,
40 replacement,
41 }
42 }
43
44 pub fn try_new(
49 subgraph: SiblingSubgraph<HostNode>,
50 host: &impl HugrView<Node = HostNode>,
51 replacement: Hugr,
52 ) -> Result<Self, InvalidReplacement> {
53 let subgraph_sig = subgraph.signature(host);
54 let repl_sig =
55 replacement
56 .inner_function_type()
57 .ok_or(InvalidReplacement::InvalidDataflowGraph {
58 node: replacement.entrypoint(),
59 op: replacement.get_optype(replacement.entrypoint()).to_owned(),
60 })?;
61 if subgraph_sig != repl_sig {
62 return Err(InvalidReplacement::InvalidSignature {
63 expected: subgraph_sig,
64 actual: Some(repl_sig.into_owned()),
65 });
66 }
67 Ok(Self {
68 subgraph,
69 replacement,
70 })
71 }
72
73 #[inline]
75 pub fn replacement(&self) -> &Hugr {
76 &self.replacement
77 }
78
79 #[inline]
81 pub fn into_replacement(self) -> Hugr {
82 self.replacement
83 }
84
85 #[inline]
87 pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
88 &self.subgraph
89 }
90
91 pub fn is_valid_rewrite(
93 &self,
94 h: &impl HugrView<Node = HostNode>,
95 ) -> Result<(), SimpleReplacementError> {
96 let parent = self.subgraph.get_parent(h);
97
98 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
100 return Err(SimpleReplacementError::InvalidParentNode());
101 }
102
103 for node in self.subgraph.nodes() {
105 if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
106 return Err(SimpleReplacementError::InvalidRemovedNode());
107 }
108 }
109
110 Ok(())
111 }
112
113 pub fn get_replacement_io(&self) -> [Node; 2] {
115 self.replacement
116 .get_io(self.replacement.entrypoint())
117 .expect("replacement is a DFG")
118 }
119
120 pub fn linked_replacement_output(
129 &self,
130 port: impl Into<HostPort<HostNode, IncomingPort>>,
131 host: &impl HugrView<Node = HostNode>,
132 ) -> Option<BoundaryPort<HostNode, OutgoingPort>> {
133 let HostPort(node, port) = port.into();
134 let pos = self
135 .subgraph
136 .outgoing_ports()
137 .iter()
138 .position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?;
139
140 Some(self.linked_replacement_output_by_position(pos, host))
141 }
142
143 fn linked_replacement_output_by_position(
148 &self,
149 pos: usize,
150 host: &impl HugrView<Node = HostNode>,
151 ) -> BoundaryPort<HostNode, OutgoingPort> {
152 debug_assert!(pos < self.subgraph().signature(host).output_count());
153
154 let [repl_inp, repl_out] = self.get_replacement_io();
156 let (out_node, out_port) = self
157 .replacement
158 .single_linked_output(repl_out, pos)
159 .expect("valid dfg wire");
160
161 if out_node != repl_inp {
162 BoundaryPort::Replacement(out_node, out_port)
163 } else {
164 let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
165 .first()
166 .expect("non-empty boundary partition");
167 let (out_node, out_port) = host
168 .single_linked_output(in_node, in_port)
169 .expect("valid dfg wire");
170 BoundaryPort::Host(out_node, out_port)
171 }
172 }
173
174 pub fn linked_host_outputs(
182 &self,
183 port: impl Into<ReplacementPort<OutgoingPort>>,
184 host: &impl HugrView<Node = HostNode>,
185 ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> {
186 let ReplacementPort(node, port) = port.into();
187 let [_, repl_out] = self.get_replacement_io();
188 let positions = self
189 .replacement
190 .linked_inputs(node, port)
191 .filter_map(move |(n, p)| (n == repl_out).then_some(p.index()));
192
193 positions
194 .map(|pos| self.subgraph.outgoing_ports()[pos])
195 .flat_map(|(out_node, out_port)| {
196 let in_nodes_ports = host.linked_inputs(out_node, out_port);
197 in_nodes_ports.map(|(n, p)| HostPort(n, p))
198 })
199 }
200
201 pub fn linked_replacement_inputs<'a>(
210 &'a self,
211 port: impl Into<HostPort<HostNode, OutgoingPort>>,
212 host: &'a impl HugrView<Node = HostNode>,
213 ) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> + 'a {
214 let HostPort(node, port) = port.into();
215 let positions = self
216 .subgraph
217 .incoming_ports()
218 .iter()
219 .positions(move |ports| {
220 let (n, p) = *ports.first().expect("non-empty boundary partition");
221 host.single_linked_output(n, p).expect("valid dfg wire") == (node, port)
222 });
223
224 positions.flat_map(|pos| self.linked_replacement_inputs_by_position(pos, host))
225 }
226
227 fn linked_replacement_inputs_by_position(
233 &self,
234 pos: usize,
235 host: &impl HugrView<Node = HostNode>,
236 ) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> {
237 debug_assert!(pos < self.subgraph().signature(host).input_count());
238
239 let [repl_inp, repl_out] = self.get_replacement_io();
240 self.replacement
241 .linked_inputs(repl_inp, pos)
242 .flat_map(move |(in_node, in_port)| {
243 if in_node != repl_out {
244 Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port)))
245 } else {
246 let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()];
247 let in_nodes_ports = host.linked_inputs(out_node, out_port);
248 Either::Right(in_nodes_ports.map(|(n, p)| BoundaryPort::Host(n, p)))
249 }
250 })
251 }
252
253 pub fn linked_host_input(
261 &self,
262 port: impl Into<ReplacementPort<IncomingPort>>,
263 host: &impl HugrView<Node = HostNode>,
264 ) -> HostPort<HostNode, OutgoingPort> {
265 let ReplacementPort(node, port) = port.into();
266 let (out_node, out_port) = self
267 .replacement
268 .single_linked_output(node, port)
269 .expect("valid dfg wire");
270
271 let [repl_in, _] = self.get_replacement_io();
272 assert!(out_node == repl_in, "not a boundary port");
273
274 let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
275 .first()
276 .expect("non-empty input partition");
277
278 let (host_node, host_port) = host
279 .single_linked_output(in_node, in_port)
280 .expect("valid dfg wire");
281 HostPort(host_node, host_port)
282 }
283
284 pub fn incoming_boundary<'a>(
295 &'a self,
296 host: &'a impl HugrView<Node = HostNode>,
297 ) -> impl Iterator<
298 Item = (
299 HostPort<HostNode, OutgoingPort>,
300 ReplacementPort<IncomingPort>,
301 ),
302 > + 'a {
303 let subgraph_outgoing_ports = self
305 .subgraph
306 .incoming_ports()
307 .iter()
308 .map(|in_ports| *in_ports.first().expect("non-empty input partition"))
309 .map(|(node, in_port)| {
310 host.single_linked_output(node, in_port)
311 .expect("valid dfg wire")
312 });
313
314 subgraph_outgoing_ports
315 .enumerate()
316 .flat_map(|(pos, subg_np)| {
317 self.linked_replacement_inputs_by_position(pos, host)
318 .filter_map(move |np| Some((np.as_replacement()?, subg_np)))
319 })
320 .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
321 (
322 HostPort(subgraph_node, subgraph_port),
323 ReplacementPort(repl_node, repl_port),
324 )
325 })
326 }
327
328 pub fn outgoing_boundary<'a>(
341 &'a self,
342 host: &'a impl HugrView<Node = HostNode>,
343 ) -> impl Iterator<
344 Item = (
345 ReplacementPort<OutgoingPort>,
346 HostPort<HostNode, IncomingPort>,
347 ),
348 > + 'a {
349 let subgraph_incoming_ports = self.subgraph.outgoing_ports().iter().map(
351 move |&(subgraph_out_node, subgraph_out_port)| {
352 host.linked_inputs(subgraph_out_node, subgraph_out_port)
353 },
354 );
355
356 subgraph_incoming_ports
357 .enumerate()
358 .filter_map(|(pos, subg_all)| {
359 let np = self
360 .linked_replacement_output_by_position(pos, host)
361 .as_replacement()?;
362 Some((np, subg_all))
363 })
364 .flat_map(|(repl_np, subg_all)| subg_all.map(move |subg_np| (repl_np, subg_np)))
365 .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
366 (
367 ReplacementPort(repl_node, repl_port),
368 HostPort(subgraph_node, subgraph_port),
369 )
370 })
371 }
372
373 pub fn host_to_host_boundary<'a>(
388 &'a self,
389 host: &'a impl HugrView<Node = HostNode>,
390 ) -> impl Iterator<
391 Item = (
392 HostPort<HostNode, OutgoingPort>,
393 HostPort<HostNode, IncomingPort>,
394 ),
395 > + 'a {
396 let subgraph_incoming_ports = self.subgraph.outgoing_ports().iter().map(
398 move |&(subgraph_out_node, subgraph_out_port)| {
399 host.linked_inputs(subgraph_out_node, subgraph_out_port)
400 },
401 );
402
403 subgraph_incoming_ports
404 .enumerate()
405 .filter_map(|(pos, subg_all)| {
406 Some((
407 self.linked_replacement_output_by_position(pos, host)
408 .as_host()?,
409 subg_all,
410 ))
411 })
412 .flat_map(|(host_np, subg_all)| subg_all.map(move |subg_np| (host_np, subg_np)))
413 .map(
414 |((host_out_node, host_out_port), (host_in_node, host_in_port))| {
415 (
416 HostPort(host_out_node, host_out_port),
417 HostPort(host_in_node, host_in_port),
418 )
419 },
420 )
421 }
422
423 pub fn map_host_output(
431 &self,
432 port: impl Into<HostPort<HostNode, OutgoingPort>>,
433 ) -> Option<ReplacementPort<IncomingPort>> {
434 let HostPort(node, port) = port.into();
435 let pos = self
436 .subgraph
437 .outgoing_ports()
438 .iter()
439 .position(|&node_port| node_port == (node, port))?;
440 let incoming_port: IncomingPort = pos.into();
441 let [_, rep_output] = self.get_replacement_io();
442 Some(ReplacementPort(rep_output, incoming_port))
443 }
444
445 pub fn map_replacement_input(
450 &self,
451 port: impl Into<ReplacementPort<OutgoingPort>>,
452 ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> + '_ {
453 let ReplacementPort(node, port) = port.into();
454 let [repl_input, _] = self.get_replacement_io();
455
456 let ports = if node == repl_input {
457 self.subgraph.incoming_ports().get(port.index())
458 } else {
459 None
460 };
461 ports
462 .into_iter()
463 .flat_map(|ports| ports.iter().map(|&(n, p)| HostPort(n, p)))
464 }
465
466 pub fn all_boundary_edges<'a>(
475 &'a self,
476 host: &'a impl HugrView<Node = HostNode>,
477 ) -> impl Iterator<
478 Item = (
479 BoundaryPort<HostNode, OutgoingPort>,
480 BoundaryPort<HostNode, IncomingPort>,
481 ),
482 > + 'a {
483 let incoming_boundary = self
484 .incoming_boundary(host)
485 .map(|(src, tgt)| (src.into(), tgt.into()));
486 let outgoing_boundary = self
487 .outgoing_boundary(host)
488 .map(|(src, tgt)| (src.into(), tgt.into()));
489 let host_to_host_boundary = self
490 .host_to_host_boundary(host)
491 .map(|(src, tgt)| (src.into(), tgt.into()));
492
493 incoming_boundary
494 .chain(outgoing_boundary)
495 .chain(host_to_host_boundary)
496 }
497
498 pub(crate) fn map_host_nodes<N: HugrNode>(
510 &self,
511 node_map: impl Fn(HostNode) -> N,
512 ) -> SimpleReplacement<N> {
513 let Self {
514 subgraph,
515 replacement,
516 } = self;
517 let subgraph = subgraph.map_nodes(node_map);
518 SimpleReplacement::new_unchecked(subgraph, replacement.clone())
519 }
520}
521
522impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
523 type Error = SimpleReplacementError;
524 type Node = HostNode;
525
526 fn verify(&self, h: &impl HugrView<Node = HostNode>) -> Result<(), SimpleReplacementError> {
527 self.is_valid_rewrite(h)
528 }
529
530 #[inline]
531 fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
532 self.subgraph.nodes().iter().copied()
533 }
534}
535
536pub struct Outcome<HostNode = Node> {
538 pub node_map: HashMap<Node, HostNode>,
540 pub removed_nodes: HashMap<HostNode, OpType>,
542}
543
544impl<N: HugrNode> PatchHugrMut for SimpleReplacement<N> {
545 type Outcome = Outcome<N>;
546 const UNCHANGED_ON_FAILURE: bool = true;
547
548 fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
549 self.is_valid_rewrite(h)?;
550
551 let parent = self.subgraph.get_parent(h);
552
553 let boundary_edges = self.all_boundary_edges(h).collect_vec();
561
562 let Self {
563 replacement,
564 subgraph,
565 ..
566 } = self;
567
568 let repl_io = replacement
570 .get_io(replacement.entrypoint())
571 .expect("replacement is DFG-rooted");
572 let repl_entrypoint = replacement.entrypoint();
573
574 let InsertionResult {
576 inserted_entrypoint: new_entrypoint,
577 mut node_map,
578 } = h.insert_hugr(parent, replacement);
579
580 for node in repl_io {
582 let node_h = node_map[&node];
583 h.remove_node(node_h);
584 node_map.remove(&node);
585 }
586
587 for child in h.children(new_entrypoint).collect_vec() {
589 h.set_parent(child, parent);
590 }
591
592 h.remove_node(new_entrypoint);
594 node_map.remove(&repl_entrypoint);
595
596 for (src, tgt) in boundary_edges {
598 let (src_node, src_port) = src.map_replacement(&node_map);
599 let (tgt_node, tgt_port) = tgt.map_replacement(&node_map);
600 h.connect(src_node, src_port, tgt_node, tgt_port);
601 }
602
603 let removed_nodes = subgraph
605 .nodes()
606 .iter()
607 .map(|&node| (node, h.remove_node(node)))
608 .collect();
609
610 Ok(Outcome {
611 node_map,
612 removed_nodes,
613 })
614 }
615}
616
617#[derive(Debug, Clone, Error, PartialEq, Eq)]
619#[non_exhaustive]
620pub enum SimpleReplacementError {
621 #[error("Parent node is invalid.")]
623 InvalidParentNode(),
624 #[error("A node requested for removal is invalid.")]
626 InvalidRemovedNode(),
627 #[error("A node in the replacement graph is invalid.")]
629 InvalidReplacementNode(),
630 #[error("Inlining replacement failed: {0}")]
632 InliningFailed(#[from] InlineDFGError),
633}
634
635#[cfg(test)]
636pub(in crate::hugr::patch) mod test {
637 use itertools::Itertools;
638 use rstest::{fixture, rstest};
639
640 use std::collections::{BTreeSet, HashMap, HashSet};
641
642 use crate::builder::test::n_identity;
643 use crate::builder::{
644 BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
645 HugrBuilder, ModuleBuilder, endo_sig, inout_sig,
646 };
647 use crate::extension::prelude::{bool_t, qb_t};
648 use crate::hugr::patch::simple_replace::Outcome;
649 use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort};
650 use crate::hugr::views::{HugrView, SiblingSubgraph};
651 use crate::hugr::{Hugr, HugrMut, Patch};
652 use crate::ops::OpTag;
653 use crate::ops::OpTrait;
654 use crate::ops::handle::NodeHandle;
655 use crate::std_extensions::logic::LogicOp;
656 use crate::std_extensions::logic::test::and_op;
657 use crate::types::{Signature, Type};
658 use crate::utils::test_quantum_extension::{cx_gate, h_gate};
659 use crate::{IncomingPort, Node, OutgoingPort};
660
661 use super::SimpleReplacement;
662
663 fn make_hugr() -> Result<Hugr, BuildError> {
673 let mut module_builder = ModuleBuilder::new();
674 let _f_id = {
675 let mut func_builder = module_builder
676 .define_function("main", Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]))?;
677
678 let [qb0, qb1, qb2] = func_builder.input_wires_arr();
679
680 let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
681
682 let mut inner_builder =
683 func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
684 let inner_graph = {
685 let [wire0, wire1] = inner_builder.input_wires_arr();
686 let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
687 let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
688 let wire45 = inner_builder
689 .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
690 let [wire4, wire5] = wire45.outputs_arr();
691 let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
692 let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
693 inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
694 }?;
695
696 func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
697 };
698 Ok(module_builder.finish_hugr()?)
699 }
700
701 #[fixture]
702 pub(in crate::hugr::patch) fn simple_hugr() -> Hugr {
703 make_hugr().unwrap()
704 }
705 fn make_dfg_hugr() -> Result<Hugr, BuildError> {
712 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
713 let [wire0, wire1] = dfg_builder.input_wires_arr();
714 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
715 let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
716 let wire45 =
717 dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
718 dfg_builder.finish_hugr_with_outputs(wire45.outputs())
719 }
720
721 #[fixture]
722 pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr {
723 make_dfg_hugr().unwrap()
724 }
725
726 fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
732 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
733
734 let [wire0, wire1] = dfg_builder.input_wires_arr();
735 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
736 let wire2out = wire2.outputs().exactly_one().unwrap();
737 let wireoutvec = vec![wire0, wire2out];
738 dfg_builder.finish_hugr_with_outputs(wireoutvec)
739 }
740
741 #[fixture]
742 pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr {
743 make_dfg_hugr2().unwrap()
744 }
745
746 #[fixture]
759 pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
760 let mut dfg_builder =
761 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
762 let [b] = dfg_builder.input_wires_arr();
763
764 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
765 let [b] = not_inp.outputs_arr();
766
767 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
768 let [b0] = not_0.outputs_arr();
769 let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
770 let [b1] = not_1.outputs_arr();
771
772 (
773 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
774 vec![not_inp.node(), not_0.node(), not_1.node()],
775 )
776 }
777
778 #[fixture]
791 pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
792 let mut dfg_builder =
793 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
794 let [b] = dfg_builder.input_wires_arr();
795
796 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
797 let [b] = not_inp.outputs_arr();
798
799 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
800 let [b0] = not_0.outputs_arr();
801 let b1 = b;
802
803 (
804 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
805 vec![not_inp.node(), not_0.node()],
806 )
807 }
808
809 #[rstest]
810 fn test_simple_replacement(
829 simple_hugr: Hugr,
830 dfg_hugr: Hugr,
831 #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
832 ) {
833 let mut h: Hugr = simple_hugr;
834 let h_node_cx: Node = h
836 .entry_descendants()
837 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
838 .unwrap();
839 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
840 let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
841 let n: Hugr = dfg_hugr;
843 let n_node_cx = n
846 .entry_descendants()
847 .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
848 .unwrap();
849 let (n_cx_out_0, _n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
851 let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
852 let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
854 let r = SimpleReplacement {
856 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
857 replacement: n,
858 };
859
860 assert_eq!(
862 r.map_host_output((h_node_h0, h_h0_out)).unwrap(),
863 ReplacementPort::from((r.get_replacement_io()[1], n_port_2))
864 );
865
866 assert_eq!(
868 HashSet::<_>::from_iter(r.invalidation_set()),
869 HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1]),
870 );
871
872 applicator(&mut h, r);
873 assert_eq!(h.validate(), Ok(()));
880 }
881
882 #[rstest]
883 fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
901 let mut h: Hugr = simple_hugr;
902
903 let h_node_cx: Node = h
905 .entry_descendants()
906 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
907 .unwrap();
908 let s: Vec<Node> = vec![h_node_cx];
909 let n: Hugr = dfg_hugr2;
911 let n_node_output = n.get_io(n.entrypoint()).unwrap()[1];
914 let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
915 let r = SimpleReplacement {
917 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
918 replacement: n,
919 };
920 let Outcome {
921 node_map,
922 removed_nodes,
923 } = h.apply_patch(r).unwrap();
924
925 assert_eq!(
926 node_map.into_keys().collect::<HashSet<_>>(),
927 [n_node_h].into_iter().collect::<HashSet<_>>(),
928 );
929 assert_eq!(
930 removed_nodes.into_keys().collect::<HashSet<_>>(),
931 [h_node_cx].into_iter().collect::<HashSet<_>>(),
932 );
933
934 assert_eq!(h.validate(), Ok(()));
941 }
942
943 #[test]
944 fn test_replace_cx_cross() {
945 let q_row: Vec<Type> = vec![qb_t(), qb_t()];
946 let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
947 let mut circ = builder.as_circuit(builder.input_wires());
948 circ.append(cx_gate(), [0, 1]).unwrap();
949 circ.append(cx_gate(), [1, 0]).unwrap();
950 let wires = circ.finish();
951 let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
952 let replacement = h.clone();
953 let orig = h.clone();
954
955 let removal = h
956 .entry_descendants()
957 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
958 .collect_vec();
959 h.apply_patch(
960 SimpleReplacement::try_new(
961 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
962 &h,
963 replacement,
964 )
965 .unwrap(),
966 )
967 .unwrap();
968
969 assert_eq!(h.num_edges(), orig.num_edges());
971 }
972
973 #[test]
974 fn test_replace_after_copy() {
975 let one_bit = vec![bool_t()];
976 let two_bit = vec![bool_t(), bool_t()];
977
978 let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
979 let inw = builder.input_wires().exactly_one().unwrap();
980 let outw = builder
981 .add_dataflow_op(and_op(), [inw, inw])
982 .unwrap()
983 .outputs();
984 let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
985
986 let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
987 let inw = builder.input_wires();
988 let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
989 let repl = builder.finish_hugr_with_outputs(outw).unwrap();
990
991 let orig = h.clone();
992
993 let removal = h
994 .entry_descendants()
995 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
996 .collect_vec();
997
998 h.apply_patch(
999 SimpleReplacement::try_new(
1000 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
1001 &h,
1002 repl,
1003 )
1004 .unwrap(),
1005 )
1006 .unwrap();
1007
1008 assert_eq!(h.num_nodes(), orig.num_nodes());
1010 }
1011
1012 #[rstest]
1017 fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
1018 let (mut hugr, nodes) = dfg_hugr_copy_bools;
1019 let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
1020
1021 let replacement = {
1022 let b =
1023 DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1024 let [w] = b.input_wires_arr();
1025 b.finish_hugr_with_outputs([w, w]).unwrap()
1026 };
1027
1028 let subgraph =
1029 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
1030 .unwrap();
1031
1032 let rewrite = SimpleReplacement {
1033 subgraph,
1034 replacement,
1035 };
1036 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1037
1038 assert_eq!(hugr.validate(), Ok(()));
1039 assert_eq!(hugr.entry_descendants().count(), 3);
1040 }
1041
1042 #[rstest]
1047 fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
1048 let (mut hugr, nodes) = dfg_hugr_half_not_bools;
1049 let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
1050
1051 let replacement = {
1052 let mut b =
1053 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1054 let [w] = b.input_wires_arr();
1055 let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
1056 let [w_not] = not.outputs_arr();
1057 b.finish_hugr_with_outputs([w, w_not]).unwrap()
1058 };
1059
1060 let subgraph =
1061 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
1062
1063 let rewrite = SimpleReplacement {
1064 subgraph,
1065 replacement,
1066 };
1067 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1068
1069 assert_eq!(hugr.validate(), Ok(()));
1070 assert_eq!(hugr.entry_descendants().count(), 4);
1071 }
1072
1073 #[rstest]
1074 fn test_nested_replace(dfg_hugr2: Hugr) {
1075 let mut h = dfg_hugr2;
1078 let h_node = h
1079 .entry_descendants()
1080 .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
1081 .unwrap();
1082
1083 let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
1085 let [input] = nest_build.input_wires_arr();
1086 let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
1087 let inner_dfg = n_identity(inner_build).unwrap();
1088 let replacement = nest_build
1089 .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
1090 .unwrap();
1091 let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
1092
1093 let rewrite = SimpleReplacement::try_new(subgraph, &h, replacement).unwrap();
1094
1095 assert_eq!(h.entry_descendants().count(), 4);
1096
1097 rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
1098 h.validate().unwrap_or_else(|e| panic!("{e}"));
1099
1100 assert_eq!(h.entry_descendants().count(), 6);
1101 }
1102
1103 #[fixture]
1105 fn copy_not_not_copy_hugr() -> Hugr {
1106 let mut b = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(); 4])).unwrap();
1107 let [w] = b.input_wires_arr();
1108 let not1 = b.add_dataflow_op(LogicOp::Not, [w]).unwrap();
1109 let not2 = b.add_dataflow_op(LogicOp::Not, [w]).unwrap();
1110
1111 let [out1] = not1.outputs_arr();
1112 let [out2] = not2.outputs_arr();
1113
1114 b.finish_hugr_with_outputs([out1, out2, out1, out2])
1115 .unwrap()
1116 }
1117
1118 #[rstest]
1119 fn test_boundary_traversal_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1120 let hugr = copy_not_not_copy_hugr;
1121 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1122 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1123 let subg_incoming = vec![
1124 vec![(not1, IncomingPort::from(0))],
1125 vec![(not2, IncomingPort::from(0))],
1126 ];
1127 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1128
1129 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1130
1131 let repl = {
1133 let b = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
1134 let [w1, w2] = b.input_wires_arr();
1135 let repl_hugr = b.finish_hugr_with_outputs([w1, w2]).unwrap();
1136 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap()
1137 };
1138
1139 let replacement_inputs: Vec<_> = repl
1141 .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
1142 .collect();
1143
1144 assert_eq!(
1145 BTreeSet::from_iter(replacement_inputs),
1146 (0..4)
1147 .map(|i| BoundaryPort::Host(out, IncomingPort::from(i)))
1148 .collect()
1149 );
1150
1151 let replacement_output = (0..4)
1153 .map(|i| {
1154 repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
1155 .unwrap()
1156 })
1157 .collect_vec();
1158
1159 assert_eq!(
1160 replacement_output,
1161 vec![BoundaryPort::Host(inp, OutgoingPort::from(0)); 4]
1162 );
1163 }
1164
1165 #[rstest]
1166 fn test_boundary_traversal_copy_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1167 let hugr = copy_not_not_copy_hugr;
1168 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1169 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1170 let subg_incoming = vec![vec![
1171 (not1, IncomingPort::from(0)),
1172 (not2, IncomingPort::from(0)),
1173 ]];
1174 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1175
1176 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1177
1178 let repl = {
1180 let b = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap();
1181 let [w] = b.input_wires_arr();
1182 let repl_hugr = b.finish_hugr_with_outputs([w, w]).unwrap();
1183 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap()
1184 };
1185
1186 let replacement_inputs: Vec<_> = repl
1187 .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
1188 .collect();
1189
1190 assert_eq!(
1191 BTreeSet::from_iter(replacement_inputs),
1192 (0..4)
1193 .map(|i| BoundaryPort::Host(out, IncomingPort::from(i)))
1194 .collect()
1195 );
1196
1197 let replacement_output = (0..4)
1198 .map(|i| {
1199 repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
1200 .unwrap()
1201 })
1202 .collect_vec();
1203
1204 assert_eq!(
1205 replacement_output,
1206 vec![BoundaryPort::Host(inp, OutgoingPort::from(0)); 4]
1207 );
1208 }
1209
1210 #[rstest]
1211 fn test_boundary_traversal_non_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1212 let hugr = copy_not_not_copy_hugr;
1213 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1214 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1215 let subg_incoming = vec![
1216 vec![(not1, IncomingPort::from(0))],
1217 vec![(not2, IncomingPort::from(0))],
1218 ];
1219 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1220
1221 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1222
1223 let (repl, or_node) = {
1225 let mut b = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
1226 let [w1, w2] = b.input_wires_arr();
1227 let or_handle = b.add_dataflow_op(LogicOp::Or, [w1, w2]).unwrap();
1228 let [out] = or_handle.outputs_arr();
1229 let repl_hugr = b.finish_hugr_with_outputs([out, out]).unwrap();
1230 (
1231 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap(),
1232 or_handle.node(),
1233 )
1234 };
1235
1236 let replacement_inputs: Vec<_> = repl
1237 .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
1238 .collect();
1239
1240 assert_eq!(
1241 BTreeSet::from_iter(replacement_inputs),
1242 (0..2)
1243 .map(|i| BoundaryPort::Replacement(or_node, IncomingPort::from(i)))
1244 .collect()
1245 );
1246 assert_eq!(
1247 repl.linked_host_input((or_node, IncomingPort::from(0)), &hugr),
1248 (inp, OutgoingPort::from(0)).into()
1249 );
1250
1251 let replacement_output = (0..4)
1252 .map(|i| {
1253 repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
1254 .unwrap()
1255 })
1256 .collect_vec();
1257
1258 assert_eq!(
1259 replacement_output,
1260 vec![BoundaryPort::Replacement(or_node, OutgoingPort::from(0)); 4]
1261 );
1262 assert_eq!(
1263 BTreeSet::from_iter(repl.linked_host_outputs((or_node, OutgoingPort::from(0)), &hugr)),
1264 BTreeSet::from_iter((0..4).map(|i| HostPort(out, IncomingPort::from(i))))
1265 );
1266 }
1267
1268 use crate::hugr::patch::replace::Replacement;
1269 fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
1270 use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec};
1271
1272 let [in_, out] = s.get_replacement_io();
1273 let mu_inp = s
1274 .incoming_boundary(h)
1275 .map(
1276 |(HostPort(src, src_port), ReplacementPort(tgt, tgt_port))| {
1277 if tgt == out {
1278 unimplemented!()
1279 }
1280 NewEdgeSpec {
1281 src,
1282 tgt,
1283 kind: NewEdgeKind::Value {
1284 src_pos: src_port,
1285 tgt_pos: tgt_port,
1286 },
1287 }
1288 },
1289 )
1290 .collect();
1291 let mu_out = s
1292 .outgoing_boundary(h)
1293 .map(
1294 |(ReplacementPort(src, src_port), HostPort(tgt, tgt_port))| {
1295 if src == in_ {
1296 unimplemented!()
1297 }
1298 NewEdgeSpec {
1299 src,
1300 tgt,
1301 kind: NewEdgeKind::Value {
1302 src_pos: src_port,
1303 tgt_pos: tgt_port,
1304 },
1305 }
1306 },
1307 )
1308 .collect();
1309 let mut replacement = s.replacement;
1310 replacement.remove_node(in_);
1311 replacement.remove_node(out);
1312 Replacement {
1313 removal: s.subgraph.nodes().to_vec(),
1314 replacement,
1315 adoptions: HashMap::new(),
1316 mu_inp,
1317 mu_out,
1318 mu_new: vec![],
1319 }
1320 }
1321
1322 fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
1323 h.apply_patch(rw).unwrap();
1324 }
1325
1326 fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
1327 h.apply_patch(to_replace(h, rw)).unwrap();
1328 }
1329}