1use crate::extension::prelude::MakeTuple;
2use crate::hugr::hugrmut::InsertionResult;
3use crate::hugr::views::HugrView;
4use crate::hugr::{NodeMetadata, ValidationError};
5use crate::ops::{self, OpTag, OpTrait, OpType, Tag, TailLoop};
6use crate::utils::collect_array;
7use crate::{Extension, IncomingPort, Node, OutgoingPort};
8
9use std::iter;
10use std::sync::Arc;
11
12use super::{
13 handle::{BuildHandle, Outputs},
14 CircuitBuilder,
15};
16use super::{BuilderWiringError, FunctionBuilder};
17
18use crate::{
19 ops::handle::{ConstID, DataflowOpID, FuncID, NodeHandle},
20 types::EdgeKind,
21};
22
23use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED};
24use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeRow};
25
26use itertools::Itertools;
27
28use super::{
29 cfg::CFGBuilder, conditional::ConditionalBuilder, dataflow::DFGBuilder,
30 tail_loop::TailLoopBuilder, BuildError, Wire,
31};
32
33use crate::Hugr;
34
35use crate::hugr::HugrMut;
36
37pub trait Container {
42 fn container_node(&self) -> Node;
44 fn hugr_mut(&mut self) -> &mut Hugr;
46 fn hugr(&self) -> &Hugr;
48 fn add_child_node(&mut self, node: impl Into<OpType>) -> Node {
52 let node: OpType = node.into();
53
54 let used_extensions = node
56 .used_extensions()
57 .unwrap_or_else(|e| panic!("Build-time signatures should have valid extensions. {e}"));
58 self.use_extensions(used_extensions);
59
60 let parent = self.container_node();
61 self.hugr_mut().add_node_with_parent(parent, node)
62 }
63
64 fn add_other_wire(&mut self, src: Node, dst: Node) -> Wire {
69 let (src_port, _) = self.hugr_mut().add_other_edge(src, dst);
70 Wire::new(src, src_port)
71 }
72
73 fn add_constant(&mut self, constant: impl Into<ops::Const>) -> ConstID {
82 self.add_child_node(constant.into()).into()
83 }
84
85 fn define_function(
93 &mut self,
94 name: impl Into<String>,
95 signature: impl Into<PolyFuncType>,
96 ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
97 let signature = signature.into();
98 let body = signature.body().clone();
99 let f_node = self.add_child_node(ops::FuncDefn {
100 name: name.into(),
101 signature,
102 });
103
104 self.use_extensions(
106 body.used_extensions().unwrap_or_else(|e| {
107 panic!("Build-time signatures should have valid extensions. {e}")
108 }),
109 );
110
111 let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
112 Ok(FunctionBuilder::from_dfg_builder(db))
113 }
114
115 fn add_hugr(&mut self, child: Hugr) -> InsertionResult {
117 let parent = self.container_node();
118 self.hugr_mut().insert_hugr(parent, child)
119 }
120
121 fn add_hugr_view(&mut self, child: &impl HugrView) -> InsertionResult {
123 let parent = self.container_node();
124 self.hugr_mut().insert_from_view(parent, child)
125 }
126
127 fn set_metadata(&mut self, key: impl AsRef<str>, meta: impl Into<NodeMetadata>) {
129 let parent = self.container_node();
130 self.hugr_mut().set_metadata(parent, key, meta);
132 }
133
134 fn set_child_metadata(
138 &mut self,
139 child: Node,
140 key: impl AsRef<str>,
141 meta: impl Into<NodeMetadata>,
142 ) {
143 self.hugr_mut().set_metadata(child, key, meta);
144 }
145
146 fn use_extension(&mut self, ext: impl Into<Arc<Extension>>) {
148 self.hugr_mut().use_extension(ext);
149 }
150
151 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
153 where
154 ExtensionRegistry: Extend<Reg>,
155 {
156 self.hugr_mut().extensions_mut().extend(registry);
157 }
158}
159
160pub trait HugrBuilder: Container {
163 fn finish_hugr(self) -> Result<Hugr, ValidationError>;
165}
166
167pub trait SubContainer: Container {
169 type ContainerHandle;
172 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError>;
175}
176pub trait Dataflow: Container {
178 fn num_inputs(&self) -> usize;
180 fn io(&self) -> [Node; 2] {
182 self.hugr()
183 .children(self.container_node())
184 .take(2)
185 .collect_vec()
186 .try_into()
187 .expect("First two children should be IO")
188 }
189 fn input(&self) -> BuildHandle<DataflowOpID> {
191 (self.io()[0], self.num_inputs()).into()
192 }
193 fn output(&self) -> DataflowOpID {
195 self.io()[1].into()
196 }
197 fn input_wires(&self) -> Outputs {
199 self.input().outputs()
200 }
201 fn add_dataflow_op(
210 &mut self,
211 nodetype: impl Into<OpType>,
212 input_wires: impl IntoIterator<Item = Wire>,
213 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
214 let outs = add_node_with_wires(self, nodetype, input_wires)?;
215
216 Ok(outs.into())
217 }
218
219 fn add_hugr_with_wires(
227 &mut self,
228 hugr: Hugr,
229 input_wires: impl IntoIterator<Item = Wire>,
230 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
231 let optype = hugr.get_optype(hugr.root()).clone();
232 let num_outputs = optype.value_output_count();
233 let node = self.add_hugr(hugr).new_root;
234
235 wire_up_inputs(input_wires, node, self)
236 .map_err(|error| BuildError::OperationWiring { op: optype, error })?;
237
238 Ok((node, num_outputs).into())
239 }
240
241 fn add_hugr_view_with_wires(
249 &mut self,
250 hugr: &impl HugrView,
251 input_wires: impl IntoIterator<Item = Wire>,
252 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
253 let node = self.add_hugr_view(hugr).new_root;
254 let optype = hugr.get_optype(hugr.root()).clone();
255 let num_outputs = optype.value_output_count();
256
257 wire_up_inputs(input_wires, node, self)
258 .map_err(|error| BuildError::OperationWiring { op: optype, error })?;
259
260 Ok((node, num_outputs).into())
261 }
262
263 fn set_outputs(
269 &mut self,
270 output_wires: impl IntoIterator<Item = Wire>,
271 ) -> Result<(), BuildError> {
272 let [_, out] = self.io();
273 wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| {
274 BuildError::OutputWiring {
275 container_op: self.hugr().get_optype(self.container_node()).clone(),
276 container_node: self.container_node(),
277 error,
278 }
279 })
280 }
281
282 fn input_wires_arr<const N: usize>(&self) -> [Wire; N] {
288 collect_array(self.input_wires())
289 }
290
291 fn dfg_builder(
301 &mut self,
302 signature: Signature,
303 input_wires: impl IntoIterator<Item = Wire>,
304 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
305 let op = ops::DFG {
306 signature: signature.clone(),
307 };
308 let (dfg_n, _) = add_node_with_wires(self, op, input_wires)?;
309
310 DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
311 }
312
313 fn dfg_builder_endo(
318 &mut self,
319 inputs: impl IntoIterator<Item = (Type, Wire)>,
320 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
321 let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
322 self.dfg_builder(
323 Signature::new_endo(types).with_extension_delta(TO_BE_INFERRED),
324 input_wires,
325 )
326 }
327
328 fn cfg_builder(
340 &mut self,
341 inputs: impl IntoIterator<Item = (Type, Wire)>,
342 output_types: TypeRow,
343 ) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
344 self.cfg_builder_exts(inputs, output_types, TO_BE_INFERRED)
345 }
346
347 fn cfg_builder_exts(
360 &mut self,
361 inputs: impl IntoIterator<Item = (Type, Wire)>,
362 output_types: TypeRow,
363 extension_delta: impl Into<ExtensionSet>,
364 ) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
365 let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
366
367 let inputs: TypeRow = input_types.into();
368
369 let (cfg_node, _) = add_node_with_wires(
370 self,
371 ops::CFG {
372 signature: Signature::new(inputs.clone(), output_types.clone())
373 .with_extension_delta(extension_delta),
374 },
375 input_wires,
376 )?;
377 CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
378 }
379
380 fn load_const(&mut self, cid: &ConstID) -> Wire {
383 let const_node = cid.node();
384 let nodetype = self.hugr().get_optype(const_node);
385 let op: ops::Const = nodetype
386 .clone()
387 .try_into()
388 .expect("ConstID does not refer to Const op.");
389
390 let load_n = self
391 .add_dataflow_op(
392 ops::LoadConstant {
393 datatype: op.get_type().clone(),
394 },
395 vec![Wire::new(const_node, OutgoingPort::from(0))],
397 )
398 .expect("The constant type should match the LoadConstant type.");
399
400 load_n.out_wire(0)
401 }
402
403 fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Wire {
406 let cid = self.add_constant(constant);
407 self.load_const(&cid)
408 }
409
410 fn add_load_value(&mut self, constant: impl Into<ops::Value>) -> Wire {
413 self.add_load_const(constant.into())
414 }
415
416 fn load_func<const DEFINED: bool>(
422 &mut self,
423 fid: &FuncID<DEFINED>,
424 type_args: &[TypeArg],
425 ) -> Result<Wire, BuildError> {
426 let func_node = fid.node();
427 let func_op = self.hugr().get_optype(func_node);
428 let func_sig = match func_op {
429 OpType::FuncDefn(ops::FuncDefn { signature, .. })
430 | OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(),
431 _ => {
432 return Err(BuildError::UnexpectedType {
433 node: func_node,
434 op_desc: "FuncDecl/FuncDefn",
435 })
436 }
437 };
438
439 let load_n = self.add_dataflow_op(
440 ops::LoadFunction::try_new(func_sig, type_args)?,
441 vec![Wire::new(func_node, func_op.static_output_port().unwrap())],
443 )?;
444
445 Ok(load_n.out_wire(0))
446 }
447
448 fn tail_loop_builder(
460 &mut self,
461 just_inputs: impl IntoIterator<Item = (Type, Wire)>,
462 inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
463 just_out_types: TypeRow,
464 ) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
465 self.tail_loop_builder_exts(just_inputs, inputs_outputs, just_out_types, TO_BE_INFERRED)
466 }
467
468 fn tail_loop_builder_exts(
480 &mut self,
481 just_inputs: impl IntoIterator<Item = (Type, Wire)>,
482 inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
483 just_out_types: TypeRow,
484 extension_delta: impl Into<ExtensionSet>,
485 ) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
486 let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
487 just_inputs.into_iter().unzip();
488 let (rest_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
489 inputs_outputs.into_iter().unzip();
490 input_wires.extend(rest_input_wires);
491
492 let tail_loop = ops::TailLoop {
493 just_inputs: input_types.into(),
494 just_outputs: just_out_types,
495 rest: rest_types.into(),
496 extension_delta: extension_delta.into(),
497 };
498 let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
500
501 TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
502 }
503
504 fn conditional_builder(
517 &mut self,
518 sum_input: (impl IntoIterator<Item = TypeRow>, Wire),
519 other_inputs: impl IntoIterator<Item = (Type, Wire)>,
520 output_types: TypeRow,
521 ) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
522 self.conditional_builder_exts(sum_input, other_inputs, output_types, TO_BE_INFERRED)
523 }
524
525 fn conditional_builder_exts(
540 &mut self,
541 (sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
542 other_inputs: impl IntoIterator<Item = (Type, Wire)>,
543 output_types: TypeRow,
544 extension_delta: impl Into<ExtensionSet>,
545 ) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
546 let mut input_wires = vec![sum_wire];
547 let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
548 other_inputs.into_iter().unzip();
549
550 input_wires.extend(rest_input_wires);
551 let inputs: TypeRow = input_types.into();
552 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
553 let n_cases = sum_rows.len();
554 let n_out_wires = output_types.len();
555
556 let conditional_id = self.add_dataflow_op(
557 ops::Conditional {
558 sum_rows,
559 other_inputs: inputs,
560 outputs: output_types,
561 extension_delta: extension_delta.into(),
562 },
563 input_wires,
564 )?;
565
566 Ok(ConditionalBuilder {
567 base: self.hugr_mut(),
568 conditional_node: conditional_id.node(),
569 n_out_wires,
570 case_nodes: vec![None; n_cases],
571 })
572 }
573
574 fn set_order(&mut self, before: &impl NodeHandle, after: &impl NodeHandle) {
577 self.add_other_wire(before.node(), after.node());
578 }
579
580 fn get_wire_type(&self, wire: Wire) -> Result<Type, BuildError> {
582 let kind = self.hugr().get_optype(wire.node()).port_kind(wire.source());
583
584 if let Some(EdgeKind::Value(typ)) = kind {
585 Ok(typ)
586 } else {
587 Err(BuildError::WireNotFound(wire))
588 }
589 }
590
591 fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
599 let values = values.into_iter().collect_vec();
600 let types: Result<Vec<Type>, _> = values
601 .iter()
602 .map(|&wire| self.get_wire_type(wire))
603 .collect();
604 let types = types?.into();
605 let make_op = self.add_dataflow_op(MakeTuple(types), values)?;
606 Ok(make_op.out_wire(0))
607 }
608
609 fn make_sum(
619 &mut self,
620 tag: usize,
621 variants: impl IntoIterator<Item = TypeRow>,
622 values: impl IntoIterator<Item = Wire>,
623 ) -> Result<Wire, BuildError> {
624 let make_op = self.add_dataflow_op(
625 Tag {
626 tag,
627 variants: variants.into_iter().collect_vec(),
628 },
629 values.into_iter().collect_vec(),
630 )?;
631 Ok(make_op.out_wire(0))
632 }
633
634 fn make_continue(
644 &mut self,
645 tail_loop: ops::TailLoop,
646 values: impl IntoIterator<Item = Wire>,
647 ) -> Result<Wire, BuildError> {
648 self.make_sum(
649 TailLoop::CONTINUE_TAG,
650 [tail_loop.just_inputs, tail_loop.just_outputs],
651 values,
652 )
653 }
654
655 fn make_break(
665 &mut self,
666 loop_op: ops::TailLoop,
667 values: impl IntoIterator<Item = Wire>,
668 ) -> Result<Wire, BuildError> {
669 self.make_sum(
670 TailLoop::BREAK_TAG,
671 [loop_op.just_inputs, loop_op.just_outputs],
672 values,
673 )
674 }
675
676 fn call<const DEFINED: bool>(
685 &mut self,
686 function: &FuncID<DEFINED>,
687 type_args: &[TypeArg],
688 input_wires: impl IntoIterator<Item = Wire>,
689 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
690 let hugr = self.hugr();
691 let def_op = hugr.get_optype(function.node());
692 let type_scheme = match def_op {
693 OpType::FuncDefn(ops::FuncDefn { signature, .. })
694 | OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(),
695 _ => {
696 return Err(BuildError::UnexpectedType {
697 node: function.node(),
698 op_desc: "FuncDecl/FuncDefn",
699 })
700 }
701 };
702 let op: OpType = ops::Call::try_new(type_scheme, type_args)?.into();
703 let const_in_port = op.static_input_port().unwrap();
704 let op_id = self.add_dataflow_op(op, input_wires)?;
705 let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
706
707 self.hugr_mut()
708 .connect(function.node(), src_port, op_id.node(), const_in_port);
709 Ok(op_id)
710 }
711
712 fn as_circuit(&mut self, wires: impl IntoIterator<Item = Wire>) -> CircuitBuilder<Self> {
715 CircuitBuilder::new(wires, self)
716 }
717
718 fn add_barrier(
727 &mut self,
728 wires: impl IntoIterator<Item = Wire>,
729 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
730 let wires = wires.into_iter().collect_vec();
731 let types: Result<Vec<Type>, _> =
732 wires.iter().map(|&wire| self.get_wire_type(wire)).collect();
733 let types = types?;
734 let barrier_op =
735 self.add_dataflow_op(crate::extension::prelude::Barrier::new(types), wires)?;
736 Ok(barrier_op)
737 }
738}
739
740fn add_node_with_wires<T: Dataflow + ?Sized>(
749 data_builder: &mut T,
750 nodetype: impl Into<OpType>,
751 inputs: impl IntoIterator<Item = Wire>,
752) -> Result<(Node, usize), BuildError> {
753 let op = nodetype.into();
754 let num_outputs = op.value_output_count();
755 let op_node = data_builder.add_child_node(op.clone());
756
757 wire_up_inputs(inputs, op_node, data_builder)
758 .map_err(|error| BuildError::OperationWiring { op, error })?;
759
760 Ok((op_node, num_outputs))
761}
762
763fn wire_up_inputs<T: Dataflow + ?Sized>(
771 inputs: impl IntoIterator<Item = Wire>,
772 op_node: Node,
773 data_builder: &mut T,
774) -> Result<(), BuilderWiringError> {
775 for (dst_port, wire) in inputs.into_iter().enumerate() {
776 wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?;
777 }
778 Ok(())
779}
780
781fn wire_up<T: Dataflow + ?Sized>(
787 data_builder: &mut T,
788 src: Node,
789 src_port: impl Into<OutgoingPort>,
790 dst: Node,
791 dst_port: impl Into<IncomingPort>,
792) -> Result<bool, BuilderWiringError> {
793 let src_port = src_port.into();
794 let dst_port = dst_port.into();
795 let base = data_builder.hugr_mut();
796
797 let src_parent = base.get_parent(src);
798 let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
799 let dst_parent = base.get_parent(dst);
800 let local_source = src_parent == dst_parent;
801 if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
802 if !local_source {
803 if !typ.copyable() {
805 return Err(BuilderWiringError::NonCopyableIntergraph {
806 src,
807 src_offset: src_port.into(),
808 dst,
809 dst_offset: dst_port.into(),
810 typ,
811 });
812 }
813
814 let src_parent = src_parent.expect("Node has no parent");
815 let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
816 .tuple_windows()
817 .find_map(|(ancestor, ancestor_parent)| {
818 (ancestor_parent == src_parent ||
819 Some(ancestor_parent) == src_parent_parent)
821 .then_some(ancestor)
822 })
823 else {
824 return Err(BuilderWiringError::NoRelationIntergraph {
825 src,
826 src_offset: src_port.into(),
827 dst,
828 dst_offset: dst_port.into(),
829 });
830 };
831
832 if !OpTag::ControlFlowChild.is_superset(base.get_optype(src).tag())
833 && !OpTag::ControlFlowChild.is_superset(base.get_optype(src_sibling).tag())
834 {
835 base.add_other_edge(src, src_sibling);
837 }
838 } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
839 return Err(BuilderWiringError::NoCopyLinear {
841 typ,
842 src,
843 src_offset: src_port.into(),
844 });
845 }
846 }
847
848 data_builder
849 .hugr_mut()
850 .connect(src, src_port, dst, dst_port);
851 Ok(local_source
852 && matches!(
853 data_builder
854 .hugr_mut()
855 .get_optype(dst)
856 .port_kind(dst_port)
857 .unwrap(),
858 EdgeKind::Value(_)
859 ))
860}
861
862pub trait DataflowHugr: HugrBuilder + Dataflow {
864 fn finish_hugr_with_outputs(
870 mut self,
871 outputs: impl IntoIterator<Item = Wire>,
872 ) -> Result<Hugr, BuildError>
873 where
874 Self: Sized,
875 {
876 self.set_outputs(outputs)?;
877 Ok(self.finish_hugr()?)
878 }
879}
880
881pub trait DataflowSubContainer: SubContainer + Dataflow {
883 fn finish_with_outputs(
890 mut self,
891 outputs: impl IntoIterator<Item = Wire>,
892 ) -> Result<Self::ContainerHandle, BuildError>
893 where
894 Self: Sized,
895 {
896 self.set_outputs(outputs)?;
897 self.finish_sub_container()
898 }
899}
900
901impl<T: HugrBuilder + Dataflow> DataflowHugr for T {}
902impl<T: SubContainer + Dataflow> DataflowSubContainer for T {}