hugr_core/hugr/views/root_checked/
dfg.rs

1//! RootChecked methods specific to dataflow graphs.
2
3use std::collections::BTreeMap;
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::{
9    IncomingPort, OutgoingPort, PortIndex,
10    hugr::HugrMut,
11    ops::{
12        OpTrait, OpType,
13        handle::{DataflowParentID, DfgID},
14    },
15    types::{NoRV, Signature, TypeBase},
16};
17
18use super::RootChecked;
19
20macro_rules! impl_dataflow_parent_methods {
21    ($handle_type:ident) => {
22        impl<H: HugrMut> RootChecked<H, $handle_type<H::Node>> {
23            /// Get the input and output nodes of the DFG at the entrypoint node.
24            pub fn get_io(&self) -> [H::Node; 2] {
25                self.hugr()
26                    .get_io(self.hugr().entrypoint())
27                    .expect("valid DFG graph")
28            }
29
30            /// Rewire the inputs and outputs of the nested DFG to modify its signature.
31            ///
32            /// Reorder the outgoing resp. incoming wires at the input resp. output
33            /// node of the DFG to modify the signature of the DFG HUGR. This will
34            /// recursively update the signatures of all ancestors of the entrypoint.
35            ///
36            /// ### Arguments
37            ///
38            /// * `new_inputs`: The new input signature. After the map, the i-th input
39            ///   wire will be connected to the ports connected to the
40            ///   `new_inputs[i]`-th input of the old DFG.
41            /// * `new_outputs`: The new output signature. After the map, the i-th
42            ///   output wire will be connected to the ports connected to the
43            ///   `new_outputs[i]`-th output of the old DFG.
44            ///
45            /// Returns an `InvalidSignature` error if the new_inputs and new_outputs
46            /// map are not valid signatures.
47            ///
48            /// ### Panics
49            ///
50            /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
51            /// DFG of the entrypoint that has more than one inner DFG.
52            pub fn map_function_type(
53                &mut self,
54                new_inputs: &[usize],
55                new_outputs: &[usize],
56            ) -> Result<(), InvalidSignature> {
57                let [inp, out] = self.get_io();
58                let Self(hugr, _) = self;
59
60                // Record the old connections from and to the input and output nodes
61                let old_inputs_incoming = hugr
62                    .node_outputs(inp)
63                    .map(|p| hugr.linked_inputs(inp, p).collect_vec())
64                    .collect_vec();
65                let old_outputs_outgoing = hugr
66                    .node_inputs(out)
67                    .map(|p| hugr.linked_outputs(out, p).collect_vec())
68                    .collect_vec();
69
70                // The old signature types
71                let old_inp_sig = hugr
72                    .get_optype(inp)
73                    .dataflow_signature()
74                    .expect("input has signature");
75                let old_inp_sig = old_inp_sig.output_types();
76                let old_out_sig = hugr
77                    .get_optype(out)
78                    .dataflow_signature()
79                    .expect("output has signature");
80                let old_out_sig = old_out_sig.input_types();
81
82                // Check if the signature map is valid
83                check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?;
84                check_valid_outputs(old_out_sig, new_outputs)?;
85
86                // The new signature types
87                let new_inp_sig = new_inputs
88                    .iter()
89                    .map(|&i| old_inp_sig[i].clone())
90                    .collect_vec();
91                let new_out_sig = new_outputs
92                    .iter()
93                    .map(|&i| old_out_sig[i].clone())
94                    .collect_vec();
95                let new_sig = Signature::new(new_inp_sig, new_out_sig);
96
97                // Remove all edges of the input and output nodes
98                disconnect_all(hugr, inp);
99                disconnect_all(hugr, out);
100
101                // Update the signatures of the IO and their ancestors
102                let mut is_ancestor = false;
103                let mut node = hugr.entrypoint();
104                while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
105                    let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
106                    for node in [node, inner_inp, inner_out] {
107                        update_signature(hugr, node, &new_sig);
108                    }
109                    if is_ancestor {
110                        update_inner_dfg_links(hugr, node);
111                    }
112                    if let Some(parent) = hugr.get_parent(node) {
113                        node = parent;
114                        is_ancestor = true;
115                    } else {
116                        break;
117                    }
118                }
119
120                // Insert the new edges at the input
121                let mut old_output_to_new_input = BTreeMap::<IncomingPort, OutgoingPort>::new();
122                for (inp_pos, &old_pos) in new_inputs.iter().enumerate() {
123                    for &(node, port) in &old_inputs_incoming[old_pos] {
124                        if node != out {
125                            hugr.connect(inp, inp_pos, node, port);
126                        } else {
127                            old_output_to_new_input.insert(port, inp_pos.into());
128                        }
129                    }
130                }
131
132                // Insert the new edges at the output
133                for (out_pos, &old_pos) in new_outputs.iter().enumerate() {
134                    for &(node, port) in &old_outputs_outgoing[old_pos] {
135                        if node != inp {
136                            hugr.connect(node, port, out, out_pos);
137                        } else {
138                            let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap();
139                            hugr.connect(inp, inp_pos, out, out_pos);
140                        }
141                    }
142                }
143
144                Ok(())
145            }
146        }
147    };
148}
149
150impl_dataflow_parent_methods!(DataflowParentID);
151impl_dataflow_parent_methods!(DfgID);
152
153/// Panics if the DFG within `node` is not a single inner DFG.
154fn update_inner_dfg_links<H: HugrMut>(hugr: &mut H, node: H::Node) {
155    // connect all edges of the inner DFG to the input and output nodes
156    let inner_dfg = hugr
157        .children(node)
158        .skip(2)
159        .exactly_one()
160        .ok()
161        .expect("no non-trivial inner DFG");
162
163    let [inp, out] = hugr.get_io(node).expect("valid DFG graph");
164    disconnect_all(hugr, inner_dfg);
165    for (out_port, _) in hugr.out_value_types(inp).collect_vec() {
166        hugr.connect(inp, out_port, inner_dfg, out_port.index());
167    }
168    for (in_port, _) in hugr.in_value_types(out).collect_vec() {
169        hugr.connect(inner_dfg, in_port.index(), out, in_port);
170    }
171}
172
173fn disconnect_all<H: HugrMut>(hugr: &mut H, node: H::Node) {
174    let all_ports = hugr.all_node_ports(node).collect_vec();
175    for port in all_ports {
176        hugr.disconnect(node, port);
177    }
178}
179
180fn update_signature<H: HugrMut>(hugr: &mut H, node: H::Node, new_sig: &Signature) {
181    match hugr.optype_mut(node) {
182        OpType::DFG(dfg) => {
183            dfg.signature = new_sig.clone();
184        }
185        OpType::FuncDefn(fn_def_op) => *fn_def_op.signature_mut() = new_sig.clone().into(),
186        OpType::Input(inp) => {
187            inp.types = new_sig.input().clone();
188        }
189        OpType::Output(out) => out.types = new_sig.output().clone(),
190        _ => panic!("only update signature of DFG, FuncDefn, Input, or Output"),
191    };
192    let new_op = hugr.get_optype(node);
193    hugr.set_num_ports(node, new_op.input_count(), new_op.output_count());
194}
195
196fn check_valid_inputs<V>(
197    old_ports: &[Vec<V>],
198    old_sig: &[TypeBase<NoRV>],
199    map_sig: &[usize],
200) -> Result<(), InvalidSignature> {
201    if let Some(old_pos) = map_sig
202        .iter()
203        .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos))
204    {
205        return Err(InvalidSignature::UnknownIO(old_pos, "input"));
206    }
207
208    let counts = map_sig.iter().copied().counts();
209    if let Some(old_pos) = old_ports.iter().enumerate().find_map(|(old_pos, vec)| {
210        ((!vec.is_empty() || old_sig.get(old_pos).is_some_and(|t| !t.copyable()))
211            && !counts.contains_key(&old_pos))
212        .then_some(old_pos)
213    }) {
214        return Err(InvalidSignature::MissingIO(old_pos, "input"));
215    }
216
217    if let Some(old_pos) = counts
218        .iter()
219        .find_map(|(&old_pos, &count)| (count > 1).then_some(old_pos))
220    {
221        return Err(InvalidSignature::DuplicateInput(old_pos));
222    }
223
224    Ok(())
225}
226
227fn check_valid_outputs(
228    old_sig: &[TypeBase<NoRV>],
229    map_sig: &[usize],
230) -> Result<(), InvalidSignature> {
231    if let Some(old_pos) = map_sig
232        .iter()
233        .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos))
234    {
235        return Err(InvalidSignature::UnknownIO(old_pos, "output"));
236    }
237
238    let counts = map_sig.iter().copied().counts();
239    let linear_types = old_sig
240        .iter()
241        .enumerate()
242        .filter_map(|(pos, t)| (!t.copyable()).then_some(pos));
243    for old_pos in linear_types {
244        let Some(&cnt) = counts.get(&old_pos) else {
245            return Err(InvalidSignature::MissingIO(old_pos, "output"));
246        };
247        if cnt != 1 {
248            return Err(InvalidSignature::LinearityViolation(old_pos, "output"));
249        }
250    }
251
252    Ok(())
253}
254
255/// Errors that can occur when mapping the I/O of a DFG.
256#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Error)]
257#[non_exhaustive]
258pub enum InvalidSignature {
259    /// Error when a required input/output is missing from the new signature
260    #[error("{1} at position {0} is required but missing in new signature")]
261    MissingIO(usize, &'static str),
262    /// Error when trying to access an input/output that doesn't exist in the
263    /// signature
264    #[error("No {1} at position {0} in signature")]
265    UnknownIO(usize, &'static str),
266    /// Error when a linear input/output is used multiple times or not at all
267    #[error("Linearity of {1} at position {0} is not preserved in new signature")]
268    LinearityViolation(usize, &'static str),
269    /// Error when an input is used multiple times in the new signature
270    #[error("Input at position {0} is duplicated in new signature")]
271    DuplicateInput(usize),
272}
273
274#[cfg(test)]
275mod test {
276    use insta::assert_snapshot;
277
278    use super::*;
279    use crate::builder::{
280        DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig,
281    };
282    use crate::extension::prelude::{bool_t, qb_t};
283    use crate::hugr::views::root_checked::RootChecked;
284    use crate::ops::handle::NodeHandle;
285    use crate::ops::{NamedOp, OpParent};
286    use crate::types::Signature;
287    use crate::utils::test_quantum_extension::cx_gate;
288    use crate::{Hugr, HugrView};
289
290    fn new_empty_dfg(sig: Signature) -> Hugr {
291        let dfg_builder = DFGBuilder::new(sig).unwrap();
292        let wires = dfg_builder.input_wires();
293        dfg_builder.finish_hugr_with_outputs(wires).unwrap()
294    }
295
296    #[test]
297    fn test_map_io() {
298        // Create a DFG with 2 inputs and 2 outputs
299        let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
300        let mut hugr = new_empty_dfg(sig);
301
302        // Wrap in RootChecked
303        let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
304
305        // Test mapping inputs: [0,1] -> [1,0]
306        let input_map = vec![1, 0];
307        let output_map = vec![0, 1];
308
309        // Map the I/O
310        dfg_view.map_function_type(&input_map, &output_map).unwrap();
311
312        // Verify the new signature
313        let dfg_hugr = dfg_view.hugr();
314        let new_sig = dfg_hugr
315            .get_optype(dfg_hugr.entrypoint())
316            .dataflow_signature()
317            .unwrap();
318        assert_eq!(new_sig.input_count(), 2);
319        assert_eq!(new_sig.output_count(), 2);
320
321        // Test invalid mapping - missing input
322        let invalid_input_map = vec![0, 0];
323        let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
324        assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
325
326        // Test invalid mapping - duplicate input
327        let invalid_input_map = vec![0, 0, 1];
328        assert!(matches!(
329            dfg_view.map_function_type(&invalid_input_map, &output_map),
330            Err(InvalidSignature::DuplicateInput(0))
331        ));
332
333        // Test invalid mapping - unknown output
334        let invalid_output_map = vec![0, 2];
335        assert!(matches!(
336            dfg_view.map_function_type(&input_map, &invalid_output_map),
337            Err(InvalidSignature::UnknownIO(2, "output"))
338        ));
339    }
340
341    #[test]
342    fn test_map_io_dfg_id() {
343        // Create a DFG with 2 inputs and 2 outputs
344        let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
345        let mut hugr = new_empty_dfg(sig);
346
347        // Wrap in RootChecked
348        let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
349
350        // Test mapping inputs: [0,1] -> [1,0]
351        let input_map = vec![1, 0];
352        let output_map = vec![0, 1];
353
354        // Map the I/O
355        dfg_view.map_function_type(&input_map, &output_map).unwrap();
356
357        // Verify the new signature
358        let dfg_hugr = dfg_view.hugr();
359        let new_sig = dfg_hugr
360            .get_optype(dfg_hugr.entrypoint())
361            .dataflow_signature()
362            .unwrap();
363        assert_eq!(new_sig.input_count(), 2);
364        assert_eq!(new_sig.output_count(), 2);
365
366        // Test invalid mapping - missing input
367        let invalid_input_map = vec![0, 0];
368        let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
369        assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
370
371        // Test invalid mapping - duplicate input
372        let invalid_input_map = vec![0, 0, 1];
373        assert!(matches!(
374            dfg_view.map_function_type(&invalid_input_map, &output_map),
375            Err(InvalidSignature::DuplicateInput(0))
376        ));
377
378        // Test invalid mapping - unknown output
379        let invalid_output_map = vec![0, 2];
380        assert!(matches!(
381            dfg_view.map_function_type(&input_map, &invalid_output_map),
382            Err(InvalidSignature::UnknownIO(2, "output"))
383        ));
384    }
385
386    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
387    #[test]
388    fn test_map_io_duplicate_output() {
389        // Create a DFG with 1 input and 1 output
390        let sig = Signature::new_endo(vec![bool_t()]);
391        let mut hugr = new_empty_dfg(sig);
392
393        // Wrap in RootChecked
394        let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
395
396        // Test mapping outputs: [0] -> [0,0] (duplicating the output)
397        let input_map = vec![0];
398        let output_map = vec![0, 0];
399
400        // Map the I/O
401        dfg_view.map_function_type(&input_map, &output_map).unwrap();
402
403        let dfg_hugr = dfg_view.hugr();
404        if let Err(err) = dfg_hugr.validate() {
405            panic!("Invalid Hugr: {err}");
406        }
407
408        // Verify the new signature
409        let new_sig = dfg_hugr
410            .get_optype(dfg_hugr.entrypoint())
411            .dataflow_signature()
412            .unwrap();
413        assert_eq!(new_sig.input_count(), 1);
414        assert_eq!(new_sig.output_count(), 2);
415        assert_snapshot!(dfg_hugr.mermaid_string());
416    }
417
418    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
419    #[test]
420    fn test_map_io_cx_gate() {
421        // Create a DFG with 2 inputs and 2 outputs for a CX gate
422        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap();
423        let [wire0, wire1] = dfg_builder.input_wires_arr();
424        let cx_handle = dfg_builder
425            .add_dataflow_op(cx_gate(), vec![wire0, wire1])
426            .unwrap();
427        let cx_node = cx_handle.node();
428        let [wire0, wire1] = cx_handle.outputs_arr();
429        let mut hugr = dfg_builder
430            .finish_hugr_with_outputs(vec![wire0, wire1])
431            .unwrap();
432
433        // Wrap in RootChecked
434        let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
435
436        // Test mapping inputs: [0,1] -> [1,0] (swapping inputs)
437        let input_map = vec![1, 0];
438        let output_map = vec![0, 1];
439
440        // Map the I/O
441        dfg_view.map_function_type(&input_map, &output_map).unwrap();
442
443        let dfg_hugr = dfg_view.hugr();
444        if let Err(err) = dfg_hugr.validate() {
445            panic!("Invalid Hugr: {err}");
446        }
447
448        // Verify the new signature
449        let new_sig = dfg_hugr
450            .get_optype(dfg_hugr.entrypoint())
451            .dataflow_signature()
452            .unwrap();
453        assert_eq!(new_sig.input_count(), 2);
454        assert_eq!(new_sig.output_count(), 2);
455
456        // Verify the connections are preserved but swapped
457        let [new_inp, new_out] = dfg_view.get_io();
458        assert_eq!(
459            dfg_hugr.linked_inputs(new_inp, 0).collect_vec(),
460            vec![(cx_node, 1.into())]
461        );
462        assert_eq!(
463            dfg_hugr.linked_inputs(new_inp, 1).collect_vec(),
464            vec![(cx_node, 0.into())]
465        );
466        assert_eq!(
467            dfg_hugr.linked_outputs(new_out, 0).collect_vec(),
468            vec![(cx_node, 0.into())]
469        );
470        assert_eq!(
471            dfg_hugr.linked_outputs(new_out, 1).collect_vec(),
472            vec![(cx_node, 1.into())]
473        );
474
475        assert_snapshot!(dfg_hugr.mermaid_string());
476    }
477
478    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
479    #[test]
480    fn test_map_io_cycle_3qb() {
481        // Create a DFG with 3 inputs and 3 outputs: CX[0, 1] and empty 2nd qubit
482        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(); 3])).unwrap();
483        let [wire0, wire1, wire2] = dfg_builder.input_wires_arr();
484        let cx_handle = dfg_builder
485            .add_dataflow_op(cx_gate(), vec![wire0, wire1])
486            .unwrap();
487        let cx_node = cx_handle.node();
488        let [wire0, wire1] = cx_handle.outputs_arr();
489        let mut hugr = dfg_builder
490            .finish_hugr_with_outputs(vec![wire0, wire1, wire2])
491            .unwrap();
492
493        // Wrap in RootChecked
494        let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
495
496        // Test cycling outputs: [0,1,2] -> [1,2,0]
497        let input_map = vec![1, 2, 0];
498        let output_map = vec![0, 1, 2];
499
500        // Map the I/O
501        dfg_view.map_function_type(&input_map, &output_map).unwrap();
502        let [dfg_inp, dfg_out] = dfg_view.get_io();
503
504        let dfg_hugr = dfg_view.hugr();
505        if let Err(err) = dfg_hugr.validate() {
506            panic!("Invalid Hugr: {err}");
507        }
508
509        // Verify the new signature
510        let new_sig = dfg_hugr
511            .get_optype(dfg_hugr.entrypoint())
512            .dataflow_signature()
513            .unwrap();
514        assert_eq!(new_sig.input_count(), 3);
515        assert_eq!(new_sig.output_count(), 3);
516
517        // Verify inp(0) -> cx(1), inp(1) -> out(2), inp(2) -> cx(0)
518        for (i, exp_gate) in [cx_node, dfg_out, cx_node].into_iter().enumerate() {
519            assert_eq!(
520                dfg_hugr.linked_inputs(dfg_inp, i).collect_vec(),
521                vec![(exp_gate, ((i + 1) % 3).into())]
522            );
523        }
524        // Verify cx(0) -> out(0), cx(1) -> out(1), inp(1) -> out(2)
525        for (i, exp_gate) in [cx_node, cx_node, dfg_inp].into_iter().enumerate() {
526            let exp_outport = std::cmp::min(i, 1);
527            assert_eq!(
528                dfg_hugr.linked_outputs(dfg_out, i).collect_vec(),
529                vec![(exp_gate, exp_outport.into())],
530                "expected {}({exp_outport}) -> out({i})",
531                dfg_hugr.get_optype(exp_gate).name()
532            );
533        }
534
535        assert_snapshot!(dfg_hugr.mermaid_string());
536    }
537
538    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
539    #[test]
540    fn test_map_io_recursive() {
541        use crate::builder::ModuleBuilder;
542        use crate::extension::prelude::{bool_t, qb_t};
543        use crate::types::Signature;
544
545        // Create a module with two functions: "foo" and "bar"
546        let mut module_builder = ModuleBuilder::new();
547
548        // Define function "foo" with nested DFGs
549        let dfg_roots = {
550            let mut foo_builder = module_builder
551                .define_function("foo", Signature::new_endo(vec![qb_t(), bool_t()]))
552                .unwrap();
553
554            let [qb, b] = foo_builder.input_wires_arr();
555
556            // Create first nested DFG
557            let mut dfg1_builder = foo_builder
558                .dfg_builder_endo([(qb_t(), qb), (bool_t(), b)])
559                .unwrap();
560            let [dfg1_qb, dfg1_b] = dfg1_builder.input_wires_arr();
561
562            // Create second nested DFG inside the first one
563            let dfg2_builder = dfg1_builder
564                .dfg_builder_endo([(qb_t(), dfg1_qb), (bool_t(), dfg1_b)])
565                .unwrap();
566            let [dfg2_qb, dfg2_b] = dfg2_builder.input_wires_arr();
567
568            // Connect inputs to outputs in innermost DFG
569            let dfg2_id = dfg2_builder.finish_with_outputs([dfg2_qb, dfg2_b]).unwrap();
570
571            // Connect through first DFG
572            let dfg1_id = dfg1_builder.finish_with_outputs(dfg2_id.outputs()).unwrap();
573
574            // Finish function
575            let foo_id = foo_builder.finish_with_outputs(dfg1_id.outputs()).unwrap();
576
577            [foo_id.node(), dfg1_id.node(), dfg2_id.node()]
578        };
579
580        let mut hugr = module_builder.finish_hugr().unwrap();
581        hugr.set_entrypoint(dfg_roots[2]);
582
583        // Test successful signature update in "foo"
584        let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
585
586        // Swap the outputs: [0,1] -> [1,0]
587        let input_map = vec![0, 1];
588        let output_map = vec![1, 0];
589
590        dfg_view.map_function_type(&input_map, &output_map).unwrap();
591
592        // Verify the new signature at each level
593        for node in dfg_roots {
594            let sig = hugr.get_optype(node).inner_function_type().unwrap();
595            assert_eq!(sig.input_types(), vec![qb_t(), bool_t()]);
596            assert_eq!(sig.output_types(), vec![bool_t(), qb_t()]);
597        }
598
599        assert_snapshot!(hugr.mermaid_string());
600    }
601}