1use super::{PatchHugrMut, PatchVerification};
6use crate::ops::handle::{DfgID, NodeHandle};
7use crate::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex};
8
9pub struct InlineDFG(pub DfgID);
11
12#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
14#[non_exhaustive]
15pub enum InlineDFGError {
16 #[error("{node} was not a DFG")]
18 NotDFG {
19 node: Node,
21 },
22 #[error("Cannot inline the entrypoint node, {node}")]
24 CantInlineEntrypoint {
25 node: Node,
27 },
28}
29
30impl PatchVerification for InlineDFG {
31 type Error = InlineDFGError;
32
33 type Node = Node;
34
35 fn verify(&self, h: &impl crate::HugrView<Node = Node>) -> Result<(), Self::Error> {
36 let n = self.0.node();
37 if h.get_optype(n).as_dfg().is_none() {
38 return Err(InlineDFGError::NotDFG { node: n });
39 }
40 if n == h.entrypoint() {
41 return Err(InlineDFGError::CantInlineEntrypoint { node: n });
42 }
43 Ok(())
44 }
45
46 fn invalidated_nodes(
47 &self,
48 _: &impl HugrView<Node = Self::Node>,
49 ) -> impl Iterator<Item = Self::Node> {
50 [self.0.node()].into_iter()
51 }
52}
53
54impl PatchHugrMut for InlineDFG {
55 type Outcome = [Node; 3];
57
58 const UNCHANGED_ON_FAILURE: bool = true;
59
60 fn apply_hugr_mut(
61 self,
62 h: &mut impl crate::hugr::HugrMut<Node = Node>,
63 ) -> Result<Self::Outcome, Self::Error> {
64 self.verify(h)?;
65 let n = self.0.node();
66 let (oth_in, oth_out) = {
67 let dfg_ty = h.get_optype(n);
68 (
69 dfg_ty.other_input_port().unwrap(),
70 dfg_ty.other_output_port().unwrap(),
71 )
72 };
73 let parent = h.get_parent(n).unwrap();
74 let [input, output] = h.get_io(n).unwrap();
75 for ch in h.children(n).skip(2).collect::<Vec<_>>() {
76 h.set_parent(ch, parent);
77 }
78 for (src_n, src_p) in h.linked_outputs(n, oth_in).collect::<Vec<_>>() {
80 debug_assert_eq!(Some(src_p), h.get_optype(src_n).other_output_port());
82 for tgt_n in h.output_neighbours(input).collect::<Vec<_>>() {
83 h.add_other_edge(src_n, tgt_n);
84 }
85 }
86 let input_ord_succs = h
88 .linked_inputs(input, h.get_optype(input).other_output_port().unwrap())
89 .collect::<Vec<_>>();
90 for inp in h.node_inputs(n).collect::<Vec<_>>() {
91 if inp == oth_in {
92 continue;
93 }
94 let (src_n, src_p) = h.single_linked_output(n, inp).unwrap();
96 h.disconnect(n, inp); let outp = OutgoingPort::from(inp.index());
98 let targets = h.linked_inputs(input, outp).collect::<Vec<_>>();
99 h.disconnect(input, outp);
100
101 for (tgt_n, tgt_p) in targets {
102 h.connect(src_n, src_p, tgt_n, tgt_p);
103 }
104 for (tgt, _) in &input_ord_succs {
106 h.add_other_edge(src_n, *tgt);
107 }
108 }
109 for (tgt_n, tgt_p) in h.linked_inputs(n, oth_out).collect::<Vec<_>>() {
111 debug_assert_eq!(Some(tgt_p), h.get_optype(tgt_n).other_input_port());
112 for src_n in h.input_neighbours(output).collect::<Vec<_>>() {
113 h.add_other_edge(src_n, tgt_n);
114 }
115 }
116 let output_ord_preds = h
118 .linked_outputs(output, h.get_optype(output).other_input_port().unwrap())
119 .collect::<Vec<_>>();
120 for outport in h.node_outputs(n).collect::<Vec<_>>() {
121 if outport == oth_out {
122 continue;
123 }
124 let inpp = IncomingPort::from(outport.index());
125 let (src_n, src_p) = h.single_linked_output(output, inpp).unwrap();
127 h.disconnect(output, inpp);
128
129 for (tgt_n, tgt_p) in h.linked_inputs(n, outport).collect::<Vec<_>>() {
130 h.connect(src_n, src_p, tgt_n, tgt_p);
131 for (src, _) in &output_ord_preds {
133 h.add_other_edge(*src, tgt_n);
134 }
135 }
136 h.disconnect(n, outport);
137 }
138 h.remove_node(input);
139 h.remove_node(output);
140 assert!(h.children(n).next().is_none());
141 h.remove_node(n);
142 Ok([n, input, output])
143 }
144}
145
146#[cfg(test)]
147mod test {
148 use std::collections::HashSet;
149
150 use rstest::rstest;
151
152 use crate::builder::{
153 Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
154 endo_sig, inout_sig,
155 };
156 use crate::extension::prelude::qb_t;
157 use crate::hugr::HugrMut;
158 use crate::ops::handle::{DfgID, NodeHandle};
159 use crate::ops::{OpType, Value};
160 use crate::std_extensions::arithmetic::float_types;
161 use crate::std_extensions::arithmetic::int_ops::IntOpDef;
162 use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
163 use crate::types::Signature;
164 use crate::utils::test_quantum_extension;
165 use crate::{Direction, HugrView, Port, type_row};
166 use crate::{Hugr, Wire};
167
168 use super::InlineDFG;
169
170 fn find_dfgs<H: HugrView>(h: &H) -> Vec<H::Node> {
171 h.entry_descendants()
172 .filter(|n| h.get_optype(*n).as_dfg().is_some())
173 .collect()
174 }
175 fn extension_ops<H: HugrView>(h: &H) -> Vec<H::Node> {
176 h.nodes()
177 .filter(|n| matches!(h.get_optype(*n), OpType::ExtensionOp(_)))
178 .collect()
179 }
180
181 #[rstest]
182 #[case(true)]
183 #[case(false)]
184 fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
185 use crate::hugr::patch::inline_dfg::InlineDFGError;
186
187 let int_ty = &int_types::INT_TYPES[6];
188
189 let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
190 let [a, b] = outer.input_wires_arr();
191 fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
192 d: &mut DFGBuilder<T>,
193 ) -> Result<Wire, Box<dyn std::error::Error>> {
194 let cst = Value::extension(ConstInt::new_u(6, 15)?);
195 let c1 = d.add_load_const(cst);
196
197 Ok(c1)
198 }
199 let c1 = nonlocal.then(|| make_const(&mut outer));
200 let inner = {
201 let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
202 let [a] = inner.input_wires_arr();
203 let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
204 let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
205 inner.finish_with_outputs(a1.outputs())?
206 };
207 let [a1] = inner.outputs_arr();
208
209 let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
210 let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs())?;
211
212 assert_eq!(
214 outer.children(inner.node()).count(),
215 if nonlocal { 3 } else { 5 }
216 ); assert_eq!(find_dfgs(&outer), vec![outer.entrypoint(), inner.node()]);
218 let [add, sub] = extension_ops(&outer).try_into().unwrap();
219 assert_eq!(
220 outer.get_parent(outer.get_parent(add).unwrap()),
221 outer.get_parent(sub)
222 );
223 assert_eq!(outer.entry_descendants().count(), 10); {
225 let mut h = outer.clone();
227 assert_eq!(
228 h.apply_patch(InlineDFG(DfgID::from(h.entrypoint()))),
229 Err(InlineDFGError::CantInlineEntrypoint {
230 node: h.entrypoint()
231 })
232 );
233 assert_eq!(h, outer); }
235
236 outer.apply_patch(InlineDFG(*inner.handle()))?;
237 outer.validate()?;
238 assert_eq!(outer.entry_descendants().count(), 7);
239 assert_eq!(find_dfgs(&outer), vec![outer.entrypoint()]);
240 let [add, sub] = extension_ops(&outer).try_into().unwrap();
241 assert_eq!(outer.get_parent(add), Some(outer.entrypoint()));
242 assert_eq!(outer.get_parent(sub), Some(outer.entrypoint()));
243 assert_eq!(
244 outer.node_connections(add, sub).collect::<Vec<_>>().len(),
245 1
246 );
247 Ok(())
248 }
249
250 #[test]
251 fn permutation() -> Result<(), Box<dyn std::error::Error>> {
252 let mut h = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
253 let [p, q] = h.input_wires_arr();
254 let [p_h] = h
255 .add_dataflow_op(test_quantum_extension::h_gate(), [p])?
256 .outputs_arr();
257 let swap = {
258 let swap = h.dfg_builder(Signature::new_endo(vec![qb_t(), qb_t()]), [p_h, q])?;
259 let [a, b] = swap.input_wires_arr();
260 swap.finish_with_outputs([b, a])?
261 };
262 let [q, p] = swap.outputs_arr();
263 let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
264
265 let mut h = h.finish_hugr_with_outputs(cx.outputs())?;
266 assert_eq!(find_dfgs(&h), vec![h.entrypoint(), swap.node()]);
267 assert_eq!(h.entry_descendants().count(), 8); assert_eq!(
270 h.node_connections(p_h.node(), swap.node())
271 .collect::<Vec<_>>(),
272 vec![[
273 Port::new(Direction::Outgoing, 0),
274 Port::new(Direction::Incoming, 0)
275 ]]
276 );
277 assert_eq!(
278 h.node_connections(swap.node(), cx.node())
279 .collect::<Vec<_>>(),
280 vec![
281 [
282 Port::new(Direction::Outgoing, 0),
283 Port::new(Direction::Incoming, 0)
284 ],
285 [
286 Port::new(Direction::Outgoing, 1),
287 Port::new(Direction::Incoming, 1)
288 ]
289 ]
290 );
291
292 h.apply_patch(InlineDFG(*swap.handle()))?;
293 assert_eq!(find_dfgs(&h), vec![h.entrypoint()]);
294 assert_eq!(h.entry_descendants().count(), 5); let mut ops = extension_ops(&h);
296 ops.sort_by_key(|n| h.num_outputs(*n)); let [h_gate, cx] = ops.try_into().unwrap();
298 assert_eq!(
300 h.node_connections(h_gate, cx).collect::<Vec<_>>(),
301 vec![[
302 Port::new(Direction::Outgoing, 0),
303 Port::new(Direction::Incoming, 1)
304 ]]
305 );
306 Ok(())
307 }
308
309 #[test]
310 fn order_edges() -> Result<(), Box<dyn std::error::Error>> {
311 let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
336 let [a, b] = outer.input_wires_arr();
337 let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
338 let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
339 let mut inner = outer.dfg_builder(endo_sig(qb_t()), h_b.outputs())?;
340 let [i] = inner.input_wires_arr();
341 let f = inner.add_load_value(float_types::ConstF64::new(1.0));
342 inner.add_other_wire(inner.input().node(), f.node());
343 let r = inner.add_dataflow_op(test_quantum_extension::rz_f64(), [i, f])?;
344 let [m, b] = inner
345 .add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
346 .outputs_arr();
347 let mut if_n =
349 inner.conditional_builder(([type_row![], type_row![]], b), [], type_row![])?;
350 if_n.case_builder(0)?.finish_with_outputs([])?;
351 if_n.case_builder(1)?.finish_with_outputs([])?;
352 let if_n = if_n.finish_sub_container()?;
353 inner.add_other_wire(if_n.node(), inner.output().node());
354 let inner = inner.finish_with_outputs([m])?;
355 outer.add_other_wire(h_a.node(), inner.node());
356 let h_a2 = outer.add_dataflow_op(test_quantum_extension::h_gate(), h_a.outputs())?;
357 outer.add_other_wire(inner.node(), h_a2.node());
358 let cx = outer.add_dataflow_op(
359 test_quantum_extension::cx_gate(),
360 h_a2.outputs().chain(inner.outputs()),
361 )?;
362 let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?;
363
364 outer.apply_patch(InlineDFG(*inner.handle()))?;
365 outer.validate()?;
366 let order_neighbours = |n, d| {
367 let p = outer.get_optype(n).other_port(d).unwrap();
368 outer
369 .linked_ports(n, p)
370 .map(|(n, _)| n)
371 .collect::<HashSet<_>>()
372 };
373 assert_eq!(
375 order_neighbours(h_a.node(), Direction::Outgoing),
376 HashSet::from([r.node(), f.node()])
377 );
378 assert_eq!(
380 order_neighbours(f.node(), Direction::Incoming),
381 HashSet::from([h_a.node(), h_b.node()])
382 );
383 assert_eq!(
385 order_neighbours(h_a2.node(), Direction::Incoming),
386 HashSet::from([m.node(), if_n.node()])
387 );
388 assert_eq!(
390 order_neighbours(if_n.node(), Direction::Outgoing),
391 HashSet::from([h_a2.node(), cx.node()])
392 );
393 Ok(())
394 }
395}