hugr_core/hugr/rewrite/
simple_replace.rs

1//! Implementation of the `SimpleReplace` operation.
2
3use std::collections::HashMap;
4
5use crate::hugr::hugrmut::InsertionResult;
6pub use crate::hugr::internal::HugrMutInternals;
7use crate::hugr::views::SiblingSubgraph;
8use crate::hugr::{HugrMut, HugrView, Rewrite};
9use crate::ops::{OpTag, OpTrait, OpType};
10use crate::{Hugr, IncomingPort, Node};
11use thiserror::Error;
12
13use super::inline_dfg::InlineDFGError;
14
15/// Specification of a simple replacement operation.
16#[derive(Debug, Clone)]
17pub struct SimpleReplacement {
18    /// The subgraph of the hugr to be replaced.
19    subgraph: SiblingSubgraph,
20    /// A hugr with DFG root (consisting of replacement nodes).
21    replacement: Hugr,
22    /// A map from (target ports of edges from the Input node of `replacement`) to (target ports of
23    /// edges from nodes not in `removal` to nodes in `removal`).
24    nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)>,
25    /// A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
26    /// (input ports of the Output node of `replacement`).
27    nu_out: HashMap<(Node, IncomingPort), IncomingPort>,
28}
29
30impl SimpleReplacement {
31    /// Create a new [`SimpleReplacement`] specification.
32    #[inline]
33    pub fn new(
34        subgraph: SiblingSubgraph,
35        replacement: Hugr,
36        nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)>,
37        nu_out: HashMap<(Node, IncomingPort), IncomingPort>,
38    ) -> Self {
39        Self {
40            subgraph,
41            replacement,
42            nu_inp,
43            nu_out,
44        }
45    }
46
47    /// The replacement hugr.
48    #[inline]
49    pub fn replacement(&self) -> &Hugr {
50        &self.replacement
51    }
52
53    /// Subgraph to be replaced.
54    #[inline]
55    pub fn subgraph(&self) -> &SiblingSubgraph {
56        &self.subgraph
57    }
58}
59
60impl Rewrite for SimpleReplacement {
61    type Error = SimpleReplacementError;
62    type ApplyResult = Vec<(Node, OpType)>;
63    const UNCHANGED_ON_FAILURE: bool = true;
64
65    fn verify(&self, _h: &impl HugrView) -> Result<(), SimpleReplacementError> {
66        unimplemented!()
67    }
68
69    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
70        let Self {
71            subgraph,
72            replacement,
73            nu_inp,
74            nu_out,
75        } = self;
76        let parent = subgraph.get_parent(h);
77        // 1. Check the parent node exists and is a DataflowParent.
78        if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
79            return Err(SimpleReplacementError::InvalidParentNode());
80        }
81        // 2. Check that all the to-be-removed nodes are children of it and are leaves.
82        for node in subgraph.nodes() {
83            if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
84                return Err(SimpleReplacementError::InvalidRemovedNode());
85            }
86        }
87
88        let replacement_output_node = replacement
89            .get_io(replacement.root())
90            .expect("parent already checked.")[1];
91
92        // 3. Do the replacement.
93        // Now we proceed to connect the edges between the newly inserted
94        // replacement and the rest of the graph.
95        //
96        // Existing connections to the removed subgraph will be automatically
97        // removed when the nodes are removed.
98
99        // 3.1. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the
100        // predecessor of p to (the new copy of) q.
101        let nu_inp_connects: Vec<_> = nu_inp
102            .iter()
103            .filter(|&((rep_inp_node, _), _)| {
104                replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
105            })
106            .map(
107                |((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
108                    // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
109                    let (rem_inp_pred_node, rem_inp_pred_port) = h
110                        .single_linked_output(*rem_inp_node, *rem_inp_port)
111                        .unwrap();
112                    (
113                        rem_inp_pred_node,
114                        rem_inp_pred_port,
115                        // the new input node will be updated after insertion
116                        rep_inp_node,
117                        rep_inp_port,
118                    )
119                },
120            )
121            .collect();
122
123        // 3.2. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
124        // edge from (the new copy of) the predecessor of q to p.
125        let nu_out_connects: Vec<_> = nu_out
126            .iter()
127            .filter_map(|((rem_out_node, rem_out_port), rep_out_port)| {
128                let (rep_out_pred_node, rep_out_pred_port) = replacement
129                    .single_linked_output(replacement_output_node, *rep_out_port)
130                    .unwrap();
131                (replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
132                    (
133                        // the new output node will be updated after insertion
134                        rep_out_pred_node,
135                        rep_out_pred_port,
136                        rem_out_node,
137                        rem_out_port,
138                    )
139                })
140            })
141            .collect();
142
143        // 3.3. Insert the replacement as a whole.
144        let InsertionResult {
145            new_root,
146            node_map: index_map,
147        } = h.insert_hugr(parent, replacement);
148
149        // remove the Input and Output nodes from the replacement graph
150        let replace_children = h.children(new_root).collect::<Vec<Node>>();
151        for &io in &replace_children[..2] {
152            h.remove_node(io);
153        }
154        // make all replacement top level children children of the parent
155        for &child in &replace_children[2..] {
156            h.set_parent(child, parent);
157        }
158        // remove the replacement root (which now has no children and no edges)
159        h.remove_node(new_root);
160
161        // 3.4. Update replacement nodes according to insertion mapping and connect
162        for (src_node, src_port, tgt_node, tgt_port) in nu_inp_connects {
163            h.connect(
164                src_node,
165                src_port,
166                *index_map.get(tgt_node).unwrap(),
167                *tgt_port,
168            )
169        }
170
171        for (src_node, src_port, tgt_node, tgt_port) in nu_out_connects {
172            h.connect(
173                *index_map.get(&src_node).unwrap(),
174                src_port,
175                *tgt_node,
176                *tgt_port,
177            )
178        }
179        // 3.5. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
180        // to p1.
181        //
182        // i.e. the replacement graph has direct edges between the input and output nodes.
183        for ((rem_out_node, rem_out_port), &rep_out_port) in &nu_out {
184            let rem_inp_nodeport = nu_inp.get(&(replacement_output_node, rep_out_port));
185            if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
186                // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
187                let (rem_inp_pred_node, rem_inp_pred_port) = h
188                    .single_linked_output(*rem_inp_node, *rem_inp_port)
189                    .unwrap();
190
191                h.connect(
192                    rem_inp_pred_node,
193                    rem_inp_pred_port,
194                    *rem_out_node,
195                    *rem_out_port,
196                );
197            }
198        }
199
200        // 3.6. Remove all nodes in subgraph and edges between them.
201        Ok(subgraph
202            .nodes()
203            .iter()
204            .map(|&node| (node, h.remove_node(node)))
205            .collect())
206    }
207
208    #[inline]
209    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
210        let subcirc = self.subgraph.nodes().iter().copied();
211        let out_neighs = self.nu_out.keys().map(|key| key.0);
212        subcirc.chain(out_neighs)
213    }
214}
215
216/// Error from a [`SimpleReplacement`] operation.
217#[derive(Debug, Clone, Error, PartialEq, Eq)]
218#[non_exhaustive]
219pub enum SimpleReplacementError {
220    /// Invalid parent node.
221    #[error("Parent node is invalid.")]
222    InvalidParentNode(),
223    /// Node requested for removal is invalid.
224    #[error("A node requested for removal is invalid.")]
225    InvalidRemovedNode(),
226    /// Node in replacement graph is invalid.
227    #[error("A node in the replacement graph is invalid.")]
228    InvalidReplacementNode(),
229    /// Inlining replacement failed.
230    #[error("Inlining replacement failed: {0}")]
231    InliningFailed(#[from] InlineDFGError),
232}
233
234#[cfg(test)]
235pub(in crate::hugr::rewrite) mod test {
236    use itertools::Itertools;
237    use rstest::{fixture, rstest};
238    use std::collections::{HashMap, HashSet};
239
240    use crate::builder::test::n_identity;
241    use crate::builder::{
242        endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
243        DataflowSubContainer, HugrBuilder, ModuleBuilder,
244    };
245    use crate::extension::prelude::{bool_t, qb_t};
246    use crate::extension::ExtensionSet;
247    use crate::hugr::views::{HugrView, SiblingSubgraph};
248    use crate::hugr::{Hugr, HugrMut, Rewrite};
249    use crate::ops::dataflow::DataflowOpTrait;
250    use crate::ops::handle::NodeHandle;
251    use crate::ops::OpTag;
252    use crate::ops::OpTrait;
253    use crate::std_extensions::logic::test::and_op;
254    use crate::std_extensions::logic::LogicOp;
255    use crate::types::{Signature, Type};
256    use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID};
257    use crate::{IncomingPort, Node};
258
259    use super::SimpleReplacement;
260
261    /// Creates a hugr like the following:
262    /// --   H   --
263    /// -- [DFG] --
264    /// where [DFG] is:
265    /// ┌───┐     ┌───┐
266    /// ┤ H ├──■──┤ H ├
267    /// ├───┤┌─┴─┐├───┤
268    /// ┤ H ├┤ X ├┤ H ├
269    /// └───┘└───┘└───┘
270    fn make_hugr() -> Result<Hugr, BuildError> {
271        let mut module_builder = ModuleBuilder::new();
272        let _f_id = {
273            let just_q: ExtensionSet = EXTENSION_ID.into();
274            let mut func_builder = module_builder.define_function(
275                "main",
276                Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
277                    .with_extension_delta(just_q.clone()),
278            )?;
279
280            let [qb0, qb1, qb2] = func_builder.input_wires_arr();
281
282            let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
283
284            let mut inner_builder =
285                func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
286            let inner_graph = {
287                let [wire0, wire1] = inner_builder.input_wires_arr();
288                let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
289                let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
290                let wire45 = inner_builder
291                    .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
292                let [wire4, wire5] = wire45.outputs_arr();
293                let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
294                let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
295                inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
296            }?;
297
298            func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
299        };
300        Ok(module_builder.finish_hugr()?)
301    }
302
303    #[fixture]
304    pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr {
305        make_hugr().unwrap()
306    }
307    /// Creates a hugr with a DFG root like the following:
308    /// ┌───┐
309    /// ┤ H ├──■──
310    /// ├───┤┌─┴─┐
311    /// ┤ H ├┤ X ├
312    /// └───┘└───┘
313    fn make_dfg_hugr() -> Result<Hugr, BuildError> {
314        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?;
315        let [wire0, wire1] = dfg_builder.input_wires_arr();
316        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
317        let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
318        let wire45 =
319            dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
320        dfg_builder.finish_hugr_with_outputs(wire45.outputs())
321    }
322
323    #[fixture]
324    pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr {
325        make_dfg_hugr().unwrap()
326    }
327
328    /// Creates a hugr with a DFG root like the following:
329    /// ─────
330    /// ┌───┐
331    /// ┤ H ├
332    /// └───┘
333    fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
334        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
335
336        let [wire0, wire1] = dfg_builder.input_wires_arr();
337        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
338        let wire2out = wire2.outputs().exactly_one().unwrap();
339        let wireoutvec = vec![wire0, wire2out];
340        dfg_builder.finish_hugr_with_outputs(wireoutvec)
341    }
342
343    #[fixture]
344    pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr {
345        make_dfg_hugr2().unwrap()
346    }
347
348    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
349    ///                     ┌─────────┐
350    ///                ┌────┤ (1) NOT ├──
351    ///  ┌─────────┐   │    └─────────┘
352    /// ─┤ (0) NOT ├───┤
353    ///  └─────────┘   │    ┌─────────┐
354    ///                └────┤ (2) NOT ├──
355    ///                     └─────────┘
356    /// This can be replaced with an empty hugr coping the input to both outputs.
357    ///
358    /// Returns the hugr and the nodes of the NOT gates, in order.
359    #[fixture]
360    pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
361        let mut dfg_builder =
362            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
363        let [b] = dfg_builder.input_wires_arr();
364
365        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
366        let [b] = not_inp.outputs_arr();
367
368        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
369        let [b0] = not_0.outputs_arr();
370        let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
371        let [b1] = not_1.outputs_arr();
372
373        (
374            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
375            vec![not_inp.node(), not_0.node(), not_1.node()],
376        )
377    }
378
379    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
380    ///                     ┌─────────┐
381    ///                ┌────┤ (1) NOT ├──
382    ///  ┌─────────┐   │    └─────────┘
383    /// ─┤ (0) NOT ├───┤
384    ///  └─────────┘   │
385    ///                └─────────────────
386    ///
387    /// This can be replaced with a single NOT op, coping the input to the first output.
388    ///
389    /// Returns the hugr and the nodes of the NOT ops, in order.
390    #[fixture]
391    pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
392        let mut dfg_builder =
393            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
394        let [b] = dfg_builder.input_wires_arr();
395
396        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
397        let [b] = not_inp.outputs_arr();
398
399        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
400        let [b0] = not_0.outputs_arr();
401        let b1 = b;
402
403        (
404            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
405            vec![not_inp.node(), not_0.node()],
406        )
407    }
408
409    #[rstest]
410    /// Replace the
411    ///      ┌───┐
412    /// ──■──┤ H ├
413    /// ┌─┴─┐├───┤
414    /// ┤ X ├┤ H ├
415    /// └───┘└───┘
416    /// part of
417    /// ┌───┐     ┌───┐
418    /// ┤ H ├──■──┤ H ├
419    /// ├───┤┌─┴─┐├───┤
420    /// ┤ H ├┤ X ├┤ H ├
421    /// └───┘└───┘└───┘
422    /// with
423    /// ┌───┐
424    /// ┤ H ├──■──
425    /// ├───┤┌─┴─┐
426    /// ┤ H ├┤ X ├
427    /// └───┘└───┘
428    fn test_simple_replacement(
429        simple_hugr: Hugr,
430        dfg_hugr: Hugr,
431        #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
432    ) {
433        let mut h: Hugr = simple_hugr;
434        // 1. Locate the CX and its successor H's in h
435        let h_node_cx: Node = h
436            .nodes()
437            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
438            .unwrap();
439        let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
440        let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
441        // 2. Construct a new DFG-rooted hugr for the replacement
442        let n: Hugr = dfg_hugr;
443        // 3. Construct the input and output matchings
444        // 3.1. Locate the CX and its predecessor H's in n
445        let n_node_cx = n
446            .nodes()
447            .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
448            .unwrap();
449        let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap();
450        // 3.2. Locate the ports we need to specify as "glue" in n
451        let n_port_0 = n.node_inputs(n_node_h0).next().unwrap();
452        let n_port_1 = n.node_inputs(n_node_h1).next().unwrap();
453        let (n_cx_out_0, n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
454        let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
455        let n_port_3 = n.linked_inputs(n_node_cx, n_cx_out_1).next().unwrap().1;
456        // 3.3. Locate the ports we need to specify as "glue" in h
457        let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
458        let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
459        let h_h1_out = h.node_outputs(h_node_h1).next().unwrap();
460        let (h_outp_node, h_port_2) = h.linked_inputs(h_node_h0, h_h0_out).next().unwrap();
461        let h_port_3 = h.linked_inputs(h_node_h1, h_h1_out).next().unwrap().1;
462        // 3.4. Construct the maps
463        let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
464        let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
465        nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0));
466        nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1));
467        nu_out.insert((h_outp_node, h_port_2), n_port_2);
468        nu_out.insert((h_outp_node, h_port_3), n_port_3);
469        // 4. Define the replacement
470        let r = SimpleReplacement {
471            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
472            replacement: n,
473            nu_inp,
474            nu_out,
475        };
476        assert_eq!(
477            HashSet::<_>::from_iter(r.invalidation_set()),
478            HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]),
479        );
480
481        applicator(&mut h, r);
482        // Expect [DFG] to be replaced with:
483        // ┌───┐┌───┐
484        // ┤ H ├┤ H ├──■──
485        // ├───┤├───┤┌─┴─┐
486        // ┤ H ├┤ H ├┤ X ├
487        // └───┘└───┘└───┘
488        assert_eq!(h.validate(), Ok(()));
489    }
490
491    #[rstest]
492    /// Replace the
493    ///
494    /// ──■──
495    /// ┌─┴─┐
496    /// ┤ X ├
497    /// └───┘
498    /// part of
499    /// ┌───┐     ┌───┐
500    /// ┤ H ├──■──┤ H ├
501    /// ├───┤┌─┴─┐├───┤
502    /// ┤ H ├┤ X ├┤ H ├
503    /// └───┘└───┘└───┘
504    /// with
505    /// ─────
506    /// ┌───┐
507    /// ┤ H ├
508    /// └───┘
509    fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
510        let mut h: Hugr = simple_hugr;
511
512        // 1. Locate the CX in h
513        let h_node_cx: Node = h
514            .nodes()
515            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
516            .unwrap();
517        let s: Vec<Node> = vec![h_node_cx].into_iter().collect();
518        // 2. Construct a new DFG-rooted hugr for the replacement
519        let n: Hugr = dfg_hugr2;
520        // 3. Construct the input and output matchings
521        // 3.1. Locate the Output and its predecessor H in n
522        let n_node_output = n
523            .nodes()
524            .find(|node: &Node| n.get_optype(*node).tag() == OpTag::Output)
525            .unwrap();
526        let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
527        // 3.2. Locate the ports we need to specify as "glue" in n
528        let (n_port_0, n_port_1) = n
529            .node_inputs(n_node_output)
530            .take(2)
531            .collect_tuple()
532            .unwrap();
533        let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
534        // 3.3. Locate the ports we need to specify as "glue" in h
535        let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
536        let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
537        let h_port_2 = h.node_inputs(h_node_h0).next().unwrap();
538        let h_port_3 = h.node_inputs(h_node_h1).next().unwrap();
539        // 3.4. Construct the maps
540        let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
541        let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
542        nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
543        nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
544        nu_out.insert((h_node_h0, h_port_2), n_port_0);
545        nu_out.insert((h_node_h1, h_port_3), n_port_1);
546        // 4. Define the replacement
547        let r = SimpleReplacement {
548            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
549            replacement: n,
550            nu_inp,
551            nu_out,
552        };
553        h.apply_rewrite(r).unwrap();
554        // Expect [DFG] to be replaced with:
555        // ┌───┐┌───┐
556        // ┤ H ├┤ H ├
557        // ├───┤├───┤┌───┐
558        // ┤ H ├┤ H ├┤ H ├
559        // └───┘└───┘└───┘
560        assert_eq!(h.validate(), Ok(()));
561    }
562
563    #[test]
564    fn test_replace_cx_cross() {
565        let q_row: Vec<Type> = vec![qb_t(), qb_t()];
566        let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
567        let mut circ = builder.as_circuit(builder.input_wires());
568        circ.append(cx_gate(), [0, 1]).unwrap();
569        circ.append(cx_gate(), [1, 0]).unwrap();
570        let wires = circ.finish();
571        let [input, output] = builder.io();
572        let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
573        let replacement = h.clone();
574        let orig = h.clone();
575
576        let removal = h
577            .nodes()
578            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
579            .collect_vec();
580        let inputs = h
581            .node_outputs(input)
582            .filter(|&p| {
583                h.get_optype(input)
584                    .as_input()
585                    .unwrap()
586                    .signature()
587                    .port_type(p)
588                    .is_some()
589            })
590            .map(|p| {
591                let link = h.linked_inputs(input, p).next().unwrap();
592                (link, link)
593            })
594            .collect();
595        let outputs = h
596            .node_inputs(output)
597            .filter(|&p| {
598                h.get_optype(output)
599                    .as_output()
600                    .unwrap()
601                    .signature()
602                    .port_type(p)
603                    .is_some()
604            })
605            .map(|p| ((output, p), p))
606            .collect();
607        h.apply_rewrite(SimpleReplacement::new(
608            SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
609            replacement,
610            inputs,
611            outputs,
612        ))
613        .unwrap();
614
615        // They should be the same, up to node indices
616        assert_eq!(h.edge_count(), orig.edge_count());
617    }
618
619    #[test]
620    fn test_replace_after_copy() {
621        let one_bit = vec![bool_t()];
622        let two_bit = vec![bool_t(), bool_t()];
623
624        let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
625        let inw = builder.input_wires().exactly_one().unwrap();
626        let outw = builder
627            .add_dataflow_op(and_op(), [inw, inw])
628            .unwrap()
629            .outputs();
630        let [input, _] = builder.io();
631        let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
632
633        let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
634        let inw = builder.input_wires();
635        let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
636        let [repl_input, repl_output] = builder.io();
637        let repl = builder.finish_hugr_with_outputs(outw).unwrap();
638
639        let orig = h.clone();
640
641        let removal = h
642            .nodes()
643            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
644            .collect_vec();
645
646        let first_out_p = h.node_outputs(input).next().unwrap();
647        let embedded_inputs = h.linked_inputs(input, first_out_p);
648        let repl_inputs = repl
649            .node_outputs(repl_input)
650            .map(|p| repl.linked_inputs(repl_input, p).next().unwrap());
651        let inputs = embedded_inputs.zip(repl_inputs).collect();
652
653        let outputs = repl
654            .node_inputs(repl_output)
655            .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some())
656            .map(|p| ((repl_output, p), p))
657            .collect();
658
659        h.apply_rewrite(SimpleReplacement::new(
660            SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
661            repl,
662            inputs,
663            outputs,
664        ))
665        .unwrap();
666
667        // Nothing changed
668        assert_eq!(h.node_count(), orig.node_count());
669    }
670
671    /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the input
672    /// directly to the outputs.
673    ///
674    /// https://github.com/CQCL/hugr/issues/1190
675    #[rstest]
676    fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
677        let (mut hugr, nodes) = dfg_hugr_copy_bools;
678        let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
679
680        let [_input, output] = hugr.get_io(hugr.root()).unwrap();
681
682        let replacement = {
683            let b =
684                DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
685            let [w] = b.input_wires_arr();
686            b.finish_hugr_with_outputs([w, w]).unwrap()
687        };
688        let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
689
690        let subgraph =
691            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
692                .unwrap();
693        // A map from (target ports of edges from the Input node of `replacement`) to (target ports of
694        // edges from nodes not in `removal` to nodes in `removal`).
695        let nu_inp = [
696            (
697                (repl_output, IncomingPort::from(0)),
698                (input_not, IncomingPort::from(0)),
699            ),
700            (
701                (repl_output, IncomingPort::from(1)),
702                (input_not, IncomingPort::from(0)),
703            ),
704        ]
705        .into_iter()
706        .collect();
707        // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
708        // (input ports of the Output node of `replacement`).
709        let nu_out = [
710            ((output, IncomingPort::from(0)), IncomingPort::from(0)),
711            ((output, IncomingPort::from(1)), IncomingPort::from(1)),
712        ]
713        .into_iter()
714        .collect();
715
716        let rewrite = SimpleReplacement {
717            subgraph,
718            replacement,
719            nu_inp,
720            nu_out,
721        };
722        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
723
724        assert_eq!(hugr.validate(), Ok(()));
725        assert_eq!(hugr.node_count(), 3);
726    }
727
728    /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting the input
729    /// directly to the output.
730    ///
731    /// https://github.com/CQCL/hugr/issues/1323
732    #[rstest]
733    fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
734        let (mut hugr, nodes) = dfg_hugr_half_not_bools;
735        let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
736
737        let [_input, output] = hugr.get_io(hugr.root()).unwrap();
738
739        let (replacement, repl_not) = {
740            let mut b =
741                DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
742            let [w] = b.input_wires_arr();
743            let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
744            let [w_not] = not.outputs_arr();
745            (b.finish_hugr_with_outputs([w, w_not]).unwrap(), not.node())
746        };
747        let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
748
749        let subgraph =
750            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
751        // A map from (target ports of edges from the Input node of `replacement`) to (target ports of
752        // edges from nodes not in `removal` to nodes in `removal`).
753        let nu_inp = [
754            (
755                (repl_output, IncomingPort::from(0)),
756                (input_not, IncomingPort::from(0)),
757            ),
758            (
759                (repl_not, IncomingPort::from(0)),
760                (input_not, IncomingPort::from(0)),
761            ),
762        ]
763        .into_iter()
764        .collect();
765        // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
766        // (input ports of the Output node of `replacement`).
767        let nu_out = [
768            ((output, IncomingPort::from(0)), IncomingPort::from(0)),
769            ((output, IncomingPort::from(1)), IncomingPort::from(1)),
770        ]
771        .into_iter()
772        .collect();
773
774        let rewrite = SimpleReplacement {
775            subgraph,
776            replacement,
777            nu_inp,
778            nu_out,
779        };
780        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
781
782        assert_eq!(hugr.validate(), Ok(()));
783        assert_eq!(hugr.node_count(), 4);
784    }
785
786    #[rstest]
787    fn test_nested_replace(dfg_hugr2: Hugr) {
788        // replace a node with a hugr with children
789
790        let mut h = dfg_hugr2;
791        let h_node = h
792            .nodes()
793            .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
794            .unwrap();
795
796        // build a nested identity dfg
797        let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
798        let [input] = nest_build.input_wires_arr();
799        let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
800        let inner_dfg = n_identity(inner_build).unwrap();
801        let inner_dfg_node = inner_dfg.node();
802        let replacement = nest_build
803            .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
804            .unwrap();
805        let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
806        let nu_inp = vec![(
807            (inner_dfg_node, IncomingPort::from(0)),
808            (h_node, IncomingPort::from(0)),
809        )]
810        .into_iter()
811        .collect();
812
813        let nu_out = vec![(
814            (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)),
815            IncomingPort::from(0),
816        )]
817        .into_iter()
818        .collect();
819
820        let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);
821
822        assert_eq!(h.node_count(), 4);
823
824        rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
825        h.validate().unwrap_or_else(|e| panic!("{e}"));
826
827        assert_eq!(h.node_count(), 6);
828    }
829
830    use crate::hugr::rewrite::replace::Replacement;
831    fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement {
832        use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};
833
834        let mut replacement = s.replacement;
835        let (in_, out) = replacement
836            .children(replacement.root())
837            .take(2)
838            .collect_tuple()
839            .unwrap();
840        let mu_inp = s
841            .nu_inp
842            .iter()
843            .map(|((tgt, tgt_port), (r_n, r_p))| {
844                if *tgt == out {
845                    unimplemented!()
846                };
847                let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap();
848                NewEdgeSpec {
849                    src,
850                    tgt: *tgt,
851                    kind: NewEdgeKind::Value {
852                        src_pos: src_port,
853                        tgt_pos: *tgt_port,
854                    },
855                }
856            })
857            .collect();
858        let mu_out = s
859            .nu_out
860            .iter()
861            .map(|((tgt, tgt_port), out_port)| {
862                let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap();
863                if src == in_ {
864                    unimplemented!()
865                };
866                NewEdgeSpec {
867                    src,
868                    tgt: *tgt,
869                    kind: NewEdgeKind::Value {
870                        src_pos: src_port,
871                        tgt_pos: *tgt_port,
872                    },
873                }
874            })
875            .collect();
876        replacement.remove_node(in_);
877        replacement.remove_node(out);
878        Replacement {
879            removal: s.subgraph.nodes().to_vec(),
880            replacement,
881            adoptions: HashMap::new(),
882            mu_inp,
883            mu_out,
884            mu_new: vec![],
885        }
886    }
887
888    fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
889        h.apply_rewrite(rw).unwrap();
890    }
891
892    fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
893        h.apply_rewrite(to_replace(h, rw)).unwrap();
894    }
895}