Skip to main content

hugr_core/hugr/patch/
simple_replace.rs

1//! Implementation of the `SimpleReplace` operation.
2
3use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7use crate::hugr::views::SiblingSubgraph;
8pub use crate::hugr::views::sibling_subgraph::InvalidReplacement;
9use crate::hugr::{HugrMut, HugrView};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::{Hugr, IncomingPort, Node, OutgoingPort, PortIndex};
12
13use itertools::{Either, Itertools};
14
15use thiserror::Error;
16
17use super::inline_dfg::InlineDFGError;
18use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort};
19
20pub mod serial;
21
22/// Specification of a simple replacement operation.
23///
24/// # Type parameters
25///
26/// - `N`: The type of nodes in the host hugr.
27#[derive(Debug, Clone)]
28pub struct SimpleReplacement<HostNode = Node> {
29    /// The subgraph of the host hugr to be replaced.
30    subgraph: SiblingSubgraph<HostNode>,
31    /// A hugr with DFG root (consisting of replacement nodes).
32    replacement: Hugr,
33}
34
35impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
36    /// Create a new [`SimpleReplacement`] specification without checking that
37    /// the replacement has the same signature as the subgraph.
38    #[inline]
39    pub fn new_unchecked(subgraph: SiblingSubgraph<HostNode>, replacement: Hugr) -> Self {
40        Self {
41            subgraph,
42            replacement,
43        }
44    }
45
46    /// Create a new [`SimpleReplacement`] specification.
47    ///
48    /// The given replacement should have an entrypoint that is a dataflow container (such as a DFG
49    /// or a function definition). If the entrypoint is a function, its (potentially polymorphic)
50    /// signature will be checked against the subgraph signature, otherwise the inner function type
51    /// of the entrypoint will be checked against the subgraph signature.
52    ///
53    /// Return a [`InvalidReplacement::InvalidSignature`] error if `subgraph`
54    /// and `replacement` have different signatures.
55    pub fn try_new(
56        subgraph: SiblingSubgraph<HostNode>,
57        host: &impl HugrView<Node = HostNode>,
58        replacement: Hugr,
59    ) -> Result<Self, InvalidReplacement> {
60        let subgraph_sig = subgraph.poly_func_type(host);
61        let repl_sig = replacement
62            .poly_func_type()
63            .or_else(|| {
64                Some(
65                    replacement
66                        .inner_function_type()
67                        .unwrap()
68                        .into_owned()
69                        .into(),
70                )
71            })
72            .ok_or(InvalidReplacement::InvalidDataflowGraph {
73                node: replacement.entrypoint(),
74                op: Box::new(replacement.entrypoint_optype().to_owned()),
75            })?;
76        if subgraph_sig != repl_sig {
77            return Err(InvalidReplacement::InvalidSignature {
78                expected: Box::new(subgraph_sig),
79                actual: Some(Box::new(repl_sig)),
80            });
81        }
82        Ok(Self {
83            subgraph,
84            replacement,
85        })
86    }
87
88    /// The replacement hugr.
89    #[inline]
90    pub fn replacement(&self) -> &Hugr {
91        &self.replacement
92    }
93
94    /// Consume self and return the replacement hugr.
95    #[inline]
96    pub fn into_replacement(self) -> Hugr {
97        self.replacement
98    }
99
100    /// Subgraph to be replaced.
101    #[inline]
102    pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
103        &self.subgraph
104    }
105
106    /// Check if the replacement can be applied to the given hugr.
107    pub fn is_valid_rewrite(
108        &self,
109        h: &impl HugrView<Node = HostNode>,
110    ) -> Result<(), SimpleReplacementError> {
111        let parent = self.subgraph.get_parent(h);
112
113        // 1. Check the parent node exists and is a DataflowParent.
114        if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
115            return Err(SimpleReplacementError::InvalidParentNode());
116        }
117
118        // 2. Check that all the to-be-removed nodes are children of it and are leaves.
119        for node in self.subgraph.nodes() {
120            if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
121                return Err(SimpleReplacementError::InvalidRemovedNode());
122            }
123        }
124
125        Ok(())
126    }
127
128    /// Get the input and output nodes of the replacement hugr.
129    pub fn get_replacement_io(&self) -> [Node; 2] {
130        self.replacement
131            .get_io(self.replacement.entrypoint())
132            .expect("replacement is a DFG")
133    }
134
135    /// Traverse output boundary edge from `host` to `replacement`.
136    ///
137    /// Given an incoming port in `host` linked to an output boundary port of
138    /// `subgraph`, return the port it will be linked to after application
139    /// of `self`.
140    ///
141    /// The returned port will be in `replacement`, unless the wire in the
142    /// replacement is empty and `boundary` is [`BoundaryMode::SnapToHost`] (the
143    /// default), in which case it will be another `host` port. If
144    /// [`BoundaryMode::IncludeIO`] is passed, the returned port will always
145    /// be in `replacement` even if it is invalid (i.e. it is an IO node in
146    /// the replacement).
147    pub fn linked_replacement_output(
148        &self,
149        port: impl Into<HostPort<HostNode, IncomingPort>>,
150        host: &impl HugrView<Node = HostNode>,
151        boundary: BoundaryMode,
152    ) -> Option<BoundaryPort<HostNode, OutgoingPort>> {
153        let HostPort(node, port) = port.into();
154        let pos = self
155            .subgraph
156            .outgoing_ports()
157            .iter()
158            .position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?;
159
160        Some(self.linked_replacement_output_by_position(pos, host, boundary))
161    }
162
163    /// The outgoing port linked to the i-th output boundary edge of `subgraph`.
164    ///
165    /// This port will be in `replacement` if the i-th output wire is not
166    /// connected to the input, and in `host` otherwise.
167    fn linked_replacement_output_by_position(
168        &self,
169        pos: usize,
170        host: &impl HugrView<Node = HostNode>,
171        boundary: BoundaryMode,
172    ) -> BoundaryPort<HostNode, OutgoingPort> {
173        debug_assert!(
174            pos < self
175                .subgraph()
176                .poly_func_type(host)
177                .into_body()
178                .output_count()
179        );
180
181        // The outgoing ports at the output boundary of `replacement`
182        let [repl_inp, repl_out] = self.get_replacement_io();
183        let (out_node, out_port) = self
184            .replacement
185            .single_linked_output(repl_out, pos)
186            .expect("valid dfg wire");
187
188        if out_node != repl_inp || boundary == BoundaryMode::IncludeIO {
189            BoundaryPort::Replacement(out_node, out_port)
190        } else {
191            let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
192                .first()
193                .expect("non-empty boundary partition");
194            let (out_node, out_port) = host
195                .single_linked_output(in_node, in_port)
196                .expect("valid dfg wire");
197            BoundaryPort::Host(out_node, out_port)
198        }
199    }
200
201    /// Traverse output boundary edge from `replacement` to `host`.
202    ///
203    /// `port` must be an outgoing port linked to the output node of
204    /// `replacement`.
205    ///
206    /// This is the inverse of [`Self::linked_replacement_output`], in the case
207    /// where the latter returns a [`BoundaryPort::Replacement`] port.
208    pub fn linked_host_outputs(
209        &self,
210        port: impl Into<ReplacementPort<OutgoingPort>>,
211        host: &impl HugrView<Node = HostNode>,
212    ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> {
213        let ReplacementPort(node, port) = port.into();
214        let [_, repl_out] = self.get_replacement_io();
215        let positions = self
216            .replacement
217            .linked_inputs(node, port)
218            .filter_map(move |(n, p)| (n == repl_out).then_some(p.index()));
219
220        positions
221            .map(|pos| self.subgraph.outgoing_ports()[pos])
222            .flat_map(|(out_node, out_port)| {
223                let in_nodes_ports = host.linked_inputs(out_node, out_port);
224                in_nodes_ports.map(|(n, p)| HostPort(n, p))
225            })
226    }
227
228    /// Traverse input boundary edge from `host` to `replacement`.
229    ///
230    /// Given an outgoing port in `host` linked to an input boundary port of
231    /// `subgraph`, return the ports it will be linked to after application
232    /// of `self`.
233    ///
234    /// The returned ports will be in `replacement`, unless the wires in the
235    /// replacement are empty and `boundary` is [`BoundaryMode::SnapToHost`]
236    /// (the default), in which case they will be other `host` ports. If
237    /// [`BoundaryMode::IncludeIO`] is passed, the returned ports will
238    /// always be in `replacement` even if they are invalid (i.e. they are
239    /// an IO node in the replacement).
240    pub fn linked_replacement_inputs<'a>(
241        &'a self,
242        port: impl Into<HostPort<HostNode, OutgoingPort>>,
243        host: &'a impl HugrView<Node = HostNode>,
244        boundary: BoundaryMode,
245    ) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> + 'a {
246        let HostPort(node, port) = port.into();
247        let positions = self
248            .subgraph
249            .incoming_ports()
250            .iter()
251            .positions(move |ports| {
252                let (n, p) = *ports.first().expect("non-empty boundary partition");
253                host.single_linked_output(n, p).expect("valid dfg wire") == (node, port)
254            });
255
256        positions
257            .flat_map(move |pos| self.linked_replacement_inputs_by_position(pos, host, boundary))
258    }
259
260    /// The incoming ports linked to the i-th input boundary edge of `subgraph`.
261    fn linked_replacement_inputs_by_position(
262        &self,
263        pos: usize,
264        host: &impl HugrView<Node = HostNode>,
265        boundary: BoundaryMode,
266    ) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> {
267        debug_assert!(
268            pos < self
269                .subgraph()
270                .poly_func_type(host)
271                .into_body()
272                .input_count()
273        );
274
275        let [repl_inp, repl_out] = self.get_replacement_io();
276        self.replacement
277            .linked_inputs(repl_inp, pos)
278            .flat_map(move |(in_node, in_port)| {
279                if in_node != repl_out || boundary == BoundaryMode::IncludeIO {
280                    Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port)))
281                } else {
282                    let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()];
283                    let in_nodes_ports = host.linked_inputs(out_node, out_port);
284                    Either::Right(in_nodes_ports.map(|(n, p)| BoundaryPort::Host(n, p)))
285                }
286            })
287    }
288
289    /// Traverse output boundary edge from `replacement` to `host`.
290    ///
291    /// `port` must be an outgoing port linked to the output node of
292    /// `replacement`.
293    ///
294    /// This is the inverse of [`Self::linked_replacement_output`], in the case
295    /// where the latter returns a [`BoundaryPort::Replacement`] port.
296    pub fn linked_host_input(
297        &self,
298        port: impl Into<ReplacementPort<IncomingPort>>,
299        host: &impl HugrView<Node = HostNode>,
300    ) -> HostPort<HostNode, OutgoingPort> {
301        let ReplacementPort(node, port) = port.into();
302        let (out_node, out_port) = self
303            .replacement
304            .single_linked_output(node, port)
305            .expect("valid dfg wire");
306
307        let [repl_in, _] = self.get_replacement_io();
308        assert!(out_node == repl_in, "not a boundary port");
309
310        let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
311            .first()
312            .expect("non-empty input partition");
313
314        let (host_node, host_port) = host
315            .single_linked_output(in_node, in_port)
316            .expect("valid dfg wire");
317        HostPort(host_node, host_port)
318    }
319
320    /// Get all edges that the replacement would add from outgoing ports in
321    /// `host` to incoming ports in `self.replacement`.
322    ///
323    /// For each pair in the returned vector, the first element is a port in
324    /// `host` and the second is a port in `self.replacement`:
325    ///  - The outgoing host ports are always linked to the input boundary of
326    ///    `subgraph`, i.e. the ports returned by
327    ///    [`SiblingSubgraph::incoming_ports`],
328    ///  - The incoming replacement ports are always linked to output ports of
329    ///    the [`OpTag::Input`] node of `self.replacement`.
330    pub fn incoming_boundary<'a>(
331        &'a self,
332        host: &'a impl HugrView<Node = HostNode>,
333    ) -> impl Iterator<
334        Item = (
335            HostPort<HostNode, OutgoingPort>,
336            ReplacementPort<IncomingPort>,
337        ),
338    > + 'a {
339        // The outgoing ports at the input boundary of `subgraph`
340        let subgraph_outgoing_ports = self
341            .subgraph
342            .incoming_ports()
343            .iter()
344            .map(|in_ports| *in_ports.first().expect("non-empty input partition"))
345            .map(|(node, in_port)| {
346                host.single_linked_output(node, in_port)
347                    .expect("valid dfg wire")
348            });
349
350        subgraph_outgoing_ports
351            .enumerate()
352            .flat_map(|(pos, subg_np)| {
353                self.linked_replacement_inputs_by_position(pos, host, BoundaryMode::SnapToHost)
354                    .filter_map(move |np| Some((np.as_replacement()?, subg_np)))
355            })
356            .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
357                (
358                    HostPort(subgraph_node, subgraph_port),
359                    ReplacementPort(repl_node, repl_port),
360                )
361            })
362    }
363
364    /// Get all edges that the replacement would add from outgoing ports in
365    /// `self.replacement` to incoming ports in `host`.
366    ///
367    /// For each pair in the returned vector, the first element is a port in
368    /// `self.replacement` and the second is a port in `host`:
369    ///  - The outgoing replacement ports are always linked to inputs of the
370    ///    [`OpTag::Output`] node of `self.replacement`,
371    ///  - The incoming host ports are always linked to the output boundary of
372    ///    `subgraph`, i.e. the ports returned by
373    ///    [`SiblingSubgraph::outgoing_ports`],
374    ///
375    /// This panics if self.replacement is not a DFG.
376    pub fn outgoing_boundary<'a>(
377        &'a self,
378        host: &'a impl HugrView<Node = HostNode>,
379    ) -> impl Iterator<
380        Item = (
381            ReplacementPort<OutgoingPort>,
382            HostPort<HostNode, IncomingPort>,
383        ),
384    > + 'a {
385        // The incoming ports at the output boundary of `subgraph`
386        let subgraph_incoming_ports = self.subgraph.outgoing_ports().iter().map(
387            move |&(subgraph_out_node, subgraph_out_port)| {
388                host.linked_inputs(subgraph_out_node, subgraph_out_port)
389            },
390        );
391
392        subgraph_incoming_ports
393            .enumerate()
394            .filter_map(|(pos, subg_all)| {
395                let np = self
396                    .linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
397                    .as_replacement()?;
398                Some((np, subg_all))
399            })
400            .flat_map(|(repl_np, subg_all)| subg_all.map(move |subg_np| (repl_np, subg_np)))
401            .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
402                (
403                    ReplacementPort(repl_node, repl_port),
404                    HostPort(subgraph_node, subgraph_port),
405                )
406            })
407    }
408
409    /// Get all edges that the replacement would add between ports in `host`.
410    ///
411    /// These correspond to direct edges between the input and output nodes
412    /// in the replacement graph.
413    ///
414    /// For each pair in the returned vector, both ports are in `host`:
415    ///  - The outgoing host ports are linked to the input boundary of
416    ///    `subgraph`, i.e. the ports returned by
417    ///    [`SiblingSubgraph::incoming_ports`],
418    ///  - The incoming host ports are linked to the output boundary of
419    ///    `subgraph`, i.e. the ports returned by
420    ///    [`SiblingSubgraph::outgoing_ports`].
421    ///
422    /// This panics if self.replacement is not a DFG.
423    pub fn host_to_host_boundary<'a>(
424        &'a self,
425        host: &'a impl HugrView<Node = HostNode>,
426    ) -> impl Iterator<
427        Item = (
428            HostPort<HostNode, OutgoingPort>,
429            HostPort<HostNode, IncomingPort>,
430        ),
431    > + 'a {
432        // The incoming ports at the output boundary of `subgraph`
433        let subgraph_incoming_ports = self.subgraph.outgoing_ports().iter().map(
434            move |&(subgraph_out_node, subgraph_out_port)| {
435                host.linked_inputs(subgraph_out_node, subgraph_out_port)
436            },
437        );
438
439        subgraph_incoming_ports
440            .enumerate()
441            .filter_map(|(pos, subg_all)| {
442                Some((
443                    self.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
444                        .as_host()?,
445                    subg_all,
446                ))
447            })
448            .flat_map(|(host_np, subg_all)| subg_all.map(move |subg_np| (host_np, subg_np)))
449            .map(
450                |((host_out_node, host_out_port), (host_in_node, host_in_port))| {
451                    (
452                        HostPort(host_out_node, host_out_port),
453                        HostPort(host_in_node, host_in_port),
454                    )
455                },
456            )
457    }
458
459    /// Get the incoming port at the output node of `self.replacement`
460    /// that corresponds to the given outgoing port on the subgraph output
461    /// boundary.
462    ///
463    /// The host `port` should be a port in `self.subgraph().outgoing_ports()`.
464    ///
465    /// This panics if self.replacement is not a DFG.
466    pub fn map_host_output(
467        &self,
468        port: impl Into<HostPort<HostNode, OutgoingPort>>,
469    ) -> Option<ReplacementPort<IncomingPort>> {
470        let HostPort(node, port) = port.into();
471        let pos = self
472            .subgraph
473            .outgoing_ports()
474            .iter()
475            .position(|&node_port| node_port == (node, port))?;
476        let incoming_port: IncomingPort = pos.into();
477        let [_, rep_output] = self.get_replacement_io();
478        Some(ReplacementPort(rep_output, incoming_port))
479    }
480
481    /// Get the incoming port in `subgraph` that corresponds to the given
482    /// replacement input port.
483    ///
484    /// This panics if self.replacement is not a DFG.
485    pub fn map_replacement_input(
486        &self,
487        port: impl Into<ReplacementPort<OutgoingPort>>,
488    ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> + '_ {
489        let ReplacementPort(node, port) = port.into();
490        let [repl_input, _] = self.get_replacement_io();
491
492        let ports = if node == repl_input {
493            self.subgraph.incoming_ports().get(port.index())
494        } else {
495            None
496        };
497        ports
498            .into_iter()
499            .flat_map(|ports| ports.iter().map(|&(n, p)| HostPort(n, p)))
500    }
501
502    /// Get all edges that the replacement would add between `host` and
503    /// `self.replacement`.
504    ///
505    /// This is equivalent to chaining the results of
506    /// [`Self::incoming_boundary`], [`Self::outgoing_boundary`], and
507    /// [`Self::host_to_host_boundary`].
508    ///
509    /// This panics if self.replacement is not a DFG.
510    pub fn all_boundary_edges<'a>(
511        &'a self,
512        host: &'a impl HugrView<Node = HostNode>,
513    ) -> impl Iterator<
514        Item = (
515            BoundaryPort<HostNode, OutgoingPort>,
516            BoundaryPort<HostNode, IncomingPort>,
517        ),
518    > + 'a {
519        let incoming_boundary = self
520            .incoming_boundary(host)
521            .map(|(src, tgt)| (src.into(), tgt.into()));
522        let outgoing_boundary = self
523            .outgoing_boundary(host)
524            .map(|(src, tgt)| (src.into(), tgt.into()));
525        let host_to_host_boundary = self
526            .host_to_host_boundary(host)
527            .map(|(src, tgt)| (src.into(), tgt.into()));
528
529        incoming_boundary
530            .chain(outgoing_boundary)
531            .chain(host_to_host_boundary)
532    }
533
534    /// Map the host nodes in `self` according to `node_map`.
535    ///
536    /// `node_map` must map nodes in the current HUGR of the subgraph to
537    /// its equivalent nodes in some `new_host`.
538    ///
539    /// This converts a replacement that acts on nodes of type `HostNode` to
540    /// a replacement that acts on `new_host`, with nodes of type `N`.
541    pub fn map_host_nodes<N: HugrNode>(
542        &self,
543        node_map: impl Fn(HostNode) -> N,
544        new_host: &impl HugrView<Node = N>,
545    ) -> Result<SimpleReplacement<N>, InvalidReplacement> {
546        let Self {
547            subgraph,
548            replacement,
549        } = self;
550        let subgraph = subgraph.map_nodes(node_map);
551        SimpleReplacement::try_new(subgraph, new_host, replacement.clone())
552    }
553
554    /// Allows to get the [Self::invalidated_nodes] without requiring a
555    /// [HugrView].
556    pub fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
557        self.subgraph.nodes().iter().copied()
558    }
559}
560
561impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
562    type Error = SimpleReplacementError;
563    type Node = HostNode;
564
565    fn verify(&self, h: &impl HugrView<Node = HostNode>) -> Result<(), SimpleReplacementError> {
566        self.is_valid_rewrite(h)
567    }
568
569    #[inline]
570    fn invalidated_nodes(
571        &self,
572        _: &impl HugrView<Node = Self::Node>,
573    ) -> impl Iterator<Item = Self::Node> {
574        self.invalidation_set()
575    }
576}
577
578/// In [`SimpleReplacement::replacement`], IO nodes marking the boundary will
579/// not be valid nodes in the host after the replacement is applied.
580///
581/// This enum allows specifying whether these invalid nodes on the boundary
582/// should be returned or should be resolved to valid nodes in the host.
583#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
584pub enum BoundaryMode {
585    /// Only consider nodes that are valid after the replacement is applied.
586    ///
587    /// This means that nodes in hosts may be returned in places where nodes in
588    /// the replacement would be typically expected.
589    #[default]
590    SnapToHost,
591    /// Include all nodes, including potentially invalid ones (inputs and
592    /// outputs of replacements).
593    IncludeIO,
594}
595
596/// Result of applying a [`SimpleReplacement`].
597pub struct Outcome<HostNode = Node> {
598    /// Map from Node in replacement to corresponding Node in the result Hugr
599    pub node_map: HashMap<Node, HostNode>,
600    /// Nodes removed from the result Hugr and their weights
601    pub removed_nodes: HashMap<HostNode, OpType>,
602}
603
604impl<N: HugrNode> PatchHugrMut for SimpleReplacement<N> {
605    type Outcome = Outcome<N>;
606    const UNCHANGED_ON_FAILURE: bool = true;
607
608    fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
609        self.is_valid_rewrite(h)?;
610
611        let parent = self.subgraph.get_parent(h);
612
613        // We proceed to connect the edges between the newly inserted
614        // replacement and the rest of the graph.
615        //
616        // Existing connections to the removed subgraph will be automatically
617        // removed when the nodes are removed.
618
619        // 1. Get the boundary edges
620        let boundary_edges = self.all_boundary_edges(h).collect_vec();
621
622        let Self {
623            replacement,
624            subgraph,
625            ..
626        } = self;
627
628        // Nodes to remove from the replacement hugr
629        let repl_io = replacement
630            .get_io(replacement.entrypoint())
631            .expect("replacement is DFG-rooted");
632        let repl_entrypoint = replacement.entrypoint();
633
634        // 2. Insert the replacement as a whole.
635        let InsertionResult {
636            inserted_entrypoint: new_entrypoint,
637            mut node_map,
638        } = h.insert_hugr(parent, replacement);
639
640        // remove the Input and Output from h and node_map
641        for node in repl_io {
642            let node_h = node_map[&node];
643            h.remove_node(node_h);
644            node_map.remove(&node);
645        }
646
647        // make all (remaining) replacement top level children children of the parent
648        for child in h.children(new_entrypoint).collect_vec() {
649            h.set_parent(child, parent);
650        }
651
652        // remove the replacement entrypoint from h and node_map
653        h.remove_node(new_entrypoint);
654        node_map.remove(&repl_entrypoint);
655
656        // 3. Insert all boundary edges.
657        for (src, tgt) in boundary_edges {
658            let (src_node, src_port) = src.map_replacement(&node_map);
659            let (tgt_node, tgt_port) = tgt.map_replacement(&node_map);
660            h.connect(src_node, src_port, tgt_node, tgt_port);
661        }
662
663        // 4. Remove all nodes in subgraph and edges between them.
664        let removed_nodes = subgraph
665            .nodes()
666            .iter()
667            .map(|&node| (node, h.remove_node(node)))
668            .collect();
669
670        Ok(Outcome {
671            node_map,
672            removed_nodes,
673        })
674    }
675}
676
677/// Error from a [`SimpleReplacement`] operation.
678#[derive(Debug, Clone, Error, PartialEq, Eq)]
679#[non_exhaustive]
680pub enum SimpleReplacementError {
681    /// Invalid parent node.
682    #[error("Parent node is invalid.")]
683    InvalidParentNode(),
684    /// Node requested for removal is invalid.
685    #[error("A node requested for removal is invalid.")]
686    InvalidRemovedNode(),
687    /// Node in replacement graph is invalid.
688    #[error("A node in the replacement graph is invalid.")]
689    InvalidReplacementNode(),
690    /// Inlining replacement failed.
691    #[error("Inlining replacement failed: {0}")]
692    InliningFailed(#[from] InlineDFGError),
693}
694
695#[cfg(test)]
696pub(in crate::hugr::patch) mod test {
697    use itertools::Itertools;
698    use rstest::{fixture, rstest};
699
700    use std::collections::{BTreeSet, HashMap, HashSet};
701
702    use crate::builder::test::n_identity;
703    use crate::builder::{
704        BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
705        ModuleBuilder, endo_sig, inout_sig,
706    };
707    use crate::extension::prelude::{bool_t, qb_t};
708    use crate::hugr::patch::simple_replace::{BoundaryMode, Outcome};
709    use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort};
710    use crate::hugr::views::{HugrView, SiblingSubgraph};
711    use crate::hugr::{Hugr, HugrMut, Patch};
712    use crate::ops::OpTag;
713    use crate::ops::OpTrait;
714    use crate::ops::handle::NodeHandle;
715    use crate::std_extensions::logic::LogicOp;
716    use crate::std_extensions::logic::test::and_op;
717    use crate::types::{Signature, Type};
718    use crate::utils::test_quantum_extension::{cx_gate, h_gate};
719    use crate::{IncomingPort, Node, OutgoingPort};
720
721    use super::SimpleReplacement;
722
723    /// Creates a hugr like the following:
724    /// --   H   --
725    /// -- [DFG] --
726    /// where [DFG] is:
727    /// ┌───┐     ┌───┐
728    /// ┤ H ├──■──┤ H ├
729    /// ├───┤┌─┴─┐├───┤
730    /// ┤ H ├┤ X ├┤ H ├
731    /// └───┘└───┘└───┘
732    fn make_hugr() -> Result<Hugr, BuildError> {
733        let mut module_builder = ModuleBuilder::new();
734        let _f_id = {
735            let mut func_builder = module_builder
736                .define_function("main", Signature::new_endo([qb_t(), qb_t(), qb_t()]))?;
737
738            let [qb0, qb1, qb2] = func_builder.input_wires_arr();
739
740            let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
741
742            let mut inner_builder =
743                func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
744            let inner_graph = {
745                let [wire0, wire1] = inner_builder.input_wires_arr();
746                let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
747                let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
748                let wire45 = inner_builder
749                    .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
750                let [wire4, wire5] = wire45.outputs_arr();
751                let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
752                let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
753                inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
754            }?;
755
756            func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
757        };
758        Ok(module_builder.finish_hugr()?)
759    }
760
761    #[fixture]
762    pub(in crate::hugr::patch) fn simple_hugr() -> Hugr {
763        make_hugr().unwrap()
764    }
765    /// Creates a hugr with a DFG root like the following:
766    /// ┌───┐
767    /// ┤ H ├──■──
768    /// ├───┤┌─┴─┐
769    /// ┤ H ├┤ X ├
770    /// └───┘└───┘
771    fn make_dfg_hugr() -> Result<Hugr, BuildError> {
772        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
773        let [wire0, wire1] = dfg_builder.input_wires_arr();
774        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
775        let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
776        let wire45 =
777            dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
778        dfg_builder.finish_hugr_with_outputs(wire45.outputs())
779    }
780
781    #[fixture]
782    pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr {
783        make_dfg_hugr().unwrap()
784    }
785
786    /// Creates a hugr with a DFG root like the following:
787    /// ─────
788    /// ┌───┐
789    /// ┤ H ├
790    /// └───┘
791    fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
792        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
793
794        let [wire0, wire1] = dfg_builder.input_wires_arr();
795        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
796        let wire2out = wire2.outputs().exactly_one().unwrap();
797        let wireoutvec = vec![wire0, wire2out];
798        dfg_builder.finish_hugr_with_outputs(wireoutvec)
799    }
800
801    #[fixture]
802    pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr {
803        make_dfg_hugr2().unwrap()
804    }
805
806    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
807    ///                     ┌─────────┐
808    ///                ┌────┤ (1) NOT ├──
809    ///  ┌─────────┐   │    └─────────┘
810    /// ─┤ (0) NOT ├───┤
811    ///  └─────────┘   │    ┌─────────┐
812    ///                └────┤ (2) NOT ├──
813    ///                     └─────────┘
814    /// This can be replaced with an empty hugr coping the input to both
815    /// outputs.
816    ///
817    /// Returns the hugr and the nodes of the NOT gates, in order.
818    #[fixture]
819    pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
820        let mut dfg_builder =
821            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
822        let [b] = dfg_builder.input_wires_arr();
823
824        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
825        let [b] = not_inp.outputs_arr();
826
827        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
828        let [b0] = not_0.outputs_arr();
829        let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
830        let [b1] = not_1.outputs_arr();
831
832        (
833            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
834            vec![not_inp.node(), not_0.node(), not_1.node()],
835        )
836    }
837
838    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
839    ///                     ┌─────────┐
840    ///                ┌────┤ (1) NOT ├──
841    ///  ┌─────────┐   │    └─────────┘
842    /// ─┤ (0) NOT ├───┤
843    ///  └─────────┘   │
844    ///                └─────────────────
845    ///
846    /// This can be replaced with a single NOT op, coping the input to the first
847    /// output.
848    ///
849    /// Returns the hugr and the nodes of the NOT ops, in order.
850    #[fixture]
851    pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
852        let mut dfg_builder =
853            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
854        let [b] = dfg_builder.input_wires_arr();
855
856        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
857        let [b] = not_inp.outputs_arr();
858
859        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
860        let [b0] = not_0.outputs_arr();
861        let b1 = b;
862
863        (
864            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
865            vec![not_inp.node(), not_0.node()],
866        )
867    }
868
869    #[rstest]
870    /// Replace the
871    ///      ┌───┐
872    /// ──■──┤ H ├
873    /// ┌─┴─┐├───┤
874    /// ┤ X ├┤ H ├
875    /// └───┘└───┘
876    /// part of
877    /// ┌───┐     ┌───┐
878    /// ┤ H ├──■──┤ H ├
879    /// ├───┤┌─┴─┐├───┤
880    /// ┤ H ├┤ X ├┤ H ├
881    /// └───┘└───┘└───┘
882    /// with
883    /// ┌───┐
884    /// ┤ H ├──■──
885    /// ├───┤┌─┴─┐
886    /// ┤ H ├┤ X ├
887    /// └───┘└───┘
888    fn test_simple_replacement(
889        simple_hugr: Hugr,
890        dfg_hugr: Hugr,
891        #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
892    ) {
893        let mut h: Hugr = simple_hugr;
894        // 1. Locate the CX and its successor H's in h
895        let h_node_cx: Node = h
896            .entry_descendants()
897            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
898            .unwrap();
899        let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
900        let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
901        // 2. Construct a new DFG-rooted hugr for the replacement
902        let n: Hugr = dfg_hugr;
903        // 3. Construct the input and output matchings
904        // 3.1. Locate the CX and its predecessor H's in n
905        let n_node_cx = n
906            .entry_descendants()
907            .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
908            .unwrap();
909        // 3.2. Locate the ports we need to specify as "glue" in n
910        let (n_cx_out_0, _n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
911        let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
912        // 3.3. Locate the ports we need to specify as "glue" in h
913        let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
914        // 4. Define the replacement
915        let r = SimpleReplacement {
916            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
917            replacement: n,
918        };
919
920        // Check output boundary
921        assert_eq!(
922            r.map_host_output((h_node_h0, h_h0_out)).unwrap(),
923            ReplacementPort::from((r.get_replacement_io()[1], n_port_2))
924        );
925
926        // Check invalidation set
927        assert_eq!(
928            HashSet::<_>::from_iter(r.invalidated_nodes(&h)),
929            HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1]),
930        );
931
932        applicator(&mut h, r);
933        // Expect [DFG] to be replaced with:
934        // ┌───┐┌───┐
935        // ┤ H ├┤ H ├──■──
936        // ├───┤├───┤┌─┴─┐
937        // ┤ H ├┤ H ├┤ X ├
938        // └───┘└───┘└───┘
939        assert_eq!(h.validate(), Ok(()));
940    }
941
942    #[rstest]
943    /// Replace the
944    ///
945    /// ──■──
946    /// ┌─┴─┐
947    /// ┤ X ├
948    /// └───┘
949    /// part of
950    /// ┌───┐     ┌───┐
951    /// ┤ H ├──■──┤ H ├
952    /// ├───┤┌─┴─┐├───┤
953    /// ┤ H ├┤ X ├┤ H ├
954    /// └───┘└───┘└───┘
955    /// with
956    /// ─────
957    /// ┌───┐
958    /// ┤ H ├
959    /// └───┘
960    fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
961        let mut h: Hugr = simple_hugr;
962
963        // 1. Locate the CX in h
964        let h_node_cx: Node = h
965            .entry_descendants()
966            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
967            .unwrap();
968        let s: Vec<Node> = vec![h_node_cx];
969        // 2. Construct a new DFG-rooted hugr for the replacement
970        let n: Hugr = dfg_hugr2;
971        // 3. Construct the input and output matchings
972        // 3.1. Locate the Output and its predecessor H in n
973        let n_node_output = n.get_io(n.entrypoint()).unwrap()[1];
974        let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
975        // 4. Define the replacement
976        let r = SimpleReplacement {
977            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
978            replacement: n,
979        };
980        let Outcome {
981            node_map,
982            removed_nodes,
983        } = h.apply_patch(r).unwrap();
984
985        assert_eq!(
986            node_map.into_keys().collect::<HashSet<_>>(),
987            [n_node_h].into_iter().collect::<HashSet<_>>(),
988        );
989        assert_eq!(
990            removed_nodes.into_keys().collect::<HashSet<_>>(),
991            [h_node_cx].into_iter().collect::<HashSet<_>>(),
992        );
993
994        // Expect [DFG] to be replaced with:
995        // ┌───┐┌───┐
996        // ┤ H ├┤ H ├
997        // ├───┤├───┤┌───┐
998        // ┤ H ├┤ H ├┤ H ├
999        // └───┘└───┘└───┘
1000        assert_eq!(h.validate(), Ok(()));
1001    }
1002
1003    #[test]
1004    fn test_replace_cx_cross() {
1005        let q_row: Vec<Type> = vec![qb_t(), qb_t()];
1006        let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
1007        let mut circ = builder.as_circuit(builder.input_wires());
1008        circ.append(cx_gate(), [0, 1]).unwrap();
1009        circ.append(cx_gate(), [1, 0]).unwrap();
1010        let wires = circ.finish();
1011        let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
1012        let replacement = h.clone();
1013        let orig = h.clone();
1014
1015        let removal = h
1016            .entry_descendants()
1017            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
1018            .collect_vec();
1019        h.apply_patch(
1020            SimpleReplacement::try_new(
1021                SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
1022                &h,
1023                replacement,
1024            )
1025            .unwrap(),
1026        )
1027        .unwrap();
1028
1029        // They should be the same, up to node indices
1030        assert_eq!(h.num_edges(), orig.num_edges());
1031    }
1032
1033    #[test]
1034    fn test_replace_after_copy() {
1035        let one_bit = vec![bool_t()];
1036        let two_bit = vec![bool_t(), bool_t()];
1037
1038        let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
1039        let inw = builder.input_wires().exactly_one().unwrap();
1040        let outw = builder
1041            .add_dataflow_op(and_op(), [inw, inw])
1042            .unwrap()
1043            .outputs();
1044        let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
1045
1046        let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
1047        let inw = builder.input_wires();
1048        let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
1049        let repl = builder.finish_hugr_with_outputs(outw).unwrap();
1050
1051        let orig = h.clone();
1052
1053        let removal = h
1054            .entry_descendants()
1055            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
1056            .collect_vec();
1057
1058        h.apply_patch(
1059            SimpleReplacement::try_new(
1060                SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
1061                &h,
1062                repl,
1063            )
1064            .unwrap(),
1065        )
1066        .unwrap();
1067
1068        // Nothing changed
1069        assert_eq!(h.num_nodes(), orig.num_nodes());
1070    }
1071
1072    /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the
1073    /// input directly to the outputs.
1074    ///
1075    /// https://github.com/CQCL/hugr/issues/1190
1076    #[rstest]
1077    fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
1078        let (mut hugr, nodes) = dfg_hugr_copy_bools;
1079        let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
1080
1081        let replacement = {
1082            let b =
1083                DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1084            let [w] = b.input_wires_arr();
1085            b.finish_hugr_with_outputs([w, w]).unwrap()
1086        };
1087
1088        let subgraph =
1089            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
1090                .unwrap();
1091
1092        let rewrite = SimpleReplacement {
1093            subgraph,
1094            replacement,
1095        };
1096        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1097
1098        assert_eq!(hugr.validate(), Ok(()));
1099        assert_eq!(hugr.entry_descendants().count(), 3);
1100    }
1101
1102    /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting
1103    /// the input directly to the output.
1104    ///
1105    /// https://github.com/CQCL/hugr/issues/1323
1106    #[rstest]
1107    fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
1108        let (mut hugr, nodes) = dfg_hugr_half_not_bools;
1109        let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
1110
1111        let replacement = {
1112            let mut b =
1113                DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
1114            let [w] = b.input_wires_arr();
1115            let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
1116            let [w_not] = not.outputs_arr();
1117            b.finish_hugr_with_outputs([w, w_not]).unwrap()
1118        };
1119
1120        let subgraph =
1121            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
1122
1123        let rewrite = SimpleReplacement {
1124            subgraph,
1125            replacement,
1126        };
1127        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
1128
1129        assert_eq!(hugr.validate(), Ok(()));
1130        assert_eq!(hugr.entry_descendants().count(), 4);
1131    }
1132
1133    #[rstest]
1134    fn test_nested_replace(dfg_hugr2: Hugr) {
1135        // replace a node with a hugr with children
1136
1137        let mut h = dfg_hugr2;
1138        let h_node = h
1139            .entry_descendants()
1140            .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
1141            .unwrap();
1142
1143        // build a nested identity dfg
1144        let mut nest_build = DFGBuilder::new(Signature::new_endo([qb_t()])).unwrap();
1145        let [input] = nest_build.input_wires_arr();
1146        let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
1147        let inner_dfg = n_identity(inner_build).unwrap();
1148        let replacement = nest_build
1149            .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
1150            .unwrap();
1151        let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
1152
1153        let rewrite = SimpleReplacement::try_new(subgraph, &h, replacement).unwrap();
1154
1155        assert_eq!(h.entry_descendants().count(), 4);
1156
1157        rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
1158        h.validate().unwrap_or_else(|e| panic!("{e}"));
1159
1160        assert_eq!(h.entry_descendants().count(), 6);
1161    }
1162
1163    /// A dfg hugr with 1 input -> copy -> 2x NOT -> 2x copy -> 4 outputs
1164    #[fixture]
1165    fn copy_not_not_copy_hugr() -> Hugr {
1166        let mut b = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(); 4])).unwrap();
1167        let [w] = b.input_wires_arr();
1168        let not1 = b.add_dataflow_op(LogicOp::Not, [w]).unwrap();
1169        let not2 = b.add_dataflow_op(LogicOp::Not, [w]).unwrap();
1170
1171        let [out1] = not1.outputs_arr();
1172        let [out2] = not2.outputs_arr();
1173
1174        b.finish_hugr_with_outputs([out1, out2, out1, out2])
1175            .unwrap()
1176    }
1177
1178    #[rstest]
1179    fn test_boundary_traversal_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1180        let hugr = copy_not_not_copy_hugr;
1181        let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1182        let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1183        let subg_incoming = vec![
1184            vec![(not1, IncomingPort::from(0))],
1185            vec![(not2, IncomingPort::from(0))],
1186        ];
1187        let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1188
1189        let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1190
1191        // Create an empty replacement (just copies)
1192        let repl = {
1193            let b = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
1194            let [w1, w2] = b.input_wires_arr();
1195            let repl_hugr = b.finish_hugr_with_outputs([w1, w2]).unwrap();
1196            SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap()
1197        };
1198
1199        // Test linked_replacement_inputs with empty replacement
1200        let replacement_inputs: Vec<_> = repl
1201            .linked_replacement_inputs(
1202                (inp, OutgoingPort::from(0)),
1203                &hugr,
1204                BoundaryMode::SnapToHost,
1205            )
1206            .collect();
1207
1208        assert_eq!(
1209            BTreeSet::from_iter(replacement_inputs),
1210            (0..4)
1211                .map(|i| BoundaryPort::Host(out, IncomingPort::from(i)))
1212                .collect()
1213        );
1214
1215        // Test linked_replacement_output with empty replacement
1216        let replacement_output = (0..4)
1217            .map(|i| {
1218                repl.linked_replacement_output(
1219                    (out, IncomingPort::from(i)),
1220                    &hugr,
1221                    BoundaryMode::SnapToHost,
1222                )
1223                .unwrap()
1224            })
1225            .collect_vec();
1226
1227        assert_eq!(
1228            replacement_output,
1229            vec![BoundaryPort::Host(inp, OutgoingPort::from(0)); 4]
1230        );
1231    }
1232
1233    #[rstest]
1234    fn test_boundary_traversal_copy_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1235        let hugr = copy_not_not_copy_hugr;
1236        let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1237        let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1238        let subg_incoming = vec![vec![
1239            (not1, IncomingPort::from(0)),
1240            (not2, IncomingPort::from(0)),
1241        ]];
1242        let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1243
1244        let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1245
1246        // Create an empty replacement (just copies)
1247        let repl = {
1248            let b = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap();
1249            let [w] = b.input_wires_arr();
1250            let repl_hugr = b.finish_hugr_with_outputs([w, w]).unwrap();
1251            SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap()
1252        };
1253
1254        let replacement_inputs: Vec<_> = repl
1255            .linked_replacement_inputs(
1256                (inp, OutgoingPort::from(0)),
1257                &hugr,
1258                BoundaryMode::SnapToHost,
1259            )
1260            .collect();
1261
1262        assert_eq!(
1263            BTreeSet::from_iter(replacement_inputs),
1264            (0..4)
1265                .map(|i| BoundaryPort::Host(out, IncomingPort::from(i)))
1266                .collect()
1267        );
1268
1269        let replacement_output = (0..4)
1270            .map(|i| {
1271                repl.linked_replacement_output(
1272                    (out, IncomingPort::from(i)),
1273                    &hugr,
1274                    BoundaryMode::SnapToHost,
1275                )
1276                .unwrap()
1277            })
1278            .collect_vec();
1279
1280        assert_eq!(
1281            replacement_output,
1282            vec![BoundaryPort::Host(inp, OutgoingPort::from(0)); 4]
1283        );
1284    }
1285
1286    #[rstest]
1287    fn test_boundary_traversal_non_empty_replacement(copy_not_not_copy_hugr: Hugr) {
1288        let hugr = copy_not_not_copy_hugr;
1289        let [inp, out] = hugr.get_io(hugr.entrypoint()).unwrap();
1290        let [not1, not2] = hugr.output_neighbours(inp).collect_array().unwrap();
1291        let subg_incoming = vec![
1292            vec![(not1, IncomingPort::from(0))],
1293            vec![(not2, IncomingPort::from(0))],
1294        ];
1295        let subg_outgoing = [not1, not2].map(|n| (n, OutgoingPort::from(0))).to_vec();
1296
1297        let subgraph = SiblingSubgraph::try_new(subg_incoming, subg_outgoing, &hugr).unwrap();
1298
1299        // Create a replacement with a single NOT gate
1300        let (repl, or_node) = {
1301            let mut b = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
1302            let [w1, w2] = b.input_wires_arr();
1303            let or_handle = b.add_dataflow_op(LogicOp::Or, [w1, w2]).unwrap();
1304            let [out] = or_handle.outputs_arr();
1305            let repl_hugr = b.finish_hugr_with_outputs([out, out]).unwrap();
1306            (
1307                SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).unwrap(),
1308                or_handle.node(),
1309            )
1310        };
1311
1312        let replacement_inputs: Vec<_> = repl
1313            .linked_replacement_inputs(
1314                (inp, OutgoingPort::from(0)),
1315                &hugr,
1316                BoundaryMode::SnapToHost,
1317            )
1318            .collect();
1319
1320        assert_eq!(
1321            BTreeSet::from_iter(replacement_inputs),
1322            (0..2)
1323                .map(|i| BoundaryPort::Replacement(or_node, IncomingPort::from(i)))
1324                .collect()
1325        );
1326        assert_eq!(
1327            repl.linked_host_input((or_node, IncomingPort::from(0)), &hugr),
1328            (inp, OutgoingPort::from(0)).into()
1329        );
1330
1331        let replacement_output = (0..4)
1332            .map(|i| {
1333                repl.linked_replacement_output(
1334                    (out, IncomingPort::from(i)),
1335                    &hugr,
1336                    BoundaryMode::SnapToHost,
1337                )
1338                .unwrap()
1339            })
1340            .collect_vec();
1341
1342        assert_eq!(
1343            replacement_output,
1344            vec![BoundaryPort::Replacement(or_node, OutgoingPort::from(0)); 4]
1345        );
1346        assert_eq!(
1347            BTreeSet::from_iter(repl.linked_host_outputs((or_node, OutgoingPort::from(0)), &hugr)),
1348            BTreeSet::from_iter((0..4).map(|i| HostPort(out, IncomingPort::from(i))))
1349        );
1350    }
1351
1352    use crate::hugr::patch::replace::Replacement;
1353    fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
1354        use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec};
1355
1356        let [in_, out] = s.get_replacement_io();
1357        let mu_inp = s
1358            .incoming_boundary(h)
1359            .map(
1360                |(HostPort(src, src_port), ReplacementPort(tgt, tgt_port))| {
1361                    if tgt == out {
1362                        unimplemented!()
1363                    }
1364                    NewEdgeSpec {
1365                        src,
1366                        tgt,
1367                        kind: NewEdgeKind::Value {
1368                            src_pos: src_port,
1369                            tgt_pos: tgt_port,
1370                        },
1371                    }
1372                },
1373            )
1374            .collect();
1375        let mu_out = s
1376            .outgoing_boundary(h)
1377            .map(
1378                |(ReplacementPort(src, src_port), HostPort(tgt, tgt_port))| {
1379                    if src == in_ {
1380                        unimplemented!()
1381                    }
1382                    NewEdgeSpec {
1383                        src,
1384                        tgt,
1385                        kind: NewEdgeKind::Value {
1386                            src_pos: src_port,
1387                            tgt_pos: tgt_port,
1388                        },
1389                    }
1390                },
1391            )
1392            .collect();
1393        let mut replacement = s.replacement;
1394        replacement.remove_node(in_);
1395        replacement.remove_node(out);
1396        Replacement {
1397            removal: s.subgraph.nodes().to_vec(),
1398            replacement,
1399            adoptions: HashMap::new(),
1400            mu_inp,
1401            mu_out,
1402            mu_new: vec![],
1403        }
1404    }
1405
1406    fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
1407        h.apply_patch(rw).unwrap();
1408    }
1409
1410    fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
1411        h.apply_patch(to_replace(h, rw)).unwrap();
1412    }
1413}