1use std::collections::BTreeMap;
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::{
9 IncomingPort, OutgoingPort, Port, PortIndex,
10 hugr::HugrMut,
11 ops::{
12 OpParent, OpTrait, OpType,
13 handle::{DataflowParentID, DfgID},
14 },
15 types::{NoRV, Signature, Type, TypeBase},
16};
17
18use super::RootChecked;
19
20macro_rules! impl_dataflow_parent_methods {
21 ($handle_type:ident) => {
22 impl<H: HugrMut> RootChecked<H, $handle_type<H::Node>> {
23 pub fn get_io(&self) -> [H::Node; 2] {
25 self.hugr()
26 .get_io(self.hugr().entrypoint())
27 .expect("valid DFG graph")
28 }
29
30 pub fn map_function_type(
53 &mut self,
54 new_inputs: &[usize],
55 new_outputs: &[usize],
56 ) -> Result<(), InvalidSignature> {
57 let [inp, out] = self.get_io();
58 let Self(hugr, _) = self;
59
60 let old_inputs_incoming = hugr
62 .node_outputs(inp)
63 .map(|p| hugr.linked_inputs(inp, p).collect_vec())
64 .collect_vec();
65 let old_outputs_outgoing = hugr
66 .node_inputs(out)
67 .map(|p| hugr.linked_outputs(out, p).collect_vec())
68 .collect_vec();
69
70 let old_inp_sig = hugr
72 .get_optype(inp)
73 .dataflow_signature()
74 .expect("input has signature");
75 let old_inp_sig = old_inp_sig.output_types();
76 let old_out_sig = hugr
77 .get_optype(out)
78 .dataflow_signature()
79 .expect("output has signature");
80 let old_out_sig = old_out_sig.input_types();
81
82 check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?;
84 check_valid_outputs(old_out_sig, new_outputs)?;
85
86 let new_inp_sig = new_inputs
88 .iter()
89 .map(|&i| old_inp_sig[i].clone())
90 .collect_vec();
91 let new_out_sig = new_outputs
92 .iter()
93 .map(|&i| old_out_sig[i].clone())
94 .collect_vec();
95 let new_sig = Signature::new(new_inp_sig, new_out_sig);
96
97 disconnect_all(hugr, inp);
99 disconnect_all(hugr, out);
100
101 let mut is_ancestor = false;
103 let mut node = hugr.entrypoint();
104 while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
105 let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
106 for node in [node, inner_inp, inner_out] {
107 update_signature(hugr, node, &new_sig);
108 }
109 if is_ancestor {
110 update_inner_dfg_links(hugr, node);
111 }
112 if let Some(parent) = hugr.get_parent(node) {
113 node = parent;
114 is_ancestor = true;
115 } else {
116 break;
117 }
118 }
119
120 let mut old_output_to_new_input = BTreeMap::<IncomingPort, OutgoingPort>::new();
122 for (inp_pos, &old_pos) in new_inputs.iter().enumerate() {
123 for &(node, port) in &old_inputs_incoming[old_pos] {
124 if node != out {
125 hugr.connect(inp, inp_pos, node, port);
126 } else {
127 old_output_to_new_input.insert(port, inp_pos.into());
128 }
129 }
130 }
131
132 for (out_pos, &old_pos) in new_outputs.iter().enumerate() {
134 for &(node, port) in &old_outputs_outgoing[old_pos] {
135 if node != inp {
136 hugr.connect(node, port, out, out_pos);
137 } else {
138 let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap();
139 hugr.connect(inp, inp_pos, out, out_pos);
140 }
141 }
142 }
143
144 Ok(())
145 }
146
147 pub fn extend_inputs<'a>(
165 &mut self,
166 new_inputs: impl IntoIterator<Item = &'a Type>,
167 ) -> Result<(), InvalidSignature> {
168 let Self(hugr, _) = self;
169 let curr_sig = hugr
170 .get_optype(hugr.entrypoint())
171 .inner_function_type()
172 .expect("valid DFG graph")
173 .into_owned();
174
175 let n_inputs = curr_sig.input_count();
176
177 let new_inputs: Vec<_> = new_inputs
178 .into_iter()
179 .enumerate()
180 .map(|(i, t)| {
181 if t.copyable() {
182 Ok(t)
183 } else {
184 let p = IncomingPort::from(n_inputs + i);
185 Err(InvalidSignature::ExpectedCopyable(p.into()))
186 }
187 })
188 .try_collect()?;
189
190 let new_sig = Signature::new(curr_sig.input.extend(new_inputs), curr_sig.output);
191
192 let mut node = hugr.entrypoint();
194 let mut is_ancestor = false;
195 while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
196 let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
197 for node in [node, inner_inp, inner_out] {
198 update_signature(hugr, node, &new_sig);
199 }
200 if is_ancestor {
201 update_inner_dfg_links(hugr, node);
202 }
203 if let Some(parent) = hugr.get_parent(node) {
204 node = parent;
205 is_ancestor = true;
206 } else {
207 break;
208 }
209 }
210
211 Ok(())
212 }
213 }
214 };
215}
216
217impl_dataflow_parent_methods!(DataflowParentID);
218impl_dataflow_parent_methods!(DfgID);
219
220fn update_inner_dfg_links<H: HugrMut>(hugr: &mut H, node: H::Node) {
222 let inner_dfg = hugr
224 .children(node)
225 .skip(2)
226 .exactly_one()
227 .ok()
228 .expect("no non-trivial inner DFG");
229
230 let [inp, out] = hugr.get_io(node).expect("valid DFG graph");
231 disconnect_all(hugr, inner_dfg);
232 for (out_port, _) in hugr.out_value_types(inp).collect_vec() {
233 hugr.connect(inp, out_port, inner_dfg, out_port.index());
234 }
235 for (in_port, _) in hugr.in_value_types(out).collect_vec() {
236 hugr.connect(inner_dfg, in_port.index(), out, in_port);
237 }
238}
239
240fn disconnect_all<H: HugrMut>(hugr: &mut H, node: H::Node) {
241 let all_ports = hugr.all_node_ports(node).collect_vec();
242 for port in all_ports {
243 hugr.disconnect(node, port);
244 }
245}
246
247fn update_signature<H: HugrMut>(hugr: &mut H, node: H::Node, new_sig: &Signature) {
248 match hugr.optype_mut(node) {
249 OpType::DFG(dfg) => {
250 dfg.signature = new_sig.clone();
251 }
252 OpType::FuncDefn(fn_def_op) => *fn_def_op.signature_mut() = new_sig.clone().into(),
253 OpType::Input(inp) => {
254 inp.types = new_sig.input().clone();
255 }
256 OpType::Output(out) => out.types = new_sig.output().clone(),
257 _ => panic!("only update signature of DFG, FuncDefn, Input, or Output"),
258 };
259 let new_op = hugr.get_optype(node);
260 hugr.set_num_ports(node, new_op.input_count(), new_op.output_count());
261}
262
263fn check_valid_inputs<V>(
264 old_ports: &[Vec<V>],
265 old_sig: &[TypeBase<NoRV>],
266 map_sig: &[usize],
267) -> Result<(), InvalidSignature> {
268 if let Some(old_pos) = map_sig
269 .iter()
270 .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos))
271 {
272 return Err(InvalidSignature::UnknownIO(old_pos, "input"));
273 }
274
275 let counts = map_sig.iter().copied().counts();
276 if let Some(old_pos) = old_ports.iter().enumerate().find_map(|(old_pos, vec)| {
277 ((!vec.is_empty() || old_sig.get(old_pos).is_some_and(|t| !t.copyable()))
278 && !counts.contains_key(&old_pos))
279 .then_some(old_pos)
280 }) {
281 return Err(InvalidSignature::MissingIO(old_pos, "input"));
282 }
283
284 if let Some(old_pos) = counts
285 .iter()
286 .find_map(|(&old_pos, &count)| (count > 1).then_some(old_pos))
287 {
288 return Err(InvalidSignature::DuplicateInput(old_pos));
289 }
290
291 Ok(())
292}
293
294fn check_valid_outputs(
295 old_sig: &[TypeBase<NoRV>],
296 map_sig: &[usize],
297) -> Result<(), InvalidSignature> {
298 if let Some(old_pos) = map_sig
299 .iter()
300 .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos))
301 {
302 return Err(InvalidSignature::UnknownIO(old_pos, "output"));
303 }
304
305 let counts = map_sig.iter().copied().counts();
306 let linear_types = old_sig
307 .iter()
308 .enumerate()
309 .filter_map(|(pos, t)| (!t.copyable()).then_some(pos));
310 for old_pos in linear_types {
311 let Some(&cnt) = counts.get(&old_pos) else {
312 return Err(InvalidSignature::MissingIO(old_pos, "output"));
313 };
314 if cnt != 1 {
315 return Err(InvalidSignature::LinearityViolation(old_pos, "output"));
316 }
317 }
318
319 Ok(())
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Error)]
324#[non_exhaustive]
325pub enum InvalidSignature {
326 #[error("{1} at position {0} is required but missing in new signature")]
328 MissingIO(usize, &'static str),
329 #[error("No {1} at position {0} in signature")]
332 UnknownIO(usize, &'static str),
333 #[error("Linearity of {1} at position {0} is not preserved in new signature")]
335 LinearityViolation(usize, &'static str),
336 #[error("Input at position {0} is duplicated in new signature")]
338 DuplicateInput(usize),
339 #[error("Type at port {0:?} must be copyable")]
341 ExpectedCopyable(Port),
342}
343
344#[cfg(test)]
345mod test {
346 use insta::assert_snapshot;
347
348 use super::*;
349 use crate::builder::{
350 DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig,
351 };
352 use crate::extension::prelude::{bool_t, qb_t};
353 use crate::hugr::views::root_checked::RootChecked;
354 use crate::ops::handle::NodeHandle;
355 use crate::ops::{NamedOp, OpParent};
356 use crate::std_extensions::arithmetic::float_types::float64_type;
357 use crate::types::Signature;
358 use crate::utils::test_quantum_extension::cx_gate;
359 use crate::{Hugr, HugrView};
360
361 fn new_empty_dfg(sig: Signature) -> Hugr {
362 let dfg_builder = DFGBuilder::new(sig).unwrap();
363 let wires = dfg_builder.input_wires();
364 dfg_builder.finish_hugr_with_outputs(wires).unwrap()
365 }
366
367 #[test]
368 fn test_map_io() {
369 let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
371 let mut hugr = new_empty_dfg(sig);
372
373 let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
375
376 let input_map = vec![1, 0];
378 let output_map = vec![0, 1];
379
380 dfg_view.map_function_type(&input_map, &output_map).unwrap();
382
383 let dfg_hugr = dfg_view.hugr();
385 let new_sig = dfg_hugr
386 .get_optype(dfg_hugr.entrypoint())
387 .dataflow_signature()
388 .unwrap();
389 assert_eq!(new_sig.input_count(), 2);
390 assert_eq!(new_sig.output_count(), 2);
391
392 let invalid_input_map = vec![0, 0];
394 let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
395 assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
396
397 let invalid_input_map = vec![0, 0, 1];
399 assert!(matches!(
400 dfg_view.map_function_type(&invalid_input_map, &output_map),
401 Err(InvalidSignature::DuplicateInput(0))
402 ));
403
404 let invalid_output_map = vec![0, 2];
406 assert!(matches!(
407 dfg_view.map_function_type(&input_map, &invalid_output_map),
408 Err(InvalidSignature::UnknownIO(2, "output"))
409 ));
410 }
411
412 #[test]
413 fn test_map_io_dfg_id() {
414 let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
416 let mut hugr = new_empty_dfg(sig);
417
418 let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
420
421 let input_map = vec![1, 0];
423 let output_map = vec![0, 1];
424
425 dfg_view.map_function_type(&input_map, &output_map).unwrap();
427
428 let dfg_hugr = dfg_view.hugr();
430 let new_sig = dfg_hugr
431 .get_optype(dfg_hugr.entrypoint())
432 .dataflow_signature()
433 .unwrap();
434 assert_eq!(new_sig.input_count(), 2);
435 assert_eq!(new_sig.output_count(), 2);
436
437 let invalid_input_map = vec![0, 0];
439 let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
440 assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
441
442 let invalid_input_map = vec![0, 0, 1];
444 assert!(matches!(
445 dfg_view.map_function_type(&invalid_input_map, &output_map),
446 Err(InvalidSignature::DuplicateInput(0))
447 ));
448
449 let invalid_output_map = vec![0, 2];
451 assert!(matches!(
452 dfg_view.map_function_type(&input_map, &invalid_output_map),
453 Err(InvalidSignature::UnknownIO(2, "output"))
454 ));
455 }
456
457 #[cfg_attr(miri, ignore)] #[test]
459 fn test_map_io_duplicate_output() {
460 let sig = Signature::new_endo(vec![bool_t()]);
462 let mut hugr = new_empty_dfg(sig);
463
464 let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
466
467 let input_map = vec![0];
469 let output_map = vec![0, 0];
470
471 dfg_view.map_function_type(&input_map, &output_map).unwrap();
473
474 let dfg_hugr = dfg_view.hugr();
475 if let Err(err) = dfg_hugr.validate() {
476 panic!("Invalid Hugr: {err}");
477 }
478
479 let new_sig = dfg_hugr
481 .get_optype(dfg_hugr.entrypoint())
482 .dataflow_signature()
483 .unwrap();
484 assert_eq!(new_sig.input_count(), 1);
485 assert_eq!(new_sig.output_count(), 2);
486 assert_snapshot!(dfg_hugr.mermaid_string());
487 }
488
489 #[cfg_attr(miri, ignore)] #[test]
491 fn test_map_io_cx_gate() {
492 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap();
494 let [wire0, wire1] = dfg_builder.input_wires_arr();
495 let cx_handle = dfg_builder
496 .add_dataflow_op(cx_gate(), vec![wire0, wire1])
497 .unwrap();
498 let cx_node = cx_handle.node();
499 let [wire0, wire1] = cx_handle.outputs_arr();
500 let mut hugr = dfg_builder
501 .finish_hugr_with_outputs(vec![wire0, wire1])
502 .unwrap();
503
504 let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
506
507 let input_map = vec![1, 0];
509 let output_map = vec![0, 1];
510
511 dfg_view.map_function_type(&input_map, &output_map).unwrap();
513
514 let dfg_hugr = dfg_view.hugr();
515 if let Err(err) = dfg_hugr.validate() {
516 panic!("Invalid Hugr: {err}");
517 }
518
519 let new_sig = dfg_hugr
521 .get_optype(dfg_hugr.entrypoint())
522 .dataflow_signature()
523 .unwrap();
524 assert_eq!(new_sig.input_count(), 2);
525 assert_eq!(new_sig.output_count(), 2);
526
527 let [new_inp, new_out] = dfg_view.get_io();
529 assert_eq!(
530 dfg_hugr.linked_inputs(new_inp, 0).collect_vec(),
531 vec![(cx_node, 1.into())]
532 );
533 assert_eq!(
534 dfg_hugr.linked_inputs(new_inp, 1).collect_vec(),
535 vec![(cx_node, 0.into())]
536 );
537 assert_eq!(
538 dfg_hugr.linked_outputs(new_out, 0).collect_vec(),
539 vec![(cx_node, 0.into())]
540 );
541 assert_eq!(
542 dfg_hugr.linked_outputs(new_out, 1).collect_vec(),
543 vec![(cx_node, 1.into())]
544 );
545
546 assert_snapshot!(dfg_hugr.mermaid_string());
547 }
548
549 #[cfg_attr(miri, ignore)] #[test]
551 fn test_map_io_cycle_3qb() {
552 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(); 3])).unwrap();
554 let [wire0, wire1, wire2] = dfg_builder.input_wires_arr();
555 let cx_handle = dfg_builder
556 .add_dataflow_op(cx_gate(), vec![wire0, wire1])
557 .unwrap();
558 let cx_node = cx_handle.node();
559 let [wire0, wire1] = cx_handle.outputs_arr();
560 let mut hugr = dfg_builder
561 .finish_hugr_with_outputs(vec![wire0, wire1, wire2])
562 .unwrap();
563
564 let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
566
567 let input_map = vec![1, 2, 0];
569 let output_map = vec![0, 1, 2];
570
571 dfg_view.map_function_type(&input_map, &output_map).unwrap();
573 let [dfg_inp, dfg_out] = dfg_view.get_io();
574
575 let dfg_hugr = dfg_view.hugr();
576 if let Err(err) = dfg_hugr.validate() {
577 panic!("Invalid Hugr: {err}");
578 }
579
580 let new_sig = dfg_hugr
582 .get_optype(dfg_hugr.entrypoint())
583 .dataflow_signature()
584 .unwrap();
585 assert_eq!(new_sig.input_count(), 3);
586 assert_eq!(new_sig.output_count(), 3);
587
588 for (i, exp_gate) in [cx_node, dfg_out, cx_node].into_iter().enumerate() {
590 assert_eq!(
591 dfg_hugr.linked_inputs(dfg_inp, i).collect_vec(),
592 vec![(exp_gate, ((i + 1) % 3).into())]
593 );
594 }
595 for (i, exp_gate) in [cx_node, cx_node, dfg_inp].into_iter().enumerate() {
597 let exp_outport = std::cmp::min(i, 1);
598 assert_eq!(
599 dfg_hugr.linked_outputs(dfg_out, i).collect_vec(),
600 vec![(exp_gate, exp_outport.into())],
601 "expected {}({exp_outport}) -> out({i})",
602 dfg_hugr.get_optype(exp_gate).name()
603 );
604 }
605
606 assert_snapshot!(dfg_hugr.mermaid_string());
607 }
608
609 #[cfg_attr(miri, ignore)] #[test]
611 fn test_map_io_recursive() {
612 use crate::builder::ModuleBuilder;
613 use crate::extension::prelude::{bool_t, qb_t};
614 use crate::types::Signature;
615
616 let mut module_builder = ModuleBuilder::new();
618
619 let dfg_roots = {
621 let mut foo_builder = module_builder
622 .define_function("foo", Signature::new_endo(vec![qb_t(), bool_t()]))
623 .unwrap();
624
625 let [qb, b] = foo_builder.input_wires_arr();
626
627 let mut dfg1_builder = foo_builder
629 .dfg_builder_endo([(qb_t(), qb), (bool_t(), b)])
630 .unwrap();
631 let [dfg1_qb, dfg1_b] = dfg1_builder.input_wires_arr();
632
633 let dfg2_builder = dfg1_builder
635 .dfg_builder_endo([(qb_t(), dfg1_qb), (bool_t(), dfg1_b)])
636 .unwrap();
637 let [dfg2_qb, dfg2_b] = dfg2_builder.input_wires_arr();
638
639 let dfg2_id = dfg2_builder.finish_with_outputs([dfg2_qb, dfg2_b]).unwrap();
641
642 let dfg1_id = dfg1_builder.finish_with_outputs(dfg2_id.outputs()).unwrap();
644
645 let foo_id = foo_builder.finish_with_outputs(dfg1_id.outputs()).unwrap();
647
648 [foo_id.node(), dfg1_id.node(), dfg2_id.node()]
649 };
650
651 let mut hugr = module_builder.finish_hugr().unwrap();
652 hugr.set_entrypoint(dfg_roots[2]);
653
654 let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
656
657 let input_map = vec![0, 1];
659 let output_map = vec![1, 0];
660
661 dfg_view.map_function_type(&input_map, &output_map).unwrap();
662
663 for node in dfg_roots {
665 let sig = hugr.get_optype(node).inner_function_type().unwrap();
666 assert_eq!(sig.input_types(), vec![qb_t(), bool_t()]);
667 assert_eq!(sig.output_types(), vec![bool_t(), qb_t()]);
668 }
669
670 assert_snapshot!(hugr.mermaid_string());
671 }
672
673 #[test]
674 fn test_extend_inputs() {
675 let dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t()])).unwrap();
677 let [wire] = dfg_builder.input_wires_arr();
678 let mut hugr = dfg_builder.finish_hugr_with_outputs(vec![wire]).unwrap();
679
680 let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
682
683 let new_inputs = vec![bool_t(), float64_type()];
685 dfg_view.extend_inputs(&new_inputs).unwrap();
686 assert_eq!(
687 dfg_view.hugr().inner_function_type().unwrap(),
688 Signature::new(vec![qb_t(), bool_t(), float64_type()], vec![qb_t()])
689 );
690
691 let new_inputs_fail = vec![qb_t()];
692 let err = dfg_view.extend_inputs(&new_inputs_fail);
693 assert_eq!(
694 err,
695 Err(InvalidSignature::ExpectedCopyable(
696 IncomingPort::from(3).into()
697 ))
698 );
699 }
700}