1use crate::extension::prelude::MakeTuple;
2use crate::hugr::ValidationError;
3use crate::hugr::hugrmut::InsertionResult;
4use crate::hugr::linking::{HugrLinking, NameLinkingPolicy, NodeLinkingDirective};
5use crate::hugr::views::HugrView;
6use crate::metadata::Metadata;
7use crate::ops::{self, OpTag, OpTrait, OpType, Tag, TailLoop};
8use crate::utils::collect_array;
9use crate::{Extension, IncomingPort, Node, OutgoingPort};
10
11use std::collections::HashMap;
12use std::iter;
13use std::sync::Arc;
14
15use super::{BuilderWiringError, ModuleBuilder};
16use super::{
17 CircuitBuilder,
18 handle::{BuildHandle, Outputs},
19};
20
21use crate::{
22 ops::handle::{ConstID, DataflowOpID, FuncID, NodeHandle},
23 types::EdgeKind,
24};
25
26use crate::extension::ExtensionRegistry;
27use crate::types::{Signature, Type, TypeArg, TypeRow};
28
29use itertools::Itertools;
30
31use super::{
32 BuildError, Wire, cfg::CFGBuilder, conditional::ConditionalBuilder, dataflow::DFGBuilder,
33 tail_loop::TailLoopBuilder,
34};
35
36use crate::Hugr;
37
38use crate::hugr::HugrMut;
39
40pub trait Container {
45 fn container_node(&self) -> Node;
47 fn hugr_mut(&mut self) -> &mut Hugr;
49 fn hugr(&self) -> &Hugr;
51 fn add_child_node(&mut self, node: impl Into<OpType>) -> Node {
55 let node: OpType = node.into();
56
57 let used_extensions = node
59 .used_extensions()
60 .unwrap_or_else(|e| panic!("Build-time signatures should have valid extensions. {e}"));
61 self.use_extensions(used_extensions);
62
63 let parent = self.container_node();
64 self.hugr_mut().add_node_with_parent(parent, node)
65 }
66
67 fn add_other_wire(&mut self, src: Node, dst: Node) -> Wire {
72 let (src_port, _) = self.hugr_mut().add_other_edge(src, dst);
73 Wire::new(src, src_port)
74 }
75
76 fn add_constant(&mut self, constant: impl Into<ops::Const>) -> ConstID {
85 self.add_child_node(constant.into()).into()
86 }
87
88 fn add_hugr(&mut self, child: Hugr) -> InsertionResult {
92 let region = child.entrypoint();
93 self.add_hugr_region(child, region)
94 }
95
96 fn add_hugr_region(&mut self, child: Hugr, region: Node) -> InsertionResult {
100 let parent = self.container_node();
101 self.hugr_mut().insert_region(parent, child, region)
102 }
103
104 fn add_hugr_view<H: HugrView>(&mut self, child: &H) -> InsertionResult<H::Node, Node> {
109 let parent = self.container_node();
110 self.hugr_mut().insert_from_view(parent, child)
111 }
112
113 fn set_metadata<M: Metadata>(&mut self, meta: <M as Metadata>::Type<'_>) {
115 let parent = self.container_node();
116 self.hugr_mut().set_metadata::<M>(parent, meta);
118 }
119
120 fn set_child_metadata<M: Metadata>(&mut self, child: Node, meta: <M as Metadata>::Type<'_>) {
124 self.hugr_mut().set_metadata::<M>(child, meta);
125 }
126
127 fn use_extension(&mut self, ext: impl Into<Arc<Extension>>) {
129 self.hugr_mut().use_extension(ext);
130 }
131
132 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
134 where
135 ExtensionRegistry: Extend<Reg>,
136 {
137 self.hugr_mut().use_extensions(registry);
138 }
139}
140
141pub trait HugrBuilder: Container {
144 fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> {
147 debug_assert!(
148 self.hugr()
149 .get_optype(self.hugr().module_root())
150 .is_module()
151 );
152 ModuleBuilder(self.hugr_mut())
153 }
154
155 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>>;
157}
158
159pub trait SubContainer: Container {
161 type ContainerHandle;
164 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError>;
167}
168pub trait Dataflow: Container {
170 fn num_inputs(&self) -> usize;
172 fn io(&self) -> [Node; 2] {
174 self.hugr()
175 .children(self.container_node())
176 .take(2)
177 .collect_vec()
178 .try_into()
179 .expect("First two children should be IO")
180 }
181 fn input(&self) -> BuildHandle<DataflowOpID> {
183 (self.io()[0], self.num_inputs()).into()
184 }
185 fn output(&self) -> DataflowOpID {
187 self.io()[1].into()
188 }
189 fn input_wires(&self) -> Outputs {
191 self.input().outputs()
192 }
193 fn add_dataflow_op(
202 &mut self,
203 nodetype: impl Into<OpType>,
204 input_wires: impl IntoIterator<Item = Wire>,
205 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
206 let outs = add_node_with_wires(self, nodetype, input_wires)?;
207
208 Ok(outs.into())
209 }
210
211 fn add_hugr_with_wires(
223 &mut self,
224 hugr: Hugr,
225 input_wires: impl IntoIterator<Item = Wire>,
226 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
227 let region = hugr.entrypoint();
228 self.add_hugr_region_with_wires(hugr, region, input_wires)
229 }
230
231 fn add_link_hugr_with_wires(
243 &mut self,
244 hugr: Hugr,
245 policy: &NameLinkingPolicy,
246 input_wires: impl IntoIterator<Item = Wire>,
247 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
248 let parent = self.container_node();
249 let node = self
250 .hugr_mut()
251 .insert_link_hugr(parent, hugr, policy)?
252 .inserted_entrypoint;
253 wire_ins_return_outs(input_wires, node, self)
254 }
255
256 fn add_hugr_region_with_wires(
267 &mut self,
268 hugr: Hugr,
269 region: Node,
270 input_wires: impl IntoIterator<Item = Wire>,
271 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
272 let node = self.add_hugr_region(hugr, region).inserted_entrypoint;
273
274 wire_ins_return_outs(input_wires, node, self)
275 }
276
277 fn add_link_hugr_by_node_with_wires(
282 &mut self,
283 hugr: Hugr,
284 input_wires: impl IntoIterator<Item = Wire>,
285 defns: HashMap<Node, NodeLinkingDirective>,
286 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
287 let parent = Some(self.container_node());
288 let ep = hugr.entrypoint();
289 let node = self
290 .hugr_mut()
291 .insert_link_hugr_by_node(parent, hugr, defns)?
292 .node_map[&ep];
293 wire_ins_return_outs(input_wires, node, self)
294 }
295
296 fn add_hugr_view_with_wires(
307 &mut self,
308 hugr: &impl HugrView,
309 input_wires: impl IntoIterator<Item = Wire>,
310 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
311 let node = self.add_hugr_view(hugr).inserted_entrypoint;
312 wire_ins_return_outs(input_wires, node, self)
313 }
314
315 fn add_link_view_with_wires(
327 &mut self,
328 hugr: &impl HugrView,
329 policy: &NameLinkingPolicy,
330 input_wires: impl IntoIterator<Item = Wire>,
331 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
332 let parent = self.container_node();
333 let insertion = self
334 .hugr_mut()
335 .insert_link_from_view(parent, hugr, policy)
336 .map_err(|ins_err| BuildError::HugrViewInsertionError(ins_err.to_string()))?;
337 let node = insertion.node_map[&hugr.entrypoint()];
338 wire_ins_return_outs(input_wires, node, self)
339 }
340
341 fn add_link_view_by_node_with_wires<H: HugrView>(
345 &mut self,
346 hugr: &H,
347 input_wires: impl IntoIterator<Item = Wire>,
348 defns: HashMap<H::Node, NodeLinkingDirective>,
349 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
350 let parent = Some(self.container_node());
351 let node = self
352 .hugr_mut()
353 .insert_link_view_by_node(parent, hugr, defns)
354 .map_err(|ins_err| BuildError::HugrViewInsertionError(ins_err.to_string()))?
355 .node_map[&hugr.entrypoint()];
356 wire_ins_return_outs(input_wires, node, self)
357 }
358
359 fn set_outputs(
365 &mut self,
366 output_wires: impl IntoIterator<Item = Wire>,
367 ) -> Result<(), BuildError> {
368 let [_, out] = self.io();
369 wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| {
370 BuildError::OutputWiring {
371 container_op: Box::new(self.hugr().get_optype(self.container_node()).clone()),
372 container_node: self.container_node(),
373 error,
374 }
375 })
376 }
377
378 #[track_caller]
384 fn input_wires_arr<const N: usize>(&self) -> [Wire; N] {
385 collect_array(self.input_wires())
386 }
387
388 fn dfg_builder(
398 &mut self,
399 signature: Signature,
400 input_wires: impl IntoIterator<Item = Wire>,
401 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
402 let op = ops::DFG {
403 signature: signature.clone(),
404 };
405 let (dfg_n, _) = add_node_with_wires(self, op, input_wires)?;
406
407 DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
408 }
409
410 fn dfg_builder_endo(
415 &mut self,
416 inputs: impl IntoIterator<Item = (Type, Wire)>,
417 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
418 let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
419 self.dfg_builder(Signature::new_endo(types), input_wires)
420 }
421
422 fn cfg_builder(
433 &mut self,
434 inputs: impl IntoIterator<Item = (Type, Wire)>,
435 output_types: TypeRow,
436 ) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
437 let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
438
439 let inputs: TypeRow = input_types.into();
440
441 let (cfg_node, _) = add_node_with_wires(
442 self,
443 ops::CFG {
444 signature: Signature::new(inputs.clone(), output_types.clone()),
445 },
446 input_wires,
447 )?;
448 CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
449 }
450
451 fn load_const(&mut self, cid: &ConstID) -> Wire {
454 let const_node = cid.node();
455 let nodetype = self.hugr().get_optype(const_node);
456 let op: ops::Const = nodetype
457 .clone()
458 .try_into()
459 .expect("ConstID does not refer to Const op.");
460
461 let load_n = self
462 .add_dataflow_op(
463 ops::LoadConstant {
464 datatype: op.get_type().clone(),
465 },
466 vec![Wire::new(const_node, OutgoingPort::from(0))],
468 )
469 .expect("The constant type should match the LoadConstant type.");
470
471 load_n.out_wire(0)
472 }
473
474 fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Wire {
477 let cid = self.add_constant(constant);
478 self.load_const(&cid)
479 }
480
481 fn add_load_value(&mut self, constant: impl Into<ops::Value>) -> Wire {
484 self.add_load_const(constant.into())
485 }
486
487 fn load_func<const DEFINED: bool>(
493 &mut self,
494 fid: &FuncID<DEFINED>,
495 type_args: &[TypeArg],
496 ) -> Result<Wire, BuildError> {
497 let func_node = fid.node();
498 let func_op = self.hugr().get_optype(func_node);
499 let func_sig = match func_op {
500 OpType::FuncDefn(fd) => fd.signature().clone(),
501 OpType::FuncDecl(fd) => fd.signature().clone(),
502 _ => {
503 return Err(BuildError::UnexpectedType {
504 node: func_node,
505 op_desc: "FuncDecl/FuncDefn",
506 });
507 }
508 };
509
510 let load_n = self.add_dataflow_op(
511 ops::LoadFunction::try_new(func_sig, type_args)?,
512 vec![Wire::new(func_node, func_op.static_output_port().unwrap())],
514 )?;
515
516 Ok(load_n.out_wire(0))
517 }
518
519 fn tail_loop_builder(
530 &mut self,
531 just_inputs: impl IntoIterator<Item = (Type, Wire)>,
532 inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
533 just_out_types: TypeRow,
534 ) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
535 let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
536 just_inputs.into_iter().unzip();
537 let (rest_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
538 inputs_outputs.into_iter().unzip();
539 input_wires.extend(rest_input_wires);
540
541 let tail_loop = ops::TailLoop {
542 just_inputs: input_types.into(),
543 just_outputs: just_out_types,
544 rest: rest_types.into(),
545 };
546 let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
548
549 TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
550 }
551
552 fn conditional_builder(
565 &mut self,
566 (sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
567 other_inputs: impl IntoIterator<Item = (Type, Wire)>,
568 output_types: TypeRow,
569 ) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
570 let mut input_wires = vec![sum_wire];
571 let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
572 other_inputs.into_iter().unzip();
573
574 input_wires.extend(rest_input_wires);
575 let inputs: TypeRow = input_types.into();
576 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
577 let n_cases = sum_rows.len();
578 let n_out_wires = output_types.len();
579
580 let conditional_id = self.add_dataflow_op(
581 ops::Conditional {
582 sum_rows,
583 other_inputs: inputs,
584 outputs: output_types,
585 },
586 input_wires,
587 )?;
588
589 Ok(ConditionalBuilder {
590 base: self.hugr_mut(),
591 conditional_node: conditional_id.node(),
592 n_out_wires,
593 case_nodes: vec![None; n_cases],
594 })
595 }
596
597 fn set_order(&mut self, before: &impl NodeHandle, after: &impl NodeHandle) {
600 self.add_other_wire(before.node(), after.node());
601 }
602
603 fn get_wire_type(&self, wire: Wire) -> Result<Type, BuildError> {
605 let kind = self.hugr().get_optype(wire.node()).port_kind(wire.source());
606
607 if let Some(EdgeKind::Value(typ)) = kind {
608 Ok(typ)
609 } else {
610 Err(BuildError::WireNotFound(wire))
611 }
612 }
613
614 fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
622 let values = values.into_iter().collect_vec();
623 let types: Result<Vec<Type>, _> = values
624 .iter()
625 .map(|&wire| self.get_wire_type(wire))
626 .collect();
627 let types = types?.into();
628 let make_op = self.add_dataflow_op(MakeTuple(types), values)?;
629 Ok(make_op.out_wire(0))
630 }
631
632 fn make_sum(
642 &mut self,
643 tag: usize,
644 variants: impl IntoIterator<Item = TypeRow>,
645 values: impl IntoIterator<Item = Wire>,
646 ) -> Result<Wire, BuildError> {
647 let make_op = self.add_dataflow_op(
648 Tag {
649 tag,
650 variants: variants.into_iter().collect_vec(),
651 },
652 values.into_iter().collect_vec(),
653 )?;
654 Ok(make_op.out_wire(0))
655 }
656
657 fn make_continue(
667 &mut self,
668 tail_loop: ops::TailLoop,
669 values: impl IntoIterator<Item = Wire>,
670 ) -> Result<Wire, BuildError> {
671 self.make_sum(
672 TailLoop::CONTINUE_TAG,
673 [tail_loop.just_inputs, tail_loop.just_outputs],
674 values,
675 )
676 }
677
678 fn make_break(
688 &mut self,
689 loop_op: ops::TailLoop,
690 values: impl IntoIterator<Item = Wire>,
691 ) -> Result<Wire, BuildError> {
692 self.make_sum(
693 TailLoop::BREAK_TAG,
694 [loop_op.just_inputs, loop_op.just_outputs],
695 values,
696 )
697 }
698
699 fn call<const DEFINED: bool>(
708 &mut self,
709 function: &FuncID<DEFINED>,
710 type_args: &[TypeArg],
711 input_wires: impl IntoIterator<Item = Wire>,
712 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
713 let hugr = self.hugr();
714 let def_op = hugr.get_optype(function.node());
715 let type_scheme = match def_op {
716 OpType::FuncDefn(fd) => fd.signature().clone(),
717 OpType::FuncDecl(fd) => fd.signature().clone(),
718 _ => {
719 return Err(BuildError::UnexpectedType {
720 node: function.node(),
721 op_desc: "FuncDecl/FuncDefn",
722 });
723 }
724 };
725 let op: OpType = ops::Call::try_new(type_scheme, type_args)?.into();
726 let const_in_port = op.static_input_port().unwrap();
727 let op_id = self.add_dataflow_op(op, input_wires)?;
728 let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
729
730 self.hugr_mut()
731 .connect(function.node(), src_port, op_id.node(), const_in_port);
732 Ok(op_id)
733 }
734
735 fn as_circuit(&mut self, wires: impl IntoIterator<Item = Wire>) -> CircuitBuilder<'_, Self> {
738 CircuitBuilder::new(wires, self)
739 }
740
741 fn add_barrier(
750 &mut self,
751 wires: impl IntoIterator<Item = Wire>,
752 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
753 let wires = wires.into_iter().collect_vec();
754 let types: Result<Vec<Type>, _> =
755 wires.iter().map(|&wire| self.get_wire_type(wire)).collect();
756 let types = types?;
757 let barrier_op =
758 self.add_dataflow_op(crate::extension::prelude::Barrier::new(types), wires)?;
759 Ok(barrier_op)
760 }
761}
762
763fn add_node_with_wires<T: Dataflow + ?Sized>(
772 data_builder: &mut T,
773 nodetype: impl Into<OpType>,
774 inputs: impl IntoIterator<Item = Wire>,
775) -> Result<(Node, usize), BuildError> {
776 let op: OpType = nodetype.into();
777 let num_outputs = op.value_output_count();
778 let op_node = data_builder.add_child_node(op.clone());
779
780 wire_up_inputs(inputs, op_node, data_builder).map_err(|error| BuildError::OperationWiring {
781 op: Box::new(op),
782 error,
783 })?;
784
785 Ok((op_node, num_outputs))
786}
787
788fn wire_up_inputs<T: Dataflow + ?Sized>(
796 inputs: impl IntoIterator<Item = Wire>,
797 op_node: Node,
798 data_builder: &mut T,
799) -> Result<(), BuilderWiringError> {
800 for (dst_port, wire) in inputs.into_iter().enumerate() {
801 wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?;
802 }
803 Ok(())
804}
805
806fn wire_ins_return_outs<T: Dataflow + ?Sized>(
807 inputs: impl IntoIterator<Item = Wire>,
808 node: Node,
809 data_builder: &mut T,
810) -> Result<BuildHandle<DataflowOpID>, BuildError> {
811 let op = data_builder.hugr().get_optype(node).clone();
812 let num_outputs = op.value_output_count();
813 wire_up_inputs(inputs, node, data_builder).map_err(|error| BuildError::OperationWiring {
814 op: Box::new(op),
815 error,
816 })?;
817 Ok((node, num_outputs).into())
818}
819
820fn wire_up<T: Dataflow + ?Sized>(
826 data_builder: &mut T,
827 src: Node,
828 src_port: impl Into<OutgoingPort>,
829 dst: Node,
830 dst_port: impl Into<IncomingPort>,
831) -> Result<bool, BuilderWiringError> {
832 let src_port = src_port.into();
833 let dst_port = dst_port.into();
834 let base = data_builder.hugr_mut();
835
836 let src_parent = base.get_parent(src);
837 let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
838 let dst_parent = base.get_parent(dst);
839 let local_source = src_parent == dst_parent;
840 if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
841 if !local_source {
842 if !typ.copyable() {
844 return Err(BuilderWiringError::NonCopyableIntergraph {
845 src,
846 src_offset: src_port.into(),
847 dst,
848 dst_offset: dst_port.into(),
849 typ: Box::new(typ),
850 });
851 }
852
853 let src_parent = src_parent.expect("Node has no parent");
854 let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
855 .tuple_windows()
856 .find_map(|(ancestor, ancestor_parent)| {
857 (ancestor_parent == src_parent ||
858 Some(ancestor_parent) == src_parent_parent)
860 .then_some(ancestor)
861 })
862 else {
863 return Err(BuilderWiringError::NoRelationIntergraph {
864 src,
865 src_offset: src_port.into(),
866 dst,
867 dst_offset: dst_port.into(),
868 });
869 };
870
871 if !OpTag::ControlFlowChild.is_superset(base.get_optype(src).tag())
872 && !OpTag::ControlFlowChild.is_superset(base.get_optype(src_sibling).tag())
873 {
874 base.add_other_edge(src, src_sibling);
876 }
877 } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
878 return Err(BuilderWiringError::NoCopyLinear {
880 typ: Box::new(typ),
881 src,
882 src_offset: src_port.into(),
883 });
884 }
885 }
886
887 data_builder
888 .hugr_mut()
889 .connect(src, src_port, dst, dst_port);
890 Ok(local_source
891 && matches!(
892 data_builder
893 .hugr_mut()
894 .get_optype(dst)
895 .port_kind(dst_port)
896 .unwrap(),
897 EdgeKind::Value(_)
898 ))
899}
900
901pub trait DataflowHugr: HugrBuilder + Dataflow {
903 fn finish_hugr_with_outputs(
909 mut self,
910 outputs: impl IntoIterator<Item = Wire>,
911 ) -> Result<Hugr, BuildError>
912 where
913 Self: Sized,
914 {
915 self.set_outputs(outputs)?;
916 Ok(self.finish_hugr()?)
917 }
918}
919
920pub trait DataflowSubContainer: SubContainer + Dataflow {
922 fn finish_with_outputs(
929 mut self,
930 outputs: impl IntoIterator<Item = Wire>,
931 ) -> Result<Self::ContainerHandle, BuildError>
932 where
933 Self: Sized,
934 {
935 self.set_outputs(outputs)?;
936 self.finish_sub_container()
937 }
938}
939
940impl<T: HugrBuilder + Dataflow> DataflowHugr for T {}
941impl<T: SubContainer + Dataflow> DataflowSubContainer for T {}