1use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7use crate::hugr::views::SiblingSubgraph;
8pub use crate::hugr::views::sibling_subgraph::InvalidReplacement;
9use crate::hugr::{HugrMut, HugrView};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::{Hugr, IncomingPort, Node, OutgoingPort, PortIndex};
12
13use itertools::{Either, Itertools};
14
15use thiserror::Error;
16
17use super::inline_dfg::InlineDFGError;
18use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort};
19
20pub mod serial;
21
22#[derive(Debug, Clone)]
28pub struct SimpleReplacement<HostNode = Node> {
29 subgraph: SiblingSubgraph<HostNode>,
31 replacement: Hugr,
33}
34
35impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
36 #[inline]
39 pub fn new_unchecked(subgraph: SiblingSubgraph<HostNode>, replacement: Hugr) -> Self {
40 Self {
41 subgraph,
42 replacement,
43 }
44 }
45
46 pub fn try_new(
56 subgraph: SiblingSubgraph<HostNode>,
57 host: &impl HugrView<Node = HostNode>,
58 replacement: Hugr,
59 ) -> Result<Self, InvalidReplacement> {
60 let subgraph_sig = subgraph.poly_func_type(host);
61 let repl_sig = replacement
62 .poly_func_type()
63 .or_else(|| {
64 Some(
65 replacement
66 .inner_function_type()
67 .unwrap()
68 .into_owned()
69 .into(),
70 )
71 })
72 .ok_or(InvalidReplacement::InvalidDataflowGraph {
73 node: replacement.entrypoint(),
74 op: Box::new(replacement.entrypoint_optype().to_owned()),
75 })?;
76 if subgraph_sig != repl_sig {
77 return Err(InvalidReplacement::InvalidSignature {
78 expected: Box::new(subgraph_sig),
79 actual: Some(Box::new(repl_sig)),
80 });
81 }
82 Ok(Self {
83 subgraph,
84 replacement,
85 })
86 }
87
88 #[inline]
90 pub fn replacement(&self) -> &Hugr {
91 &self.replacement
92 }
93
94 #[inline]
96 pub fn into_replacement(self) -> Hugr {
97 self.replacement
98 }
99
100 #[inline]
102 pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
103 &self.subgraph
104 }
105
106 pub fn is_valid_rewrite(
108 &self,
109 h: &impl HugrView<Node = HostNode>,
110 ) -> Result<(), SimpleReplacementError> {
111 let parent = self.subgraph.get_parent(h);
112
113 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
115 return Err(SimpleReplacementError::InvalidParentNode());
116 }
117
118 for node in self.subgraph.nodes() {
120 if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
121 return Err(SimpleReplacementError::InvalidRemovedNode());
122 }
123 }
124
125 Ok(())
126 }
127
128 pub fn get_replacement_io(&self) -> [Node; 2] {
130 self.replacement
131 .get_io(self.replacement.entrypoint())
132 .expect("replacement is a DFG")
133 }
134
135 pub fn linked_replacement_output(
148 &self,
149 port: impl Into<HostPort<HostNode, IncomingPort>>,
150 host: &impl HugrView<Node = HostNode>,
151 boundary: BoundaryMode,
152 ) -> Option<BoundaryPort<HostNode, OutgoingPort>> {
153 let HostPort(node, port) = port.into();
154 let pos = self
155 .subgraph
156 .outgoing_ports()
157 .iter()
158 .position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?;
159
160 Some(self.linked_replacement_output_by_position(pos, host, boundary))
161 }
162
163 fn linked_replacement_output_by_position(
168 &self,
169 pos: usize,
170 host: &impl HugrView<Node = HostNode>,
171 boundary: BoundaryMode,
172 ) -> BoundaryPort<HostNode, OutgoingPort> {
173 debug_assert!(
174 pos < self
175 .subgraph()
176 .poly_func_type(host)
177 .into_body()
178 .output_count()
179 );
180
181 let [repl_inp, repl_out] = self.get_replacement_io();
183 let (out_node, out_port) = self
184 .replacement
185 .single_linked_output(repl_out, pos)
186 .expect("valid dfg wire");
187
188 if out_node != repl_inp || boundary == BoundaryMode::IncludeIO {
189 BoundaryPort::Replacement(out_node, out_port)
190 } else {
191 let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
192 .first()
193 .expect("non-empty boundary partition");
194 let (out_node, out_port) = host
195 .single_linked_output(in_node, in_port)
196 .expect("valid dfg wire");
197 BoundaryPort::Host(out_node, out_port)
198 }
199 }
200
201 pub fn linked_host_outputs(
209 &self,
210 port: impl Into<ReplacementPort<OutgoingPort>>,
211 host: &impl HugrView<Node = HostNode>,
212 ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> {
213 let ReplacementPort(node, port) = port.into();
214 let [_, repl_out] = self.get_replacement_io();
215 let positions = self
216 .replacement
217 .linked_inputs(node, port)
218 .filter_map(move |(n, p)| (n == repl_out).then_some(p.index()));
219
220 positions
221 .map(|pos| self.subgraph.outgoing_ports()[pos])
222 .flat_map(|(out_node, out_port)| {
223 let in_nodes_ports = host.linked_inputs(out_node, out_port);
224 in_nodes_ports.map(|(n, p)| HostPort(n, p))
225 })
226 }
227
228 pub fn linked_replacement_inputs<'a>(
241 &'a self,
242 port: impl Into<HostPort<HostNode, OutgoingPort>>,
243 host: &'a impl HugrView<Node = HostNode>,
244 boundary: BoundaryMode,
245 ) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> + 'a {
246 let HostPort(node, port) = port.into();
247 let positions = self
248 .subgraph
249 .incoming_ports()
250 .iter()
251 .positions(move |ports| {
252 let (n, p) = *ports.first().expect("non-empty boundary partition");
253 host.single_linked_output(n, p).expect("valid dfg wire") == (node, port)
254 });
255
256 positions
257 .flat_map(move |pos| self.linked_replacement_inputs_by_position(pos, host, boundary))
258 }
259
260 fn linked_replacement_inputs_by_position(
262 &self,
263 pos: usize,
264 host: &impl HugrView<Node = HostNode>,
265 boundary: BoundaryMode,
266 ) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> {
267 debug_assert!(
268 pos < self
269 .subgraph()
270 .poly_func_type(host)
271 .into_body()
272 .input_count()
273 );
274
275 let [repl_inp, repl_out] = self.get_replacement_io();
276 self.replacement
277 .linked_inputs(repl_inp, pos)
278 .flat_map(move |(in_node, in_port)| {
279 if in_node != repl_out || boundary == BoundaryMode::IncludeIO {
280 Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port)))
281 } else {
282 let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()];
283 let in_nodes_ports = host.linked_inputs(out_node, out_port);
284 Either::Right(in_nodes_ports.map(|(n, p)| BoundaryPort::Host(n, p)))
285 }
286 })
287 }
288
289 pub fn linked_host_input(
297 &self,
298 port: impl Into<ReplacementPort<IncomingPort>>,
299 host: &impl HugrView<Node = HostNode>,
300 ) -> HostPort<HostNode, OutgoingPort> {
301 let ReplacementPort(node, port) = port.into();
302 let (out_node, out_port) = self
303 .replacement
304 .single_linked_output(node, port)
305 .expect("valid dfg wire");
306
307 let [repl_in, _] = self.get_replacement_io();
308 assert!(out_node == repl_in, "not a boundary port");
309
310 let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
311 .first()
312 .expect("non-empty input partition");
313
314 let (host_node, host_port) = host
315 .single_linked_output(in_node, in_port)
316 .expect("valid dfg wire");
317 HostPort(host_node, host_port)
318 }
319
320 pub fn incoming_boundary<'a>(
331 &'a self,
332 host: &'a impl HugrView<Node = HostNode>,
333 ) -> impl Iterator<
334 Item = (
335 HostPort<HostNode, OutgoingPort>,
336 ReplacementPort<IncomingPort>,
337 ),
338 > + 'a {
339 let subgraph_outgoing_ports = self
341 .subgraph
342 .incoming_ports()
343 .iter()
344 .map(|in_ports| *in_ports.first().expect("non-empty input partition"))
345 .map(|(node, in_port)| {
346 host.single_linked_output(node, in_port)
347 .expect("valid dfg wire")
348 });
349
350 subgraph_outgoing_ports
351 .enumerate()
352 .flat_map(|(pos, subg_np)| {
353 self.linked_replacement_inputs_by_position(pos, host, BoundaryMode::SnapToHost)
354 .filter_map(move |np| Some((np.as_replacement()?, subg_np)))
355 })
356 .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
357 (
358 HostPort(subgraph_node, subgraph_port),
359 ReplacementPort(repl_node, repl_port),
360 )
361 })
362 }
363
364 pub fn outgoing_boundary<'a>(
377 &'a self,
378 host: &'a impl HugrView<Node = HostNode>,
379 ) -> impl Iterator<
380 Item = (
381 ReplacementPort<OutgoingPort>,
382 HostPort<HostNode, IncomingPort>,
383 ),
384 > + 'a {
385 let subgraph_incoming_ports = self.subgraph.outgoing_ports().iter().map(
387 move |&(subgraph_out_node, subgraph_out_port)| {
388 host.linked_inputs(subgraph_out_node, subgraph_out_port)
389 },
390 );
391
392 subgraph_incoming_ports
393 .enumerate()
394 .filter_map(|(pos, subg_all)| {
395 let np = self
396 .linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
397 .as_replacement()?;
398 Some((np, subg_all))
399 })
400 .flat_map(|(repl_np, subg_all)| subg_all.map(move |subg_np| (repl_np, subg_np)))
401 .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
402 (
403 ReplacementPort(repl_node, repl_port),
404 HostPort(subgraph_node, subgraph_port),
405 )
406 })
407 }
408
409 pub fn host_to_host_boundary<'a>(
424 &'a self,
425 host: &'a impl HugrView<Node = HostNode>,
426 ) -> impl Iterator<
427 Item = (
428 HostPort<HostNode, OutgoingPort>,
429 HostPort<HostNode, IncomingPort>,
430 ),
431 > + 'a {
432 let subgraph_incoming_ports = self.subgraph.outgoing_ports().iter().map(
434 move |&(subgraph_out_node, subgraph_out_port)| {
435 host.linked_inputs(subgraph_out_node, subgraph_out_port)
436 },
437 );
438
439 subgraph_incoming_ports
440 .enumerate()
441 .filter_map(|(pos, subg_all)| {
442 Some((
443 self.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
444 .as_host()?,
445 subg_all,
446 ))
447 })
448 .flat_map(|(host_np, subg_all)| subg_all.map(move |subg_np| (host_np, subg_np)))
449 .map(
450 |((host_out_node, host_out_port), (host_in_node, host_in_port))| {
451 (
452 HostPort(host_out_node, host_out_port),
453 HostPort(host_in_node, host_in_port),
454 )
455 },
456 )
457 }
458
459 pub fn map_host_output(
467 &self,
468 port: impl Into<HostPort<HostNode, OutgoingPort>>,
469 ) -> Option<ReplacementPort<IncomingPort>> {
470 let HostPort(node, port) = port.into();
471 let pos = self
472 .subgraph
473 .outgoing_ports()
474 .iter()
475 .position(|&node_port| node_port == (node, port))?;
476 let incoming_port: IncomingPort = pos.into();
477 let [_, rep_output] = self.get_replacement_io();
478 Some(ReplacementPort(rep_output, incoming_port))
479 }
480
481 pub fn map_replacement_input(
486 &self,
487 port: impl Into<ReplacementPort<OutgoingPort>>,
488 ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> + '_ {
489 let ReplacementPort(node, port) = port.into();
490 let [repl_input, _] = self.get_replacement_io();
491
492 let ports = if node == repl_input {
493 self.subgraph.incoming_ports().get(port.index())
494 } else {
495 None
496 };
497 ports
498 .into_iter()
499 .flat_map(|ports| ports.iter().map(|&(n, p)| HostPort(n, p)))
500 }
501
502 pub fn all_boundary_edges<'a>(
511 &'a self,
512 host: &'a impl HugrView<Node = HostNode>,
513 ) -> impl Iterator<
514 Item = (
515 BoundaryPort<HostNode, OutgoingPort>,
516 BoundaryPort<HostNode, IncomingPort>,
517 ),
518 > + 'a {
519 let incoming_boundary = self
520 .incoming_boundary(host)
521 .map(|(src, tgt)| (src.into(), tgt.into()));
522 let outgoing_boundary = self
523 .outgoing_boundary(host)
524 .map(|(src, tgt)| (src.into(), tgt.into()));
525 let host_to_host_boundary = self
526 .host_to_host_boundary(host)
527 .map(|(src, tgt)| (src.into(), tgt.into()));
528
529 incoming_boundary
530 .chain(outgoing_boundary)
531 .chain(host_to_host_boundary)
532 }
533
534 pub fn map_host_nodes<N: HugrNode>(
542 &self,
543 node_map: impl Fn(HostNode) -> N,
544 new_host: &impl HugrView<Node = N>,
545 ) -> Result<SimpleReplacement<N>, InvalidReplacement> {
546 let Self {
547 subgraph,
548 replacement,
549 } = self;
550 let subgraph = subgraph.map_nodes(node_map);
551 SimpleReplacement::try_new(subgraph, new_host, replacement.clone())
552 }
553
554 pub fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
557 self.subgraph.nodes().iter().copied()
558 }
559}
560
561impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
562 type Error = SimpleReplacementError;
563 type Node = HostNode;
564
565 fn verify(&self, h: &impl HugrView<Node = HostNode>) -> Result<(), SimpleReplacementError> {
566 self.is_valid_rewrite(h)
567 }
568
569 #[inline]
570 fn invalidated_nodes(
571 &self,
572 _: &impl HugrView<Node = Self::Node>,
573 ) -> impl Iterator<Item = Self::Node> {
574 self.invalidation_set()
575 }
576}
577
578#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
584pub enum BoundaryMode {
585 #[default]
590 SnapToHost,
591 IncludeIO,
594}
595
596pub struct Outcome<HostNode = Node> {
598 pub node_map: HashMap<Node, HostNode>,
600 pub removed_nodes: HashMap<HostNode, OpType>,
602}
603
604impl<N: HugrNode> PatchHugrMut for SimpleReplacement<N> {
605 type Outcome = Outcome<N>;
606 const UNCHANGED_ON_FAILURE: bool = true;
607
608 fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
609 self.is_valid_rewrite(h)?;
610
611 let parent = self.subgraph.get_parent(h);
612
613 let boundary_edges = self.all_boundary_edges(h).collect_vec();
621
622 let Self {
623 replacement,
624 subgraph,
625 ..
626 } = self;
627
628 let repl_io = replacement
630 .get_io(replacement.entrypoint())
631 .expect("replacement is DFG-rooted");
632 let repl_entrypoint = replacement.entrypoint();
633
634 let InsertionResult {
636 inserted_entrypoint: new_entrypoint,
637 mut node_map,
638 } = h.insert_hugr(parent, replacement);
639
640 for node in repl_io {
642 let node_h = node_map[&node];
643 h.remove_node(node_h);
644 node_map.remove(&node);
645 }
646
647 for child in h.children(new_entrypoint).collect_vec() {
649 h.set_parent(child, parent);
650 }
651
652 h.remove_node(new_entrypoint);
654 node_map.remove(&repl_entrypoint);
655
656 for (src, tgt) in boundary_edges {
658 let (src_node, src_port) = src.map_replacement(&node_map);
659 let (tgt_node, tgt_port) = tgt.map_replacement(&node_map);
660 h.connect(src_node, src_port, tgt_node, tgt_port);
661 }
662
663 let removed_nodes = subgraph
665 .nodes()
666 .iter()
667 .map(|&node| (node, h.remove_node(node)))
668 .collect();
669
670 Ok(Outcome {
671 node_map,
672 removed_nodes,
673 })
674 }
675}
676
677#[derive(Debug, Clone, Error, PartialEq, Eq)]
679#[non_exhaustive]
680pub enum SimpleReplacementError {
681 #[error("Parent node is invalid.")]
683 InvalidParentNode(),
684 #[error("A node requested for removal is invalid.")]
686 InvalidRemovedNode(),
687 #[error("A node in the replacement graph is invalid.")]
689 InvalidReplacementNode(),
690 #[error("Inlining replacement failed: {0}")]
692 InliningFailed(#[from] InlineDFGError),
693}
694
695#[cfg(test)]
696pub(in crate::hugr::patch) mod test {
697 use itertools::Itertools;
698 use rstest::{fixture, rstest};
699
700 use std::collections::{BTreeSet, HashMap, HashSet};
701
702 use crate::builder::test::n_identity;
703 use crate::builder::{
704 BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
705 ModuleBuilder, endo_sig, inout_sig,
706 };
707 use crate::extension::prelude::{bool_t, qb_t};
708 use crate::hugr::patch::simple_replace::{BoundaryMode, Outcome};
709 use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort};
710 use crate::hugr::views::{HugrView, SiblingSubgraph};
711 use crate::hugr::{Hugr, HugrMut, Patch};
712 use crate::ops::OpTag;
713 use crate::ops::OpTrait;
714 use crate::ops::handle::NodeHandle;
715 use crate::std_extensions::logic::LogicOp;
716 use crate::std_extensions::logic::test::and_op;
717 use crate::types::{Signature, Type};
718 use crate::utils::test_quantum_extension::{cx_gate, h_gate};
719 use crate::{IncomingPort, Node, OutgoingPort};
720
721 use super::SimpleReplacement;
722
723 fn make_hugr() -> Result<Hugr, BuildError> {
733 let mut module_builder = ModuleBuilder::new();
734 let _f_id = {
735 let mut func_builder = module_builder
736 .define_function("main", Signature::new_endo([qb_t(), qb_t(), qb_t()]))?;
737
738 let [qb0, qb1, qb2] = func_builder.input_wires_arr();
739
740 let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
741
742 let mut inner_builder =
743 func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
744 let inner_graph = {
745 let [wire0, wire1] = inner_builder.input_wires_arr();
746 let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
747 let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
748 let wire45 = inner_builder
749 .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
750 let [wire4, wire5] = wire45.outputs_arr();
751 let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
752 let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
753 inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
754 }?;
755
756 func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
757 };
758 Ok(module_builder.finish_hugr()?)
759 }
760
761 #[fixture]
762 pub(in crate::hugr::patch) fn simple_hugr() -> Hugr {
763 make_hugr().unwrap()
764 }
765 fn make_dfg_hugr() -> Result<Hugr, BuildError> {
772 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
773 let [wire0, wire1] = dfg_builder.input_wires_arr();
774 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
775 let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
776 let wire45 =
777 dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
778 dfg_builder.finish_hugr_with_outputs(wire45.outputs())
779 }
780
781 #[fixture]
782 pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr {
783 make_dfg_hugr().unwrap()
784 }
785
786 fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
792 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
793
794 let [wire0, wire1] = dfg_builder.input_wires_arr();
795 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
796 let wire2out = wire2.outputs().exactly_one().unwrap();
797 let wireoutvec = vec![wire0, wire2out];
798 dfg_builder.finish_hugr_with_outputs(wireoutvec)
799 }
800
801 #[fixture]
802 pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr {
803 make_dfg_hugr2().unwrap()
804 }
805
806 #[fixture]
819 pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
820 let mut dfg_builder =
821 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
822 let [b] = dfg_builder.input_wires_arr();
823
824 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
825 let [b] = not_inp.outputs_arr();
826
827 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
828 let [b0] = not_0.outputs_arr();
829 let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
830 let [b1] = not_1.outputs_arr();
831
832 (
833 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
834 vec![not_inp.node(), not_0.node(), not_1.node()],
835 )
836 }
837
838 #[fixture]
851 pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
852 let mut dfg_builder =
853 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
854 let [b] = dfg_builder.input_wires_arr();
855
856 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
857 let [b] = not_inp.outputs_arr();
858
859 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
860 let [b0] = not_0.outputs_arr();
861 let b1 = b;
862
863 (
864 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
865 vec![not_inp.node(), not_0.node()],
866 )
867 }
868
869 #[rstest]
870 fn test_simple_replacement(
889 simple_hugr: Hugr,
890 dfg_hugr: Hugr,
891 #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
892 ) {
893 let mut h: Hugr = simple_hugr;
894 let h_node_cx: Node = h
896 .entry_descendants()
897 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
898 .unwrap();
899 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
900 let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
901 let n: Hugr = dfg_hugr;
903 let n_node_cx = n
906 .entry_descendants()
907 .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
908 .unwrap();
909 let (n_cx_out_0, _n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
911 let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
912 let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
914 let r = SimpleReplacement {
916 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
917 replacement: n,
918 };
919
920 assert_eq!(
922 r.map_host_output((h_node_h0, h_h0_out)).unwrap(),
923 ReplacementPort::from((r.get_replacement_io()[1], n_port_2))
924 );
925
926 assert_eq!(
928 HashSet::<_>::from_iter(r.invalidated_nodes(&h)),
929 HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1]),
930 );
931
932 applicator(&mut h, r);
933 assert_eq!(h.validate(), Ok(()));
940 }
941
942 #[rstest]
943 fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
961 let mut h: Hugr = simple_hugr;
962
963 let h_node_cx: Node = h
965 .entry_descendants()
966 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
967 .unwrap();
968 let s: Vec<Node> = vec![h_node_cx];
969 let n: Hugr = dfg_hugr2;
971 let n_node_output = n.get_io(n.entrypoint()).unwrap()[1];
974 let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
975 let r = SimpleReplacement {
977 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
978 replacement: n,
979 };
980 let Outcome {
981 node_map,
982 removed_nodes,
983 } = h.apply_patch(r).unwrap();
984
985 assert_eq!(
986 node_map.into_keys().collect::<HashSet<_>>(),
987 [n_node_h].into_iter().collect::<HashSet<_>>(),
988 );
989 assert_eq!(
990 removed_nodes.into_keys().collect::<HashSet<_>>(),
991 [h_node_cx].into_iter().collect::<HashSet<_>>(),
992 );
993
994 assert_eq!(h.validate(), Ok(()));
1001 }
1002
1003 #[test]
1004 fn test_replace_cx_cross() {
1005 let q_row: Vec<Type> = vec![qb_t(), qb_t()];
1006 let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
1007 let mut circ = builder.as_circuit(builder.input_wires());
1008 circ.append(cx_gate(), [0, 1]).unwrap();
1009 circ.append(cx_gate(), [1, 0]).unwrap();
1010 let wires = circ.finish();
1011 let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
1012 let replacement = h.clone();
1013 let orig = h.clone();
1014
1015 let removal = h
1016 .entry_descendants()
1017 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
1018 .collect_vec();
1019 h.apply_patch(
1020 SimpleReplacement::try_new(
1021 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
1022 &h,
1023 replacement,
1024 )
1025 .unwrap(),
1026 )
1027 .unwrap();
1028
1029 assert_eq!(h.num_edges(), orig.num_edges());
1031 }
1032
1033 #[test]
1034 fn test_replace_after_copy() {
1035 let one_bit = vec![bool_t()];
1036 let two_bit = vec![bool_t(), bool_t()];
1037
1038 let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
1039 let inw = builder.input_wires().exactly_one().unwrap();
1040 let outw = builder
1041 .add_dataflow_op(and_op(), [inw, inw])
1042 .unwrap()
1043 .outputs();
1044 let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
1045
1046 let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
1047 let inw = builder.input_wires();
1048 let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
1049 let repl = builder.finish_hugr_with_outputs(outw).unwrap();
1050
1051 let orig = h.clone();
1052
1053 let removal = h
1054 .entry_descendants()
1055 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
1056 .collect_vec();
1057
1058 h.apply_patch(
1059 SimpleReplacement::try_new(
1060 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
1061 &h,
1062 repl,
1063 )
1064 .unwrap(),
1065 )
1066 .unwrap();
1067
1068 assert_eq!(h.num_nodes(), orig.num_nodes());
1070 }
1071
1072 #[rstest]
1077 fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
1078 let (mut hugr, nodes) = dfg_hugr_copy_bools;
1079 let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
1080
1081 let replacement = {
1082 let b =
1083 DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1084 let [w] = b.input_wires_arr();
1085 b.finish_hugr_with_outputs([w, w]).unwrap()
1086 };
1087
1088 let subgraph =
1089 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
1090 .unwrap();
1091
1092 let rewrite = SimpleReplacement {
1093 subgraph,
1094 replacement,
1095 };
1096 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1097
1098 assert_eq!(hugr.validate(), Ok(()));
1099 assert_eq!(hugr.entry_descendants().count(), 3);
1100 }
1101
1102 #[rstest]
1107 fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
1108 let (mut hugr, nodes) = dfg_hugr_half_not_bools;
1109 let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
1110
1111 let replacement = {
1112 let mut b =
1113 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1114 let [w] = b.input_wires_arr();
1115 let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
1116 let [w_not] = not.outputs_arr();
1117 b.finish_hugr_with_outputs([w, w_not]).unwrap()
1118 };
1119
1120 let subgraph =
1121 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
1122
1123 let rewrite = SimpleReplacement {
1124 subgraph,
1125 replacement,
1126 };
1127 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1128
1129 assert_eq!(hugr.validate(), Ok(()));
1130 assert_eq!(hugr.entry_descendants().count(), 4);
1131 }
1132
1133 #[rstest]
1134 fn test_nested_replace(dfg_hugr2: Hugr) {
1135 let mut h = dfg_hugr2;
1138 let h_node = h
1139 .entry_descendants()
1140 .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
1141 .unwrap();
1142
1143 let mut nest_build = DFGBuilder::new(Signature::new_endo([qb_t()])).unwrap();
1145 let [input] = nest_build.input_wires_arr();
1146 let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
1147 let inner_dfg = n_identity(inner_build).unwrap();
1148 let replacement = nest_build
1149 .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
1150 .unwrap();
1151 let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
1152
1153 let rewrite = SimpleReplacement::try_new(subgraph, &h, replacement).unwrap();
1154
1155 assert_eq!(h.entry_descendants().count(), 4);
1156
1157 rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
1158 h.validate().unwrap_or_else(|e| panic!("{e}"));
1159
1160 assert_eq!(h.entry_descendants().count(), 6);
1161 }
1162
1163 #[fixture]
1165 fn copy_not_not_copy_hugr() -> Hugr {
1166 let mut b = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(); 4])).unwrap();
1167 let [w] = b.input_wires_arr();
1168 let not1 = b.add_dataflow_op(LogicOp::Not, [w]).unwrap();
1169 let not2 = b.add_dataflow_op(LogicOp::Not, [w]).unwrap();
1170
1171 let [out1] = not1.outputs_arr();
1172 let [out2] = not2.outputs_arr();
1173
1174 b.finish_hugr_with_outputs([out1, out2, out1, out2])
1175 .unwrap()
1176 }
1177
1178 #[rstest]
1179 fn test_boundary_traversal_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1180 let hugr = copy_not_not_copy_hugr;
1181 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1182 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1183 let subg_incoming = vec![
1184 vec![(not1, IncomingPort::from(0))],
1185 vec![(not2, IncomingPort::from(0))],
1186 ];
1187 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1188
1189 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1190
1191 let repl = {
1193 let b = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
1194 let [w1, w2] = b.input_wires_arr();
1195 let repl_hugr = b.finish_hugr_with_outputs([w1, w2]).unwrap();
1196 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap()
1197 };
1198
1199 let replacement_inputs: Vec<_> = repl
1201 .linked_replacement_inputs(
1202 (inp, OutgoingPort::from(0)),
1203 &hugr,
1204 BoundaryMode::SnapToHost,
1205 )
1206 .collect();
1207
1208 assert_eq!(
1209 BTreeSet::from_iter(replacement_inputs),
1210 (0..4)
1211 .map(|i| BoundaryPort::Host(out, IncomingPort::from(i)))
1212 .collect()
1213 );
1214
1215 let replacement_output = (0..4)
1217 .map(|i| {
1218 repl.linked_replacement_output(
1219 (out, IncomingPort::from(i)),
1220 &hugr,
1221 BoundaryMode::SnapToHost,
1222 )
1223 .unwrap()
1224 })
1225 .collect_vec();
1226
1227 assert_eq!(
1228 replacement_output,
1229 vec![BoundaryPort::Host(inp, OutgoingPort::from(0)); 4]
1230 );
1231 }
1232
1233 #[rstest]
1234 fn test_boundary_traversal_copy_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1235 let hugr = copy_not_not_copy_hugr;
1236 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1237 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1238 let subg_incoming = vec![vec![
1239 (not1, IncomingPort::from(0)),
1240 (not2, IncomingPort::from(0)),
1241 ]];
1242 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1243
1244 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1245
1246 let repl = {
1248 let b = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap();
1249 let [w] = b.input_wires_arr();
1250 let repl_hugr = b.finish_hugr_with_outputs([w, w]).unwrap();
1251 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap()
1252 };
1253
1254 let replacement_inputs: Vec<_> = repl
1255 .linked_replacement_inputs(
1256 (inp, OutgoingPort::from(0)),
1257 &hugr,
1258 BoundaryMode::SnapToHost,
1259 )
1260 .collect();
1261
1262 assert_eq!(
1263 BTreeSet::from_iter(replacement_inputs),
1264 (0..4)
1265 .map(|i| BoundaryPort::Host(out, IncomingPort::from(i)))
1266 .collect()
1267 );
1268
1269 let replacement_output = (0..4)
1270 .map(|i| {
1271 repl.linked_replacement_output(
1272 (out, IncomingPort::from(i)),
1273 &hugr,
1274 BoundaryMode::SnapToHost,
1275 )
1276 .unwrap()
1277 })
1278 .collect_vec();
1279
1280 assert_eq!(
1281 replacement_output,
1282 vec![BoundaryPort::Host(inp, OutgoingPort::from(0)); 4]
1283 );
1284 }
1285
1286 #[rstest]
1287 fn test_boundary_traversal_non_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1288 let hugr = copy_not_not_copy_hugr;
1289 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1290 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1291 let subg_incoming = vec![
1292 vec![(not1, IncomingPort::from(0))],
1293 vec![(not2, IncomingPort::from(0))],
1294 ];
1295 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1296
1297 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1298
1299 let (repl, or_node) = {
1301 let mut b = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
1302 let [w1, w2] = b.input_wires_arr();
1303 let or_handle = b.add_dataflow_op(LogicOp::Or, [w1, w2]).unwrap();
1304 let [out] = or_handle.outputs_arr();
1305 let repl_hugr = b.finish_hugr_with_outputs([out, out]).unwrap();
1306 (
1307 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap(),
1308 or_handle.node(),
1309 )
1310 };
1311
1312 let replacement_inputs: Vec<_> = repl
1313 .linked_replacement_inputs(
1314 (inp, OutgoingPort::from(0)),
1315 &hugr,
1316 BoundaryMode::SnapToHost,
1317 )
1318 .collect();
1319
1320 assert_eq!(
1321 BTreeSet::from_iter(replacement_inputs),
1322 (0..2)
1323 .map(|i| BoundaryPort::Replacement(or_node, IncomingPort::from(i)))
1324 .collect()
1325 );
1326 assert_eq!(
1327 repl.linked_host_input((or_node, IncomingPort::from(0)), &hugr),
1328 (inp, OutgoingPort::from(0)).into()
1329 );
1330
1331 let replacement_output = (0..4)
1332 .map(|i| {
1333 repl.linked_replacement_output(
1334 (out, IncomingPort::from(i)),
1335 &hugr,
1336 BoundaryMode::SnapToHost,
1337 )
1338 .unwrap()
1339 })
1340 .collect_vec();
1341
1342 assert_eq!(
1343 replacement_output,
1344 vec![BoundaryPort::Replacement(or_node, OutgoingPort::from(0)); 4]
1345 );
1346 assert_eq!(
1347 BTreeSet::from_iter(repl.linked_host_outputs((or_node, OutgoingPort::from(0)), &hugr)),
1348 BTreeSet::from_iter((0..4).map(|i| HostPort(out, IncomingPort::from(i))))
1349 );
1350 }
1351
1352 use crate::hugr::patch::replace::Replacement;
1353 fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
1354 use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec};
1355
1356 let [in_, out] = s.get_replacement_io();
1357 let mu_inp = s
1358 .incoming_boundary(h)
1359 .map(
1360 |(HostPort(src, src_port), ReplacementPort(tgt, tgt_port))| {
1361 if tgt == out {
1362 unimplemented!()
1363 }
1364 NewEdgeSpec {
1365 src,
1366 tgt,
1367 kind: NewEdgeKind::Value {
1368 src_pos: src_port,
1369 tgt_pos: tgt_port,
1370 },
1371 }
1372 },
1373 )
1374 .collect();
1375 let mu_out = s
1376 .outgoing_boundary(h)
1377 .map(
1378 |(ReplacementPort(src, src_port), HostPort(tgt, tgt_port))| {
1379 if src == in_ {
1380 unimplemented!()
1381 }
1382 NewEdgeSpec {
1383 src,
1384 tgt,
1385 kind: NewEdgeKind::Value {
1386 src_pos: src_port,
1387 tgt_pos: tgt_port,
1388 },
1389 }
1390 },
1391 )
1392 .collect();
1393 let mut replacement = s.replacement;
1394 replacement.remove_node(in_);
1395 replacement.remove_node(out);
1396 Replacement {
1397 removal: s.subgraph.nodes().to_vec(),
1398 replacement,
1399 adoptions: HashMap::new(),
1400 mu_inp,
1401 mu_out,
1402 mu_new: vec![],
1403 }
1404 }
1405
1406 fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
1407 h.apply_patch(rw).unwrap();
1408 }
1409
1410 fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
1411 h.apply_patch(to_replace(h, rw)).unwrap();
1412 }
1413}