hugr_core/hugr/rewrite/
inline_dfg.rs

1//! A rewrite that inlines a DFG node, moving all children
2//! of the DFG except Input+Output into the DFG's parent,
3//! and deleting the DFG along with its Input + Output
4
5use super::Rewrite;
6use crate::ops::handle::{DfgID, NodeHandle};
7use crate::{IncomingPort, Node, OutgoingPort, PortIndex};
8
9/// Structure identifying an `InlineDFG` rewrite from the spec
10pub struct InlineDFG(pub DfgID);
11
12/// Errors from an [InlineDFG] rewrite.
13#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
14#[non_exhaustive]
15pub enum InlineDFGError {
16    /// Node to inline was not a DFG. (E.g. node has been overwritten since the DfgID originated.)
17    #[error("Node {0} was not a DFG")]
18    NotDFG(Node),
19    /// DFG has no parent (is the root).
20    #[error("Node did not have a parent into which to inline")]
21    NoParent,
22}
23
24impl Rewrite for InlineDFG {
25    /// Returns the removed nodes: the DFG, and its Input and Output children.
26    type ApplyResult = [Node; 3];
27    type Error = InlineDFGError;
28
29    const UNCHANGED_ON_FAILURE: bool = true;
30
31    fn verify(&self, h: &impl crate::HugrView<Node = Node>) -> Result<(), Self::Error> {
32        let n = self.0.node();
33        if h.get_optype(n).as_dfg().is_none() {
34            return Err(InlineDFGError::NotDFG(n));
35        };
36        if h.get_parent(n).is_none() {
37            return Err(InlineDFGError::NoParent);
38        };
39        Ok(())
40    }
41
42    fn apply(self, h: &mut impl crate::hugr::HugrMut) -> Result<Self::ApplyResult, Self::Error> {
43        self.verify(h)?;
44        let n = self.0.node();
45        let (oth_in, oth_out) = {
46            let dfg_ty = h.get_optype(n);
47            (
48                dfg_ty.other_input_port().unwrap(),
49                dfg_ty.other_output_port().unwrap(),
50            )
51        };
52        let parent = h.get_parent(n).unwrap();
53        let [input, output] = h.get_io(n).unwrap();
54        for ch in h.children(n).skip(2).collect::<Vec<_>>().into_iter() {
55            h.set_parent(ch, parent);
56        }
57        // DFG Inputs. Deal with Order inputs first
58        for (src_n, src_p) in h.linked_outputs(n, oth_in).collect::<Vec<_>>() {
59            // Order edge from src_n to DFG => add order edge to each successor of Input node
60            debug_assert_eq!(Some(src_p), h.get_optype(src_n).other_output_port());
61            for tgt_n in h.output_neighbours(input).collect::<Vec<_>>() {
62                h.add_other_edge(src_n, tgt_n);
63            }
64        }
65        // And remaining (Value) inputs
66        let input_ord_succs = h
67            .linked_inputs(input, h.get_optype(input).other_output_port().unwrap())
68            .collect::<Vec<_>>();
69        for inp in h.node_inputs(n).collect::<Vec<_>>() {
70            if inp == oth_in {
71                continue;
72            };
73            // Hugr is invalid if there is no output linked to the DFG input.
74            let (src_n, src_p) = h.single_linked_output(n, inp).unwrap();
75            h.disconnect(n, inp); // These disconnects allow permutations to work trivially.
76            let outp = OutgoingPort::from(inp.index());
77            let targets = h.linked_inputs(input, outp).collect::<Vec<_>>();
78            h.disconnect(input, outp);
79
80            for (tgt_n, tgt_p) in targets {
81                h.connect(src_n, src_p, tgt_n, tgt_p);
82            }
83            // Ensure order-successors of Input node execute after any node producing an input
84            for (tgt, _) in input_ord_succs.iter() {
85                h.add_other_edge(src_n, *tgt);
86            }
87        }
88        // DFG Outputs. Deal with Order outputs first.
89        for (tgt_n, tgt_p) in h.linked_inputs(n, oth_out).collect::<Vec<_>>() {
90            debug_assert_eq!(Some(tgt_p), h.get_optype(tgt_n).other_input_port());
91            for src_n in h.input_neighbours(output).collect::<Vec<_>>() {
92                h.add_other_edge(src_n, tgt_n);
93            }
94        }
95        // And remaining (Value) outputs
96        let output_ord_preds = h
97            .linked_outputs(output, h.get_optype(output).other_input_port().unwrap())
98            .collect::<Vec<_>>();
99        for outport in h.node_outputs(n).collect::<Vec<_>>() {
100            if outport == oth_out {
101                continue;
102            };
103            let inpp = IncomingPort::from(outport.index());
104            // Hugr is invalid if the Output node has no corresponding input
105            let (src_n, src_p) = h.single_linked_output(output, inpp).unwrap();
106            h.disconnect(output, inpp);
107
108            for (tgt_n, tgt_p) in h.linked_inputs(n, outport).collect::<Vec<_>>() {
109                h.connect(src_n, src_p, tgt_n, tgt_p);
110                // Ensure order-predecessors of Output node execute before any node consuming a DFG output
111                for (src, _) in output_ord_preds.iter() {
112                    h.add_other_edge(*src, tgt_n);
113                }
114            }
115            h.disconnect(n, outport);
116        }
117        h.remove_node(input);
118        h.remove_node(output);
119        assert!(h.children(n).next().is_none());
120        h.remove_node(n);
121        Ok([n, input, output])
122    }
123
124    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
125        [self.0.node()].into_iter()
126    }
127}
128
129#[cfg(test)]
130mod test {
131    use std::collections::HashSet;
132
133    use rstest::rstest;
134
135    use crate::builder::{
136        endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
137        SubContainer,
138    };
139    use crate::extension::prelude::qb_t;
140    use crate::extension::ExtensionSet;
141    use crate::hugr::rewrite::inline_dfg::InlineDFGError;
142    use crate::hugr::HugrMut;
143    use crate::ops::handle::{DfgID, NodeHandle};
144    use crate::ops::{OpType, Value};
145    use crate::std_extensions::arithmetic::float_types;
146    use crate::std_extensions::arithmetic::int_ops::IntOpDef;
147    use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
148    use crate::types::Signature;
149    use crate::utils::test_quantum_extension;
150    use crate::{type_row, Direction, HugrView, Port};
151    use crate::{Hugr, Wire};
152
153    use super::InlineDFG;
154
155    fn find_dfgs<H: HugrView>(h: &H) -> Vec<H::Node> {
156        h.nodes()
157            .filter(|n| h.get_optype(*n).as_dfg().is_some())
158            .collect()
159    }
160    fn extension_ops<H: HugrView>(h: &H) -> Vec<H::Node> {
161        h.nodes()
162            .filter(|n| matches!(h.get_optype(*n), OpType::ExtensionOp(_)))
163            .collect()
164    }
165
166    #[rstest]
167    #[case(true)]
168    #[case(false)]
169    fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
170        let int_ty = &int_types::INT_TYPES[6];
171
172        let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
173        let [a, b] = outer.input_wires_arr();
174        fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
175            d: &mut DFGBuilder<T>,
176        ) -> Result<Wire, Box<dyn std::error::Error>> {
177            let cst = Value::extension(ConstInt::new_u(6, 15)?);
178            let c1 = d.add_load_const(cst);
179
180            Ok(c1)
181        }
182        let c1 = nonlocal.then(|| make_const(&mut outer));
183        let inner = {
184            let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
185            let [a] = inner.input_wires_arr();
186            let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
187            let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
188            inner.finish_with_outputs(a1.outputs())?
189        };
190        let [a1] = inner.outputs_arr();
191
192        let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
193        let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs())?;
194
195        // Sanity checks
196        assert_eq!(
197            outer.children(inner.node()).count(),
198            if nonlocal { 3 } else { 5 }
199        ); // Input, Output, add; + const, load_const
200        assert_eq!(find_dfgs(&outer), vec![outer.root(), inner.node()]);
201        let [add, sub] = extension_ops(&outer).try_into().unwrap();
202        assert_eq!(
203            outer.get_parent(outer.get_parent(add).unwrap()),
204            outer.get_parent(sub)
205        );
206        assert_eq!(outer.nodes().count(), 10); // 6 above + inner DFG + outer (DFG + Input + Output + sub)
207        {
208            // Check we can't inline the outer DFG
209            let mut h = outer.clone();
210            assert_eq!(
211                h.apply_rewrite(InlineDFG(DfgID::from(h.root()))),
212                Err(InlineDFGError::NoParent)
213            );
214            assert_eq!(h, outer); // unchanged
215        }
216
217        outer.apply_rewrite(InlineDFG(*inner.handle()))?;
218        outer.validate()?;
219        assert_eq!(outer.nodes().count(), 7);
220        assert_eq!(find_dfgs(&outer), vec![outer.root()]);
221        let [add, sub] = extension_ops(&outer).try_into().unwrap();
222        assert_eq!(outer.get_parent(add), Some(outer.root()));
223        assert_eq!(outer.get_parent(sub), Some(outer.root()));
224        assert_eq!(
225            outer.node_connections(add, sub).collect::<Vec<_>>().len(),
226            1
227        );
228        Ok(())
229    }
230
231    #[test]
232    fn permutation() -> Result<(), Box<dyn std::error::Error>> {
233        let mut h = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
234        let [p, q] = h.input_wires_arr();
235        let [p_h] = h
236            .add_dataflow_op(test_quantum_extension::h_gate(), [p])?
237            .outputs_arr();
238        let swap = {
239            let swap = h.dfg_builder(Signature::new_endo(vec![qb_t(), qb_t()]), [p_h, q])?;
240            let [a, b] = swap.input_wires_arr();
241            swap.finish_with_outputs([b, a])?
242        };
243        let [q, p] = swap.outputs_arr();
244        let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
245
246        let mut h = h.finish_hugr_with_outputs(cx.outputs())?;
247        assert_eq!(find_dfgs(&h), vec![h.root(), swap.node()]);
248        assert_eq!(h.nodes().count(), 8); // Dfg+I+O, H, CX, Dfg+I+O
249                                          // No permutation outside the swap DFG:
250        assert_eq!(
251            h.node_connections(p_h.node(), swap.node())
252                .collect::<Vec<_>>(),
253            vec![[
254                Port::new(Direction::Outgoing, 0),
255                Port::new(Direction::Incoming, 0)
256            ]]
257        );
258        assert_eq!(
259            h.node_connections(swap.node(), cx.node())
260                .collect::<Vec<_>>(),
261            vec![
262                [
263                    Port::new(Direction::Outgoing, 0),
264                    Port::new(Direction::Incoming, 0)
265                ],
266                [
267                    Port::new(Direction::Outgoing, 1),
268                    Port::new(Direction::Incoming, 1)
269                ]
270            ]
271        );
272
273        h.apply_rewrite(InlineDFG(*swap.handle()))?;
274        assert_eq!(find_dfgs(&h), vec![h.root()]);
275        assert_eq!(h.nodes().count(), 5); // Dfg+I+O
276        let mut ops = extension_ops(&h);
277        ops.sort_by_key(|n| h.num_outputs(*n)); // Put H before CX
278        let [h_gate, cx] = ops.try_into().unwrap();
279        // Now permutation exists:
280        assert_eq!(
281            h.node_connections(h_gate, cx).collect::<Vec<_>>(),
282            vec![[
283                Port::new(Direction::Outgoing, 0),
284                Port::new(Direction::Incoming, 1)
285            ]]
286        );
287        Ok(())
288    }
289
290    #[test]
291    fn order_edges() -> Result<(), Box<dyn std::error::Error>> {
292        /*      -----|-----|-----
293         *           |     |
294         *          H_a   H_b
295         *           |.    /         NB. Order edge H_a to nested DFG
296         *           | .  |
297         *           |  /-|--------\
298         *           |  | | .  Cst | NB. Order edge Input to LCst
299         *           |  | |  . |   |
300         *           |  | |   LCst |
301         *           |  |  \ /     |
302         *           |  |  RZ      |
303         *           |  |  |       |
304         *           |  |  meas    |
305         *           |  |  | \     |
306         *           |  |  |  if   |
307         *           |  |  |  .    | NB. Order edge if to Output
308         *           |  \--|-------/
309         *           |  .  |
310         *           | .   |         NB. Order edge nested DFG to H_a2
311         *           H_a2  /
312         *             \  /
313         *              CX
314         */
315        // Extension inference here relies on quantum ops not requiring their own test_quantum_extension
316        let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
317        let [a, b] = outer.input_wires_arr();
318        let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
319        let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
320        let mut inner = outer.dfg_builder(endo_sig(qb_t()), h_b.outputs())?;
321        let [i] = inner.input_wires_arr();
322        let f = inner.add_load_value(float_types::ConstF64::new(1.0));
323        inner.add_other_wire(inner.input().node(), f.node());
324        let r = inner.add_dataflow_op(test_quantum_extension::rz_f64(), [i, f])?;
325        let [m, b] = inner
326            .add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
327            .outputs_arr();
328        // Node using the boolean. Here we just select between two empty computations.
329        let mut if_n = inner.conditional_builder_exts(
330            ([type_row![], type_row![]], b),
331            [],
332            type_row![],
333            ExtensionSet::new(),
334        )?;
335        if_n.case_builder(0)?.finish_with_outputs([])?;
336        if_n.case_builder(1)?.finish_with_outputs([])?;
337        let if_n = if_n.finish_sub_container()?;
338        inner.add_other_wire(if_n.node(), inner.output().node());
339        let inner = inner.finish_with_outputs([m])?;
340        outer.add_other_wire(h_a.node(), inner.node());
341        let h_a2 = outer.add_dataflow_op(test_quantum_extension::h_gate(), h_a.outputs())?;
342        outer.add_other_wire(inner.node(), h_a2.node());
343        let cx = outer.add_dataflow_op(
344            test_quantum_extension::cx_gate(),
345            h_a2.outputs().chain(inner.outputs()),
346        )?;
347        let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?;
348
349        outer.apply_rewrite(InlineDFG(*inner.handle()))?;
350        outer.validate()?;
351        let order_neighbours = |n, d| {
352            let p = outer.get_optype(n).other_port(d).unwrap();
353            outer
354                .linked_ports(n, p)
355                .map(|(n, _)| n)
356                .collect::<HashSet<_>>()
357        };
358        // h_a should have Order edges added to Rz and the F64 load_const
359        assert_eq!(
360            order_neighbours(h_a.node(), Direction::Outgoing),
361            HashSet::from([r.node(), f.node()])
362        );
363        // Likewise the load_const should have Order edges from the inputs to the inner DFG, i.e. h_a and h_b
364        assert_eq!(
365            order_neighbours(f.node(), Direction::Incoming),
366            HashSet::from([h_a.node(), h_b.node()])
367        );
368        // h_a2 should have Order edges from the measure and if
369        assert_eq!(
370            order_neighbours(h_a2.node(), Direction::Incoming),
371            HashSet::from([m.node(), if_n.node()])
372        );
373        // the if should have Order edges to the CX and h_a2
374        assert_eq!(
375            order_neighbours(if_n.node(), Direction::Outgoing),
376            HashSet::from([h_a2.node(), cx.node()])
377        );
378        Ok(())
379    }
380}