hugr_core/hugr/rewrite/
simple_replace.rs

1//! Implementation of the `SimpleReplace` operation.
2
3use std::collections::HashMap;
4
5use crate::core::HugrNode;
6use crate::hugr::hugrmut::InsertionResult;
7pub use crate::hugr::internal::HugrMutInternals;
8use crate::hugr::views::SiblingSubgraph;
9use crate::hugr::{HugrMut, HugrView, Rewrite};
10use crate::ops::{OpTag, OpTrait, OpType};
11use crate::{Hugr, IncomingPort, Node, OutgoingPort};
12
13use itertools::Itertools;
14
15use thiserror::Error;
16
17use super::inline_dfg::InlineDFGError;
18use super::{BoundaryPort, HostPort, ReplacementPort};
19
20/// Specification of a simple replacement operation.
21///
22/// # Type parameters
23///
24/// - `N`: The type of nodes in the host hugr.
25#[derive(Debug, Clone)]
26pub struct SimpleReplacement<HostNode = Node> {
27    /// The subgraph of the host hugr to be replaced.
28    subgraph: SiblingSubgraph<HostNode>,
29    /// A hugr with DFG root (consisting of replacement nodes).
30    replacement: Hugr,
31    /// A map from (target ports of edges from the Input node of `replacement`)
32    /// to (target ports of edges from nodes not in `subgraph` to nodes in `subgraph`).
33    nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
34    /// A map from (target ports of edges from nodes in `subgraph` to nodes not
35    /// in `subgraph`) to (input ports of the Output node of `replacement`).
36    nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>,
37}
38
39impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
40    /// Create a new [`SimpleReplacement`] specification.
41    #[inline]
42    pub fn new(
43        subgraph: SiblingSubgraph<HostNode>,
44        replacement: Hugr,
45        nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>,
46        nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>,
47    ) -> Self {
48        Self {
49            subgraph,
50            replacement,
51            nu_inp,
52            nu_out,
53        }
54    }
55
56    /// The replacement hugr.
57    #[inline]
58    pub fn replacement(&self) -> &Hugr {
59        &self.replacement
60    }
61
62    /// Consume self and return the replacement hugr.
63    #[inline]
64    pub fn into_replacement(self) -> Hugr {
65        self.replacement
66    }
67
68    /// Subgraph to be replaced.
69    #[inline]
70    pub fn subgraph(&self) -> &SiblingSubgraph<HostNode> {
71        &self.subgraph
72    }
73
74    /// Check if the replacement can be applied to the given hugr.
75    pub fn is_valid_rewrite(
76        &self,
77        h: &impl HugrView<Node = HostNode>,
78    ) -> Result<(), SimpleReplacementError> {
79        let parent = self.subgraph.get_parent(h);
80
81        // 1. Check the parent node exists and is a DataflowParent.
82        if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
83            return Err(SimpleReplacementError::InvalidParentNode());
84        }
85
86        // 2. Check that all the to-be-removed nodes are children of it and are leaves.
87        for node in self.subgraph.nodes() {
88            if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
89                return Err(SimpleReplacementError::InvalidRemovedNode());
90            }
91        }
92
93        Ok(())
94    }
95
96    /// Get the input and output nodes of the replacement hugr.
97    pub fn get_replacement_io(&self) -> Result<[Node; 2], SimpleReplacementError> {
98        self.replacement
99            .get_io(self.replacement.root())
100            .ok_or(SimpleReplacementError::InvalidParentNode())
101    }
102
103    /// Get all edges that the replacement would add from outgoing ports in
104    /// `host` to incoming ports in `self.replacement`.
105    ///
106    /// The incoming ports returned are always connected to outputs of
107    /// the [`OpTag::Input`] node of `self.replacement`.
108    ///
109    /// For each pair in the returned vector, the first element is a port in
110    /// `host` and the second is a port in `self.replacement`.
111    pub fn incoming_boundary<'a>(
112        &'a self,
113        host: &'a impl HugrView<Node = HostNode>,
114    ) -> impl Iterator<
115        Item = (
116            HostPort<HostNode, OutgoingPort>,
117            ReplacementPort<IncomingPort>,
118        ),
119    > + 'a {
120        // For each p = self.nu_inp[q] such that q is not an Output port,
121        // there will be an edge from the predecessor of p to (the new copy of) q.
122        self.nu_inp
123            .iter()
124            .filter(|&((rep_inp_node, _), _)| {
125                self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
126            })
127            .map(
128                |(&(rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
129                    // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
130                    let (rem_inp_pred_node, rem_inp_pred_port) = host
131                        .single_linked_output(*rem_inp_node, *rem_inp_port)
132                        .unwrap();
133                    (
134                        HostPort(rem_inp_pred_node, rem_inp_pred_port),
135                        ReplacementPort(rep_inp_node, rep_inp_port),
136                    )
137                },
138            )
139    }
140
141    /// Get all edges that the replacement would add from outgoing ports in
142    /// `self.replacement` to incoming ports in `host`.
143    ///
144    /// The outgoing ports returned are always connected to inputs of
145    /// the [`OpTag::Output`] node of `self.replacement`.
146    ///
147    /// For each pair in the returned vector, the first element is a port in
148    /// `self.replacement` and the second is a port in `host`.
149    ///
150    /// This panics if self.replacement is not a DFG.
151    pub fn outgoing_boundary<'a>(
152        &'a self,
153        _host: &'a impl HugrView<Node = HostNode>,
154    ) -> impl Iterator<
155        Item = (
156            ReplacementPort<OutgoingPort>,
157            HostPort<HostNode, IncomingPort>,
158        ),
159    > + 'a {
160        let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG");
161
162        // For each q = self.nu_out[p] such that the predecessor of q is not an Input port,
163        // there will be an edge from (the new copy of) the predecessor of q to p.
164        self.nu_out
165            .iter()
166            .filter_map(move |(&(rem_out_node, rem_out_port), rep_out_port)| {
167                let (rep_out_pred_node, rep_out_pred_port) = self
168                    .replacement
169                    .single_linked_output(replacement_output_node, *rep_out_port)
170                    .unwrap();
171                (self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
172                    (
173                        // the new output node will be updated after insertion
174                        ReplacementPort(rep_out_pred_node, rep_out_pred_port),
175                        HostPort(rem_out_node, rem_out_port),
176                    )
177                })
178            })
179    }
180
181    /// Get all edges that the replacement would add between ports in `host`.
182    ///
183    /// These correspond to direct edges between the input and output nodes
184    /// in the replacement graph.
185    ///
186    /// For each pair in the returned vector, the both ports are in `host`.
187    ///
188    /// This panics if self.replacement is not a DFG.
189    pub fn host_to_host_boundary<'a>(
190        &'a self,
191        host: &'a impl HugrView<Node = HostNode>,
192    ) -> impl Iterator<
193        Item = (
194            HostPort<HostNode, OutgoingPort>,
195            HostPort<HostNode, IncomingPort>,
196        ),
197    > + 'a {
198        let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG");
199
200        // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
201        // to p1.
202        self.nu_out
203            .iter()
204            .filter_map(move |(&(rem_out_node, rem_out_port), &rep_out_port)| {
205                self.nu_inp
206                    .get(&(replacement_output_node, rep_out_port))
207                    .map(|&(rem_inp_node, rem_inp_port)| {
208                        let (rem_inp_pred_node, rem_inp_pred_port) = host
209                            .single_linked_output(rem_inp_node, rem_inp_port)
210                            .unwrap();
211                        (
212                            HostPort(rem_inp_pred_node, rem_inp_pred_port),
213                            HostPort(rem_out_node, rem_out_port),
214                        )
215                    })
216            })
217    }
218
219    /// Get the incoming port at the output node of `self.replacement` that
220    /// corresponds to the given host output port.
221    ///
222    /// This panics if self.replacement is not a DFG.
223    pub fn map_host_output(
224        &self,
225        port: impl Into<HostPort<HostNode, IncomingPort>>,
226    ) -> Option<ReplacementPort<IncomingPort>> {
227        let HostPort(node, port) = port.into();
228        let [_, rep_output] = self.get_replacement_io().expect("replacement is a DFG");
229        self.nu_out
230            .get(&(node, port))
231            .map(|&rep_out_port| ReplacementPort(rep_output, rep_out_port))
232    }
233
234    /// Get the incoming port in `subgraph` that corresponds to the given
235    /// replacement input port.
236    ///
237    /// This panics if self.replacement is not a DFG.
238    pub fn map_replacement_input(
239        &self,
240        port: impl Into<ReplacementPort<IncomingPort>>,
241    ) -> Option<HostPort<HostNode, IncomingPort>> {
242        let ReplacementPort(node, port) = port.into();
243        self.nu_inp.get(&(node, port)).copied().map(Into::into)
244    }
245
246    /// Get all edges that the replacement would add between `host` and
247    /// `self.replacement`.
248    ///
249    /// This is equivalent to chaining the results of [`Self::incoming_boundary`],
250    /// [`Self::outgoing_boundary`], and [`Self::host_to_host_boundary`].
251    ///
252    /// This panics if self.replacement is not a DFG.
253    pub fn all_boundary_edges<'a>(
254        &'a self,
255        host: &'a impl HugrView<Node = HostNode>,
256    ) -> impl Iterator<
257        Item = (
258            BoundaryPort<HostNode, OutgoingPort>,
259            BoundaryPort<HostNode, IncomingPort>,
260        ),
261    > + 'a {
262        let incoming_boundary = self
263            .incoming_boundary(host)
264            .map(|(src, tgt)| (src.into(), tgt.into()));
265        let outgoing_boundary = self
266            .outgoing_boundary(host)
267            .map(|(src, tgt)| (src.into(), tgt.into()));
268        let host_to_host_boundary = self
269            .host_to_host_boundary(host)
270            .map(|(src, tgt)| (src.into(), tgt.into()));
271
272        incoming_boundary
273            .chain(outgoing_boundary)
274            .chain(host_to_host_boundary)
275    }
276}
277
278impl Rewrite for SimpleReplacement {
279    type Error = SimpleReplacementError;
280    type ApplyResult = Vec<(Node, OpType)>;
281    const UNCHANGED_ON_FAILURE: bool = true;
282
283    fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), SimpleReplacementError> {
284        self.is_valid_rewrite(h)
285    }
286
287    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
288        self.is_valid_rewrite(h)?;
289
290        let parent = self.subgraph.get_parent(h);
291
292        // We proceed to connect the edges between the newly inserted
293        // replacement and the rest of the graph.
294        //
295        // Existing connections to the removed subgraph will be automatically
296        // removed when the nodes are removed.
297
298        // 1. Get the boundary edges
299        let boundary_edges = self.all_boundary_edges(h).collect_vec();
300
301        let Self {
302            replacement,
303            subgraph,
304            ..
305        } = self;
306
307        // 2. Insert the replacement as a whole.
308        let InsertionResult {
309            new_root,
310            node_map: index_map,
311        } = h.insert_hugr(parent, replacement);
312
313        // remove the Input and Output nodes from the replacement graph
314        let replace_children = h.children(new_root).collect::<Vec<Node>>();
315        for &io in &replace_children[..2] {
316            h.remove_node(io);
317        }
318        // make all replacement top level children children of the parent
319        for &child in &replace_children[2..] {
320            h.set_parent(child, parent);
321        }
322        // remove the replacement root (which now has no children and no edges)
323        h.remove_node(new_root);
324
325        // 3. Insert all boundary edges.
326        for (src, tgt) in boundary_edges {
327            let (src_node, src_port) = src.map_replacement(&index_map);
328            let (tgt_node, tgt_port) = tgt.map_replacement(&index_map);
329            h.connect(src_node, src_port, tgt_node, tgt_port);
330        }
331
332        // 4. Remove all nodes in subgraph and edges between them.
333        Ok(subgraph
334            .nodes()
335            .iter()
336            .map(|&node| (node, h.remove_node(node)))
337            .collect())
338    }
339
340    #[inline]
341    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
342        let subcirc = self.subgraph.nodes().iter().copied();
343        let out_neighs = self.nu_out.keys().map(|key| key.0);
344        subcirc.chain(out_neighs)
345    }
346}
347
348/// Error from a [`SimpleReplacement`] operation.
349#[derive(Debug, Clone, Error, PartialEq, Eq)]
350#[non_exhaustive]
351pub enum SimpleReplacementError {
352    /// Invalid parent node.
353    #[error("Parent node is invalid.")]
354    InvalidParentNode(),
355    /// Node requested for removal is invalid.
356    #[error("A node requested for removal is invalid.")]
357    InvalidRemovedNode(),
358    /// Node in replacement graph is invalid.
359    #[error("A node in the replacement graph is invalid.")]
360    InvalidReplacementNode(),
361    /// Inlining replacement failed.
362    #[error("Inlining replacement failed: {0}")]
363    InliningFailed(#[from] InlineDFGError),
364}
365
366#[cfg(test)]
367pub(in crate::hugr::rewrite) mod test {
368    use itertools::Itertools;
369    use rstest::{fixture, rstest};
370    use std::collections::{HashMap, HashSet};
371
372    use crate::builder::test::n_identity;
373    use crate::builder::{
374        endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
375        DataflowSubContainer, HugrBuilder, ModuleBuilder,
376    };
377    use crate::extension::prelude::{bool_t, qb_t};
378    use crate::extension::ExtensionSet;
379    use crate::hugr::views::{HugrView, SiblingSubgraph};
380    use crate::hugr::{Hugr, HugrMut, Rewrite};
381    use crate::ops::dataflow::DataflowOpTrait;
382    use crate::ops::handle::NodeHandle;
383    use crate::ops::OpTag;
384    use crate::ops::OpTrait;
385    use crate::std_extensions::logic::test::and_op;
386    use crate::std_extensions::logic::LogicOp;
387    use crate::types::{Signature, Type};
388    use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID};
389    use crate::{IncomingPort, Node};
390
391    use super::SimpleReplacement;
392
393    /// Creates a hugr like the following:
394    /// --   H   --
395    /// -- [DFG] --
396    /// where [DFG] is:
397    /// ┌───┐     ┌───┐
398    /// ┤ H ├──■──┤ H ├
399    /// ├───┤┌─┴─┐├───┤
400    /// ┤ H ├┤ X ├┤ H ├
401    /// └───┘└───┘└───┘
402    fn make_hugr() -> Result<Hugr, BuildError> {
403        let mut module_builder = ModuleBuilder::new();
404        let _f_id = {
405            let just_q: ExtensionSet = EXTENSION_ID.into();
406            let mut func_builder = module_builder.define_function(
407                "main",
408                Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
409                    .with_extension_delta(just_q.clone()),
410            )?;
411
412            let [qb0, qb1, qb2] = func_builder.input_wires_arr();
413
414            let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?;
415
416            let mut inner_builder =
417                func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?;
418            let inner_graph = {
419                let [wire0, wire1] = inner_builder.input_wires_arr();
420                let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?;
421                let wire3 = inner_builder.add_dataflow_op(h_gate(), vec![wire1])?;
422                let wire45 = inner_builder
423                    .add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
424                let [wire4, wire5] = wire45.outputs_arr();
425                let wire6 = inner_builder.add_dataflow_op(h_gate(), vec![wire4])?;
426                let wire7 = inner_builder.add_dataflow_op(h_gate(), vec![wire5])?;
427                inner_builder.finish_with_outputs(wire6.outputs().chain(wire7.outputs()))
428            }?;
429
430            func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))?
431        };
432        Ok(module_builder.finish_hugr()?)
433    }
434
435    #[fixture]
436    pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr {
437        make_hugr().unwrap()
438    }
439    /// Creates a hugr with a DFG root like the following:
440    /// ┌───┐
441    /// ┤ H ├──■──
442    /// ├───┤┌─┴─┐
443    /// ┤ H ├┤ X ├
444    /// └───┘└───┘
445    fn make_dfg_hugr() -> Result<Hugr, BuildError> {
446        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?;
447        let [wire0, wire1] = dfg_builder.input_wires_arr();
448        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?;
449        let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
450        let wire45 =
451            dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?;
452        dfg_builder.finish_hugr_with_outputs(wire45.outputs())
453    }
454
455    #[fixture]
456    pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr {
457        make_dfg_hugr().unwrap()
458    }
459
460    /// Creates a hugr with a DFG root like the following:
461    /// ─────
462    /// ┌───┐
463    /// ┤ H ├
464    /// └───┘
465    fn make_dfg_hugr2() -> Result<Hugr, BuildError> {
466        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
467
468        let [wire0, wire1] = dfg_builder.input_wires_arr();
469        let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?;
470        let wire2out = wire2.outputs().exactly_one().unwrap();
471        let wireoutvec = vec![wire0, wire2out];
472        dfg_builder.finish_hugr_with_outputs(wireoutvec)
473    }
474
475    #[fixture]
476    pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr {
477        make_dfg_hugr2().unwrap()
478    }
479
480    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
481    ///                     ┌─────────┐
482    ///                ┌────┤ (1) NOT ├──
483    ///  ┌─────────┐   │    └─────────┘
484    /// ─┤ (0) NOT ├───┤
485    ///  └─────────┘   │    ┌─────────┐
486    ///                └────┤ (2) NOT ├──
487    ///                     └─────────┘
488    /// This can be replaced with an empty hugr coping the input to both outputs.
489    ///
490    /// Returns the hugr and the nodes of the NOT gates, in order.
491    #[fixture]
492    pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
493        let mut dfg_builder =
494            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
495        let [b] = dfg_builder.input_wires_arr();
496
497        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
498        let [b] = not_inp.outputs_arr();
499
500        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
501        let [b0] = not_0.outputs_arr();
502        let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
503        let [b1] = not_1.outputs_arr();
504
505        (
506            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
507            vec![not_inp.node(), not_0.node(), not_1.node()],
508        )
509    }
510
511    /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t())
512    ///                     ┌─────────┐
513    ///                ┌────┤ (1) NOT ├──
514    ///  ┌─────────┐   │    └─────────┘
515    /// ─┤ (0) NOT ├───┤
516    ///  └─────────┘   │
517    ///                └─────────────────
518    ///
519    /// This can be replaced with a single NOT op, coping the input to the first output.
520    ///
521    /// Returns the hugr and the nodes of the NOT ops, in order.
522    #[fixture]
523    pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
524        let mut dfg_builder =
525            DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
526        let [b] = dfg_builder.input_wires_arr();
527
528        let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
529        let [b] = not_inp.outputs_arr();
530
531        let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap();
532        let [b0] = not_0.outputs_arr();
533        let b1 = b;
534
535        (
536            dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(),
537            vec![not_inp.node(), not_0.node()],
538        )
539    }
540
541    #[rstest]
542    /// Replace the
543    ///      ┌───┐
544    /// ──■──┤ H ├
545    /// ┌─┴─┐├───┤
546    /// ┤ X ├┤ H ├
547    /// └───┘└───┘
548    /// part of
549    /// ┌───┐     ┌───┐
550    /// ┤ H ├──■──┤ H ├
551    /// ├───┤┌─┴─┐├───┤
552    /// ┤ H ├┤ X ├┤ H ├
553    /// └───┘└───┘└───┘
554    /// with
555    /// ┌───┐
556    /// ┤ H ├──■──
557    /// ├───┤┌─┴─┐
558    /// ┤ H ├┤ X ├
559    /// └───┘└───┘
560    fn test_simple_replacement(
561        simple_hugr: Hugr,
562        dfg_hugr: Hugr,
563        #[values(apply_simple, apply_replace)] applicator: impl Fn(&mut Hugr, SimpleReplacement),
564    ) {
565        let mut h: Hugr = simple_hugr;
566        // 1. Locate the CX and its successor H's in h
567        let h_node_cx: Node = h
568            .nodes()
569            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
570            .unwrap();
571        let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
572        let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
573        // 2. Construct a new DFG-rooted hugr for the replacement
574        let n: Hugr = dfg_hugr;
575        // 3. Construct the input and output matchings
576        // 3.1. Locate the CX and its predecessor H's in n
577        let n_node_cx = n
578            .nodes()
579            .find(|node: &Node| *n.get_optype(*node) == cx_gate().into())
580            .unwrap();
581        let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap();
582        // 3.2. Locate the ports we need to specify as "glue" in n
583        let n_port_0 = n.node_inputs(n_node_h0).next().unwrap();
584        let n_port_1 = n.node_inputs(n_node_h1).next().unwrap();
585        let (n_cx_out_0, n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap();
586        let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1;
587        let n_port_3 = n.linked_inputs(n_node_cx, n_cx_out_1).next().unwrap().1;
588        // 3.3. Locate the ports we need to specify as "glue" in h
589        let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
590        let h_h0_out = h.node_outputs(h_node_h0).next().unwrap();
591        let h_h1_out = h.node_outputs(h_node_h1).next().unwrap();
592        let (h_outp_node, h_port_2) = h.linked_inputs(h_node_h0, h_h0_out).next().unwrap();
593        let h_port_3 = h.linked_inputs(h_node_h1, h_h1_out).next().unwrap().1;
594        // 3.4. Construct the maps
595        let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
596        let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
597        nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0));
598        nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1));
599        nu_out.insert((h_outp_node, h_port_2), n_port_2);
600        nu_out.insert((h_outp_node, h_port_3), n_port_3);
601        // 4. Define the replacement
602        let r = SimpleReplacement {
603            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
604            replacement: n,
605            nu_inp,
606            nu_out,
607        };
608        assert_eq!(
609            HashSet::<_>::from_iter(r.invalidation_set()),
610            HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]),
611        );
612
613        applicator(&mut h, r);
614        // Expect [DFG] to be replaced with:
615        // ┌───┐┌───┐
616        // ┤ H ├┤ H ├──■──
617        // ├───┤├───┤┌─┴─┐
618        // ┤ H ├┤ H ├┤ X ├
619        // └───┘└───┘└───┘
620        assert_eq!(h.validate(), Ok(()));
621    }
622
623    #[rstest]
624    /// Replace the
625    ///
626    /// ──■──
627    /// ┌─┴─┐
628    /// ┤ X ├
629    /// └───┘
630    /// part of
631    /// ┌───┐     ┌───┐
632    /// ┤ H ├──■──┤ H ├
633    /// ├───┤┌─┴─┐├───┤
634    /// ┤ H ├┤ X ├┤ H ├
635    /// └───┘└───┘└───┘
636    /// with
637    /// ─────
638    /// ┌───┐
639    /// ┤ H ├
640    /// └───┘
641    fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
642        let mut h: Hugr = simple_hugr;
643
644        // 1. Locate the CX in h
645        let h_node_cx: Node = h
646            .nodes()
647            .find(|node: &Node| *h.get_optype(*node) == cx_gate().into())
648            .unwrap();
649        let s: Vec<Node> = vec![h_node_cx].into_iter().collect();
650        // 2. Construct a new DFG-rooted hugr for the replacement
651        let n: Hugr = dfg_hugr2;
652        // 3. Construct the input and output matchings
653        // 3.1. Locate the Output and its predecessor H in n
654        let n_node_output = n
655            .nodes()
656            .find(|node: &Node| n.get_optype(*node).tag() == OpTag::Output)
657            .unwrap();
658        let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
659        // 3.2. Locate the ports we need to specify as "glue" in n
660        let (n_port_0, n_port_1) = n
661            .node_inputs(n_node_output)
662            .take(2)
663            .collect_tuple()
664            .unwrap();
665        let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
666        // 3.3. Locate the ports we need to specify as "glue" in h
667        let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
668        let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
669        let h_port_2 = h.node_inputs(h_node_h0).next().unwrap();
670        let h_port_3 = h.node_inputs(h_node_h1).next().unwrap();
671        // 3.4. Construct the maps
672        let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new();
673        let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new();
674        nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
675        nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
676        nu_out.insert((h_node_h0, h_port_2), n_port_0);
677        nu_out.insert((h_node_h1, h_port_3), n_port_1);
678        // 4. Define the replacement
679        let r = SimpleReplacement {
680            subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
681            replacement: n,
682            nu_inp,
683            nu_out,
684        };
685        h.apply_rewrite(r).unwrap();
686        // Expect [DFG] to be replaced with:
687        // ┌───┐┌───┐
688        // ┤ H ├┤ H ├
689        // ├───┤├───┤┌───┐
690        // ┤ H ├┤ H ├┤ H ├
691        // └───┘└───┘└───┘
692        assert_eq!(h.validate(), Ok(()));
693    }
694
695    #[test]
696    fn test_replace_cx_cross() {
697        let q_row: Vec<Type> = vec![qb_t(), qb_t()];
698        let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap();
699        let mut circ = builder.as_circuit(builder.input_wires());
700        circ.append(cx_gate(), [0, 1]).unwrap();
701        circ.append(cx_gate(), [1, 0]).unwrap();
702        let wires = circ.finish();
703        let [input, output] = builder.io();
704        let mut h = builder.finish_hugr_with_outputs(wires).unwrap();
705        let replacement = h.clone();
706        let orig = h.clone();
707
708        let removal = h
709            .nodes()
710            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
711            .collect_vec();
712        let inputs = h
713            .node_outputs(input)
714            .filter(|&p| {
715                h.get_optype(input)
716                    .as_input()
717                    .unwrap()
718                    .signature()
719                    .port_type(p)
720                    .is_some()
721            })
722            .map(|p| {
723                let link = h.linked_inputs(input, p).next().unwrap();
724                (link, link)
725            })
726            .collect();
727        let outputs = h
728            .node_inputs(output)
729            .filter(|&p| {
730                h.get_optype(output)
731                    .as_output()
732                    .unwrap()
733                    .signature()
734                    .port_type(p)
735                    .is_some()
736            })
737            .map(|p| ((output, p), p))
738            .collect();
739        h.apply_rewrite(SimpleReplacement::new(
740            SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
741            replacement,
742            inputs,
743            outputs,
744        ))
745        .unwrap();
746
747        // They should be the same, up to node indices
748        assert_eq!(h.edge_count(), orig.edge_count());
749    }
750
751    #[test]
752    fn test_replace_after_copy() {
753        let one_bit = vec![bool_t()];
754        let two_bit = vec![bool_t(), bool_t()];
755
756        let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap();
757        let inw = builder.input_wires().exactly_one().unwrap();
758        let outw = builder
759            .add_dataflow_op(and_op(), [inw, inw])
760            .unwrap()
761            .outputs();
762        let [input, _] = builder.io();
763        let mut h = builder.finish_hugr_with_outputs(outw).unwrap();
764
765        let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap();
766        let inw = builder.input_wires();
767        let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs();
768        let [repl_input, repl_output] = builder.io();
769        let repl = builder.finish_hugr_with_outputs(outw).unwrap();
770
771        let orig = h.clone();
772
773        let removal = h
774            .nodes()
775            .filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
776            .collect_vec();
777
778        let first_out_p = h.node_outputs(input).next().unwrap();
779        let embedded_inputs = h.linked_inputs(input, first_out_p);
780        let repl_inputs = repl
781            .node_outputs(repl_input)
782            .map(|p| repl.linked_inputs(repl_input, p).next().unwrap());
783        let inputs = embedded_inputs.zip(repl_inputs).collect();
784
785        let outputs = repl
786            .node_inputs(repl_output)
787            .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some())
788            .map(|p| ((repl_output, p), p))
789            .collect();
790
791        h.apply_rewrite(SimpleReplacement::new(
792            SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
793            repl,
794            inputs,
795            outputs,
796        ))
797        .unwrap();
798
799        // Nothing changed
800        assert_eq!(h.node_count(), orig.node_count());
801    }
802
803    /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the input
804    /// directly to the outputs.
805    ///
806    /// https://github.com/CQCL/hugr/issues/1190
807    #[rstest]
808    fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
809        let (mut hugr, nodes) = dfg_hugr_copy_bools;
810        let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();
811
812        let [_input, output] = hugr.get_io(hugr.root()).unwrap();
813
814        let replacement = {
815            let b =
816                DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
817            let [w] = b.input_wires_arr();
818            b.finish_hugr_with_outputs([w, w]).unwrap()
819        };
820        let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
821
822        let subgraph =
823            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
824                .unwrap();
825        // A map from (target ports of edges from the Input node of `replacement`) to (target ports of
826        // edges from nodes not in `removal` to nodes in `removal`).
827        let nu_inp = [
828            (
829                (repl_output, IncomingPort::from(0)),
830                (input_not, IncomingPort::from(0)),
831            ),
832            (
833                (repl_output, IncomingPort::from(1)),
834                (input_not, IncomingPort::from(0)),
835            ),
836        ]
837        .into_iter()
838        .collect();
839        // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
840        // (input ports of the Output node of `replacement`).
841        let nu_out = [
842            ((output, IncomingPort::from(0)), IncomingPort::from(0)),
843            ((output, IncomingPort::from(1)), IncomingPort::from(1)),
844        ]
845        .into_iter()
846        .collect();
847
848        let rewrite = SimpleReplacement {
849            subgraph,
850            replacement,
851            nu_inp,
852            nu_out,
853        };
854        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
855
856        assert_eq!(hugr.validate(), Ok(()));
857        assert_eq!(hugr.node_count(), 3);
858    }
859
860    /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting the input
861    /// directly to the output.
862    ///
863    /// https://github.com/CQCL/hugr/issues/1323
864    #[rstest]
865    fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
866        let (mut hugr, nodes) = dfg_hugr_half_not_bools;
867        let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();
868
869        let [_input, output] = hugr.get_io(hugr.root()).unwrap();
870
871        let (replacement, repl_not) = {
872            let mut b =
873                DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap();
874            let [w] = b.input_wires_arr();
875            let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
876            let [w_not] = not.outputs_arr();
877            (b.finish_hugr_with_outputs([w, w_not]).unwrap(), not.node())
878        };
879        let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();
880
881        let subgraph =
882            SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
883        // A map from (target ports of edges from the Input node of `replacement`) to (target ports of
884        // edges from nodes not in `removal` to nodes in `removal`).
885        let nu_inp = [
886            (
887                (repl_output, IncomingPort::from(0)),
888                (input_not, IncomingPort::from(0)),
889            ),
890            (
891                (repl_not, IncomingPort::from(0)),
892                (input_not, IncomingPort::from(0)),
893            ),
894        ]
895        .into_iter()
896        .collect();
897        // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
898        // (input ports of the Output node of `replacement`).
899        let nu_out = [
900            ((output, IncomingPort::from(0)), IncomingPort::from(0)),
901            ((output, IncomingPort::from(1)), IncomingPort::from(1)),
902        ]
903        .into_iter()
904        .collect();
905
906        let rewrite = SimpleReplacement {
907            subgraph,
908            replacement,
909            nu_inp,
910            nu_out,
911        };
912        rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
913
914        assert_eq!(hugr.validate(), Ok(()));
915        assert_eq!(hugr.node_count(), 4);
916    }
917
918    #[rstest]
919    fn test_nested_replace(dfg_hugr2: Hugr) {
920        // replace a node with a hugr with children
921
922        let mut h = dfg_hugr2;
923        let h_node = h
924            .nodes()
925            .find(|node: &Node| *h.get_optype(*node) == h_gate().into())
926            .unwrap();
927
928        // build a nested identity dfg
929        let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
930        let [input] = nest_build.input_wires_arr();
931        let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap();
932        let inner_dfg = n_identity(inner_build).unwrap();
933        let inner_dfg_node = inner_dfg.node();
934        let replacement = nest_build
935            .finish_hugr_with_outputs([inner_dfg.out_wire(0)])
936            .unwrap();
937        let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
938        let nu_inp = vec![(
939            (inner_dfg_node, IncomingPort::from(0)),
940            (h_node, IncomingPort::from(0)),
941        )]
942        .into_iter()
943        .collect();
944
945        let nu_out = vec![(
946            (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)),
947            IncomingPort::from(0),
948        )]
949        .into_iter()
950        .collect();
951
952        let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);
953
954        assert_eq!(h.node_count(), 4);
955
956        rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
957        h.validate().unwrap_or_else(|e| panic!("{e}"));
958
959        assert_eq!(h.node_count(), 6);
960    }
961
962    use crate::hugr::rewrite::replace::Replacement;
963    fn to_replace(h: &impl HugrView<Node = Node>, s: SimpleReplacement) -> Replacement {
964        use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};
965
966        let mut replacement = s.replacement;
967        let (in_, out) = replacement
968            .children(replacement.root())
969            .take(2)
970            .collect_tuple()
971            .unwrap();
972        let mu_inp = s
973            .nu_inp
974            .iter()
975            .map(|((tgt, tgt_port), (r_n, r_p))| {
976                if *tgt == out {
977                    unimplemented!()
978                };
979                let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap();
980                NewEdgeSpec {
981                    src,
982                    tgt: *tgt,
983                    kind: NewEdgeKind::Value {
984                        src_pos: src_port,
985                        tgt_pos: *tgt_port,
986                    },
987                }
988            })
989            .collect();
990        let mu_out = s
991            .nu_out
992            .iter()
993            .map(|((tgt, tgt_port), out_port)| {
994                let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap();
995                if src == in_ {
996                    unimplemented!()
997                };
998                NewEdgeSpec {
999                    src,
1000                    tgt: *tgt,
1001                    kind: NewEdgeKind::Value {
1002                        src_pos: src_port,
1003                        tgt_pos: *tgt_port,
1004                    },
1005                }
1006            })
1007            .collect();
1008        replacement.remove_node(in_);
1009        replacement.remove_node(out);
1010        Replacement {
1011            removal: s.subgraph.nodes().to_vec(),
1012            replacement,
1013            adoptions: HashMap::new(),
1014            mu_inp,
1015            mu_out,
1016            mu_new: vec![],
1017        }
1018    }
1019
1020    fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
1021        h.apply_rewrite(rw).unwrap();
1022    }
1023
1024    fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
1025        h.apply_rewrite(to_replace(h, rw)).unwrap();
1026    }
1027}