1use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7use crate::hugr::views::SiblingSubgraph;
8pub use crate::hugr::views::sibling_subgraph::InvalidReplacement;
9use crate::hugr::{HugrMut, HugrView};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::{Hugr, IncomingPort, Node, OutgoingPort, PortIndex};
12
13use itertools::{Either, Itertools};
14
15use thiserror::Error;
16
17use super::inline_dfg::InlineDFGError;
18use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort};
19
20pub mod serial;
21
22#[derive(Debug, Clone)]
28pub struct SimpleReplacement<HostNode = Node> {
29 subgraph: SiblingSubgraph<HostNode>,
31 replacement: Hugr,
33}
34
35impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
36 #[inline]
39 pub fn new_unchecked(subgraph: SiblingSubgraph<HostNode>, replacement: Hugr) -> Self {
40 Self {
41 subgraph,
42 replacement,
43 }
44 }
45
46 pub fn try_new(
51 subgraph: SiblingSubgraph<HostNode>,
52 host: &impl HugrView<Node = HostNode>,
53 replacement: Hugr,
54 ) -> Result<Self, InvalidReplacement> {
55 let subgraph_sig = subgraph.signature(host);
56 let repl_sig =
57 replacement
58 .inner_function_type()
59 .ok_or(InvalidReplacement::InvalidDataflowGraph {
60 node: replacement.entrypoint(),
61 op: Box::new(replacement.get_optype(replacement.entrypoint()).to_owned()),
62 })?;
63 if subgraph_sig != repl_sig {
64 return Err(InvalidReplacement::InvalidSignature {
65 expected: Box::new(subgraph_sig),
66 actual: Some(Box::new(repl_sig.into_owned())),
67 });
68 }
69 Ok(Self {
70 subgraph,
71 replacement,
72 })
73 }
74
75 #[inline]
77 pub fn replacement(&self) -> &Hugr {
78 &self.replacement
79 }
80
81 #[inline]
83 pub fn into_replacement(self) -> Hugr {
84 self.replacement
85 }
86
87 #[inline]
89 pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
90 &self.subgraph
91 }
92
93 pub fn is_valid_rewrite(
95 &self,
96 h: &impl HugrView<Node = HostNode>,
97 ) -> Result<(), SimpleReplacementError> {
98 let parent = self.subgraph.get_parent(h);
99
100 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
102 return Err(SimpleReplacementError::InvalidParentNode());
103 }
104
105 for node in self.subgraph.nodes() {
107 if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
108 return Err(SimpleReplacementError::InvalidRemovedNode());
109 }
110 }
111
112 Ok(())
113 }
114
115 pub fn get_replacement_io(&self) -> [Node; 2] {
117 self.replacement
118 .get_io(self.replacement.entrypoint())
119 .expect("replacement is a DFG")
120 }
121
122 pub fn linked_replacement_output(
135 &self,
136 port: impl Into<HostPort<HostNode, IncomingPort>>,
137 host: &impl HugrView<Node = HostNode>,
138 boundary: BoundaryMode,
139 ) -> Option<BoundaryPort<HostNode, OutgoingPort>> {
140 let HostPort(node, port) = port.into();
141 let pos = self
142 .subgraph
143 .outgoing_ports()
144 .iter()
145 .position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?;
146
147 Some(self.linked_replacement_output_by_position(pos, host, boundary))
148 }
149
150 fn linked_replacement_output_by_position(
155 &self,
156 pos: usize,
157 host: &impl HugrView<Node = HostNode>,
158 boundary: BoundaryMode,
159 ) -> BoundaryPort<HostNode, OutgoingPort> {
160 debug_assert!(pos < self.subgraph().signature(host).output_count());
161
162 let [repl_inp, repl_out] = self.get_replacement_io();
164 let (out_node, out_port) = self
165 .replacement
166 .single_linked_output(repl_out, pos)
167 .expect("valid dfg wire");
168
169 if out_node != repl_inp || boundary == BoundaryMode::IncludeIO {
170 BoundaryPort::Replacement(out_node, out_port)
171 } else {
172 let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
173 .first()
174 .expect("non-empty boundary partition");
175 let (out_node, out_port) = host
176 .single_linked_output(in_node, in_port)
177 .expect("valid dfg wire");
178 BoundaryPort::Host(out_node, out_port)
179 }
180 }
181
182 pub fn linked_host_outputs(
190 &self,
191 port: impl Into<ReplacementPort<OutgoingPort>>,
192 host: &impl HugrView<Node = HostNode>,
193 ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> {
194 let ReplacementPort(node, port) = port.into();
195 let [_, repl_out] = self.get_replacement_io();
196 let positions = self
197 .replacement
198 .linked_inputs(node, port)
199 .filter_map(move |(n, p)| (n == repl_out).then_some(p.index()));
200
201 positions
202 .map(|pos| self.subgraph.outgoing_ports()[pos])
203 .flat_map(|(out_node, out_port)| {
204 let in_nodes_ports = host.linked_inputs(out_node, out_port);
205 in_nodes_ports.map(|(n, p)| HostPort(n, p))
206 })
207 }
208
209 pub fn linked_replacement_inputs<'a>(
222 &'a self,
223 port: impl Into<HostPort<HostNode, OutgoingPort>>,
224 host: &'a impl HugrView<Node = HostNode>,
225 boundary: BoundaryMode,
226 ) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> + 'a {
227 let HostPort(node, port) = port.into();
228 let positions = self
229 .subgraph
230 .incoming_ports()
231 .iter()
232 .positions(move |ports| {
233 let (n, p) = *ports.first().expect("non-empty boundary partition");
234 host.single_linked_output(n, p).expect("valid dfg wire") == (node, port)
235 });
236
237 positions
238 .flat_map(move |pos| self.linked_replacement_inputs_by_position(pos, host, boundary))
239 }
240
241 fn linked_replacement_inputs_by_position(
243 &self,
244 pos: usize,
245 host: &impl HugrView<Node = HostNode>,
246 boundary: BoundaryMode,
247 ) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> {
248 debug_assert!(pos < self.subgraph().signature(host).input_count());
249
250 let [repl_inp, repl_out] = self.get_replacement_io();
251 self.replacement
252 .linked_inputs(repl_inp, pos)
253 .flat_map(move |(in_node, in_port)| {
254 if in_node != repl_out || boundary == BoundaryMode::IncludeIO {
255 Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port)))
256 } else {
257 let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()];
258 let in_nodes_ports = host.linked_inputs(out_node, out_port);
259 Either::Right(in_nodes_ports.map(|(n, p)| BoundaryPort::Host(n, p)))
260 }
261 })
262 }
263
264 pub fn linked_host_input(
272 &self,
273 port: impl Into<ReplacementPort<IncomingPort>>,
274 host: &impl HugrView<Node = HostNode>,
275 ) -> HostPort<HostNode, OutgoingPort> {
276 let ReplacementPort(node, port) = port.into();
277 let (out_node, out_port) = self
278 .replacement
279 .single_linked_output(node, port)
280 .expect("valid dfg wire");
281
282 let [repl_in, _] = self.get_replacement_io();
283 assert!(out_node == repl_in, "not a boundary port");
284
285 let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
286 .first()
287 .expect("non-empty input partition");
288
289 let (host_node, host_port) = host
290 .single_linked_output(in_node, in_port)
291 .expect("valid dfg wire");
292 HostPort(host_node, host_port)
293 }
294
295 pub fn incoming_boundary<'a>(
306 &'a self,
307 host: &'a impl HugrView<Node = HostNode>,
308 ) -> impl Iterator<
309 Item = (
310 HostPort<HostNode, OutgoingPort>,
311 ReplacementPort<IncomingPort>,
312 ),
313 > + 'a {
314 let subgraph_outgoing_ports = self
316 .subgraph
317 .incoming_ports()
318 .iter()
319 .map(|in_ports| *in_ports.first().expect("non-empty input partition"))
320 .map(|(node, in_port)| {
321 host.single_linked_output(node, in_port)
322 .expect("valid dfg wire")
323 });
324
325 subgraph_outgoing_ports
326 .enumerate()
327 .flat_map(|(pos, subg_np)| {
328 self.linked_replacement_inputs_by_position(pos, host, BoundaryMode::SnapToHost)
329 .filter_map(move |np| Some((np.as_replacement()?, subg_np)))
330 })
331 .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
332 (
333 HostPort(subgraph_node, subgraph_port),
334 ReplacementPort(repl_node, repl_port),
335 )
336 })
337 }
338
339 pub fn outgoing_boundary<'a>(
352 &'a self,
353 host: &'a impl HugrView<Node = HostNode>,
354 ) -> impl Iterator<
355 Item = (
356 ReplacementPort<OutgoingPort>,
357 HostPort<HostNode, IncomingPort>,
358 ),
359 > + 'a {
360 let subgraph_incoming_ports = self.subgraph.outgoing_ports().iter().map(
362 move |&(subgraph_out_node, subgraph_out_port)| {
363 host.linked_inputs(subgraph_out_node, subgraph_out_port)
364 },
365 );
366
367 subgraph_incoming_ports
368 .enumerate()
369 .filter_map(|(pos, subg_all)| {
370 let np = self
371 .linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
372 .as_replacement()?;
373 Some((np, subg_all))
374 })
375 .flat_map(|(repl_np, subg_all)| subg_all.map(move |subg_np| (repl_np, subg_np)))
376 .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
377 (
378 ReplacementPort(repl_node, repl_port),
379 HostPort(subgraph_node, subgraph_port),
380 )
381 })
382 }
383
384 pub fn host_to_host_boundary<'a>(
399 &'a self,
400 host: &'a impl HugrView<Node = HostNode>,
401 ) -> impl Iterator<
402 Item = (
403 HostPort<HostNode, OutgoingPort>,
404 HostPort<HostNode, IncomingPort>,
405 ),
406 > + 'a {
407 let subgraph_incoming_ports = self.subgraph.outgoing_ports().iter().map(
409 move |&(subgraph_out_node, subgraph_out_port)| {
410 host.linked_inputs(subgraph_out_node, subgraph_out_port)
411 },
412 );
413
414 subgraph_incoming_ports
415 .enumerate()
416 .filter_map(|(pos, subg_all)| {
417 Some((
418 self.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
419 .as_host()?,
420 subg_all,
421 ))
422 })
423 .flat_map(|(host_np, subg_all)| subg_all.map(move |subg_np| (host_np, subg_np)))
424 .map(
425 |((host_out_node, host_out_port), (host_in_node, host_in_port))| {
426 (
427 HostPort(host_out_node, host_out_port),
428 HostPort(host_in_node, host_in_port),
429 )
430 },
431 )
432 }
433
434 pub fn map_host_output(
442 &self,
443 port: impl Into<HostPort<HostNode, OutgoingPort>>,
444 ) -> Option<ReplacementPort<IncomingPort>> {
445 let HostPort(node, port) = port.into();
446 let pos = self
447 .subgraph
448 .outgoing_ports()
449 .iter()
450 .position(|&node_port| node_port == (node, port))?;
451 let incoming_port: IncomingPort = pos.into();
452 let [_, rep_output] = self.get_replacement_io();
453 Some(ReplacementPort(rep_output, incoming_port))
454 }
455
456 pub fn map_replacement_input(
461 &self,
462 port: impl Into<ReplacementPort<OutgoingPort>>,
463 ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> + '_ {
464 let ReplacementPort(node, port) = port.into();
465 let [repl_input, _] = self.get_replacement_io();
466
467 let ports = if node == repl_input {
468 self.subgraph.incoming_ports().get(port.index())
469 } else {
470 None
471 };
472 ports
473 .into_iter()
474 .flat_map(|ports| ports.iter().map(|&(n, p)| HostPort(n, p)))
475 }
476
477 pub fn all_boundary_edges<'a>(
486 &'a self,
487 host: &'a impl HugrView<Node = HostNode>,
488 ) -> impl Iterator<
489 Item = (
490 BoundaryPort<HostNode, OutgoingPort>,
491 BoundaryPort<HostNode, IncomingPort>,
492 ),
493 > + 'a {
494 let incoming_boundary = self
495 .incoming_boundary(host)
496 .map(|(src, tgt)| (src.into(), tgt.into()));
497 let outgoing_boundary = self
498 .outgoing_boundary(host)
499 .map(|(src, tgt)| (src.into(), tgt.into()));
500 let host_to_host_boundary = self
501 .host_to_host_boundary(host)
502 .map(|(src, tgt)| (src.into(), tgt.into()));
503
504 incoming_boundary
505 .chain(outgoing_boundary)
506 .chain(host_to_host_boundary)
507 }
508
509 pub fn map_host_nodes<N: HugrNode>(
517 &self,
518 node_map: impl Fn(HostNode) -> N,
519 new_host: &impl HugrView<Node = N>,
520 ) -> Result<SimpleReplacement<N>, InvalidReplacement> {
521 let Self {
522 subgraph,
523 replacement,
524 } = self;
525 let subgraph = subgraph.map_nodes(node_map);
526 SimpleReplacement::try_new(subgraph, new_host, replacement.clone())
527 }
528
529 pub fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
532 self.subgraph.nodes().iter().copied()
533 }
534}
535
536impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
537 type Error = SimpleReplacementError;
538 type Node = HostNode;
539
540 fn verify(&self, h: &impl HugrView<Node = HostNode>) -> Result<(), SimpleReplacementError> {
541 self.is_valid_rewrite(h)
542 }
543
544 #[inline]
545 fn invalidated_nodes(
546 &self,
547 _: &impl HugrView<Node = Self::Node>,
548 ) -> impl Iterator<Item = Self::Node> {
549 self.invalidation_set()
550 }
551}
552
553#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
559pub enum BoundaryMode {
560 #[default]
565 SnapToHost,
566 IncludeIO,
569}
570
571pub struct Outcome<HostNode = Node> {
573 pub node_map: HashMap<Node, HostNode>,
575 pub removed_nodes: HashMap<HostNode, OpType>,
577}
578
579impl<N: HugrNode> PatchHugrMut for SimpleReplacement<N> {
580 type Outcome = Outcome<N>;
581 const UNCHANGED_ON_FAILURE: bool = true;
582
583 fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
584 self.is_valid_rewrite(h)?;
585
586 let parent = self.subgraph.get_parent(h);
587
588 let boundary_edges = self.all_boundary_edges(h).collect_vec();
596
597 let Self {
598 replacement,
599 subgraph,
600 ..
601 } = self;
602
603 let repl_io = replacement
605 .get_io(replacement.entrypoint())
606 .expect("replacement is DFG-rooted");
607 let repl_entrypoint = replacement.entrypoint();
608
609 let InsertionResult {
611 inserted_entrypoint: new_entrypoint,
612 mut node_map,
613 } = h.insert_hugr(parent, replacement);
614
615 for node in repl_io {
617 let node_h = node_map[&node];
618 h.remove_node(node_h);
619 node_map.remove(&node);
620 }
621
622 for child in h.children(new_entrypoint).collect_vec() {
624 h.set_parent(child, parent);
625 }
626
627 h.remove_node(new_entrypoint);
629 node_map.remove(&repl_entrypoint);
630
631 for (src, tgt) in boundary_edges {
633 let (src_node, src_port) = src.map_replacement(&node_map);
634 let (tgt_node, tgt_port) = tgt.map_replacement(&node_map);
635 h.connect(src_node, src_port, tgt_node, tgt_port);
636 }
637
638 let removed_nodes = subgraph
640 .nodes()
641 .iter()
642 .map(|&node| (node, h.remove_node(node)))
643 .collect();
644
645 Ok(Outcome {
646 node_map,
647 removed_nodes,
648 })
649 }
650}
651
652#[derive(Debug, Clone, Error, PartialEq, Eq)]
654#[non_exhaustive]
655pub enum SimpleReplacementError {
656 #[error("Parent node is invalid.")]
658 InvalidParentNode(),
659 #[error("A node requested for removal is invalid.")]
661 InvalidRemovedNode(),
662 #[error("A node in the replacement graph is invalid.")]
664 InvalidReplacementNode(),
665 #[error("Inlining replacement failed: {0}")]
667 InliningFailed(#[from] InlineDFGError),
668}
669
670#[cfg(test)]
671pub(in crate::hugr::patch) mod test {
672 use itertools::Itertools;
673 use rstest::{fixture, rstest};
674
675 use std::collections::{BTreeSet, HashMap, HashSet};
676
677 use crate::builder::test::n_identity;
678 use crate::builder::{
679 BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
680 ModuleBuilder, endo_sig, inout_sig,
681 };
682 use crate::extension::prelude::{bool_t, qb_t};
683 use crate::hugr::patch::simple_replace::{BoundaryMode, Outcome};
684 use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort};
685 use crate::hugr::views::{HugrView, SiblingSubgraph};
686 use crate::hugr::{Hugr, HugrMut, Patch};
687 use crate::ops::OpTag;
688 use crate::ops::OpTrait;
689 use crate::ops::handle::NodeHandle;
690 use crate::std_extensions::logic::LogicOp;
691 use crate::std_extensions::logic::test::and_op;
692 use crate::types::{Signature, Type};
693 use crate::utils::test_quantum_extension::{cx_gate, h_gate};
694 use crate::{IncomingPort, Node, OutgoingPort};
695
696 use super::SimpleReplacement;
697
698 fn make_hugr() -> Result<Hugr, BuildError> {
708 let mut module_builder = ModuleBuilder::new();
709 let _f_id = {
710 let mut func_builder = module_builder
711 .define_function("main", Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]))?;
712
713 let [qb0, qb1, qb2] = func_builder.input_wires_arr();
714
715 let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
716
717 let mut inner_builder =
718 func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
719 let inner_graph = {
720 let [wire0, wire1] = inner_builder.input_wires_arr();
721 let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
722 let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
723 let wire45 = inner_builder
724 .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
725 let [wire4, wire5] = wire45.outputs_arr();
726 let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
727 let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
728 inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
729 }?;
730
731 func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
732 };
733 Ok(module_builder.finish_hugr()?)
734 }
735
736 #[fixture]
737 pub(in crate::hugr::patch) fn simple_hugr() -> Hugr {
738 make_hugr().unwrap()
739 }
740 fn make_dfg_hugr() -> Result<Hugr, BuildError> {
747 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
748 let [wire0, wire1] = dfg_builder.input_wires_arr();
749 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
750 let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
751 let wire45 =
752 dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
753 dfg_builder.finish_hugr_with_outputs(wire45.outputs())
754 }
755
756 #[fixture]
757 pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr {
758 make_dfg_hugr().unwrap()
759 }
760
761 fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
767 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
768
769 let [wire0, wire1] = dfg_builder.input_wires_arr();
770 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
771 let wire2out = wire2.outputs().exactly_one().unwrap();
772 let wireoutvec = vec![wire0, wire2out];
773 dfg_builder.finish_hugr_with_outputs(wireoutvec)
774 }
775
776 #[fixture]
777 pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr {
778 make_dfg_hugr2().unwrap()
779 }
780
781 #[fixture]
794 pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
795 let mut dfg_builder =
796 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
797 let [b] = dfg_builder.input_wires_arr();
798
799 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
800 let [b] = not_inp.outputs_arr();
801
802 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
803 let [b0] = not_0.outputs_arr();
804 let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
805 let [b1] = not_1.outputs_arr();
806
807 (
808 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
809 vec![not_inp.node(), not_0.node(), not_1.node()],
810 )
811 }
812
813 #[fixture]
826 pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
827 let mut dfg_builder =
828 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
829 let [b] = dfg_builder.input_wires_arr();
830
831 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
832 let [b] = not_inp.outputs_arr();
833
834 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
835 let [b0] = not_0.outputs_arr();
836 let b1 = b;
837
838 (
839 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
840 vec![not_inp.node(), not_0.node()],
841 )
842 }
843
844 #[rstest]
845 fn test_simple_replacement(
864 simple_hugr: Hugr,
865 dfg_hugr: Hugr,
866 #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
867 ) {
868 let mut h: Hugr = simple_hugr;
869 let h_node_cx: Node = h
871 .entry_descendants()
872 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
873 .unwrap();
874 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
875 let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
876 let n: Hugr = dfg_hugr;
878 let n_node_cx = n
881 .entry_descendants()
882 .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
883 .unwrap();
884 let (n_cx_out_0, _n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
886 let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
887 let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
889 let r = SimpleReplacement {
891 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
892 replacement: n,
893 };
894
895 assert_eq!(
897 r.map_host_output((h_node_h0, h_h0_out)).unwrap(),
898 ReplacementPort::from((r.get_replacement_io()[1], n_port_2))
899 );
900
901 assert_eq!(
903 HashSet::<_>::from_iter(r.invalidated_nodes(&h)),
904 HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1]),
905 );
906
907 applicator(&mut h, r);
908 assert_eq!(h.validate(), Ok(()));
915 }
916
917 #[rstest]
918 fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
936 let mut h: Hugr = simple_hugr;
937
938 let h_node_cx: Node = h
940 .entry_descendants()
941 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
942 .unwrap();
943 let s: Vec<Node> = vec![h_node_cx];
944 let n: Hugr = dfg_hugr2;
946 let n_node_output = n.get_io(n.entrypoint()).unwrap()[1];
949 let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
950 let r = SimpleReplacement {
952 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
953 replacement: n,
954 };
955 let Outcome {
956 node_map,
957 removed_nodes,
958 } = h.apply_patch(r).unwrap();
959
960 assert_eq!(
961 node_map.into_keys().collect::<HashSet<_>>(),
962 [n_node_h].into_iter().collect::<HashSet<_>>(),
963 );
964 assert_eq!(
965 removed_nodes.into_keys().collect::<HashSet<_>>(),
966 [h_node_cx].into_iter().collect::<HashSet<_>>(),
967 );
968
969 assert_eq!(h.validate(), Ok(()));
976 }
977
978 #[test]
979 fn test_replace_cx_cross() {
980 let q_row: Vec<Type> = vec![qb_t(), qb_t()];
981 let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
982 let mut circ = builder.as_circuit(builder.input_wires());
983 circ.append(cx_gate(), [0, 1]).unwrap();
984 circ.append(cx_gate(), [1, 0]).unwrap();
985 let wires = circ.finish();
986 let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
987 let replacement = h.clone();
988 let orig = h.clone();
989
990 let removal = h
991 .entry_descendants()
992 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
993 .collect_vec();
994 h.apply_patch(
995 SimpleReplacement::try_new(
996 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
997 &h,
998 replacement,
999 )
1000 .unwrap(),
1001 )
1002 .unwrap();
1003
1004 assert_eq!(h.num_edges(), orig.num_edges());
1006 }
1007
1008 #[test]
1009 fn test_replace_after_copy() {
1010 let one_bit = vec![bool_t()];
1011 let two_bit = vec![bool_t(), bool_t()];
1012
1013 let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
1014 let inw = builder.input_wires().exactly_one().unwrap();
1015 let outw = builder
1016 .add_dataflow_op(and_op(), [inw, inw])
1017 .unwrap()
1018 .outputs();
1019 let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
1020
1021 let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
1022 let inw = builder.input_wires();
1023 let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
1024 let repl = builder.finish_hugr_with_outputs(outw).unwrap();
1025
1026 let orig = h.clone();
1027
1028 let removal = h
1029 .entry_descendants()
1030 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
1031 .collect_vec();
1032
1033 h.apply_patch(
1034 SimpleReplacement::try_new(
1035 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
1036 &h,
1037 repl,
1038 )
1039 .unwrap(),
1040 )
1041 .unwrap();
1042
1043 assert_eq!(h.num_nodes(), orig.num_nodes());
1045 }
1046
1047 #[rstest]
1052 fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
1053 let (mut hugr, nodes) = dfg_hugr_copy_bools;
1054 let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
1055
1056 let replacement = {
1057 let b =
1058 DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1059 let [w] = b.input_wires_arr();
1060 b.finish_hugr_with_outputs([w, w]).unwrap()
1061 };
1062
1063 let subgraph =
1064 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
1065 .unwrap();
1066
1067 let rewrite = SimpleReplacement {
1068 subgraph,
1069 replacement,
1070 };
1071 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1072
1073 assert_eq!(hugr.validate(), Ok(()));
1074 assert_eq!(hugr.entry_descendants().count(), 3);
1075 }
1076
1077 #[rstest]
1082 fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
1083 let (mut hugr, nodes) = dfg_hugr_half_not_bools;
1084 let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
1085
1086 let replacement = {
1087 let mut b =
1088 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1089 let [w] = b.input_wires_arr();
1090 let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
1091 let [w_not] = not.outputs_arr();
1092 b.finish_hugr_with_outputs([w, w_not]).unwrap()
1093 };
1094
1095 let subgraph =
1096 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
1097
1098 let rewrite = SimpleReplacement {
1099 subgraph,
1100 replacement,
1101 };
1102 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1103
1104 assert_eq!(hugr.validate(), Ok(()));
1105 assert_eq!(hugr.entry_descendants().count(), 4);
1106 }
1107
1108 #[rstest]
1109 fn test_nested_replace(dfg_hugr2: Hugr) {
1110 let mut h = dfg_hugr2;
1113 let h_node = h
1114 .entry_descendants()
1115 .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
1116 .unwrap();
1117
1118 let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
1120 let [input] = nest_build.input_wires_arr();
1121 let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
1122 let inner_dfg = n_identity(inner_build).unwrap();
1123 let replacement = nest_build
1124 .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
1125 .unwrap();
1126 let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
1127
1128 let rewrite = SimpleReplacement::try_new(subgraph, &h, replacement).unwrap();
1129
1130 assert_eq!(h.entry_descendants().count(), 4);
1131
1132 rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
1133 h.validate().unwrap_or_else(|e| panic!("{e}"));
1134
1135 assert_eq!(h.entry_descendants().count(), 6);
1136 }
1137
1138 #[fixture]
1140 fn copy_not_not_copy_hugr() -> Hugr {
1141 let mut b = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(); 4])).unwrap();
1142 let [w] = b.input_wires_arr();
1143 let not1 = b.add_dataflow_op(LogicOp::Not, [w]).unwrap();
1144 let not2 = b.add_dataflow_op(LogicOp::Not, [w]).unwrap();
1145
1146 let [out1] = not1.outputs_arr();
1147 let [out2] = not2.outputs_arr();
1148
1149 b.finish_hugr_with_outputs([out1, out2, out1, out2])
1150 .unwrap()
1151 }
1152
1153 #[rstest]
1154 fn test_boundary_traversal_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1155 let hugr = copy_not_not_copy_hugr;
1156 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1157 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1158 let subg_incoming = vec![
1159 vec![(not1, IncomingPort::from(0))],
1160 vec![(not2, IncomingPort::from(0))],
1161 ];
1162 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1163
1164 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1165
1166 let repl = {
1168 let b = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
1169 let [w1, w2] = b.input_wires_arr();
1170 let repl_hugr = b.finish_hugr_with_outputs([w1, w2]).unwrap();
1171 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap()
1172 };
1173
1174 let replacement_inputs: Vec<_> = repl
1176 .linked_replacement_inputs(
1177 (inp, OutgoingPort::from(0)),
1178 &hugr,
1179 BoundaryMode::SnapToHost,
1180 )
1181 .collect();
1182
1183 assert_eq!(
1184 BTreeSet::from_iter(replacement_inputs),
1185 (0..4)
1186 .map(|i| BoundaryPort::Host(out, IncomingPort::from(i)))
1187 .collect()
1188 );
1189
1190 let replacement_output = (0..4)
1192 .map(|i| {
1193 repl.linked_replacement_output(
1194 (out, IncomingPort::from(i)),
1195 &hugr,
1196 BoundaryMode::SnapToHost,
1197 )
1198 .unwrap()
1199 })
1200 .collect_vec();
1201
1202 assert_eq!(
1203 replacement_output,
1204 vec![BoundaryPort::Host(inp, OutgoingPort::from(0)); 4]
1205 );
1206 }
1207
1208 #[rstest]
1209 fn test_boundary_traversal_copy_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1210 let hugr = copy_not_not_copy_hugr;
1211 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1212 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1213 let subg_incoming = vec![vec![
1214 (not1, IncomingPort::from(0)),
1215 (not2, IncomingPort::from(0)),
1216 ]];
1217 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1218
1219 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1220
1221 let repl = {
1223 let b = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap();
1224 let [w] = b.input_wires_arr();
1225 let repl_hugr = b.finish_hugr_with_outputs([w, w]).unwrap();
1226 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap()
1227 };
1228
1229 let replacement_inputs: Vec<_> = repl
1230 .linked_replacement_inputs(
1231 (inp, OutgoingPort::from(0)),
1232 &hugr,
1233 BoundaryMode::SnapToHost,
1234 )
1235 .collect();
1236
1237 assert_eq!(
1238 BTreeSet::from_iter(replacement_inputs),
1239 (0..4)
1240 .map(|i| BoundaryPort::Host(out, IncomingPort::from(i)))
1241 .collect()
1242 );
1243
1244 let replacement_output = (0..4)
1245 .map(|i| {
1246 repl.linked_replacement_output(
1247 (out, IncomingPort::from(i)),
1248 &hugr,
1249 BoundaryMode::SnapToHost,
1250 )
1251 .unwrap()
1252 })
1253 .collect_vec();
1254
1255 assert_eq!(
1256 replacement_output,
1257 vec![BoundaryPort::Host(inp, OutgoingPort::from(0)); 4]
1258 );
1259 }
1260
1261 #[rstest]
1262 fn test_boundary_traversal_non_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1263 let hugr = copy_not_not_copy_hugr;
1264 let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1265 let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1266 let subg_incoming = vec![
1267 vec![(not1, IncomingPort::from(0))],
1268 vec![(not2, IncomingPort::from(0))],
1269 ];
1270 let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1271
1272 let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1273
1274 let (repl, or_node) = {
1276 let mut b = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
1277 let [w1, w2] = b.input_wires_arr();
1278 let or_handle = b.add_dataflow_op(LogicOp::Or, [w1, w2]).unwrap();
1279 let [out] = or_handle.outputs_arr();
1280 let repl_hugr = b.finish_hugr_with_outputs([out, out]).unwrap();
1281 (
1282 SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap(),
1283 or_handle.node(),
1284 )
1285 };
1286
1287 let replacement_inputs: Vec<_> = repl
1288 .linked_replacement_inputs(
1289 (inp, OutgoingPort::from(0)),
1290 &hugr,
1291 BoundaryMode::SnapToHost,
1292 )
1293 .collect();
1294
1295 assert_eq!(
1296 BTreeSet::from_iter(replacement_inputs),
1297 (0..2)
1298 .map(|i| BoundaryPort::Replacement(or_node, IncomingPort::from(i)))
1299 .collect()
1300 );
1301 assert_eq!(
1302 repl.linked_host_input((or_node, IncomingPort::from(0)), &hugr),
1303 (inp, OutgoingPort::from(0)).into()
1304 );
1305
1306 let replacement_output = (0..4)
1307 .map(|i| {
1308 repl.linked_replacement_output(
1309 (out, IncomingPort::from(i)),
1310 &hugr,
1311 BoundaryMode::SnapToHost,
1312 )
1313 .unwrap()
1314 })
1315 .collect_vec();
1316
1317 assert_eq!(
1318 replacement_output,
1319 vec![BoundaryPort::Replacement(or_node, OutgoingPort::from(0)); 4]
1320 );
1321 assert_eq!(
1322 BTreeSet::from_iter(repl.linked_host_outputs((or_node, OutgoingPort::from(0)), &hugr)),
1323 BTreeSet::from_iter((0..4).map(|i| HostPort(out, IncomingPort::from(i))))
1324 );
1325 }
1326
1327 use crate::hugr::patch::replace::Replacement;
1328 fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
1329 use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec};
1330
1331 let [in_, out] = s.get_replacement_io();
1332 let mu_inp = s
1333 .incoming_boundary(h)
1334 .map(
1335 |(HostPort(src, src_port), ReplacementPort(tgt, tgt_port))| {
1336 if tgt == out {
1337 unimplemented!()
1338 }
1339 NewEdgeSpec {
1340 src,
1341 tgt,
1342 kind: NewEdgeKind::Value {
1343 src_pos: src_port,
1344 tgt_pos: tgt_port,
1345 },
1346 }
1347 },
1348 )
1349 .collect();
1350 let mu_out = s
1351 .outgoing_boundary(h)
1352 .map(
1353 |(ReplacementPort(src, src_port), HostPort(tgt, tgt_port))| {
1354 if src == in_ {
1355 unimplemented!()
1356 }
1357 NewEdgeSpec {
1358 src,
1359 tgt,
1360 kind: NewEdgeKind::Value {
1361 src_pos: src_port,
1362 tgt_pos: tgt_port,
1363 },
1364 }
1365 },
1366 )
1367 .collect();
1368 let mut replacement = s.replacement;
1369 replacement.remove_node(in_);
1370 replacement.remove_node(out);
1371 Replacement {
1372 removal: s.subgraph.nodes().to_vec(),
1373 replacement,
1374 adoptions: HashMap::new(),
1375 mu_inp,
1376 mu_out,
1377 mu_new: vec![],
1378 }
1379 }
1380
1381 fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
1382 h.apply_patch(rw).unwrap();
1383 }
1384
1385 fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
1386 h.apply_patch(to_replace(h, rw)).unwrap();
1387 }
1388}