hugr_core/hugr/patch/
inline_dfg.rs

1//! A rewrite that inlines a DFG node, moving all children
2//! of the DFG except Input+Output into the DFG's parent,
3//! and deleting the DFG along with its Input + Output
4
5use super::{PatchHugrMut, PatchVerification};
6use crate::core::HugrNode;
7use crate::ops::handle::{DfgID, NodeHandle};
8use crate::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex};
9
10/// Structure identifying an `InlineDFG` rewrite from the spec
11pub struct InlineDFG<N = Node>(pub DfgID<N>);
12
13/// Errors from an [`InlineDFG`] rewrite.
14#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
15#[non_exhaustive]
16pub enum InlineDFGError<N = Node> {
17    /// Node to inline was not a DFG. (E.g. node has been overwritten since the `DfgID` originated.)
18    #[error("{node} was not a DFG")]
19    NotDFG {
20        /// The node we tried to inline
21        node: N,
22    },
23    /// The DFG node is the hugr entrypoint
24    #[error("Cannot inline the entrypoint node, {node}")]
25    CantInlineEntrypoint {
26        /// The node we tried to inline
27        node: N,
28    },
29}
30
31impl<N: HugrNode> PatchVerification for InlineDFG<N> {
32    type Error = InlineDFGError<N>;
33
34    type Node = N;
35
36    fn verify(&self, h: &impl crate::HugrView<Node = N>) -> Result<(), Self::Error> {
37        let n = self.0.node();
38        if h.get_optype(n).as_dfg().is_none() {
39            return Err(InlineDFGError::NotDFG { node: n });
40        }
41        if n == h.entrypoint() {
42            return Err(InlineDFGError::CantInlineEntrypoint { node: n });
43        }
44        Ok(())
45    }
46
47    fn invalidated_nodes(
48        &self,
49        _: &impl HugrView<Node = Self::Node>,
50    ) -> impl Iterator<Item = Self::Node> {
51        [self.0.node()].into_iter()
52    }
53}
54
55impl<N: HugrNode> PatchHugrMut for InlineDFG<N> {
56    /// The removed nodes: the DFG, and its Input and Output children.
57    type Outcome = [N; 3];
58
59    const UNCHANGED_ON_FAILURE: bool = true;
60
61    fn apply_hugr_mut(
62        self,
63        h: &mut impl crate::hugr::HugrMut<Node = N>,
64    ) -> Result<Self::Outcome, Self::Error> {
65        self.verify(h)?;
66        let n = self.0.node();
67        let (oth_in, oth_out) = {
68            let dfg_ty = h.get_optype(n);
69            (
70                dfg_ty.other_input_port().unwrap(),
71                dfg_ty.other_output_port().unwrap(),
72            )
73        };
74        let parent = h.get_parent(n).unwrap();
75        let [input, output] = h.get_io(n).unwrap();
76        for ch in h.children(n).skip(2).collect::<Vec<_>>() {
77            h.set_parent(ch, parent);
78        }
79        // DFG Inputs. Deal with Order inputs first
80        for (src_n, src_p) in h.linked_outputs(n, oth_in).collect::<Vec<_>>() {
81            // Order edge from src_n to DFG => add order edge to each successor of Input node
82            debug_assert_eq!(Some(src_p), h.get_optype(src_n).other_output_port());
83            for tgt_n in h.output_neighbours(input).collect::<Vec<_>>() {
84                h.add_other_edge(src_n, tgt_n);
85            }
86        }
87        // And remaining (Value) inputs
88        let input_ord_succs = h
89            .linked_inputs(input, h.get_optype(input).other_output_port().unwrap())
90            .collect::<Vec<_>>();
91        for inp in h.node_inputs(n).collect::<Vec<_>>() {
92            if inp == oth_in {
93                continue;
94            }
95            // Hugr is invalid if there is no output linked to the DFG input.
96            let (src_n, src_p) = h.single_linked_output(n, inp).unwrap();
97            h.disconnect(n, inp); // These disconnects allow permutations to work trivially.
98            let outp = OutgoingPort::from(inp.index());
99            let targets = h.linked_inputs(input, outp).collect::<Vec<_>>();
100            h.disconnect(input, outp);
101
102            for (tgt_n, tgt_p) in targets {
103                h.connect(src_n, src_p, tgt_n, tgt_p);
104            }
105            // Ensure order-successors of Input node execute after any node producing an input
106            for (tgt, _) in &input_ord_succs {
107                h.add_other_edge(src_n, *tgt);
108            }
109        }
110        // DFG Outputs. Deal with Order outputs first.
111        for (tgt_n, tgt_p) in h.linked_inputs(n, oth_out).collect::<Vec<_>>() {
112            debug_assert_eq!(Some(tgt_p), h.get_optype(tgt_n).other_input_port());
113            for src_n in h.input_neighbours(output).collect::<Vec<_>>() {
114                h.add_other_edge(src_n, tgt_n);
115            }
116        }
117        // And remaining (Value) outputs
118        let output_ord_preds = h
119            .linked_outputs(output, h.get_optype(output).other_input_port().unwrap())
120            .collect::<Vec<_>>();
121        for outport in h.node_outputs(n).collect::<Vec<_>>() {
122            if outport == oth_out {
123                continue;
124            }
125            let inpp = IncomingPort::from(outport.index());
126            // Hugr is invalid if the Output node has no corresponding input
127            let (src_n, src_p) = h.single_linked_output(output, inpp).unwrap();
128            h.disconnect(output, inpp);
129
130            for (tgt_n, tgt_p) in h.linked_inputs(n, outport).collect::<Vec<_>>() {
131                h.connect(src_n, src_p, tgt_n, tgt_p);
132                // Ensure order-predecessors of Output node execute before any node consuming a DFG output
133                for (src, _) in &output_ord_preds {
134                    h.add_other_edge(*src, tgt_n);
135                }
136            }
137            h.disconnect(n, outport);
138        }
139        h.remove_node(input);
140        h.remove_node(output);
141        assert!(h.children(n).next().is_none());
142        h.remove_node(n);
143        Ok([n, input, output])
144    }
145}
146
147#[cfg(test)]
148mod test {
149    use std::collections::HashSet;
150
151    use rstest::rstest;
152
153    use crate::builder::{
154        Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
155        endo_sig, inout_sig,
156    };
157    use crate::extension::prelude::qb_t;
158    use crate::hugr::HugrMut;
159    use crate::ops::handle::{DfgID, NodeHandle};
160    use crate::ops::{OpType, Value};
161    use crate::std_extensions::arithmetic::float_types;
162    use crate::std_extensions::arithmetic::int_ops::IntOpDef;
163    use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
164    use crate::types::Signature;
165    use crate::utils::test_quantum_extension;
166    use crate::{Direction, HugrView, Port, type_row};
167    use crate::{Hugr, Wire};
168
169    use super::InlineDFG;
170
171    fn find_dfgs<H: HugrView>(h: &H) -> Vec<H::Node> {
172        h.entry_descendants()
173            .filter(|n| h.get_optype(*n).as_dfg().is_some())
174            .collect()
175    }
176    fn extension_ops<H: HugrView>(h: &H) -> Vec<H::Node> {
177        h.nodes()
178            .filter(|n| matches!(h.get_optype(*n), OpType::ExtensionOp(_)))
179            .collect()
180    }
181
182    #[rstest]
183    #[case(true)]
184    #[case(false)]
185    fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
186        use crate::hugr::patch::inline_dfg::InlineDFGError;
187
188        let int_ty = &int_types::INT_TYPES[6];
189
190        let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
191        let [a, b] = outer.input_wires_arr();
192        fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
193            d: &mut DFGBuilder<T>,
194        ) -> Result<Wire, Box<dyn std::error::Error>> {
195            let cst = Value::extension(ConstInt::new_u(6, 15)?);
196            let c1 = d.add_load_const(cst);
197
198            Ok(c1)
199        }
200        let c1 = nonlocal.then(|| make_const(&mut outer));
201        let inner = {
202            let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
203            let [a] = inner.input_wires_arr();
204            let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
205            let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
206            inner.finish_with_outputs(a1.outputs())?
207        };
208        let [a1] = inner.outputs_arr();
209
210        let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
211        let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs())?;
212
213        // Sanity checks
214        assert_eq!(
215            outer.children(inner.node()).count(),
216            if nonlocal { 3 } else { 5 }
217        ); // Input, Output, add; + const, load_const
218        assert_eq!(find_dfgs(&outer), vec![outer.entrypoint(), inner.node()]);
219        let [add, sub] = extension_ops(&outer).try_into().unwrap();
220        assert_eq!(
221            outer.get_parent(outer.get_parent(add).unwrap()),
222            outer.get_parent(sub)
223        );
224        assert_eq!(outer.entry_descendants().count(), 10); // 6 above + inner DFG + outer (DFG + Input + Output + sub)
225        {
226            // Check we can't inline the outer DFG
227            let mut h = outer.clone();
228            assert_eq!(
229                h.apply_patch(InlineDFG(DfgID::from(h.entrypoint()))),
230                Err(InlineDFGError::CantInlineEntrypoint {
231                    node: h.entrypoint()
232                })
233            );
234            assert_eq!(h, outer); // unchanged
235        }
236
237        outer.apply_patch(InlineDFG(*inner.handle()))?;
238        outer.validate()?;
239        assert_eq!(outer.entry_descendants().count(), 7);
240        assert_eq!(find_dfgs(&outer), vec![outer.entrypoint()]);
241        let [add, sub] = extension_ops(&outer).try_into().unwrap();
242        assert_eq!(outer.get_parent(add), Some(outer.entrypoint()));
243        assert_eq!(outer.get_parent(sub), Some(outer.entrypoint()));
244        assert_eq!(
245            outer.node_connections(add, sub).collect::<Vec<_>>().len(),
246            1
247        );
248        Ok(())
249    }
250
251    #[test]
252    fn permutation() -> Result<(), Box<dyn std::error::Error>> {
253        let mut h = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
254        let [p, q] = h.input_wires_arr();
255        let [p_h] = h
256            .add_dataflow_op(test_quantum_extension::h_gate(), [p])?
257            .outputs_arr();
258        let swap = {
259            let swap = h.dfg_builder(Signature::new_endo(vec![qb_t(), qb_t()]), [p_h, q])?;
260            let [a, b] = swap.input_wires_arr();
261            swap.finish_with_outputs([b, a])?
262        };
263        let [q, p] = swap.outputs_arr();
264        let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
265
266        let mut h = h.finish_hugr_with_outputs(cx.outputs())?;
267        assert_eq!(find_dfgs(&h), vec![h.entrypoint(), swap.node()]);
268        assert_eq!(h.entry_descendants().count(), 8); // Dfg+I+O, H, CX, Dfg+I+O
269        // No permutation outside the swap DFG:
270        assert_eq!(
271            h.node_connections(p_h.node(), swap.node())
272                .collect::<Vec<_>>(),
273            vec![[
274                Port::new(Direction::Outgoing, 0),
275                Port::new(Direction::Incoming, 0)
276            ]]
277        );
278        assert_eq!(
279            h.node_connections(swap.node(), cx.node())
280                .collect::<Vec<_>>(),
281            vec![
282                [
283                    Port::new(Direction::Outgoing, 0),
284                    Port::new(Direction::Incoming, 0)
285                ],
286                [
287                    Port::new(Direction::Outgoing, 1),
288                    Port::new(Direction::Incoming, 1)
289                ]
290            ]
291        );
292
293        h.apply_patch(InlineDFG(*swap.handle()))?;
294        assert_eq!(find_dfgs(&h), vec![h.entrypoint()]);
295        assert_eq!(h.entry_descendants().count(), 5); // Dfg+I+O
296        let mut ops = extension_ops(&h);
297        ops.sort_by_key(|n| h.num_outputs(*n)); // Put H before CX
298        let [h_gate, cx] = ops.try_into().unwrap();
299        // Now permutation exists:
300        assert_eq!(
301            h.node_connections(h_gate, cx).collect::<Vec<_>>(),
302            vec![[
303                Port::new(Direction::Outgoing, 0),
304                Port::new(Direction::Incoming, 1)
305            ]]
306        );
307        Ok(())
308    }
309
310    #[test]
311    fn order_edges() -> Result<(), Box<dyn std::error::Error>> {
312        /*      -----|-----|-----
313         *           |     |
314         *          H_a   H_b
315         *           |.    /         NB. Order edge H_a to nested DFG
316         *           | .  |
317         *           |  /-|--------\
318         *           |  | | .  Cst | NB. Order edge Input to LCst
319         *           |  | |  . |   |
320         *           |  | |   LCst |
321         *           |  |  \ /     |
322         *           |  |  RZ      |
323         *           |  |  |       |
324         *           |  |  meas    |
325         *           |  |  | \     |
326         *           |  |  |  if   |
327         *           |  |  |  .    | NB. Order edge if to Output
328         *           |  \--|-------/
329         *           |  .  |
330         *           | .   |         NB. Order edge nested DFG to H_a2
331         *           H_a2  /
332         *             \  /
333         *              CX
334         */
335        // Extension inference here relies on quantum ops not requiring their own test_quantum_extension
336        let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
337        let [a, b] = outer.input_wires_arr();
338        let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
339        let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
340        let mut inner = outer.dfg_builder(endo_sig(qb_t()), h_b.outputs())?;
341        let [i] = inner.input_wires_arr();
342        let f = inner.add_load_value(float_types::ConstF64::new(1.0));
343        inner.add_other_wire(inner.input().node(), f.node());
344        let r = inner.add_dataflow_op(test_quantum_extension::rz_f64(), [i, f])?;
345        let [m, b] = inner
346            .add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
347            .outputs_arr();
348        // Node using the boolean. Here we just select between two empty computations.
349        let mut if_n =
350            inner.conditional_builder(([type_row![], type_row![]], b), [], type_row![])?;
351        if_n.case_builder(0)?.finish_with_outputs([])?;
352        if_n.case_builder(1)?.finish_with_outputs([])?;
353        let if_n = if_n.finish_sub_container()?;
354        inner.add_other_wire(if_n.node(), inner.output().node());
355        let inner = inner.finish_with_outputs([m])?;
356        outer.add_other_wire(h_a.node(), inner.node());
357        let h_a2 = outer.add_dataflow_op(test_quantum_extension::h_gate(), h_a.outputs())?;
358        outer.add_other_wire(inner.node(), h_a2.node());
359        let cx = outer.add_dataflow_op(
360            test_quantum_extension::cx_gate(),
361            h_a2.outputs().chain(inner.outputs()),
362        )?;
363        let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?;
364
365        outer.apply_patch(InlineDFG(*inner.handle()))?;
366        outer.validate()?;
367        let order_neighbours = |n, d| {
368            let p = outer.get_optype(n).other_port(d).unwrap();
369            outer
370                .linked_ports(n, p)
371                .map(|(n, _)| n)
372                .collect::<HashSet<_>>()
373        };
374        // h_a should have Order edges added to Rz and the F64 load_const
375        assert_eq!(
376            order_neighbours(h_a.node(), Direction::Outgoing),
377            HashSet::from([r.node(), f.node()])
378        );
379        // Likewise the load_const should have Order edges from the inputs to the inner DFG, i.e. h_a and h_b
380        assert_eq!(
381            order_neighbours(f.node(), Direction::Incoming),
382            HashSet::from([h_a.node(), h_b.node()])
383        );
384        // h_a2 should have Order edges from the measure and if
385        assert_eq!(
386            order_neighbours(h_a2.node(), Direction::Incoming),
387            HashSet::from([m.node(), if_n.node()])
388        );
389        // the if should have Order edges to the CX and h_a2
390        assert_eq!(
391            order_neighbours(if_n.node(), Direction::Outgoing),
392            HashSet::from([h_a2.node(), cx.node()])
393        );
394        Ok(())
395    }
396}