hugr_core/hugr/patch/
simple_replace.rs

1//! Implementation of the `SimpleReplace` operation.
2
3use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7use crate::hugr::views::SiblingSubgraph;
8pub use crate::hugr::views::sibling_subgraph::InvalidReplacement;
9use crate::hugr::{HugrMut, HugrView};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::types::EdgeKind;
12use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, PortIndex};
13
14use itertools::Itertools;
15
16use thiserror::Error;
17
18use super::inline_dfg::InlineDFGError;
19use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort};
20
21/// Specification of a simple replacement operation.
22///
23/// # Type parameters
24///
25/// - `N`: The type of nodes in the host hugr.
26#[derive(Debug, Clone)]
27pub struct SimpleReplacement<HostNode = Node> {
28    /// The subgraph of the host hugr to be replaced.
29    subgraph: SiblingSubgraph<HostNode>,
30    /// A hugr with DFG root (consisting of replacement nodes).
31    replacement: Hugr,
32}
33
34impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
35    /// Create a new [`SimpleReplacement`] specification without checking that
36    /// the replacement has the same signature as the subgraph.
37    #[inline]
38    pub fn new_unchecked(subgraph: SiblingSubgraph<HostNode>, replacement: Hugr) -> Self {
39        Self {
40            subgraph,
41            replacement,
42        }
43    }
44
45    /// Create a new [`SimpleReplacement`] specification.
46    ///
47    /// Return a [`InvalidReplacement::InvalidSignature`] error if `subgraph`
48    /// and `replacement` have different signatures.
49    pub fn try_new(
50        subgraph: SiblingSubgraph<HostNode>,
51        host: &impl HugrView<Node = HostNode>,
52        replacement: Hugr,
53    ) -> Result<Self, InvalidReplacement> {
54        let subgraph_sig = subgraph.signature(host);
55        let repl_sig =
56            replacement
57                .inner_function_type()
58                .ok_or(InvalidReplacement::InvalidDataflowGraph {
59                    node: replacement.entrypoint(),
60                    op: replacement.get_optype(replacement.entrypoint()).to_owned(),
61                })?;
62        if subgraph_sig != repl_sig {
63            return Err(InvalidReplacement::InvalidSignature {
64                expected: subgraph_sig,
65                actual: Some(repl_sig.into_owned()),
66            });
67        }
68        Ok(Self {
69            subgraph,
70            replacement,
71        })
72    }
73
74    /// The replacement hugr.
75    #[inline]
76    pub fn replacement(&self) -> &Hugr {
77        &self.replacement
78    }
79
80    /// Consume self and return the replacement hugr.
81    #[inline]
82    pub fn into_replacement(self) -> Hugr {
83        self.replacement
84    }
85
86    /// Subgraph to be replaced.
87    #[inline]
88    pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
89        &self.subgraph
90    }
91
92    /// Check if the replacement can be applied to the given hugr.
93    pub fn is_valid_rewrite(
94        &self,
95        h: &impl HugrView<Node = HostNode>,
96    ) -> Result<(), SimpleReplacementError> {
97        let parent = self.subgraph.get_parent(h);
98
99        // 1. Check the parent node exists and is a DataflowParent.
100        if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
101            return Err(SimpleReplacementError::InvalidParentNode());
102        }
103
104        // 2. Check that all the to-be-removed nodes are children of it and are leaves.
105        for node in self.subgraph.nodes() {
106            if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
107                return Err(SimpleReplacementError::InvalidRemovedNode());
108            }
109        }
110
111        Ok(())
112    }
113
114    /// Get the input and output nodes of the replacement hugr.
115    pub fn get_replacement_io(&self) -> [Node; 2] {
116        self.replacement
117            .get_io(self.replacement.entrypoint())
118            .expect("replacement is a DFG")
119    }
120
121    /// Get all edges that the replacement would add from outgoing ports in
122    /// `host` to incoming ports in `self.replacement`.
123    ///
124    /// For each pair in the returned vector, the first element is a port in
125    /// `host` and the second is a port in `self.replacement`:
126    ///  - The outgoing host ports are always linked to the input boundary of
127    ///    `subgraph`, i.e. the ports returned by
128    ///    [`SiblingSubgraph::incoming_ports`],
129    ///  - The incoming replacement ports are always linked to output ports of
130    ///    the [`OpTag::Input`] node of `self.replacement`.
131    pub fn incoming_boundary<'a>(
132        &'a self,
133        host: &'a impl HugrView<Node = HostNode>,
134    ) -> impl Iterator<
135        Item = (
136            HostPort<HostNode, OutgoingPort>,
137            ReplacementPort<IncomingPort>,
138        ),
139    > + 'a {
140        // The outgoing ports at the input boundary of `subgraph`
141        let subgraph_outgoing_ports = self
142            .subgraph
143            .incoming_ports()
144            .iter()
145            .map(|in_ports| *in_ports.first().expect("non-empty input partition"))
146            .map(|(node, in_port)| {
147                host.single_linked_output(node, in_port)
148                    .expect("valid dfg wire")
149            });
150
151        // The incoming ports at the input boundary of `replacement`
152        let [repl_inp, _] = self.get_replacement_io();
153        let repl_incoming_ports = self
154            .replacement
155            .node_outputs(repl_inp)
156            .filter(move |&port| is_value_port(&self.replacement, repl_inp, port))
157            .map(move |repl_out_port| {
158                self.replacement
159                    .linked_inputs(repl_inp, repl_out_port)
160                    .filter(|&(node, _)| self.replacement.get_optype(node).tag() != OpTag::Output)
161            });
162
163        // Zip the two iterators and add edges from each outgoing port to all
164        // corresponding incoming ports.
165        subgraph_outgoing_ports.zip(repl_incoming_ports).flat_map(
166            |((subgraph_out_node, subgraph_out_port), repl_all_incoming)| {
167                // add edge from outgoing port in subgraph to incoming port in
168                // replacement
169                repl_all_incoming.map(move |(repl_inp_node, repl_inp_port)| {
170                    (
171                        HostPort(subgraph_out_node, subgraph_out_port),
172                        ReplacementPort(repl_inp_node, repl_inp_port),
173                    )
174                })
175            },
176        )
177    }
178
179    /// Get all edges that the replacement would add from outgoing ports in
180    /// `self.replacement` to incoming ports in `host`.
181    ///
182    /// For each pair in the returned vector, the first element is a port in
183    /// `self.replacement` and the second is a port in `host`:
184    ///  - The outgoing replacement ports are always linked to inputs of the
185    ///    [`OpTag::Output`] node of `self.replacement`,
186    ///  - The incoming host ports are always linked to the output boundary of
187    ///    `subgraph`, i.e. the ports returned by
188    ///    [`SiblingSubgraph::outgoing_ports`],
189    ///
190    /// This panics if self.replacement is not a DFG.
191    pub fn outgoing_boundary<'a>(
192        &'a self,
193        host: &'a impl HugrView<Node = HostNode>,
194    ) -> impl Iterator<
195        Item = (
196            ReplacementPort<OutgoingPort>,
197            HostPort<HostNode, IncomingPort>,
198        ),
199    > + 'a {
200        // The incoming ports at the output boundary of `subgraph`
201        let subgraph_incoming_ports =
202            self.subgraph
203                .outgoing_ports()
204                .iter()
205                .map(|&(subgraph_out_node, subgraph_out_port)| {
206                    host.linked_inputs(subgraph_out_node, subgraph_out_port)
207                });
208
209        // The outgoing ports at the output boundary of `replacement`
210        let [_, repl_out] = self.get_replacement_io();
211        let repl_outgoing_ports = self
212            .replacement
213            .node_inputs(repl_out)
214            .filter(move |&port| is_value_port(&self.replacement, repl_out, port))
215            .map(move |repl_in_port| {
216                self.replacement
217                    .single_linked_output(repl_out, repl_in_port)
218                    .expect("valid dfg wire")
219            });
220
221        repl_outgoing_ports.zip(subgraph_incoming_ports).flat_map(
222            |((repl_out_node, repl_out_port), subgraph_all_incoming)| {
223                if self.replacement.get_optype(repl_out_node).tag() != OpTag::Input {
224                    Some(
225                        subgraph_all_incoming.map(move |(subgraph_in_node, subgraph_in_port)| {
226                            (
227                                // the new output node will be updated after insertion
228                                ReplacementPort(repl_out_node, repl_out_port),
229                                HostPort(subgraph_in_node, subgraph_in_port),
230                            )
231                        }),
232                    )
233                    .into_iter()
234                    .flatten()
235                } else {
236                    None.into_iter().flatten()
237                }
238            },
239        )
240    }
241
242    /// Get all edges that the replacement would add between ports in `host`.
243    ///
244    /// These correspond to direct edges between the input and output nodes
245    /// in the replacement graph.
246    ///
247    /// For each pair in the returned vector, both ports are in `host`:
248    ///  - The outgoing host ports are linked to the input boundary of
249    ///    `subgraph`, i.e. the ports returned by
250    ///    [`SiblingSubgraph::incoming_ports`],
251    ///  - The incoming host ports are linked to the output boundary of
252    ///    `subgraph`, i.e. the ports returned by
253    ///    [`SiblingSubgraph::outgoing_ports`].
254    ///
255    /// This panics if self.replacement is not a DFG.
256    pub fn host_to_host_boundary<'a>(
257        &'a self,
258        host: &'a impl HugrView<Node = HostNode>,
259    ) -> impl Iterator<
260        Item = (
261            HostPort<HostNode, OutgoingPort>,
262            HostPort<HostNode, IncomingPort>,
263        ),
264    > + 'a {
265        let [repl_in, repl_out] = self.get_replacement_io();
266
267        let empty_wires = self
268            .replacement
269            .node_inputs(repl_out)
270            .filter(move |&port| is_value_port(&self.replacement, repl_out, port))
271            .filter_map(move |repl_in_port| {
272                let (repl_out_node, repl_out_port) = self
273                    .replacement
274                    .single_linked_output(repl_out, repl_in_port)
275                    .expect("valid dfg wire");
276                (repl_out_node == repl_in).then_some((repl_out_port, repl_in_port))
277            });
278
279        // The outgoing ports at the input boundary of `subgraph`
280        let subgraph_input_boundary = self
281            .subgraph
282            .incoming_ports()
283            .iter()
284            .map(|node_ports| {
285                let (node, port) = *node_ports.first().expect("non-empty boundary partition");
286                host.single_linked_output(node, port)
287                    .expect("valid dfg wire")
288            })
289            .collect_vec();
290        // The incoming ports at the output boundary of `subgraph`
291        let subgraph_output_boundary = self
292            .subgraph
293            .outgoing_ports()
294            .iter()
295            .map(|&(node, port)| host.linked_inputs(node, port).collect_vec())
296            .collect_vec();
297
298        empty_wires.flat_map(move |(repl_out_port, repl_in_port)| {
299            let (host_out_node, host_out_port) = subgraph_input_boundary[repl_out_port.index()];
300            subgraph_output_boundary[repl_in_port.index()]
301                .clone()
302                .into_iter()
303                .map(move |(host_in_node, host_in_port)| {
304                    (
305                        HostPort(host_out_node, host_out_port),
306                        HostPort(host_in_node, host_in_port),
307                    )
308                })
309        })
310    }
311
312    /// Get the incoming port at the output node of `self.replacement`
313    /// that corresponds to the given outgoing port on the subgraph output
314    /// boundary.
315    ///
316    /// The host `port` should be a port in `self.subgraph().outgoing_ports()`.
317    ///
318    /// This panics if self.replacement is not a DFG.
319    pub fn map_host_output(
320        &self,
321        port: impl Into<HostPort<HostNode, OutgoingPort>>,
322    ) -> Option<ReplacementPort<IncomingPort>> {
323        let HostPort(node, port) = port.into();
324        let pos = self
325            .subgraph
326            .outgoing_ports()
327            .iter()
328            .position(|&node_port| node_port == (node, port))?;
329        let incoming_port: IncomingPort = pos.into();
330        let [_, rep_output] = self.get_replacement_io();
331        Some(ReplacementPort(rep_output, incoming_port))
332    }
333
334    /// Get the incoming ports in the input boundary of `subgraph` that
335    /// correspond to the given output port at the input node of `replacement`
336    ///
337    /// Return ports in `self.subgraph().incoming_ports()`.
338    ///
339    /// This panics if self.replacement is not a DFG.
340    pub fn map_replacement_input(
341        &self,
342        port: impl Into<ReplacementPort<OutgoingPort>>,
343    ) -> impl Iterator<Item = HostPort<HostNode, IncomingPort>> + '_ {
344        let ReplacementPort(node, port) = port.into();
345        let [repl_input, _] = self.get_replacement_io();
346
347        let ports = if node == repl_input {
348            self.subgraph.incoming_ports().get(port.index())
349        } else {
350            None
351        };
352        ports
353            .into_iter()
354            .flat_map(|ports| ports.iter().map(|&(n, p)| HostPort(n, p)))
355    }
356
357    /// Get all edges that the replacement would add between `host` and
358    /// `self.replacement`.
359    ///
360    /// This is equivalent to chaining the results of
361    /// [`Self::incoming_boundary`], [`Self::outgoing_boundary`], and
362    /// [`Self::host_to_host_boundary`].
363    ///
364    /// This panics if self.replacement is not a DFG.
365    pub fn all_boundary_edges<'a>(
366        &'a self,
367        host: &'a impl HugrView<Node = HostNode>,
368    ) -> impl Iterator<
369        Item = (
370            BoundaryPort<HostNode, OutgoingPort>,
371            BoundaryPort<HostNode, IncomingPort>,
372        ),
373    > + 'a {
374        let incoming_boundary = self
375            .incoming_boundary(host)
376            .map(|(src, tgt)| (src.into(), tgt.into()));
377        let outgoing_boundary = self
378            .outgoing_boundary(host)
379            .map(|(src, tgt)| (src.into(), tgt.into()));
380        let host_to_host_boundary = self
381            .host_to_host_boundary(host)
382            .map(|(src, tgt)| (src.into(), tgt.into()));
383
384        incoming_boundary
385            .chain(outgoing_boundary)
386            .chain(host_to_host_boundary)
387    }
388}
389
390impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
391    type Error = SimpleReplacementError;
392    type Node = HostNode;
393
394    fn verify(&self, h: &impl HugrView<Node = HostNode>) -> Result<(), SimpleReplacementError> {
395        self.is_valid_rewrite(h)
396    }
397
398    #[inline]
399    fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
400        self.subgraph.nodes().iter().copied()
401    }
402}
403
404/// Result of applying a [`SimpleReplacement`].
405pub struct Outcome<HostNode = Node> {
406    /// Map from Node in replacement to corresponding Node in the result Hugr
407    pub node_map: HashMap<Node, HostNode>,
408    /// Nodes removed from the result Hugr and their weights
409    pub removed_nodes: HashMap<HostNode, OpType>,
410}
411
412impl<N: HugrNode> PatchHugrMut for SimpleReplacement<N> {
413    type Outcome = Outcome<N>;
414    const UNCHANGED_ON_FAILURE: bool = true;
415
416    fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
417        self.is_valid_rewrite(h)?;
418
419        let parent = self.subgraph.get_parent(h);
420
421        // We proceed to connect the edges between the newly inserted
422        // replacement and the rest of the graph.
423        //
424        // Existing connections to the removed subgraph will be automatically
425        // removed when the nodes are removed.
426
427        // 1. Get the boundary edges
428        let boundary_edges = self.all_boundary_edges(h).collect_vec();
429
430        let Self {
431            replacement,
432            subgraph,
433            ..
434        } = self;
435
436        // Nodes to remove from the replacement hugr
437        let repl_io = replacement
438            .get_io(replacement.entrypoint())
439            .expect("replacement is DFG-rooted");
440        let repl_entrypoint = replacement.entrypoint();
441
442        // 2. Insert the replacement as a whole.
443        let InsertionResult {
444            inserted_entrypoint: new_entrypoint,
445            mut node_map,
446        } = h.insert_hugr(parent, replacement);
447
448        // remove the Input and Output from h and node_map
449        for node in repl_io {
450            let node_h = node_map[&node];
451            h.remove_node(node_h);
452            node_map.remove(&node);
453        }
454
455        // make all (remaining) replacement top level children children of the parent
456        for child in h.children(new_entrypoint).collect_vec() {
457            h.set_parent(child, parent);
458        }
459
460        // remove the replacement entrypoint from h and node_map
461        h.remove_node(new_entrypoint);
462        node_map.remove(&repl_entrypoint);
463
464        // 3. Insert all boundary edges.
465        for (src, tgt) in boundary_edges {
466            let (src_node, src_port) = src.map_replacement(&node_map);
467            let (tgt_node, tgt_port) = tgt.map_replacement(&node_map);
468            h.connect(src_node, src_port, tgt_node, tgt_port);
469        }
470
471        // 4. Remove all nodes in subgraph and edges between them.
472        let removed_nodes = subgraph
473            .nodes()
474            .iter()
475            .map(|&node| (node, h.remove_node(node)))
476            .collect();
477
478        Ok(Outcome {
479            node_map,
480            removed_nodes,
481        })
482    }
483}
484
485/// Error from a [`SimpleReplacement`] operation.
486#[derive(Debug, Clone, Error, PartialEq, Eq)]
487#[non_exhaustive]
488pub enum SimpleReplacementError {
489    /// Invalid parent node.
490    #[error("Parent node is invalid.")]
491    InvalidParentNode(),
492    /// Node requested for removal is invalid.
493    #[error("A node requested for removal is invalid.")]
494    InvalidRemovedNode(),
495    /// Node in replacement graph is invalid.
496    #[error("A node in the replacement graph is invalid.")]
497    InvalidReplacementNode(),
498    /// Inlining replacement failed.
499    #[error("Inlining replacement failed: {0}")]
500    InliningFailed(#[from] InlineDFGError),
501}
502
503fn is_value_port<N: HugrNode>(
504    hugr: &impl HugrView<Node = N>,
505    node: N,
506    port: impl Into<Port>,
507) -> bool {
508    hugr.get_optype(node)
509        .port_kind(port)
510        .as_ref()
511        .is_some_and(EdgeKind::is_value)
512}
513
514#[cfg(test)]
515pub(in crate::hugr::patch) mod test {
516    use itertools::Itertools;
517    use rstest::{fixture, rstest};
518
519    use std::collections::{HashMap, HashSet};
520
521    use crate::Node;
522    use crate::builder::test::n_identity;
523    use crate::builder::{
524        BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
525        HugrBuilder, ModuleBuilder, endo_sig, inout_sig,
526    };
527    use crate::extension::prelude::{bool_t, qb_t};
528    use crate::hugr::patch::simple_replace::Outcome;
529    use crate::hugr::patch::{HostPort, PatchVerification, ReplacementPort};
530    use crate::hugr::views::{HugrView, SiblingSubgraph};
531    use crate::hugr::{Hugr, HugrMut, Patch};
532    use crate::ops::OpTag;
533    use crate::ops::OpTrait;
534    use crate::ops::handle::NodeHandle;
535    use crate::std_extensions::logic::LogicOp;
536    use crate::std_extensions::logic::test::and_op;
537    use crate::types::{Signature, Type};
538    use crate::utils::test_quantum_extension::{cx_gate, h_gate};
539
540    use super::SimpleReplacement;
541
542    /// Creates a hugr like the following:
543    /// --   H   --
544    /// -- [DFG] --
545    /// where [DFG] is:
546    /// ┌───┐     ┌───┐
547    /// ┤ H ├──■──┤ H ├
548    /// ├───┤┌─┴─┐├───┤
549    /// ┤ H ├┤ X ├┤ H ├
550    /// └───┘└───┘└───┘
551    fn make_hugr() -> Result<Hugr, BuildError> {
552        let mut module_builder = ModuleBuilder::new();
553        let _f_id = {
554            let mut func_builder = module_builder
555                .define_function("main", Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]))?;
556
557            let [qb0, qb1, qb2] = func_builder.input_wires_arr();
558
559            let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
560
561            let mut inner_builder =
562                func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
563            let inner_graph = {
564                let [wire0, wire1] = inner_builder.input_wires_arr();
565                let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
566                let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
567                let wire45 = inner_builder
568                    .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
569                let [wire4, wire5] = wire45.outputs_arr();
570                let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
571                let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
572                inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
573            }?;
574
575            func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
576        };
577        Ok(module_builder.finish_hugr()?)
578    }
579
580    #[fixture]
581    pub(in crate::hugr::patch) fn simple_hugr() -> Hugr {
582        make_hugr().unwrap()
583    }
584    /// Creates a hugr with a DFG root like the following:
585    /// ┌───┐
586    /// ┤ H ├──■──
587    /// ├───┤┌─┴─┐
588    /// ┤ H ├┤ X ├
589    /// └───┘└───┘
590    fn make_dfg_hugr() -> Result<Hugr, BuildError> {
591        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
592        let [wire0, wire1] = dfg_builder.input_wires_arr();
593        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
594        let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
595        let wire45 =
596            dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
597        dfg_builder.finish_hugr_with_outputs(wire45.outputs())
598    }
599
600    #[fixture]
601    pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr {
602        make_dfg_hugr().unwrap()
603    }
604
605    /// Creates a hugr with a DFG root like the following:
606    /// ─────
607    /// ┌───┐
608    /// ┤ H ├
609    /// └───┘
610    fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
611        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
612
613        let [wire0, wire1] = dfg_builder.input_wires_arr();
614        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
615        let wire2out = wire2.outputs().exactly_one().unwrap();
616        let wireoutvec = vec![wire0, wire2out];
617        dfg_builder.finish_hugr_with_outputs(wireoutvec)
618    }
619
620    #[fixture]
621    pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr {
622        make_dfg_hugr2().unwrap()
623    }
624
625    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
626    ///                     ┌─────────┐
627    ///                ┌────┤ (1) NOT ├──
628    ///  ┌─────────┐   │    └─────────┘
629    /// ─┤ (0) NOT ├───┤
630    ///  └─────────┘   │    ┌─────────┐
631    ///                └────┤ (2) NOT ├──
632    ///                     └─────────┘
633    /// This can be replaced with an empty hugr coping the input to both
634    /// outputs.
635    ///
636    /// Returns the hugr and the nodes of the NOT gates, in order.
637    #[fixture]
638    pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
639        let mut dfg_builder =
640            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
641        let [b] = dfg_builder.input_wires_arr();
642
643        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
644        let [b] = not_inp.outputs_arr();
645
646        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
647        let [b0] = not_0.outputs_arr();
648        let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
649        let [b1] = not_1.outputs_arr();
650
651        (
652            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
653            vec![not_inp.node(), not_0.node(), not_1.node()],
654        )
655    }
656
657    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
658    ///                     ┌─────────┐
659    ///                ┌────┤ (1) NOT ├──
660    ///  ┌─────────┐   │    └─────────┘
661    /// ─┤ (0) NOT ├───┤
662    ///  └─────────┘   │
663    ///                └─────────────────
664    ///
665    /// This can be replaced with a single NOT op, coping the input to the first
666    /// output.
667    ///
668    /// Returns the hugr and the nodes of the NOT ops, in order.
669    #[fixture]
670    pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
671        let mut dfg_builder =
672            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
673        let [b] = dfg_builder.input_wires_arr();
674
675        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
676        let [b] = not_inp.outputs_arr();
677
678        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
679        let [b0] = not_0.outputs_arr();
680        let b1 = b;
681
682        (
683            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
684            vec![not_inp.node(), not_0.node()],
685        )
686    }
687
688    #[rstest]
689    /// Replace the
690    ///      ┌───┐
691    /// ──■──┤ H ├
692    /// ┌─┴─┐├───┤
693    /// ┤ X ├┤ H ├
694    /// └───┘└───┘
695    /// part of
696    /// ┌───┐     ┌───┐
697    /// ┤ H ├──■──┤ H ├
698    /// ├───┤┌─┴─┐├───┤
699    /// ┤ H ├┤ X ├┤ H ├
700    /// └───┘└───┘└───┘
701    /// with
702    /// ┌───┐
703    /// ┤ H ├──■──
704    /// ├───┤┌─┴─┐
705    /// ┤ H ├┤ X ├
706    /// └───┘└───┘
707    fn test_simple_replacement(
708        simple_hugr: Hugr,
709        dfg_hugr: Hugr,
710        #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
711    ) {
712        let mut h: Hugr = simple_hugr;
713        // 1. Locate the CX and its successor H's in h
714        let h_node_cx: Node = h
715            .entry_descendants()
716            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
717            .unwrap();
718        let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
719        let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
720        // 2. Construct a new DFG-rooted hugr for the replacement
721        let n: Hugr = dfg_hugr;
722        // 3. Construct the input and output matchings
723        // 3.1. Locate the CX and its predecessor H's in n
724        let n_node_cx = n
725            .entry_descendants()
726            .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
727            .unwrap();
728        // 3.2. Locate the ports we need to specify as "glue" in n
729        let (n_cx_out_0, _n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
730        let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
731        // 3.3. Locate the ports we need to specify as "glue" in h
732        let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
733        // 4. Define the replacement
734        let r = SimpleReplacement {
735            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
736            replacement: n,
737        };
738
739        // Check output boundary
740        assert_eq!(
741            r.map_host_output((h_node_h0, h_h0_out)).unwrap(),
742            ReplacementPort::from((r.get_replacement_io()[1], n_port_2))
743        );
744
745        // Check invalidation set
746        assert_eq!(
747            HashSet::<_>::from_iter(r.invalidation_set()),
748            HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1]),
749        );
750
751        applicator(&mut h, r);
752        // Expect [DFG] to be replaced with:
753        // ┌───┐┌───┐
754        // ┤ H ├┤ H ├──■──
755        // ├───┤├───┤┌─┴─┐
756        // ┤ H ├┤ H ├┤ X ├
757        // └───┘└───┘└───┘
758        assert_eq!(h.validate(), Ok(()));
759    }
760
761    #[rstest]
762    /// Replace the
763    ///
764    /// ──■──
765    /// ┌─┴─┐
766    /// ┤ X ├
767    /// └───┘
768    /// part of
769    /// ┌───┐     ┌───┐
770    /// ┤ H ├──■──┤ H ├
771    /// ├───┤┌─┴─┐├───┤
772    /// ┤ H ├┤ X ├┤ H ├
773    /// └───┘└───┘└───┘
774    /// with
775    /// ─────
776    /// ┌───┐
777    /// ┤ H ├
778    /// └───┘
779    fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
780        let mut h: Hugr = simple_hugr;
781
782        // 1. Locate the CX in h
783        let h_node_cx: Node = h
784            .entry_descendants()
785            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
786            .unwrap();
787        let s: Vec<Node> = vec![h_node_cx];
788        // 2. Construct a new DFG-rooted hugr for the replacement
789        let n: Hugr = dfg_hugr2;
790        // 3. Construct the input and output matchings
791        // 3.1. Locate the Output and its predecessor H in n
792        let n_node_output = n.get_io(n.entrypoint()).unwrap()[1];
793        let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
794        // 4. Define the replacement
795        let r = SimpleReplacement {
796            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
797            replacement: n,
798        };
799        let Outcome {
800            node_map,
801            removed_nodes,
802        } = h.apply_patch(r).unwrap();
803
804        assert_eq!(
805            node_map.into_keys().collect::<HashSet<_>>(),
806            [n_node_h].into_iter().collect::<HashSet<_>>(),
807        );
808        assert_eq!(
809            removed_nodes.into_keys().collect::<HashSet<_>>(),
810            [h_node_cx].into_iter().collect::<HashSet<_>>(),
811        );
812
813        // Expect [DFG] to be replaced with:
814        // ┌───┐┌───┐
815        // ┤ H ├┤ H ├
816        // ├───┤├───┤┌───┐
817        // ┤ H ├┤ H ├┤ H ├
818        // └───┘└───┘└───┘
819        assert_eq!(h.validate(), Ok(()));
820    }
821
822    #[test]
823    fn test_replace_cx_cross() {
824        let q_row: Vec<Type> = vec![qb_t(), qb_t()];
825        let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
826        let mut circ = builder.as_circuit(builder.input_wires());
827        circ.append(cx_gate(), [0, 1]).unwrap();
828        circ.append(cx_gate(), [1, 0]).unwrap();
829        let wires = circ.finish();
830        let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
831        let replacement = h.clone();
832        let orig = h.clone();
833
834        let removal = h
835            .entry_descendants()
836            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
837            .collect_vec();
838        h.apply_patch(
839            SimpleReplacement::try_new(
840                SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
841                &h,
842                replacement,
843            )
844            .unwrap(),
845        )
846        .unwrap();
847
848        // They should be the same, up to node indices
849        assert_eq!(h.num_edges(), orig.num_edges());
850    }
851
852    #[test]
853    fn test_replace_after_copy() {
854        let one_bit = vec![bool_t()];
855        let two_bit = vec![bool_t(), bool_t()];
856
857        let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
858        let inw = builder.input_wires().exactly_one().unwrap();
859        let outw = builder
860            .add_dataflow_op(and_op(), [inw, inw])
861            .unwrap()
862            .outputs();
863        let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
864
865        let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
866        let inw = builder.input_wires();
867        let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
868        let repl = builder.finish_hugr_with_outputs(outw).unwrap();
869
870        let orig = h.clone();
871
872        let removal = h
873            .entry_descendants()
874            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
875            .collect_vec();
876
877        h.apply_patch(
878            SimpleReplacement::try_new(
879                SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
880                &h,
881                repl,
882            )
883            .unwrap(),
884        )
885        .unwrap();
886
887        // Nothing changed
888        assert_eq!(h.num_nodes(), orig.num_nodes());
889    }
890
891    /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the
892    /// input directly to the outputs.
893    ///
894    /// https://github.com/CQCL/hugr/issues/1190
895    #[rstest]
896    fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
897        let (mut hugr, nodes) = dfg_hugr_copy_bools;
898        let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
899
900        let replacement = {
901            let b =
902                DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
903            let [w] = b.input_wires_arr();
904            b.finish_hugr_with_outputs([w, w]).unwrap()
905        };
906
907        let subgraph =
908            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
909                .unwrap();
910
911        let rewrite = SimpleReplacement {
912            subgraph,
913            replacement,
914        };
915        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
916
917        assert_eq!(hugr.validate(), Ok(()));
918        assert_eq!(hugr.entry_descendants().count(), 3);
919    }
920
921    /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting
922    /// the input directly to the output.
923    ///
924    /// https://github.com/CQCL/hugr/issues/1323
925    #[rstest]
926    fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
927        let (mut hugr, nodes) = dfg_hugr_half_not_bools;
928        let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
929
930        let replacement = {
931            let mut b =
932                DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
933            let [w] = b.input_wires_arr();
934            let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
935            let [w_not] = not.outputs_arr();
936            b.finish_hugr_with_outputs([w, w_not]).unwrap()
937        };
938
939        let subgraph =
940            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
941
942        let rewrite = SimpleReplacement {
943            subgraph,
944            replacement,
945        };
946        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
947
948        assert_eq!(hugr.validate(), Ok(()));
949        assert_eq!(hugr.entry_descendants().count(), 4);
950    }
951
952    #[rstest]
953    fn test_nested_replace(dfg_hugr2: Hugr) {
954        // replace a node with a hugr with children
955
956        let mut h = dfg_hugr2;
957        let h_node = h
958            .entry_descendants()
959            .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
960            .unwrap();
961
962        // build a nested identity dfg
963        let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
964        let [input] = nest_build.input_wires_arr();
965        let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
966        let inner_dfg = n_identity(inner_build).unwrap();
967        let replacement = nest_build
968            .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
969            .unwrap();
970        let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
971
972        let rewrite = SimpleReplacement::try_new(subgraph, &h, replacement).unwrap();
973
974        assert_eq!(h.entry_descendants().count(), 4);
975
976        rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
977        h.validate().unwrap_or_else(|e| panic!("{e}"));
978
979        assert_eq!(h.entry_descendants().count(), 6);
980    }
981
982    use crate::hugr::patch::replace::Replacement;
983    fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
984        use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec};
985
986        let [in_, out] = s.get_replacement_io();
987        let mu_inp = s
988            .incoming_boundary(h)
989            .map(
990                |(HostPort(src, src_port), ReplacementPort(tgt, tgt_port))| {
991                    if tgt == out {
992                        unimplemented!()
993                    }
994                    NewEdgeSpec {
995                        src,
996                        tgt,
997                        kind: NewEdgeKind::Value {
998                            src_pos: src_port,
999                            tgt_pos: tgt_port,
1000                        },
1001                    }
1002                },
1003            )
1004            .collect();
1005        let mu_out = s
1006            .outgoing_boundary(h)
1007            .map(
1008                |(ReplacementPort(src, src_port), HostPort(tgt, tgt_port))| {
1009                    if src == in_ {
1010                        unimplemented!()
1011                    }
1012                    NewEdgeSpec {
1013                        src,
1014                        tgt,
1015                        kind: NewEdgeKind::Value {
1016                            src_pos: src_port,
1017                            tgt_pos: tgt_port,
1018                        },
1019                    }
1020                },
1021            )
1022            .collect();
1023        let mut replacement = s.replacement;
1024        replacement.remove_node(in_);
1025        replacement.remove_node(out);
1026        Replacement {
1027            removal: s.subgraph.nodes().to_vec(),
1028            replacement,
1029            adoptions: HashMap::new(),
1030            mu_inp,
1031            mu_out,
1032            mu_new: vec![],
1033        }
1034    }
1035
1036    fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
1037        h.apply_patch(rw).unwrap();
1038    }
1039
1040    fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
1041        h.apply_patch(to_replace(h, rw)).unwrap();
1042    }
1043}