hugr_core/hugr/patch/
inline_dfg.rs

1//! A rewrite that inlines a DFG node, moving all children
2//! of the DFG except Input+Output into the DFG's parent,
3//! and deleting the DFG along with its Input + Output
4
5use super::{PatchHugrMut, PatchVerification};
6use crate::ops::handle::{DfgID, NodeHandle};
7use crate::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex};
8
9/// Structure identifying an `InlineDFG` rewrite from the spec
10pub struct InlineDFG(pub DfgID);
11
12/// Errors from an [`InlineDFG`] rewrite.
13#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
14#[non_exhaustive]
15pub enum InlineDFGError {
16    /// Node to inline was not a DFG. (E.g. node has been overwritten since the `DfgID` originated.)
17    #[error("{node} was not a DFG")]
18    NotDFG {
19        /// The node we tried to inline
20        node: Node,
21    },
22    /// The DFG node is the hugr entrypoint
23    #[error("Cannot inline the entrypoint node, {node}")]
24    CantInlineEntrypoint {
25        /// The node we tried to inline
26        node: Node,
27    },
28}
29
30impl PatchVerification for InlineDFG {
31    type Error = InlineDFGError;
32
33    type Node = Node;
34
35    fn verify(&self, h: &impl crate::HugrView<Node = Node>) -> Result<(), Self::Error> {
36        let n = self.0.node();
37        if h.get_optype(n).as_dfg().is_none() {
38            return Err(InlineDFGError::NotDFG { node: n });
39        }
40        if n == h.entrypoint() {
41            return Err(InlineDFGError::CantInlineEntrypoint { node: n });
42        }
43        Ok(())
44    }
45
46    fn invalidated_nodes(
47        &self,
48        _: &impl HugrView<Node = Self::Node>,
49    ) -> impl Iterator<Item = Self::Node> {
50        [self.0.node()].into_iter()
51    }
52}
53
54impl PatchHugrMut for InlineDFG {
55    /// The removed nodes: the DFG, and its Input and Output children.
56    type Outcome = [Node; 3];
57
58    const UNCHANGED_ON_FAILURE: bool = true;
59
60    fn apply_hugr_mut(
61        self,
62        h: &mut impl crate::hugr::HugrMut<Node = Node>,
63    ) -> Result<Self::Outcome, Self::Error> {
64        self.verify(h)?;
65        let n = self.0.node();
66        let (oth_in, oth_out) = {
67            let dfg_ty = h.get_optype(n);
68            (
69                dfg_ty.other_input_port().unwrap(),
70                dfg_ty.other_output_port().unwrap(),
71            )
72        };
73        let parent = h.get_parent(n).unwrap();
74        let [input, output] = h.get_io(n).unwrap();
75        for ch in h.children(n).skip(2).collect::<Vec<_>>() {
76            h.set_parent(ch, parent);
77        }
78        // DFG Inputs. Deal with Order inputs first
79        for (src_n, src_p) in h.linked_outputs(n, oth_in).collect::<Vec<_>>() {
80            // Order edge from src_n to DFG => add order edge to each successor of Input node
81            debug_assert_eq!(Some(src_p), h.get_optype(src_n).other_output_port());
82            for tgt_n in h.output_neighbours(input).collect::<Vec<_>>() {
83                h.add_other_edge(src_n, tgt_n);
84            }
85        }
86        // And remaining (Value) inputs
87        let input_ord_succs = h
88            .linked_inputs(input, h.get_optype(input).other_output_port().unwrap())
89            .collect::<Vec<_>>();
90        for inp in h.node_inputs(n).collect::<Vec<_>>() {
91            if inp == oth_in {
92                continue;
93            }
94            // Hugr is invalid if there is no output linked to the DFG input.
95            let (src_n, src_p) = h.single_linked_output(n, inp).unwrap();
96            h.disconnect(n, inp); // These disconnects allow permutations to work trivially.
97            let outp = OutgoingPort::from(inp.index());
98            let targets = h.linked_inputs(input, outp).collect::<Vec<_>>();
99            h.disconnect(input, outp);
100
101            for (tgt_n, tgt_p) in targets {
102                h.connect(src_n, src_p, tgt_n, tgt_p);
103            }
104            // Ensure order-successors of Input node execute after any node producing an input
105            for (tgt, _) in &input_ord_succs {
106                h.add_other_edge(src_n, *tgt);
107            }
108        }
109        // DFG Outputs. Deal with Order outputs first.
110        for (tgt_n, tgt_p) in h.linked_inputs(n, oth_out).collect::<Vec<_>>() {
111            debug_assert_eq!(Some(tgt_p), h.get_optype(tgt_n).other_input_port());
112            for src_n in h.input_neighbours(output).collect::<Vec<_>>() {
113                h.add_other_edge(src_n, tgt_n);
114            }
115        }
116        // And remaining (Value) outputs
117        let output_ord_preds = h
118            .linked_outputs(output, h.get_optype(output).other_input_port().unwrap())
119            .collect::<Vec<_>>();
120        for outport in h.node_outputs(n).collect::<Vec<_>>() {
121            if outport == oth_out {
122                continue;
123            }
124            let inpp = IncomingPort::from(outport.index());
125            // Hugr is invalid if the Output node has no corresponding input
126            let (src_n, src_p) = h.single_linked_output(output, inpp).unwrap();
127            h.disconnect(output, inpp);
128
129            for (tgt_n, tgt_p) in h.linked_inputs(n, outport).collect::<Vec<_>>() {
130                h.connect(src_n, src_p, tgt_n, tgt_p);
131                // Ensure order-predecessors of Output node execute before any node consuming a DFG output
132                for (src, _) in &output_ord_preds {
133                    h.add_other_edge(*src, tgt_n);
134                }
135            }
136            h.disconnect(n, outport);
137        }
138        h.remove_node(input);
139        h.remove_node(output);
140        assert!(h.children(n).next().is_none());
141        h.remove_node(n);
142        Ok([n, input, output])
143    }
144}
145
146#[cfg(test)]
147mod test {
148    use std::collections::HashSet;
149
150    use rstest::rstest;
151
152    use crate::builder::{
153        Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
154        endo_sig, inout_sig,
155    };
156    use crate::extension::prelude::qb_t;
157    use crate::hugr::HugrMut;
158    use crate::ops::handle::{DfgID, NodeHandle};
159    use crate::ops::{OpType, Value};
160    use crate::std_extensions::arithmetic::float_types;
161    use crate::std_extensions::arithmetic::int_ops::IntOpDef;
162    use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
163    use crate::types::Signature;
164    use crate::utils::test_quantum_extension;
165    use crate::{Direction, HugrView, Port, type_row};
166    use crate::{Hugr, Wire};
167
168    use super::InlineDFG;
169
170    fn find_dfgs<H: HugrView>(h: &H) -> Vec<H::Node> {
171        h.entry_descendants()
172            .filter(|n| h.get_optype(*n).as_dfg().is_some())
173            .collect()
174    }
175    fn extension_ops<H: HugrView>(h: &H) -> Vec<H::Node> {
176        h.nodes()
177            .filter(|n| matches!(h.get_optype(*n), OpType::ExtensionOp(_)))
178            .collect()
179    }
180
181    #[rstest]
182    #[case(true)]
183    #[case(false)]
184    fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
185        use crate::hugr::patch::inline_dfg::InlineDFGError;
186
187        let int_ty = &int_types::INT_TYPES[6];
188
189        let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
190        let [a, b] = outer.input_wires_arr();
191        fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
192            d: &mut DFGBuilder<T>,
193        ) -> Result<Wire, Box<dyn std::error::Error>> {
194            let cst = Value::extension(ConstInt::new_u(6, 15)?);
195            let c1 = d.add_load_const(cst);
196
197            Ok(c1)
198        }
199        let c1 = nonlocal.then(|| make_const(&mut outer));
200        let inner = {
201            let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
202            let [a] = inner.input_wires_arr();
203            let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
204            let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
205            inner.finish_with_outputs(a1.outputs())?
206        };
207        let [a1] = inner.outputs_arr();
208
209        let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
210        let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs())?;
211
212        // Sanity checks
213        assert_eq!(
214            outer.children(inner.node()).count(),
215            if nonlocal { 3 } else { 5 }
216        ); // Input, Output, add; + const, load_const
217        assert_eq!(find_dfgs(&outer), vec![outer.entrypoint(), inner.node()]);
218        let [add, sub] = extension_ops(&outer).try_into().unwrap();
219        assert_eq!(
220            outer.get_parent(outer.get_parent(add).unwrap()),
221            outer.get_parent(sub)
222        );
223        assert_eq!(outer.entry_descendants().count(), 10); // 6 above + inner DFG + outer (DFG + Input + Output + sub)
224        {
225            // Check we can't inline the outer DFG
226            let mut h = outer.clone();
227            assert_eq!(
228                h.apply_patch(InlineDFG(DfgID::from(h.entrypoint()))),
229                Err(InlineDFGError::CantInlineEntrypoint {
230                    node: h.entrypoint()
231                })
232            );
233            assert_eq!(h, outer); // unchanged
234        }
235
236        outer.apply_patch(InlineDFG(*inner.handle()))?;
237        outer.validate()?;
238        assert_eq!(outer.entry_descendants().count(), 7);
239        assert_eq!(find_dfgs(&outer), vec![outer.entrypoint()]);
240        let [add, sub] = extension_ops(&outer).try_into().unwrap();
241        assert_eq!(outer.get_parent(add), Some(outer.entrypoint()));
242        assert_eq!(outer.get_parent(sub), Some(outer.entrypoint()));
243        assert_eq!(
244            outer.node_connections(add, sub).collect::<Vec<_>>().len(),
245            1
246        );
247        Ok(())
248    }
249
250    #[test]
251    fn permutation() -> Result<(), Box<dyn std::error::Error>> {
252        let mut h = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
253        let [p, q] = h.input_wires_arr();
254        let [p_h] = h
255            .add_dataflow_op(test_quantum_extension::h_gate(), [p])?
256            .outputs_arr();
257        let swap = {
258            let swap = h.dfg_builder(Signature::new_endo(vec![qb_t(), qb_t()]), [p_h, q])?;
259            let [a, b] = swap.input_wires_arr();
260            swap.finish_with_outputs([b, a])?
261        };
262        let [q, p] = swap.outputs_arr();
263        let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
264
265        let mut h = h.finish_hugr_with_outputs(cx.outputs())?;
266        assert_eq!(find_dfgs(&h), vec![h.entrypoint(), swap.node()]);
267        assert_eq!(h.entry_descendants().count(), 8); // Dfg+I+O, H, CX, Dfg+I+O
268        // No permutation outside the swap DFG:
269        assert_eq!(
270            h.node_connections(p_h.node(), swap.node())
271                .collect::<Vec<_>>(),
272            vec![[
273                Port::new(Direction::Outgoing, 0),
274                Port::new(Direction::Incoming, 0)
275            ]]
276        );
277        assert_eq!(
278            h.node_connections(swap.node(), cx.node())
279                .collect::<Vec<_>>(),
280            vec![
281                [
282                    Port::new(Direction::Outgoing, 0),
283                    Port::new(Direction::Incoming, 0)
284                ],
285                [
286                    Port::new(Direction::Outgoing, 1),
287                    Port::new(Direction::Incoming, 1)
288                ]
289            ]
290        );
291
292        h.apply_patch(InlineDFG(*swap.handle()))?;
293        assert_eq!(find_dfgs(&h), vec![h.entrypoint()]);
294        assert_eq!(h.entry_descendants().count(), 5); // Dfg+I+O
295        let mut ops = extension_ops(&h);
296        ops.sort_by_key(|n| h.num_outputs(*n)); // Put H before CX
297        let [h_gate, cx] = ops.try_into().unwrap();
298        // Now permutation exists:
299        assert_eq!(
300            h.node_connections(h_gate, cx).collect::<Vec<_>>(),
301            vec![[
302                Port::new(Direction::Outgoing, 0),
303                Port::new(Direction::Incoming, 1)
304            ]]
305        );
306        Ok(())
307    }
308
309    #[test]
310    fn order_edges() -> Result<(), Box<dyn std::error::Error>> {
311        /*      -----|-----|-----
312         *           |     |
313         *          H_a   H_b
314         *           |.    /         NB. Order edge H_a to nested DFG
315         *           | .  |
316         *           |  /-|--------\
317         *           |  | | .  Cst | NB. Order edge Input to LCst
318         *           |  | |  . |   |
319         *           |  | |   LCst |
320         *           |  |  \ /     |
321         *           |  |  RZ      |
322         *           |  |  |       |
323         *           |  |  meas    |
324         *           |  |  | \     |
325         *           |  |  |  if   |
326         *           |  |  |  .    | NB. Order edge if to Output
327         *           |  \--|-------/
328         *           |  .  |
329         *           | .   |         NB. Order edge nested DFG to H_a2
330         *           H_a2  /
331         *             \  /
332         *              CX
333         */
334        // Extension inference here relies on quantum ops not requiring their own test_quantum_extension
335        let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
336        let [a, b] = outer.input_wires_arr();
337        let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
338        let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
339        let mut inner = outer.dfg_builder(endo_sig(qb_t()), h_b.outputs())?;
340        let [i] = inner.input_wires_arr();
341        let f = inner.add_load_value(float_types::ConstF64::new(1.0));
342        inner.add_other_wire(inner.input().node(), f.node());
343        let r = inner.add_dataflow_op(test_quantum_extension::rz_f64(), [i, f])?;
344        let [m, b] = inner
345            .add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
346            .outputs_arr();
347        // Node using the boolean. Here we just select between two empty computations.
348        let mut if_n =
349            inner.conditional_builder(([type_row![], type_row![]], b), [], type_row![])?;
350        if_n.case_builder(0)?.finish_with_outputs([])?;
351        if_n.case_builder(1)?.finish_with_outputs([])?;
352        let if_n = if_n.finish_sub_container()?;
353        inner.add_other_wire(if_n.node(), inner.output().node());
354        let inner = inner.finish_with_outputs([m])?;
355        outer.add_other_wire(h_a.node(), inner.node());
356        let h_a2 = outer.add_dataflow_op(test_quantum_extension::h_gate(), h_a.outputs())?;
357        outer.add_other_wire(inner.node(), h_a2.node());
358        let cx = outer.add_dataflow_op(
359            test_quantum_extension::cx_gate(),
360            h_a2.outputs().chain(inner.outputs()),
361        )?;
362        let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?;
363
364        outer.apply_patch(InlineDFG(*inner.handle()))?;
365        outer.validate()?;
366        let order_neighbours = |n, d| {
367            let p = outer.get_optype(n).other_port(d).unwrap();
368            outer
369                .linked_ports(n, p)
370                .map(|(n, _)| n)
371                .collect::<HashSet<_>>()
372        };
373        // h_a should have Order edges added to Rz and the F64 load_const
374        assert_eq!(
375            order_neighbours(h_a.node(), Direction::Outgoing),
376            HashSet::from([r.node(), f.node()])
377        );
378        // Likewise the load_const should have Order edges from the inputs to the inner DFG, i.e. h_a and h_b
379        assert_eq!(
380            order_neighbours(f.node(), Direction::Incoming),
381            HashSet::from([h_a.node(), h_b.node()])
382        );
383        // h_a2 should have Order edges from the measure and if
384        assert_eq!(
385            order_neighbours(h_a2.node(), Direction::Incoming),
386            HashSet::from([m.node(), if_n.node()])
387        );
388        // the if should have Order edges to the CX and h_a2
389        assert_eq!(
390            order_neighbours(if_n.node(), Direction::Outgoing),
391            HashSet::from([h_a2.node(), cx.node()])
392        );
393        Ok(())
394    }
395}