1use std::collections::BTreeMap;
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::{
9 IncomingPort, OutgoingPort, PortIndex,
10 hugr::HugrMut,
11 ops::{
12 OpTrait, OpType,
13 handle::{DataflowParentID, DfgID},
14 },
15 types::{NoRV, Signature, TypeBase},
16};
17
18use super::RootChecked;
19
20macro_rules! impl_dataflow_parent_methods {
21 ($handle_type:ident) => {
22 impl<H: HugrMut> RootChecked<H, $handle_type<H::Node>> {
23 pub fn get_io(&self) -> [H::Node; 2] {
25 self.hugr()
26 .get_io(self.hugr().entrypoint())
27 .expect("valid DFG graph")
28 }
29
30 pub fn map_function_type(
53 &mut self,
54 new_inputs: &[usize],
55 new_outputs: &[usize],
56 ) -> Result<(), InvalidSignature> {
57 let [inp, out] = self.get_io();
58 let Self(hugr, _) = self;
59
60 let old_inputs_incoming = hugr
62 .node_outputs(inp)
63 .map(|p| hugr.linked_inputs(inp, p).collect_vec())
64 .collect_vec();
65 let old_outputs_outgoing = hugr
66 .node_inputs(out)
67 .map(|p| hugr.linked_outputs(out, p).collect_vec())
68 .collect_vec();
69
70 let old_inp_sig = hugr
72 .get_optype(inp)
73 .dataflow_signature()
74 .expect("input has signature");
75 let old_inp_sig = old_inp_sig.output_types();
76 let old_out_sig = hugr
77 .get_optype(out)
78 .dataflow_signature()
79 .expect("output has signature");
80 let old_out_sig = old_out_sig.input_types();
81
82 check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?;
84 check_valid_outputs(old_out_sig, new_outputs)?;
85
86 let new_inp_sig = new_inputs
88 .iter()
89 .map(|&i| old_inp_sig[i].clone())
90 .collect_vec();
91 let new_out_sig = new_outputs
92 .iter()
93 .map(|&i| old_out_sig[i].clone())
94 .collect_vec();
95 let new_sig = Signature::new(new_inp_sig, new_out_sig);
96
97 disconnect_all(hugr, inp);
99 disconnect_all(hugr, out);
100
101 let mut is_ancestor = false;
103 let mut node = hugr.entrypoint();
104 while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
105 let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
106 for node in [node, inner_inp, inner_out] {
107 update_signature(hugr, node, &new_sig);
108 }
109 if is_ancestor {
110 update_inner_dfg_links(hugr, node);
111 }
112 if let Some(parent) = hugr.get_parent(node) {
113 node = parent;
114 is_ancestor = true;
115 } else {
116 break;
117 }
118 }
119
120 let mut old_output_to_new_input = BTreeMap::<IncomingPort, OutgoingPort>::new();
122 for (inp_pos, &old_pos) in new_inputs.iter().enumerate() {
123 for &(node, port) in &old_inputs_incoming[old_pos] {
124 if node != out {
125 hugr.connect(inp, inp_pos, node, port);
126 } else {
127 old_output_to_new_input.insert(port, inp_pos.into());
128 }
129 }
130 }
131
132 for (out_pos, &old_pos) in new_outputs.iter().enumerate() {
134 for &(node, port) in &old_outputs_outgoing[old_pos] {
135 if node != inp {
136 hugr.connect(node, port, out, out_pos);
137 } else {
138 let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap();
139 hugr.connect(inp, inp_pos, out, out_pos);
140 }
141 }
142 }
143
144 Ok(())
145 }
146 }
147 };
148}
149
150impl_dataflow_parent_methods!(DataflowParentID);
151impl_dataflow_parent_methods!(DfgID);
152
153fn update_inner_dfg_links<H: HugrMut>(hugr: &mut H, node: H::Node) {
155 let inner_dfg = hugr
157 .children(node)
158 .skip(2)
159 .exactly_one()
160 .ok()
161 .expect("no non-trivial inner DFG");
162
163 let [inp, out] = hugr.get_io(node).expect("valid DFG graph");
164 disconnect_all(hugr, inner_dfg);
165 for (out_port, _) in hugr.out_value_types(inp).collect_vec() {
166 hugr.connect(inp, out_port, inner_dfg, out_port.index());
167 }
168 for (in_port, _) in hugr.in_value_types(out).collect_vec() {
169 hugr.connect(inner_dfg, in_port.index(), out, in_port);
170 }
171}
172
173fn disconnect_all<H: HugrMut>(hugr: &mut H, node: H::Node) {
174 let all_ports = hugr.all_node_ports(node).collect_vec();
175 for port in all_ports {
176 hugr.disconnect(node, port);
177 }
178}
179
180fn update_signature<H: HugrMut>(hugr: &mut H, node: H::Node, new_sig: &Signature) {
181 match hugr.optype_mut(node) {
182 OpType::DFG(dfg) => {
183 dfg.signature = new_sig.clone();
184 }
185 OpType::FuncDefn(fn_def_op) => *fn_def_op.signature_mut() = new_sig.clone().into(),
186 OpType::Input(inp) => {
187 inp.types = new_sig.input().clone();
188 }
189 OpType::Output(out) => out.types = new_sig.output().clone(),
190 _ => panic!("only update signature of DFG, FuncDefn, Input, or Output"),
191 };
192 let new_op = hugr.get_optype(node);
193 hugr.set_num_ports(node, new_op.input_count(), new_op.output_count());
194}
195
196fn check_valid_inputs<V>(
197 old_ports: &[Vec<V>],
198 old_sig: &[TypeBase<NoRV>],
199 map_sig: &[usize],
200) -> Result<(), InvalidSignature> {
201 if let Some(old_pos) = map_sig
202 .iter()
203 .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos))
204 {
205 return Err(InvalidSignature::UnknownIO(old_pos, "input"));
206 }
207
208 let counts = map_sig.iter().copied().counts();
209 if let Some(old_pos) = old_ports.iter().enumerate().find_map(|(old_pos, vec)| {
210 ((!vec.is_empty() || old_sig.get(old_pos).is_some_and(|t| !t.copyable()))
211 && !counts.contains_key(&old_pos))
212 .then_some(old_pos)
213 }) {
214 return Err(InvalidSignature::MissingIO(old_pos, "input"));
215 }
216
217 if let Some(old_pos) = counts
218 .iter()
219 .find_map(|(&old_pos, &count)| (count > 1).then_some(old_pos))
220 {
221 return Err(InvalidSignature::DuplicateInput(old_pos));
222 }
223
224 Ok(())
225}
226
227fn check_valid_outputs(
228 old_sig: &[TypeBase<NoRV>],
229 map_sig: &[usize],
230) -> Result<(), InvalidSignature> {
231 if let Some(old_pos) = map_sig
232 .iter()
233 .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos))
234 {
235 return Err(InvalidSignature::UnknownIO(old_pos, "output"));
236 }
237
238 let counts = map_sig.iter().copied().counts();
239 let linear_types = old_sig
240 .iter()
241 .enumerate()
242 .filter_map(|(pos, t)| (!t.copyable()).then_some(pos));
243 for old_pos in linear_types {
244 let Some(&cnt) = counts.get(&old_pos) else {
245 return Err(InvalidSignature::MissingIO(old_pos, "output"));
246 };
247 if cnt != 1 {
248 return Err(InvalidSignature::LinearityViolation(old_pos, "output"));
249 }
250 }
251
252 Ok(())
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Error)]
257#[non_exhaustive]
258pub enum InvalidSignature {
259 #[error("{1} at position {0} is required but missing in new signature")]
261 MissingIO(usize, &'static str),
262 #[error("No {1} at position {0} in signature")]
265 UnknownIO(usize, &'static str),
266 #[error("Linearity of {1} at position {0} is not preserved in new signature")]
268 LinearityViolation(usize, &'static str),
269 #[error("Input at position {0} is duplicated in new signature")]
271 DuplicateInput(usize),
272}
273
274#[cfg(test)]
275mod test {
276 use insta::assert_snapshot;
277
278 use super::*;
279 use crate::builder::{
280 DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig,
281 };
282 use crate::extension::prelude::{bool_t, qb_t};
283 use crate::hugr::views::root_checked::RootChecked;
284 use crate::ops::handle::NodeHandle;
285 use crate::ops::{NamedOp, OpParent};
286 use crate::types::Signature;
287 use crate::utils::test_quantum_extension::cx_gate;
288 use crate::{Hugr, HugrView};
289
290 fn new_empty_dfg(sig: Signature) -> Hugr {
291 let dfg_builder = DFGBuilder::new(sig).unwrap();
292 let wires = dfg_builder.input_wires();
293 dfg_builder.finish_hugr_with_outputs(wires).unwrap()
294 }
295
296 #[test]
297 fn test_map_io() {
298 let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
300 let mut hugr = new_empty_dfg(sig);
301
302 let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
304
305 let input_map = vec![1, 0];
307 let output_map = vec![0, 1];
308
309 dfg_view.map_function_type(&input_map, &output_map).unwrap();
311
312 let dfg_hugr = dfg_view.hugr();
314 let new_sig = dfg_hugr
315 .get_optype(dfg_hugr.entrypoint())
316 .dataflow_signature()
317 .unwrap();
318 assert_eq!(new_sig.input_count(), 2);
319 assert_eq!(new_sig.output_count(), 2);
320
321 let invalid_input_map = vec![0, 0];
323 let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
324 assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
325
326 let invalid_input_map = vec![0, 0, 1];
328 assert!(matches!(
329 dfg_view.map_function_type(&invalid_input_map, &output_map),
330 Err(InvalidSignature::DuplicateInput(0))
331 ));
332
333 let invalid_output_map = vec![0, 2];
335 assert!(matches!(
336 dfg_view.map_function_type(&input_map, &invalid_output_map),
337 Err(InvalidSignature::UnknownIO(2, "output"))
338 ));
339 }
340
341 #[test]
342 fn test_map_io_dfg_id() {
343 let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
345 let mut hugr = new_empty_dfg(sig);
346
347 let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
349
350 let input_map = vec![1, 0];
352 let output_map = vec![0, 1];
353
354 dfg_view.map_function_type(&input_map, &output_map).unwrap();
356
357 let dfg_hugr = dfg_view.hugr();
359 let new_sig = dfg_hugr
360 .get_optype(dfg_hugr.entrypoint())
361 .dataflow_signature()
362 .unwrap();
363 assert_eq!(new_sig.input_count(), 2);
364 assert_eq!(new_sig.output_count(), 2);
365
366 let invalid_input_map = vec![0, 0];
368 let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
369 assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
370
371 let invalid_input_map = vec![0, 0, 1];
373 assert!(matches!(
374 dfg_view.map_function_type(&invalid_input_map, &output_map),
375 Err(InvalidSignature::DuplicateInput(0))
376 ));
377
378 let invalid_output_map = vec![0, 2];
380 assert!(matches!(
381 dfg_view.map_function_type(&input_map, &invalid_output_map),
382 Err(InvalidSignature::UnknownIO(2, "output"))
383 ));
384 }
385
386 #[cfg_attr(miri, ignore)] #[test]
388 fn test_map_io_duplicate_output() {
389 let sig = Signature::new_endo(vec![bool_t()]);
391 let mut hugr = new_empty_dfg(sig);
392
393 let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
395
396 let input_map = vec![0];
398 let output_map = vec![0, 0];
399
400 dfg_view.map_function_type(&input_map, &output_map).unwrap();
402
403 let dfg_hugr = dfg_view.hugr();
404 if let Err(err) = dfg_hugr.validate() {
405 panic!("Invalid Hugr: {err}");
406 }
407
408 let new_sig = dfg_hugr
410 .get_optype(dfg_hugr.entrypoint())
411 .dataflow_signature()
412 .unwrap();
413 assert_eq!(new_sig.input_count(), 1);
414 assert_eq!(new_sig.output_count(), 2);
415 assert_snapshot!(dfg_hugr.mermaid_string());
416 }
417
418 #[cfg_attr(miri, ignore)] #[test]
420 fn test_map_io_cx_gate() {
421 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap();
423 let [wire0, wire1] = dfg_builder.input_wires_arr();
424 let cx_handle = dfg_builder
425 .add_dataflow_op(cx_gate(), vec![wire0, wire1])
426 .unwrap();
427 let cx_node = cx_handle.node();
428 let [wire0, wire1] = cx_handle.outputs_arr();
429 let mut hugr = dfg_builder
430 .finish_hugr_with_outputs(vec![wire0, wire1])
431 .unwrap();
432
433 let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
435
436 let input_map = vec![1, 0];
438 let output_map = vec![0, 1];
439
440 dfg_view.map_function_type(&input_map, &output_map).unwrap();
442
443 let dfg_hugr = dfg_view.hugr();
444 if let Err(err) = dfg_hugr.validate() {
445 panic!("Invalid Hugr: {err}");
446 }
447
448 let new_sig = dfg_hugr
450 .get_optype(dfg_hugr.entrypoint())
451 .dataflow_signature()
452 .unwrap();
453 assert_eq!(new_sig.input_count(), 2);
454 assert_eq!(new_sig.output_count(), 2);
455
456 let [new_inp, new_out] = dfg_view.get_io();
458 assert_eq!(
459 dfg_hugr.linked_inputs(new_inp, 0).collect_vec(),
460 vec![(cx_node, 1.into())]
461 );
462 assert_eq!(
463 dfg_hugr.linked_inputs(new_inp, 1).collect_vec(),
464 vec![(cx_node, 0.into())]
465 );
466 assert_eq!(
467 dfg_hugr.linked_outputs(new_out, 0).collect_vec(),
468 vec![(cx_node, 0.into())]
469 );
470 assert_eq!(
471 dfg_hugr.linked_outputs(new_out, 1).collect_vec(),
472 vec![(cx_node, 1.into())]
473 );
474
475 assert_snapshot!(dfg_hugr.mermaid_string());
476 }
477
478 #[cfg_attr(miri, ignore)] #[test]
480 fn test_map_io_cycle_3qb() {
481 let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(); 3])).unwrap();
483 let [wire0, wire1, wire2] = dfg_builder.input_wires_arr();
484 let cx_handle = dfg_builder
485 .add_dataflow_op(cx_gate(), vec![wire0, wire1])
486 .unwrap();
487 let cx_node = cx_handle.node();
488 let [wire0, wire1] = cx_handle.outputs_arr();
489 let mut hugr = dfg_builder
490 .finish_hugr_with_outputs(vec![wire0, wire1, wire2])
491 .unwrap();
492
493 let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
495
496 let input_map = vec![1, 2, 0];
498 let output_map = vec![0, 1, 2];
499
500 dfg_view.map_function_type(&input_map, &output_map).unwrap();
502 let [dfg_inp, dfg_out] = dfg_view.get_io();
503
504 let dfg_hugr = dfg_view.hugr();
505 if let Err(err) = dfg_hugr.validate() {
506 panic!("Invalid Hugr: {err}");
507 }
508
509 let new_sig = dfg_hugr
511 .get_optype(dfg_hugr.entrypoint())
512 .dataflow_signature()
513 .unwrap();
514 assert_eq!(new_sig.input_count(), 3);
515 assert_eq!(new_sig.output_count(), 3);
516
517 for (i, exp_gate) in [cx_node, dfg_out, cx_node].into_iter().enumerate() {
519 assert_eq!(
520 dfg_hugr.linked_inputs(dfg_inp, i).collect_vec(),
521 vec![(exp_gate, ((i + 1) % 3).into())]
522 );
523 }
524 for (i, exp_gate) in [cx_node, cx_node, dfg_inp].into_iter().enumerate() {
526 let exp_outport = std::cmp::min(i, 1);
527 assert_eq!(
528 dfg_hugr.linked_outputs(dfg_out, i).collect_vec(),
529 vec![(exp_gate, exp_outport.into())],
530 "expected {}({exp_outport}) -> out({i})",
531 dfg_hugr.get_optype(exp_gate).name()
532 );
533 }
534
535 assert_snapshot!(dfg_hugr.mermaid_string());
536 }
537
538 #[cfg_attr(miri, ignore)] #[test]
540 fn test_map_io_recursive() {
541 use crate::builder::ModuleBuilder;
542 use crate::extension::prelude::{bool_t, qb_t};
543 use crate::types::Signature;
544
545 let mut module_builder = ModuleBuilder::new();
547
548 let dfg_roots = {
550 let mut foo_builder = module_builder
551 .define_function("foo", Signature::new_endo(vec![qb_t(), bool_t()]))
552 .unwrap();
553
554 let [qb, b] = foo_builder.input_wires_arr();
555
556 let mut dfg1_builder = foo_builder
558 .dfg_builder_endo([(qb_t(), qb), (bool_t(), b)])
559 .unwrap();
560 let [dfg1_qb, dfg1_b] = dfg1_builder.input_wires_arr();
561
562 let dfg2_builder = dfg1_builder
564 .dfg_builder_endo([(qb_t(), dfg1_qb), (bool_t(), dfg1_b)])
565 .unwrap();
566 let [dfg2_qb, dfg2_b] = dfg2_builder.input_wires_arr();
567
568 let dfg2_id = dfg2_builder.finish_with_outputs([dfg2_qb, dfg2_b]).unwrap();
570
571 let dfg1_id = dfg1_builder.finish_with_outputs(dfg2_id.outputs()).unwrap();
573
574 let foo_id = foo_builder.finish_with_outputs(dfg1_id.outputs()).unwrap();
576
577 [foo_id.node(), dfg1_id.node(), dfg2_id.node()]
578 };
579
580 let mut hugr = module_builder.finish_hugr().unwrap();
581 hugr.set_entrypoint(dfg_roots[2]);
582
583 let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
585
586 let input_map = vec![0, 1];
588 let output_map = vec![1, 0];
589
590 dfg_view.map_function_type(&input_map, &output_map).unwrap();
591
592 for node in dfg_roots {
594 let sig = hugr.get_optype(node).inner_function_type().unwrap();
595 assert_eq!(sig.input_types(), vec![qb_t(), bool_t()]);
596 assert_eq!(sig.output_types(), vec![bool_t(), qb_t()]);
597 }
598
599 assert_snapshot!(hugr.mermaid_string());
600 }
601}