1use super::{PatchHugrMut, PatchVerification};
6use crate::core::HugrNode;
7use crate::ops::handle::{DfgID, NodeHandle};
8use crate::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex};
9
10pub struct InlineDFG<N = Node>(pub DfgID<N>);
12
13#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
15#[non_exhaustive]
16pub enum InlineDFGError<N = Node> {
17 #[error("{node} was not a DFG")]
19 NotDFG {
20 node: N,
22 },
23 #[error("Cannot inline the entrypoint node, {node}")]
25 CantInlineEntrypoint {
26 node: N,
28 },
29}
30
31impl<N: HugrNode> PatchVerification for InlineDFG<N> {
32 type Error = InlineDFGError<N>;
33
34 type Node = N;
35
36 fn verify(&self, h: &impl crate::HugrView<Node = N>) -> Result<(), Self::Error> {
37 let n = self.0.node();
38 if h.get_optype(n).as_dfg().is_none() {
39 return Err(InlineDFGError::NotDFG { node: n });
40 }
41 if n == h.entrypoint() {
42 return Err(InlineDFGError::CantInlineEntrypoint { node: n });
43 }
44 Ok(())
45 }
46
47 fn invalidated_nodes(
48 &self,
49 _: &impl HugrView<Node = Self::Node>,
50 ) -> impl Iterator<Item = Self::Node> {
51 [self.0.node()].into_iter()
52 }
53}
54
55impl<N: HugrNode> PatchHugrMut for InlineDFG<N> {
56 type Outcome = [N; 3];
58
59 const UNCHANGED_ON_FAILURE: bool = true;
60
61 fn apply_hugr_mut(
62 self,
63 h: &mut impl crate::hugr::HugrMut<Node = N>,
64 ) -> Result<Self::Outcome, Self::Error> {
65 self.verify(h)?;
66 let n = self.0.node();
67 let (oth_in, oth_out) = {
68 let dfg_ty = h.get_optype(n);
69 (
70 dfg_ty.other_input_port().unwrap(),
71 dfg_ty.other_output_port().unwrap(),
72 )
73 };
74 let parent = h.get_parent(n).unwrap();
75 let [input, output] = h.get_io(n).unwrap();
76 for ch in h.children(n).skip(2).collect::<Vec<_>>() {
77 h.set_parent(ch, parent);
78 }
79 for (src_n, src_p) in h.linked_outputs(n, oth_in).collect::<Vec<_>>() {
81 debug_assert_eq!(Some(src_p), h.get_optype(src_n).other_output_port());
83 for tgt_n in h.output_neighbours(input).collect::<Vec<_>>() {
84 h.add_other_edge(src_n, tgt_n);
85 }
86 }
87 let input_ord_succs = h
89 .linked_inputs(input, h.get_optype(input).other_output_port().unwrap())
90 .collect::<Vec<_>>();
91 for inp in h.node_inputs(n).collect::<Vec<_>>() {
92 if inp == oth_in {
93 continue;
94 }
95 let (src_n, src_p) = h.single_linked_output(n, inp).unwrap();
97 h.disconnect(n, inp); let outp = OutgoingPort::from(inp.index());
99 let targets = h.linked_inputs(input, outp).collect::<Vec<_>>();
100 h.disconnect(input, outp);
101
102 for (tgt_n, tgt_p) in targets {
103 h.connect(src_n, src_p, tgt_n, tgt_p);
104 }
105 for (tgt, _) in &input_ord_succs {
107 h.add_other_edge(src_n, *tgt);
108 }
109 }
110 for (tgt_n, tgt_p) in h.linked_inputs(n, oth_out).collect::<Vec<_>>() {
112 debug_assert_eq!(Some(tgt_p), h.get_optype(tgt_n).other_input_port());
113 for src_n in h.input_neighbours(output).collect::<Vec<_>>() {
114 h.add_other_edge(src_n, tgt_n);
115 }
116 }
117 let output_ord_preds = h
119 .linked_outputs(output, h.get_optype(output).other_input_port().unwrap())
120 .collect::<Vec<_>>();
121 for outport in h.node_outputs(n).collect::<Vec<_>>() {
122 if outport == oth_out {
123 continue;
124 }
125 let inpp = IncomingPort::from(outport.index());
126 let (src_n, src_p) = h.single_linked_output(output, inpp).unwrap();
128 h.disconnect(output, inpp);
129
130 for (tgt_n, tgt_p) in h.linked_inputs(n, outport).collect::<Vec<_>>() {
131 h.connect(src_n, src_p, tgt_n, tgt_p);
132 for (src, _) in &output_ord_preds {
134 h.add_other_edge(*src, tgt_n);
135 }
136 }
137 h.disconnect(n, outport);
138 }
139 h.remove_node(input);
140 h.remove_node(output);
141 assert!(h.children(n).next().is_none());
142 h.remove_node(n);
143 Ok([n, input, output])
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use std::collections::HashSet;
150
151 use rstest::rstest;
152
153 use crate::builder::{
154 Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
155 endo_sig, inout_sig,
156 };
157 use crate::extension::prelude::qb_t;
158 use crate::hugr::HugrMut;
159 use crate::ops::handle::{DfgID, NodeHandle};
160 use crate::ops::{OpType, Value};
161 use crate::std_extensions::arithmetic::float_types;
162 use crate::std_extensions::arithmetic::int_ops::IntOpDef;
163 use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
164 use crate::types::Signature;
165 use crate::utils::test_quantum_extension;
166 use crate::{Direction, HugrView, Port, type_row};
167 use crate::{Hugr, Wire};
168
169 use super::InlineDFG;
170
171 fn find_dfgs<H: HugrView>(h: &H) -> Vec<H::Node> {
172 h.entry_descendants()
173 .filter(|n| h.get_optype(*n).as_dfg().is_some())
174 .collect()
175 }
176 fn extension_ops<H: HugrView>(h: &H) -> Vec<H::Node> {
177 h.nodes()
178 .filter(|n| matches!(h.get_optype(*n), OpType::ExtensionOp(_)))
179 .collect()
180 }
181
182 #[rstest]
183 #[case(true)]
184 #[case(false)]
185 fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
186 use crate::hugr::patch::inline_dfg::InlineDFGError;
187
188 let int_ty = &int_types::INT_TYPES[6];
189
190 let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
191 let [a, b] = outer.input_wires_arr();
192 fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
193 d: &mut DFGBuilder<T>,
194 ) -> Result<Wire, Box<dyn std::error::Error>> {
195 let cst = Value::extension(ConstInt::new_u(6, 15)?);
196 let c1 = d.add_load_const(cst);
197
198 Ok(c1)
199 }
200 let c1 = nonlocal.then(|| make_const(&mut outer));
201 let inner = {
202 let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
203 let [a] = inner.input_wires_arr();
204 let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
205 let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
206 inner.finish_with_outputs(a1.outputs())?
207 };
208 let [a1] = inner.outputs_arr();
209
210 let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
211 let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs())?;
212
213 assert_eq!(
215 outer.children(inner.node()).count(),
216 if nonlocal { 3 } else { 5 }
217 ); assert_eq!(find_dfgs(&outer), vec![outer.entrypoint(), inner.node()]);
219 let [add, sub] = extension_ops(&outer).try_into().unwrap();
220 assert_eq!(
221 outer.get_parent(outer.get_parent(add).unwrap()),
222 outer.get_parent(sub)
223 );
224 assert_eq!(outer.entry_descendants().count(), 10); {
226 let mut h = outer.clone();
228 assert_eq!(
229 h.apply_patch(InlineDFG(DfgID::from(h.entrypoint()))),
230 Err(InlineDFGError::CantInlineEntrypoint {
231 node: h.entrypoint()
232 })
233 );
234 assert_eq!(h, outer); }
236
237 outer.apply_patch(InlineDFG(*inner.handle()))?;
238 outer.validate()?;
239 assert_eq!(outer.entry_descendants().count(), 7);
240 assert_eq!(find_dfgs(&outer), vec![outer.entrypoint()]);
241 let [add, sub] = extension_ops(&outer).try_into().unwrap();
242 assert_eq!(outer.get_parent(add), Some(outer.entrypoint()));
243 assert_eq!(outer.get_parent(sub), Some(outer.entrypoint()));
244 assert_eq!(
245 outer.node_connections(add, sub).collect::<Vec<_>>().len(),
246 1
247 );
248 Ok(())
249 }
250
251 #[test]
252 fn permutation() -> Result<(), Box<dyn std::error::Error>> {
253 let mut h = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
254 let [p, q] = h.input_wires_arr();
255 let [p_h] = h
256 .add_dataflow_op(test_quantum_extension::h_gate(), [p])?
257 .outputs_arr();
258 let swap = {
259 let swap = h.dfg_builder(Signature::new_endo(vec![qb_t(), qb_t()]), [p_h, q])?;
260 let [a, b] = swap.input_wires_arr();
261 swap.finish_with_outputs([b, a])?
262 };
263 let [q, p] = swap.outputs_arr();
264 let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
265
266 let mut h = h.finish_hugr_with_outputs(cx.outputs())?;
267 assert_eq!(find_dfgs(&h), vec![h.entrypoint(), swap.node()]);
268 assert_eq!(h.entry_descendants().count(), 8); assert_eq!(
271 h.node_connections(p_h.node(), swap.node())
272 .collect::<Vec<_>>(),
273 vec![[
274 Port::new(Direction::Outgoing, 0),
275 Port::new(Direction::Incoming, 0)
276 ]]
277 );
278 assert_eq!(
279 h.node_connections(swap.node(), cx.node())
280 .collect::<Vec<_>>(),
281 vec![
282 [
283 Port::new(Direction::Outgoing, 0),
284 Port::new(Direction::Incoming, 0)
285 ],
286 [
287 Port::new(Direction::Outgoing, 1),
288 Port::new(Direction::Incoming, 1)
289 ]
290 ]
291 );
292
293 h.apply_patch(InlineDFG(*swap.handle()))?;
294 assert_eq!(find_dfgs(&h), vec![h.entrypoint()]);
295 assert_eq!(h.entry_descendants().count(), 5); let mut ops = extension_ops(&h);
297 ops.sort_by_key(|n| h.num_outputs(*n)); let [h_gate, cx] = ops.try_into().unwrap();
299 assert_eq!(
301 h.node_connections(h_gate, cx).collect::<Vec<_>>(),
302 vec![[
303 Port::new(Direction::Outgoing, 0),
304 Port::new(Direction::Incoming, 1)
305 ]]
306 );
307 Ok(())
308 }
309
310 #[test]
311 fn order_edges() -> Result<(), Box<dyn std::error::Error>> {
312 let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
337 let [a, b] = outer.input_wires_arr();
338 let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
339 let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
340 let mut inner = outer.dfg_builder(endo_sig(qb_t()), h_b.outputs())?;
341 let [i] = inner.input_wires_arr();
342 let f = inner.add_load_value(float_types::ConstF64::new(1.0));
343 inner.add_other_wire(inner.input().node(), f.node());
344 let r = inner.add_dataflow_op(test_quantum_extension::rz_f64(), [i, f])?;
345 let [m, b] = inner
346 .add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
347 .outputs_arr();
348 let mut if_n =
350 inner.conditional_builder(([type_row![], type_row![]], b), [], type_row![])?;
351 if_n.case_builder(0)?.finish_with_outputs([])?;
352 if_n.case_builder(1)?.finish_with_outputs([])?;
353 let if_n = if_n.finish_sub_container()?;
354 inner.add_other_wire(if_n.node(), inner.output().node());
355 let inner = inner.finish_with_outputs([m])?;
356 outer.add_other_wire(h_a.node(), inner.node());
357 let h_a2 = outer.add_dataflow_op(test_quantum_extension::h_gate(), h_a.outputs())?;
358 outer.add_other_wire(inner.node(), h_a2.node());
359 let cx = outer.add_dataflow_op(
360 test_quantum_extension::cx_gate(),
361 h_a2.outputs().chain(inner.outputs()),
362 )?;
363 let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?;
364
365 outer.apply_patch(InlineDFG(*inner.handle()))?;
366 outer.validate()?;
367 let order_neighbours = |n, d| {
368 let p = outer.get_optype(n).other_port(d).unwrap();
369 outer
370 .linked_ports(n, p)
371 .map(|(n, _)| n)
372 .collect::<HashSet<_>>()
373 };
374 assert_eq!(
376 order_neighbours(h_a.node(), Direction::Outgoing),
377 HashSet::from([r.node(), f.node()])
378 );
379 assert_eq!(
381 order_neighbours(f.node(), Direction::Incoming),
382 HashSet::from([h_a.node(), h_b.node()])
383 );
384 assert_eq!(
386 order_neighbours(h_a2.node(), Direction::Incoming),
387 HashSet::from([m.node(), if_n.node()])
388 );
389 assert_eq!(
391 order_neighbours(if_n.node(), Direction::Outgoing),
392 HashSet::from([h_a2.node(), cx.node()])
393 );
394 Ok(())
395 }
396}