hugr_core/hugr/views/root_checked/
dfg.rs

1//! RootChecked methods specific to dataflow graphs.
2
3use std::collections::BTreeMap;
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::{
9    IncomingPort, OutgoingPort, Port, PortIndex,
10    hugr::HugrMut,
11    ops::{
12        OpParent, OpTrait, OpType,
13        handle::{DataflowParentID, DfgID},
14    },
15    types::{NoRV, Signature, Type, TypeBase},
16};
17
18use super::RootChecked;
19
20macro_rules! impl_dataflow_parent_methods {
21    ($handle_type:ident) => {
22        impl<H: HugrMut> RootChecked<H, $handle_type<H::Node>> {
23            /// Get the input and output nodes of the DFG at the entrypoint node.
24            pub fn get_io(&self) -> [H::Node; 2] {
25                self.hugr()
26                    .get_io(self.hugr().entrypoint())
27                    .expect("valid DFG graph")
28            }
29
30            /// Rewire the inputs and outputs of the nested DFG to modify its signature.
31            ///
32            /// Reorder the outgoing resp. incoming wires at the input resp. output
33            /// node of the DFG to modify the signature of the DFG HUGR. This will
34            /// recursively update the signatures of all ancestors of the entrypoint.
35            ///
36            /// ### Arguments
37            ///
38            /// * `new_inputs`: The new input signature. After the map, the i-th input
39            ///   wire will be connected to the ports connected to the
40            ///   `new_inputs[i]`-th input of the old DFG.
41            /// * `new_outputs`: The new output signature. After the map, the i-th
42            ///   output wire will be connected to the ports connected to the
43            ///   `new_outputs[i]`-th output of the old DFG.
44            ///
45            /// Returns an `InvalidSignature` error if the new_inputs and new_outputs
46            /// map are not valid signatures.
47            ///
48            /// ### Panics
49            ///
50            /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
51            /// DFG of the entrypoint that has more than one inner DFG.
52            pub fn map_function_type(
53                &mut self,
54                new_inputs: &[usize],
55                new_outputs: &[usize],
56            ) -> Result<(), InvalidSignature> {
57                let [inp, out] = self.get_io();
58                let Self(hugr, _) = self;
59
60                // Record the old connections from and to the input and output nodes
61                let old_inputs_incoming = hugr
62                    .node_outputs(inp)
63                    .map(|p| hugr.linked_inputs(inp, p).collect_vec())
64                    .collect_vec();
65                let old_outputs_outgoing = hugr
66                    .node_inputs(out)
67                    .map(|p| hugr.linked_outputs(out, p).collect_vec())
68                    .collect_vec();
69
70                // The old signature types
71                let old_inp_sig = hugr
72                    .get_optype(inp)
73                    .dataflow_signature()
74                    .expect("input has signature");
75                let old_inp_sig = old_inp_sig.output_types();
76                let old_out_sig = hugr
77                    .get_optype(out)
78                    .dataflow_signature()
79                    .expect("output has signature");
80                let old_out_sig = old_out_sig.input_types();
81
82                // Check if the signature map is valid
83                check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?;
84                check_valid_outputs(old_out_sig, new_outputs)?;
85
86                // The new signature types
87                let new_inp_sig = new_inputs
88                    .iter()
89                    .map(|&i| old_inp_sig[i].clone())
90                    .collect_vec();
91                let new_out_sig = new_outputs
92                    .iter()
93                    .map(|&i| old_out_sig[i].clone())
94                    .collect_vec();
95                let new_sig = Signature::new(new_inp_sig, new_out_sig);
96
97                // Remove all edges of the input and output nodes
98                disconnect_all(hugr, inp);
99                disconnect_all(hugr, out);
100
101                // Update the signatures of the IO and their ancestors
102                let mut is_ancestor = false;
103                let mut node = hugr.entrypoint();
104                while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
105                    let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
106                    for node in [node, inner_inp, inner_out] {
107                        update_signature(hugr, node, &new_sig);
108                    }
109                    if is_ancestor {
110                        update_inner_dfg_links(hugr, node);
111                    }
112                    if let Some(parent) = hugr.get_parent(node) {
113                        node = parent;
114                        is_ancestor = true;
115                    } else {
116                        break;
117                    }
118                }
119
120                // Insert the new edges at the input
121                let mut old_output_to_new_input = BTreeMap::<IncomingPort, OutgoingPort>::new();
122                for (inp_pos, &old_pos) in new_inputs.iter().enumerate() {
123                    for &(node, port) in &old_inputs_incoming[old_pos] {
124                        if node != out {
125                            hugr.connect(inp, inp_pos, node, port);
126                        } else {
127                            old_output_to_new_input.insert(port, inp_pos.into());
128                        }
129                    }
130                }
131
132                // Insert the new edges at the output
133                for (out_pos, &old_pos) in new_outputs.iter().enumerate() {
134                    for &(node, port) in &old_outputs_outgoing[old_pos] {
135                        if node != inp {
136                            hugr.connect(node, port, out, out_pos);
137                        } else {
138                            let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap();
139                            hugr.connect(inp, inp_pos, out, out_pos);
140                        }
141                    }
142                }
143
144                Ok(())
145            }
146
147            /// Add copyable inputs to the DFG to modify its signature.
148            ///
149            /// Append new inputs to the DFG. These will not be connected to any op and
150            /// must be copyable. This will recursively update the signatures of all
151            /// ancestors of the entrypoint.
152            ///
153            /// ### Arguments
154            ///
155            /// * `new_inputs`: The new input types to append to the signature.
156            ///
157            /// Returns an `InvalidSignature` error if the new_input types are not
158            /// copyable.
159            ///
160            /// ### Panics
161            ///
162            /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
163            /// DFG of the entrypoint that has more than one inner DFG.
164            pub fn extend_inputs<'a>(
165                &mut self,
166                new_inputs: impl IntoIterator<Item = &'a Type>,
167            ) -> Result<(), InvalidSignature> {
168                let Self(hugr, _) = self;
169                let curr_sig = hugr
170                    .get_optype(hugr.entrypoint())
171                    .inner_function_type()
172                    .expect("valid DFG graph")
173                    .into_owned();
174
175                let n_inputs = curr_sig.input_count();
176
177                let new_inputs: Vec<_> = new_inputs
178                    .into_iter()
179                    .enumerate()
180                    .map(|(i, t)| {
181                        if t.copyable() {
182                            Ok(t)
183                        } else {
184                            let p = IncomingPort::from(n_inputs + i);
185                            Err(InvalidSignature::ExpectedCopyable(p.into()))
186                        }
187                    })
188                    .try_collect()?;
189
190                let new_sig = Signature::new(curr_sig.input.extend(new_inputs), curr_sig.output);
191
192                // Update the signatures of the IO and their ancestors
193                let mut node = hugr.entrypoint();
194                let mut is_ancestor = false;
195                while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
196                    let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
197                    for node in [node, inner_inp, inner_out] {
198                        update_signature(hugr, node, &new_sig);
199                    }
200                    if is_ancestor {
201                        update_inner_dfg_links(hugr, node);
202                    }
203                    if let Some(parent) = hugr.get_parent(node) {
204                        node = parent;
205                        is_ancestor = true;
206                    } else {
207                        break;
208                    }
209                }
210
211                Ok(())
212            }
213        }
214    };
215}
216
217impl_dataflow_parent_methods!(DataflowParentID);
218impl_dataflow_parent_methods!(DfgID);
219
220/// Panics if the DFG within `node` is not a single inner DFG.
221fn update_inner_dfg_links<H: HugrMut>(hugr: &mut H, node: H::Node) {
222    // connect all edges of the inner DFG to the input and output nodes
223    let inner_dfg = hugr
224        .children(node)
225        .skip(2)
226        .exactly_one()
227        .ok()
228        .expect("no non-trivial inner DFG");
229
230    let [inp, out] = hugr.get_io(node).expect("valid DFG graph");
231    disconnect_all(hugr, inner_dfg);
232    for (out_port, _) in hugr.out_value_types(inp).collect_vec() {
233        hugr.connect(inp, out_port, inner_dfg, out_port.index());
234    }
235    for (in_port, _) in hugr.in_value_types(out).collect_vec() {
236        hugr.connect(inner_dfg, in_port.index(), out, in_port);
237    }
238}
239
240fn disconnect_all<H: HugrMut>(hugr: &mut H, node: H::Node) {
241    let all_ports = hugr.all_node_ports(node).collect_vec();
242    for port in all_ports {
243        hugr.disconnect(node, port);
244    }
245}
246
247fn update_signature<H: HugrMut>(hugr: &mut H, node: H::Node, new_sig: &Signature) {
248    match hugr.optype_mut(node) {
249        OpType::DFG(dfg) => {
250            dfg.signature = new_sig.clone();
251        }
252        OpType::FuncDefn(fn_def_op) => *fn_def_op.signature_mut() = new_sig.clone().into(),
253        OpType::Input(inp) => {
254            inp.types = new_sig.input().clone();
255        }
256        OpType::Output(out) => out.types = new_sig.output().clone(),
257        _ => panic!("only update signature of DFG, FuncDefn, Input, or Output"),
258    };
259    let new_op = hugr.get_optype(node);
260    hugr.set_num_ports(node, new_op.input_count(), new_op.output_count());
261}
262
263fn check_valid_inputs<V>(
264    old_ports: &[Vec<V>],
265    old_sig: &[TypeBase<NoRV>],
266    map_sig: &[usize],
267) -> Result<(), InvalidSignature> {
268    if let Some(old_pos) = map_sig
269        .iter()
270        .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos))
271    {
272        return Err(InvalidSignature::UnknownIO(old_pos, "input"));
273    }
274
275    let counts = map_sig.iter().copied().counts();
276    if let Some(old_pos) = old_ports.iter().enumerate().find_map(|(old_pos, vec)| {
277        ((!vec.is_empty() || old_sig.get(old_pos).is_some_and(|t| !t.copyable()))
278            && !counts.contains_key(&old_pos))
279        .then_some(old_pos)
280    }) {
281        return Err(InvalidSignature::MissingIO(old_pos, "input"));
282    }
283
284    if let Some(old_pos) = counts
285        .iter()
286        .find_map(|(&old_pos, &count)| (count > 1).then_some(old_pos))
287    {
288        return Err(InvalidSignature::DuplicateInput(old_pos));
289    }
290
291    Ok(())
292}
293
294fn check_valid_outputs(
295    old_sig: &[TypeBase<NoRV>],
296    map_sig: &[usize],
297) -> Result<(), InvalidSignature> {
298    if let Some(old_pos) = map_sig
299        .iter()
300        .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos))
301    {
302        return Err(InvalidSignature::UnknownIO(old_pos, "output"));
303    }
304
305    let counts = map_sig.iter().copied().counts();
306    let linear_types = old_sig
307        .iter()
308        .enumerate()
309        .filter_map(|(pos, t)| (!t.copyable()).then_some(pos));
310    for old_pos in linear_types {
311        let Some(&cnt) = counts.get(&old_pos) else {
312            return Err(InvalidSignature::MissingIO(old_pos, "output"));
313        };
314        if cnt != 1 {
315            return Err(InvalidSignature::LinearityViolation(old_pos, "output"));
316        }
317    }
318
319    Ok(())
320}
321
322/// Errors that can occur when mapping the I/O of a DFG.
323#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Error)]
324#[non_exhaustive]
325pub enum InvalidSignature {
326    /// Error when a required input/output is missing from the new signature
327    #[error("{1} at position {0} is required but missing in new signature")]
328    MissingIO(usize, &'static str),
329    /// Error when trying to access an input/output that doesn't exist in the
330    /// signature
331    #[error("No {1} at position {0} in signature")]
332    UnknownIO(usize, &'static str),
333    /// Error when a linear input/output is used multiple times or not at all
334    #[error("Linearity of {1} at position {0} is not preserved in new signature")]
335    LinearityViolation(usize, &'static str),
336    /// Error when an input is used multiple times in the new signature
337    #[error("Input at position {0} is duplicated in new signature")]
338    DuplicateInput(usize),
339    /// Expected a copyable type at the given port
340    #[error("Type at port {0:?} must be copyable")]
341    ExpectedCopyable(Port),
342}
343
344#[cfg(test)]
345mod test {
346    use insta::assert_snapshot;
347
348    use super::*;
349    use crate::builder::{
350        DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig,
351    };
352    use crate::extension::prelude::{bool_t, qb_t};
353    use crate::hugr::views::root_checked::RootChecked;
354    use crate::ops::handle::NodeHandle;
355    use crate::ops::{NamedOp, OpParent};
356    use crate::std_extensions::arithmetic::float_types::float64_type;
357    use crate::types::Signature;
358    use crate::utils::test_quantum_extension::cx_gate;
359    use crate::{Hugr, HugrView};
360
361    fn new_empty_dfg(sig: Signature) -> Hugr {
362        let dfg_builder = DFGBuilder::new(sig).unwrap();
363        let wires = dfg_builder.input_wires();
364        dfg_builder.finish_hugr_with_outputs(wires).unwrap()
365    }
366
367    #[test]
368    fn test_map_io() {
369        // Create a DFG with 2 inputs and 2 outputs
370        let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
371        let mut hugr = new_empty_dfg(sig);
372
373        // Wrap in RootChecked
374        let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
375
376        // Test mapping inputs: [0,1] -> [1,0]
377        let input_map = vec![1, 0];
378        let output_map = vec![0, 1];
379
380        // Map the I/O
381        dfg_view.map_function_type(&input_map, &output_map).unwrap();
382
383        // Verify the new signature
384        let dfg_hugr = dfg_view.hugr();
385        let new_sig = dfg_hugr
386            .get_optype(dfg_hugr.entrypoint())
387            .dataflow_signature()
388            .unwrap();
389        assert_eq!(new_sig.input_count(), 2);
390        assert_eq!(new_sig.output_count(), 2);
391
392        // Test invalid mapping - missing input
393        let invalid_input_map = vec![0, 0];
394        let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
395        assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
396
397        // Test invalid mapping - duplicate input
398        let invalid_input_map = vec![0, 0, 1];
399        assert!(matches!(
400            dfg_view.map_function_type(&invalid_input_map, &output_map),
401            Err(InvalidSignature::DuplicateInput(0))
402        ));
403
404        // Test invalid mapping - unknown output
405        let invalid_output_map = vec![0, 2];
406        assert!(matches!(
407            dfg_view.map_function_type(&input_map, &invalid_output_map),
408            Err(InvalidSignature::UnknownIO(2, "output"))
409        ));
410    }
411
412    #[test]
413    fn test_map_io_dfg_id() {
414        // Create a DFG with 2 inputs and 2 outputs
415        let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
416        let mut hugr = new_empty_dfg(sig);
417
418        // Wrap in RootChecked
419        let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
420
421        // Test mapping inputs: [0,1] -> [1,0]
422        let input_map = vec![1, 0];
423        let output_map = vec![0, 1];
424
425        // Map the I/O
426        dfg_view.map_function_type(&input_map, &output_map).unwrap();
427
428        // Verify the new signature
429        let dfg_hugr = dfg_view.hugr();
430        let new_sig = dfg_hugr
431            .get_optype(dfg_hugr.entrypoint())
432            .dataflow_signature()
433            .unwrap();
434        assert_eq!(new_sig.input_count(), 2);
435        assert_eq!(new_sig.output_count(), 2);
436
437        // Test invalid mapping - missing input
438        let invalid_input_map = vec![0, 0];
439        let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
440        assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
441
442        // Test invalid mapping - duplicate input
443        let invalid_input_map = vec![0, 0, 1];
444        assert!(matches!(
445            dfg_view.map_function_type(&invalid_input_map, &output_map),
446            Err(InvalidSignature::DuplicateInput(0))
447        ));
448
449        // Test invalid mapping - unknown output
450        let invalid_output_map = vec![0, 2];
451        assert!(matches!(
452            dfg_view.map_function_type(&input_map, &invalid_output_map),
453            Err(InvalidSignature::UnknownIO(2, "output"))
454        ));
455    }
456
457    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
458    #[test]
459    fn test_map_io_duplicate_output() {
460        // Create a DFG with 1 input and 1 output
461        let sig = Signature::new_endo(vec![bool_t()]);
462        let mut hugr = new_empty_dfg(sig);
463
464        // Wrap in RootChecked
465        let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
466
467        // Test mapping outputs: [0] -> [0,0] (duplicating the output)
468        let input_map = vec![0];
469        let output_map = vec![0, 0];
470
471        // Map the I/O
472        dfg_view.map_function_type(&input_map, &output_map).unwrap();
473
474        let dfg_hugr = dfg_view.hugr();
475        if let Err(err) = dfg_hugr.validate() {
476            panic!("Invalid Hugr: {err}");
477        }
478
479        // Verify the new signature
480        let new_sig = dfg_hugr
481            .get_optype(dfg_hugr.entrypoint())
482            .dataflow_signature()
483            .unwrap();
484        assert_eq!(new_sig.input_count(), 1);
485        assert_eq!(new_sig.output_count(), 2);
486        assert_snapshot!(dfg_hugr.mermaid_string());
487    }
488
489    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
490    #[test]
491    fn test_map_io_cx_gate() {
492        // Create a DFG with 2 inputs and 2 outputs for a CX gate
493        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap();
494        let [wire0, wire1] = dfg_builder.input_wires_arr();
495        let cx_handle = dfg_builder
496            .add_dataflow_op(cx_gate(), vec![wire0, wire1])
497            .unwrap();
498        let cx_node = cx_handle.node();
499        let [wire0, wire1] = cx_handle.outputs_arr();
500        let mut hugr = dfg_builder
501            .finish_hugr_with_outputs(vec![wire0, wire1])
502            .unwrap();
503
504        // Wrap in RootChecked
505        let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
506
507        // Test mapping inputs: [0,1] -> [1,0] (swapping inputs)
508        let input_map = vec![1, 0];
509        let output_map = vec![0, 1];
510
511        // Map the I/O
512        dfg_view.map_function_type(&input_map, &output_map).unwrap();
513
514        let dfg_hugr = dfg_view.hugr();
515        if let Err(err) = dfg_hugr.validate() {
516            panic!("Invalid Hugr: {err}");
517        }
518
519        // Verify the new signature
520        let new_sig = dfg_hugr
521            .get_optype(dfg_hugr.entrypoint())
522            .dataflow_signature()
523            .unwrap();
524        assert_eq!(new_sig.input_count(), 2);
525        assert_eq!(new_sig.output_count(), 2);
526
527        // Verify the connections are preserved but swapped
528        let [new_inp, new_out] = dfg_view.get_io();
529        assert_eq!(
530            dfg_hugr.linked_inputs(new_inp, 0).collect_vec(),
531            vec![(cx_node, 1.into())]
532        );
533        assert_eq!(
534            dfg_hugr.linked_inputs(new_inp, 1).collect_vec(),
535            vec![(cx_node, 0.into())]
536        );
537        assert_eq!(
538            dfg_hugr.linked_outputs(new_out, 0).collect_vec(),
539            vec![(cx_node, 0.into())]
540        );
541        assert_eq!(
542            dfg_hugr.linked_outputs(new_out, 1).collect_vec(),
543            vec![(cx_node, 1.into())]
544        );
545
546        assert_snapshot!(dfg_hugr.mermaid_string());
547    }
548
549    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
550    #[test]
551    fn test_map_io_cycle_3qb() {
552        // Create a DFG with 3 inputs and 3 outputs: CX[0, 1] and empty 2nd qubit
553        let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(); 3])).unwrap();
554        let [wire0, wire1, wire2] = dfg_builder.input_wires_arr();
555        let cx_handle = dfg_builder
556            .add_dataflow_op(cx_gate(), vec![wire0, wire1])
557            .unwrap();
558        let cx_node = cx_handle.node();
559        let [wire0, wire1] = cx_handle.outputs_arr();
560        let mut hugr = dfg_builder
561            .finish_hugr_with_outputs(vec![wire0, wire1, wire2])
562            .unwrap();
563
564        // Wrap in RootChecked
565        let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
566
567        // Test cycling outputs: [0,1,2] -> [1,2,0]
568        let input_map = vec![1, 2, 0];
569        let output_map = vec![0, 1, 2];
570
571        // Map the I/O
572        dfg_view.map_function_type(&input_map, &output_map).unwrap();
573        let [dfg_inp, dfg_out] = dfg_view.get_io();
574
575        let dfg_hugr = dfg_view.hugr();
576        if let Err(err) = dfg_hugr.validate() {
577            panic!("Invalid Hugr: {err}");
578        }
579
580        // Verify the new signature
581        let new_sig = dfg_hugr
582            .get_optype(dfg_hugr.entrypoint())
583            .dataflow_signature()
584            .unwrap();
585        assert_eq!(new_sig.input_count(), 3);
586        assert_eq!(new_sig.output_count(), 3);
587
588        // Verify inp(0) -> cx(1), inp(1) -> out(2), inp(2) -> cx(0)
589        for (i, exp_gate) in [cx_node, dfg_out, cx_node].into_iter().enumerate() {
590            assert_eq!(
591                dfg_hugr.linked_inputs(dfg_inp, i).collect_vec(),
592                vec![(exp_gate, ((i + 1) % 3).into())]
593            );
594        }
595        // Verify cx(0) -> out(0), cx(1) -> out(1), inp(1) -> out(2)
596        for (i, exp_gate) in [cx_node, cx_node, dfg_inp].into_iter().enumerate() {
597            let exp_outport = std::cmp::min(i, 1);
598            assert_eq!(
599                dfg_hugr.linked_outputs(dfg_out, i).collect_vec(),
600                vec![(exp_gate, exp_outport.into())],
601                "expected {}({exp_outport}) -> out({i})",
602                dfg_hugr.get_optype(exp_gate).name()
603            );
604        }
605
606        assert_snapshot!(dfg_hugr.mermaid_string());
607    }
608
609    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
610    #[test]
611    fn test_map_io_recursive() {
612        use crate::builder::ModuleBuilder;
613        use crate::extension::prelude::{bool_t, qb_t};
614        use crate::types::Signature;
615
616        // Create a module with two functions: "foo" and "bar"
617        let mut module_builder = ModuleBuilder::new();
618
619        // Define function "foo" with nested DFGs
620        let dfg_roots = {
621            let mut foo_builder = module_builder
622                .define_function("foo", Signature::new_endo(vec![qb_t(), bool_t()]))
623                .unwrap();
624
625            let [qb, b] = foo_builder.input_wires_arr();
626
627            // Create first nested DFG
628            let mut dfg1_builder = foo_builder
629                .dfg_builder_endo([(qb_t(), qb), (bool_t(), b)])
630                .unwrap();
631            let [dfg1_qb, dfg1_b] = dfg1_builder.input_wires_arr();
632
633            // Create second nested DFG inside the first one
634            let dfg2_builder = dfg1_builder
635                .dfg_builder_endo([(qb_t(), dfg1_qb), (bool_t(), dfg1_b)])
636                .unwrap();
637            let [dfg2_qb, dfg2_b] = dfg2_builder.input_wires_arr();
638
639            // Connect inputs to outputs in innermost DFG
640            let dfg2_id = dfg2_builder.finish_with_outputs([dfg2_qb, dfg2_b]).unwrap();
641
642            // Connect through first DFG
643            let dfg1_id = dfg1_builder.finish_with_outputs(dfg2_id.outputs()).unwrap();
644
645            // Finish function
646            let foo_id = foo_builder.finish_with_outputs(dfg1_id.outputs()).unwrap();
647
648            [foo_id.node(), dfg1_id.node(), dfg2_id.node()]
649        };
650
651        let mut hugr = module_builder.finish_hugr().unwrap();
652        hugr.set_entrypoint(dfg_roots[2]);
653
654        // Test successful signature update in "foo"
655        let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
656
657        // Swap the outputs: [0,1] -> [1,0]
658        let input_map = vec![0, 1];
659        let output_map = vec![1, 0];
660
661        dfg_view.map_function_type(&input_map, &output_map).unwrap();
662
663        // Verify the new signature at each level
664        for node in dfg_roots {
665            let sig = hugr.get_optype(node).inner_function_type().unwrap();
666            assert_eq!(sig.input_types(), vec![qb_t(), bool_t()]);
667            assert_eq!(sig.output_types(), vec![bool_t(), qb_t()]);
668        }
669
670        assert_snapshot!(hugr.mermaid_string());
671    }
672
673    #[test]
674    fn test_extend_inputs() {
675        // Create an empty DFG
676        let dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t()])).unwrap();
677        let [wire] = dfg_builder.input_wires_arr();
678        let mut hugr = dfg_builder.finish_hugr_with_outputs(vec![wire]).unwrap();
679
680        // Wrap in RootChecked
681        let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
682
683        // Extend the inputs
684        let new_inputs = vec![bool_t(), float64_type()];
685        dfg_view.extend_inputs(&new_inputs).unwrap();
686        assert_eq!(
687            dfg_view.hugr().inner_function_type().unwrap(),
688            Signature::new(vec![qb_t(), bool_t(), float64_type()], vec![qb_t()])
689        );
690
691        let new_inputs_fail = vec![qb_t()];
692        let err = dfg_view.extend_inputs(&new_inputs_fail);
693        assert_eq!(
694            err,
695            Err(InvalidSignature::ExpectedCopyable(
696                IncomingPort::from(3).into()
697            ))
698        );
699    }
700}