1use super::{PatchHugrMut, PatchVerification};
6use crate::ops::handle::{DfgID, NodeHandle};
7use crate::{IncomingPort, Node, OutgoingPort, PortIndex};
8
9pub struct InlineDFG(pub DfgID);
11
12#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
14#[non_exhaustive]
15pub enum InlineDFGError {
16 #[error("{node} was not a DFG")]
18 NotDFG {
19 node: Node,
21 },
22 #[error("Cannot inline the entrypoint node, {node}")]
24 CantInlineEntrypoint {
25 node: Node,
27 },
28}
29
30impl PatchVerification for InlineDFG {
31 type Error = InlineDFGError;
32
33 type Node = Node;
34
35 fn verify(&self, h: &impl crate::HugrView<Node = Node>) -> Result<(), Self::Error> {
36 let n = self.0.node();
37 if h.get_optype(n).as_dfg().is_none() {
38 return Err(InlineDFGError::NotDFG { node: n });
39 }
40 if n == h.entrypoint() {
41 return Err(InlineDFGError::CantInlineEntrypoint { node: n });
42 }
43 Ok(())
44 }
45
46 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
47 [self.0.node()].into_iter()
48 }
49}
50
51impl PatchHugrMut for InlineDFG {
52 type Outcome = [Node; 3];
54
55 const UNCHANGED_ON_FAILURE: bool = true;
56
57 fn apply_hugr_mut(
58 self,
59 h: &mut impl crate::hugr::HugrMut<Node = Node>,
60 ) -> Result<Self::Outcome, Self::Error> {
61 self.verify(h)?;
62 let n = self.0.node();
63 let (oth_in, oth_out) = {
64 let dfg_ty = h.get_optype(n);
65 (
66 dfg_ty.other_input_port().unwrap(),
67 dfg_ty.other_output_port().unwrap(),
68 )
69 };
70 let parent = h.get_parent(n).unwrap();
71 let [input, output] = h.get_io(n).unwrap();
72 for ch in h.children(n).skip(2).collect::<Vec<_>>() {
73 h.set_parent(ch, parent);
74 }
75 for (src_n, src_p) in h.linked_outputs(n, oth_in).collect::<Vec<_>>() {
77 debug_assert_eq!(Some(src_p), h.get_optype(src_n).other_output_port());
79 for tgt_n in h.output_neighbours(input).collect::<Vec<_>>() {
80 h.add_other_edge(src_n, tgt_n);
81 }
82 }
83 let input_ord_succs = h
85 .linked_inputs(input, h.get_optype(input).other_output_port().unwrap())
86 .collect::<Vec<_>>();
87 for inp in h.node_inputs(n).collect::<Vec<_>>() {
88 if inp == oth_in {
89 continue;
90 }
91 let (src_n, src_p) = h.single_linked_output(n, inp).unwrap();
93 h.disconnect(n, inp); let outp = OutgoingPort::from(inp.index());
95 let targets = h.linked_inputs(input, outp).collect::<Vec<_>>();
96 h.disconnect(input, outp);
97
98 for (tgt_n, tgt_p) in targets {
99 h.connect(src_n, src_p, tgt_n, tgt_p);
100 }
101 for (tgt, _) in &input_ord_succs {
103 h.add_other_edge(src_n, *tgt);
104 }
105 }
106 for (tgt_n, tgt_p) in h.linked_inputs(n, oth_out).collect::<Vec<_>>() {
108 debug_assert_eq!(Some(tgt_p), h.get_optype(tgt_n).other_input_port());
109 for src_n in h.input_neighbours(output).collect::<Vec<_>>() {
110 h.add_other_edge(src_n, tgt_n);
111 }
112 }
113 let output_ord_preds = h
115 .linked_outputs(output, h.get_optype(output).other_input_port().unwrap())
116 .collect::<Vec<_>>();
117 for outport in h.node_outputs(n).collect::<Vec<_>>() {
118 if outport == oth_out {
119 continue;
120 }
121 let inpp = IncomingPort::from(outport.index());
122 let (src_n, src_p) = h.single_linked_output(output, inpp).unwrap();
124 h.disconnect(output, inpp);
125
126 for (tgt_n, tgt_p) in h.linked_inputs(n, outport).collect::<Vec<_>>() {
127 h.connect(src_n, src_p, tgt_n, tgt_p);
128 for (src, _) in &output_ord_preds {
130 h.add_other_edge(*src, tgt_n);
131 }
132 }
133 h.disconnect(n, outport);
134 }
135 h.remove_node(input);
136 h.remove_node(output);
137 assert!(h.children(n).next().is_none());
138 h.remove_node(n);
139 Ok([n, input, output])
140 }
141}
142
143#[cfg(test)]
144mod test {
145 use std::collections::HashSet;
146
147 use rstest::rstest;
148
149 use crate::builder::{
150 Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
151 endo_sig, inout_sig,
152 };
153 use crate::extension::prelude::qb_t;
154 use crate::hugr::HugrMut;
155 use crate::ops::handle::{DfgID, NodeHandle};
156 use crate::ops::{OpType, Value};
157 use crate::std_extensions::arithmetic::float_types;
158 use crate::std_extensions::arithmetic::int_ops::IntOpDef;
159 use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
160 use crate::types::Signature;
161 use crate::utils::test_quantum_extension;
162 use crate::{Direction, HugrView, Port, type_row};
163 use crate::{Hugr, Wire};
164
165 use super::InlineDFG;
166
167 fn find_dfgs<H: HugrView>(h: &H) -> Vec<H::Node> {
168 h.entry_descendants()
169 .filter(|n| h.get_optype(*n).as_dfg().is_some())
170 .collect()
171 }
172 fn extension_ops<H: HugrView>(h: &H) -> Vec<H::Node> {
173 h.nodes()
174 .filter(|n| matches!(h.get_optype(*n), OpType::ExtensionOp(_)))
175 .collect()
176 }
177
178 #[rstest]
179 #[case(true)]
180 #[case(false)]
181 fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
182 use crate::hugr::patch::inline_dfg::InlineDFGError;
183
184 let int_ty = &int_types::INT_TYPES[6];
185
186 let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
187 let [a, b] = outer.input_wires_arr();
188 fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
189 d: &mut DFGBuilder<T>,
190 ) -> Result<Wire, Box<dyn std::error::Error>> {
191 let cst = Value::extension(ConstInt::new_u(6, 15)?);
192 let c1 = d.add_load_const(cst);
193
194 Ok(c1)
195 }
196 let c1 = nonlocal.then(|| make_const(&mut outer));
197 let inner = {
198 let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
199 let [a] = inner.input_wires_arr();
200 let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
201 let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
202 inner.finish_with_outputs(a1.outputs())?
203 };
204 let [a1] = inner.outputs_arr();
205
206 let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
207 let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs())?;
208
209 assert_eq!(
211 outer.children(inner.node()).count(),
212 if nonlocal { 3 } else { 5 }
213 ); assert_eq!(find_dfgs(&outer), vec![outer.entrypoint(), inner.node()]);
215 let [add, sub] = extension_ops(&outer).try_into().unwrap();
216 assert_eq!(
217 outer.get_parent(outer.get_parent(add).unwrap()),
218 outer.get_parent(sub)
219 );
220 assert_eq!(outer.entry_descendants().count(), 10); {
222 let mut h = outer.clone();
224 assert_eq!(
225 h.apply_patch(InlineDFG(DfgID::from(h.entrypoint()))),
226 Err(InlineDFGError::CantInlineEntrypoint {
227 node: h.entrypoint()
228 })
229 );
230 assert_eq!(h, outer); }
232
233 outer.apply_patch(InlineDFG(*inner.handle()))?;
234 outer.validate()?;
235 assert_eq!(outer.entry_descendants().count(), 7);
236 assert_eq!(find_dfgs(&outer), vec![outer.entrypoint()]);
237 let [add, sub] = extension_ops(&outer).try_into().unwrap();
238 assert_eq!(outer.get_parent(add), Some(outer.entrypoint()));
239 assert_eq!(outer.get_parent(sub), Some(outer.entrypoint()));
240 assert_eq!(
241 outer.node_connections(add, sub).collect::<Vec<_>>().len(),
242 1
243 );
244 Ok(())
245 }
246
247 #[test]
248 fn permutation() -> Result<(), Box<dyn std::error::Error>> {
249 let mut h = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
250 let [p, q] = h.input_wires_arr();
251 let [p_h] = h
252 .add_dataflow_op(test_quantum_extension::h_gate(), [p])?
253 .outputs_arr();
254 let swap = {
255 let swap = h.dfg_builder(Signature::new_endo(vec![qb_t(), qb_t()]), [p_h, q])?;
256 let [a, b] = swap.input_wires_arr();
257 swap.finish_with_outputs([b, a])?
258 };
259 let [q, p] = swap.outputs_arr();
260 let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
261
262 let mut h = h.finish_hugr_with_outputs(cx.outputs())?;
263 assert_eq!(find_dfgs(&h), vec![h.entrypoint(), swap.node()]);
264 assert_eq!(h.entry_descendants().count(), 8); assert_eq!(
267 h.node_connections(p_h.node(), swap.node())
268 .collect::<Vec<_>>(),
269 vec![[
270 Port::new(Direction::Outgoing, 0),
271 Port::new(Direction::Incoming, 0)
272 ]]
273 );
274 assert_eq!(
275 h.node_connections(swap.node(), cx.node())
276 .collect::<Vec<_>>(),
277 vec![
278 [
279 Port::new(Direction::Outgoing, 0),
280 Port::new(Direction::Incoming, 0)
281 ],
282 [
283 Port::new(Direction::Outgoing, 1),
284 Port::new(Direction::Incoming, 1)
285 ]
286 ]
287 );
288
289 h.apply_patch(InlineDFG(*swap.handle()))?;
290 assert_eq!(find_dfgs(&h), vec![h.entrypoint()]);
291 assert_eq!(h.entry_descendants().count(), 5); let mut ops = extension_ops(&h);
293 ops.sort_by_key(|n| h.num_outputs(*n)); let [h_gate, cx] = ops.try_into().unwrap();
295 assert_eq!(
297 h.node_connections(h_gate, cx).collect::<Vec<_>>(),
298 vec![[
299 Port::new(Direction::Outgoing, 0),
300 Port::new(Direction::Incoming, 1)
301 ]]
302 );
303 Ok(())
304 }
305
306 #[test]
307 fn order_edges() -> Result<(), Box<dyn std::error::Error>> {
308 let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
333 let [a, b] = outer.input_wires_arr();
334 let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
335 let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
336 let mut inner = outer.dfg_builder(endo_sig(qb_t()), h_b.outputs())?;
337 let [i] = inner.input_wires_arr();
338 let f = inner.add_load_value(float_types::ConstF64::new(1.0));
339 inner.add_other_wire(inner.input().node(), f.node());
340 let r = inner.add_dataflow_op(test_quantum_extension::rz_f64(), [i, f])?;
341 let [m, b] = inner
342 .add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
343 .outputs_arr();
344 let mut if_n =
346 inner.conditional_builder(([type_row![], type_row![]], b), [], type_row![])?;
347 if_n.case_builder(0)?.finish_with_outputs([])?;
348 if_n.case_builder(1)?.finish_with_outputs([])?;
349 let if_n = if_n.finish_sub_container()?;
350 inner.add_other_wire(if_n.node(), inner.output().node());
351 let inner = inner.finish_with_outputs([m])?;
352 outer.add_other_wire(h_a.node(), inner.node());
353 let h_a2 = outer.add_dataflow_op(test_quantum_extension::h_gate(), h_a.outputs())?;
354 outer.add_other_wire(inner.node(), h_a2.node());
355 let cx = outer.add_dataflow_op(
356 test_quantum_extension::cx_gate(),
357 h_a2.outputs().chain(inner.outputs()),
358 )?;
359 let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?;
360
361 outer.apply_patch(InlineDFG(*inner.handle()))?;
362 outer.validate()?;
363 let order_neighbours = |n, d| {
364 let p = outer.get_optype(n).other_port(d).unwrap();
365 outer
366 .linked_ports(n, p)
367 .map(|(n, _)| n)
368 .collect::<HashSet<_>>()
369 };
370 assert_eq!(
372 order_neighbours(h_a.node(), Direction::Outgoing),
373 HashSet::from([r.node(), f.node()])
374 );
375 assert_eq!(
377 order_neighbours(f.node(), Direction::Incoming),
378 HashSet::from([h_a.node(), h_b.node()])
379 );
380 assert_eq!(
382 order_neighbours(h_a2.node(), Direction::Incoming),
383 HashSet::from([m.node(), if_n.node()])
384 );
385 assert_eq!(
387 order_neighbours(if_n.node(), Direction::Outgoing),
388 HashSet::from([h_a2.node(), cx.node()])
389 );
390 Ok(())
391 }
392}