1use super::Rewrite;
6use crate::ops::handle::{DfgID, NodeHandle};
7use crate::{IncomingPort, Node, OutgoingPort, PortIndex};
8
9pub struct InlineDFG(pub DfgID);
11
12#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
14#[non_exhaustive]
15pub enum InlineDFGError {
16 #[error("Node {0} was not a DFG")]
18 NotDFG(Node),
19 #[error("Node did not have a parent into which to inline")]
21 NoParent,
22}
23
24impl Rewrite for InlineDFG {
25 type ApplyResult = [Node; 3];
27 type Error = InlineDFGError;
28
29 const UNCHANGED_ON_FAILURE: bool = true;
30
31 fn verify(&self, h: &impl crate::HugrView<Node = Node>) -> Result<(), Self::Error> {
32 let n = self.0.node();
33 if h.get_optype(n).as_dfg().is_none() {
34 return Err(InlineDFGError::NotDFG(n));
35 };
36 if h.get_parent(n).is_none() {
37 return Err(InlineDFGError::NoParent);
38 };
39 Ok(())
40 }
41
42 fn apply(self, h: &mut impl crate::hugr::HugrMut) -> Result<Self::ApplyResult, Self::Error> {
43 self.verify(h)?;
44 let n = self.0.node();
45 let (oth_in, oth_out) = {
46 let dfg_ty = h.get_optype(n);
47 (
48 dfg_ty.other_input_port().unwrap(),
49 dfg_ty.other_output_port().unwrap(),
50 )
51 };
52 let parent = h.get_parent(n).unwrap();
53 let [input, output] = h.get_io(n).unwrap();
54 for ch in h.children(n).skip(2).collect::<Vec<_>>().into_iter() {
55 h.set_parent(ch, parent);
56 }
57 for (src_n, src_p) in h.linked_outputs(n, oth_in).collect::<Vec<_>>() {
59 debug_assert_eq!(Some(src_p), h.get_optype(src_n).other_output_port());
61 for tgt_n in h.output_neighbours(input).collect::<Vec<_>>() {
62 h.add_other_edge(src_n, tgt_n);
63 }
64 }
65 let input_ord_succs = h
67 .linked_inputs(input, h.get_optype(input).other_output_port().unwrap())
68 .collect::<Vec<_>>();
69 for inp in h.node_inputs(n).collect::<Vec<_>>() {
70 if inp == oth_in {
71 continue;
72 };
73 let (src_n, src_p) = h.single_linked_output(n, inp).unwrap();
75 h.disconnect(n, inp); let outp = OutgoingPort::from(inp.index());
77 let targets = h.linked_inputs(input, outp).collect::<Vec<_>>();
78 h.disconnect(input, outp);
79
80 for (tgt_n, tgt_p) in targets {
81 h.connect(src_n, src_p, tgt_n, tgt_p);
82 }
83 for (tgt, _) in input_ord_succs.iter() {
85 h.add_other_edge(src_n, *tgt);
86 }
87 }
88 for (tgt_n, tgt_p) in h.linked_inputs(n, oth_out).collect::<Vec<_>>() {
90 debug_assert_eq!(Some(tgt_p), h.get_optype(tgt_n).other_input_port());
91 for src_n in h.input_neighbours(output).collect::<Vec<_>>() {
92 h.add_other_edge(src_n, tgt_n);
93 }
94 }
95 let output_ord_preds = h
97 .linked_outputs(output, h.get_optype(output).other_input_port().unwrap())
98 .collect::<Vec<_>>();
99 for outport in h.node_outputs(n).collect::<Vec<_>>() {
100 if outport == oth_out {
101 continue;
102 };
103 let inpp = IncomingPort::from(outport.index());
104 let (src_n, src_p) = h.single_linked_output(output, inpp).unwrap();
106 h.disconnect(output, inpp);
107
108 for (tgt_n, tgt_p) in h.linked_inputs(n, outport).collect::<Vec<_>>() {
109 h.connect(src_n, src_p, tgt_n, tgt_p);
110 for (src, _) in output_ord_preds.iter() {
112 h.add_other_edge(*src, tgt_n);
113 }
114 }
115 h.disconnect(n, outport);
116 }
117 h.remove_node(input);
118 h.remove_node(output);
119 assert!(h.children(n).next().is_none());
120 h.remove_node(n);
121 Ok([n, input, output])
122 }
123
124 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
125 [self.0.node()].into_iter()
126 }
127}
128
129#[cfg(test)]
130mod test {
131 use std::collections::HashSet;
132
133 use rstest::rstest;
134
135 use crate::builder::{
136 endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
137 SubContainer,
138 };
139 use crate::extension::prelude::qb_t;
140 use crate::extension::ExtensionSet;
141 use crate::hugr::rewrite::inline_dfg::InlineDFGError;
142 use crate::hugr::HugrMut;
143 use crate::ops::handle::{DfgID, NodeHandle};
144 use crate::ops::{OpType, Value};
145 use crate::std_extensions::arithmetic::float_types;
146 use crate::std_extensions::arithmetic::int_ops::IntOpDef;
147 use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
148 use crate::types::Signature;
149 use crate::utils::test_quantum_extension;
150 use crate::{type_row, Direction, HugrView, Port};
151 use crate::{Hugr, Wire};
152
153 use super::InlineDFG;
154
155 fn find_dfgs<H: HugrView>(h: &H) -> Vec<H::Node> {
156 h.nodes()
157 .filter(|n| h.get_optype(*n).as_dfg().is_some())
158 .collect()
159 }
160 fn extension_ops<H: HugrView>(h: &H) -> Vec<H::Node> {
161 h.nodes()
162 .filter(|n| matches!(h.get_optype(*n), OpType::ExtensionOp(_)))
163 .collect()
164 }
165
166 #[rstest]
167 #[case(true)]
168 #[case(false)]
169 fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
170 let int_ty = &int_types::INT_TYPES[6];
171
172 let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
173 let [a, b] = outer.input_wires_arr();
174 fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
175 d: &mut DFGBuilder<T>,
176 ) -> Result<Wire, Box<dyn std::error::Error>> {
177 let cst = Value::extension(ConstInt::new_u(6, 15)?);
178 let c1 = d.add_load_const(cst);
179
180 Ok(c1)
181 }
182 let c1 = nonlocal.then(|| make_const(&mut outer));
183 let inner = {
184 let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
185 let [a] = inner.input_wires_arr();
186 let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
187 let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
188 inner.finish_with_outputs(a1.outputs())?
189 };
190 let [a1] = inner.outputs_arr();
191
192 let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
193 let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs())?;
194
195 assert_eq!(
197 outer.children(inner.node()).count(),
198 if nonlocal { 3 } else { 5 }
199 ); assert_eq!(find_dfgs(&outer), vec![outer.root(), inner.node()]);
201 let [add, sub] = extension_ops(&outer).try_into().unwrap();
202 assert_eq!(
203 outer.get_parent(outer.get_parent(add).unwrap()),
204 outer.get_parent(sub)
205 );
206 assert_eq!(outer.nodes().count(), 10); {
208 let mut h = outer.clone();
210 assert_eq!(
211 h.apply_rewrite(InlineDFG(DfgID::from(h.root()))),
212 Err(InlineDFGError::NoParent)
213 );
214 assert_eq!(h, outer); }
216
217 outer.apply_rewrite(InlineDFG(*inner.handle()))?;
218 outer.validate()?;
219 assert_eq!(outer.nodes().count(), 7);
220 assert_eq!(find_dfgs(&outer), vec![outer.root()]);
221 let [add, sub] = extension_ops(&outer).try_into().unwrap();
222 assert_eq!(outer.get_parent(add), Some(outer.root()));
223 assert_eq!(outer.get_parent(sub), Some(outer.root()));
224 assert_eq!(
225 outer.node_connections(add, sub).collect::<Vec<_>>().len(),
226 1
227 );
228 Ok(())
229 }
230
231 #[test]
232 fn permutation() -> Result<(), Box<dyn std::error::Error>> {
233 let mut h = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
234 let [p, q] = h.input_wires_arr();
235 let [p_h] = h
236 .add_dataflow_op(test_quantum_extension::h_gate(), [p])?
237 .outputs_arr();
238 let swap = {
239 let swap = h.dfg_builder(Signature::new_endo(vec![qb_t(), qb_t()]), [p_h, q])?;
240 let [a, b] = swap.input_wires_arr();
241 swap.finish_with_outputs([b, a])?
242 };
243 let [q, p] = swap.outputs_arr();
244 let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
245
246 let mut h = h.finish_hugr_with_outputs(cx.outputs())?;
247 assert_eq!(find_dfgs(&h), vec![h.root(), swap.node()]);
248 assert_eq!(h.nodes().count(), 8); assert_eq!(
251 h.node_connections(p_h.node(), swap.node())
252 .collect::<Vec<_>>(),
253 vec![[
254 Port::new(Direction::Outgoing, 0),
255 Port::new(Direction::Incoming, 0)
256 ]]
257 );
258 assert_eq!(
259 h.node_connections(swap.node(), cx.node())
260 .collect::<Vec<_>>(),
261 vec![
262 [
263 Port::new(Direction::Outgoing, 0),
264 Port::new(Direction::Incoming, 0)
265 ],
266 [
267 Port::new(Direction::Outgoing, 1),
268 Port::new(Direction::Incoming, 1)
269 ]
270 ]
271 );
272
273 h.apply_rewrite(InlineDFG(*swap.handle()))?;
274 assert_eq!(find_dfgs(&h), vec![h.root()]);
275 assert_eq!(h.nodes().count(), 5); let mut ops = extension_ops(&h);
277 ops.sort_by_key(|n| h.num_outputs(*n)); let [h_gate, cx] = ops.try_into().unwrap();
279 assert_eq!(
281 h.node_connections(h_gate, cx).collect::<Vec<_>>(),
282 vec![[
283 Port::new(Direction::Outgoing, 0),
284 Port::new(Direction::Incoming, 1)
285 ]]
286 );
287 Ok(())
288 }
289
290 #[test]
291 fn order_edges() -> Result<(), Box<dyn std::error::Error>> {
292 let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
317 let [a, b] = outer.input_wires_arr();
318 let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
319 let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
320 let mut inner = outer.dfg_builder(endo_sig(qb_t()), h_b.outputs())?;
321 let [i] = inner.input_wires_arr();
322 let f = inner.add_load_value(float_types::ConstF64::new(1.0));
323 inner.add_other_wire(inner.input().node(), f.node());
324 let r = inner.add_dataflow_op(test_quantum_extension::rz_f64(), [i, f])?;
325 let [m, b] = inner
326 .add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
327 .outputs_arr();
328 let mut if_n = inner.conditional_builder_exts(
330 ([type_row![], type_row![]], b),
331 [],
332 type_row![],
333 ExtensionSet::new(),
334 )?;
335 if_n.case_builder(0)?.finish_with_outputs([])?;
336 if_n.case_builder(1)?.finish_with_outputs([])?;
337 let if_n = if_n.finish_sub_container()?;
338 inner.add_other_wire(if_n.node(), inner.output().node());
339 let inner = inner.finish_with_outputs([m])?;
340 outer.add_other_wire(h_a.node(), inner.node());
341 let h_a2 = outer.add_dataflow_op(test_quantum_extension::h_gate(), h_a.outputs())?;
342 outer.add_other_wire(inner.node(), h_a2.node());
343 let cx = outer.add_dataflow_op(
344 test_quantum_extension::cx_gate(),
345 h_a2.outputs().chain(inner.outputs()),
346 )?;
347 let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?;
348
349 outer.apply_rewrite(InlineDFG(*inner.handle()))?;
350 outer.validate()?;
351 let order_neighbours = |n, d| {
352 let p = outer.get_optype(n).other_port(d).unwrap();
353 outer
354 .linked_ports(n, p)
355 .map(|(n, _)| n)
356 .collect::<HashSet<_>>()
357 };
358 assert_eq!(
360 order_neighbours(h_a.node(), Direction::Outgoing),
361 HashSet::from([r.node(), f.node()])
362 );
363 assert_eq!(
365 order_neighbours(f.node(), Direction::Incoming),
366 HashSet::from([h_a.node(), h_b.node()])
367 );
368 assert_eq!(
370 order_neighbours(h_a2.node(), Direction::Incoming),
371 HashSet::from([m.node(), if_n.node()])
372 );
373 assert_eq!(
375 order_neighbours(if_n.node(), Direction::Outgoing),
376 HashSet::from([h_a2.node(), cx.node()])
377 );
378 Ok(())
379 }
380}