1use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7use crate::hugr::views::SiblingSubgraph;
8use crate::hugr::{HugrMut, HugrView};
9use crate::ops::{OpTag, OpTrait, OpType};
10use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port};
11
12use derive_more::derive::From;
13use itertools::{Either, Itertools};
14
15use thiserror::Error;
16
17use super::inline_dfg::InlineDFGError;
18use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort};
19
20#[derive(Debug, Clone)]
26pub struct SimpleReplacement<HostNode = Node> {
27 subgraph: SiblingSubgraph<HostNode>,
29 replacement: Hugr,
31 nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
35 nu_out: OutputBoundaryMap<HostNode>,
60}
61
62#[derive(Debug, Clone, From)]
68pub enum OutputBoundaryMap<HostNode> {
69 ByIncoming(HashMap<(HostNode, IncomingPort), IncomingPort>),
72 ByOutgoing(HashMap<(HostNode, OutgoingPort), IncomingPort>),
75}
76
77impl<N: HugrNode> OutputBoundaryMap<N> {
78 pub fn iter(&self) -> impl Iterator<Item = ((N, Port), IncomingPort)> + '_ {
83 match self {
84 OutputBoundaryMap::ByIncoming(map) => Either::Left(
85 map.iter()
86 .map(|(&(node, in_port), &v)| ((node, in_port.into()), v)),
87 ),
88 OutputBoundaryMap::ByOutgoing(map) => Either::Right(
89 map.iter()
90 .map(|(&(node, out_port), &v)| ((node, out_port.into()), v)),
91 ),
92 }
93 .into_iter()
94 }
95
96 pub fn iter_as_incoming<'a>(
101 &'a self,
102 host: &'a impl HugrView<Node = N>,
103 ) -> impl Iterator<Item = ((N, IncomingPort), IncomingPort)> + 'a {
104 self.iter()
105 .flat_map(move |((rem_out_node, rem_out_port), rep_out_port)| {
106 as_incoming_ports(rem_out_node, rem_out_port, host).map(
107 move |(rem_out_node, rem_out_port)| {
108 ((rem_out_node, rem_out_port), rep_out_port)
109 },
110 )
111 })
112 }
113
114 pub fn get<P: Into<Port>>(&self, node: N, port: P) -> Option<IncomingPort> {
119 match (self, port.into().as_directed()) {
120 (OutputBoundaryMap::ByIncoming(map), Either::Left(incoming)) => {
121 map.get(&(node, incoming)).copied()
122 }
123 (OutputBoundaryMap::ByOutgoing(map), Either::Right(outgoing)) => {
124 map.get(&(node, outgoing)).copied()
125 }
126 _ => None,
127 }
128 }
129
130 pub fn get_as_incoming(
135 &self,
136 node: N,
137 incoming: IncomingPort,
138 host: &impl HugrView<Node = N>,
139 ) -> Option<IncomingPort> {
140 match self {
141 OutputBoundaryMap::ByIncoming(map) => map.get(&(node, incoming)).copied(),
142 OutputBoundaryMap::ByOutgoing(map) => {
143 let outgoing = host
144 .single_linked_output(node, incoming)
145 .expect("invalid data flow wire");
146 map.get(&outgoing).copied()
147 }
148 }
149 }
150}
151
152impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
153 #[inline]
155 pub fn new(
156 subgraph: SiblingSubgraph<HostNode>,
157 replacement: Hugr,
158 nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
159 nu_out: impl Into<OutputBoundaryMap<HostNode>>,
160 ) -> Self {
161 Self {
162 subgraph,
163 replacement,
164 nu_inp,
165 nu_out: nu_out.into(),
166 }
167 }
168
169 #[inline]
171 pub fn replacement(&self) -> &Hugr {
172 &self.replacement
173 }
174
175 #[inline]
177 pub fn into_replacement(self) -> Hugr {
178 self.replacement
179 }
180
181 #[inline]
183 pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
184 &self.subgraph
185 }
186
187 pub fn is_valid_rewrite(
189 &self,
190 h: &impl HugrView<Node = HostNode>,
191 ) -> Result<(), SimpleReplacementError> {
192 let parent = self.subgraph.get_parent(h);
193
194 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
196 return Err(SimpleReplacementError::InvalidParentNode());
197 }
198
199 for node in self.subgraph.nodes() {
201 if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
202 return Err(SimpleReplacementError::InvalidRemovedNode());
203 }
204 }
205
206 Ok(())
207 }
208
209 pub fn get_replacement_io(&self) -> Result<[Node; 2], SimpleReplacementError> {
211 self.replacement
212 .get_io(self.replacement.entrypoint())
213 .ok_or(SimpleReplacementError::InvalidParentNode())
214 }
215
216 pub fn incoming_boundary<'a>(
225 &'a self,
226 host: &'a impl HugrView<Node = HostNode>,
227 ) -> impl Iterator<
228 Item = (
229 HostPort<HostNode, OutgoingPort>,
230 ReplacementPort<IncomingPort>,
231 ),
232 > + 'a {
233 self.nu_inp
236 .iter()
237 .filter(|&((rep_inp_node, _), _)| {
238 self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
239 })
240 .map(
241 |(&(rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
242 let (rem_inp_pred_node, rem_inp_pred_port) = host
245 .single_linked_output(*rem_inp_node, *rem_inp_port)
246 .unwrap();
247 (
248 HostPort(rem_inp_pred_node, rem_inp_pred_port),
249 ReplacementPort(rep_inp_node, rep_inp_port),
250 )
251 },
252 )
253 }
254
255 pub fn outgoing_boundary<'a>(
266 &'a self,
267 host: &'a impl HugrView<Node = HostNode>,
268 ) -> impl Iterator<
269 Item = (
270 ReplacementPort<OutgoingPort>,
271 HostPort<HostNode, IncomingPort>,
272 ),
273 > + 'a {
274 let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG");
275
276 self.nu_out.iter_as_incoming(host).filter_map(
280 move |((rem_out_node, rem_out_port), rep_out_port)| {
281 let (rep_out_pred_node, rep_out_pred_port) = self
282 .replacement
283 .single_linked_output(replacement_output_node, rep_out_port)
284 .unwrap();
285 (self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
286 (
287 ReplacementPort(rep_out_pred_node, rep_out_pred_port),
289 HostPort(rem_out_node, rem_out_port),
290 )
291 })
292 },
293 )
294 }
295
296 pub fn host_to_host_boundary<'a>(
305 &'a self,
306 host: &'a impl HugrView<Node = HostNode>,
307 ) -> impl Iterator<
308 Item = (
309 HostPort<HostNode, OutgoingPort>,
310 HostPort<HostNode, IncomingPort>,
311 ),
312 > + 'a {
313 let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG");
314
315 self.nu_out.iter_as_incoming(host).filter_map(
318 move |((rem_out_node, rem_out_port), rep_out_port)| {
319 self.nu_inp
320 .get(&(replacement_output_node, rep_out_port))
321 .map(|&(rem_inp_node, rem_inp_port)| {
322 let (rem_inp_pred_node, rem_inp_pred_port) = host
323 .single_linked_output(rem_inp_node, rem_inp_port)
324 .unwrap();
325 (
326 HostPort(rem_inp_pred_node, rem_inp_pred_port),
327 HostPort(rem_out_node, rem_out_port),
328 )
329 })
330 },
331 )
332 }
333
334 pub fn map_host_output<P: Into<Port>>(
342 &self,
343 port: impl Into<HostPort<HostNode, P>>,
344 ) -> Option<ReplacementPort<IncomingPort>> {
345 let HostPort(node, port) = port.into();
346 let [_, rep_output] = self.get_replacement_io().expect("replacement is a DFG");
347 self.nu_out
348 .get(node, port.into())
349 .map(|rep_out_port| ReplacementPort(rep_output, rep_out_port))
350 }
351
352 pub fn map_replacement_input(
357 &self,
358 port: impl Into<ReplacementPort<IncomingPort>>,
359 ) -> Option<HostPort<HostNode, IncomingPort>> {
360 let ReplacementPort(node, port) = port.into();
361 self.nu_inp.get(&(node, port)).copied().map(Into::into)
362 }
363
364 pub fn all_boundary_edges<'a>(
373 &'a self,
374 host: &'a impl HugrView<Node = HostNode>,
375 ) -> impl Iterator<
376 Item = (
377 BoundaryPort<HostNode, OutgoingPort>,
378 BoundaryPort<HostNode, IncomingPort>,
379 ),
380 > + 'a {
381 let incoming_boundary = self
382 .incoming_boundary(host)
383 .map(|(src, tgt)| (src.into(), tgt.into()));
384 let outgoing_boundary = self
385 .outgoing_boundary(host)
386 .map(|(src, tgt)| (src.into(), tgt.into()));
387 let host_to_host_boundary = self
388 .host_to_host_boundary(host)
389 .map(|(src, tgt)| (src.into(), tgt.into()));
390
391 incoming_boundary
392 .chain(outgoing_boundary)
393 .chain(host_to_host_boundary)
394 }
395}
396
397impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
398 type Error = SimpleReplacementError;
399 type Node = HostNode;
400
401 fn verify(&self, h: &impl HugrView<Node = HostNode>) -> Result<(), SimpleReplacementError> {
402 self.is_valid_rewrite(h)
403 }
404
405 #[inline]
406 fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
407 let subcirc = self.subgraph.nodes().iter().copied();
408 let nu_out_nodes = match &self.nu_out {
409 OutputBoundaryMap::ByIncoming(map) => Some(map.keys().map(|key| key.0)),
410 OutputBoundaryMap::ByOutgoing(_) => None,
411 }
412 .into_iter()
413 .flatten();
414 subcirc.chain(nu_out_nodes)
415 }
416}
417
418pub struct Outcome<HostNode = Node> {
420 pub node_map: HashMap<Node, HostNode>,
422 pub removed_nodes: HashMap<HostNode, OpType>,
424}
425
426impl<N: HugrNode> PatchHugrMut for SimpleReplacement<N> {
427 type Outcome = Outcome<N>;
428 const UNCHANGED_ON_FAILURE: bool = true;
429
430 fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
431 self.is_valid_rewrite(h)?;
432
433 let parent = self.subgraph.get_parent(h);
434
435 let boundary_edges = self.all_boundary_edges(h).collect_vec();
443
444 let Self {
445 replacement,
446 subgraph,
447 ..
448 } = self;
449
450 let repl_io = replacement
452 .get_io(replacement.entrypoint())
453 .expect("replacement is DFG-rooted");
454 let repl_entrypoint = replacement.entrypoint();
455
456 let InsertionResult {
458 inserted_entrypoint: new_entrypoint,
459 mut node_map,
460 } = h.insert_hugr(parent, replacement);
461
462 for node in repl_io {
464 let node_h = node_map[&node];
465 h.remove_node(node_h);
466 node_map.remove(&node);
467 }
468
469 for child in h.children(new_entrypoint).collect_vec() {
471 h.set_parent(child, parent);
472 }
473
474 h.remove_node(new_entrypoint);
476 node_map.remove(&repl_entrypoint);
477
478 for (src, tgt) in boundary_edges {
480 let (src_node, src_port) = src.map_replacement(&node_map);
481 let (tgt_node, tgt_port) = tgt.map_replacement(&node_map);
482 h.connect(src_node, src_port, tgt_node, tgt_port);
483 }
484
485 let removed_nodes = subgraph
487 .nodes()
488 .iter()
489 .map(|&node| (node, h.remove_node(node)))
490 .collect();
491
492 Ok(Outcome {
493 node_map,
494 removed_nodes,
495 })
496 }
497}
498
499#[derive(Debug, Clone, Error, PartialEq, Eq)]
501#[non_exhaustive]
502pub enum SimpleReplacementError {
503 #[error("Parent node is invalid.")]
505 InvalidParentNode(),
506 #[error("A node requested for removal is invalid.")]
508 InvalidRemovedNode(),
509 #[error("A node in the replacement graph is invalid.")]
511 InvalidReplacementNode(),
512 #[error("Inlining replacement failed: {0}")]
514 InliningFailed(#[from] InlineDFGError),
515}
516
517fn as_incoming_ports<'a, N: HugrNode + 'a>(
518 node: N,
519 port: Port,
520 hugr: &'a impl HugrView<Node = N>,
521) -> impl Iterator<Item = (N, IncomingPort)> + 'a {
522 match port.as_directed() {
523 Either::Left(incoming) => Either::Left(std::iter::once((node, incoming))),
524 Either::Right(outgoing) => Either::Right(hugr.linked_inputs(node, outgoing)),
525 }
526 .into_iter()
527}
528
529#[cfg(test)]
530pub(in crate::hugr::patch) mod test {
531 use itertools::Itertools;
532 use rstest::{fixture, rstest};
533
534 use std::collections::{HashMap, HashSet};
535
536 use crate::builder::test::n_identity;
537 use crate::builder::{
538 endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
539 DataflowSubContainer, HugrBuilder, ModuleBuilder,
540 };
541 use crate::extension::prelude::{bool_t, qb_t};
542 use crate::hugr::patch::simple_replace::{Outcome, OutputBoundaryMap};
543 use crate::hugr::patch::{PatchVerification, ReplacementPort};
544 use crate::hugr::views::{HugrView, SiblingSubgraph};
545 use crate::hugr::{Hugr, HugrMut, Patch};
546 use crate::ops::dataflow::DataflowOpTrait;
547 use crate::ops::handle::NodeHandle;
548 use crate::ops::OpTag;
549 use crate::ops::OpTrait;
550 use crate::std_extensions::logic::test::and_op;
551 use crate::std_extensions::logic::LogicOp;
552 use crate::types::{Signature, Type};
553 use crate::utils::test_quantum_extension::{cx_gate, h_gate};
554 use crate::{Direction, IncomingPort, Node, OutgoingPort, Port};
555
556 use super::SimpleReplacement;
557
558 fn make_hugr() -> Result<Hugr, BuildError> {
568 let mut module_builder = ModuleBuilder::new();
569 let _f_id = {
570 let mut func_builder = module_builder
571 .define_function("main", Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]))?;
572
573 let [qb0, qb1, qb2] = func_builder.input_wires_arr();
574
575 let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
576
577 let mut inner_builder =
578 func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
579 let inner_graph = {
580 let [wire0, wire1] = inner_builder.input_wires_arr();
581 let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
582 let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
583 let wire45 = inner_builder
584 .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
585 let [wire4, wire5] = wire45.outputs_arr();
586 let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
587 let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
588 inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
589 }?;
590
591 func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
592 };
593 Ok(module_builder.finish_hugr()?)
594 }
595
596 #[fixture]
597 pub(in crate::hugr::patch) fn simple_hugr() -> Hugr {
598 make_hugr().unwrap()
599 }
600 fn make_dfg_hugr() -> Result<Hugr, BuildError> {
607 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
608 let [wire0, wire1] = dfg_builder.input_wires_arr();
609 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
610 let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
611 let wire45 =
612 dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
613 dfg_builder.finish_hugr_with_outputs(wire45.outputs())
614 }
615
616 #[fixture]
617 pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr {
618 make_dfg_hugr().unwrap()
619 }
620
621 fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
627 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
628
629 let [wire0, wire1] = dfg_builder.input_wires_arr();
630 let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
631 let wire2out = wire2.outputs().exactly_one().unwrap();
632 let wireoutvec = vec![wire0, wire2out];
633 dfg_builder.finish_hugr_with_outputs(wireoutvec)
634 }
635
636 #[fixture]
637 pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr {
638 make_dfg_hugr2().unwrap()
639 }
640
641 #[fixture]
654 pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
655 let mut dfg_builder =
656 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
657 let [b] = dfg_builder.input_wires_arr();
658
659 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
660 let [b] = not_inp.outputs_arr();
661
662 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
663 let [b0] = not_0.outputs_arr();
664 let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
665 let [b1] = not_1.outputs_arr();
666
667 (
668 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
669 vec![not_inp.node(), not_0.node(), not_1.node()],
670 )
671 }
672
673 #[fixture]
686 pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
687 let mut dfg_builder =
688 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
689 let [b] = dfg_builder.input_wires_arr();
690
691 let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
692 let [b] = not_inp.outputs_arr();
693
694 let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
695 let [b0] = not_0.outputs_arr();
696 let b1 = b;
697
698 (
699 dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
700 vec![not_inp.node(), not_0.node()],
701 )
702 }
703
704 #[rstest]
705 fn test_simple_replacement(
724 simple_hugr: Hugr,
725 dfg_hugr: Hugr,
726 #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
727 ) {
728 let mut h: Hugr = simple_hugr;
729 let h_node_cx: Node = h
731 .entry_descendants()
732 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
733 .unwrap();
734 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
735 let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
736 let n: Hugr = dfg_hugr;
738 let n_node_cx = n
741 .entry_descendants()
742 .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
743 .unwrap();
744 let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap();
745 let n_port_0 = n.node_inputs(n_node_h0).next().unwrap();
747 let n_port_1 = n.node_inputs(n_node_h1).next().unwrap();
748 let (n_cx_out_0, n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
749 let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
750 let n_port_3 = n.linked_inputs(n_node_cx, n_cx_out_1).next().unwrap().1;
751 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
753 let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
754 let h_h1_out = h.node_outputs(h_node_h1).next().unwrap();
755 let (h_outp_node, h_port_2) = h.linked_inputs(h_node_h0, h_h0_out).next().unwrap();
756 let h_port_3 = h.linked_inputs(h_node_h1, h_h1_out).next().unwrap().1;
757 let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
759 let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
760 nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0));
761 nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1));
762 nu_out.insert((h_outp_node, h_port_2), n_port_2);
763 nu_out.insert((h_outp_node, h_port_3), n_port_3);
764 let r = SimpleReplacement {
766 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
767 replacement: n,
768 nu_inp,
769 nu_out: nu_out.into(),
770 };
771
772 assert_eq!(
774 r.map_host_output((h_outp_node, h_port_2)).unwrap(),
775 ReplacementPort::from((r.get_replacement_io().unwrap()[1], n_port_2))
776 );
777 assert!(r
778 .map_host_output((h_outp_node, OutgoingPort::from(0)))
779 .is_none());
780
781 assert_eq!(
783 HashSet::<_>::from_iter(r.invalidation_set()),
784 HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]),
785 );
786
787 applicator(&mut h, r);
788 assert_eq!(h.validate(), Ok(()));
795 }
796
797 #[rstest]
798 fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
816 let mut h: Hugr = simple_hugr;
817
818 let h_node_cx: Node = h
820 .entry_descendants()
821 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
822 .unwrap();
823 let s: Vec<Node> = vec![h_node_cx];
824 let n: Hugr = dfg_hugr2;
826 let n_node_output = n.get_io(n.entrypoint()).unwrap()[1];
829 let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
830 let (n_port_0, n_port_1) = n
832 .node_inputs(n_node_output)
833 .take(2)
834 .collect_tuple()
835 .unwrap();
836 let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
837 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
839 let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
840 let h_port_2 = h.node_inputs(h_node_h0).next().unwrap();
841 let h_port_3 = h.node_inputs(h_node_h1).next().unwrap();
842 let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
844 let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
845 nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
846 nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
847 nu_out.insert((h_node_h0, h_port_2), n_port_0);
848 nu_out.insert((h_node_h1, h_port_3), n_port_1);
849 let r = SimpleReplacement {
851 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
852 replacement: n,
853 nu_inp,
854 nu_out: nu_out.into(),
855 };
856 let Outcome {
857 node_map,
858 removed_nodes,
859 } = h.apply_patch(r).unwrap();
860
861 assert_eq!(
862 node_map.into_keys().collect::<HashSet<_>>(),
863 [n_node_h].into_iter().collect::<HashSet<_>>(),
864 );
865 assert_eq!(
866 removed_nodes.into_keys().collect::<HashSet<_>>(),
867 [h_node_cx].into_iter().collect::<HashSet<_>>(),
868 );
869
870 assert_eq!(h.validate(), Ok(()));
877 }
878
879 #[test]
880 fn test_replace_cx_cross() {
881 let q_row: Vec<Type> = vec![qb_t(), qb_t()];
882 let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
883 let mut circ = builder.as_circuit(builder.input_wires());
884 circ.append(cx_gate(), [0, 1]).unwrap();
885 circ.append(cx_gate(), [1, 0]).unwrap();
886 let wires = circ.finish();
887 let [input, output] = builder.io();
888 let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
889 let replacement = h.clone();
890 let orig = h.clone();
891
892 let removal = h
893 .entry_descendants()
894 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
895 .collect_vec();
896 let inputs = h
897 .node_outputs(input)
898 .filter(|&p| {
899 h.get_optype(input)
900 .as_input()
901 .unwrap()
902 .signature()
903 .port_type(p)
904 .is_some()
905 })
906 .map(|p| {
907 let link = h.linked_inputs(input, p).next().unwrap();
908 (link, link)
909 })
910 .collect();
911 let outputs: HashMap<_, _> = h
912 .node_inputs(output)
913 .filter(|&p| {
914 h.get_optype(output)
915 .as_output()
916 .unwrap()
917 .signature()
918 .port_type(p)
919 .is_some()
920 })
921 .map(|p| ((output, p), p))
922 .collect();
923 h.apply_patch(SimpleReplacement::new(
924 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
925 replacement,
926 inputs,
927 outputs,
928 ))
929 .unwrap();
930
931 assert_eq!(h.num_edges(), orig.num_edges());
933 }
934
935 #[test]
936 fn test_replace_after_copy() {
937 let one_bit = vec![bool_t()];
938 let two_bit = vec![bool_t(), bool_t()];
939
940 let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
941 let inw = builder.input_wires().exactly_one().unwrap();
942 let outw = builder
943 .add_dataflow_op(and_op(), [inw, inw])
944 .unwrap()
945 .outputs();
946 let [input, _] = builder.io();
947 let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
948
949 let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
950 let inw = builder.input_wires();
951 let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
952 let [repl_input, repl_output] = builder.io();
953 let repl = builder.finish_hugr_with_outputs(outw).unwrap();
954
955 let orig = h.clone();
956
957 let removal = h
958 .entry_descendants()
959 .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
960 .collect_vec();
961
962 let first_out_p = h.node_outputs(input).next().unwrap();
963 let embedded_inputs = h.linked_inputs(input, first_out_p);
964 let repl_inputs = repl
965 .node_outputs(repl_input)
966 .map(|p| repl.linked_inputs(repl_input, p).next().unwrap());
967 let inputs = embedded_inputs.zip(repl_inputs).collect();
968
969 let outputs: HashMap<_, _> = repl
970 .node_inputs(repl_output)
971 .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some())
972 .map(|p| ((repl_output, p), p))
973 .collect();
974
975 h.apply_patch(SimpleReplacement::new(
976 SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
977 repl,
978 inputs,
979 outputs,
980 ))
981 .unwrap();
982
983 assert_eq!(h.num_nodes(), orig.num_nodes());
985 }
986
987 #[rstest]
992 fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
993 let (mut hugr, nodes) = dfg_hugr_copy_bools;
994 let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
995
996 let [_input, output] = hugr.get_io(hugr.entrypoint()).unwrap();
997
998 let replacement = {
999 let b =
1000 DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1001 let [w] = b.input_wires_arr();
1002 b.finish_hugr_with_outputs([w, w]).unwrap()
1003 };
1004 let [_repl_input, repl_output] = replacement.get_io(replacement.entrypoint()).unwrap();
1005
1006 let subgraph =
1007 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
1008 .unwrap();
1009 let nu_inp = [
1013 (
1014 (repl_output, IncomingPort::from(0)),
1015 (input_not, IncomingPort::from(0)),
1016 ),
1017 (
1018 (repl_output, IncomingPort::from(1)),
1019 (input_not, IncomingPort::from(0)),
1020 ),
1021 ]
1022 .into_iter()
1023 .collect();
1024 let nu_out: HashMap<_, _> = [
1027 ((output, IncomingPort::from(0)), IncomingPort::from(0)),
1028 ((output, IncomingPort::from(1)), IncomingPort::from(1)),
1029 ]
1030 .into_iter()
1031 .collect();
1032
1033 let rewrite = SimpleReplacement {
1034 subgraph,
1035 replacement,
1036 nu_inp,
1037 nu_out: nu_out.into(),
1038 };
1039 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1040
1041 assert_eq!(hugr.validate(), Ok(()));
1042 assert_eq!(hugr.entry_descendants().count(), 3);
1043 }
1044
1045 #[rstest]
1050 fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
1051 let (mut hugr, nodes) = dfg_hugr_half_not_bools;
1052 let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
1053
1054 let [_input, output] = hugr.get_io(hugr.entrypoint()).unwrap();
1055
1056 let (replacement, repl_not) = {
1057 let mut b =
1058 DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1059 let [w] = b.input_wires_arr();
1060 let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
1061 let [w_not] = not.outputs_arr();
1062 (b.finish_hugr_with_outputs([w, w_not]).unwrap(), not.node())
1063 };
1064 let [_repl_input, repl_output] = replacement.get_io(replacement.entrypoint()).unwrap();
1065
1066 let subgraph =
1067 SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
1068 let nu_inp = [
1072 (
1073 (repl_output, IncomingPort::from(0)),
1074 (input_not, IncomingPort::from(0)),
1075 ),
1076 (
1077 (repl_not, IncomingPort::from(0)),
1078 (input_not, IncomingPort::from(0)),
1079 ),
1080 ]
1081 .into_iter()
1082 .collect();
1083 let nu_out: HashMap<_, _> = [
1086 ((output, IncomingPort::from(0)), IncomingPort::from(0)),
1087 ((output, IncomingPort::from(1)), IncomingPort::from(1)),
1088 ]
1089 .into_iter()
1090 .collect();
1091
1092 let rewrite = SimpleReplacement {
1093 subgraph,
1094 replacement,
1095 nu_inp,
1096 nu_out: nu_out.into(),
1097 };
1098 rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1099
1100 assert_eq!(hugr.validate(), Ok(()));
1101 assert_eq!(hugr.entry_descendants().count(), 4);
1102 }
1103
1104 #[rstest]
1105 fn test_nested_replace(dfg_hugr2: Hugr) {
1106 let mut h = dfg_hugr2;
1109 let h_node = h
1110 .entry_descendants()
1111 .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
1112 .unwrap();
1113
1114 let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
1116 let [input] = nest_build.input_wires_arr();
1117 let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
1118 let inner_dfg = n_identity(inner_build).unwrap();
1119 let inner_dfg_node = inner_dfg.node();
1120 let replacement = nest_build
1121 .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
1122 .unwrap();
1123 let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
1124 let nu_inp = vec![(
1125 (inner_dfg_node, IncomingPort::from(0)),
1126 (h_node, IncomingPort::from(0)),
1127 )]
1128 .into_iter()
1129 .collect();
1130
1131 let nu_out: HashMap<_, _> = vec![(
1132 (h.get_io(h.entrypoint()).unwrap()[1], IncomingPort::from(1)),
1133 IncomingPort::from(0),
1134 )]
1135 .into_iter()
1136 .collect();
1137
1138 let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);
1139
1140 assert_eq!(h.entry_descendants().count(), 4);
1141
1142 rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
1143 h.validate().unwrap_or_else(|e| panic!("{e}"));
1144
1145 assert_eq!(h.entry_descendants().count(), 6);
1146 }
1147
1148 #[rstest]
1149 fn test_simple_replacement_with_empty_wires_using_outgoing_ports(
1150 simple_hugr: Hugr,
1151 dfg_hugr2: Hugr,
1152 ) {
1153 let mut h: Hugr = simple_hugr;
1154
1155 let h_node_cx: Node = h
1157 .entry_descendants()
1158 .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
1159 .unwrap();
1160 let s = vec![h_node_cx];
1161 let n: Hugr = dfg_hugr2;
1163 let [_n_node_input, n_node_output] = n.get_io(n.entrypoint()).unwrap();
1166 let n_node_h = n.input_neighbours(n_node_output).nth(1).unwrap();
1167 let (n_port_0, n_port_1) = n
1169 .node_inputs(n_node_output)
1170 .take(2)
1171 .collect_tuple()
1172 .unwrap();
1173 let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
1174 let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
1176 let mut nu_inp = HashMap::new();
1178 let mut nu_out = HashMap::new();
1179 nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
1180 nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
1181 nu_out.insert((h_node_cx, OutgoingPort::from(0)), n_port_0);
1182 nu_out.insert((h_node_cx, OutgoingPort::from(1)), n_port_1);
1183 let r = SimpleReplacement {
1185 subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
1186 replacement: n,
1187 nu_inp,
1188 nu_out: nu_out.into(),
1189 };
1190 h.apply_patch(r).unwrap();
1191 assert_eq!(h.validate(), Ok(()));
1198 }
1199
1200 #[rstest]
1201 fn test_output_boundary_map(dfg_hugr2: Hugr) {
1202 let [inp, out] = dfg_hugr2.get_io(dfg_hugr2.entrypoint()).unwrap();
1203 let map = [
1204 ((inp, OutgoingPort::from(0)), IncomingPort::from(0)),
1205 ((inp, OutgoingPort::from(1)), IncomingPort::from(1)),
1206 ]
1207 .into_iter()
1208 .collect();
1209 let map = OutputBoundaryMap::ByOutgoing(map);
1210
1211 assert_eq!(
1213 map.get(inp, OutgoingPort::from(0)),
1214 Some(IncomingPort::from(0))
1215 );
1216 assert_eq!(
1217 map.get(inp, OutgoingPort::from(1)),
1218 Some(IncomingPort::from(1))
1219 );
1220
1221 assert!(map.get(out, IncomingPort::from(0)).is_none());
1223 assert_eq!(
1224 map.get_as_incoming(out, IncomingPort::from(0), &dfg_hugr2),
1225 Some(IncomingPort::from(0))
1226 );
1227
1228 assert_eq!(
1230 map.iter().collect::<HashSet<_>>(),
1231 HashSet::from_iter([
1232 (
1233 (inp, Port::new(Direction::Outgoing, 0)),
1234 IncomingPort::from(0)
1235 ),
1236 (
1237 (inp, Port::new(Direction::Outgoing, 1)),
1238 IncomingPort::from(1)
1239 ),
1240 ])
1241 );
1242 let h_gate = dfg_hugr2.output_neighbours(inp).nth(1).unwrap();
1243 assert_eq!(
1244 map.iter_as_incoming(&dfg_hugr2).collect::<HashSet<_>>(),
1245 HashSet::from_iter([
1246 ((out, IncomingPort::from(0)), IncomingPort::from(0)),
1247 ((h_gate, IncomingPort::from(0)), IncomingPort::from(1)),
1248 ])
1249 );
1250 }
1251
1252 use crate::hugr::patch::replace::Replacement;
1253 fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
1254 use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec};
1255
1256 let mut replacement = s.replacement;
1257 let (in_, out) = replacement
1258 .children(replacement.entrypoint())
1259 .take(2)
1260 .collect_tuple()
1261 .unwrap();
1262 let mu_inp = s
1263 .nu_inp
1264 .iter()
1265 .map(|((tgt, tgt_port), (r_n, r_p))| {
1266 if *tgt == out {
1267 unimplemented!()
1268 };
1269 let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap();
1270 NewEdgeSpec {
1271 src,
1272 tgt: *tgt,
1273 kind: NewEdgeKind::Value {
1274 src_pos: src_port,
1275 tgt_pos: *tgt_port,
1276 },
1277 }
1278 })
1279 .collect();
1280 let mu_out = s
1281 .nu_out
1282 .iter_as_incoming(&h)
1283 .map(|((tgt, tgt_port), out_port)| {
1284 let (src, src_port) = replacement.single_linked_output(out, out_port).unwrap();
1285 if src == in_ {
1286 unimplemented!()
1287 };
1288 NewEdgeSpec {
1289 src,
1290 tgt,
1291 kind: NewEdgeKind::Value {
1292 src_pos: src_port,
1293 tgt_pos: tgt_port,
1294 },
1295 }
1296 })
1297 .collect();
1298 replacement.remove_node(in_);
1299 replacement.remove_node(out);
1300 Replacement {
1301 removal: s.subgraph.nodes().to_vec(),
1302 replacement,
1303 adoptions: HashMap::new(),
1304 mu_inp,
1305 mu_out,
1306 mu_new: vec![],
1307 }
1308 }
1309
1310 fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
1311 h.apply_patch(rw).unwrap();
1312 }
1313
1314 fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
1315 h.apply_patch(to_replace(h, rw)).unwrap();
1316 }
1317}