hugr_core/hugr/rewrite/
simple_replace.rs

1//! Implementation of the `SimpleReplace` operation.
2
3use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7pub use crate::hugr::internal::HugrMutInternals;
8use crate::hugr::views::SiblingSubgraph;
9use crate::hugr::{HugrMut, HugrView, Rewrite};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::{Hugr, IncomingPort, Node};
12
13use thiserror::Error;
14
15use super::inline_dfg::InlineDFGError;
16
17/// Specification of a simple replacement operation.
18///
19/// # Type parameters
20///
21/// - `N`: The type of nodes in the host hugr.
22#[derive(Debug, Clone)]
23pub struct SimpleReplacement<HostNode = Node> {
24    /// The subgraph of the host hugr to be replaced.
25    subgraph: SiblingSubgraph<HostNode>,
26    /// A hugr with DFG root (consisting of replacement nodes).
27    replacement: Hugr,
28    /// A map from (target ports of edges from the Input node of `replacement`) to (target ports of
29    /// edges from nodes not in `removal` to nodes in `removal`).
30    nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
31    /// A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
32    /// (input ports of the Output node of `replacement`).
33    nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>,
34}
35
36impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
37    /// Create a new [`SimpleReplacement`] specification.
38    #[inline]
39    pub fn new(
40        subgraph: SiblingSubgraph<HostNode>,
41        replacement: Hugr,
42        nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
43        nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>,
44    ) -> Self {
45        Self {
46            subgraph,
47            replacement,
48            nu_inp,
49            nu_out,
50        }
51    }
52
53    /// The replacement hugr.
54    #[inline]
55    pub fn replacement(&self) -> &Hugr {
56        &self.replacement
57    }
58
59    /// Subgraph to be replaced.
60    #[inline]
61    pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
62        &self.subgraph
63    }
64}
65
66impl Rewrite for SimpleReplacement {
67    type Error = SimpleReplacementError;
68    type ApplyResult = Vec<(Node, OpType)>;
69    const UNCHANGED_ON_FAILURE: bool = true;
70
71    fn verify(&self, _h: &impl HugrView<Node = Node>) -> Result<(), SimpleReplacementError> {
72        unimplemented!()
73    }
74
75    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
76        let Self {
77            subgraph,
78            replacement,
79            nu_inp,
80            nu_out,
81        } = self;
82        let parent = subgraph.get_parent(h);
83        // 1. Check the parent node exists and is a DataflowParent.
84        if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
85            return Err(SimpleReplacementError::InvalidParentNode());
86        }
87        // 2. Check that all the to-be-removed nodes are children of it and are leaves.
88        for node in subgraph.nodes() {
89            if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
90                return Err(SimpleReplacementError::InvalidRemovedNode());
91            }
92        }
93
94        let replacement_output_node = replacement
95            .get_io(replacement.root())
96            .expect("parent already checked.")[1];
97
98        // 3. Do the replacement.
99        // Now we proceed to connect the edges between the newly inserted
100        // replacement and the rest of the graph.
101        //
102        // Existing connections to the removed subgraph will be automatically
103        // removed when the nodes are removed.
104
105        // 3.1. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the
106        // predecessor of p to (the new copy of) q.
107        let nu_inp_connects: Vec<_> = nu_inp
108            .iter()
109            .filter(|&((rep_inp_node, _), _)| {
110                replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
111            })
112            .map(
113                |((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
114                    // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
115                    let (rem_inp_pred_node, rem_inp_pred_port) = h
116                        .single_linked_output(*rem_inp_node, *rem_inp_port)
117                        .unwrap();
118                    (
119                        rem_inp_pred_node,
120                        rem_inp_pred_port,
121                        // the new input node will be updated after insertion
122                        rep_inp_node,
123                        rep_inp_port,
124                    )
125                },
126            )
127            .collect();
128
129        // 3.2. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
130        // edge from (the new copy of) the predecessor of q to p.
131        let nu_out_connects: Vec<_> = nu_out
132            .iter()
133            .filter_map(|((rem_out_node, rem_out_port), rep_out_port)| {
134                let (rep_out_pred_node, rep_out_pred_port) = replacement
135                    .single_linked_output(replacement_output_node, *rep_out_port)
136                    .unwrap();
137                (replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
138                    (
139                        // the new output node will be updated after insertion
140                        rep_out_pred_node,
141                        rep_out_pred_port,
142                        rem_out_node,
143                        rem_out_port,
144                    )
145                })
146            })
147            .collect();
148
149        // 3.3. Insert the replacement as a whole.
150        let InsertionResult {
151            new_root,
152            node_map: index_map,
153        } = h.insert_hugr(parent, replacement);
154
155        // remove the Input and Output nodes from the replacement graph
156        let replace_children = h.children(new_root).collect::<Vec<Node>>();
157        for &io in &replace_children[..2] {
158            h.remove_node(io);
159        }
160        // make all replacement top level children children of the parent
161        for &child in &replace_children[2..] {
162            h.set_parent(child, parent);
163        }
164        // remove the replacement root (which now has no children and no edges)
165        h.remove_node(new_root);
166
167        // 3.4. Update replacement nodes according to insertion mapping and connect
168        for (src_node, src_port, tgt_node, tgt_port) in nu_inp_connects {
169            h.connect(
170                src_node,
171                src_port,
172                *index_map.get(tgt_node).unwrap(),
173                *tgt_port,
174            )
175        }
176
177        for (src_node, src_port, tgt_node, tgt_port) in nu_out_connects {
178            h.connect(
179                *index_map.get(&src_node).unwrap(),
180                src_port,
181                *tgt_node,
182                *tgt_port,
183            )
184        }
185        // 3.5. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
186        // to p1.
187        //
188        // i.e. the replacement graph has direct edges between the input and output nodes.
189        for ((rem_out_node, rem_out_port), &rep_out_port) in &nu_out {
190            let rem_inp_nodeport = nu_inp.get(&(replacement_output_node, rep_out_port));
191            if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
192                // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
193                let (rem_inp_pred_node, rem_inp_pred_port) = h
194                    .single_linked_output(*rem_inp_node, *rem_inp_port)
195                    .unwrap();
196
197                h.connect(
198                    rem_inp_pred_node,
199                    rem_inp_pred_port,
200                    *rem_out_node,
201                    *rem_out_port,
202                );
203            }
204        }
205
206        // 3.6. Remove all nodes in subgraph and edges between them.
207        Ok(subgraph
208            .nodes()
209            .iter()
210            .map(|&node| (node, h.remove_node(node)))
211            .collect())
212    }
213
214    #[inline]
215    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
216        let subcirc = self.subgraph.nodes().iter().copied();
217        let out_neighs = self.nu_out.keys().map(|key| key.0);
218        subcirc.chain(out_neighs)
219    }
220}
221
222/// Error from a [`SimpleReplacement`] operation.
223#[derive(Debug, Clone, Error, PartialEq, Eq)]
224#[non_exhaustive]
225pub enum SimpleReplacementError {
226    /// Invalid parent node.
227    #[error("Parent node is invalid.")]
228    InvalidParentNode(),
229    /// Node requested for removal is invalid.
230    #[error("A node requested for removal is invalid.")]
231    InvalidRemovedNode(),
232    /// Node in replacement graph is invalid.
233    #[error("A node in the replacement graph is invalid.")]
234    InvalidReplacementNode(),
235    /// Inlining replacement failed.
236    #[error("Inlining replacement failed: {0}")]
237    InliningFailed(#[from] InlineDFGError),
238}
239
240#[cfg(test)]
241pub(in crate::hugr::rewrite) mod test {
242    use itertools::Itertools;
243    use rstest::{fixture, rstest};
244    use std::collections::{HashMap, HashSet};
245
246    use crate::builder::test::n_identity;
247    use crate::builder::{
248        endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
249        DataflowSubContainer, HugrBuilder, ModuleBuilder,
250    };
251    use crate::extension::prelude::{bool_t, qb_t};
252    use crate::extension::ExtensionSet;
253    use crate::hugr::views::{HugrView, SiblingSubgraph};
254    use crate::hugr::{Hugr, HugrMut, Rewrite};
255    use crate::ops::dataflow::DataflowOpTrait;
256    use crate::ops::handle::NodeHandle;
257    use crate::ops::OpTag;
258    use crate::ops::OpTrait;
259    use crate::std_extensions::logic::test::and_op;
260    use crate::std_extensions::logic::LogicOp;
261    use crate::types::{Signature, Type};
262    use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID};
263    use crate::{IncomingPort, Node};
264
265    use super::SimpleReplacement;
266
267    /// Creates a hugr like the following:
268    /// --   H   --
269    /// -- [DFG] --
270    /// where [DFG] is:
271    /// ┌───┐     ┌───┐
272    /// ┤ H ├──■──┤ H ├
273    /// ├───┤┌─┴─┐├───┤
274    /// ┤ H ├┤ X ├┤ H ├
275    /// └───┘└───┘└───┘
276    fn make_hugr() -> Result<Hugr, BuildError> {
277        let mut module_builder = ModuleBuilder::new();
278        let _f_id = {
279            let just_q: ExtensionSet = EXTENSION_ID.into();
280            let mut func_builder = module_builder.define_function(
281                "main",
282                Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
283                    .with_extension_delta(just_q.clone()),
284            )?;
285
286            let [qb0, qb1, qb2] = func_builder.input_wires_arr();
287
288            let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
289
290            let mut inner_builder =
291                func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
292            let inner_graph = {
293                let [wire0, wire1] = inner_builder.input_wires_arr();
294                let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
295                let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
296                let wire45 = inner_builder
297                    .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
298                let [wire4, wire5] = wire45.outputs_arr();
299                let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
300                let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
301                inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
302            }?;
303
304            func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
305        };
306        Ok(module_builder.finish_hugr()?)
307    }
308
309    #[fixture]
310    pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr {
311        make_hugr().unwrap()
312    }
313    /// Creates a hugr with a DFG root like the following:
314    /// ┌───┐
315    /// ┤ H ├──■──
316    /// ├───┤┌─┴─┐
317    /// ┤ H ├┤ X ├
318    /// └───┘└───┘
319    fn make_dfg_hugr() -> Result<Hugr, BuildError> {
320        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?;
321        let [wire0, wire1] = dfg_builder.input_wires_arr();
322        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
323        let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
324        let wire45 =
325            dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
326        dfg_builder.finish_hugr_with_outputs(wire45.outputs())
327    }
328
329    #[fixture]
330    pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr {
331        make_dfg_hugr().unwrap()
332    }
333
334    /// Creates a hugr with a DFG root like the following:
335    /// ─────
336    /// ┌───┐
337    /// ┤ H ├
338    /// └───┘
339    fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
340        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
341
342        let [wire0, wire1] = dfg_builder.input_wires_arr();
343        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
344        let wire2out = wire2.outputs().exactly_one().unwrap();
345        let wireoutvec = vec![wire0, wire2out];
346        dfg_builder.finish_hugr_with_outputs(wireoutvec)
347    }
348
349    #[fixture]
350    pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr {
351        make_dfg_hugr2().unwrap()
352    }
353
354    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
355    ///                     ┌─────────┐
356    ///                ┌────┤ (1) NOT ├──
357    ///  ┌─────────┐   │    └─────────┘
358    /// ─┤ (0) NOT ├───┤
359    ///  └─────────┘   │    ┌─────────┐
360    ///                └────┤ (2) NOT ├──
361    ///                     └─────────┘
362    /// This can be replaced with an empty hugr coping the input to both outputs.
363    ///
364    /// Returns the hugr and the nodes of the NOT gates, in order.
365    #[fixture]
366    pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
367        let mut dfg_builder =
368            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
369        let [b] = dfg_builder.input_wires_arr();
370
371        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
372        let [b] = not_inp.outputs_arr();
373
374        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
375        let [b0] = not_0.outputs_arr();
376        let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
377        let [b1] = not_1.outputs_arr();
378
379        (
380            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
381            vec![not_inp.node(), not_0.node(), not_1.node()],
382        )
383    }
384
385    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
386    ///                     ┌─────────┐
387    ///                ┌────┤ (1) NOT ├──
388    ///  ┌─────────┐   │    └─────────┘
389    /// ─┤ (0) NOT ├───┤
390    ///  └─────────┘   │
391    ///                └─────────────────
392    ///
393    /// This can be replaced with a single NOT op, coping the input to the first output.
394    ///
395    /// Returns the hugr and the nodes of the NOT ops, in order.
396    #[fixture]
397    pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
398        let mut dfg_builder =
399            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
400        let [b] = dfg_builder.input_wires_arr();
401
402        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
403        let [b] = not_inp.outputs_arr();
404
405        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
406        let [b0] = not_0.outputs_arr();
407        let b1 = b;
408
409        (
410            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
411            vec![not_inp.node(), not_0.node()],
412        )
413    }
414
415    #[rstest]
416    /// Replace the
417    ///      ┌───┐
418    /// ──■──┤ H ├
419    /// ┌─┴─┐├───┤
420    /// ┤ X ├┤ H ├
421    /// └───┘└───┘
422    /// part of
423    /// ┌───┐     ┌───┐
424    /// ┤ H ├──■──┤ H ├
425    /// ├───┤┌─┴─┐├───┤
426    /// ┤ H ├┤ X ├┤ H ├
427    /// └───┘└───┘└───┘
428    /// with
429    /// ┌───┐
430    /// ┤ H ├──■──
431    /// ├───┤┌─┴─┐
432    /// ┤ H ├┤ X ├
433    /// └───┘└───┘
434    fn test_simple_replacement(
435        simple_hugr: Hugr,
436        dfg_hugr: Hugr,
437        #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
438    ) {
439        let mut h: Hugr = simple_hugr;
440        // 1. Locate the CX and its successor H's in h
441        let h_node_cx: Node = h
442            .nodes()
443            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
444            .unwrap();
445        let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
446        let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
447        // 2. Construct a new DFG-rooted hugr for the replacement
448        let n: Hugr = dfg_hugr;
449        // 3. Construct the input and output matchings
450        // 3.1. Locate the CX and its predecessor H's in n
451        let n_node_cx = n
452            .nodes()
453            .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
454            .unwrap();
455        let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap();
456        // 3.2. Locate the ports we need to specify as "glue" in n
457        let n_port_0 = n.node_inputs(n_node_h0).next().unwrap();
458        let n_port_1 = n.node_inputs(n_node_h1).next().unwrap();
459        let (n_cx_out_0, n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
460        let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
461        let n_port_3 = n.linked_inputs(n_node_cx, n_cx_out_1).next().unwrap().1;
462        // 3.3. Locate the ports we need to specify as "glue" in h
463        let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
464        let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
465        let h_h1_out = h.node_outputs(h_node_h1).next().unwrap();
466        let (h_outp_node, h_port_2) = h.linked_inputs(h_node_h0, h_h0_out).next().unwrap();
467        let h_port_3 = h.linked_inputs(h_node_h1, h_h1_out).next().unwrap().1;
468        // 3.4. Construct the maps
469        let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
470        let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
471        nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0));
472        nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1));
473        nu_out.insert((h_outp_node, h_port_2), n_port_2);
474        nu_out.insert((h_outp_node, h_port_3), n_port_3);
475        // 4. Define the replacement
476        let r = SimpleReplacement {
477            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
478            replacement: n,
479            nu_inp,
480            nu_out,
481        };
482        assert_eq!(
483            HashSet::<_>::from_iter(r.invalidation_set()),
484            HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]),
485        );
486
487        applicator(&mut h, r);
488        // Expect [DFG] to be replaced with:
489        // ┌───┐┌───┐
490        // ┤ H ├┤ H ├──■──
491        // ├───┤├───┤┌─┴─┐
492        // ┤ H ├┤ H ├┤ X ├
493        // └───┘└───┘└───┘
494        assert_eq!(h.validate(), Ok(()));
495    }
496
497    #[rstest]
498    /// Replace the
499    ///
500    /// ──■──
501    /// ┌─┴─┐
502    /// ┤ X ├
503    /// └───┘
504    /// part of
505    /// ┌───┐     ┌───┐
506    /// ┤ H ├──■──┤ H ├
507    /// ├───┤┌─┴─┐├───┤
508    /// ┤ H ├┤ X ├┤ H ├
509    /// └───┘└───┘└───┘
510    /// with
511    /// ─────
512    /// ┌───┐
513    /// ┤ H ├
514    /// └───┘
515    fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
516        let mut h: Hugr = simple_hugr;
517
518        // 1. Locate the CX in h
519        let h_node_cx: Node = h
520            .nodes()
521            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
522            .unwrap();
523        let s: Vec<Node> = vec![h_node_cx].into_iter().collect();
524        // 2. Construct a new DFG-rooted hugr for the replacement
525        let n: Hugr = dfg_hugr2;
526        // 3. Construct the input and output matchings
527        // 3.1. Locate the Output and its predecessor H in n
528        let n_node_output = n
529            .nodes()
530            .find(|node: &Node| n.get_optype(*node).tag() == OpTag::Output)
531            .unwrap();
532        let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
533        // 3.2. Locate the ports we need to specify as "glue" in n
534        let (n_port_0, n_port_1) = n
535            .node_inputs(n_node_output)
536            .take(2)
537            .collect_tuple()
538            .unwrap();
539        let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
540        // 3.3. Locate the ports we need to specify as "glue" in h
541        let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
542        let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
543        let h_port_2 = h.node_inputs(h_node_h0).next().unwrap();
544        let h_port_3 = h.node_inputs(h_node_h1).next().unwrap();
545        // 3.4. Construct the maps
546        let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
547        let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
548        nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
549        nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
550        nu_out.insert((h_node_h0, h_port_2), n_port_0);
551        nu_out.insert((h_node_h1, h_port_3), n_port_1);
552        // 4. Define the replacement
553        let r = SimpleReplacement {
554            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
555            replacement: n,
556            nu_inp,
557            nu_out,
558        };
559        h.apply_rewrite(r).unwrap();
560        // Expect [DFG] to be replaced with:
561        // ┌───┐┌───┐
562        // ┤ H ├┤ H ├
563        // ├───┤├───┤┌───┐
564        // ┤ H ├┤ H ├┤ H ├
565        // └───┘└───┘└───┘
566        assert_eq!(h.validate(), Ok(()));
567    }
568
569    #[test]
570    fn test_replace_cx_cross() {
571        let q_row: Vec<Type> = vec![qb_t(), qb_t()];
572        let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
573        let mut circ = builder.as_circuit(builder.input_wires());
574        circ.append(cx_gate(), [0, 1]).unwrap();
575        circ.append(cx_gate(), [1, 0]).unwrap();
576        let wires = circ.finish();
577        let [input, output] = builder.io();
578        let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
579        let replacement = h.clone();
580        let orig = h.clone();
581
582        let removal = h
583            .nodes()
584            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
585            .collect_vec();
586        let inputs = h
587            .node_outputs(input)
588            .filter(|&p| {
589                h.get_optype(input)
590                    .as_input()
591                    .unwrap()
592                    .signature()
593                    .port_type(p)
594                    .is_some()
595            })
596            .map(|p| {
597                let link = h.linked_inputs(input, p).next().unwrap();
598                (link, link)
599            })
600            .collect();
601        let outputs = h
602            .node_inputs(output)
603            .filter(|&p| {
604                h.get_optype(output)
605                    .as_output()
606                    .unwrap()
607                    .signature()
608                    .port_type(p)
609                    .is_some()
610            })
611            .map(|p| ((output, p), p))
612            .collect();
613        h.apply_rewrite(SimpleReplacement::new(
614            SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
615            replacement,
616            inputs,
617            outputs,
618        ))
619        .unwrap();
620
621        // They should be the same, up to node indices
622        assert_eq!(h.edge_count(), orig.edge_count());
623    }
624
625    #[test]
626    fn test_replace_after_copy() {
627        let one_bit = vec![bool_t()];
628        let two_bit = vec![bool_t(), bool_t()];
629
630        let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
631        let inw = builder.input_wires().exactly_one().unwrap();
632        let outw = builder
633            .add_dataflow_op(and_op(), [inw, inw])
634            .unwrap()
635            .outputs();
636        let [input, _] = builder.io();
637        let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
638
639        let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
640        let inw = builder.input_wires();
641        let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
642        let [repl_input, repl_output] = builder.io();
643        let repl = builder.finish_hugr_with_outputs(outw).unwrap();
644
645        let orig = h.clone();
646
647        let removal = h
648            .nodes()
649            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
650            .collect_vec();
651
652        let first_out_p = h.node_outputs(input).next().unwrap();
653        let embedded_inputs = h.linked_inputs(input, first_out_p);
654        let repl_inputs = repl
655            .node_outputs(repl_input)
656            .map(|p| repl.linked_inputs(repl_input, p).next().unwrap());
657        let inputs = embedded_inputs.zip(repl_inputs).collect();
658
659        let outputs = repl
660            .node_inputs(repl_output)
661            .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some())
662            .map(|p| ((repl_output, p), p))
663            .collect();
664
665        h.apply_rewrite(SimpleReplacement::new(
666            SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
667            repl,
668            inputs,
669            outputs,
670        ))
671        .unwrap();
672
673        // Nothing changed
674        assert_eq!(h.node_count(), orig.node_count());
675    }
676
677    /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the input
678    /// directly to the outputs.
679    ///
680    /// https://github.com/CQCL/hugr/issues/1190
681    #[rstest]
682    fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
683        let (mut hugr, nodes) = dfg_hugr_copy_bools;
684        let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
685
686        let [_input, output] = hugr.get_io(hugr.root()).unwrap();
687
688        let replacement = {
689            let b =
690                DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
691            let [w] = b.input_wires_arr();
692            b.finish_hugr_with_outputs([w, w]).unwrap()
693        };
694        let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
695
696        let subgraph =
697            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
698                .unwrap();
699        // A map from (target ports of edges from the Input node of `replacement`) to (target ports of
700        // edges from nodes not in `removal` to nodes in `removal`).
701        let nu_inp = [
702            (
703                (repl_output, IncomingPort::from(0)),
704                (input_not, IncomingPort::from(0)),
705            ),
706            (
707                (repl_output, IncomingPort::from(1)),
708                (input_not, IncomingPort::from(0)),
709            ),
710        ]
711        .into_iter()
712        .collect();
713        // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
714        // (input ports of the Output node of `replacement`).
715        let nu_out = [
716            ((output, IncomingPort::from(0)), IncomingPort::from(0)),
717            ((output, IncomingPort::from(1)), IncomingPort::from(1)),
718        ]
719        .into_iter()
720        .collect();
721
722        let rewrite = SimpleReplacement {
723            subgraph,
724            replacement,
725            nu_inp,
726            nu_out,
727        };
728        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
729
730        assert_eq!(hugr.validate(), Ok(()));
731        assert_eq!(hugr.node_count(), 3);
732    }
733
734    /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting the input
735    /// directly to the output.
736    ///
737    /// https://github.com/CQCL/hugr/issues/1323
738    #[rstest]
739    fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
740        let (mut hugr, nodes) = dfg_hugr_half_not_bools;
741        let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
742
743        let [_input, output] = hugr.get_io(hugr.root()).unwrap();
744
745        let (replacement, repl_not) = {
746            let mut b =
747                DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
748            let [w] = b.input_wires_arr();
749            let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
750            let [w_not] = not.outputs_arr();
751            (b.finish_hugr_with_outputs([w, w_not]).unwrap(), not.node())
752        };
753        let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
754
755        let subgraph =
756            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
757        // A map from (target ports of edges from the Input node of `replacement`) to (target ports of
758        // edges from nodes not in `removal` to nodes in `removal`).
759        let nu_inp = [
760            (
761                (repl_output, IncomingPort::from(0)),
762                (input_not, IncomingPort::from(0)),
763            ),
764            (
765                (repl_not, IncomingPort::from(0)),
766                (input_not, IncomingPort::from(0)),
767            ),
768        ]
769        .into_iter()
770        .collect();
771        // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
772        // (input ports of the Output node of `replacement`).
773        let nu_out = [
774            ((output, IncomingPort::from(0)), IncomingPort::from(0)),
775            ((output, IncomingPort::from(1)), IncomingPort::from(1)),
776        ]
777        .into_iter()
778        .collect();
779
780        let rewrite = SimpleReplacement {
781            subgraph,
782            replacement,
783            nu_inp,
784            nu_out,
785        };
786        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
787
788        assert_eq!(hugr.validate(), Ok(()));
789        assert_eq!(hugr.node_count(), 4);
790    }
791
792    #[rstest]
793    fn test_nested_replace(dfg_hugr2: Hugr) {
794        // replace a node with a hugr with children
795
796        let mut h = dfg_hugr2;
797        let h_node = h
798            .nodes()
799            .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
800            .unwrap();
801
802        // build a nested identity dfg
803        let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
804        let [input] = nest_build.input_wires_arr();
805        let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
806        let inner_dfg = n_identity(inner_build).unwrap();
807        let inner_dfg_node = inner_dfg.node();
808        let replacement = nest_build
809            .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
810            .unwrap();
811        let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
812        let nu_inp = vec![(
813            (inner_dfg_node, IncomingPort::from(0)),
814            (h_node, IncomingPort::from(0)),
815        )]
816        .into_iter()
817        .collect();
818
819        let nu_out = vec![(
820            (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)),
821            IncomingPort::from(0),
822        )]
823        .into_iter()
824        .collect();
825
826        let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);
827
828        assert_eq!(h.node_count(), 4);
829
830        rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
831        h.validate().unwrap_or_else(|e| panic!("{e}"));
832
833        assert_eq!(h.node_count(), 6);
834    }
835
836    use crate::hugr::rewrite::replace::Replacement;
837    fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
838        use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};
839
840        let mut replacement = s.replacement;
841        let (in_, out) = replacement
842            .children(replacement.root())
843            .take(2)
844            .collect_tuple()
845            .unwrap();
846        let mu_inp = s
847            .nu_inp
848            .iter()
849            .map(|((tgt, tgt_port), (r_n, r_p))| {
850                if *tgt == out {
851                    unimplemented!()
852                };
853                let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap();
854                NewEdgeSpec {
855                    src,
856                    tgt: *tgt,
857                    kind: NewEdgeKind::Value {
858                        src_pos: src_port,
859                        tgt_pos: *tgt_port,
860                    },
861                }
862            })
863            .collect();
864        let mu_out = s
865            .nu_out
866            .iter()
867            .map(|((tgt, tgt_port), out_port)| {
868                let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap();
869                if src == in_ {
870                    unimplemented!()
871                };
872                NewEdgeSpec {
873                    src,
874                    tgt: *tgt,
875                    kind: NewEdgeKind::Value {
876                        src_pos: src_port,
877                        tgt_pos: *tgt_port,
878                    },
879                }
880            })
881            .collect();
882        replacement.remove_node(in_);
883        replacement.remove_node(out);
884        Replacement {
885            removal: s.subgraph.nodes().to_vec(),
886            replacement,
887            adoptions: HashMap::new(),
888            mu_inp,
889            mu_out,
890            mu_new: vec![],
891        }
892    }
893
894    fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
895        h.apply_rewrite(rw).unwrap();
896    }
897
898    fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
899        h.apply_rewrite(to_replace(h, rw)).unwrap();
900    }
901}